feat(khaosz/trainer): 优化训练器回调机制与数据采样逻辑
This commit is contained in:
parent
e0e9942e4a
commit
315ce1990a
|
|
@ -8,6 +8,7 @@ from khaosz.trainer.strategy import (
|
||||||
SchedulerFactory
|
SchedulerFactory
|
||||||
)
|
)
|
||||||
from khaosz.trainer.trainer_callback import (
|
from khaosz.trainer.trainer_callback import (
|
||||||
|
TrainerCallback,
|
||||||
ProgressBarCallback,
|
ProgressBarCallback,
|
||||||
CheckpointCallback,
|
CheckpointCallback,
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
|
|
@ -25,6 +26,7 @@ __all__ = [
|
||||||
"SchedulerFactory",
|
"SchedulerFactory",
|
||||||
|
|
||||||
# callback
|
# callback
|
||||||
|
"TrainerCallback",
|
||||||
"ProgressBarCallback",
|
"ProgressBarCallback",
|
||||||
"CheckpointCallback",
|
"CheckpointCallback",
|
||||||
"TrainerCallback",
|
"TrainerCallback",
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,8 @@ def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool
|
||||||
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool()
|
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool()
|
||||||
attention_mask = seq_mask & causal_mask
|
attention_mask = seq_mask & causal_mask
|
||||||
|
|
||||||
return attention_mask
|
# fix the shape (bsz, 1, seq_len, seq_len) unsqueeze for broadcast
|
||||||
|
return attention_mask.unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
class BaseSegmentFetcher:
|
class BaseSegmentFetcher:
|
||||||
|
|
@ -118,9 +119,12 @@ class BaseDataset(Dataset, ABC):
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def save(self, save_path: str):
|
def save(self, save_path: str):
|
||||||
|
keys = list(self.segments.keys())
|
||||||
|
if not keys:
|
||||||
|
return
|
||||||
|
|
||||||
first_item = self.segments[keys[0]]
|
first_item = self.segments[keys[0]]
|
||||||
segment_size = len(first_item)
|
segment_size = len(first_item)
|
||||||
keys = list(self.segments.keys())
|
|
||||||
|
|
||||||
for i in range(segment_size):
|
for i in range(segment_size):
|
||||||
formated_segment = {key: self.segments[key][i] for key in keys}
|
formated_segment = {key: self.segments[key][i] for key in keys}
|
||||||
|
|
@ -272,7 +276,7 @@ class RandomSampler(Sampler[int]):
|
||||||
self.data_source = data_source
|
self.data_source = data_source
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.epoch = 0
|
self.epoch = 0
|
||||||
self.current_index = 0
|
self.current_iter = 0
|
||||||
self._indices = None
|
self._indices = None
|
||||||
|
|
||||||
if generator is None:
|
if generator is None:
|
||||||
|
|
@ -291,20 +295,21 @@ class RandomSampler(Sampler[int]):
|
||||||
if self._indices is None:
|
if self._indices is None:
|
||||||
self._generate_indices()
|
self._generate_indices()
|
||||||
|
|
||||||
for i in range(self.current_index, n):
|
for i in range(self.current_iter, n):
|
||||||
yield self._indices[i]
|
yield self._indices[i]
|
||||||
|
self.current_iter += 1
|
||||||
|
|
||||||
self.epoch += 1
|
self.epoch += 1
|
||||||
self.current_index = 0
|
|
||||||
self._indices = None
|
self._indices = None
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data_source) - self.current_index
|
n = len(self.data_source)
|
||||||
|
return n - self.current_iter % n
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {
|
return {
|
||||||
'epoch': self.epoch,
|
'epoch': self.epoch,
|
||||||
'current_index': self.current_index,
|
'current_iter': self.current_iter,
|
||||||
'seed': self.seed,
|
'seed': self.seed,
|
||||||
'generator_state': self.generator.get_state() if self.generator else None,
|
'generator_state': self.generator.get_state() if self.generator else None,
|
||||||
'indices': self._indices
|
'indices': self._indices
|
||||||
|
|
@ -312,7 +317,7 @@ class RandomSampler(Sampler[int]):
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
self.epoch = state_dict['epoch']
|
self.epoch = state_dict['epoch']
|
||||||
self.current_index = state_dict['current_index']
|
self.current_iter = state_dict['current_iter']
|
||||||
self.seed = state_dict['seed']
|
self.seed = state_dict['seed']
|
||||||
|
|
||||||
if self.generator and state_dict['generator_state'] is not None:
|
if self.generator and state_dict['generator_state'] is not None:
|
||||||
|
|
|
||||||
|
|
@ -159,12 +159,7 @@ class StrategyFactory:
|
||||||
def load(model, train_type, **kwargs):
|
def load(model, train_type, **kwargs):
|
||||||
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
|
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
|
||||||
"seq": lambda: SeqStrategy(model),
|
"seq": lambda: SeqStrategy(model),
|
||||||
"sft": lambda: SftStrategy(
|
"sft": lambda: SftStrategy(model),
|
||||||
model,
|
|
||||||
kwargs.get("bos_token_id"),
|
|
||||||
kwargs.get("eos_token_id"),
|
|
||||||
kwargs.get("multi_turn")
|
|
||||||
),
|
|
||||||
"dpo": lambda: DpoStrategy(
|
"dpo": lambda: DpoStrategy(
|
||||||
model,
|
model,
|
||||||
kwargs.get("pad_token_id"),
|
kwargs.get("pad_token_id"),
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,4 @@
|
||||||
import torch
|
from typing import Optional, List, cast
|
||||||
import itertools
|
|
||||||
from typing import Optional, List
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from khaosz.core import ModelParameter, Checkpoint
|
from khaosz.core import ModelParameter, Checkpoint
|
||||||
|
|
@ -23,11 +21,7 @@ class Trainer:
|
||||||
schedule_config: ScheduleConfig,
|
schedule_config: ScheduleConfig,
|
||||||
callbacks: Optional[List[TrainerCallback]] = None
|
callbacks: Optional[List[TrainerCallback]] = None
|
||||||
):
|
):
|
||||||
self.checkpoint = Checkpoint(
|
self.parameter = parameter
|
||||||
model=parameter.model,
|
|
||||||
tokenizer=parameter.tokenizer,
|
|
||||||
config=parameter.config,
|
|
||||||
)
|
|
||||||
self.train_config = train_config
|
self.train_config = train_config
|
||||||
self.schedule_config = schedule_config
|
self.schedule_config = schedule_config
|
||||||
self.callbacks = callbacks or self._get_default_callbacks()
|
self.callbacks = callbacks or self._get_default_callbacks()
|
||||||
|
|
@ -40,24 +34,50 @@ class Trainer:
|
||||||
SchedulerCallback(self.schedule_config),
|
SchedulerCallback(self.schedule_config),
|
||||||
]
|
]
|
||||||
|
|
||||||
def _create_dataloader(self, start_index: int = 0) -> DataLoader:
|
def _set_train_kwargs(self, kwargs: dict):
|
||||||
|
used_epochs = 0
|
||||||
|
used_iters = 0
|
||||||
seed = self.train_config.random_seed
|
seed = self.train_config.random_seed
|
||||||
generator = torch.Generator().manual_seed(seed)
|
sampler = RandomSampler(data_source=self.train_config.dataset, seed=seed)
|
||||||
sampler = RandomSampler(
|
optim = self.train_config.optimizer
|
||||||
self.train_config.dataset,
|
checkpoint = cast(Checkpoint, kwargs.get('checkpoint', None))
|
||||||
generator=generator,
|
|
||||||
seed=seed
|
if checkpoint is None:
|
||||||
|
checkpoint = Checkpoint(
|
||||||
|
model=self.parameter.model,
|
||||||
|
tokenizer=self.parameter.tokenizer,
|
||||||
|
config=self.parameter.config,
|
||||||
|
sampler_state=None,
|
||||||
|
optim_state=None,
|
||||||
|
loss_list=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sampler_state = checkpoint.sampler_state
|
||||||
|
optim_state = checkpoint.optim_state
|
||||||
|
|
||||||
|
if sampler_state:
|
||||||
|
sampler.load_state_dict(sampler_state)
|
||||||
|
used_epochs = sampler_state.get('epoch', 0)
|
||||||
|
used_iters = sampler_state.get('iter', 0)
|
||||||
|
|
||||||
|
if optim_state:
|
||||||
|
optim.load_state_dict(optim_state)
|
||||||
|
|
||||||
|
checkpoint.optim_state = optim.state_dict()
|
||||||
|
checkpoint.sampler_state = sampler.state_dict()
|
||||||
|
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
self.train_config.dataset,
|
self.train_config.dataset,
|
||||||
batch_size=self.train_config.batch_size,
|
batch_size=self.train_config.batch_size,
|
||||||
sampler=sampler
|
sampler=sampler
|
||||||
)
|
)
|
||||||
|
|
||||||
if start_index > 0:
|
kwargs["dataloader"] = dataloader
|
||||||
dataloader = itertools.islice(dataloader, start_index, None)
|
kwargs["optimizer"] = self.train_config.optimizer
|
||||||
|
kwargs["epoch"] = used_epochs
|
||||||
return dataloader
|
kwargs["current_iter"] = used_iters
|
||||||
|
kwargs["sampler"] = sampler
|
||||||
|
kwargs["checkpoint"] = checkpoint
|
||||||
|
|
||||||
def _call_callbacks(self, method_name: str, **kwargs):
|
def _call_callbacks(self, method_name: str, **kwargs):
|
||||||
for callback in self.callbacks:
|
for callback in self.callbacks:
|
||||||
|
|
@ -67,53 +87,54 @@ class Trainer:
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
self,
|
self,
|
||||||
train_checkpoint: Optional[Checkpoint] = None
|
checkpoint: Optional[Checkpoint] = None
|
||||||
) -> Checkpoint:
|
) -> Checkpoint:
|
||||||
|
|
||||||
if train_checkpoint:
|
|
||||||
self.checkpoint = train_checkpoint
|
|
||||||
self.train_config.optimizer.load_state_dict(train_checkpoint.optim_state)
|
|
||||||
else:
|
|
||||||
self.checkpoint.optim_state = self.train_config.optimizer.state_dict()
|
|
||||||
|
|
||||||
current_iter = len(self.checkpoint.loss_list)
|
|
||||||
total_steps_per_epoch = len(self.train_config.dataset) // self.train_config.batch_size
|
|
||||||
total_reamining_steps = total_steps_per_epoch * self.train_config.n_epoch - current_iter
|
|
||||||
|
|
||||||
current_epochs = total_reamining_steps // total_steps_per_epoch
|
|
||||||
current_steps = total_reamining_steps % total_steps_per_epoch
|
|
||||||
|
|
||||||
# train
|
# train
|
||||||
self._call_callbacks('on_train_begin', checkpoint=self.checkpoint)
|
train_kwargs = {
|
||||||
self.checkpoint.model.train()
|
'checkpoint': checkpoint,
|
||||||
|
'dataloader': None,
|
||||||
|
'optimizer': None,
|
||||||
|
'sampler': None,
|
||||||
|
'epoch': 0,
|
||||||
|
'current_iter': 0,
|
||||||
|
'loss': 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
self._set_train_kwargs(train_kwargs)
|
||||||
|
self._call_callbacks('on_train_begin', **train_kwargs)
|
||||||
|
|
||||||
|
dataloader = train_kwargs['dataloader']
|
||||||
|
checkpoint = train_kwargs['checkpoint']
|
||||||
|
start_epoch = train_kwargs['epoch']
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for epoch in range(current_epochs):
|
self.parameter.model.train()
|
||||||
|
for epoch in range(start_epoch, self.train_config.n_epoch):
|
||||||
# epoch
|
# epoch
|
||||||
self._call_callbacks('on_epoch_begin', epoch=epoch)
|
train_kwargs["epoch"] = epoch
|
||||||
dataloader = self._create_dataloader(start_index=current_steps)
|
self._call_callbacks('on_epoch_begin', **train_kwargs)
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
# batch
|
# batch
|
||||||
self._call_callbacks('on_batch_begin', batch=batch)
|
self._call_callbacks('on_batch_begin', **train_kwargs)
|
||||||
loss = self.train_config.strategy(batch)
|
loss = self.train_config.strategy(batch)
|
||||||
self.checkpoint.loss_list.append(loss.item())
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self._call_callbacks('on_batch_end', batch=batch, loss=loss.item(), current_iter=current_iter)
|
train_kwargs["loss"] = loss.item()
|
||||||
|
self._call_callbacks('on_batch_end', **train_kwargs)
|
||||||
|
|
||||||
if current_iter % self.train_config.accumulation_steps == 0:
|
if train_kwargs["current_iter"] % self.train_config.accumulation_steps == 0:
|
||||||
# step
|
# step
|
||||||
self._call_callbacks('on_step_begin', current_iter=current_iter)
|
self._call_callbacks('on_step_begin', **train_kwargs)
|
||||||
self.train_config.optimizer.step()
|
self.train_config.optimizer.step()
|
||||||
self.train_config.optimizer.zero_grad()
|
self.train_config.optimizer.zero_grad()
|
||||||
self._call_callbacks('on_step_end', current_iter=current_iter)
|
self._call_callbacks('on_step_end', **train_kwargs)
|
||||||
|
|
||||||
current_iter += 1
|
train_kwargs["current_iter"] += 1
|
||||||
|
|
||||||
self._call_callbacks('on_epoch_end', epoch=epoch, loss_list=self.checkpoint.loss_list)
|
self._call_callbacks('on_epoch_end', **train_kwargs)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
self._call_callbacks('on_train_end', checkpoint=self.checkpoint)
|
self._call_callbacks('on_train_end', **train_kwargs)
|
||||||
return self.checkpoint
|
return checkpoint
|
||||||
|
|
@ -1,9 +1,12 @@
|
||||||
import os
|
import os
|
||||||
|
import torch.optim as optim
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from khaosz.core.parameter import Checkpoint
|
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from torch.optim.lr_scheduler import LambdaLR
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
from typing import Optional, cast, TYPE_CHECKING
|
from typing import Optional, cast, TYPE_CHECKING
|
||||||
|
from khaosz.core.parameter import Checkpoint
|
||||||
|
from khaosz.trainer.data_util import RandomSampler
|
||||||
from khaosz.trainer.strategy import ScheduleConfig, SchedulerFactory
|
from khaosz.trainer.strategy import ScheduleConfig, SchedulerFactory
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
@ -76,7 +79,7 @@ class ProgressBarCallback(TrainerCallback):
|
||||||
|
|
||||||
def on_epoch_begin(self, trainer: 'Trainer', **kwargs):
|
def on_epoch_begin(self, trainer: 'Trainer', **kwargs):
|
||||||
epoch = kwargs.get('epoch')
|
epoch = kwargs.get('epoch')
|
||||||
dataloader = trainer._create_dataloader()
|
dataloader = kwargs.get('dataloader')
|
||||||
self.progress_bar = tqdm(
|
self.progress_bar = tqdm(
|
||||||
dataloader,
|
dataloader,
|
||||||
desc=f"Epoch {epoch+1}/{trainer.train_config.n_epoch}",
|
desc=f"Epoch {epoch+1}/{trainer.train_config.n_epoch}",
|
||||||
|
|
@ -106,28 +109,36 @@ class CheckpointCallback(TrainerCallback):
|
||||||
self.last_ckpt_iter = 0
|
self.last_ckpt_iter = 0
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _save_checkpoint(trainer: 'Trainer'):
|
def _save_checkpoint(trainer: 'Trainer', **kwargs):
|
||||||
current_iter = len(trainer.checkpoint.loss_list)
|
current_iter = kwargs.get('current_iter')
|
||||||
save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{current_iter}")
|
random_sampler = cast(RandomSampler, kwargs.get('sampler'))
|
||||||
trainer.checkpoint.optim_state = trainer.train_config.optimizer.state_dict()
|
optimizer = cast(optim.Optimizer, kwargs.get('optimizer'))
|
||||||
trainer.checkpoint.save(save_path)
|
|
||||||
|
|
||||||
def on_train_begin(self, trainer: 'Trainer', **kwargs):
|
|
||||||
_ = trainer
|
|
||||||
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
|
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
|
||||||
self.last_ckpt_iter = len(checkpoint.loss_list)
|
|
||||||
|
save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{current_iter}")
|
||||||
|
checkpoint.sampler_state = random_sampler.state_dict()
|
||||||
|
checkpoint.optim_state = optimizer.state_dict()
|
||||||
|
|
||||||
|
checkpoint.sampler_state['epoch'] = kwargs.get('epoch', 0)
|
||||||
|
checkpoint.sampler_state['current_iter'] = kwargs.get('current_iter', 0)
|
||||||
|
|
||||||
|
checkpoint.save(save_path)
|
||||||
|
|
||||||
def on_batch_end(self, trainer: 'Trainer', **kwargs):
|
def on_batch_end(self, trainer: 'Trainer', **kwargs):
|
||||||
current_iter = kwargs.get('current_iter')
|
current_iter = kwargs.get('current_iter')
|
||||||
|
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
|
||||||
|
loss = kwargs.get('loss')
|
||||||
|
checkpoint.loss_list.append(loss)
|
||||||
|
|
||||||
if current_iter - self.last_ckpt_iter >= self.checkpoint_interval:
|
if current_iter - self.last_ckpt_iter >= self.checkpoint_interval:
|
||||||
CheckpointCallback._save_checkpoint(trainer)
|
CheckpointCallback._save_checkpoint(trainer, **kwargs)
|
||||||
self.last_ckpt_iter = current_iter
|
self.last_ckpt_iter = current_iter
|
||||||
|
|
||||||
def on_train_end(self, trainer: 'Trainer', **kwargs):
|
def on_train_end(self, trainer: 'Trainer', **kwargs):
|
||||||
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
|
current_iter = kwargs.get('current_iter')
|
||||||
current_iter = len(checkpoint.loss_list)
|
|
||||||
if current_iter != self.last_ckpt_iter:
|
if current_iter != self.last_ckpt_iter:
|
||||||
CheckpointCallback._save_checkpoint(trainer)
|
CheckpointCallback._save_checkpoint(trainer, **kwargs)
|
||||||
|
self.last_ckpt_iter = current_iter
|
||||||
|
|
||||||
|
|
||||||
class GradientClippingCallback(TrainerCallback):
|
class GradientClippingCallback(TrainerCallback):
|
||||||
|
|
@ -137,7 +148,7 @@ class GradientClippingCallback(TrainerCallback):
|
||||||
def on_step_begin(self, trainer: 'Trainer', **kwargs):
|
def on_step_begin(self, trainer: 'Trainer', **kwargs):
|
||||||
_ = kwargs
|
_ = kwargs
|
||||||
clip_grad_norm_(
|
clip_grad_norm_(
|
||||||
trainer.checkpoint.model.parameters(),
|
trainer.parameter.model.parameters(),
|
||||||
trainer.train_config.max_grad_norm
|
trainer.train_config.max_grad_norm
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -152,8 +163,7 @@ class SchedulerCallback(TrainerCallback):
|
||||||
self.current_iter = 0
|
self.current_iter = 0
|
||||||
|
|
||||||
def on_train_begin(self, trainer: 'Trainer', **kwargs):
|
def on_train_begin(self, trainer: 'Trainer', **kwargs):
|
||||||
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
|
self.current_iter = kwargs.get('current_iter')
|
||||||
self.current_iter = len(checkpoint.loss_list)
|
|
||||||
|
|
||||||
for group in trainer.train_config.optimizer.param_groups:
|
for group in trainer.train_config.optimizer.param_groups:
|
||||||
if "initial_lr" not in group:
|
if "initial_lr" not in group:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue