feat(khaosz): 优化模型参数保存与加载逻辑
This commit is contained in:
parent
4fcdc87c95
commit
f25a249291
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional, Self, Union
|
from typing import Any, Dict, List, Optional, Self, Union
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from khaosz.core.tokenizer import BpeTokenizer
|
from khaosz.core.tokenizer import BpeTokenizer
|
||||||
|
|
@ -84,11 +84,9 @@ class ModelParameter(BaseModelIO):
|
||||||
)
|
)
|
||||||
|
|
||||||
def save(self, save_dir: Union[str, Path]):
|
def save(self, save_dir: Union[str, Path]):
|
||||||
"""Save model parameters."""
|
|
||||||
self.save_components(save_dir)
|
self.save_components(save_dir)
|
||||||
|
|
||||||
def load(self, load_dir: Union[str, Path]) -> Self:
|
def load(self, load_dir: Union[str, Path]) -> Self:
|
||||||
"""Load model parameters."""
|
|
||||||
return self.load_components(load_dir)
|
return self.load_components(load_dir)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -108,26 +106,16 @@ class Checkpoint(BaseModelIO):
|
||||||
default_factory=TransformerConfig,
|
default_factory=TransformerConfig,
|
||||||
metadata={"help": "Transformer model configuration."}
|
metadata={"help": "Transformer model configuration."}
|
||||||
)
|
)
|
||||||
loss_list: list[float] = field(
|
optim_state: Dict[str, Any] = field(
|
||||||
default_factory=list,
|
|
||||||
metadata={"help": "List of training losses."}
|
|
||||||
)
|
|
||||||
current_iter: int = field(
|
|
||||||
default=0,
|
|
||||||
metadata={"help": "Current training iteration."}
|
|
||||||
)
|
|
||||||
optimizer: Optional[optim.Optimizer] = field(
|
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Optimizer state."}
|
metadata={"help": "Optimizer state."}
|
||||||
)
|
)
|
||||||
|
loss_list: List[float] = field(
|
||||||
def __post_init__(self):
|
default_factory=list,
|
||||||
# Ensure current_iter matches loss list length if not explicitly set
|
metadata={"help": "List of training losses."}
|
||||||
if self.current_iter == 0 and self.loss_list:
|
)
|
||||||
self.current_iter = len(self.loss_list)
|
|
||||||
|
|
||||||
def _get_training_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
|
def _get_training_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
|
||||||
"""Get file paths for training-specific files."""
|
|
||||||
paths = self._get_file_paths(directory)
|
paths = self._get_file_paths(directory)
|
||||||
paths.update({
|
paths.update({
|
||||||
"loss_list": paths["model"].parent / "loss.pkl",
|
"loss_list": paths["model"].parent / "loss.pkl",
|
||||||
|
|
@ -137,7 +125,6 @@ class Checkpoint(BaseModelIO):
|
||||||
return paths
|
return paths
|
||||||
|
|
||||||
def save_training_state(self, save_dir: Union[str, Path]):
|
def save_training_state(self, save_dir: Union[str, Path]):
|
||||||
"""Save training-specific state."""
|
|
||||||
paths = self._get_training_paths(save_dir)
|
paths = self._get_training_paths(save_dir)
|
||||||
|
|
||||||
# Save loss plot
|
# Save loss plot
|
||||||
|
|
@ -148,25 +135,21 @@ class Checkpoint(BaseModelIO):
|
||||||
pkl.dump(self.loss_list, f)
|
pkl.dump(self.loss_list, f)
|
||||||
|
|
||||||
# Save optimizer state
|
# Save optimizer state
|
||||||
if self.optimizer is not None:
|
with open(str(paths["optimizer"]), "wb") as f:
|
||||||
with open(str(paths["optimizer"]), "wb") as f:
|
pkl.dump(self.optim_state, f)
|
||||||
pkl.dump(self.optimizer.state_dict(), f)
|
|
||||||
|
|
||||||
def load_training_state(self, load_dir: Union[str, Path]) -> Self:
|
def load_training_state(self, load_dir: Union[str, Path]) -> Self:
|
||||||
"""Load training-specific state."""
|
|
||||||
paths = self._get_training_paths(load_dir)
|
paths = self._get_training_paths(load_dir)
|
||||||
|
|
||||||
# Load loss list
|
# Load loss list
|
||||||
if paths["loss_list"].exists():
|
if paths["loss_list"].exists():
|
||||||
with open(str(paths["loss_list"]), "rb") as f:
|
with open(str(paths["loss_list"]), "rb") as f:
|
||||||
self.loss_list = pkl.load(f)
|
self.loss_list = pkl.load(f)
|
||||||
self.current_iter = len(self.loss_list)
|
|
||||||
|
|
||||||
# Load optimizer state
|
# Load optimizer state
|
||||||
if paths["optimizer"].exists() and self.optimizer is not None:
|
if paths["optimizer"].exists():
|
||||||
with open(str(paths["optimizer"]), "rb") as f:
|
with open(str(paths["optimizer"]), "rb") as f:
|
||||||
optim_state = pkl.load(f)
|
self.optim_state = pkl.load(f)
|
||||||
self.optimizer.load_state_dict(optim_state)
|
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
@ -175,9 +158,11 @@ class Checkpoint(BaseModelIO):
|
||||||
if not self.loss_list:
|
if not self.loss_list:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
current_iter = len(self.loss_list)
|
||||||
|
|
||||||
plt.figure(figsize=(10, 6))
|
plt.figure(figsize=(10, 6))
|
||||||
plt.plot(self.loss_list)
|
plt.plot(self.loss_list)
|
||||||
plt.title(f"Training Loss - Iteration {self.current_iter}")
|
plt.title(f"Training Loss - Iteration {current_iter}")
|
||||||
plt.xlabel("Batch")
|
plt.xlabel("Batch")
|
||||||
plt.ylabel("Loss")
|
plt.ylabel("Loss")
|
||||||
plt.grid(True)
|
plt.grid(True)
|
||||||
|
|
@ -224,7 +209,7 @@ class ParameterLoader:
|
||||||
tokenizer: BpeTokenizer,
|
tokenizer: BpeTokenizer,
|
||||||
config: TransformerConfig,
|
config: TransformerConfig,
|
||||||
loss_list: Optional[list[float]] = None,
|
loss_list: Optional[list[float]] = None,
|
||||||
optimizer: Optional[optim.Optimizer] = None
|
optimizer: Optional[optim.Optimizer] = None,
|
||||||
) -> Checkpoint:
|
) -> Checkpoint:
|
||||||
"""Convenience method to create a training checkpoint."""
|
"""Convenience method to create a training checkpoint."""
|
||||||
return Checkpoint(
|
return Checkpoint(
|
||||||
|
|
@ -232,7 +217,7 @@ class ParameterLoader:
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
config=config,
|
config=config,
|
||||||
loss_list=loss_list or [],
|
loss_list=loss_list or [],
|
||||||
optimizer=optimizer
|
optimizer_state=optimizer
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import logging
|
|
||||||
|
|
||||||
from typing import Tuple
|
from typing import Optional
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from torch.optim.lr_scheduler import LambdaLR
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
from torch.utils.data import DataLoader, RandomSampler
|
from torch.utils.data import DataLoader, RandomSampler
|
||||||
|
|
@ -15,82 +14,57 @@ from khaosz.trainer.strategy import SchedulerFactory, StrategyFactory, TrainConf
|
||||||
class Trainer:
|
class Trainer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
parameter: ModelParameter,
|
parameter: ModelParameter
|
||||||
log_path: str="./train_log.log"
|
|
||||||
):
|
):
|
||||||
logger = logging.getLogger()
|
self.checkpoint = Checkpoint(
|
||||||
logger.setLevel(level = logging.INFO)
|
model=parameter.model,
|
||||||
handler = logging.FileHandler(log_path)
|
tokenizer=parameter.tokenizer,
|
||||||
handler.setLevel(logging.INFO)
|
config=parameter.config,
|
||||||
handler.setFormatter(logging.Formatter('%(asctime)s: %(message)s'))
|
)
|
||||||
logger.addHandler(handler)
|
|
||||||
logger.info("initializing trainer ...")
|
|
||||||
|
|
||||||
self.logger = logger
|
|
||||||
self.model = parameter.model
|
|
||||||
self.tokenizer = parameter.tokenizer
|
|
||||||
self.config = parameter.config
|
|
||||||
|
|
||||||
def save_checkpoint(
|
def save_checkpoint(
|
||||||
self,
|
self,
|
||||||
loss_list: list,
|
loss_list: list,
|
||||||
ckpt_dir: str,
|
train_config: TrainConfig
|
||||||
current_iter: int,
|
|
||||||
last_ckpt_iter: int
|
|
||||||
):
|
):
|
||||||
save_path = os.path.join(ckpt_dir, f"iter_{current_iter}")
|
current_iter = len(loss_list)
|
||||||
Checkpoint(
|
save_path = os.path.join(train_config.ckpt_dir, f"iter_{current_iter}")
|
||||||
self.model,
|
self.checkpoint.loss_list = loss_list
|
||||||
self.tokenizer,
|
self.checkpoint.optim_state = train_config.optimizer.state_dict()
|
||||||
self.config,
|
self.checkpoint.save(save_path)
|
||||||
loss_list,
|
|
||||||
current_iter
|
|
||||||
).save(save_path)
|
|
||||||
|
|
||||||
diff_iter = current_iter - last_ckpt_iter
|
|
||||||
avg_loss = sum(loss_list[last_ckpt_iter:current_iter]) / diff_iter
|
|
||||||
self.logger.info(f"iter: {current_iter} loss: {avg_loss}")
|
|
||||||
|
|
||||||
return current_iter
|
|
||||||
|
|
||||||
def load_checkpoint(self, train_checkpoint: Checkpoint) -> Tuple[list, int]:
|
|
||||||
self.model = train_checkpoint.model
|
|
||||||
self.tokenizer = train_checkpoint.tokenizer
|
|
||||||
self.config = train_checkpoint.config
|
|
||||||
loss_list = train_checkpoint.loss_list
|
|
||||||
last_ckpt_iter = train_checkpoint.current_iter
|
|
||||||
|
|
||||||
return loss_list, last_ckpt_iter
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
self,
|
self,
|
||||||
train_config: TrainConfig,
|
train_config: TrainConfig,
|
||||||
schedule_config: ScheduleConfig,
|
schedule_config: ScheduleConfig,
|
||||||
train_checkpoint: Checkpoint = None
|
train_checkpoint: Optional[Checkpoint] = None
|
||||||
):
|
) -> Checkpoint:
|
||||||
assert schedule_config.schedule_type in ["cosine", "sgdr"]
|
assert schedule_config.schedule_type in ["cosine", "sgdr"]
|
||||||
assert train_config.train_type in ["seq", "sft", "dpo"]
|
assert train_config.train_type in ["seq", "sft", "dpo"]
|
||||||
|
|
||||||
if train_checkpoint:
|
if train_checkpoint:
|
||||||
loss_list, last_ckpt_iter = self.load_checkpoint(train_checkpoint)
|
self.checkpoint = train_checkpoint
|
||||||
current_iter = train_checkpoint.current_iter + 1
|
train_config.optimizer.load_state_dict(train_checkpoint.optim_state)
|
||||||
self.logger.info(f"Resuming training from checkpoint: iter {current_iter}")
|
|
||||||
else:
|
loss_list = self.checkpoint.loss_list
|
||||||
current_iter = 0
|
current_iter = len(self.checkpoint.loss_list)
|
||||||
last_ckpt_iter = 0
|
last_ckpt_iter = current_iter
|
||||||
loss_list = []
|
|
||||||
|
|
||||||
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
|
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
|
||||||
**schedule_config.get_kwargs()
|
**schedule_config.get_kwargs()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
strategy_kwargs = {
|
||||||
|
"bos_token_id": self.checkpoint.tokenizer.bos_id,
|
||||||
|
"eos_token_id": self.checkpoint.tokenizer.eos_id,
|
||||||
|
"pad_token_id": self.checkpoint.tokenizer.pad_id,
|
||||||
|
"dpo_beta": train_config.dpo_beta
|
||||||
|
}
|
||||||
|
|
||||||
strategy = StrategyFactory.load(
|
strategy = StrategyFactory.load(
|
||||||
self.model,
|
self.checkpoint.model,
|
||||||
train_type=train_config.train_type,
|
train_config.train_type,
|
||||||
bos_token_id=self.tokenizer.bos_id,
|
**strategy_kwargs
|
||||||
eos_token_id=self.tokenizer.eos_id,
|
|
||||||
pad_token_id=self.tokenizer.pad_id,
|
|
||||||
dpo_beta=train_config.dpo_beta
|
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler = LambdaLR(
|
scheduler = LambdaLR(
|
||||||
|
|
@ -104,11 +78,8 @@ class Trainer:
|
||||||
sampler = RandomSampler(train_config.dataset, generator=generator)
|
sampler = RandomSampler(train_config.dataset, generator=generator)
|
||||||
remaining_epochs = train_config.n_epoch - current_iter // (len(train_config.dataset) // train_config.batch_size)
|
remaining_epochs = train_config.n_epoch - current_iter // (len(train_config.dataset) // train_config.batch_size)
|
||||||
|
|
||||||
self.logger.info(f"Starting {train_config.train_type.upper()} training for {train_config.n_epoch} epochs")
|
|
||||||
self.logger.info(f"Checkpoint interval: {train_config.n_iter_ckpt} iterations")
|
|
||||||
|
|
||||||
for epoch in range(remaining_epochs):
|
for epoch in range(remaining_epochs):
|
||||||
self.model.train()
|
self.checkpoint.model.train()
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
train_config.dataset,
|
train_config.dataset,
|
||||||
batch_size=train_config.batch_size,
|
batch_size=train_config.batch_size,
|
||||||
|
|
@ -128,7 +99,7 @@ class Trainer:
|
||||||
#step
|
#step
|
||||||
if current_iter % train_config.n_iter_step == 0:
|
if current_iter % train_config.n_iter_step == 0:
|
||||||
clip_grad_norm_(
|
clip_grad_norm_(
|
||||||
self.model.parameters(),
|
self.checkpoint.model.parameters(),
|
||||||
train_config.max_grad_norm
|
train_config.max_grad_norm
|
||||||
)
|
)
|
||||||
train_config.optimizer.step()
|
train_config.optimizer.step()
|
||||||
|
|
@ -142,28 +113,11 @@ class Trainer:
|
||||||
})
|
})
|
||||||
#save checkpotint
|
#save checkpotint
|
||||||
if current_iter - last_ckpt_iter >= train_config.n_iter_ckpt:
|
if current_iter - last_ckpt_iter >= train_config.n_iter_ckpt:
|
||||||
last_ckpt_iter = self.save_checkpoint(
|
self.save_checkpoint(loss_list, train_config)
|
||||||
loss_list,
|
last_ckpt_iter = current_iter
|
||||||
train_config.ckpt_dir,
|
|
||||||
current_iter,
|
|
||||||
last_ckpt_iter
|
|
||||||
)
|
|
||||||
|
|
||||||
if current_iter != last_ckpt_iter:
|
if current_iter != last_ckpt_iter:
|
||||||
last_ckpt_iter = self.save_checkpoint(
|
self.save_checkpoint(loss_list, train_config)
|
||||||
loss_list,
|
last_ckpt_iter = current_iter
|
||||||
train_config.ckpt_dir,
|
|
||||||
current_iter,
|
|
||||||
last_ckpt_iter
|
|
||||||
)
|
|
||||||
|
|
||||||
self.logger.info("Training completed")
|
return self.checkpoint
|
||||||
|
|
||||||
return Checkpoint(
|
|
||||||
self.model,
|
|
||||||
self.tokenizer,
|
|
||||||
self.config,
|
|
||||||
loss_list,
|
|
||||||
current_iter,
|
|
||||||
train_config.optimizer
|
|
||||||
)
|
|
||||||
Loading…
Reference in New Issue