fix: 统一state_dict 处理方式
This commit is contained in:
parent
e35cb0d84a
commit
c4feab96fe
|
|
@ -1,5 +1,5 @@
|
|||
import json
|
||||
import torch
|
||||
import safetensors.torch as st
|
||||
import torch.distributed as dist
|
||||
|
||||
from pathlib import Path
|
||||
|
|
@ -34,9 +34,8 @@ class Checkpoint:
|
|||
}
|
||||
with open(save_path / "meta.json", "w") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
with open(save_path / f"state_dict.pt", "wb") as f:
|
||||
torch.save(self.state_dict, f)
|
||||
|
||||
st.save_file(self.state_dict, save_path / f"state_dict.safetensors")
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
|
|
@ -57,8 +56,7 @@ class Checkpoint:
|
|||
dist.broadcast_object_list(meta_list, src=0)
|
||||
meta = meta_list[0]
|
||||
|
||||
with open(save_path / f"state_dict.pt", "rb") as f:
|
||||
state_dict = torch.load(f)
|
||||
state_dict = st.load_file(save_path / f"state_dict.safetensors")
|
||||
|
||||
return cls(
|
||||
state_dict=state_dict,
|
||||
|
|
|
|||
Loading…
Reference in New Issue