feat(khaosz): 优化模型参数保存与加载逻辑

This commit is contained in:
ViperEkura 2025-09-28 14:00:21 +08:00
parent 4fcdc87c95
commit f25a249291
2 changed files with 54 additions and 115 deletions

View File

@ -5,7 +5,7 @@ import torch.nn as nn
import torch.optim as optim
from dataclasses import dataclass, field
from typing import Optional, Self, Union
from typing import Any, Dict, List, Optional, Self, Union
from pathlib import Path
from khaosz.core.tokenizer import BpeTokenizer
@ -84,11 +84,9 @@ class ModelParameter(BaseModelIO):
)
def save(self, save_dir: Union[str, Path]):
"""Save model parameters."""
self.save_components(save_dir)
def load(self, load_dir: Union[str, Path]) -> Self:
"""Load model parameters."""
return self.load_components(load_dir)
@ -108,26 +106,16 @@ class Checkpoint(BaseModelIO):
default_factory=TransformerConfig,
metadata={"help": "Transformer model configuration."}
)
loss_list: list[float] = field(
default_factory=list,
metadata={"help": "List of training losses."}
)
current_iter: int = field(
default=0,
metadata={"help": "Current training iteration."}
)
optimizer: Optional[optim.Optimizer] = field(
optim_state: Dict[str, Any] = field(
default=None,
metadata={"help": "Optimizer state."}
)
def __post_init__(self):
# Ensure current_iter matches loss list length if not explicitly set
if self.current_iter == 0 and self.loss_list:
self.current_iter = len(self.loss_list)
loss_list: List[float] = field(
default_factory=list,
metadata={"help": "List of training losses."}
)
def _get_training_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
"""Get file paths for training-specific files."""
paths = self._get_file_paths(directory)
paths.update({
"loss_list": paths["model"].parent / "loss.pkl",
@ -137,7 +125,6 @@ class Checkpoint(BaseModelIO):
return paths
def save_training_state(self, save_dir: Union[str, Path]):
"""Save training-specific state."""
paths = self._get_training_paths(save_dir)
# Save loss plot
@ -148,25 +135,21 @@ class Checkpoint(BaseModelIO):
pkl.dump(self.loss_list, f)
# Save optimizer state
if self.optimizer is not None:
with open(str(paths["optimizer"]), "wb") as f:
pkl.dump(self.optimizer.state_dict(), f)
pkl.dump(self.optim_state, f)
def load_training_state(self, load_dir: Union[str, Path]) -> Self:
"""Load training-specific state."""
paths = self._get_training_paths(load_dir)
# Load loss list
if paths["loss_list"].exists():
with open(str(paths["loss_list"]), "rb") as f:
self.loss_list = pkl.load(f)
self.current_iter = len(self.loss_list)
# Load optimizer state
if paths["optimizer"].exists() and self.optimizer is not None:
if paths["optimizer"].exists():
with open(str(paths["optimizer"]), "rb") as f:
optim_state = pkl.load(f)
self.optimizer.load_state_dict(optim_state)
self.optim_state = pkl.load(f)
return self
@ -175,9 +158,11 @@ class Checkpoint(BaseModelIO):
if not self.loss_list:
return
current_iter = len(self.loss_list)
plt.figure(figsize=(10, 6))
plt.plot(self.loss_list)
plt.title(f"Training Loss - Iteration {self.current_iter}")
plt.title(f"Training Loss - Iteration {current_iter}")
plt.xlabel("Batch")
plt.ylabel("Loss")
plt.grid(True)
@ -224,7 +209,7 @@ class ParameterLoader:
tokenizer: BpeTokenizer,
config: TransformerConfig,
loss_list: Optional[list[float]] = None,
optimizer: Optional[optim.Optimizer] = None
optimizer: Optional[optim.Optimizer] = None,
) -> Checkpoint:
"""Convenience method to create a training checkpoint."""
return Checkpoint(
@ -232,7 +217,7 @@ class ParameterLoader:
tokenizer=tokenizer,
config=config,
loss_list=loss_list or [],
optimizer=optimizer
optimizer_state=optimizer
)

View File

@ -1,8 +1,7 @@
import os
import torch
import logging
from typing import Tuple
from typing import Optional
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, RandomSampler
@ -15,82 +14,57 @@ from khaosz.trainer.strategy import SchedulerFactory, StrategyFactory, TrainConf
class Trainer:
def __init__(
self,
parameter: ModelParameter,
log_path: str="./train_log.log"
parameter: ModelParameter
):
logger = logging.getLogger()
logger.setLevel(level = logging.INFO)
handler = logging.FileHandler(log_path)
handler.setLevel(logging.INFO)
handler.setFormatter(logging.Formatter('%(asctime)s: %(message)s'))
logger.addHandler(handler)
logger.info("initializing trainer ...")
self.logger = logger
self.model = parameter.model
self.tokenizer = parameter.tokenizer
self.config = parameter.config
self.checkpoint = Checkpoint(
model=parameter.model,
tokenizer=parameter.tokenizer,
config=parameter.config,
)
def save_checkpoint(
self,
loss_list: list,
ckpt_dir: str,
current_iter: int,
last_ckpt_iter: int
train_config: TrainConfig
):
save_path = os.path.join(ckpt_dir, f"iter_{current_iter}")
Checkpoint(
self.model,
self.tokenizer,
self.config,
loss_list,
current_iter
).save(save_path)
diff_iter = current_iter - last_ckpt_iter
avg_loss = sum(loss_list[last_ckpt_iter:current_iter]) / diff_iter
self.logger.info(f"iter: {current_iter} loss: {avg_loss}")
return current_iter
def load_checkpoint(self, train_checkpoint: Checkpoint) -> Tuple[list, int]:
self.model = train_checkpoint.model
self.tokenizer = train_checkpoint.tokenizer
self.config = train_checkpoint.config
loss_list = train_checkpoint.loss_list
last_ckpt_iter = train_checkpoint.current_iter
return loss_list, last_ckpt_iter
current_iter = len(loss_list)
save_path = os.path.join(train_config.ckpt_dir, f"iter_{current_iter}")
self.checkpoint.loss_list = loss_list
self.checkpoint.optim_state = train_config.optimizer.state_dict()
self.checkpoint.save(save_path)
def train(
self,
train_config: TrainConfig,
schedule_config: ScheduleConfig,
train_checkpoint: Checkpoint = None
):
train_checkpoint: Optional[Checkpoint] = None
) -> Checkpoint:
assert schedule_config.schedule_type in ["cosine", "sgdr"]
assert train_config.train_type in ["seq", "sft", "dpo"]
if train_checkpoint:
loss_list, last_ckpt_iter = self.load_checkpoint(train_checkpoint)
current_iter = train_checkpoint.current_iter + 1
self.logger.info(f"Resuming training from checkpoint: iter {current_iter}")
else:
current_iter = 0
last_ckpt_iter = 0
loss_list = []
self.checkpoint = train_checkpoint
train_config.optimizer.load_state_dict(train_checkpoint.optim_state)
loss_list = self.checkpoint.loss_list
current_iter = len(self.checkpoint.loss_list)
last_ckpt_iter = current_iter
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
**schedule_config.get_kwargs()
)
strategy_kwargs = {
"bos_token_id": self.checkpoint.tokenizer.bos_id,
"eos_token_id": self.checkpoint.tokenizer.eos_id,
"pad_token_id": self.checkpoint.tokenizer.pad_id,
"dpo_beta": train_config.dpo_beta
}
strategy = StrategyFactory.load(
self.model,
train_type=train_config.train_type,
bos_token_id=self.tokenizer.bos_id,
eos_token_id=self.tokenizer.eos_id,
pad_token_id=self.tokenizer.pad_id,
dpo_beta=train_config.dpo_beta
self.checkpoint.model,
train_config.train_type,
**strategy_kwargs
)
scheduler = LambdaLR(
@ -104,11 +78,8 @@ class Trainer:
sampler = RandomSampler(train_config.dataset, generator=generator)
remaining_epochs = train_config.n_epoch - current_iter // (len(train_config.dataset) // train_config.batch_size)
self.logger.info(f"Starting {train_config.train_type.upper()} training for {train_config.n_epoch} epochs")
self.logger.info(f"Checkpoint interval: {train_config.n_iter_ckpt} iterations")
for epoch in range(remaining_epochs):
self.model.train()
self.checkpoint.model.train()
dataloader = DataLoader(
train_config.dataset,
batch_size=train_config.batch_size,
@ -128,7 +99,7 @@ class Trainer:
#step
if current_iter % train_config.n_iter_step == 0:
clip_grad_norm_(
self.model.parameters(),
self.checkpoint.model.parameters(),
train_config.max_grad_norm
)
train_config.optimizer.step()
@ -142,28 +113,11 @@ class Trainer:
})
#save checkpotint
if current_iter - last_ckpt_iter >= train_config.n_iter_ckpt:
last_ckpt_iter = self.save_checkpoint(
loss_list,
train_config.ckpt_dir,
current_iter,
last_ckpt_iter
)
self.save_checkpoint(loss_list, train_config)
last_ckpt_iter = current_iter
if current_iter != last_ckpt_iter:
last_ckpt_iter = self.save_checkpoint(
loss_list,
train_config.ckpt_dir,
current_iter,
last_ckpt_iter
)
self.save_checkpoint(loss_list, train_config)
last_ckpt_iter = current_iter
self.logger.info("Training completed")
return Checkpoint(
self.model,
self.tokenizer,
self.config,
loss_list,
current_iter,
train_config.optimizer
)
return self.checkpoint