fix: 统一state_dict 处理方式

This commit is contained in:
ViperEkura 2026-03-13 22:41:56 +08:00
parent e35cb0d84a
commit c4feab96fe
1 changed files with 4 additions and 6 deletions

View File

@ -1,5 +1,5 @@
import json import json
import torch import safetensors.torch as st
import torch.distributed as dist import torch.distributed as dist
from pathlib import Path from pathlib import Path
@ -35,8 +35,7 @@ class Checkpoint:
with open(save_path / "meta.json", "w") as f: with open(save_path / "meta.json", "w") as f:
json.dump(meta, f, indent=2) json.dump(meta, f, indent=2)
with open(save_path / f"state_dict.pt", "wb") as f: st.save_file(self.state_dict, save_path / f"state_dict.safetensors")
torch.save(self.state_dict, f)
@classmethod @classmethod
def load( def load(
@ -57,8 +56,7 @@ class Checkpoint:
dist.broadcast_object_list(meta_list, src=0) dist.broadcast_object_list(meta_list, src=0)
meta = meta_list[0] meta = meta_list[0]
with open(save_path / f"state_dict.pt", "rb") as f: state_dict = st.load_file(save_path / f"state_dict.safetensors")
state_dict = torch.load(f)
return cls( return cls(
state_dict=state_dict, state_dict=state_dict,