feat(khaosz/core/parameter): 添加采样器状态的保存与加载功能
This commit is contained in:
parent
198c1ac55c
commit
e0e9942e4a
|
|
@ -110,6 +110,10 @@ class Checkpoint(BaseModelIO):
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Optimizer state."}
|
metadata={"help": "Optimizer state."}
|
||||||
)
|
)
|
||||||
|
sampler_state: Dict[str, Any] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Sampler state."}
|
||||||
|
)
|
||||||
loss_list: List[float] = field(
|
loss_list: List[float] = field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
metadata={"help": "List of training losses."}
|
metadata={"help": "List of training losses."}
|
||||||
|
|
@ -120,7 +124,8 @@ class Checkpoint(BaseModelIO):
|
||||||
paths.update({
|
paths.update({
|
||||||
"loss_list": paths["model"].parent / "loss.pkl",
|
"loss_list": paths["model"].parent / "loss.pkl",
|
||||||
"loss_plot": paths["model"].parent / "loss.png",
|
"loss_plot": paths["model"].parent / "loss.png",
|
||||||
"optimizer": paths["model"].parent / "optimizer.pkl"
|
"optim_state": paths["model"].parent / "optim_state.pkl",
|
||||||
|
"sampler_state": paths["model"].parent / "sampler_state.pkl"
|
||||||
})
|
})
|
||||||
return paths
|
return paths
|
||||||
|
|
||||||
|
|
@ -135,8 +140,12 @@ class Checkpoint(BaseModelIO):
|
||||||
pkl.dump(self.loss_list, f)
|
pkl.dump(self.loss_list, f)
|
||||||
|
|
||||||
# Save optimizer state
|
# Save optimizer state
|
||||||
with open(str(paths["optimizer"]), "wb") as f:
|
with open(str(paths["optim_state"]), "wb") as f:
|
||||||
pkl.dump(self.optim_state, f)
|
pkl.dump(self.optim_state, f)
|
||||||
|
|
||||||
|
# Save sampler state
|
||||||
|
with open(str(paths["sampler_state"]), "wb") as f:
|
||||||
|
pkl.dump(self.sampler_state, f)
|
||||||
|
|
||||||
def load_training_state(self, load_dir: Union[str, Path]) -> Self:
|
def load_training_state(self, load_dir: Union[str, Path]) -> Self:
|
||||||
paths = self._get_training_paths(load_dir)
|
paths = self._get_training_paths(load_dir)
|
||||||
|
|
@ -147,10 +156,15 @@ class Checkpoint(BaseModelIO):
|
||||||
self.loss_list = pkl.load(f)
|
self.loss_list = pkl.load(f)
|
||||||
|
|
||||||
# Load optimizer state
|
# Load optimizer state
|
||||||
if paths["optimizer"].exists():
|
if paths["optim_state"].exists():
|
||||||
with open(str(paths["optimizer"]), "rb") as f:
|
with open(str(paths["optim_state"]), "rb") as f:
|
||||||
self.optim_state = pkl.load(f)
|
self.optim_state = pkl.load(f)
|
||||||
|
|
||||||
|
# Load sampler state
|
||||||
|
if paths["sampler_state"].exists():
|
||||||
|
with open(str(paths["sampler_state"]), "rb") as f:
|
||||||
|
self.sampler_state = pkl.load(f)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _plot_loss(self, save_path: str):
|
def _plot_loss(self, save_path: str):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue