AstrAI/khaosz/trainer/trainer.py

168 lines
5.7 KiB
Python

import os
import torch
import logging
from typing import Tuple
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, RandomSampler
from tqdm import tqdm
from khaosz.core import ModelParameter, Checkpoint
from khaosz.trainer.strategy import SchedulerFactory, StrategyFactory, TrainConfig, ScheduleConfig
class Trainer:
def __init__(
self,
parameter: ModelParameter,
log_path: str="./train_log.log"
):
logger = logging.getLogger()
logger.setLevel(level = logging.INFO)
handler = logging.FileHandler(log_path)
handler.setLevel(logging.INFO)
handler.setFormatter(logging.Formatter('%(asctime)s: %(message)s'))
logger.addHandler(handler)
logger.info("initializing trainer ...")
self.logger = logger
self.model = parameter.model
self.tokenizer = parameter.tokenizer
self.config = parameter.config
def save_checkpoint(
self,
loss_list: list,
ckpt_dir: str,
current_iter: int,
last_ckpt_iter: int
):
save_path = os.path.join(ckpt_dir, f"iter_{current_iter}")
Checkpoint(
self.model,
self.tokenizer,
self.config,
loss_list,
current_iter
).save(save_path)
diff_iter = current_iter - last_ckpt_iter
avg_loss = sum(loss_list[last_ckpt_iter:current_iter]) / diff_iter
self.logger.info(f"iter: {current_iter} loss: {avg_loss}")
return current_iter
def load_checkpoint(self, train_checkpoint: Checkpoint) -> Tuple[list, int]:
self.model = train_checkpoint.model
self.tokenizer = train_checkpoint.tokenizer
self.config = train_checkpoint.config
loss_list = train_checkpoint.loss_list
last_ckpt_iter = train_checkpoint.current_iter
return loss_list, last_ckpt_iter
def train(
self,
train_config: TrainConfig,
schedule_config: ScheduleConfig,
train_checkpoint: Checkpoint = None
):
assert schedule_config.schedule_type in ["cosine", "sgdr"]
assert train_config.train_type in ["seq", "sft", "dpo"]
if train_checkpoint:
loss_list, last_ckpt_iter = self.load_checkpoint(train_checkpoint)
current_iter = train_checkpoint.current_iter + 1
self.logger.info(f"Resuming training from checkpoint: iter {current_iter}")
else:
current_iter = 0
last_ckpt_iter = 0
loss_list = []
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
**schedule_config.get_kwargs()
)
strategy = StrategyFactory.load(
self.model,
train_config.train_type,
self.tokenizer.pad_id,
train_config.dpo_beta
)
scheduler = LambdaLR(
train_config.optimizer,
lambda_scheduler_fn,
last_epoch=current_iter - 1 if train_checkpoint else -1
)
seed = train_config.random_seed
generator = torch.Generator().manual_seed(seed)
sampler = RandomSampler(train_config.dataset, generator=generator)
remaining_epochs = train_config.n_epoch - current_iter // (len(train_config.dataset) // train_config.batch_size)
self.logger.info(f"Starting {train_config.train_type.upper()} training for {train_config.n_epoch} epochs")
self.logger.info(f"Checkpoint interval: {train_config.n_iter_ckpt} iterations")
for epoch in range(remaining_epochs):
self.model.train()
dataloader = DataLoader(
train_config.dataset,
batch_size=train_config.batch_size,
sampler=sampler
)
progress_bar = tqdm(
dataloader,
desc=f"Epoch {epoch+1}/{train_config.n_epoch}",
dynamic_ncols=True
)
for batch in progress_bar:
#forward
loss = strategy(batch)
loss_list.append(loss.item())
#backward
loss.backward()
#step
if current_iter % train_config.n_iter_step == 0:
clip_grad_norm_(
self.model.parameters(),
train_config.max_grad_norm
)
train_config.optimizer.step()
train_config.optimizer.zero_grad()
current_iter += 1
scheduler.step()
progress_bar.set_postfix({
"loss": f"{loss.item():.4f}",
"lr": f"{train_config.optimizer.param_groups[0]['lr']:.2e}"
})
#save checkpotint
if current_iter - last_ckpt_iter >= train_config.n_iter_ckpt:
last_ckpt_iter = self.save_checkpoint(
loss_list,
train_config.ckpt_dir,
current_iter,
last_ckpt_iter
)
if current_iter != last_ckpt_iter:
last_ckpt_iter = self.save_checkpoint(
loss_list,
train_config.ckpt_dir,
current_iter,
last_ckpt_iter
)
self.logger.info("Training completed")
return Checkpoint(
self.model,
self.tokenizer,
self.config,
loss_list,
current_iter,
train_config.optimizer
)