fix: 统一state_dict 处理方式

This commit is contained in:
ViperEkura 2026-03-13 22:41:56 +08:00
parent e35cb0d84a
commit c4feab96fe
1 changed files with 4 additions and 6 deletions

View File

@ -1,5 +1,5 @@
import json
import torch
import safetensors.torch as st
import torch.distributed as dist
from pathlib import Path
@ -35,8 +35,7 @@ class Checkpoint:
with open(save_path / "meta.json", "w") as f:
json.dump(meta, f, indent=2)
with open(save_path / f"state_dict.pt", "wb") as f:
torch.save(self.state_dict, f)
st.save_file(self.state_dict, save_path / f"state_dict.safetensors")
@classmethod
def load(
@ -57,8 +56,7 @@ class Checkpoint:
dist.broadcast_object_list(meta_list, src=0)
meta = meta_list[0]
with open(save_path / f"state_dict.pt", "rb") as f:
state_dict = torch.load(f)
state_dict = st.load_file(save_path / f"state_dict.safetensors")
return cls(
state_dict=state_dict,