refactor(trainer): 重构学习率调度器实现并分离配置与工厂逻辑
This commit is contained in:
parent
c51b203fde
commit
b67bc9865d
|
|
@ -1,12 +1,18 @@
|
||||||
from khaosz.config.model_config import TransformerConfig
|
from khaosz.config.model_config import TransformerConfig
|
||||||
from khaosz.config.param_config import BaseModelIO, ModelParameter, Checkpoint, ParameterLoader
|
from khaosz.config.param_config import BaseModelIO, ModelParameter, Checkpoint, ParameterLoader
|
||||||
|
from khaosz.config.schedule_config import ScheduleConfig, CosineScheduleConfig, SgdrScheduleConfig
|
||||||
from khaosz.config.train_config import TrainConfig
|
from khaosz.config.train_config import TrainConfig
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseModelIO",
|
"BaseModelIO",
|
||||||
"ModelParameter",
|
"ModelParameter",
|
||||||
"Checkpoint",
|
"Checkpoint",
|
||||||
"ParameterLoader",
|
"ParameterLoader",
|
||||||
"TransformerConfig",
|
"TransformerConfig",
|
||||||
"TrainConfig"
|
"TrainConfig",
|
||||||
|
|
||||||
|
"ScheduleConfig",
|
||||||
|
"CosineScheduleConfig",
|
||||||
|
"SgdrScheduleConfig",
|
||||||
]
|
]
|
||||||
|
|
@ -0,0 +1,87 @@
|
||||||
|
from typing import Any, Literal, Dict
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ScheduleConfig(ABC):
|
||||||
|
schedule_type: str = field(
|
||||||
|
default="cosine",
|
||||||
|
metadata={
|
||||||
|
"help": "Type of learning rate schedule.",
|
||||||
|
"choices": ["cosine", "sgdr"]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
warmup_steps: int = field(
|
||||||
|
default=1000,
|
||||||
|
metadata={"help": "Number of warmup steps."}
|
||||||
|
)
|
||||||
|
min_rate: float = field(
|
||||||
|
default=0.05,
|
||||||
|
metadata={"help": "Minimum learning rate multiplier."}
|
||||||
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_kwargs(self) -> Dict[str, Any]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
"""Validate configuration parameters."""
|
||||||
|
if self.warmup_steps < 0:
|
||||||
|
raise ValueError(f"warmup_steps must be non-negative, got {self.warmup_steps}")
|
||||||
|
if not 0 <= self.min_rate <= 1:
|
||||||
|
raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CosineScheduleConfig(ScheduleConfig):
|
||||||
|
total_steps: int = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Total training steps for cosine schedule."}
|
||||||
|
)
|
||||||
|
schedule_type: Literal["cosine"] = "cosine"
|
||||||
|
|
||||||
|
def get_kwargs(self) -> Dict[str, Any]:
|
||||||
|
if self.total_steps is None:
|
||||||
|
raise ValueError("total_steps must be specified for cosine schedule")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"schedule_type": self.schedule_type,
|
||||||
|
"warmup_steps": self.warmup_steps,
|
||||||
|
"lr_decay_steps": self.total_steps - self.warmup_steps,
|
||||||
|
"min_rate": self.min_rate
|
||||||
|
}
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
super().validate()
|
||||||
|
if self.total_steps is not None and self.total_steps <= self.warmup_steps:
|
||||||
|
raise ValueError(f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SgdrScheduleConfig(ScheduleConfig):
|
||||||
|
cycle_length: int = field(
|
||||||
|
default=1000,
|
||||||
|
metadata={"help": "Length of the first cycle in steps."}
|
||||||
|
)
|
||||||
|
t_mult: int = field(
|
||||||
|
default=2,
|
||||||
|
metadata={"help": "Multiplier for cycle length growth."}
|
||||||
|
)
|
||||||
|
schedule_type: Literal["sgdr"] = "sgdr"
|
||||||
|
|
||||||
|
def get_kwargs(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"schedule_type": self.schedule_type,
|
||||||
|
"warmup_steps": self.warmup_steps,
|
||||||
|
"cycle_length": self.cycle_length,
|
||||||
|
"min_rate": self.min_rate,
|
||||||
|
"t_mult": self.t_mult
|
||||||
|
}
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
super().validate()
|
||||||
|
if self.cycle_length <= 0:
|
||||||
|
raise ValueError(f"cycle_length must be positive, got {self.cycle_length}")
|
||||||
|
if self.t_mult < 1:
|
||||||
|
raise ValueError(f"t_mult must be >= 1, got {self.t_mult}")
|
||||||
|
|
@ -1,10 +1,7 @@
|
||||||
from khaosz.trainer.trainer import Trainer
|
from khaosz.trainer.trainer import Trainer
|
||||||
from khaosz.trainer.strategy import (
|
from khaosz.trainer.strategy import StrategyFactory
|
||||||
CosineScheduleConfig,
|
from khaosz.trainer.schedule import SchedulerFactory
|
||||||
SgdrScheduleConfig,
|
|
||||||
StrategyFactory,
|
|
||||||
SchedulerFactory
|
|
||||||
)
|
|
||||||
from khaosz.trainer.train_callback import (
|
from khaosz.trainer.train_callback import (
|
||||||
TrainCallback,
|
TrainCallback,
|
||||||
ProgressBarCallback,
|
ProgressBarCallback,
|
||||||
|
|
@ -15,16 +12,18 @@ from khaosz.trainer.train_callback import (
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
# trainer
|
||||||
"Trainer",
|
"Trainer",
|
||||||
|
|
||||||
|
# factory
|
||||||
"StrategyFactory",
|
"StrategyFactory",
|
||||||
"CosineScheduleConfig",
|
|
||||||
"SgdrScheduleConfig",
|
|
||||||
"SchedulerFactory",
|
"SchedulerFactory",
|
||||||
|
|
||||||
# callback
|
# callback
|
||||||
"TrainCallback",
|
"TrainCallback",
|
||||||
"ProgressBarCallback",
|
"ProgressBarCallback",
|
||||||
"CheckpointCallback",
|
"CheckpointCallback",
|
||||||
|
"TrainCallback",
|
||||||
"SchedulerCallback",
|
"SchedulerCallback",
|
||||||
"StepMonitorCallback"
|
"StepMonitorCallback"
|
||||||
]
|
]
|
||||||
|
|
@ -0,0 +1,164 @@
|
||||||
|
import math
|
||||||
|
from abc import abstractmethod, ABC
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
|
from khaosz.config.schedule_config import ScheduleConfig
|
||||||
|
|
||||||
|
|
||||||
|
class BaseScheduler(LRScheduler, ABC):
|
||||||
|
"""
|
||||||
|
Base scheduler class for all other schedulers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, optimizer, last_epoch: int = -1):
|
||||||
|
super().__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_lr(self) -> List[float]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def state_dict(self) -> Dict[str, Any]:
|
||||||
|
return super().state_dict()
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict: Dict[str, Any]):
|
||||||
|
super().load_state_dict(state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
class CosineScheduler(BaseScheduler):
|
||||||
|
"""
|
||||||
|
Cosine decay scheduler with warmup, implemented as PyTorch LRScheduler.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
optimizer,
|
||||||
|
warmup_steps: int,
|
||||||
|
lr_decay_steps: int,
|
||||||
|
min_rate: float = 0.05,
|
||||||
|
last_epoch: int = -1
|
||||||
|
):
|
||||||
|
self.warmup_steps = warmup_steps
|
||||||
|
self.lr_decay_steps = lr_decay_steps
|
||||||
|
self.min_rate = min_rate
|
||||||
|
self.total_steps = warmup_steps + lr_decay_steps
|
||||||
|
super().__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
|
|
||||||
|
def get_lr(self) -> List[float]:
|
||||||
|
# warmup
|
||||||
|
if self.last_epoch < self.warmup_steps:
|
||||||
|
warmup_factor = max(self.min_rate, self.last_epoch / self.warmup_steps)
|
||||||
|
return [base_lr * warmup_factor for base_lr in self.base_lrs]
|
||||||
|
|
||||||
|
# cosine decay
|
||||||
|
decay_progress = (self.last_epoch - self.warmup_steps) / self.lr_decay_steps
|
||||||
|
decay_progress = min(decay_progress, 1.0)
|
||||||
|
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * decay_progress))
|
||||||
|
decay_factor = max(self.min_rate, cosine_decay)
|
||||||
|
return [base_lr * decay_factor for base_lr in self.base_lrs]
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
state = super().state_dict()
|
||||||
|
state.update({
|
||||||
|
'warmup_steps': self.warmup_steps,
|
||||||
|
'lr_decay_steps': self.lr_decay_steps,
|
||||||
|
'min_rate': self.min_rate,
|
||||||
|
'total_steps': self.total_steps,
|
||||||
|
})
|
||||||
|
return state
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
self.warmup_steps = state_dict.pop('warmup_steps')
|
||||||
|
self.lr_decay_steps = state_dict.pop('lr_decay_steps')
|
||||||
|
self.min_rate = state_dict.pop('min_rate')
|
||||||
|
self.total_steps = state_dict.pop('total_steps')
|
||||||
|
super().load_state_dict(state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
class SGDRScheduler(BaseScheduler):
|
||||||
|
"""
|
||||||
|
SGDR (Stochastic Gradient Descent with Warm Restarts) scheduler,
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
optimizer,
|
||||||
|
warmup_steps: int,
|
||||||
|
cycle_length: int,
|
||||||
|
min_rate: float = 0.05,
|
||||||
|
t_mult: int = 2,
|
||||||
|
last_epoch: int = -1,
|
||||||
|
):
|
||||||
|
self.warmup_steps = warmup_steps
|
||||||
|
self.cycle_length = cycle_length
|
||||||
|
self.min_rate = min_rate
|
||||||
|
self.t_mult = t_mult
|
||||||
|
|
||||||
|
super().__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
|
|
||||||
|
def get_lr(self):
|
||||||
|
# warmup
|
||||||
|
if self.last_epoch < self.warmup_steps:
|
||||||
|
warmup_factor = max(self.min_rate, self.last_epoch / self.warmup_steps)
|
||||||
|
return [base_lr * warmup_factor for base_lr in self.base_lrs]
|
||||||
|
|
||||||
|
# SGDR
|
||||||
|
steps_since_warmup = self.last_epoch - self.warmup_steps
|
||||||
|
|
||||||
|
# 1. Calculate current cycle and position within cycle
|
||||||
|
current_cycle_length = self.cycle_length
|
||||||
|
total_cycles_length = 0
|
||||||
|
cycle_num = 0
|
||||||
|
|
||||||
|
while total_cycles_length + current_cycle_length <= steps_since_warmup:
|
||||||
|
total_cycles_length += current_cycle_length
|
||||||
|
current_cycle_length *= self.t_mult
|
||||||
|
cycle_num += 1
|
||||||
|
|
||||||
|
steps_in_cycle = steps_since_warmup - total_cycles_length
|
||||||
|
|
||||||
|
# 2. Cosine annealing within the current cycle
|
||||||
|
cosine_factor = 0.5 * (1 + math.cos(math.pi * steps_in_cycle / current_cycle_length))
|
||||||
|
learning_rate_factor = self.min_rate + (1 - self.min_rate) * cosine_factor
|
||||||
|
|
||||||
|
return [base_lr * learning_rate_factor for base_lr in self.base_lrs]
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
"""Returns the state of the scheduler as a dict."""
|
||||||
|
state = super().state_dict()
|
||||||
|
state.update({
|
||||||
|
'warmup_steps': self.warmup_steps,
|
||||||
|
'cycle_length': self.cycle_length,
|
||||||
|
'min_rate': self.min_rate,
|
||||||
|
't_mult': self.t_mult
|
||||||
|
})
|
||||||
|
return state
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
"""Loads the scheduler's state."""
|
||||||
|
self.warmup_steps = state_dict.pop('warmup_steps')
|
||||||
|
self.cycle_length = state_dict.pop('cycle_length')
|
||||||
|
self.min_rate = state_dict.pop('min_rate')
|
||||||
|
self.t_mult = state_dict.pop('t_mult')
|
||||||
|
super().load_state_dict(state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerFactory:
|
||||||
|
"""
|
||||||
|
Factory class for creating learning rate schedulers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_scheduler(optimizer, scedule_config: ScheduleConfig) -> BaseScheduler:
|
||||||
|
kwargs = scedule_config.get_kwargs()
|
||||||
|
schedule_type = kwargs.pop("schedule_type")
|
||||||
|
|
||||||
|
if schedule_type == "cosine":
|
||||||
|
return CosineScheduler(optimizer, **kwargs)
|
||||||
|
elif schedule_type == "sgdr":
|
||||||
|
return SGDRScheduler(optimizer, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported schedule type: {schedule_type}")
|
||||||
|
|
||||||
|
|
@ -1,13 +1,11 @@
|
||||||
import copy
|
import copy
|
||||||
import math
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import Any, Literal, Tuple, Callable, Dict, Union
|
from typing import Any, Tuple, Callable, Dict, Union
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
|
||||||
|
|
||||||
|
|
||||||
def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id: int):
|
def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id: int):
|
||||||
|
|
@ -167,181 +165,4 @@ class StrategyFactory:
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
strategy = train_strategy[train_type]()
|
strategy = train_strategy[train_type]()
|
||||||
return strategy
|
return strategy
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ScheduleConfig(ABC):
|
|
||||||
schedule_type: str = field(
|
|
||||||
default="cosine",
|
|
||||||
metadata={
|
|
||||||
"help": "Type of learning rate schedule.",
|
|
||||||
"choices": ["cosine", "sgdr"]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
warmup_steps: int = field(
|
|
||||||
default=1000,
|
|
||||||
metadata={"help": "Number of warmup steps."}
|
|
||||||
)
|
|
||||||
min_rate: float = field(
|
|
||||||
default=0.05,
|
|
||||||
metadata={"help": "Minimum learning rate multiplier."}
|
|
||||||
)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_kwargs(self) -> Dict[str, Any]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def validate(self) -> None:
|
|
||||||
"""Validate configuration parameters."""
|
|
||||||
if self.warmup_steps < 0:
|
|
||||||
raise ValueError(f"warmup_steps must be non-negative, got {self.warmup_steps}")
|
|
||||||
if not 0 <= self.min_rate <= 1:
|
|
||||||
raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CosineScheduleConfig(ScheduleConfig):
|
|
||||||
total_steps: int = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Total training steps for cosine schedule."}
|
|
||||||
)
|
|
||||||
schedule_type: Literal["cosine"] = "cosine"
|
|
||||||
|
|
||||||
def get_kwargs(self) -> Dict[str, Any]:
|
|
||||||
if self.total_steps is None:
|
|
||||||
raise ValueError("total_steps must be specified for cosine schedule")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"schedule_type": self.schedule_type,
|
|
||||||
"warmup_steps": self.warmup_steps,
|
|
||||||
"lr_decay_steps": self.total_steps - self.warmup_steps,
|
|
||||||
"min_rate": self.min_rate
|
|
||||||
}
|
|
||||||
|
|
||||||
def validate(self) -> None:
|
|
||||||
super().validate()
|
|
||||||
if self.total_steps is not None and self.total_steps <= self.warmup_steps:
|
|
||||||
raise ValueError(f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SgdrScheduleConfig(ScheduleConfig):
|
|
||||||
cycle_length: int = field(
|
|
||||||
default=1000,
|
|
||||||
metadata={"help": "Length of the first cycle in steps."}
|
|
||||||
)
|
|
||||||
t_mult: int = field(
|
|
||||||
default=2,
|
|
||||||
metadata={"help": "Multiplier for cycle length growth."}
|
|
||||||
)
|
|
||||||
schedule_type: Literal["sgdr"] = "sgdr"
|
|
||||||
|
|
||||||
def get_kwargs(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"schedule_type": self.schedule_type,
|
|
||||||
"warmup_steps": self.warmup_steps,
|
|
||||||
"cycle_length": self.cycle_length,
|
|
||||||
"min_rate": self.min_rate,
|
|
||||||
"t_mult": self.t_mult
|
|
||||||
}
|
|
||||||
|
|
||||||
def validate(self) -> None:
|
|
||||||
super().validate()
|
|
||||||
if self.cycle_length <= 0:
|
|
||||||
raise ValueError(f"cycle_length must be positive, got {self.cycle_length}")
|
|
||||||
if self.t_mult < 1:
|
|
||||||
raise ValueError(f"t_mult must be >= 1, got {self.t_mult}")
|
|
||||||
|
|
||||||
|
|
||||||
class SchedulerFactory:
|
|
||||||
"""Factory for creating learning rate schedule functions."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_sgdr_schedule(
|
|
||||||
warmup_steps: int,
|
|
||||||
cycle_length: int,
|
|
||||||
min_rate: float = 0.05,
|
|
||||||
t_mult: int = 2
|
|
||||||
) -> Callable[[int], float]:
|
|
||||||
"""
|
|
||||||
Create SGDR (Stochastic Gradient Descent with Warm Restarts) schedule.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
warmup_steps: Number of warmup steps
|
|
||||||
cycle_length: Length of the first cycle
|
|
||||||
min_rate: Minimum learning rate multiplier
|
|
||||||
t_mult: Cycle length multiplier
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Schedule function that takes current step and returns LR multiplier
|
|
||||||
"""
|
|
||||||
|
|
||||||
def sgdr_schedule(current_step: int) -> float:
|
|
||||||
# Warmup phase
|
|
||||||
if current_step < warmup_steps:
|
|
||||||
return max(min_rate, current_step / warmup_steps)
|
|
||||||
|
|
||||||
# SGDR phase
|
|
||||||
steps_since_warmup = current_step - warmup_steps
|
|
||||||
|
|
||||||
# Find current cycle and position within cycle
|
|
||||||
cycle_start = 0
|
|
||||||
current_cycle_length = cycle_length
|
|
||||||
cycle_index = 0
|
|
||||||
|
|
||||||
while steps_since_warmup >= cycle_start + current_cycle_length:
|
|
||||||
cycle_start += current_cycle_length
|
|
||||||
current_cycle_length *= t_mult
|
|
||||||
cycle_index += 1
|
|
||||||
|
|
||||||
position_in_cycle = steps_since_warmup - cycle_start
|
|
||||||
progress = position_in_cycle / current_cycle_length
|
|
||||||
|
|
||||||
# Cosine annealing within cycle
|
|
||||||
return max(min_rate, 0.5 * (1 + math.cos(math.pi * progress)))
|
|
||||||
|
|
||||||
return sgdr_schedule
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_cosine_schedule(
|
|
||||||
warmup_steps: int,
|
|
||||||
lr_decay_steps: int,
|
|
||||||
min_rate: float = 0.05
|
|
||||||
) -> Callable[[int], float]:
|
|
||||||
"""
|
|
||||||
Create cosine decay schedule with warmup.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
warmup_steps: Number of warmup steps
|
|
||||||
lr_decay_steps: Number of steps for cosine decay after warmup
|
|
||||||
min_rate: Minimum learning rate multiplier
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Schedule function that takes current step and returns LR multiplier
|
|
||||||
"""
|
|
||||||
|
|
||||||
def cosine_schedule(current_step: int) -> float:
|
|
||||||
if current_step < warmup_steps:
|
|
||||||
# Linear warmup
|
|
||||||
return max(min_rate, current_step / warmup_steps)
|
|
||||||
else:
|
|
||||||
# Cosine decay
|
|
||||||
decay_progress = (current_step - warmup_steps) / lr_decay_steps
|
|
||||||
decay_progress = min(decay_progress, 1.0) # Clamp at 1.0
|
|
||||||
return max(min_rate, 0.5 * (1.0 + math.cos(math.pi * decay_progress)))
|
|
||||||
|
|
||||||
return cosine_schedule
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load_schedule_fn(scedule_config: ScheduleConfig) -> Callable[[int], float]:
|
|
||||||
kwargs = scedule_config.get_kwargs()
|
|
||||||
schedule_type = kwargs.pop("schedule_type")
|
|
||||||
|
|
||||||
if schedule_type == "cosine":
|
|
||||||
return SchedulerFactory.get_cosine_schedule(**kwargs)
|
|
||||||
elif schedule_type == "sgdr":
|
|
||||||
return SchedulerFactory.get_sgdr_schedule(**kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported schedule type: {schedule_type}")
|
|
||||||
|
|
||||||
|
|
@ -8,7 +8,7 @@ from torch.nn.utils import clip_grad_norm_
|
||||||
from torch.optim.lr_scheduler import LambdaLR
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
from typing import List, Optional, Protocol, TYPE_CHECKING
|
from typing import List, Optional, Protocol, TYPE_CHECKING
|
||||||
|
|
||||||
from khaosz.trainer.strategy import ScheduleConfig, SchedulerFactory
|
from khaosz.config import ScheduleConfig
|
||||||
from khaosz.trainer.metric_util import (
|
from khaosz.trainer.metric_util import (
|
||||||
grad_max,
|
grad_max,
|
||||||
grad_min,
|
grad_min,
|
||||||
|
|
@ -78,17 +78,8 @@ class SchedulerCallback(TrainCallback):
|
||||||
for group in trainer.train_config.optimizer.param_groups:
|
for group in trainer.train_config.optimizer.param_groups:
|
||||||
if "initial_lr" not in group:
|
if "initial_lr" not in group:
|
||||||
group["initial_lr"] = group["lr"]
|
group["initial_lr"] = group["lr"]
|
||||||
|
|
||||||
self.schedule_config.validate()
|
|
||||||
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
|
|
||||||
self.schedule_config
|
|
||||||
)
|
|
||||||
|
|
||||||
self.scheduler = LambdaLR(
|
self.scheduler = context.scheduler
|
||||||
trainer.train_config.optimizer,
|
|
||||||
lambda_scheduler_fn,
|
|
||||||
last_epoch=context.current_iter - 1
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||||
_ = trainer, context
|
_ = trainer, context
|
||||||
|
|
@ -106,8 +97,8 @@ class CheckpointCallback(TrainCallback):
|
||||||
|
|
||||||
def _save_checkpoint(self, trainer: 'Trainer', context: 'TrainContext'):
|
def _save_checkpoint(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||||
save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{context.current_iter}")
|
save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{context.current_iter}")
|
||||||
# context.checkpoint.scheduler_state = context.sampler.state_dict()
|
|
||||||
context.checkpoint.optimizer_state = context.optimizer.state_dict()
|
context.checkpoint.optimizer_state = context.optimizer.state_dict()
|
||||||
|
context.checkpoint.scheduler_state = context.scheduler.state_dict()
|
||||||
context.checkpoint.save(save_path)
|
context.checkpoint.save(save_path)
|
||||||
self.last_ckpt_iter = context.current_iter
|
self.last_ckpt_iter = context.current_iter
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
from dataclasses import dataclass, field, fields
|
from dataclasses import dataclass, field, fields
|
||||||
from typing import Optional, Self, TYPE_CHECKING
|
from typing import Optional, Self, TYPE_CHECKING
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from khaosz.config.param_config import Checkpoint
|
from khaosz.config import Checkpoint
|
||||||
from khaosz.data.data_util import ResumeableRandomSampler
|
from khaosz.data import ResumeableRandomSampler
|
||||||
|
from khaosz.trainer.schedule import BaseScheduler, SchedulerFactory
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from khaosz.trainer.trainer import Trainer
|
from khaosz.trainer.trainer import Trainer
|
||||||
|
|
@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
||||||
class TrainContext:
|
class TrainContext:
|
||||||
dataloader: DataLoader = field(default=None)
|
dataloader: DataLoader = field(default=None)
|
||||||
optimizer: Optimizer = field(default=None)
|
optimizer: Optimizer = field(default=None)
|
||||||
scheduler: LRScheduler = field(default=None)
|
scheduler: BaseScheduler = field(default=None)
|
||||||
checkpoint: Checkpoint = field(default=None)
|
checkpoint: Checkpoint = field(default=None)
|
||||||
epoch: int = field(default=0)
|
epoch: int = field(default=0)
|
||||||
current_iter: int = field(default=0)
|
current_iter: int = field(default=0)
|
||||||
|
|
@ -54,8 +54,20 @@ class TrainContextBuilder:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_scheduler(self) -> Self:
|
def with_scheduler(self) -> Self:
|
||||||
return self
|
# the build order has any problem ?
|
||||||
|
optimizer = self.trainer.train_config.optimizer
|
||||||
|
schedule_config = self.trainer.schedule_config
|
||||||
|
scheduler = SchedulerFactory.load_scheduler(optimizer, schedule_config)
|
||||||
|
|
||||||
|
if self._context.checkpoint and self._context.checkpoint.scheduler_state:
|
||||||
|
scheduler.load_state_dict(self._context.checkpoint.scheduler_state)
|
||||||
|
|
||||||
|
self._context.scheduler = scheduler
|
||||||
|
|
||||||
|
if self._context.checkpoint:
|
||||||
|
self._context.checkpoint.scheduler_state = scheduler.state_dict()
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
def with_dataloader(self) -> Self:
|
def with_dataloader(self) -> Self:
|
||||||
resumeable_sampler = ResumeableRandomSampler(
|
resumeable_sampler = ResumeableRandomSampler(
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,11 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
from khaosz.config import (
|
||||||
from khaosz.config import ModelParameter, Checkpoint
|
ModelParameter,
|
||||||
from khaosz.trainer.strategy import ScheduleConfig
|
Checkpoint,
|
||||||
from khaosz.config.train_config import TrainConfig
|
ScheduleConfig,
|
||||||
|
TrainConfig
|
||||||
|
)
|
||||||
from khaosz.trainer.train_callback import (
|
from khaosz.trainer.train_callback import (
|
||||||
TrainCallback,
|
TrainCallback,
|
||||||
ProgressBarCallback,
|
ProgressBarCallback,
|
||||||
|
|
@ -15,6 +17,7 @@ from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -36,7 +39,7 @@ class Trainer:
|
||||||
SchedulerCallback(self.schedule_config),
|
SchedulerCallback(self.schedule_config),
|
||||||
]
|
]
|
||||||
|
|
||||||
def _build_train_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
||||||
return (TrainContextBuilder(self)
|
return (TrainContextBuilder(self)
|
||||||
.with_checkpoint(checkpoint)
|
.with_checkpoint(checkpoint)
|
||||||
.with_optimizer()
|
.with_optimizer()
|
||||||
|
|
@ -51,8 +54,7 @@ class Trainer:
|
||||||
method(self, context)
|
method(self, context)
|
||||||
|
|
||||||
def train(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
|
def train(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
|
||||||
context = self._build_train_context(checkpoint)
|
context = self._build_context(checkpoint)
|
||||||
|
|
||||||
self._call_callbacks('on_train_begin', context)
|
self._call_callbacks('on_train_begin', context)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
6
train.py
6
train.py
|
|
@ -3,9 +3,9 @@ import argparse
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from khaosz.core import ParameterLoader, Checkpoint
|
from khaosz.config import ParameterLoader, Checkpoint, TrainConfig, CosineScheduleConfig
|
||||||
from khaosz.trainer import Trainer, DatasetLoader, TrainConfig, CosineScheduleConfig
|
from khaosz.trainer import Trainer, StrategyFactory
|
||||||
from khaosz.trainer import StrategyFactory
|
from khaosz.data import DatasetLoader
|
||||||
|
|
||||||
|
|
||||||
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
|
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue