feat(khaosz/trainer): 优化训练器回调机制与数据采样逻辑

This commit is contained in:
ViperEkura 2025-09-30 16:33:18 +08:00
parent e0e9942e4a
commit 315ce1990a
5 changed files with 118 additions and 85 deletions

View File

@ -8,6 +8,7 @@ from khaosz.trainer.strategy import (
SchedulerFactory SchedulerFactory
) )
from khaosz.trainer.trainer_callback import ( from khaosz.trainer.trainer_callback import (
TrainerCallback,
ProgressBarCallback, ProgressBarCallback,
CheckpointCallback, CheckpointCallback,
TrainerCallback, TrainerCallback,
@ -25,6 +26,7 @@ __all__ = [
"SchedulerFactory", "SchedulerFactory",
# callback # callback
"TrainerCallback",
"ProgressBarCallback", "ProgressBarCallback",
"CheckpointCallback", "CheckpointCallback",
"TrainerCallback", "TrainerCallback",

View File

@ -53,7 +53,8 @@ def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool() causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool()
attention_mask = seq_mask & causal_mask attention_mask = seq_mask & causal_mask
return attention_mask # fix the shape (bsz, 1, seq_len, seq_len) unsqueeze for broadcast
return attention_mask.unsqueeze(0)
class BaseSegmentFetcher: class BaseSegmentFetcher:
@ -118,9 +119,12 @@ class BaseDataset(Dataset, ABC):
self.device = device self.device = device
def save(self, save_path: str): def save(self, save_path: str):
keys = list(self.segments.keys())
if not keys:
return
first_item = self.segments[keys[0]] first_item = self.segments[keys[0]]
segment_size = len(first_item) segment_size = len(first_item)
keys = list(self.segments.keys())
for i in range(segment_size): for i in range(segment_size):
formated_segment = {key: self.segments[key][i] for key in keys} formated_segment = {key: self.segments[key][i] for key in keys}
@ -272,7 +276,7 @@ class RandomSampler(Sampler[int]):
self.data_source = data_source self.data_source = data_source
self.seed = seed self.seed = seed
self.epoch = 0 self.epoch = 0
self.current_index = 0 self.current_iter = 0
self._indices = None self._indices = None
if generator is None: if generator is None:
@ -291,20 +295,21 @@ class RandomSampler(Sampler[int]):
if self._indices is None: if self._indices is None:
self._generate_indices() self._generate_indices()
for i in range(self.current_index, n): for i in range(self.current_iter, n):
yield self._indices[i] yield self._indices[i]
self.current_iter += 1
self.epoch += 1 self.epoch += 1
self.current_index = 0
self._indices = None self._indices = None
def __len__(self): def __len__(self):
return len(self.data_source) - self.current_index n = len(self.data_source)
return n - self.current_iter % n
def state_dict(self): def state_dict(self):
return { return {
'epoch': self.epoch, 'epoch': self.epoch,
'current_index': self.current_index, 'current_iter': self.current_iter,
'seed': self.seed, 'seed': self.seed,
'generator_state': self.generator.get_state() if self.generator else None, 'generator_state': self.generator.get_state() if self.generator else None,
'indices': self._indices 'indices': self._indices
@ -312,7 +317,7 @@ class RandomSampler(Sampler[int]):
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
self.epoch = state_dict['epoch'] self.epoch = state_dict['epoch']
self.current_index = state_dict['current_index'] self.current_iter = state_dict['current_iter']
self.seed = state_dict['seed'] self.seed = state_dict['seed']
if self.generator and state_dict['generator_state'] is not None: if self.generator and state_dict['generator_state'] is not None:

View File

@ -159,12 +159,7 @@ class StrategyFactory:
def load(model, train_type, **kwargs): def load(model, train_type, **kwargs):
train_strategy: Dict[str, Callable[[], BaseStrategy]] = { train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
"seq": lambda: SeqStrategy(model), "seq": lambda: SeqStrategy(model),
"sft": lambda: SftStrategy( "sft": lambda: SftStrategy(model),
model,
kwargs.get("bos_token_id"),
kwargs.get("eos_token_id"),
kwargs.get("multi_turn")
),
"dpo": lambda: DpoStrategy( "dpo": lambda: DpoStrategy(
model, model,
kwargs.get("pad_token_id"), kwargs.get("pad_token_id"),

View File

@ -1,6 +1,4 @@
import torch from typing import Optional, List, cast
import itertools
from typing import Optional, List
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from khaosz.core import ModelParameter, Checkpoint from khaosz.core import ModelParameter, Checkpoint
@ -23,11 +21,7 @@ class Trainer:
schedule_config: ScheduleConfig, schedule_config: ScheduleConfig,
callbacks: Optional[List[TrainerCallback]] = None callbacks: Optional[List[TrainerCallback]] = None
): ):
self.checkpoint = Checkpoint( self.parameter = parameter
model=parameter.model,
tokenizer=parameter.tokenizer,
config=parameter.config,
)
self.train_config = train_config self.train_config = train_config
self.schedule_config = schedule_config self.schedule_config = schedule_config
self.callbacks = callbacks or self._get_default_callbacks() self.callbacks = callbacks or self._get_default_callbacks()
@ -40,24 +34,50 @@ class Trainer:
SchedulerCallback(self.schedule_config), SchedulerCallback(self.schedule_config),
] ]
def _create_dataloader(self, start_index: int = 0) -> DataLoader: def _set_train_kwargs(self, kwargs: dict):
used_epochs = 0
used_iters = 0
seed = self.train_config.random_seed seed = self.train_config.random_seed
generator = torch.Generator().manual_seed(seed) sampler = RandomSampler(data_source=self.train_config.dataset, seed=seed)
sampler = RandomSampler( optim = self.train_config.optimizer
self.train_config.dataset, checkpoint = cast(Checkpoint, kwargs.get('checkpoint', None))
generator=generator,
seed=seed if checkpoint is None:
checkpoint = Checkpoint(
model=self.parameter.model,
tokenizer=self.parameter.tokenizer,
config=self.parameter.config,
sampler_state=None,
optim_state=None,
loss_list=[]
) )
sampler_state = checkpoint.sampler_state
optim_state = checkpoint.optim_state
if sampler_state:
sampler.load_state_dict(sampler_state)
used_epochs = sampler_state.get('epoch', 0)
used_iters = sampler_state.get('iter', 0)
if optim_state:
optim.load_state_dict(optim_state)
checkpoint.optim_state = optim.state_dict()
checkpoint.sampler_state = sampler.state_dict()
dataloader = DataLoader( dataloader = DataLoader(
self.train_config.dataset, self.train_config.dataset,
batch_size=self.train_config.batch_size, batch_size=self.train_config.batch_size,
sampler=sampler sampler=sampler
) )
if start_index > 0: kwargs["dataloader"] = dataloader
dataloader = itertools.islice(dataloader, start_index, None) kwargs["optimizer"] = self.train_config.optimizer
kwargs["epoch"] = used_epochs
return dataloader kwargs["current_iter"] = used_iters
kwargs["sampler"] = sampler
kwargs["checkpoint"] = checkpoint
def _call_callbacks(self, method_name: str, **kwargs): def _call_callbacks(self, method_name: str, **kwargs):
for callback in self.callbacks: for callback in self.callbacks:
@ -67,53 +87,54 @@ class Trainer:
def train( def train(
self, self,
train_checkpoint: Optional[Checkpoint] = None checkpoint: Optional[Checkpoint] = None
) -> Checkpoint: ) -> Checkpoint:
if train_checkpoint:
self.checkpoint = train_checkpoint
self.train_config.optimizer.load_state_dict(train_checkpoint.optim_state)
else:
self.checkpoint.optim_state = self.train_config.optimizer.state_dict()
current_iter = len(self.checkpoint.loss_list)
total_steps_per_epoch = len(self.train_config.dataset) // self.train_config.batch_size
total_reamining_steps = total_steps_per_epoch * self.train_config.n_epoch - current_iter
current_epochs = total_reamining_steps // total_steps_per_epoch
current_steps = total_reamining_steps % total_steps_per_epoch
# train # train
self._call_callbacks('on_train_begin', checkpoint=self.checkpoint) train_kwargs = {
self.checkpoint.model.train() 'checkpoint': checkpoint,
'dataloader': None,
'optimizer': None,
'sampler': None,
'epoch': 0,
'current_iter': 0,
'loss': 0.0,
}
self._set_train_kwargs(train_kwargs)
self._call_callbacks('on_train_begin', **train_kwargs)
dataloader = train_kwargs['dataloader']
checkpoint = train_kwargs['checkpoint']
start_epoch = train_kwargs['epoch']
try: try:
for epoch in range(current_epochs): self.parameter.model.train()
for epoch in range(start_epoch, self.train_config.n_epoch):
# epoch # epoch
self._call_callbacks('on_epoch_begin', epoch=epoch) train_kwargs["epoch"] = epoch
dataloader = self._create_dataloader(start_index=current_steps) self._call_callbacks('on_epoch_begin', **train_kwargs)
for batch in dataloader: for batch in dataloader:
# batch # batch
self._call_callbacks('on_batch_begin', batch=batch) self._call_callbacks('on_batch_begin', **train_kwargs)
loss = self.train_config.strategy(batch) loss = self.train_config.strategy(batch)
self.checkpoint.loss_list.append(loss.item())
loss.backward() loss.backward()
self._call_callbacks('on_batch_end', batch=batch, loss=loss.item(), current_iter=current_iter) train_kwargs["loss"] = loss.item()
self._call_callbacks('on_batch_end', **train_kwargs)
if current_iter % self.train_config.accumulation_steps == 0: if train_kwargs["current_iter"] % self.train_config.accumulation_steps == 0:
# step # step
self._call_callbacks('on_step_begin', current_iter=current_iter) self._call_callbacks('on_step_begin', **train_kwargs)
self.train_config.optimizer.step() self.train_config.optimizer.step()
self.train_config.optimizer.zero_grad() self.train_config.optimizer.zero_grad()
self._call_callbacks('on_step_end', current_iter=current_iter) self._call_callbacks('on_step_end', **train_kwargs)
current_iter += 1 train_kwargs["current_iter"] += 1
self._call_callbacks('on_epoch_end', epoch=epoch, loss_list=self.checkpoint.loss_list) self._call_callbacks('on_epoch_end', **train_kwargs)
except Exception as e: except Exception as e:
raise e raise e
finally: finally:
self._call_callbacks('on_train_end', checkpoint=self.checkpoint) self._call_callbacks('on_train_end', **train_kwargs)
return self.checkpoint return checkpoint

View File

@ -1,9 +1,12 @@
import os import os
import torch.optim as optim
from tqdm import tqdm from tqdm import tqdm
from khaosz.core.parameter import Checkpoint
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import LambdaLR from torch.optim.lr_scheduler import LambdaLR
from typing import Optional, cast, TYPE_CHECKING from typing import Optional, cast, TYPE_CHECKING
from khaosz.core.parameter import Checkpoint
from khaosz.trainer.data_util import RandomSampler
from khaosz.trainer.strategy import ScheduleConfig, SchedulerFactory from khaosz.trainer.strategy import ScheduleConfig, SchedulerFactory
if TYPE_CHECKING: if TYPE_CHECKING:
@ -76,7 +79,7 @@ class ProgressBarCallback(TrainerCallback):
def on_epoch_begin(self, trainer: 'Trainer', **kwargs): def on_epoch_begin(self, trainer: 'Trainer', **kwargs):
epoch = kwargs.get('epoch') epoch = kwargs.get('epoch')
dataloader = trainer._create_dataloader() dataloader = kwargs.get('dataloader')
self.progress_bar = tqdm( self.progress_bar = tqdm(
dataloader, dataloader,
desc=f"Epoch {epoch+1}/{trainer.train_config.n_epoch}", desc=f"Epoch {epoch+1}/{trainer.train_config.n_epoch}",
@ -106,28 +109,36 @@ class CheckpointCallback(TrainerCallback):
self.last_ckpt_iter = 0 self.last_ckpt_iter = 0
@staticmethod @staticmethod
def _save_checkpoint(trainer: 'Trainer'): def _save_checkpoint(trainer: 'Trainer', **kwargs):
current_iter = len(trainer.checkpoint.loss_list) current_iter = kwargs.get('current_iter')
save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{current_iter}") random_sampler = cast(RandomSampler, kwargs.get('sampler'))
trainer.checkpoint.optim_state = trainer.train_config.optimizer.state_dict() optimizer = cast(optim.Optimizer, kwargs.get('optimizer'))
trainer.checkpoint.save(save_path)
def on_train_begin(self, trainer: 'Trainer', **kwargs):
_ = trainer
checkpoint = cast(Checkpoint, kwargs.get('checkpoint')) checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
self.last_ckpt_iter = len(checkpoint.loss_list)
save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{current_iter}")
checkpoint.sampler_state = random_sampler.state_dict()
checkpoint.optim_state = optimizer.state_dict()
checkpoint.sampler_state['epoch'] = kwargs.get('epoch', 0)
checkpoint.sampler_state['current_iter'] = kwargs.get('current_iter', 0)
checkpoint.save(save_path)
def on_batch_end(self, trainer: 'Trainer', **kwargs): def on_batch_end(self, trainer: 'Trainer', **kwargs):
current_iter = kwargs.get('current_iter') current_iter = kwargs.get('current_iter')
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
loss = kwargs.get('loss')
checkpoint.loss_list.append(loss)
if current_iter - self.last_ckpt_iter >= self.checkpoint_interval: if current_iter - self.last_ckpt_iter >= self.checkpoint_interval:
CheckpointCallback._save_checkpoint(trainer) CheckpointCallback._save_checkpoint(trainer, **kwargs)
self.last_ckpt_iter = current_iter self.last_ckpt_iter = current_iter
def on_train_end(self, trainer: 'Trainer', **kwargs): def on_train_end(self, trainer: 'Trainer', **kwargs):
checkpoint = cast(Checkpoint, kwargs.get('checkpoint')) current_iter = kwargs.get('current_iter')
current_iter = len(checkpoint.loss_list)
if current_iter != self.last_ckpt_iter: if current_iter != self.last_ckpt_iter:
CheckpointCallback._save_checkpoint(trainer) CheckpointCallback._save_checkpoint(trainer, **kwargs)
self.last_ckpt_iter = current_iter
class GradientClippingCallback(TrainerCallback): class GradientClippingCallback(TrainerCallback):
@ -137,7 +148,7 @@ class GradientClippingCallback(TrainerCallback):
def on_step_begin(self, trainer: 'Trainer', **kwargs): def on_step_begin(self, trainer: 'Trainer', **kwargs):
_ = kwargs _ = kwargs
clip_grad_norm_( clip_grad_norm_(
trainer.checkpoint.model.parameters(), trainer.parameter.model.parameters(),
trainer.train_config.max_grad_norm trainer.train_config.max_grad_norm
) )
@ -152,8 +163,7 @@ class SchedulerCallback(TrainerCallback):
self.current_iter = 0 self.current_iter = 0
def on_train_begin(self, trainer: 'Trainer', **kwargs): def on_train_begin(self, trainer: 'Trainer', **kwargs):
checkpoint = cast(Checkpoint, kwargs.get('checkpoint')) self.current_iter = kwargs.get('current_iter')
self.current_iter = len(checkpoint.loss_list)
for group in trainer.train_config.optimizer.param_groups: for group in trainer.train_config.optimizer.param_groups:
if "initial_lr" not in group: if "initial_lr" not in group: