feat(khaosz/trainer): 添加SchedulerCallback功能
This commit is contained in:
parent
5163d3a47a
commit
648e4e177b
|
|
@ -7,13 +7,26 @@ from khaosz.trainer.strategy import (
|
|||
StrategyFactory,
|
||||
SchedulerFactory
|
||||
)
|
||||
from khaosz.trainer.callback import (
|
||||
ProgressBarCallback,
|
||||
CheckpointCallback,
|
||||
TrainerCallback,
|
||||
SchedulerCallback
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# strategy
|
||||
"DatasetLoader",
|
||||
"Trainer",
|
||||
"TrainConfig",
|
||||
"CosineScheduleConfig",
|
||||
"SgdrScheduleConfig",
|
||||
"StrategyFactory",
|
||||
"SchedulerFactory"
|
||||
"SchedulerFactory",
|
||||
|
||||
# callback
|
||||
"ProgressBarCallback",
|
||||
"CheckpointCallback",
|
||||
"TrainerCallback",
|
||||
"SchedulerCallback",
|
||||
]
|
||||
|
|
@ -1,7 +1,9 @@
|
|||
from tqdm import tqdm
|
||||
from khaosz.core.parameter import Checkpoint
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from typing import cast, TYPE_CHECKING
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from typing import Optional, cast, TYPE_CHECKING
|
||||
from khaosz.trainer.strategy import ScheduleConfig, SchedulerFactory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from khaosz.trainer.trainer import Trainer
|
||||
|
|
@ -129,4 +131,35 @@ class GradientClippingCallback(TrainerCallback):
|
|||
clip_grad_norm_(
|
||||
trainer.checkpoint.model.parameters(),
|
||||
trainer.train_config.max_grad_norm
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class SchedulerCallback(TrainerCallback):
|
||||
"""
|
||||
Scheduler callback for trainer.
|
||||
"""
|
||||
def __init__(self, schedule_config: ScheduleConfig):
|
||||
self.schedule_config = schedule_config
|
||||
self.scheduler: Optional[LambdaLR] = None
|
||||
self.current_iter = 0
|
||||
|
||||
def on_train_begin(self, trainer: 'Trainer', **kwargs):
|
||||
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
|
||||
self.current_iter = len(checkpoint.loss_list)
|
||||
|
||||
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
|
||||
**self.schedule_config.get_kwargs()
|
||||
)
|
||||
|
||||
self.scheduler = LambdaLR(
|
||||
trainer.train_config.optimizer,
|
||||
lambda_scheduler_fn,
|
||||
last_epoch=self.current_iter - 1
|
||||
)
|
||||
|
||||
def on_step_end(self, trainer: 'Trainer', **kwargs):
|
||||
_ = trainer, kwargs
|
||||
|
||||
if self.scheduler:
|
||||
self.scheduler.step()
|
||||
self.current_iter += 1
|
||||
|
|
|
|||
|
|
@ -1,12 +1,17 @@
|
|||
import os
|
||||
import torch
|
||||
from typing import Optional, List
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
|
||||
from khaosz.core import ModelParameter, Checkpoint
|
||||
from khaosz.trainer.strategy import SchedulerFactory, TrainConfig, ScheduleConfig
|
||||
from khaosz.trainer.callback import TrainerCallback, ProgressBarCallback, CheckpointCallback, GradientClippingCallback
|
||||
from khaosz.trainer.strategy import TrainConfig, ScheduleConfig
|
||||
from khaosz.trainer.callback import (
|
||||
TrainerCallback,
|
||||
ProgressBarCallback,
|
||||
CheckpointCallback,
|
||||
GradientClippingCallback,
|
||||
SchedulerCallback
|
||||
)
|
||||
|
||||
|
||||
class Trainer:
|
||||
|
|
@ -32,6 +37,7 @@ class Trainer:
|
|||
ProgressBarCallback(),
|
||||
CheckpointCallback(self.train_config.checkpoint_interval),
|
||||
GradientClippingCallback(),
|
||||
SchedulerCallback(self.schedule_config),
|
||||
]
|
||||
|
||||
def _create_dataloader(self) -> DataLoader:
|
||||
|
|
@ -72,16 +78,6 @@ class Trainer:
|
|||
for group in self.train_config.optimizer.param_groups:
|
||||
if "initial_lr" not in group:
|
||||
group["initial_lr"] = group["lr"]
|
||||
|
||||
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
|
||||
**self.schedule_config.get_kwargs()
|
||||
)
|
||||
|
||||
scheduler = LambdaLR(
|
||||
self.train_config.optimizer,
|
||||
lambda_scheduler_fn,
|
||||
last_epoch=current_iter - 1 if train_checkpoint else -1
|
||||
)
|
||||
|
||||
reamining_steps = self.train_config.n_epoch - current_iter
|
||||
total_steps = len(self.train_config.dataset) // self.train_config.batch_size
|
||||
|
|
@ -114,7 +110,6 @@ class Trainer:
|
|||
self._call_callbacks('on_step_end', current_iter=current_iter)
|
||||
|
||||
current_iter += 1
|
||||
scheduler.step()
|
||||
|
||||
self._call_callbacks('on_epoch_end', epoch=epoch, loss_list=self.checkpoint.loss_list)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue