diff --git a/khaosz/core/parameter.py b/khaosz/core/parameter.py index 75c8fba..a0809a9 100644 --- a/khaosz/core/parameter.py +++ b/khaosz/core/parameter.py @@ -110,6 +110,10 @@ class Checkpoint(BaseModelIO): default=None, metadata={"help": "Optimizer state."} ) + sampler_state: Dict[str, Any] = field( + default=None, + metadata={"help": "Sampler state."} + ) loss_list: List[float] = field( default_factory=list, metadata={"help": "List of training losses."} @@ -120,7 +124,8 @@ class Checkpoint(BaseModelIO): paths.update({ "loss_list": paths["model"].parent / "loss.pkl", "loss_plot": paths["model"].parent / "loss.png", - "optimizer": paths["model"].parent / "optimizer.pkl" + "optim_state": paths["model"].parent / "optim_state.pkl", + "sampler_state": paths["model"].parent / "sampler_state.pkl" }) return paths @@ -135,8 +140,12 @@ class Checkpoint(BaseModelIO): pkl.dump(self.loss_list, f) # Save optimizer state - with open(str(paths["optimizer"]), "wb") as f: + with open(str(paths["optim_state"]), "wb") as f: pkl.dump(self.optim_state, f) + + # Save sampler state + with open(str(paths["sampler_state"]), "wb") as f: + pkl.dump(self.sampler_state, f) def load_training_state(self, load_dir: Union[str, Path]) -> Self: paths = self._get_training_paths(load_dir) @@ -147,10 +156,15 @@ class Checkpoint(BaseModelIO): self.loss_list = pkl.load(f) # Load optimizer state - if paths["optimizer"].exists(): - with open(str(paths["optimizer"]), "rb") as f: + if paths["optim_state"].exists(): + with open(str(paths["optim_state"]), "rb") as f: self.optim_state = pkl.load(f) + # Load sampler state + if paths["sampler_state"].exists(): + with open(str(paths["sampler_state"]), "rb") as f: + self.sampler_state = pkl.load(f) + return self def _plot_loss(self, save_path: str):