From e0e9942e4a3c3e4e0a443816bca9085016eb327f Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 29 Sep 2025 19:49:35 +0800 Subject: [PATCH] =?UTF-8?q?feat(khaosz/core/parameter):=20=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E9=87=87=E6=A0=B7=E5=99=A8=E7=8A=B6=E6=80=81=E7=9A=84?= =?UTF-8?q?=E4=BF=9D=E5=AD=98=E4=B8=8E=E5=8A=A0=E8=BD=BD=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/core/parameter.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/khaosz/core/parameter.py b/khaosz/core/parameter.py index 75c8fba..a0809a9 100644 --- a/khaosz/core/parameter.py +++ b/khaosz/core/parameter.py @@ -110,6 +110,10 @@ class Checkpoint(BaseModelIO): default=None, metadata={"help": "Optimizer state."} ) + sampler_state: Dict[str, Any] = field( + default=None, + metadata={"help": "Sampler state."} + ) loss_list: List[float] = field( default_factory=list, metadata={"help": "List of training losses."} @@ -120,7 +124,8 @@ class Checkpoint(BaseModelIO): paths.update({ "loss_list": paths["model"].parent / "loss.pkl", "loss_plot": paths["model"].parent / "loss.png", - "optimizer": paths["model"].parent / "optimizer.pkl" + "optim_state": paths["model"].parent / "optim_state.pkl", + "sampler_state": paths["model"].parent / "sampler_state.pkl" }) return paths @@ -135,8 +140,12 @@ class Checkpoint(BaseModelIO): pkl.dump(self.loss_list, f) # Save optimizer state - with open(str(paths["optimizer"]), "wb") as f: + with open(str(paths["optim_state"]), "wb") as f: pkl.dump(self.optim_state, f) + + # Save sampler state + with open(str(paths["sampler_state"]), "wb") as f: + pkl.dump(self.sampler_state, f) def load_training_state(self, load_dir: Union[str, Path]) -> Self: paths = self._get_training_paths(load_dir) @@ -147,10 +156,15 @@ class Checkpoint(BaseModelIO): self.loss_list = pkl.load(f) # Load optimizer state - if paths["optimizer"].exists(): - with open(str(paths["optimizer"]), "rb") as f: + if paths["optim_state"].exists(): + with open(str(paths["optim_state"]), "rb") as f: self.optim_state = pkl.load(f) + # Load sampler state + if paths["sampler_state"].exists(): + with open(str(paths["sampler_state"]), "rb") as f: + self.sampler_state = pkl.load(f) + return self def _plot_loss(self, save_path: str):