diff --git a/khaosz/data/checkpoint.py b/khaosz/data/checkpoint.py index d18ee2f..272aaf4 100644 --- a/khaosz/data/checkpoint.py +++ b/khaosz/data/checkpoint.py @@ -1,5 +1,5 @@ import json -import torch +import safetensors.torch as st import torch.distributed as dist from pathlib import Path @@ -34,9 +34,8 @@ class Checkpoint: } with open(save_path / "meta.json", "w") as f: json.dump(meta, f, indent=2) - - with open(save_path / f"state_dict.pt", "wb") as f: - torch.save(self.state_dict, f) + + st.save_file(self.state_dict, save_path / f"state_dict.safetensors") @classmethod def load( @@ -57,8 +56,7 @@ class Checkpoint: dist.broadcast_object_list(meta_list, src=0) meta = meta_list[0] - with open(save_path / f"state_dict.pt", "rb") as f: - state_dict = torch.load(f) + state_dict = st.load_file(save_path / f"state_dict.safetensors") return cls( state_dict=state_dict,