feat(khaosz/trainer): 优化训练器回调机制与数据采样逻辑
This commit is contained in:
parent
e0e9942e4a
commit
315ce1990a
|
|
@ -8,6 +8,7 @@ from khaosz.trainer.strategy import (
|
|||
SchedulerFactory
|
||||
)
|
||||
from khaosz.trainer.trainer_callback import (
|
||||
TrainerCallback,
|
||||
ProgressBarCallback,
|
||||
CheckpointCallback,
|
||||
TrainerCallback,
|
||||
|
|
@ -25,6 +26,7 @@ __all__ = [
|
|||
"SchedulerFactory",
|
||||
|
||||
# callback
|
||||
"TrainerCallback",
|
||||
"ProgressBarCallback",
|
||||
"CheckpointCallback",
|
||||
"TrainerCallback",
|
||||
|
|
|
|||
|
|
@ -53,7 +53,8 @@ def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool
|
|||
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool()
|
||||
attention_mask = seq_mask & causal_mask
|
||||
|
||||
return attention_mask
|
||||
# fix the shape (bsz, 1, seq_len, seq_len) unsqueeze for broadcast
|
||||
return attention_mask.unsqueeze(0)
|
||||
|
||||
|
||||
class BaseSegmentFetcher:
|
||||
|
|
@ -118,9 +119,12 @@ class BaseDataset(Dataset, ABC):
|
|||
self.device = device
|
||||
|
||||
def save(self, save_path: str):
|
||||
keys = list(self.segments.keys())
|
||||
if not keys:
|
||||
return
|
||||
|
||||
first_item = self.segments[keys[0]]
|
||||
segment_size = len(first_item)
|
||||
keys = list(self.segments.keys())
|
||||
|
||||
for i in range(segment_size):
|
||||
formated_segment = {key: self.segments[key][i] for key in keys}
|
||||
|
|
@ -272,7 +276,7 @@ class RandomSampler(Sampler[int]):
|
|||
self.data_source = data_source
|
||||
self.seed = seed
|
||||
self.epoch = 0
|
||||
self.current_index = 0
|
||||
self.current_iter = 0
|
||||
self._indices = None
|
||||
|
||||
if generator is None:
|
||||
|
|
@ -291,20 +295,21 @@ class RandomSampler(Sampler[int]):
|
|||
if self._indices is None:
|
||||
self._generate_indices()
|
||||
|
||||
for i in range(self.current_index, n):
|
||||
for i in range(self.current_iter, n):
|
||||
yield self._indices[i]
|
||||
self.current_iter += 1
|
||||
|
||||
self.epoch += 1
|
||||
self.current_index = 0
|
||||
self._indices = None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_source) - self.current_index
|
||||
n = len(self.data_source)
|
||||
return n - self.current_iter % n
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
'epoch': self.epoch,
|
||||
'current_index': self.current_index,
|
||||
'current_iter': self.current_iter,
|
||||
'seed': self.seed,
|
||||
'generator_state': self.generator.get_state() if self.generator else None,
|
||||
'indices': self._indices
|
||||
|
|
@ -312,7 +317,7 @@ class RandomSampler(Sampler[int]):
|
|||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.epoch = state_dict['epoch']
|
||||
self.current_index = state_dict['current_index']
|
||||
self.current_iter = state_dict['current_iter']
|
||||
self.seed = state_dict['seed']
|
||||
|
||||
if self.generator and state_dict['generator_state'] is not None:
|
||||
|
|
|
|||
|
|
@ -159,12 +159,7 @@ class StrategyFactory:
|
|||
def load(model, train_type, **kwargs):
|
||||
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
|
||||
"seq": lambda: SeqStrategy(model),
|
||||
"sft": lambda: SftStrategy(
|
||||
model,
|
||||
kwargs.get("bos_token_id"),
|
||||
kwargs.get("eos_token_id"),
|
||||
kwargs.get("multi_turn")
|
||||
),
|
||||
"sft": lambda: SftStrategy(model),
|
||||
"dpo": lambda: DpoStrategy(
|
||||
model,
|
||||
kwargs.get("pad_token_id"),
|
||||
|
|
|
|||
|
|
@ -1,6 +1,4 @@
|
|||
import torch
|
||||
import itertools
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, cast
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from khaosz.core import ModelParameter, Checkpoint
|
||||
|
|
@ -23,11 +21,7 @@ class Trainer:
|
|||
schedule_config: ScheduleConfig,
|
||||
callbacks: Optional[List[TrainerCallback]] = None
|
||||
):
|
||||
self.checkpoint = Checkpoint(
|
||||
model=parameter.model,
|
||||
tokenizer=parameter.tokenizer,
|
||||
config=parameter.config,
|
||||
)
|
||||
self.parameter = parameter
|
||||
self.train_config = train_config
|
||||
self.schedule_config = schedule_config
|
||||
self.callbacks = callbacks or self._get_default_callbacks()
|
||||
|
|
@ -40,24 +34,50 @@ class Trainer:
|
|||
SchedulerCallback(self.schedule_config),
|
||||
]
|
||||
|
||||
def _create_dataloader(self, start_index: int = 0) -> DataLoader:
|
||||
def _set_train_kwargs(self, kwargs: dict):
|
||||
used_epochs = 0
|
||||
used_iters = 0
|
||||
seed = self.train_config.random_seed
|
||||
generator = torch.Generator().manual_seed(seed)
|
||||
sampler = RandomSampler(
|
||||
self.train_config.dataset,
|
||||
generator=generator,
|
||||
seed=seed
|
||||
sampler = RandomSampler(data_source=self.train_config.dataset, seed=seed)
|
||||
optim = self.train_config.optimizer
|
||||
checkpoint = cast(Checkpoint, kwargs.get('checkpoint', None))
|
||||
|
||||
if checkpoint is None:
|
||||
checkpoint = Checkpoint(
|
||||
model=self.parameter.model,
|
||||
tokenizer=self.parameter.tokenizer,
|
||||
config=self.parameter.config,
|
||||
sampler_state=None,
|
||||
optim_state=None,
|
||||
loss_list=[]
|
||||
)
|
||||
|
||||
sampler_state = checkpoint.sampler_state
|
||||
optim_state = checkpoint.optim_state
|
||||
|
||||
if sampler_state:
|
||||
sampler.load_state_dict(sampler_state)
|
||||
used_epochs = sampler_state.get('epoch', 0)
|
||||
used_iters = sampler_state.get('iter', 0)
|
||||
|
||||
if optim_state:
|
||||
optim.load_state_dict(optim_state)
|
||||
|
||||
checkpoint.optim_state = optim.state_dict()
|
||||
checkpoint.sampler_state = sampler.state_dict()
|
||||
|
||||
dataloader = DataLoader(
|
||||
self.train_config.dataset,
|
||||
batch_size=self.train_config.batch_size,
|
||||
sampler=sampler
|
||||
)
|
||||
|
||||
if start_index > 0:
|
||||
dataloader = itertools.islice(dataloader, start_index, None)
|
||||
|
||||
return dataloader
|
||||
kwargs["dataloader"] = dataloader
|
||||
kwargs["optimizer"] = self.train_config.optimizer
|
||||
kwargs["epoch"] = used_epochs
|
||||
kwargs["current_iter"] = used_iters
|
||||
kwargs["sampler"] = sampler
|
||||
kwargs["checkpoint"] = checkpoint
|
||||
|
||||
def _call_callbacks(self, method_name: str, **kwargs):
|
||||
for callback in self.callbacks:
|
||||
|
|
@ -67,53 +87,54 @@ class Trainer:
|
|||
|
||||
def train(
|
||||
self,
|
||||
train_checkpoint: Optional[Checkpoint] = None
|
||||
checkpoint: Optional[Checkpoint] = None
|
||||
) -> Checkpoint:
|
||||
|
||||
if train_checkpoint:
|
||||
self.checkpoint = train_checkpoint
|
||||
self.train_config.optimizer.load_state_dict(train_checkpoint.optim_state)
|
||||
else:
|
||||
self.checkpoint.optim_state = self.train_config.optimizer.state_dict()
|
||||
|
||||
current_iter = len(self.checkpoint.loss_list)
|
||||
total_steps_per_epoch = len(self.train_config.dataset) // self.train_config.batch_size
|
||||
total_reamining_steps = total_steps_per_epoch * self.train_config.n_epoch - current_iter
|
||||
|
||||
current_epochs = total_reamining_steps // total_steps_per_epoch
|
||||
current_steps = total_reamining_steps % total_steps_per_epoch
|
||||
|
||||
# train
|
||||
self._call_callbacks('on_train_begin', checkpoint=self.checkpoint)
|
||||
self.checkpoint.model.train()
|
||||
train_kwargs = {
|
||||
'checkpoint': checkpoint,
|
||||
'dataloader': None,
|
||||
'optimizer': None,
|
||||
'sampler': None,
|
||||
'epoch': 0,
|
||||
'current_iter': 0,
|
||||
'loss': 0.0,
|
||||
}
|
||||
|
||||
self._set_train_kwargs(train_kwargs)
|
||||
self._call_callbacks('on_train_begin', **train_kwargs)
|
||||
|
||||
dataloader = train_kwargs['dataloader']
|
||||
checkpoint = train_kwargs['checkpoint']
|
||||
start_epoch = train_kwargs['epoch']
|
||||
|
||||
try:
|
||||
for epoch in range(current_epochs):
|
||||
self.parameter.model.train()
|
||||
for epoch in range(start_epoch, self.train_config.n_epoch):
|
||||
# epoch
|
||||
self._call_callbacks('on_epoch_begin', epoch=epoch)
|
||||
dataloader = self._create_dataloader(start_index=current_steps)
|
||||
train_kwargs["epoch"] = epoch
|
||||
self._call_callbacks('on_epoch_begin', **train_kwargs)
|
||||
for batch in dataloader:
|
||||
# batch
|
||||
self._call_callbacks('on_batch_begin', batch=batch)
|
||||
self._call_callbacks('on_batch_begin', **train_kwargs)
|
||||
loss = self.train_config.strategy(batch)
|
||||
self.checkpoint.loss_list.append(loss.item())
|
||||
loss.backward()
|
||||
self._call_callbacks('on_batch_end', batch=batch, loss=loss.item(), current_iter=current_iter)
|
||||
train_kwargs["loss"] = loss.item()
|
||||
self._call_callbacks('on_batch_end', **train_kwargs)
|
||||
|
||||
if current_iter % self.train_config.accumulation_steps == 0:
|
||||
if train_kwargs["current_iter"] % self.train_config.accumulation_steps == 0:
|
||||
# step
|
||||
self._call_callbacks('on_step_begin', current_iter=current_iter)
|
||||
self._call_callbacks('on_step_begin', **train_kwargs)
|
||||
self.train_config.optimizer.step()
|
||||
self.train_config.optimizer.zero_grad()
|
||||
self._call_callbacks('on_step_end', current_iter=current_iter)
|
||||
self._call_callbacks('on_step_end', **train_kwargs)
|
||||
|
||||
current_iter += 1
|
||||
train_kwargs["current_iter"] += 1
|
||||
|
||||
self._call_callbacks('on_epoch_end', epoch=epoch, loss_list=self.checkpoint.loss_list)
|
||||
self._call_callbacks('on_epoch_end', **train_kwargs)
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
finally:
|
||||
self._call_callbacks('on_train_end', checkpoint=self.checkpoint)
|
||||
return self.checkpoint
|
||||
self._call_callbacks('on_train_end', **train_kwargs)
|
||||
return checkpoint
|
||||
|
|
@ -1,9 +1,12 @@
|
|||
import os
|
||||
import torch.optim as optim
|
||||
|
||||
from tqdm import tqdm
|
||||
from khaosz.core.parameter import Checkpoint
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from typing import Optional, cast, TYPE_CHECKING
|
||||
from khaosz.core.parameter import Checkpoint
|
||||
from khaosz.trainer.data_util import RandomSampler
|
||||
from khaosz.trainer.strategy import ScheduleConfig, SchedulerFactory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -76,7 +79,7 @@ class ProgressBarCallback(TrainerCallback):
|
|||
|
||||
def on_epoch_begin(self, trainer: 'Trainer', **kwargs):
|
||||
epoch = kwargs.get('epoch')
|
||||
dataloader = trainer._create_dataloader()
|
||||
dataloader = kwargs.get('dataloader')
|
||||
self.progress_bar = tqdm(
|
||||
dataloader,
|
||||
desc=f"Epoch {epoch+1}/{trainer.train_config.n_epoch}",
|
||||
|
|
@ -106,28 +109,36 @@ class CheckpointCallback(TrainerCallback):
|
|||
self.last_ckpt_iter = 0
|
||||
|
||||
@staticmethod
|
||||
def _save_checkpoint(trainer: 'Trainer'):
|
||||
current_iter = len(trainer.checkpoint.loss_list)
|
||||
save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{current_iter}")
|
||||
trainer.checkpoint.optim_state = trainer.train_config.optimizer.state_dict()
|
||||
trainer.checkpoint.save(save_path)
|
||||
|
||||
def on_train_begin(self, trainer: 'Trainer', **kwargs):
|
||||
_ = trainer
|
||||
def _save_checkpoint(trainer: 'Trainer', **kwargs):
|
||||
current_iter = kwargs.get('current_iter')
|
||||
random_sampler = cast(RandomSampler, kwargs.get('sampler'))
|
||||
optimizer = cast(optim.Optimizer, kwargs.get('optimizer'))
|
||||
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
|
||||
self.last_ckpt_iter = len(checkpoint.loss_list)
|
||||
|
||||
save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{current_iter}")
|
||||
checkpoint.sampler_state = random_sampler.state_dict()
|
||||
checkpoint.optim_state = optimizer.state_dict()
|
||||
|
||||
checkpoint.sampler_state['epoch'] = kwargs.get('epoch', 0)
|
||||
checkpoint.sampler_state['current_iter'] = kwargs.get('current_iter', 0)
|
||||
|
||||
checkpoint.save(save_path)
|
||||
|
||||
def on_batch_end(self, trainer: 'Trainer', **kwargs):
|
||||
current_iter = kwargs.get('current_iter')
|
||||
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
|
||||
loss = kwargs.get('loss')
|
||||
checkpoint.loss_list.append(loss)
|
||||
|
||||
if current_iter - self.last_ckpt_iter >= self.checkpoint_interval:
|
||||
CheckpointCallback._save_checkpoint(trainer)
|
||||
CheckpointCallback._save_checkpoint(trainer, **kwargs)
|
||||
self.last_ckpt_iter = current_iter
|
||||
|
||||
def on_train_end(self, trainer: 'Trainer', **kwargs):
|
||||
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
|
||||
current_iter = len(checkpoint.loss_list)
|
||||
current_iter = kwargs.get('current_iter')
|
||||
if current_iter != self.last_ckpt_iter:
|
||||
CheckpointCallback._save_checkpoint(trainer)
|
||||
CheckpointCallback._save_checkpoint(trainer, **kwargs)
|
||||
self.last_ckpt_iter = current_iter
|
||||
|
||||
|
||||
class GradientClippingCallback(TrainerCallback):
|
||||
|
|
@ -137,7 +148,7 @@ class GradientClippingCallback(TrainerCallback):
|
|||
def on_step_begin(self, trainer: 'Trainer', **kwargs):
|
||||
_ = kwargs
|
||||
clip_grad_norm_(
|
||||
trainer.checkpoint.model.parameters(),
|
||||
trainer.parameter.model.parameters(),
|
||||
trainer.train_config.max_grad_norm
|
||||
)
|
||||
|
||||
|
|
@ -152,8 +163,7 @@ class SchedulerCallback(TrainerCallback):
|
|||
self.current_iter = 0
|
||||
|
||||
def on_train_begin(self, trainer: 'Trainer', **kwargs):
|
||||
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
|
||||
self.current_iter = len(checkpoint.loss_list)
|
||||
self.current_iter = kwargs.get('current_iter')
|
||||
|
||||
for group in trainer.train_config.optimizer.param_groups:
|
||||
if "initial_lr" not in group:
|
||||
|
|
|
|||
Loading…
Reference in New Issue