refactor(khaosz/core/parameter): 修改参数名称

This commit is contained in:
ViperEkura 2025-10-06 20:11:46 +08:00
parent 183f481692
commit f9b6331ad7
1 changed files with 7 additions and 7 deletions

View File

@ -106,7 +106,7 @@ class Checkpoint(BaseModelIO):
default_factory=TransformerConfig,
metadata={"help": "Transformer model configuration."}
)
optim_state: Dict[str, Any] = field(
optimizer_state: Dict[str, Any] = field(
default=None,
metadata={"help": "Optimizer state."}
)
@ -124,7 +124,7 @@ class Checkpoint(BaseModelIO):
paths.update({
"loss_list": paths["model"].parent / "loss.pkl",
"loss_plot": paths["model"].parent / "loss.png",
"optim_state": paths["model"].parent / "optim_state.pkl",
"optimizer_state": paths["model"].parent / "optimizer_state.pkl",
"sampler_state": paths["model"].parent / "sampler_state.pkl"
})
return paths
@ -140,8 +140,8 @@ class Checkpoint(BaseModelIO):
pkl.dump(self.loss_list, f)
# Save optimizer state
with open(str(paths["optim_state"]), "wb") as f:
pkl.dump(self.optim_state, f)
with open(str(paths["optimizer_state"]), "wb") as f:
pkl.dump(self.optimizer_state, f)
# Save sampler state
with open(str(paths["sampler_state"]), "wb") as f:
@ -156,9 +156,9 @@ class Checkpoint(BaseModelIO):
self.loss_list = pkl.load(f)
# Load optimizer state
if paths["optim_state"].exists():
with open(str(paths["optim_state"]), "rb") as f:
self.optim_state = pkl.load(f)
if paths["optimizer_state"].exists():
with open(str(paths["optimizer_state"]), "rb") as f:
self.optimizer_state = pkl.load(f)
# Load sampler state
if paths["sampler_state"].exists():