feat(train): 支持分布式训练的优化器与调度器工厂配置

This commit is contained in:
ViperEkura 2025-12-22 20:41:03 +08:00
parent 7623b1e5fd
commit cfa3cf7daa
3 changed files with 52 additions and 22 deletions

View File

@ -1,4 +1,4 @@
from torch import nn import torch.nn as nn
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler from torch.optim.lr_scheduler import LRScheduler
@ -30,6 +30,7 @@ class TrainConfig:
default=None, default=None,
metadata={"help": "Scheduler for training."} metadata={"help": "Scheduler for training."}
) )
n_epoch: int = field( n_epoch: int = field(
default=1, default=1,
metadata={"help": "Number of epochs for training."} metadata={"help": "Number of epochs for training."}
@ -104,7 +105,15 @@ class TrainConfig:
default=None, default=None,
metadata={"help": "Parallel function for training."} metadata={"help": "Parallel function for training."}
) )
optimizer_factory: Optional[Callable[[nn.Module], Optimizer]] = field(
default=None,
metadata={"help": "Optimizer factory for training."}
)
scheduler_factory: Optional[Callable[[Optimizer], LRScheduler]] = field(
default=None,
metadata={"help": "Scheduler factory for training."}
)
# others # others
extra_kwargs: dict = field( extra_kwargs: dict = field(
default_factory=dict, default_factory=dict,
@ -115,7 +124,17 @@ class TrainConfig:
self.validate() self.validate()
def validate(self): def validate(self):
required_fields = ["model", "strategy", "dataset", "optimizer", "scheduler"] required_fields = ["model", "strategy", "dataset"]
for field_name in required_fields: for field_name in required_fields:
if getattr(self, field_name) is None: if getattr(self, field_name) is None:
raise ValueError(f"{field_name} is required.") raise ValueError(f"{field_name} is required.")
factory_case = all([self.optimizer_factory, self.scheduler_factory])
argument_case = all([self.optimizer, self.scheduler])
self.nprocs = max(self.nprocs, 1)
if self.nprocs > 1 and not factory_case:
raise ValueError("Distributed training requires optimizer and scheduler factories.")
elif self.nprocs == 1 and not argument_case:
raise ValueError("Single process training requires optimizer and scheduler arguments.")

View File

@ -80,19 +80,27 @@ class TrainContextBuilder:
return self return self
def with_strategy(self) -> Self: def with_strategy(self) -> Self:
device = get_current_device()
self._context.strategy = StrategyFactory.load( self._context.strategy = StrategyFactory.load(
model=self.config.model, model=self.config.model,
train_type=self.config.strategy, train_type=self.config.strategy,
device=device, device=get_current_device(),
**self.config.extra_kwargs **self.config.extra_kwargs
) )
return self return self
def with_parallel_fn(self) -> Self: def with_parallel_fn(self) -> Self:
fn = self.config.parallel_fn device = get_current_device()
if fn is not None: self._context.model = self._context.model.to(device=device)
if self.config.nprocs > 1:
fn = self.config.parallel_fn
optimizer_fn = self.config.optimizer_factory
scheduler_fn = self.config.scheduler_factory
self._context.model = fn(self._context.model) self._context.model = fn(self._context.model)
self._context.optimizer = optimizer_fn(self._context.model.parameters())
self._context.scheduler = scheduler_fn(self._context.optimizer)
return self return self

View File

@ -2,9 +2,10 @@ import os
import argparse import argparse
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim
import torch.distributed.fsdp as fsdp import torch.distributed.fsdp as fsdp
from torch.optim import AdamW from functools import partial
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
from khaosz.trainer import Trainer, SchedulerFactory from khaosz.trainer import Trainer, SchedulerFactory
from khaosz.data import DatasetLoader from khaosz.data import DatasetLoader
@ -26,7 +27,6 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--adamw_beta1", type=float, default=0.9, help="Beta values for AdamW optimizer.") parser.add_argument("--adamw_beta1", type=float, default=0.9, help="Beta values for AdamW optimizer.")
parser.add_argument("--adamw_beta2", type=float, default=0.95, help="Beta values for AdamW optimizer.") parser.add_argument("--adamw_beta2", type=float, default=0.95, help="Beta values for AdamW optimizer.")
parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.") parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.")
parser.add_argument("--embdeding_lr_rate", type=float, default=1.0, help="The rate between the embedding layers lr rate and the max lr rate.")
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.") parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for data loading.") parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for data loading.")
parser.add_argument("--no_pin_memory", action="store_false", dest="pin_memory", help="Disable pin memory") parser.add_argument("--no_pin_memory", action="store_false", dest="pin_memory", help="Disable pin memory")
@ -59,6 +59,12 @@ def fsdp_wrap(model: nn.Module):
) )
return fsdp_model return fsdp_model
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
return optim.AdamW(model.parameters(), **kwargs)
def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler.LRScheduler:
return SchedulerFactory.load(optimizer, **kwargs)
def train( def train(
train_type: str, train_type: str,
param_path: str, param_path: str,
@ -77,7 +83,6 @@ def train(
adamw_beta2: float, adamw_beta2: float,
adamw_weight_decay: float, adamw_weight_decay: float,
max_grad_norm: float, max_grad_norm: float,
embdeding_lr_rate: int,
random_seed: int, random_seed: int,
num_workers: int, num_workers: int,
pin_memory: bool, pin_memory: bool,
@ -110,23 +115,19 @@ def train(
stride=stride stride=stride
) )
param_groups = [
{"params": [p for n, p in model.named_parameters() if "embed" in n], "lr": max_lr * embdeding_lr_rate},
{"params": [p for n, p in model.named_parameters() if "embed" not in n], "lr": max_lr}
]
optimizer = AdamW(
param_groups,
betas=(adamw_beta1, adamw_beta2),
weight_decay=adamw_weight_decay
)
schedule_config = CosineScheduleConfig( schedule_config = CosineScheduleConfig(
warmup_steps=warmup_steps, warmup_steps=warmup_steps,
total_steps=len(dataset) * n_epoch // (batch_size * nprocs), total_steps=len(dataset) * n_epoch // (batch_size * nprocs),
) )
scheduler = SchedulerFactory.load(optimizer, schedule_config)
optimizer_fn = partial(create_optimizer, **{"lr": max_lr, "betas": (adamw_beta1, adamw_beta2), "weight_decay": adamw_weight_decay})
scheduler_fn = partial(create_scheduler, **{"schedule_config": schedule_config})
optimizer, scheduler = None, None
if nprocs == 1:
optimizer = optimizer_fn(model.parameters())
scheduler = scheduler_fn(optimizer)
train_config = TrainConfig( train_config = TrainConfig(
model=model, model=model,
@ -146,6 +147,8 @@ def train(
num_workers=num_workers, num_workers=num_workers,
pin_memory=pin_memory, pin_memory=pin_memory,
nprocs=nprocs, nprocs=nprocs,
optimizer_factory=optimizer_fn,
scheduler_factory=scheduler_fn,
extra_kwargs=kwargs, extra_kwargs=kwargs,
parallel_fn=fsdp_wrap parallel_fn=fsdp_wrap
) )