feat(khaosz/core/parameter): 添加采样器状态的保存与加载功能

This commit is contained in:
ViperEkura 2025-09-29 19:49:35 +08:00
parent 198c1ac55c
commit e0e9942e4a
1 changed files with 18 additions and 4 deletions

View File

@ -110,6 +110,10 @@ class Checkpoint(BaseModelIO):
default=None, default=None,
metadata={"help": "Optimizer state."} metadata={"help": "Optimizer state."}
) )
sampler_state: Dict[str, Any] = field(
default=None,
metadata={"help": "Sampler state."}
)
loss_list: List[float] = field( loss_list: List[float] = field(
default_factory=list, default_factory=list,
metadata={"help": "List of training losses."} metadata={"help": "List of training losses."}
@ -120,7 +124,8 @@ class Checkpoint(BaseModelIO):
paths.update({ paths.update({
"loss_list": paths["model"].parent / "loss.pkl", "loss_list": paths["model"].parent / "loss.pkl",
"loss_plot": paths["model"].parent / "loss.png", "loss_plot": paths["model"].parent / "loss.png",
"optimizer": paths["model"].parent / "optimizer.pkl" "optim_state": paths["model"].parent / "optim_state.pkl",
"sampler_state": paths["model"].parent / "sampler_state.pkl"
}) })
return paths return paths
@ -135,8 +140,12 @@ class Checkpoint(BaseModelIO):
pkl.dump(self.loss_list, f) pkl.dump(self.loss_list, f)
# Save optimizer state # Save optimizer state
with open(str(paths["optimizer"]), "wb") as f: with open(str(paths["optim_state"]), "wb") as f:
pkl.dump(self.optim_state, f) pkl.dump(self.optim_state, f)
# Save sampler state
with open(str(paths["sampler_state"]), "wb") as f:
pkl.dump(self.sampler_state, f)
def load_training_state(self, load_dir: Union[str, Path]) -> Self: def load_training_state(self, load_dir: Union[str, Path]) -> Self:
paths = self._get_training_paths(load_dir) paths = self._get_training_paths(load_dir)
@ -147,10 +156,15 @@ class Checkpoint(BaseModelIO):
self.loss_list = pkl.load(f) self.loss_list = pkl.load(f)
# Load optimizer state # Load optimizer state
if paths["optimizer"].exists(): if paths["optim_state"].exists():
with open(str(paths["optimizer"]), "rb") as f: with open(str(paths["optim_state"]), "rb") as f:
self.optim_state = pkl.load(f) self.optim_state = pkl.load(f)
# Load sampler state
if paths["sampler_state"].exists():
with open(str(paths["sampler_state"]), "rb") as f:
self.sampler_state = pkl.load(f)
return self return self
def _plot_loss(self, save_path: str): def _plot_loss(self, save_path: str):