feat(khaosz): 优化模型参数保存与加载逻辑

This commit is contained in:
ViperEkura 2025-09-28 14:00:21 +08:00
parent 4fcdc87c95
commit f25a249291
2 changed files with 54 additions and 115 deletions

View File

@ -5,7 +5,7 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional, Self, Union from typing import Any, Dict, List, Optional, Self, Union
from pathlib import Path from pathlib import Path
from khaosz.core.tokenizer import BpeTokenizer from khaosz.core.tokenizer import BpeTokenizer
@ -84,11 +84,9 @@ class ModelParameter(BaseModelIO):
) )
def save(self, save_dir: Union[str, Path]): def save(self, save_dir: Union[str, Path]):
"""Save model parameters."""
self.save_components(save_dir) self.save_components(save_dir)
def load(self, load_dir: Union[str, Path]) -> Self: def load(self, load_dir: Union[str, Path]) -> Self:
"""Load model parameters."""
return self.load_components(load_dir) return self.load_components(load_dir)
@ -108,26 +106,16 @@ class Checkpoint(BaseModelIO):
default_factory=TransformerConfig, default_factory=TransformerConfig,
metadata={"help": "Transformer model configuration."} metadata={"help": "Transformer model configuration."}
) )
loss_list: list[float] = field( optim_state: Dict[str, Any] = field(
default_factory=list,
metadata={"help": "List of training losses."}
)
current_iter: int = field(
default=0,
metadata={"help": "Current training iteration."}
)
optimizer: Optional[optim.Optimizer] = field(
default=None, default=None,
metadata={"help": "Optimizer state."} metadata={"help": "Optimizer state."}
) )
loss_list: List[float] = field(
def __post_init__(self): default_factory=list,
# Ensure current_iter matches loss list length if not explicitly set metadata={"help": "List of training losses."}
if self.current_iter == 0 and self.loss_list: )
self.current_iter = len(self.loss_list)
def _get_training_paths(self, directory: Union[str, Path]) -> dict[str, Path]: def _get_training_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
"""Get file paths for training-specific files."""
paths = self._get_file_paths(directory) paths = self._get_file_paths(directory)
paths.update({ paths.update({
"loss_list": paths["model"].parent / "loss.pkl", "loss_list": paths["model"].parent / "loss.pkl",
@ -137,7 +125,6 @@ class Checkpoint(BaseModelIO):
return paths return paths
def save_training_state(self, save_dir: Union[str, Path]): def save_training_state(self, save_dir: Union[str, Path]):
"""Save training-specific state."""
paths = self._get_training_paths(save_dir) paths = self._get_training_paths(save_dir)
# Save loss plot # Save loss plot
@ -148,25 +135,21 @@ class Checkpoint(BaseModelIO):
pkl.dump(self.loss_list, f) pkl.dump(self.loss_list, f)
# Save optimizer state # Save optimizer state
if self.optimizer is not None:
with open(str(paths["optimizer"]), "wb") as f: with open(str(paths["optimizer"]), "wb") as f:
pkl.dump(self.optimizer.state_dict(), f) pkl.dump(self.optim_state, f)
def load_training_state(self, load_dir: Union[str, Path]) -> Self: def load_training_state(self, load_dir: Union[str, Path]) -> Self:
"""Load training-specific state."""
paths = self._get_training_paths(load_dir) paths = self._get_training_paths(load_dir)
# Load loss list # Load loss list
if paths["loss_list"].exists(): if paths["loss_list"].exists():
with open(str(paths["loss_list"]), "rb") as f: with open(str(paths["loss_list"]), "rb") as f:
self.loss_list = pkl.load(f) self.loss_list = pkl.load(f)
self.current_iter = len(self.loss_list)
# Load optimizer state # Load optimizer state
if paths["optimizer"].exists() and self.optimizer is not None: if paths["optimizer"].exists():
with open(str(paths["optimizer"]), "rb") as f: with open(str(paths["optimizer"]), "rb") as f:
optim_state = pkl.load(f) self.optim_state = pkl.load(f)
self.optimizer.load_state_dict(optim_state)
return self return self
@ -175,9 +158,11 @@ class Checkpoint(BaseModelIO):
if not self.loss_list: if not self.loss_list:
return return
current_iter = len(self.loss_list)
plt.figure(figsize=(10, 6)) plt.figure(figsize=(10, 6))
plt.plot(self.loss_list) plt.plot(self.loss_list)
plt.title(f"Training Loss - Iteration {self.current_iter}") plt.title(f"Training Loss - Iteration {current_iter}")
plt.xlabel("Batch") plt.xlabel("Batch")
plt.ylabel("Loss") plt.ylabel("Loss")
plt.grid(True) plt.grid(True)
@ -224,7 +209,7 @@ class ParameterLoader:
tokenizer: BpeTokenizer, tokenizer: BpeTokenizer,
config: TransformerConfig, config: TransformerConfig,
loss_list: Optional[list[float]] = None, loss_list: Optional[list[float]] = None,
optimizer: Optional[optim.Optimizer] = None optimizer: Optional[optim.Optimizer] = None,
) -> Checkpoint: ) -> Checkpoint:
"""Convenience method to create a training checkpoint.""" """Convenience method to create a training checkpoint."""
return Checkpoint( return Checkpoint(
@ -232,7 +217,7 @@ class ParameterLoader:
tokenizer=tokenizer, tokenizer=tokenizer,
config=config, config=config,
loss_list=loss_list or [], loss_list=loss_list or [],
optimizer=optimizer optimizer_state=optimizer
) )

View File

@ -1,8 +1,7 @@
import os import os
import torch import torch
import logging
from typing import Tuple from typing import Optional
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import LambdaLR from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, RandomSampler from torch.utils.data import DataLoader, RandomSampler
@ -15,82 +14,57 @@ from khaosz.trainer.strategy import SchedulerFactory, StrategyFactory, TrainConf
class Trainer: class Trainer:
def __init__( def __init__(
self, self,
parameter: ModelParameter, parameter: ModelParameter
log_path: str="./train_log.log"
): ):
logger = logging.getLogger() self.checkpoint = Checkpoint(
logger.setLevel(level = logging.INFO) model=parameter.model,
handler = logging.FileHandler(log_path) tokenizer=parameter.tokenizer,
handler.setLevel(logging.INFO) config=parameter.config,
handler.setFormatter(logging.Formatter('%(asctime)s: %(message)s')) )
logger.addHandler(handler)
logger.info("initializing trainer ...")
self.logger = logger
self.model = parameter.model
self.tokenizer = parameter.tokenizer
self.config = parameter.config
def save_checkpoint( def save_checkpoint(
self, self,
loss_list: list, loss_list: list,
ckpt_dir: str, train_config: TrainConfig
current_iter: int,
last_ckpt_iter: int
): ):
save_path = os.path.join(ckpt_dir, f"iter_{current_iter}") current_iter = len(loss_list)
Checkpoint( save_path = os.path.join(train_config.ckpt_dir, f"iter_{current_iter}")
self.model, self.checkpoint.loss_list = loss_list
self.tokenizer, self.checkpoint.optim_state = train_config.optimizer.state_dict()
self.config, self.checkpoint.save(save_path)
loss_list,
current_iter
).save(save_path)
diff_iter = current_iter - last_ckpt_iter
avg_loss = sum(loss_list[last_ckpt_iter:current_iter]) / diff_iter
self.logger.info(f"iter: {current_iter} loss: {avg_loss}")
return current_iter
def load_checkpoint(self, train_checkpoint: Checkpoint) -> Tuple[list, int]:
self.model = train_checkpoint.model
self.tokenizer = train_checkpoint.tokenizer
self.config = train_checkpoint.config
loss_list = train_checkpoint.loss_list
last_ckpt_iter = train_checkpoint.current_iter
return loss_list, last_ckpt_iter
def train( def train(
self, self,
train_config: TrainConfig, train_config: TrainConfig,
schedule_config: ScheduleConfig, schedule_config: ScheduleConfig,
train_checkpoint: Checkpoint = None train_checkpoint: Optional[Checkpoint] = None
): ) -> Checkpoint:
assert schedule_config.schedule_type in ["cosine", "sgdr"] assert schedule_config.schedule_type in ["cosine", "sgdr"]
assert train_config.train_type in ["seq", "sft", "dpo"] assert train_config.train_type in ["seq", "sft", "dpo"]
if train_checkpoint: if train_checkpoint:
loss_list, last_ckpt_iter = self.load_checkpoint(train_checkpoint) self.checkpoint = train_checkpoint
current_iter = train_checkpoint.current_iter + 1 train_config.optimizer.load_state_dict(train_checkpoint.optim_state)
self.logger.info(f"Resuming training from checkpoint: iter {current_iter}")
else: loss_list = self.checkpoint.loss_list
current_iter = 0 current_iter = len(self.checkpoint.loss_list)
last_ckpt_iter = 0 last_ckpt_iter = current_iter
loss_list = []
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn( lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
**schedule_config.get_kwargs() **schedule_config.get_kwargs()
) )
strategy_kwargs = {
"bos_token_id": self.checkpoint.tokenizer.bos_id,
"eos_token_id": self.checkpoint.tokenizer.eos_id,
"pad_token_id": self.checkpoint.tokenizer.pad_id,
"dpo_beta": train_config.dpo_beta
}
strategy = StrategyFactory.load( strategy = StrategyFactory.load(
self.model, self.checkpoint.model,
train_type=train_config.train_type, train_config.train_type,
bos_token_id=self.tokenizer.bos_id, **strategy_kwargs
eos_token_id=self.tokenizer.eos_id,
pad_token_id=self.tokenizer.pad_id,
dpo_beta=train_config.dpo_beta
) )
scheduler = LambdaLR( scheduler = LambdaLR(
@ -104,11 +78,8 @@ class Trainer:
sampler = RandomSampler(train_config.dataset, generator=generator) sampler = RandomSampler(train_config.dataset, generator=generator)
remaining_epochs = train_config.n_epoch - current_iter // (len(train_config.dataset) // train_config.batch_size) remaining_epochs = train_config.n_epoch - current_iter // (len(train_config.dataset) // train_config.batch_size)
self.logger.info(f"Starting {train_config.train_type.upper()} training for {train_config.n_epoch} epochs")
self.logger.info(f"Checkpoint interval: {train_config.n_iter_ckpt} iterations")
for epoch in range(remaining_epochs): for epoch in range(remaining_epochs):
self.model.train() self.checkpoint.model.train()
dataloader = DataLoader( dataloader = DataLoader(
train_config.dataset, train_config.dataset,
batch_size=train_config.batch_size, batch_size=train_config.batch_size,
@ -128,7 +99,7 @@ class Trainer:
#step #step
if current_iter % train_config.n_iter_step == 0: if current_iter % train_config.n_iter_step == 0:
clip_grad_norm_( clip_grad_norm_(
self.model.parameters(), self.checkpoint.model.parameters(),
train_config.max_grad_norm train_config.max_grad_norm
) )
train_config.optimizer.step() train_config.optimizer.step()
@ -142,28 +113,11 @@ class Trainer:
}) })
#save checkpotint #save checkpotint
if current_iter - last_ckpt_iter >= train_config.n_iter_ckpt: if current_iter - last_ckpt_iter >= train_config.n_iter_ckpt:
last_ckpt_iter = self.save_checkpoint( self.save_checkpoint(loss_list, train_config)
loss_list, last_ckpt_iter = current_iter
train_config.ckpt_dir,
current_iter,
last_ckpt_iter
)
if current_iter != last_ckpt_iter: if current_iter != last_ckpt_iter:
last_ckpt_iter = self.save_checkpoint( self.save_checkpoint(loss_list, train_config)
loss_list, last_ckpt_iter = current_iter
train_config.ckpt_dir,
current_iter,
last_ckpt_iter
)
self.logger.info("Training completed") return self.checkpoint
return Checkpoint(
self.model,
self.tokenizer,
self.config,
loss_list,
current_iter,
train_config.optimizer
)