feat(trainer): 引入训练器回调机制并重构训练流程
This commit is contained in:
parent
92999fa9f6
commit
816bc78894
|
|
@ -1,7 +1,8 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
from abc import abstractmethod
|
||||||
from typing import Optional
|
from typing import Optional, List, override
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from torch.optim.lr_scheduler import LambdaLR
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
from torch.utils.data import DataLoader, RandomSampler
|
from torch.utils.data import DataLoader, RandomSampler
|
||||||
|
|
@ -11,12 +12,106 @@ from khaosz.core import ModelParameter, Checkpoint
|
||||||
from khaosz.trainer.strategy import SchedulerFactory, TrainConfig, ScheduleConfig
|
from khaosz.trainer.strategy import SchedulerFactory, TrainConfig, ScheduleConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TrainerCallback:
|
||||||
|
@abstractmethod
|
||||||
|
def on_train_begin(self, trainer: 'Trainer', **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_train_end(self, trainer: 'Trainer', **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_epoch_begin(self, trainer: 'Trainer', **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_epoch_end(self, trainer: 'Trainer', **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_batch_begin(self, trainer: 'Trainer', **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_batch_end(self, trainer: 'Trainer', **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_step_begin(self, trainer: 'Trainer', **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_step_end(self, trainer: 'Trainer', **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressBarCallback(TrainerCallback):
|
||||||
|
def __init__(self):
|
||||||
|
self.progress_bar: tqdm = None
|
||||||
|
|
||||||
|
def on_epoch_begin(self, trainer: 'Trainer', **kwargs):
|
||||||
|
epoch = kwargs.get('epoch')
|
||||||
|
dataloader = trainer._create_dataloader()
|
||||||
|
self.progress_bar = tqdm(
|
||||||
|
dataloader,
|
||||||
|
desc=f"Epoch {epoch+1}/{trainer.train_config.n_epoch}",
|
||||||
|
dynamic_ncols=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_batch_end(self, trainer: 'Trainer', **kwargs):
|
||||||
|
loss = kwargs.get('loss')
|
||||||
|
self.progress_bar.set_postfix({
|
||||||
|
"loss": f"{loss:.4f}",
|
||||||
|
"lr": f"{trainer.train_config.optimizer.param_groups[0]['lr']:.2e}"
|
||||||
|
})
|
||||||
|
self.progress_bar.update(1)
|
||||||
|
|
||||||
|
|
||||||
|
def on_epoch_end(self, trainer: 'Trainer', **kwargs):
|
||||||
|
if self.progress_bar:
|
||||||
|
self.progress_bar.close()
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointCallback(TrainerCallback):
|
||||||
|
def __init__(self, checkpoint_interval: int):
|
||||||
|
self.checkpoint_interval = checkpoint_interval
|
||||||
|
self.last_ckpt_iter = 0
|
||||||
|
|
||||||
|
def on_train_begin(self, trainer: 'Trainer', **kwargs):
|
||||||
|
checkpoint = kwargs.get('checkpoint')
|
||||||
|
self.last_ckpt_iter = len(checkpoint.loss_list)
|
||||||
|
|
||||||
|
def on_batch_end(self, trainer: 'Trainer', **kwargs):
|
||||||
|
current_iter = kwargs.get('current_iter')
|
||||||
|
if current_iter - self.last_ckpt_iter >= self.checkpoint_interval:
|
||||||
|
trainer._save_checkpoint()
|
||||||
|
self.last_ckpt_iter = current_iter
|
||||||
|
|
||||||
|
def on_train_end(self, trainer: 'Trainer', **kwargs):
|
||||||
|
checkpoint = kwargs.get('checkpoint')
|
||||||
|
current_iter = len(checkpoint.loss_list)
|
||||||
|
if current_iter != self.last_ckpt_iter:
|
||||||
|
trainer._save_checkpoint()
|
||||||
|
|
||||||
|
|
||||||
|
class GradientClippingCallback(TrainerCallback):
|
||||||
|
|
||||||
|
def on_step_begin(self, trainer: 'Trainer', **kwargs):
|
||||||
|
clip_grad_norm_(
|
||||||
|
trainer.checkpoint.model.parameters(),
|
||||||
|
trainer.train_config.max_grad_norm
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
parameter: ModelParameter,
|
parameter: ModelParameter,
|
||||||
train_config: TrainConfig,
|
train_config: TrainConfig,
|
||||||
schedule_config: ScheduleConfig
|
schedule_config: ScheduleConfig,
|
||||||
|
callbacks: Optional[List[TrainerCallback]] = None
|
||||||
):
|
):
|
||||||
self.checkpoint = Checkpoint(
|
self.checkpoint = Checkpoint(
|
||||||
model=parameter.model,
|
model=parameter.model,
|
||||||
|
|
@ -26,16 +121,37 @@ class Trainer:
|
||||||
self.train_config = train_config
|
self.train_config = train_config
|
||||||
self.schedule_config = schedule_config
|
self.schedule_config = schedule_config
|
||||||
|
|
||||||
def save_checkpoint(
|
self.callbacks = callbacks or self._get_default_callbacks()
|
||||||
self,
|
|
||||||
loss_list: list,
|
def _get_default_callbacks(self) -> List[TrainerCallback]:
|
||||||
):
|
return [
|
||||||
current_iter = len(loss_list)
|
ProgressBarCallback(),
|
||||||
|
CheckpointCallback(self.train_config.checkpoint_interval),
|
||||||
|
GradientClippingCallback(),
|
||||||
|
]
|
||||||
|
|
||||||
|
def _create_dataloader(self) -> DataLoader:
|
||||||
|
seed = self.train_config.random_seed
|
||||||
|
generator = torch.Generator().manual_seed(seed)
|
||||||
|
sampler = RandomSampler(self.train_config.dataset, generator=generator)
|
||||||
|
return DataLoader(
|
||||||
|
self.train_config.dataset,
|
||||||
|
batch_size=self.train_config.batch_size,
|
||||||
|
sampler=sampler
|
||||||
|
)
|
||||||
|
|
||||||
|
def _save_checkpoint(self):
|
||||||
|
current_iter = len(self.checkpoint.loss_list)
|
||||||
save_path = os.path.join(self.train_config.checkpoint_dir, f"iter_{current_iter}")
|
save_path = os.path.join(self.train_config.checkpoint_dir, f"iter_{current_iter}")
|
||||||
self.checkpoint.loss_list = loss_list
|
|
||||||
self.checkpoint.optim_state = self.train_config.optimizer.state_dict()
|
self.checkpoint.optim_state = self.train_config.optimizer.state_dict()
|
||||||
self.checkpoint.save(save_path)
|
self.checkpoint.save(save_path)
|
||||||
|
|
||||||
|
def _call_callbacks(self, method_name: str, **kwargs):
|
||||||
|
for callback in self.callbacks:
|
||||||
|
method = getattr(callback, method_name, None)
|
||||||
|
if method:
|
||||||
|
method(self, **kwargs)
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
self,
|
self,
|
||||||
train_checkpoint: Optional[Checkpoint] = None
|
train_checkpoint: Optional[Checkpoint] = None
|
||||||
|
|
@ -47,16 +163,13 @@ class Trainer:
|
||||||
self.train_config.optimizer.load_state_dict(train_checkpoint.optim_state)
|
self.train_config.optimizer.load_state_dict(train_checkpoint.optim_state)
|
||||||
|
|
||||||
self.checkpoint.optim_state = self.train_config.optimizer.state_dict()
|
self.checkpoint.optim_state = self.train_config.optimizer.state_dict()
|
||||||
loss_list = self.checkpoint.loss_list
|
|
||||||
current_iter = len(self.checkpoint.loss_list)
|
current_iter = len(self.checkpoint.loss_list)
|
||||||
last_ckpt_iter = current_iter
|
|
||||||
|
|
||||||
for group in self.train_config.optimizer.param_groups:
|
for group in self.train_config.optimizer.param_groups:
|
||||||
if "initial_lr" not in group:
|
if "initial_lr" not in group:
|
||||||
group["initial_lr"] = group["lr"]
|
group["initial_lr"] = group["lr"]
|
||||||
|
|
||||||
|
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
|
||||||
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
|
|
||||||
**self.schedule_config.get_kwargs()
|
**self.schedule_config.get_kwargs()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -66,52 +179,42 @@ class Trainer:
|
||||||
last_epoch=current_iter - 1 if train_checkpoint else -1
|
last_epoch=current_iter - 1 if train_checkpoint else -1
|
||||||
)
|
)
|
||||||
|
|
||||||
seed = self.train_config.random_seed
|
reamining_steps = self.train_config.n_epoch - current_iter
|
||||||
generator = torch.Generator().manual_seed(seed)
|
total_steps = len(self.train_config.dataset) // self.train_config.batch_size
|
||||||
sampler = RandomSampler(self.train_config.dataset, generator=generator)
|
remaining_epochs = (reamining_steps + total_steps - 1) // total_steps
|
||||||
remaining_epochs = self.train_config.n_epoch - current_iter // (
|
# train
|
||||||
len(self.train_config.dataset) // self.train_config.batch_size)
|
self._call_callbacks('on_train_begin', checkpoint=self.checkpoint)
|
||||||
|
|
||||||
for epoch in range(remaining_epochs):
|
try:
|
||||||
self.checkpoint.model.train()
|
for epoch in range(remaining_epochs):
|
||||||
dataloader = DataLoader(
|
self.checkpoint.model.train()
|
||||||
self.train_config.dataset,
|
|
||||||
batch_size=self.train_config.batch_size,
|
|
||||||
sampler=sampler
|
|
||||||
)
|
|
||||||
progress_bar = tqdm(
|
|
||||||
dataloader,
|
|
||||||
desc=f"Epoch {epoch+1}/{self.train_config.n_epoch}",
|
|
||||||
dynamic_ncols=True
|
|
||||||
)
|
|
||||||
for batch in progress_bar:
|
|
||||||
#forward
|
|
||||||
loss = self.train_config.strategy(batch)
|
|
||||||
loss_list.append(loss.item())
|
|
||||||
#backward
|
|
||||||
loss.backward()
|
|
||||||
#step
|
|
||||||
if current_iter % self.train_config.accumulation_steps == 0:
|
|
||||||
clip_grad_norm_(
|
|
||||||
self.checkpoint.model.parameters(),
|
|
||||||
self.train_config.max_grad_norm
|
|
||||||
)
|
|
||||||
self.train_config.optimizer.step()
|
|
||||||
self.train_config.optimizer.zero_grad()
|
|
||||||
|
|
||||||
current_iter += 1
|
# epoch
|
||||||
scheduler.step()
|
self._call_callbacks('on_epoch_begin', epoch=epoch)
|
||||||
progress_bar.set_postfix({
|
|
||||||
"loss": f"{loss.item():.4f}",
|
|
||||||
"lr": f"{self.train_config.optimizer.param_groups[0]['lr']:.2e}"
|
|
||||||
})
|
|
||||||
#save checkpotint
|
|
||||||
if current_iter - last_ckpt_iter >= self.train_config.checkpoint_interval:
|
|
||||||
self.save_checkpoint(loss_list)
|
|
||||||
last_ckpt_iter = current_iter
|
|
||||||
|
|
||||||
if current_iter != last_ckpt_iter:
|
dataloader = self._create_dataloader()
|
||||||
self.save_checkpoint(loss_list)
|
|
||||||
last_ckpt_iter = current_iter
|
for batch in dataloader:
|
||||||
|
# batch
|
||||||
|
self._call_callbacks('on_batch_begin', batch=batch)
|
||||||
|
loss = self.train_config.strategy(batch)
|
||||||
|
self.checkpoint.loss_list.append(loss.item())
|
||||||
|
loss.backward()
|
||||||
|
self._call_callbacks('on_batch_end', batch=batch, loss=loss.item(), current_iter=current_iter)
|
||||||
|
|
||||||
|
if current_iter % self.train_config.accumulation_steps == 0:
|
||||||
|
# step
|
||||||
|
self._call_callbacks('on_step_begin', current_iter=current_iter)
|
||||||
|
self.train_config.optimizer.step()
|
||||||
|
self.train_config.optimizer.zero_grad()
|
||||||
|
self._call_callbacks('on_step_end', current_iter=current_iter)
|
||||||
|
|
||||||
|
current_iter += 1
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
self._call_callbacks('on_epoch_end', epoch=epoch, loss_list=self.checkpoint.loss_list)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self._call_callbacks('on_train_end', checkpoint=self.checkpoint)
|
||||||
|
|
||||||
return self.checkpoint
|
return self.checkpoint
|
||||||
Loading…
Reference in New Issue