feat(khaosz/trainer): 添加SchedulerCallback功能

This commit is contained in:
ViperEkura 2025-09-29 13:18:44 +08:00
parent 5163d3a47a
commit 648e4e177b
3 changed files with 58 additions and 17 deletions

View File

@ -7,13 +7,26 @@ from khaosz.trainer.strategy import (
StrategyFactory,
SchedulerFactory
)
from khaosz.trainer.callback import (
ProgressBarCallback,
CheckpointCallback,
TrainerCallback,
SchedulerCallback
)
__all__ = [
# strategy
"DatasetLoader",
"Trainer",
"TrainConfig",
"CosineScheduleConfig",
"SgdrScheduleConfig",
"StrategyFactory",
"SchedulerFactory"
"SchedulerFactory",
# callback
"ProgressBarCallback",
"CheckpointCallback",
"TrainerCallback",
"SchedulerCallback",
]

View File

@ -1,7 +1,9 @@
from tqdm import tqdm
from khaosz.core.parameter import Checkpoint
from torch.nn.utils import clip_grad_norm_
from typing import cast, TYPE_CHECKING
from torch.optim.lr_scheduler import LambdaLR
from typing import Optional, cast, TYPE_CHECKING
from khaosz.trainer.strategy import ScheduleConfig, SchedulerFactory
if TYPE_CHECKING:
from khaosz.trainer.trainer import Trainer
@ -129,4 +131,35 @@ class GradientClippingCallback(TrainerCallback):
clip_grad_norm_(
trainer.checkpoint.model.parameters(),
trainer.train_config.max_grad_norm
)
)
class SchedulerCallback(TrainerCallback):
"""
Scheduler callback for trainer.
"""
def __init__(self, schedule_config: ScheduleConfig):
self.schedule_config = schedule_config
self.scheduler: Optional[LambdaLR] = None
self.current_iter = 0
def on_train_begin(self, trainer: 'Trainer', **kwargs):
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
self.current_iter = len(checkpoint.loss_list)
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
**self.schedule_config.get_kwargs()
)
self.scheduler = LambdaLR(
trainer.train_config.optimizer,
lambda_scheduler_fn,
last_epoch=self.current_iter - 1
)
def on_step_end(self, trainer: 'Trainer', **kwargs):
_ = trainer, kwargs
if self.scheduler:
self.scheduler.step()
self.current_iter += 1

View File

@ -1,12 +1,17 @@
import os
import torch
from typing import Optional, List
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, RandomSampler
from khaosz.core import ModelParameter, Checkpoint
from khaosz.trainer.strategy import SchedulerFactory, TrainConfig, ScheduleConfig
from khaosz.trainer.callback import TrainerCallback, ProgressBarCallback, CheckpointCallback, GradientClippingCallback
from khaosz.trainer.strategy import TrainConfig, ScheduleConfig
from khaosz.trainer.callback import (
TrainerCallback,
ProgressBarCallback,
CheckpointCallback,
GradientClippingCallback,
SchedulerCallback
)
class Trainer:
@ -32,6 +37,7 @@ class Trainer:
ProgressBarCallback(),
CheckpointCallback(self.train_config.checkpoint_interval),
GradientClippingCallback(),
SchedulerCallback(self.schedule_config),
]
def _create_dataloader(self) -> DataLoader:
@ -72,16 +78,6 @@ class Trainer:
for group in self.train_config.optimizer.param_groups:
if "initial_lr" not in group:
group["initial_lr"] = group["lr"]
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
**self.schedule_config.get_kwargs()
)
scheduler = LambdaLR(
self.train_config.optimizer,
lambda_scheduler_fn,
last_epoch=current_iter - 1 if train_checkpoint else -1
)
reamining_steps = self.train_config.n_epoch - current_iter
total_steps = len(self.train_config.dataset) // self.train_config.batch_size
@ -114,7 +110,6 @@ class Trainer:
self._call_callbacks('on_step_end', current_iter=current_iter)
current_iter += 1
scheduler.step()
self._call_callbacks('on_epoch_end', epoch=epoch, loss_list=self.checkpoint.loss_list)