feat(khaosz/trainer): 添加SchedulerCallback功能

This commit is contained in:
ViperEkura 2025-09-29 13:18:44 +08:00
parent 5163d3a47a
commit 648e4e177b
3 changed files with 58 additions and 17 deletions

View File

@ -7,13 +7,26 @@ from khaosz.trainer.strategy import (
StrategyFactory, StrategyFactory,
SchedulerFactory SchedulerFactory
) )
from khaosz.trainer.callback import (
ProgressBarCallback,
CheckpointCallback,
TrainerCallback,
SchedulerCallback
)
__all__ = [ __all__ = [
# strategy
"DatasetLoader", "DatasetLoader",
"Trainer", "Trainer",
"TrainConfig", "TrainConfig",
"CosineScheduleConfig", "CosineScheduleConfig",
"SgdrScheduleConfig", "SgdrScheduleConfig",
"StrategyFactory", "StrategyFactory",
"SchedulerFactory" "SchedulerFactory",
# callback
"ProgressBarCallback",
"CheckpointCallback",
"TrainerCallback",
"SchedulerCallback",
] ]

View File

@ -1,7 +1,9 @@
from tqdm import tqdm from tqdm import tqdm
from khaosz.core.parameter import Checkpoint from khaosz.core.parameter import Checkpoint
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from typing import cast, TYPE_CHECKING from torch.optim.lr_scheduler import LambdaLR
from typing import Optional, cast, TYPE_CHECKING
from khaosz.trainer.strategy import ScheduleConfig, SchedulerFactory
if TYPE_CHECKING: if TYPE_CHECKING:
from khaosz.trainer.trainer import Trainer from khaosz.trainer.trainer import Trainer
@ -130,3 +132,34 @@ class GradientClippingCallback(TrainerCallback):
trainer.checkpoint.model.parameters(), trainer.checkpoint.model.parameters(),
trainer.train_config.max_grad_norm trainer.train_config.max_grad_norm
) )
class SchedulerCallback(TrainerCallback):
"""
Scheduler callback for trainer.
"""
def __init__(self, schedule_config: ScheduleConfig):
self.schedule_config = schedule_config
self.scheduler: Optional[LambdaLR] = None
self.current_iter = 0
def on_train_begin(self, trainer: 'Trainer', **kwargs):
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
self.current_iter = len(checkpoint.loss_list)
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
**self.schedule_config.get_kwargs()
)
self.scheduler = LambdaLR(
trainer.train_config.optimizer,
lambda_scheduler_fn,
last_epoch=self.current_iter - 1
)
def on_step_end(self, trainer: 'Trainer', **kwargs):
_ = trainer, kwargs
if self.scheduler:
self.scheduler.step()
self.current_iter += 1

View File

@ -1,12 +1,17 @@
import os import os
import torch import torch
from typing import Optional, List from typing import Optional, List
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, RandomSampler from torch.utils.data import DataLoader, RandomSampler
from khaosz.core import ModelParameter, Checkpoint from khaosz.core import ModelParameter, Checkpoint
from khaosz.trainer.strategy import SchedulerFactory, TrainConfig, ScheduleConfig from khaosz.trainer.strategy import TrainConfig, ScheduleConfig
from khaosz.trainer.callback import TrainerCallback, ProgressBarCallback, CheckpointCallback, GradientClippingCallback from khaosz.trainer.callback import (
TrainerCallback,
ProgressBarCallback,
CheckpointCallback,
GradientClippingCallback,
SchedulerCallback
)
class Trainer: class Trainer:
@ -32,6 +37,7 @@ class Trainer:
ProgressBarCallback(), ProgressBarCallback(),
CheckpointCallback(self.train_config.checkpoint_interval), CheckpointCallback(self.train_config.checkpoint_interval),
GradientClippingCallback(), GradientClippingCallback(),
SchedulerCallback(self.schedule_config),
] ]
def _create_dataloader(self) -> DataLoader: def _create_dataloader(self) -> DataLoader:
@ -73,16 +79,6 @@ class Trainer:
if "initial_lr" not in group: if "initial_lr" not in group:
group["initial_lr"] = group["lr"] group["initial_lr"] = group["lr"]
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
**self.schedule_config.get_kwargs()
)
scheduler = LambdaLR(
self.train_config.optimizer,
lambda_scheduler_fn,
last_epoch=current_iter - 1 if train_checkpoint else -1
)
reamining_steps = self.train_config.n_epoch - current_iter reamining_steps = self.train_config.n_epoch - current_iter
total_steps = len(self.train_config.dataset) // self.train_config.batch_size total_steps = len(self.train_config.dataset) // self.train_config.batch_size
remaining_epochs = (reamining_steps + total_steps - 1) // total_steps remaining_epochs = (reamining_steps + total_steps - 1) // total_steps
@ -114,7 +110,6 @@ class Trainer:
self._call_callbacks('on_step_end', current_iter=current_iter) self._call_callbacks('on_step_end', current_iter=current_iter)
current_iter += 1 current_iter += 1
scheduler.step()
self._call_callbacks('on_epoch_end', epoch=epoch, loss_list=self.checkpoint.loss_list) self._call_callbacks('on_epoch_end', epoch=epoch, loss_list=self.checkpoint.loss_list)