feat(khaosz/trainer): 优化训练器回调机制与数据采样逻辑

This commit is contained in:
ViperEkura 2025-09-30 16:33:18 +08:00
parent e0e9942e4a
commit 315ce1990a
5 changed files with 118 additions and 85 deletions

View File

@ -8,6 +8,7 @@ from khaosz.trainer.strategy import (
SchedulerFactory
)
from khaosz.trainer.trainer_callback import (
TrainerCallback,
ProgressBarCallback,
CheckpointCallback,
TrainerCallback,
@ -25,6 +26,7 @@ __all__ = [
"SchedulerFactory",
# callback
"TrainerCallback",
"ProgressBarCallback",
"CheckpointCallback",
"TrainerCallback",

View File

@ -53,7 +53,8 @@ def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool()
attention_mask = seq_mask & causal_mask
return attention_mask
# fix the shape (bsz, 1, seq_len, seq_len) unsqueeze for broadcast
return attention_mask.unsqueeze(0)
class BaseSegmentFetcher:
@ -118,9 +119,12 @@ class BaseDataset(Dataset, ABC):
self.device = device
def save(self, save_path: str):
keys = list(self.segments.keys())
if not keys:
return
first_item = self.segments[keys[0]]
segment_size = len(first_item)
keys = list(self.segments.keys())
for i in range(segment_size):
formated_segment = {key: self.segments[key][i] for key in keys}
@ -272,7 +276,7 @@ class RandomSampler(Sampler[int]):
self.data_source = data_source
self.seed = seed
self.epoch = 0
self.current_index = 0
self.current_iter = 0
self._indices = None
if generator is None:
@ -291,20 +295,21 @@ class RandomSampler(Sampler[int]):
if self._indices is None:
self._generate_indices()
for i in range(self.current_index, n):
for i in range(self.current_iter, n):
yield self._indices[i]
self.current_iter += 1
self.epoch += 1
self.current_index = 0
self._indices = None
def __len__(self):
return len(self.data_source) - self.current_index
n = len(self.data_source)
return n - self.current_iter % n
def state_dict(self):
return {
'epoch': self.epoch,
'current_index': self.current_index,
'current_iter': self.current_iter,
'seed': self.seed,
'generator_state': self.generator.get_state() if self.generator else None,
'indices': self._indices
@ -312,7 +317,7 @@ class RandomSampler(Sampler[int]):
def load_state_dict(self, state_dict):
self.epoch = state_dict['epoch']
self.current_index = state_dict['current_index']
self.current_iter = state_dict['current_iter']
self.seed = state_dict['seed']
if self.generator and state_dict['generator_state'] is not None:

View File

@ -159,12 +159,7 @@ class StrategyFactory:
def load(model, train_type, **kwargs):
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
"seq": lambda: SeqStrategy(model),
"sft": lambda: SftStrategy(
model,
kwargs.get("bos_token_id"),
kwargs.get("eos_token_id"),
kwargs.get("multi_turn")
),
"sft": lambda: SftStrategy(model),
"dpo": lambda: DpoStrategy(
model,
kwargs.get("pad_token_id"),

View File

@ -1,6 +1,4 @@
import torch
import itertools
from typing import Optional, List
from typing import Optional, List, cast
from torch.utils.data import DataLoader
from khaosz.core import ModelParameter, Checkpoint
@ -23,11 +21,7 @@ class Trainer:
schedule_config: ScheduleConfig,
callbacks: Optional[List[TrainerCallback]] = None
):
self.checkpoint = Checkpoint(
model=parameter.model,
tokenizer=parameter.tokenizer,
config=parameter.config,
)
self.parameter = parameter
self.train_config = train_config
self.schedule_config = schedule_config
self.callbacks = callbacks or self._get_default_callbacks()
@ -40,24 +34,50 @@ class Trainer:
SchedulerCallback(self.schedule_config),
]
def _create_dataloader(self, start_index: int = 0) -> DataLoader:
def _set_train_kwargs(self, kwargs: dict):
used_epochs = 0
used_iters = 0
seed = self.train_config.random_seed
generator = torch.Generator().manual_seed(seed)
sampler = RandomSampler(
self.train_config.dataset,
generator=generator,
seed=seed
sampler = RandomSampler(data_source=self.train_config.dataset, seed=seed)
optim = self.train_config.optimizer
checkpoint = cast(Checkpoint, kwargs.get('checkpoint', None))
if checkpoint is None:
checkpoint = Checkpoint(
model=self.parameter.model,
tokenizer=self.parameter.tokenizer,
config=self.parameter.config,
sampler_state=None,
optim_state=None,
loss_list=[]
)
sampler_state = checkpoint.sampler_state
optim_state = checkpoint.optim_state
if sampler_state:
sampler.load_state_dict(sampler_state)
used_epochs = sampler_state.get('epoch', 0)
used_iters = sampler_state.get('iter', 0)
if optim_state:
optim.load_state_dict(optim_state)
checkpoint.optim_state = optim.state_dict()
checkpoint.sampler_state = sampler.state_dict()
dataloader = DataLoader(
self.train_config.dataset,
batch_size=self.train_config.batch_size,
sampler=sampler
)
if start_index > 0:
dataloader = itertools.islice(dataloader, start_index, None)
return dataloader
kwargs["dataloader"] = dataloader
kwargs["optimizer"] = self.train_config.optimizer
kwargs["epoch"] = used_epochs
kwargs["current_iter"] = used_iters
kwargs["sampler"] = sampler
kwargs["checkpoint"] = checkpoint
def _call_callbacks(self, method_name: str, **kwargs):
for callback in self.callbacks:
@ -67,53 +87,54 @@ class Trainer:
def train(
self,
train_checkpoint: Optional[Checkpoint] = None
checkpoint: Optional[Checkpoint] = None
) -> Checkpoint:
if train_checkpoint:
self.checkpoint = train_checkpoint
self.train_config.optimizer.load_state_dict(train_checkpoint.optim_state)
else:
self.checkpoint.optim_state = self.train_config.optimizer.state_dict()
current_iter = len(self.checkpoint.loss_list)
total_steps_per_epoch = len(self.train_config.dataset) // self.train_config.batch_size
total_reamining_steps = total_steps_per_epoch * self.train_config.n_epoch - current_iter
current_epochs = total_reamining_steps // total_steps_per_epoch
current_steps = total_reamining_steps % total_steps_per_epoch
# train
self._call_callbacks('on_train_begin', checkpoint=self.checkpoint)
self.checkpoint.model.train()
train_kwargs = {
'checkpoint': checkpoint,
'dataloader': None,
'optimizer': None,
'sampler': None,
'epoch': 0,
'current_iter': 0,
'loss': 0.0,
}
self._set_train_kwargs(train_kwargs)
self._call_callbacks('on_train_begin', **train_kwargs)
dataloader = train_kwargs['dataloader']
checkpoint = train_kwargs['checkpoint']
start_epoch = train_kwargs['epoch']
try:
for epoch in range(current_epochs):
self.parameter.model.train()
for epoch in range(start_epoch, self.train_config.n_epoch):
# epoch
self._call_callbacks('on_epoch_begin', epoch=epoch)
dataloader = self._create_dataloader(start_index=current_steps)
train_kwargs["epoch"] = epoch
self._call_callbacks('on_epoch_begin', **train_kwargs)
for batch in dataloader:
# batch
self._call_callbacks('on_batch_begin', batch=batch)
self._call_callbacks('on_batch_begin', **train_kwargs)
loss = self.train_config.strategy(batch)
self.checkpoint.loss_list.append(loss.item())
loss.backward()
self._call_callbacks('on_batch_end', batch=batch, loss=loss.item(), current_iter=current_iter)
train_kwargs["loss"] = loss.item()
self._call_callbacks('on_batch_end', **train_kwargs)
if current_iter % self.train_config.accumulation_steps == 0:
if train_kwargs["current_iter"] % self.train_config.accumulation_steps == 0:
# step
self._call_callbacks('on_step_begin', current_iter=current_iter)
self._call_callbacks('on_step_begin', **train_kwargs)
self.train_config.optimizer.step()
self.train_config.optimizer.zero_grad()
self._call_callbacks('on_step_end', current_iter=current_iter)
self._call_callbacks('on_step_end', **train_kwargs)
current_iter += 1
train_kwargs["current_iter"] += 1
self._call_callbacks('on_epoch_end', epoch=epoch, loss_list=self.checkpoint.loss_list)
self._call_callbacks('on_epoch_end', **train_kwargs)
except Exception as e:
raise e
finally:
self._call_callbacks('on_train_end', checkpoint=self.checkpoint)
return self.checkpoint
self._call_callbacks('on_train_end', **train_kwargs)
return checkpoint

View File

@ -1,9 +1,12 @@
import os
import torch.optim as optim
from tqdm import tqdm
from khaosz.core.parameter import Checkpoint
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import LambdaLR
from typing import Optional, cast, TYPE_CHECKING
from khaosz.core.parameter import Checkpoint
from khaosz.trainer.data_util import RandomSampler
from khaosz.trainer.strategy import ScheduleConfig, SchedulerFactory
if TYPE_CHECKING:
@ -76,7 +79,7 @@ class ProgressBarCallback(TrainerCallback):
def on_epoch_begin(self, trainer: 'Trainer', **kwargs):
epoch = kwargs.get('epoch')
dataloader = trainer._create_dataloader()
dataloader = kwargs.get('dataloader')
self.progress_bar = tqdm(
dataloader,
desc=f"Epoch {epoch+1}/{trainer.train_config.n_epoch}",
@ -106,28 +109,36 @@ class CheckpointCallback(TrainerCallback):
self.last_ckpt_iter = 0
@staticmethod
def _save_checkpoint(trainer: 'Trainer'):
current_iter = len(trainer.checkpoint.loss_list)
save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{current_iter}")
trainer.checkpoint.optim_state = trainer.train_config.optimizer.state_dict()
trainer.checkpoint.save(save_path)
def on_train_begin(self, trainer: 'Trainer', **kwargs):
_ = trainer
def _save_checkpoint(trainer: 'Trainer', **kwargs):
current_iter = kwargs.get('current_iter')
random_sampler = cast(RandomSampler, kwargs.get('sampler'))
optimizer = cast(optim.Optimizer, kwargs.get('optimizer'))
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
self.last_ckpt_iter = len(checkpoint.loss_list)
save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{current_iter}")
checkpoint.sampler_state = random_sampler.state_dict()
checkpoint.optim_state = optimizer.state_dict()
checkpoint.sampler_state['epoch'] = kwargs.get('epoch', 0)
checkpoint.sampler_state['current_iter'] = kwargs.get('current_iter', 0)
checkpoint.save(save_path)
def on_batch_end(self, trainer: 'Trainer', **kwargs):
current_iter = kwargs.get('current_iter')
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
loss = kwargs.get('loss')
checkpoint.loss_list.append(loss)
if current_iter - self.last_ckpt_iter >= self.checkpoint_interval:
CheckpointCallback._save_checkpoint(trainer)
CheckpointCallback._save_checkpoint(trainer, **kwargs)
self.last_ckpt_iter = current_iter
def on_train_end(self, trainer: 'Trainer', **kwargs):
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
current_iter = len(checkpoint.loss_list)
current_iter = kwargs.get('current_iter')
if current_iter != self.last_ckpt_iter:
CheckpointCallback._save_checkpoint(trainer)
CheckpointCallback._save_checkpoint(trainer, **kwargs)
self.last_ckpt_iter = current_iter
class GradientClippingCallback(TrainerCallback):
@ -137,7 +148,7 @@ class GradientClippingCallback(TrainerCallback):
def on_step_begin(self, trainer: 'Trainer', **kwargs):
_ = kwargs
clip_grad_norm_(
trainer.checkpoint.model.parameters(),
trainer.parameter.model.parameters(),
trainer.train_config.max_grad_norm
)
@ -152,8 +163,7 @@ class SchedulerCallback(TrainerCallback):
self.current_iter = 0
def on_train_begin(self, trainer: 'Trainer', **kwargs):
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
self.current_iter = len(checkpoint.loss_list)
self.current_iter = kwargs.get('current_iter')
for group in trainer.train_config.optimizer.param_groups:
if "initial_lr" not in group: