feat(khaosz/trainer): 添加SchedulerCallback功能
This commit is contained in:
parent
5163d3a47a
commit
648e4e177b
|
|
@ -7,13 +7,26 @@ from khaosz.trainer.strategy import (
|
||||||
StrategyFactory,
|
StrategyFactory,
|
||||||
SchedulerFactory
|
SchedulerFactory
|
||||||
)
|
)
|
||||||
|
from khaosz.trainer.callback import (
|
||||||
|
ProgressBarCallback,
|
||||||
|
CheckpointCallback,
|
||||||
|
TrainerCallback,
|
||||||
|
SchedulerCallback
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
# strategy
|
||||||
"DatasetLoader",
|
"DatasetLoader",
|
||||||
"Trainer",
|
"Trainer",
|
||||||
"TrainConfig",
|
"TrainConfig",
|
||||||
"CosineScheduleConfig",
|
"CosineScheduleConfig",
|
||||||
"SgdrScheduleConfig",
|
"SgdrScheduleConfig",
|
||||||
"StrategyFactory",
|
"StrategyFactory",
|
||||||
"SchedulerFactory"
|
"SchedulerFactory",
|
||||||
|
|
||||||
|
# callback
|
||||||
|
"ProgressBarCallback",
|
||||||
|
"CheckpointCallback",
|
||||||
|
"TrainerCallback",
|
||||||
|
"SchedulerCallback",
|
||||||
]
|
]
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from khaosz.core.parameter import Checkpoint
|
from khaosz.core.parameter import Checkpoint
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from typing import cast, TYPE_CHECKING
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
|
from typing import Optional, cast, TYPE_CHECKING
|
||||||
|
from khaosz.trainer.strategy import ScheduleConfig, SchedulerFactory
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from khaosz.trainer.trainer import Trainer
|
from khaosz.trainer.trainer import Trainer
|
||||||
|
|
@ -130,3 +132,34 @@ class GradientClippingCallback(TrainerCallback):
|
||||||
trainer.checkpoint.model.parameters(),
|
trainer.checkpoint.model.parameters(),
|
||||||
trainer.train_config.max_grad_norm
|
trainer.train_config.max_grad_norm
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerCallback(TrainerCallback):
|
||||||
|
"""
|
||||||
|
Scheduler callback for trainer.
|
||||||
|
"""
|
||||||
|
def __init__(self, schedule_config: ScheduleConfig):
|
||||||
|
self.schedule_config = schedule_config
|
||||||
|
self.scheduler: Optional[LambdaLR] = None
|
||||||
|
self.current_iter = 0
|
||||||
|
|
||||||
|
def on_train_begin(self, trainer: 'Trainer', **kwargs):
|
||||||
|
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
|
||||||
|
self.current_iter = len(checkpoint.loss_list)
|
||||||
|
|
||||||
|
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
|
||||||
|
**self.schedule_config.get_kwargs()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scheduler = LambdaLR(
|
||||||
|
trainer.train_config.optimizer,
|
||||||
|
lambda_scheduler_fn,
|
||||||
|
last_epoch=self.current_iter - 1
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_step_end(self, trainer: 'Trainer', **kwargs):
|
||||||
|
_ = trainer, kwargs
|
||||||
|
|
||||||
|
if self.scheduler:
|
||||||
|
self.scheduler.step()
|
||||||
|
self.current_iter += 1
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,17 @@
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from torch.optim.lr_scheduler import LambdaLR
|
|
||||||
from torch.utils.data import DataLoader, RandomSampler
|
from torch.utils.data import DataLoader, RandomSampler
|
||||||
|
|
||||||
from khaosz.core import ModelParameter, Checkpoint
|
from khaosz.core import ModelParameter, Checkpoint
|
||||||
from khaosz.trainer.strategy import SchedulerFactory, TrainConfig, ScheduleConfig
|
from khaosz.trainer.strategy import TrainConfig, ScheduleConfig
|
||||||
from khaosz.trainer.callback import TrainerCallback, ProgressBarCallback, CheckpointCallback, GradientClippingCallback
|
from khaosz.trainer.callback import (
|
||||||
|
TrainerCallback,
|
||||||
|
ProgressBarCallback,
|
||||||
|
CheckpointCallback,
|
||||||
|
GradientClippingCallback,
|
||||||
|
SchedulerCallback
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
|
|
@ -32,6 +37,7 @@ class Trainer:
|
||||||
ProgressBarCallback(),
|
ProgressBarCallback(),
|
||||||
CheckpointCallback(self.train_config.checkpoint_interval),
|
CheckpointCallback(self.train_config.checkpoint_interval),
|
||||||
GradientClippingCallback(),
|
GradientClippingCallback(),
|
||||||
|
SchedulerCallback(self.schedule_config),
|
||||||
]
|
]
|
||||||
|
|
||||||
def _create_dataloader(self) -> DataLoader:
|
def _create_dataloader(self) -> DataLoader:
|
||||||
|
|
@ -73,16 +79,6 @@ class Trainer:
|
||||||
if "initial_lr" not in group:
|
if "initial_lr" not in group:
|
||||||
group["initial_lr"] = group["lr"]
|
group["initial_lr"] = group["lr"]
|
||||||
|
|
||||||
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
|
|
||||||
**self.schedule_config.get_kwargs()
|
|
||||||
)
|
|
||||||
|
|
||||||
scheduler = LambdaLR(
|
|
||||||
self.train_config.optimizer,
|
|
||||||
lambda_scheduler_fn,
|
|
||||||
last_epoch=current_iter - 1 if train_checkpoint else -1
|
|
||||||
)
|
|
||||||
|
|
||||||
reamining_steps = self.train_config.n_epoch - current_iter
|
reamining_steps = self.train_config.n_epoch - current_iter
|
||||||
total_steps = len(self.train_config.dataset) // self.train_config.batch_size
|
total_steps = len(self.train_config.dataset) // self.train_config.batch_size
|
||||||
remaining_epochs = (reamining_steps + total_steps - 1) // total_steps
|
remaining_epochs = (reamining_steps + total_steps - 1) // total_steps
|
||||||
|
|
@ -114,7 +110,6 @@ class Trainer:
|
||||||
self._call_callbacks('on_step_end', current_iter=current_iter)
|
self._call_callbacks('on_step_end', current_iter=current_iter)
|
||||||
|
|
||||||
current_iter += 1
|
current_iter += 1
|
||||||
scheduler.step()
|
|
||||||
|
|
||||||
self._call_callbacks('on_epoch_end', epoch=epoch, loss_list=self.checkpoint.loss_list)
|
self._call_callbacks('on_epoch_end', epoch=epoch, loss_list=self.checkpoint.loss_list)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue