refactor(trainer): 重构学习率调度器实现并分离配置与工厂逻辑

This commit is contained in:
ViperEkura 2025-10-18 16:42:37 +08:00
parent c51b203fde
commit b67bc9865d
9 changed files with 299 additions and 217 deletions

View File

@ -1,12 +1,18 @@
from khaosz.config.model_config import TransformerConfig from khaosz.config.model_config import TransformerConfig
from khaosz.config.param_config import BaseModelIO, ModelParameter, Checkpoint, ParameterLoader from khaosz.config.param_config import BaseModelIO, ModelParameter, Checkpoint, ParameterLoader
from khaosz.config.schedule_config import ScheduleConfig, CosineScheduleConfig, SgdrScheduleConfig
from khaosz.config.train_config import TrainConfig from khaosz.config.train_config import TrainConfig
__all__ = [ __all__ = [
"BaseModelIO", "BaseModelIO",
"ModelParameter", "ModelParameter",
"Checkpoint", "Checkpoint",
"ParameterLoader", "ParameterLoader",
"TransformerConfig", "TransformerConfig",
"TrainConfig" "TrainConfig",
"ScheduleConfig",
"CosineScheduleConfig",
"SgdrScheduleConfig",
] ]

View File

@ -0,0 +1,87 @@
from typing import Any, Literal, Dict
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
@dataclass
class ScheduleConfig(ABC):
schedule_type: str = field(
default="cosine",
metadata={
"help": "Type of learning rate schedule.",
"choices": ["cosine", "sgdr"]
}
)
warmup_steps: int = field(
default=1000,
metadata={"help": "Number of warmup steps."}
)
min_rate: float = field(
default=0.05,
metadata={"help": "Minimum learning rate multiplier."}
)
@abstractmethod
def get_kwargs(self) -> Dict[str, Any]:
raise NotImplementedError
def validate(self) -> None:
"""Validate configuration parameters."""
if self.warmup_steps < 0:
raise ValueError(f"warmup_steps must be non-negative, got {self.warmup_steps}")
if not 0 <= self.min_rate <= 1:
raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}")
@dataclass
class CosineScheduleConfig(ScheduleConfig):
total_steps: int = field(
default=None,
metadata={"help": "Total training steps for cosine schedule."}
)
schedule_type: Literal["cosine"] = "cosine"
def get_kwargs(self) -> Dict[str, Any]:
if self.total_steps is None:
raise ValueError("total_steps must be specified for cosine schedule")
return {
"schedule_type": self.schedule_type,
"warmup_steps": self.warmup_steps,
"lr_decay_steps": self.total_steps - self.warmup_steps,
"min_rate": self.min_rate
}
def validate(self) -> None:
super().validate()
if self.total_steps is not None and self.total_steps <= self.warmup_steps:
raise ValueError(f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})")
@dataclass
class SgdrScheduleConfig(ScheduleConfig):
cycle_length: int = field(
default=1000,
metadata={"help": "Length of the first cycle in steps."}
)
t_mult: int = field(
default=2,
metadata={"help": "Multiplier for cycle length growth."}
)
schedule_type: Literal["sgdr"] = "sgdr"
def get_kwargs(self) -> Dict[str, Any]:
return {
"schedule_type": self.schedule_type,
"warmup_steps": self.warmup_steps,
"cycle_length": self.cycle_length,
"min_rate": self.min_rate,
"t_mult": self.t_mult
}
def validate(self) -> None:
super().validate()
if self.cycle_length <= 0:
raise ValueError(f"cycle_length must be positive, got {self.cycle_length}")
if self.t_mult < 1:
raise ValueError(f"t_mult must be >= 1, got {self.t_mult}")

View File

@ -1,10 +1,7 @@
from khaosz.trainer.trainer import Trainer from khaosz.trainer.trainer import Trainer
from khaosz.trainer.strategy import ( from khaosz.trainer.strategy import StrategyFactory
CosineScheduleConfig, from khaosz.trainer.schedule import SchedulerFactory
SgdrScheduleConfig,
StrategyFactory,
SchedulerFactory
)
from khaosz.trainer.train_callback import ( from khaosz.trainer.train_callback import (
TrainCallback, TrainCallback,
ProgressBarCallback, ProgressBarCallback,
@ -15,16 +12,18 @@ from khaosz.trainer.train_callback import (
) )
__all__ = [ __all__ = [
# trainer
"Trainer", "Trainer",
# factory
"StrategyFactory", "StrategyFactory",
"CosineScheduleConfig",
"SgdrScheduleConfig",
"SchedulerFactory", "SchedulerFactory",
# callback # callback
"TrainCallback", "TrainCallback",
"ProgressBarCallback", "ProgressBarCallback",
"CheckpointCallback", "CheckpointCallback",
"TrainCallback",
"SchedulerCallback", "SchedulerCallback",
"StepMonitorCallback" "StepMonitorCallback"
] ]

164
khaosz/trainer/schedule.py Normal file
View File

@ -0,0 +1,164 @@
import math
from abc import abstractmethod, ABC
from typing import Any, Dict, List
from torch.optim.lr_scheduler import LRScheduler
from khaosz.config.schedule_config import ScheduleConfig
class BaseScheduler(LRScheduler, ABC):
"""
Base scheduler class for all other schedulers.
"""
def __init__(self, optimizer, last_epoch: int = -1):
super().__init__(optimizer, last_epoch)
@abstractmethod
def get_lr(self) -> List[float]:
raise NotImplementedError
def state_dict(self) -> Dict[str, Any]:
return super().state_dict()
def load_state_dict(self, state_dict: Dict[str, Any]):
super().load_state_dict(state_dict)
class CosineScheduler(BaseScheduler):
"""
Cosine decay scheduler with warmup, implemented as PyTorch LRScheduler.
"""
def __init__(
self,
optimizer,
warmup_steps: int,
lr_decay_steps: int,
min_rate: float = 0.05,
last_epoch: int = -1
):
self.warmup_steps = warmup_steps
self.lr_decay_steps = lr_decay_steps
self.min_rate = min_rate
self.total_steps = warmup_steps + lr_decay_steps
super().__init__(optimizer, last_epoch)
def get_lr(self) -> List[float]:
# warmup
if self.last_epoch < self.warmup_steps:
warmup_factor = max(self.min_rate, self.last_epoch / self.warmup_steps)
return [base_lr * warmup_factor for base_lr in self.base_lrs]
# cosine decay
decay_progress = (self.last_epoch - self.warmup_steps) / self.lr_decay_steps
decay_progress = min(decay_progress, 1.0)
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * decay_progress))
decay_factor = max(self.min_rate, cosine_decay)
return [base_lr * decay_factor for base_lr in self.base_lrs]
def state_dict(self):
state = super().state_dict()
state.update({
'warmup_steps': self.warmup_steps,
'lr_decay_steps': self.lr_decay_steps,
'min_rate': self.min_rate,
'total_steps': self.total_steps,
})
return state
def load_state_dict(self, state_dict):
self.warmup_steps = state_dict.pop('warmup_steps')
self.lr_decay_steps = state_dict.pop('lr_decay_steps')
self.min_rate = state_dict.pop('min_rate')
self.total_steps = state_dict.pop('total_steps')
super().load_state_dict(state_dict)
class SGDRScheduler(BaseScheduler):
"""
SGDR (Stochastic Gradient Descent with Warm Restarts) scheduler,
"""
def __init__(
self,
optimizer,
warmup_steps: int,
cycle_length: int,
min_rate: float = 0.05,
t_mult: int = 2,
last_epoch: int = -1,
):
self.warmup_steps = warmup_steps
self.cycle_length = cycle_length
self.min_rate = min_rate
self.t_mult = t_mult
super().__init__(optimizer, last_epoch)
def get_lr(self):
# warmup
if self.last_epoch < self.warmup_steps:
warmup_factor = max(self.min_rate, self.last_epoch / self.warmup_steps)
return [base_lr * warmup_factor for base_lr in self.base_lrs]
# SGDR
steps_since_warmup = self.last_epoch - self.warmup_steps
# 1. Calculate current cycle and position within cycle
current_cycle_length = self.cycle_length
total_cycles_length = 0
cycle_num = 0
while total_cycles_length + current_cycle_length <= steps_since_warmup:
total_cycles_length += current_cycle_length
current_cycle_length *= self.t_mult
cycle_num += 1
steps_in_cycle = steps_since_warmup - total_cycles_length
# 2. Cosine annealing within the current cycle
cosine_factor = 0.5 * (1 + math.cos(math.pi * steps_in_cycle / current_cycle_length))
learning_rate_factor = self.min_rate + (1 - self.min_rate) * cosine_factor
return [base_lr * learning_rate_factor for base_lr in self.base_lrs]
def state_dict(self):
"""Returns the state of the scheduler as a dict."""
state = super().state_dict()
state.update({
'warmup_steps': self.warmup_steps,
'cycle_length': self.cycle_length,
'min_rate': self.min_rate,
't_mult': self.t_mult
})
return state
def load_state_dict(self, state_dict):
"""Loads the scheduler's state."""
self.warmup_steps = state_dict.pop('warmup_steps')
self.cycle_length = state_dict.pop('cycle_length')
self.min_rate = state_dict.pop('min_rate')
self.t_mult = state_dict.pop('t_mult')
super().load_state_dict(state_dict)
class SchedulerFactory:
"""
Factory class for creating learning rate schedulers.
"""
@staticmethod
def load_scheduler(optimizer, scedule_config: ScheduleConfig) -> BaseScheduler:
kwargs = scedule_config.get_kwargs()
schedule_type = kwargs.pop("schedule_type")
if schedule_type == "cosine":
return CosineScheduler(optimizer, **kwargs)
elif schedule_type == "sgdr":
return SGDRScheduler(optimizer, **kwargs)
else:
raise ValueError(f"Unsupported schedule type: {schedule_type}")

View File

@ -1,13 +1,11 @@
import copy import copy
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from typing import Any, Literal, Tuple, Callable, Dict, Union from typing import Any, Tuple, Callable, Dict, Union
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field
def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id: int): def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id: int):
@ -167,181 +165,4 @@ class StrategyFactory:
) )
} }
strategy = train_strategy[train_type]() strategy = train_strategy[train_type]()
return strategy return strategy
@dataclass
class ScheduleConfig(ABC):
schedule_type: str = field(
default="cosine",
metadata={
"help": "Type of learning rate schedule.",
"choices": ["cosine", "sgdr"]
}
)
warmup_steps: int = field(
default=1000,
metadata={"help": "Number of warmup steps."}
)
min_rate: float = field(
default=0.05,
metadata={"help": "Minimum learning rate multiplier."}
)
@abstractmethod
def get_kwargs(self) -> Dict[str, Any]:
raise NotImplementedError
def validate(self) -> None:
"""Validate configuration parameters."""
if self.warmup_steps < 0:
raise ValueError(f"warmup_steps must be non-negative, got {self.warmup_steps}")
if not 0 <= self.min_rate <= 1:
raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}")
@dataclass
class CosineScheduleConfig(ScheduleConfig):
total_steps: int = field(
default=None,
metadata={"help": "Total training steps for cosine schedule."}
)
schedule_type: Literal["cosine"] = "cosine"
def get_kwargs(self) -> Dict[str, Any]:
if self.total_steps is None:
raise ValueError("total_steps must be specified for cosine schedule")
return {
"schedule_type": self.schedule_type,
"warmup_steps": self.warmup_steps,
"lr_decay_steps": self.total_steps - self.warmup_steps,
"min_rate": self.min_rate
}
def validate(self) -> None:
super().validate()
if self.total_steps is not None and self.total_steps <= self.warmup_steps:
raise ValueError(f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})")
@dataclass
class SgdrScheduleConfig(ScheduleConfig):
cycle_length: int = field(
default=1000,
metadata={"help": "Length of the first cycle in steps."}
)
t_mult: int = field(
default=2,
metadata={"help": "Multiplier for cycle length growth."}
)
schedule_type: Literal["sgdr"] = "sgdr"
def get_kwargs(self) -> Dict[str, Any]:
return {
"schedule_type": self.schedule_type,
"warmup_steps": self.warmup_steps,
"cycle_length": self.cycle_length,
"min_rate": self.min_rate,
"t_mult": self.t_mult
}
def validate(self) -> None:
super().validate()
if self.cycle_length <= 0:
raise ValueError(f"cycle_length must be positive, got {self.cycle_length}")
if self.t_mult < 1:
raise ValueError(f"t_mult must be >= 1, got {self.t_mult}")
class SchedulerFactory:
"""Factory for creating learning rate schedule functions."""
@staticmethod
def get_sgdr_schedule(
warmup_steps: int,
cycle_length: int,
min_rate: float = 0.05,
t_mult: int = 2
) -> Callable[[int], float]:
"""
Create SGDR (Stochastic Gradient Descent with Warm Restarts) schedule.
Args:
warmup_steps: Number of warmup steps
cycle_length: Length of the first cycle
min_rate: Minimum learning rate multiplier
t_mult: Cycle length multiplier
Returns:
Schedule function that takes current step and returns LR multiplier
"""
def sgdr_schedule(current_step: int) -> float:
# Warmup phase
if current_step < warmup_steps:
return max(min_rate, current_step / warmup_steps)
# SGDR phase
steps_since_warmup = current_step - warmup_steps
# Find current cycle and position within cycle
cycle_start = 0
current_cycle_length = cycle_length
cycle_index = 0
while steps_since_warmup >= cycle_start + current_cycle_length:
cycle_start += current_cycle_length
current_cycle_length *= t_mult
cycle_index += 1
position_in_cycle = steps_since_warmup - cycle_start
progress = position_in_cycle / current_cycle_length
# Cosine annealing within cycle
return max(min_rate, 0.5 * (1 + math.cos(math.pi * progress)))
return sgdr_schedule
@staticmethod
def get_cosine_schedule(
warmup_steps: int,
lr_decay_steps: int,
min_rate: float = 0.05
) -> Callable[[int], float]:
"""
Create cosine decay schedule with warmup.
Args:
warmup_steps: Number of warmup steps
lr_decay_steps: Number of steps for cosine decay after warmup
min_rate: Minimum learning rate multiplier
Returns:
Schedule function that takes current step and returns LR multiplier
"""
def cosine_schedule(current_step: int) -> float:
if current_step < warmup_steps:
# Linear warmup
return max(min_rate, current_step / warmup_steps)
else:
# Cosine decay
decay_progress = (current_step - warmup_steps) / lr_decay_steps
decay_progress = min(decay_progress, 1.0) # Clamp at 1.0
return max(min_rate, 0.5 * (1.0 + math.cos(math.pi * decay_progress)))
return cosine_schedule
@staticmethod
def load_schedule_fn(scedule_config: ScheduleConfig) -> Callable[[int], float]:
kwargs = scedule_config.get_kwargs()
schedule_type = kwargs.pop("schedule_type")
if schedule_type == "cosine":
return SchedulerFactory.get_cosine_schedule(**kwargs)
elif schedule_type == "sgdr":
return SchedulerFactory.get_sgdr_schedule(**kwargs)
else:
raise ValueError(f"Unsupported schedule type: {schedule_type}")

View File

@ -8,7 +8,7 @@ from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import LambdaLR from torch.optim.lr_scheduler import LambdaLR
from typing import List, Optional, Protocol, TYPE_CHECKING from typing import List, Optional, Protocol, TYPE_CHECKING
from khaosz.trainer.strategy import ScheduleConfig, SchedulerFactory from khaosz.config import ScheduleConfig
from khaosz.trainer.metric_util import ( from khaosz.trainer.metric_util import (
grad_max, grad_max,
grad_min, grad_min,
@ -78,17 +78,8 @@ class SchedulerCallback(TrainCallback):
for group in trainer.train_config.optimizer.param_groups: for group in trainer.train_config.optimizer.param_groups:
if "initial_lr" not in group: if "initial_lr" not in group:
group["initial_lr"] = group["lr"] group["initial_lr"] = group["lr"]
self.schedule_config.validate()
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
self.schedule_config
)
self.scheduler = LambdaLR( self.scheduler = context.scheduler
trainer.train_config.optimizer,
lambda_scheduler_fn,
last_epoch=context.current_iter - 1
)
def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'): def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'):
_ = trainer, context _ = trainer, context
@ -106,8 +97,8 @@ class CheckpointCallback(TrainCallback):
def _save_checkpoint(self, trainer: 'Trainer', context: 'TrainContext'): def _save_checkpoint(self, trainer: 'Trainer', context: 'TrainContext'):
save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{context.current_iter}") save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{context.current_iter}")
# context.checkpoint.scheduler_state = context.sampler.state_dict()
context.checkpoint.optimizer_state = context.optimizer.state_dict() context.checkpoint.optimizer_state = context.optimizer.state_dict()
context.checkpoint.scheduler_state = context.scheduler.state_dict()
context.checkpoint.save(save_path) context.checkpoint.save(save_path)
self.last_ckpt_iter = context.current_iter self.last_ckpt_iter = context.current_iter

View File

@ -1,10 +1,10 @@
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field, fields
from typing import Optional, Self, TYPE_CHECKING from typing import Optional, Self, TYPE_CHECKING
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from khaosz.config.param_config import Checkpoint from khaosz.config import Checkpoint
from khaosz.data.data_util import ResumeableRandomSampler from khaosz.data import ResumeableRandomSampler
from khaosz.trainer.schedule import BaseScheduler, SchedulerFactory
if TYPE_CHECKING: if TYPE_CHECKING:
from khaosz.trainer.trainer import Trainer from khaosz.trainer.trainer import Trainer
@ -14,7 +14,7 @@ if TYPE_CHECKING:
class TrainContext: class TrainContext:
dataloader: DataLoader = field(default=None) dataloader: DataLoader = field(default=None)
optimizer: Optimizer = field(default=None) optimizer: Optimizer = field(default=None)
scheduler: LRScheduler = field(default=None) scheduler: BaseScheduler = field(default=None)
checkpoint: Checkpoint = field(default=None) checkpoint: Checkpoint = field(default=None)
epoch: int = field(default=0) epoch: int = field(default=0)
current_iter: int = field(default=0) current_iter: int = field(default=0)
@ -54,8 +54,20 @@ class TrainContextBuilder:
return self return self
def with_scheduler(self) -> Self: def with_scheduler(self) -> Self:
return self # the build order has any problem ?
optimizer = self.trainer.train_config.optimizer
schedule_config = self.trainer.schedule_config
scheduler = SchedulerFactory.load_scheduler(optimizer, schedule_config)
if self._context.checkpoint and self._context.checkpoint.scheduler_state:
scheduler.load_state_dict(self._context.checkpoint.scheduler_state)
self._context.scheduler = scheduler
if self._context.checkpoint:
self._context.checkpoint.scheduler_state = scheduler.state_dict()
return self
def with_dataloader(self) -> Self: def with_dataloader(self) -> Self:
resumeable_sampler = ResumeableRandomSampler( resumeable_sampler = ResumeableRandomSampler(

View File

@ -1,9 +1,11 @@
import logging import logging
from typing import Optional, List from typing import Optional, List
from khaosz.config import (
from khaosz.config import ModelParameter, Checkpoint ModelParameter,
from khaosz.trainer.strategy import ScheduleConfig Checkpoint,
from khaosz.config.train_config import TrainConfig ScheduleConfig,
TrainConfig
)
from khaosz.trainer.train_callback import ( from khaosz.trainer.train_callback import (
TrainCallback, TrainCallback,
ProgressBarCallback, ProgressBarCallback,
@ -15,6 +17,7 @@ from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Trainer: class Trainer:
def __init__( def __init__(
self, self,
@ -36,7 +39,7 @@ class Trainer:
SchedulerCallback(self.schedule_config), SchedulerCallback(self.schedule_config),
] ]
def _build_train_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext: def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
return (TrainContextBuilder(self) return (TrainContextBuilder(self)
.with_checkpoint(checkpoint) .with_checkpoint(checkpoint)
.with_optimizer() .with_optimizer()
@ -51,8 +54,7 @@ class Trainer:
method(self, context) method(self, context)
def train(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint: def train(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
context = self._build_train_context(checkpoint) context = self._build_context(checkpoint)
self._call_callbacks('on_train_begin', context) self._call_callbacks('on_train_begin', context)
try: try:

View File

@ -3,9 +3,9 @@ import argparse
import torch import torch
from torch.optim import AdamW from torch.optim import AdamW
from khaosz.core import ParameterLoader, Checkpoint from khaosz.config import ParameterLoader, Checkpoint, TrainConfig, CosineScheduleConfig
from khaosz.trainer import Trainer, DatasetLoader, TrainConfig, CosineScheduleConfig from khaosz.trainer import Trainer, StrategyFactory
from khaosz.trainer import StrategyFactory from khaosz.data import DatasetLoader
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))