feat(trainer): 重构训练配置与策略工厂引入

This commit is contained in:
ViperEkura 2025-09-28 21:39:48 +08:00
parent 2dc7b5bda8
commit fa43ed2943
6 changed files with 188 additions and 106 deletions

View File

@ -18,9 +18,13 @@ from khaosz.core.generator import (
RetrievalGenerator, RetrievalGenerator,
EmbeddingEncoder EmbeddingEncoder
) )
from khaosz.trainer.trainer import Trainer from khaosz.trainer import (
from khaosz.trainer.dataset import SeqDataset, SftDataset, DpoDataset, BaseDataset Trainer,
DatasetLoader,
TrainConfig,
StrategyFactory,
SchedulerFactory
)
__all__ = [ __all__ = [
# model # model
@ -40,10 +44,10 @@ __all__ = [
# trainer # trainer
"Trainer", "Trainer",
"SeqDataset", "DatasetLoader",
"SftDataset", "TrainConfig",
"DpoDataset", "StrategyFactory",
"BaseDataset", "SchedulerFactory",
# utils # utils
"Retriever", "Retriever",

View File

@ -1,6 +1,12 @@
from khaosz.trainer.dataset import DatasetLoader from khaosz.trainer.dataset import DatasetLoader
from khaosz.trainer.trainer import Trainer from khaosz.trainer.trainer import Trainer
from khaosz.trainer.strategy import TrainConfig, CosineScheduleConfig, SgdrScheduleConfig from khaosz.trainer.strategy import (
TrainConfig,
CosineScheduleConfig,
SgdrScheduleConfig,
StrategyFactory,
SchedulerFactory
)
__all__ = [ __all__ = [
"DatasetLoader", "DatasetLoader",
@ -8,4 +14,6 @@ __all__ = [
"TrainConfig", "TrainConfig",
"CosineScheduleConfig", "CosineScheduleConfig",
"SgdrScheduleConfig", "SgdrScheduleConfig",
"StrategyFactory",
"SchedulerFactory"
] ]

View File

@ -257,7 +257,7 @@ class DatasetLoader:
bos_token_id=kwargs.get("bos_token_id"), bos_token_id=kwargs.get("bos_token_id"),
eos_token_id=kwargs.get("eos_token_id"), eos_token_id=kwargs.get("eos_token_id"),
user_token_id=kwargs.get("user_token_id"), user_token_id=kwargs.get("user_token_id"),
multi_turn=kwargs.get("multi_turn", False) multi_turn=kwargs.get("multi_turn")
), ),
"dpo": lambda m_len, device: DpoDataset(m_len, device=device), "dpo": lambda m_len, device: DpoDataset(m_len, device=device),
} }

View File

@ -177,9 +177,10 @@ class StrategyFactory:
@dataclass @dataclass
class TrainConfig: class TrainConfig:
train_type: str = field(
default_factory=["seq", "sft", "dpo"], strategy: BaseStrategy = field(
metadata={"help": "Type of training."} default=None,
metadata={"help": "Training strategy."}
) )
dataset: Dataset = field( dataset: Dataset = field(
default=None, default=None,
@ -217,10 +218,6 @@ class TrainConfig:
default=3407, default=3407,
metadata={"help": "Random seed."} metadata={"help": "Random seed."}
) )
dpo_beta: float = field(
default=0.1,
metadata={"help": "DPO beta."}
)
def get_kwargs(self)-> Dict[str, Any]: def get_kwargs(self)-> Dict[str, Any]:
config_dict = asdict(self) config_dict = asdict(self)
@ -228,117 +225,191 @@ class TrainConfig:
@dataclass @dataclass
class ScheduleConfig: class ScheduleConfig(ABC):
schedule_type: str = field( schedule_type: str = field(
default_factory=["cosine", "sgdr"], default="cosine",
metadata={"help": "Type of learning rate schedule."} metadata={
"help": "Type of learning rate schedule.",
"choices": ["cosine", "sgdr"]
}
) )
warning_step: int = field( warmup_steps: int = field(
default=1000, default=1000,
metadata= {"help": "Warning up step."} metadata={"help": "Number of warmup steps."}
) )
min_rate: float = field(
default=0.05,
metadata={"help": "Minimum learning rate multiplier."}
)
@abstractmethod @abstractmethod
def get_kwargs(self) -> Dict[str, Any]: def get_kwargs(self) -> Dict[str, Any]:
raise NotImplementedError raise NotImplementedError
def validate(self) -> None:
"""Validate configuration parameters."""
if self.warmup_steps < 0:
raise ValueError(f"warmup_steps must be non-negative, got {self.warmup_steps}")
if not 0 <= self.min_rate <= 1:
raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}")
@dataclass @dataclass
class CosineScheduleConfig(ScheduleConfig): class CosineScheduleConfig(ScheduleConfig):
total_iters: int = field( total_steps: int = field( # 更准确的命名
default=None, default=None,
metadata={"help": "Total iterations for cosine schedule."} metadata={"help": "Total training steps for cosine schedule."}
)
min_rate: float = field(
default=0.05,
metadata={"help": "Minimum rate for cosine schedule."}
) )
schedule_type: Literal["cosine"] = "cosine" schedule_type: Literal["cosine"] = "cosine"
def get_kwargs(self) -> Dict[str, Any]: def get_kwargs(self) -> Dict[str, Any]:
if self.total_steps is None:
raise ValueError("total_steps must be specified for cosine schedule")
return { return {
"schedule_type": self.schedule_type, "schedule_type": self.schedule_type,
"warning_step": self.warning_step, "warmup_steps": self.warmup_steps,
"lr_decay_iters": self.total_iters - self.warning_step, "lr_decay_steps": self.total_steps - self.warmup_steps,
"min_rate": self.min_rate "min_rate": self.min_rate
} }
def validate(self) -> None:
super().validate()
if self.total_steps is not None and self.total_steps <= self.warmup_steps:
raise ValueError(f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})")
@dataclass @dataclass
class SgdrScheduleConfig(ScheduleConfig): class SgdrScheduleConfig(ScheduleConfig):
cycle_length: int = field( cycle_length: int = field(
default=1000, default=1000,
metadata={"help": "Cycle length for sgdr schedule."} metadata={"help": "Length of the first cycle in steps."}
) )
min_rate: float = field( t_mult: int = field(
default=0.05,
metadata={"help": "Minimum rate for sgdr schedule."}
)
T_mult: int = field(
default=2, default=2,
metadata={"help": "T_mult for sgdr schedule."} metadata={"help": "Multiplier for cycle length growth."}
) )
schedule_type: Literal["sgdr"] = "sgdr" schedule_type: Literal["sgdr"] = "sgdr"
def get_kwargs(self) -> Dict[str, Any]: def get_kwargs(self) -> Dict[str, Any]:
return { return {
"schedule_type": self.schedule_type, "schedule_type": self.schedule_type,
"warning_step": self.warning_step, "warmup_steps": self.warmup_steps,
"cycle_length": self.cycle_length, "cycle_length": self.cycle_length,
"min_rate": self.min_rate, "min_rate": self.min_rate,
"T_mult": self.T_mult "t_mult": self.t_mult
} }
def validate(self) -> None:
super().validate()
if self.cycle_length <= 0:
raise ValueError(f"cycle_length must be positive, got {self.cycle_length}")
if self.t_mult < 1:
raise ValueError(f"t_mult must be >= 1, got {self.t_mult}")
class SchedulerFactory: class SchedulerFactory:
"""Factory for creating learning rate schedule functions."""
@staticmethod @staticmethod
def get_sgdr_schedule( def get_sgdr_schedule(
warning_step: int, warmup_steps: int,
cycle_length: int, cycle_length: int,
min_rate: float = 0.1, min_rate: float = 0.05,
T_mult: int = 2 t_mult: int = 2
) -> Callable[[int], float]: ) -> Callable[[int], float]:
"""
Create SGDR (Stochastic Gradient Descent with Warm Restarts) schedule.
def sgdr_schedule(now_iter: int) -> float: Args:
if now_iter < warning_step: warmup_steps: Number of warmup steps
return max(min_rate, now_iter / warning_step) cycle_length: Length of the first cycle
min_rate: Minimum learning rate multiplier
t_mult: Cycle length multiplier
adjusted_iter = now_iter - warning_step Returns:
total_cycles, current_cycle = 0, 0 Schedule function that takes current step and returns LR multiplier
while adjusted_iter >= cycle_length * (T_mult ** total_cycles): """
current_cycle += 1
total_cycles += 1
cycle_start = sum(cycle_length * (T_mult ** i) for i in range(current_cycle)) def sgdr_schedule(current_step: int) -> float:
cycle_pos = adjusted_iter - cycle_start # Warmup phase
if current_step < warmup_steps:
return max(min_rate, current_step / warmup_steps)
cycle_length_current = cycle_length * (T_mult ** current_cycle) # SGDR phase
return max(min_rate, 0.5 * (1 + math.cos(math.pi * cycle_pos / cycle_length_current))) steps_since_warmup = current_step - warmup_steps
# Find current cycle and position within cycle
cycle_start = 0
current_cycle_length = cycle_length
cycle_index = 0
while steps_since_warmup >= cycle_start + current_cycle_length:
cycle_start += current_cycle_length
current_cycle_length *= t_mult
cycle_index += 1
position_in_cycle = steps_since_warmup - cycle_start
progress = position_in_cycle / current_cycle_length
# Cosine annealing within cycle
return max(min_rate, 0.5 * (1 + math.cos(math.pi * progress)))
return sgdr_schedule return sgdr_schedule
@staticmethod @staticmethod
def get_cosine_warmup_schedule( def get_cosine_schedule(
warning_step: int, warmup_steps: int,
lr_decay_iters: int, lr_decay_steps: int,
min_rate: float = 0.1 min_rate: float = 0.05
) -> Callable[[int], float]: ) -> Callable[[int], float]:
"""
Create cosine decay schedule with warmup.
def cosine_warmup_schedule(now_iter: int) -> float: Args:
if now_iter <= warning_step: warmup_steps: Number of warmup steps
return max(min_rate, now_iter / warning_step) lr_decay_steps: Number of steps for cosine decay after warmup
min_rate: Minimum learning rate multiplier
Returns:
Schedule function that takes current step and returns LR multiplier
"""
def cosine_schedule(current_step: int) -> float:
if current_step < warmup_steps:
# Linear warmup
return max(min_rate, current_step / warmup_steps)
else: else:
rate = (now_iter - warning_step) / (lr_decay_iters - warning_step) # Cosine decay
return max(min_rate, 0.5 * (1.0 + math.cos(math.pi * rate))) decay_progress = (current_step - warmup_steps) / lr_decay_steps
decay_progress = min(decay_progress, 1.0) # Clamp at 1.0
return max(min_rate, 0.5 * (1.0 + math.cos(math.pi * decay_progress)))
return cosine_warmup_schedule return cosine_schedule
@staticmethod @staticmethod
def load_schedule_fn(**kwargs): def create_schedule(config: ScheduleConfig) -> Callable[[int], float]:
strategy = kwargs.pop("schedule_type") """
if strategy == "cosine": Create schedule from configuration.
return SchedulerFactory.get_cosine_warmup_schedule(**kwargs)
elif strategy == "sgdr": Args:
config: Schedule configuration instance
Returns:
Schedule function
"""
config.validate()
kwargs = config.get_kwargs()
return SchedulerFactory.load_schedule_fn(**kwargs)
@staticmethod
def load_schedule_fn(**kwargs) -> Callable[[int], float]:
schedule_type = kwargs.pop("schedule_type")
if schedule_type == "cosine":
return SchedulerFactory.get_cosine_schedule(**kwargs)
elif schedule_type == "sgdr":
return SchedulerFactory.get_sgdr_schedule(**kwargs) return SchedulerFactory.get_sgdr_schedule(**kwargs)
else: else:
raise ValueError(f"Invalid schedule type: {strategy}") raise ValueError(f"Unsupported schedule type: {schedule_type}")

View File

@ -8,19 +8,23 @@ from torch.utils.data import DataLoader, RandomSampler
from tqdm import tqdm from tqdm import tqdm
from khaosz.core import ModelParameter, Checkpoint from khaosz.core import ModelParameter, Checkpoint
from khaosz.trainer.strategy import SchedulerFactory, StrategyFactory, TrainConfig, ScheduleConfig from khaosz.trainer.strategy import SchedulerFactory, TrainConfig, ScheduleConfig
class Trainer: class Trainer:
def __init__( def __init__(
self, self,
parameter: ModelParameter parameter: ModelParameter,
train_config: TrainConfig,
schedule_config: ScheduleConfig
): ):
self.checkpoint = Checkpoint( self.checkpoint = Checkpoint(
model=parameter.model, model=parameter.model,
tokenizer=parameter.tokenizer, tokenizer=parameter.tokenizer,
config=parameter.config, config=parameter.config,
) )
self.train_config = train_config
self.schedule_config = schedule_config
def save_checkpoint( def save_checkpoint(
self, self,
@ -35,12 +39,11 @@ class Trainer:
def train( def train(
self, self,
train_config: TrainConfig,
schedule_config: ScheduleConfig,
train_checkpoint: Optional[Checkpoint] = None train_checkpoint: Optional[Checkpoint] = None
) -> Checkpoint: ) -> Checkpoint:
train_config = self.train_config
schedule_config = self.schedule_config
assert schedule_config.schedule_type in ["cosine", "sgdr"] assert schedule_config.schedule_type in ["cosine", "sgdr"]
assert train_config.train_type in ["seq", "sft", "dpo"]
if train_checkpoint: if train_checkpoint:
self.checkpoint = train_checkpoint self.checkpoint = train_checkpoint
@ -60,19 +63,6 @@ class Trainer:
**schedule_config.get_kwargs() **schedule_config.get_kwargs()
) )
strategy_kwargs = {
"bos_token_id": self.checkpoint.tokenizer.bos_id,
"eos_token_id": self.checkpoint.tokenizer.eos_id,
"pad_token_id": self.checkpoint.tokenizer.pad_id,
"dpo_beta": train_config.dpo_beta
}
strategy = StrategyFactory.load(
self.checkpoint.model,
train_config.train_type,
**strategy_kwargs
)
scheduler = LambdaLR( scheduler = LambdaLR(
train_config.optimizer, train_config.optimizer,
lambda_scheduler_fn, lambda_scheduler_fn,
@ -98,7 +88,7 @@ class Trainer:
) )
for batch in progress_bar: for batch in progress_bar:
#forward #forward
loss = strategy(batch) loss = train_config.strategy(batch)
loss_list.append(loss.item()) loss_list.append(loss.item())
#backward #backward
loss.backward() loss.backward()

View File

@ -5,6 +5,7 @@ import torch
from torch.optim import AdamW from torch.optim import AdamW
from khaosz.core import ParameterLoader from khaosz.core import ParameterLoader
from khaosz.trainer import Trainer, DatasetLoader, TrainConfig, CosineScheduleConfig from khaosz.trainer import Trainer, DatasetLoader, TrainConfig, CosineScheduleConfig
from khaosz.trainer import StrategyFactory
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
@ -46,19 +47,26 @@ def train(
cache_files = get_files(data_root_path) cache_files = get_files(data_root_path)
dataset_kwargs = { strategy_kwargs = {
"multi_turn": multi_turn, "multi_turn": multi_turn,
"bos_token_id": parameter.tokenizer.bos_id, "bos_token_id": parameter.tokenizer.bos_id,
"eos_token_id": parameter.tokenizer.eos_id, "eos_token_id": parameter.tokenizer.eos_id,
"user_token_id":parameter.tokenizer.encode("<|user|>")[0] "user_token_id":parameter.tokenizer.encode("<|user|>")[0],
"dpo_beta": dpo_beta
} }
strategy = StrategyFactory.load(
model,
train_type
**strategy_kwargs
)
dataset = DatasetLoader.load( dataset = DatasetLoader.load(
train_type=train_type, train_type=train_type,
load_path=cache_files, load_path=cache_files,
max_len=parameter.config.m_len, max_len=parameter.config.m_len,
device=device, device=device,
dataset_kwargs=dataset_kwargs dataset_kwargs=strategy_kwargs
) )
param_groups = [ param_groups = [
@ -73,7 +81,7 @@ def train(
) )
train_config = TrainConfig( train_config = TrainConfig(
train_type=train_type, strategy=strategy,
dataset=dataset, dataset=dataset,
optimizer=optim, optimizer=optim,
ckpt_dir=ckpt_dir, ckpt_dir=ckpt_dir,
@ -83,7 +91,6 @@ def train(
n_iter_step=n_iter_step, n_iter_step=n_iter_step,
max_grad_norm=max_grad_norm, max_grad_norm=max_grad_norm,
random_seed=random_seed, random_seed=random_seed,
dpo_beta=dpo_beta
) )
schedule_config = CosineScheduleConfig( schedule_config = CosineScheduleConfig(
@ -91,11 +98,13 @@ def train(
total_iters=len(dataset) * n_epoch // batch_size, total_iters=len(dataset) * n_epoch // batch_size,
) )
trainer = Trainer(parameter) trainer = Trainer(
trainer.train( parameter=parameter,
train_config=train_config, train_config=train_config,
schedule_config=schedule_config, schedule_config=schedule_config,
) )
trainer.train()
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train the Transformer model.") parser = argparse.ArgumentParser(description="Train the Transformer model.")