feat(train): 支持分布式训练的优化器与调度器工厂配置

This commit is contained in:
ViperEkura 2025-12-22 20:41:03 +08:00
parent 7623b1e5fd
commit cfa3cf7daa
3 changed files with 52 additions and 22 deletions

View File

@ -1,4 +1,4 @@
from torch import nn
import torch.nn as nn
from torch.utils.data import Dataset
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
@ -30,6 +30,7 @@ class TrainConfig:
default=None,
metadata={"help": "Scheduler for training."}
)
n_epoch: int = field(
default=1,
metadata={"help": "Number of epochs for training."}
@ -104,7 +105,15 @@ class TrainConfig:
default=None,
metadata={"help": "Parallel function for training."}
)
optimizer_factory: Optional[Callable[[nn.Module], Optimizer]] = field(
default=None,
metadata={"help": "Optimizer factory for training."}
)
scheduler_factory: Optional[Callable[[Optimizer], LRScheduler]] = field(
default=None,
metadata={"help": "Scheduler factory for training."}
)
# others
extra_kwargs: dict = field(
default_factory=dict,
@ -115,7 +124,17 @@ class TrainConfig:
self.validate()
def validate(self):
required_fields = ["model", "strategy", "dataset", "optimizer", "scheduler"]
required_fields = ["model", "strategy", "dataset"]
for field_name in required_fields:
if getattr(self, field_name) is None:
raise ValueError(f"{field_name} is required.")
factory_case = all([self.optimizer_factory, self.scheduler_factory])
argument_case = all([self.optimizer, self.scheduler])
self.nprocs = max(self.nprocs, 1)
if self.nprocs > 1 and not factory_case:
raise ValueError("Distributed training requires optimizer and scheduler factories.")
elif self.nprocs == 1 and not argument_case:
raise ValueError("Single process training requires optimizer and scheduler arguments.")

View File

@ -80,19 +80,27 @@ class TrainContextBuilder:
return self
def with_strategy(self) -> Self:
device = get_current_device()
self._context.strategy = StrategyFactory.load(
model=self.config.model,
train_type=self.config.strategy,
device=device,
device=get_current_device(),
**self.config.extra_kwargs
)
return self
def with_parallel_fn(self) -> Self:
fn = self.config.parallel_fn
if fn is not None:
device = get_current_device()
self._context.model = self._context.model.to(device=device)
if self.config.nprocs > 1:
fn = self.config.parallel_fn
optimizer_fn = self.config.optimizer_factory
scheduler_fn = self.config.scheduler_factory
self._context.model = fn(self._context.model)
self._context.optimizer = optimizer_fn(self._context.model.parameters())
self._context.scheduler = scheduler_fn(self._context.optimizer)
return self

View File

@ -2,9 +2,10 @@ import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed.fsdp as fsdp
from torch.optim import AdamW
from functools import partial
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
from khaosz.trainer import Trainer, SchedulerFactory
from khaosz.data import DatasetLoader
@ -26,7 +27,6 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--adamw_beta1", type=float, default=0.9, help="Beta values for AdamW optimizer.")
parser.add_argument("--adamw_beta2", type=float, default=0.95, help="Beta values for AdamW optimizer.")
parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.")
parser.add_argument("--embdeding_lr_rate", type=float, default=1.0, help="The rate between the embedding layers lr rate and the max lr rate.")
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for data loading.")
parser.add_argument("--no_pin_memory", action="store_false", dest="pin_memory", help="Disable pin memory")
@ -59,6 +59,12 @@ def fsdp_wrap(model: nn.Module):
)
return fsdp_model
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
return optim.AdamW(model.parameters(), **kwargs)
def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler.LRScheduler:
return SchedulerFactory.load(optimizer, **kwargs)
def train(
train_type: str,
param_path: str,
@ -77,7 +83,6 @@ def train(
adamw_beta2: float,
adamw_weight_decay: float,
max_grad_norm: float,
embdeding_lr_rate: int,
random_seed: int,
num_workers: int,
pin_memory: bool,
@ -110,23 +115,19 @@ def train(
stride=stride
)
param_groups = [
{"params": [p for n, p in model.named_parameters() if "embed" in n], "lr": max_lr * embdeding_lr_rate},
{"params": [p for n, p in model.named_parameters() if "embed" not in n], "lr": max_lr}
]
optimizer = AdamW(
param_groups,
betas=(adamw_beta1, adamw_beta2),
weight_decay=adamw_weight_decay
)
schedule_config = CosineScheduleConfig(
warmup_steps=warmup_steps,
total_steps=len(dataset) * n_epoch // (batch_size * nprocs),
)
scheduler = SchedulerFactory.load(optimizer, schedule_config)
optimizer_fn = partial(create_optimizer, **{"lr": max_lr, "betas": (adamw_beta1, adamw_beta2), "weight_decay": adamw_weight_decay})
scheduler_fn = partial(create_scheduler, **{"schedule_config": schedule_config})
optimizer, scheduler = None, None
if nprocs == 1:
optimizer = optimizer_fn(model.parameters())
scheduler = scheduler_fn(optimizer)
train_config = TrainConfig(
model=model,
@ -146,6 +147,8 @@ def train(
num_workers=num_workers,
pin_memory=pin_memory,
nprocs=nprocs,
optimizer_factory=optimizer_fn,
scheduler_factory=scheduler_fn,
extra_kwargs=kwargs,
parallel_fn=fsdp_wrap
)