From f9b6331ad722b63c5a6f59a7785df04c03179844 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 6 Oct 2025 20:11:46 +0800 Subject: [PATCH] =?UTF-8?q?refactor(khaosz/core/parameter):=20=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=E5=8F=82=E6=95=B0=E5=90=8D=E7=A7=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/core/parameter.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/khaosz/core/parameter.py b/khaosz/core/parameter.py index a0809a9..2daa894 100644 --- a/khaosz/core/parameter.py +++ b/khaosz/core/parameter.py @@ -106,7 +106,7 @@ class Checkpoint(BaseModelIO): default_factory=TransformerConfig, metadata={"help": "Transformer model configuration."} ) - optim_state: Dict[str, Any] = field( + optimizer_state: Dict[str, Any] = field( default=None, metadata={"help": "Optimizer state."} ) @@ -124,7 +124,7 @@ class Checkpoint(BaseModelIO): paths.update({ "loss_list": paths["model"].parent / "loss.pkl", "loss_plot": paths["model"].parent / "loss.png", - "optim_state": paths["model"].parent / "optim_state.pkl", + "optimizer_state": paths["model"].parent / "optimizer_state.pkl", "sampler_state": paths["model"].parent / "sampler_state.pkl" }) return paths @@ -140,8 +140,8 @@ class Checkpoint(BaseModelIO): pkl.dump(self.loss_list, f) # Save optimizer state - with open(str(paths["optim_state"]), "wb") as f: - pkl.dump(self.optim_state, f) + with open(str(paths["optimizer_state"]), "wb") as f: + pkl.dump(self.optimizer_state, f) # Save sampler state with open(str(paths["sampler_state"]), "wb") as f: @@ -156,9 +156,9 @@ class Checkpoint(BaseModelIO): self.loss_list = pkl.load(f) # Load optimizer state - if paths["optim_state"].exists(): - with open(str(paths["optim_state"]), "rb") as f: - self.optim_state = pkl.load(f) + if paths["optimizer_state"].exists(): + with open(str(paths["optimizer_state"]), "rb") as f: + self.optimizer_state = pkl.load(f) # Load sampler state if paths["sampler_state"].exists():