168 lines
5.7 KiB
Python
168 lines
5.7 KiB
Python
import os
|
|
import torch
|
|
import logging
|
|
|
|
from typing import Tuple
|
|
from torch.nn.utils import clip_grad_norm_
|
|
from torch.optim.lr_scheduler import LambdaLR
|
|
from torch.utils.data import DataLoader, RandomSampler
|
|
from tqdm import tqdm
|
|
|
|
from khaosz.core import ModelParameter, Checkpoint
|
|
from khaosz.trainer.strategy import SchedulerFactory, StrategyFactory, TrainConfig, ScheduleConfig
|
|
|
|
|
|
class Trainer:
|
|
def __init__(
|
|
self,
|
|
parameter: ModelParameter,
|
|
log_path: str="./train_log.log"
|
|
):
|
|
logger = logging.getLogger()
|
|
logger.setLevel(level = logging.INFO)
|
|
handler = logging.FileHandler(log_path)
|
|
handler.setLevel(logging.INFO)
|
|
handler.setFormatter(logging.Formatter('%(asctime)s: %(message)s'))
|
|
logger.addHandler(handler)
|
|
logger.info("initializing trainer ...")
|
|
|
|
self.logger = logger
|
|
self.model = parameter.model
|
|
self.tokenizer = parameter.tokenizer
|
|
self.config = parameter.config
|
|
|
|
def save_checkpoint(
|
|
self,
|
|
loss_list: list,
|
|
ckpt_dir: str,
|
|
current_iter: int,
|
|
last_ckpt_iter: int
|
|
):
|
|
save_path = os.path.join(ckpt_dir, f"iter_{current_iter}")
|
|
Checkpoint(
|
|
self.model,
|
|
self.tokenizer,
|
|
self.config,
|
|
loss_list,
|
|
current_iter
|
|
).save(save_path)
|
|
|
|
diff_iter = current_iter - last_ckpt_iter
|
|
avg_loss = sum(loss_list[last_ckpt_iter:current_iter]) / diff_iter
|
|
self.logger.info(f"iter: {current_iter} loss: {avg_loss}")
|
|
|
|
return current_iter
|
|
|
|
def load_checkpoint(self, train_checkpoint: Checkpoint) -> Tuple[list, int]:
|
|
self.model = train_checkpoint.model
|
|
self.tokenizer = train_checkpoint.tokenizer
|
|
self.config = train_checkpoint.config
|
|
loss_list = train_checkpoint.loss_list
|
|
last_ckpt_iter = train_checkpoint.current_iter
|
|
|
|
return loss_list, last_ckpt_iter
|
|
|
|
def train(
|
|
self,
|
|
train_config: TrainConfig,
|
|
schedule_config: ScheduleConfig,
|
|
train_checkpoint: Checkpoint = None
|
|
):
|
|
assert schedule_config.schedule_type in ["cosine", "sgdr"]
|
|
assert train_config.train_type in ["seq", "sft", "dpo"]
|
|
|
|
if train_checkpoint:
|
|
loss_list, last_ckpt_iter = self.load_checkpoint(train_checkpoint)
|
|
current_iter = train_checkpoint.current_iter + 1
|
|
self.logger.info(f"Resuming training from checkpoint: iter {current_iter}")
|
|
else:
|
|
current_iter = 0
|
|
last_ckpt_iter = 0
|
|
loss_list = []
|
|
|
|
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
|
|
**schedule_config.get_kwargs()
|
|
)
|
|
|
|
strategy = StrategyFactory.load(
|
|
self.model,
|
|
train_config.train_type,
|
|
self.tokenizer.pad_id,
|
|
train_config.dpo_beta
|
|
)
|
|
|
|
scheduler = LambdaLR(
|
|
train_config.optimizer,
|
|
lambda_scheduler_fn,
|
|
last_epoch=current_iter - 1 if train_checkpoint else -1
|
|
)
|
|
|
|
seed = train_config.random_seed
|
|
generator = torch.Generator().manual_seed(seed)
|
|
sampler = RandomSampler(train_config.dataset, generator=generator)
|
|
remaining_epochs = train_config.n_epoch - current_iter // (len(train_config.dataset) // train_config.batch_size)
|
|
|
|
self.logger.info(f"Starting {train_config.train_type.upper()} training for {train_config.n_epoch} epochs")
|
|
self.logger.info(f"Checkpoint interval: {train_config.n_iter_ckpt} iterations")
|
|
|
|
for epoch in range(remaining_epochs):
|
|
self.model.train()
|
|
dataloader = DataLoader(
|
|
train_config.dataset,
|
|
batch_size=train_config.batch_size,
|
|
sampler=sampler
|
|
)
|
|
progress_bar = tqdm(
|
|
dataloader,
|
|
desc=f"Epoch {epoch+1}/{train_config.n_epoch}",
|
|
dynamic_ncols=True
|
|
)
|
|
for batch in progress_bar:
|
|
#forward
|
|
loss = strategy(batch)
|
|
loss_list.append(loss.item())
|
|
#backward
|
|
loss.backward()
|
|
#step
|
|
if current_iter % train_config.n_iter_step == 0:
|
|
clip_grad_norm_(
|
|
self.model.parameters(),
|
|
train_config.max_grad_norm
|
|
)
|
|
train_config.optimizer.step()
|
|
train_config.optimizer.zero_grad()
|
|
|
|
current_iter += 1
|
|
scheduler.step()
|
|
progress_bar.set_postfix({
|
|
"loss": f"{loss.item():.4f}",
|
|
"lr": f"{train_config.optimizer.param_groups[0]['lr']:.2e}"
|
|
})
|
|
#save checkpotint
|
|
if current_iter - last_ckpt_iter >= train_config.n_iter_ckpt:
|
|
last_ckpt_iter = self.save_checkpoint(
|
|
loss_list,
|
|
train_config.ckpt_dir,
|
|
current_iter,
|
|
last_ckpt_iter
|
|
)
|
|
|
|
if current_iter != last_ckpt_iter:
|
|
last_ckpt_iter = self.save_checkpoint(
|
|
loss_list,
|
|
train_config.ckpt_dir,
|
|
current_iter,
|
|
last_ckpt_iter
|
|
)
|
|
|
|
self.logger.info("Training completed")
|
|
|
|
return Checkpoint(
|
|
self.model,
|
|
self.tokenizer,
|
|
self.config,
|
|
loss_list,
|
|
current_iter,
|
|
train_config.optimizer
|
|
)
|