fix: 统一state_dict 处理方式
This commit is contained in:
parent
e35cb0d84a
commit
c4feab96fe
|
|
@ -1,5 +1,5 @@
|
||||||
import json
|
import json
|
||||||
import torch
|
import safetensors.torch as st
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
@ -34,9 +34,8 @@ class Checkpoint:
|
||||||
}
|
}
|
||||||
with open(save_path / "meta.json", "w") as f:
|
with open(save_path / "meta.json", "w") as f:
|
||||||
json.dump(meta, f, indent=2)
|
json.dump(meta, f, indent=2)
|
||||||
|
|
||||||
with open(save_path / f"state_dict.pt", "wb") as f:
|
st.save_file(self.state_dict, save_path / f"state_dict.safetensors")
|
||||||
torch.save(self.state_dict, f)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
|
|
@ -57,8 +56,7 @@ class Checkpoint:
|
||||||
dist.broadcast_object_list(meta_list, src=0)
|
dist.broadcast_object_list(meta_list, src=0)
|
||||||
meta = meta_list[0]
|
meta = meta_list[0]
|
||||||
|
|
||||||
with open(save_path / f"state_dict.pt", "rb") as f:
|
state_dict = st.load_file(save_path / f"state_dict.safetensors")
|
||||||
state_dict = torch.load(f)
|
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue