feat(khaosz/trainer): 引入 TrainContext 和 TrainContextBuilder 优化训练上下文管理
This commit is contained in:
parent
6e1a497c04
commit
240ee00221
|
|
@ -0,0 +1,91 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Self, TYPE_CHECKING
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from khaosz.core.parameter import Checkpoint
|
||||||
|
from khaosz.trainer.data_util import RandomSampler
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from khaosz.trainer.trainer import Trainer
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainContext:
|
||||||
|
dataloader: DataLoader
|
||||||
|
optimizer: Optimizer
|
||||||
|
sampler: RandomSampler
|
||||||
|
epoch: int
|
||||||
|
current_iter: int
|
||||||
|
loss: float
|
||||||
|
checkpoint: Checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
class TrainContextBuilder:
|
||||||
|
def __init__(self, trainer: 'Trainer'):
|
||||||
|
self.trainer = trainer
|
||||||
|
self._context = TrainContext(
|
||||||
|
dataloader=None,
|
||||||
|
optimizer=None,
|
||||||
|
sampler=None,
|
||||||
|
epoch=0,
|
||||||
|
current_iter=0,
|
||||||
|
loss=0.0,
|
||||||
|
checkpoint=None
|
||||||
|
)
|
||||||
|
|
||||||
|
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||||
|
if checkpoint is None:
|
||||||
|
checkpoint = Checkpoint(
|
||||||
|
model=self.trainer.parameter.model,
|
||||||
|
tokenizer=self.trainer.parameter.tokenizer,
|
||||||
|
config=self.trainer.parameter.config,
|
||||||
|
sampler_state=None,
|
||||||
|
optim_state=None,
|
||||||
|
loss_list=[]
|
||||||
|
)
|
||||||
|
self._context.checkpoint = checkpoint
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_sampler(self) -> Self:
|
||||||
|
seed = self.trainer.train_config.random_seed
|
||||||
|
sampler = RandomSampler(
|
||||||
|
data_source=self.trainer.train_config.dataset,
|
||||||
|
seed=seed
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._context.checkpoint and self._context.checkpoint.sampler_state:
|
||||||
|
sampler.load_state_dict(self._context.checkpoint.sampler_state)
|
||||||
|
|
||||||
|
self._context.sampler = sampler
|
||||||
|
self._context.epoch = sampler.epoch
|
||||||
|
self._context.current_iter = sampler.current_iter
|
||||||
|
|
||||||
|
if self._context.checkpoint:
|
||||||
|
self._context.checkpoint.sampler_state = sampler.state_dict()
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_optimizer(self) -> Self:
|
||||||
|
optimizer = self.trainer.train_config.optimizer
|
||||||
|
|
||||||
|
if self._context.checkpoint and self._context.checkpoint.optim_state:
|
||||||
|
optimizer.load_state_dict(self._context.checkpoint.optim_state)
|
||||||
|
|
||||||
|
self._context.optimizer = optimizer
|
||||||
|
|
||||||
|
if self._context.checkpoint:
|
||||||
|
self._context.checkpoint.optim_state = optimizer.state_dict()
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_dataloader(self) -> Self:
|
||||||
|
dataloader = DataLoader(
|
||||||
|
self.trainer.train_config.dataset,
|
||||||
|
batch_size=self.trainer.train_config.batch_size,
|
||||||
|
sampler=self._context.sampler
|
||||||
|
)
|
||||||
|
self._context.dataloader = dataloader
|
||||||
|
return self
|
||||||
|
|
||||||
|
def build(self) -> TrainContext:
|
||||||
|
return self._context
|
||||||
|
|
@ -1,9 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, List, cast
|
from typing import Optional, List
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from khaosz.core import ModelParameter, Checkpoint
|
from khaosz.core import ModelParameter, Checkpoint
|
||||||
from khaosz.trainer.data_util import RandomSampler
|
|
||||||
from khaosz.trainer.strategy import TrainConfig, ScheduleConfig
|
from khaosz.trainer.strategy import TrainConfig, ScheduleConfig
|
||||||
from khaosz.trainer.trainer_callback import (
|
from khaosz.trainer.trainer_callback import (
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
|
|
@ -12,6 +10,7 @@ from khaosz.trainer.trainer_callback import (
|
||||||
GradientClippingCallback,
|
GradientClippingCallback,
|
||||||
SchedulerCallback
|
SchedulerCallback
|
||||||
)
|
)
|
||||||
|
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -36,105 +35,61 @@ class Trainer:
|
||||||
SchedulerCallback(self.schedule_config),
|
SchedulerCallback(self.schedule_config),
|
||||||
]
|
]
|
||||||
|
|
||||||
def _set_train_kwargs(self, kwargs: dict):
|
def _build_train_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
||||||
seed = self.train_config.random_seed
|
return (TrainContextBuilder(self)
|
||||||
sampler = RandomSampler(data_source=self.train_config.dataset, seed=seed)
|
.with_checkpoint(checkpoint)
|
||||||
optim = self.train_config.optimizer
|
.with_sampler()
|
||||||
checkpoint = cast(Checkpoint, kwargs.get('checkpoint', None))
|
.with_optimizer()
|
||||||
|
.with_dataloader()
|
||||||
|
.build())
|
||||||
|
|
||||||
if checkpoint is None:
|
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||||
checkpoint = Checkpoint(
|
kwargs = {
|
||||||
model=self.parameter.model,
|
'dataloader': context.dataloader,
|
||||||
tokenizer=self.parameter.tokenizer,
|
'optimizer': context.optimizer,
|
||||||
config=self.parameter.config,
|
'sampler': context.sampler,
|
||||||
sampler_state=None,
|
'epoch': context.epoch,
|
||||||
optim_state=None,
|
'current_iter': context.current_iter,
|
||||||
loss_list=[]
|
'loss': context.loss,
|
||||||
)
|
'checkpoint': context.checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
sampler_state = checkpoint.sampler_state
|
|
||||||
optim_state = checkpoint.optim_state
|
|
||||||
|
|
||||||
if sampler_state:
|
|
||||||
sampler.load_state_dict(sampler_state)
|
|
||||||
|
|
||||||
if optim_state:
|
|
||||||
optim.load_state_dict(optim_state)
|
|
||||||
|
|
||||||
checkpoint.optim_state = optim.state_dict()
|
|
||||||
checkpoint.sampler_state = sampler.state_dict()
|
|
||||||
|
|
||||||
dataloader = DataLoader(
|
|
||||||
self.train_config.dataset,
|
|
||||||
batch_size=self.train_config.batch_size,
|
|
||||||
sampler=sampler
|
|
||||||
)
|
|
||||||
|
|
||||||
kwargs["dataloader"] = dataloader
|
|
||||||
kwargs["optimizer"] = self.train_config.optimizer
|
|
||||||
kwargs["epoch"] = sampler.epoch
|
|
||||||
kwargs["current_iter"] = sampler.current_iter
|
|
||||||
kwargs["sampler"] = sampler
|
|
||||||
kwargs["checkpoint"] = checkpoint
|
|
||||||
|
|
||||||
def _call_callbacks(self, method_name: str, **kwargs):
|
|
||||||
for callback in self.callbacks:
|
for callback in self.callbacks:
|
||||||
method = getattr(callback, method_name, None)
|
method = getattr(callback, method_name, None)
|
||||||
if method:
|
if method:
|
||||||
method(self, **kwargs)
|
method(self, **kwargs)
|
||||||
|
|
||||||
def train(
|
def train(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
|
||||||
self,
|
context = self._build_train_context(checkpoint)
|
||||||
checkpoint: Optional[Checkpoint] = None
|
|
||||||
) -> Checkpoint:
|
|
||||||
|
|
||||||
# train
|
self._call_callbacks('on_train_begin', context)
|
||||||
train_kwargs = {
|
|
||||||
'checkpoint': checkpoint,
|
|
||||||
'dataloader': None,
|
|
||||||
'optimizer': None,
|
|
||||||
'sampler': None,
|
|
||||||
'epoch': 0,
|
|
||||||
'current_iter': 0,
|
|
||||||
'loss': 0.0,
|
|
||||||
}
|
|
||||||
|
|
||||||
self._set_train_kwargs(train_kwargs)
|
|
||||||
self._call_callbacks('on_train_begin', **train_kwargs)
|
|
||||||
|
|
||||||
dataloader = train_kwargs['dataloader']
|
|
||||||
checkpoint = train_kwargs['checkpoint']
|
|
||||||
start_epoch = train_kwargs['epoch']
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.parameter.model.train()
|
self.parameter.model.train()
|
||||||
for epoch in range(start_epoch, self.train_config.n_epoch):
|
for epoch in range(context.epoch, self.train_config.n_epoch):
|
||||||
# epoch
|
context.epoch = epoch
|
||||||
train_kwargs["epoch"] = epoch
|
self._call_callbacks('on_epoch_begin', context)
|
||||||
self._call_callbacks('on_epoch_begin', **train_kwargs)
|
|
||||||
for batch in dataloader:
|
|
||||||
|
|
||||||
if train_kwargs["current_iter"] % self.train_config.accumulation_steps == 0:
|
for batch in context.dataloader:
|
||||||
# step
|
if context.current_iter % self.train_config.accumulation_steps == 0:
|
||||||
self._call_callbacks('on_step_begin', **train_kwargs)
|
self._call_callbacks('on_step_begin', context)
|
||||||
self.train_config.optimizer.step()
|
self.train_config.optimizer.step()
|
||||||
self.train_config.optimizer.zero_grad()
|
self.train_config.optimizer.zero_grad()
|
||||||
self._call_callbacks('on_step_end', **train_kwargs)
|
self._call_callbacks('on_step_end', context)
|
||||||
|
|
||||||
# batch
|
self._call_callbacks('on_batch_begin', context)
|
||||||
self._call_callbacks('on_batch_begin', **train_kwargs)
|
|
||||||
loss = self.train_config.strategy(batch)
|
loss = self.train_config.strategy(batch)
|
||||||
train_kwargs["loss"] = loss.item()
|
context.loss = loss.item()
|
||||||
train_kwargs["current_iter"] += 1
|
context.current_iter += 1
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
self._call_callbacks('on_batch_end', **train_kwargs)
|
self._call_callbacks('on_batch_end', context)
|
||||||
|
|
||||||
self._call_callbacks('on_epoch_end', **train_kwargs)
|
self._call_callbacks('on_epoch_end', context)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Training failed: {str(e)}", exc_info=True)
|
logger.error(f"Training failed: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
self._call_callbacks('on_train_end', **train_kwargs)
|
self._call_callbacks('on_train_end', context)
|
||||||
return checkpoint
|
return context.checkpoint
|
||||||
Loading…
Reference in New Issue