feat(khaosz): 优化模型参数保存与加载逻辑
This commit is contained in:
parent
4fcdc87c95
commit
f25a249291
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
|||
import torch.optim as optim
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Self, Union
|
||||
from typing import Any, Dict, List, Optional, Self, Union
|
||||
from pathlib import Path
|
||||
|
||||
from khaosz.core.tokenizer import BpeTokenizer
|
||||
|
|
@ -84,11 +84,9 @@ class ModelParameter(BaseModelIO):
|
|||
)
|
||||
|
||||
def save(self, save_dir: Union[str, Path]):
|
||||
"""Save model parameters."""
|
||||
self.save_components(save_dir)
|
||||
|
||||
def load(self, load_dir: Union[str, Path]) -> Self:
|
||||
"""Load model parameters."""
|
||||
return self.load_components(load_dir)
|
||||
|
||||
|
||||
|
|
@ -108,26 +106,16 @@ class Checkpoint(BaseModelIO):
|
|||
default_factory=TransformerConfig,
|
||||
metadata={"help": "Transformer model configuration."}
|
||||
)
|
||||
loss_list: list[float] = field(
|
||||
default_factory=list,
|
||||
metadata={"help": "List of training losses."}
|
||||
)
|
||||
current_iter: int = field(
|
||||
default=0,
|
||||
metadata={"help": "Current training iteration."}
|
||||
)
|
||||
optimizer: Optional[optim.Optimizer] = field(
|
||||
optim_state: Dict[str, Any] = field(
|
||||
default=None,
|
||||
metadata={"help": "Optimizer state."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# Ensure current_iter matches loss list length if not explicitly set
|
||||
if self.current_iter == 0 and self.loss_list:
|
||||
self.current_iter = len(self.loss_list)
|
||||
loss_list: List[float] = field(
|
||||
default_factory=list,
|
||||
metadata={"help": "List of training losses."}
|
||||
)
|
||||
|
||||
def _get_training_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
|
||||
"""Get file paths for training-specific files."""
|
||||
paths = self._get_file_paths(directory)
|
||||
paths.update({
|
||||
"loss_list": paths["model"].parent / "loss.pkl",
|
||||
|
|
@ -137,7 +125,6 @@ class Checkpoint(BaseModelIO):
|
|||
return paths
|
||||
|
||||
def save_training_state(self, save_dir: Union[str, Path]):
|
||||
"""Save training-specific state."""
|
||||
paths = self._get_training_paths(save_dir)
|
||||
|
||||
# Save loss plot
|
||||
|
|
@ -148,25 +135,21 @@ class Checkpoint(BaseModelIO):
|
|||
pkl.dump(self.loss_list, f)
|
||||
|
||||
# Save optimizer state
|
||||
if self.optimizer is not None:
|
||||
with open(str(paths["optimizer"]), "wb") as f:
|
||||
pkl.dump(self.optimizer.state_dict(), f)
|
||||
pkl.dump(self.optim_state, f)
|
||||
|
||||
def load_training_state(self, load_dir: Union[str, Path]) -> Self:
|
||||
"""Load training-specific state."""
|
||||
paths = self._get_training_paths(load_dir)
|
||||
|
||||
# Load loss list
|
||||
if paths["loss_list"].exists():
|
||||
with open(str(paths["loss_list"]), "rb") as f:
|
||||
self.loss_list = pkl.load(f)
|
||||
self.current_iter = len(self.loss_list)
|
||||
|
||||
# Load optimizer state
|
||||
if paths["optimizer"].exists() and self.optimizer is not None:
|
||||
if paths["optimizer"].exists():
|
||||
with open(str(paths["optimizer"]), "rb") as f:
|
||||
optim_state = pkl.load(f)
|
||||
self.optimizer.load_state_dict(optim_state)
|
||||
self.optim_state = pkl.load(f)
|
||||
|
||||
return self
|
||||
|
||||
|
|
@ -175,9 +158,11 @@ class Checkpoint(BaseModelIO):
|
|||
if not self.loss_list:
|
||||
return
|
||||
|
||||
current_iter = len(self.loss_list)
|
||||
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(self.loss_list)
|
||||
plt.title(f"Training Loss - Iteration {self.current_iter}")
|
||||
plt.title(f"Training Loss - Iteration {current_iter}")
|
||||
plt.xlabel("Batch")
|
||||
plt.ylabel("Loss")
|
||||
plt.grid(True)
|
||||
|
|
@ -224,7 +209,7 @@ class ParameterLoader:
|
|||
tokenizer: BpeTokenizer,
|
||||
config: TransformerConfig,
|
||||
loss_list: Optional[list[float]] = None,
|
||||
optimizer: Optional[optim.Optimizer] = None
|
||||
optimizer: Optional[optim.Optimizer] = None,
|
||||
) -> Checkpoint:
|
||||
"""Convenience method to create a training checkpoint."""
|
||||
return Checkpoint(
|
||||
|
|
@ -232,7 +217,7 @@ class ParameterLoader:
|
|||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
loss_list=loss_list or [],
|
||||
optimizer=optimizer
|
||||
optimizer_state=optimizer
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
import os
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from typing import Tuple
|
||||
from typing import Optional
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
|
|
@ -15,82 +14,57 @@ from khaosz.trainer.strategy import SchedulerFactory, StrategyFactory, TrainConf
|
|||
class Trainer:
|
||||
def __init__(
|
||||
self,
|
||||
parameter: ModelParameter,
|
||||
log_path: str="./train_log.log"
|
||||
parameter: ModelParameter
|
||||
):
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(level = logging.INFO)
|
||||
handler = logging.FileHandler(log_path)
|
||||
handler.setLevel(logging.INFO)
|
||||
handler.setFormatter(logging.Formatter('%(asctime)s: %(message)s'))
|
||||
logger.addHandler(handler)
|
||||
logger.info("initializing trainer ...")
|
||||
|
||||
self.logger = logger
|
||||
self.model = parameter.model
|
||||
self.tokenizer = parameter.tokenizer
|
||||
self.config = parameter.config
|
||||
self.checkpoint = Checkpoint(
|
||||
model=parameter.model,
|
||||
tokenizer=parameter.tokenizer,
|
||||
config=parameter.config,
|
||||
)
|
||||
|
||||
def save_checkpoint(
|
||||
self,
|
||||
loss_list: list,
|
||||
ckpt_dir: str,
|
||||
current_iter: int,
|
||||
last_ckpt_iter: int
|
||||
train_config: TrainConfig
|
||||
):
|
||||
save_path = os.path.join(ckpt_dir, f"iter_{current_iter}")
|
||||
Checkpoint(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
self.config,
|
||||
loss_list,
|
||||
current_iter
|
||||
).save(save_path)
|
||||
|
||||
diff_iter = current_iter - last_ckpt_iter
|
||||
avg_loss = sum(loss_list[last_ckpt_iter:current_iter]) / diff_iter
|
||||
self.logger.info(f"iter: {current_iter} loss: {avg_loss}")
|
||||
|
||||
return current_iter
|
||||
|
||||
def load_checkpoint(self, train_checkpoint: Checkpoint) -> Tuple[list, int]:
|
||||
self.model = train_checkpoint.model
|
||||
self.tokenizer = train_checkpoint.tokenizer
|
||||
self.config = train_checkpoint.config
|
||||
loss_list = train_checkpoint.loss_list
|
||||
last_ckpt_iter = train_checkpoint.current_iter
|
||||
|
||||
return loss_list, last_ckpt_iter
|
||||
current_iter = len(loss_list)
|
||||
save_path = os.path.join(train_config.ckpt_dir, f"iter_{current_iter}")
|
||||
self.checkpoint.loss_list = loss_list
|
||||
self.checkpoint.optim_state = train_config.optimizer.state_dict()
|
||||
self.checkpoint.save(save_path)
|
||||
|
||||
def train(
|
||||
self,
|
||||
train_config: TrainConfig,
|
||||
schedule_config: ScheduleConfig,
|
||||
train_checkpoint: Checkpoint = None
|
||||
):
|
||||
train_checkpoint: Optional[Checkpoint] = None
|
||||
) -> Checkpoint:
|
||||
assert schedule_config.schedule_type in ["cosine", "sgdr"]
|
||||
assert train_config.train_type in ["seq", "sft", "dpo"]
|
||||
|
||||
if train_checkpoint:
|
||||
loss_list, last_ckpt_iter = self.load_checkpoint(train_checkpoint)
|
||||
current_iter = train_checkpoint.current_iter + 1
|
||||
self.logger.info(f"Resuming training from checkpoint: iter {current_iter}")
|
||||
else:
|
||||
current_iter = 0
|
||||
last_ckpt_iter = 0
|
||||
loss_list = []
|
||||
self.checkpoint = train_checkpoint
|
||||
train_config.optimizer.load_state_dict(train_checkpoint.optim_state)
|
||||
|
||||
loss_list = self.checkpoint.loss_list
|
||||
current_iter = len(self.checkpoint.loss_list)
|
||||
last_ckpt_iter = current_iter
|
||||
|
||||
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
|
||||
**schedule_config.get_kwargs()
|
||||
)
|
||||
|
||||
strategy_kwargs = {
|
||||
"bos_token_id": self.checkpoint.tokenizer.bos_id,
|
||||
"eos_token_id": self.checkpoint.tokenizer.eos_id,
|
||||
"pad_token_id": self.checkpoint.tokenizer.pad_id,
|
||||
"dpo_beta": train_config.dpo_beta
|
||||
}
|
||||
|
||||
strategy = StrategyFactory.load(
|
||||
self.model,
|
||||
train_type=train_config.train_type,
|
||||
bos_token_id=self.tokenizer.bos_id,
|
||||
eos_token_id=self.tokenizer.eos_id,
|
||||
pad_token_id=self.tokenizer.pad_id,
|
||||
dpo_beta=train_config.dpo_beta
|
||||
self.checkpoint.model,
|
||||
train_config.train_type,
|
||||
**strategy_kwargs
|
||||
)
|
||||
|
||||
scheduler = LambdaLR(
|
||||
|
|
@ -104,11 +78,8 @@ class Trainer:
|
|||
sampler = RandomSampler(train_config.dataset, generator=generator)
|
||||
remaining_epochs = train_config.n_epoch - current_iter // (len(train_config.dataset) // train_config.batch_size)
|
||||
|
||||
self.logger.info(f"Starting {train_config.train_type.upper()} training for {train_config.n_epoch} epochs")
|
||||
self.logger.info(f"Checkpoint interval: {train_config.n_iter_ckpt} iterations")
|
||||
|
||||
for epoch in range(remaining_epochs):
|
||||
self.model.train()
|
||||
self.checkpoint.model.train()
|
||||
dataloader = DataLoader(
|
||||
train_config.dataset,
|
||||
batch_size=train_config.batch_size,
|
||||
|
|
@ -128,7 +99,7 @@ class Trainer:
|
|||
#step
|
||||
if current_iter % train_config.n_iter_step == 0:
|
||||
clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.checkpoint.model.parameters(),
|
||||
train_config.max_grad_norm
|
||||
)
|
||||
train_config.optimizer.step()
|
||||
|
|
@ -142,28 +113,11 @@ class Trainer:
|
|||
})
|
||||
#save checkpotint
|
||||
if current_iter - last_ckpt_iter >= train_config.n_iter_ckpt:
|
||||
last_ckpt_iter = self.save_checkpoint(
|
||||
loss_list,
|
||||
train_config.ckpt_dir,
|
||||
current_iter,
|
||||
last_ckpt_iter
|
||||
)
|
||||
self.save_checkpoint(loss_list, train_config)
|
||||
last_ckpt_iter = current_iter
|
||||
|
||||
if current_iter != last_ckpt_iter:
|
||||
last_ckpt_iter = self.save_checkpoint(
|
||||
loss_list,
|
||||
train_config.ckpt_dir,
|
||||
current_iter,
|
||||
last_ckpt_iter
|
||||
)
|
||||
self.save_checkpoint(loss_list, train_config)
|
||||
last_ckpt_iter = current_iter
|
||||
|
||||
self.logger.info("Training completed")
|
||||
|
||||
return Checkpoint(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
self.config,
|
||||
loss_list,
|
||||
current_iter,
|
||||
train_config.optimizer
|
||||
)
|
||||
return self.checkpoint
|
||||
Loading…
Reference in New Issue