From c4feab96fe58071d4ccfa42a64be4085b2d15200 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Fri, 13 Mar 2026 22:41:56 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E7=BB=9F=E4=B8=80state=5Fdict=20?= =?UTF-8?q?=E5=A4=84=E7=90=86=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/data/checkpoint.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/khaosz/data/checkpoint.py b/khaosz/data/checkpoint.py index d18ee2f..272aaf4 100644 --- a/khaosz/data/checkpoint.py +++ b/khaosz/data/checkpoint.py @@ -1,5 +1,5 @@ import json -import torch +import safetensors.torch as st import torch.distributed as dist from pathlib import Path @@ -34,9 +34,8 @@ class Checkpoint: } with open(save_path / "meta.json", "w") as f: json.dump(meta, f, indent=2) - - with open(save_path / f"state_dict.pt", "wb") as f: - torch.save(self.state_dict, f) + + st.save_file(self.state_dict, save_path / f"state_dict.safetensors") @classmethod def load( @@ -57,8 +56,7 @@ class Checkpoint: dist.broadcast_object_list(meta_list, src=0) meta = meta_list[0] - with open(save_path / f"state_dict.pt", "rb") as f: - state_dict = torch.load(f) + state_dict = st.load_file(save_path / f"state_dict.safetensors") return cls( state_dict=state_dict,