feat(train): 支持分布式训练的优化器与调度器工厂配置
This commit is contained in:
parent
7623b1e5fd
commit
cfa3cf7daa
|
|
@ -1,4 +1,4 @@
|
|||
from torch import nn
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
|
@ -30,6 +30,7 @@ class TrainConfig:
|
|||
default=None,
|
||||
metadata={"help": "Scheduler for training."}
|
||||
)
|
||||
|
||||
n_epoch: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of epochs for training."}
|
||||
|
|
@ -104,6 +105,14 @@ class TrainConfig:
|
|||
default=None,
|
||||
metadata={"help": "Parallel function for training."}
|
||||
)
|
||||
optimizer_factory: Optional[Callable[[nn.Module], Optimizer]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Optimizer factory for training."}
|
||||
)
|
||||
scheduler_factory: Optional[Callable[[Optimizer], LRScheduler]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Scheduler factory for training."}
|
||||
)
|
||||
|
||||
# others
|
||||
extra_kwargs: dict = field(
|
||||
|
|
@ -115,7 +124,17 @@ class TrainConfig:
|
|||
self.validate()
|
||||
|
||||
def validate(self):
|
||||
required_fields = ["model", "strategy", "dataset", "optimizer", "scheduler"]
|
||||
required_fields = ["model", "strategy", "dataset"]
|
||||
|
||||
for field_name in required_fields:
|
||||
if getattr(self, field_name) is None:
|
||||
raise ValueError(f"{field_name} is required.")
|
||||
|
||||
factory_case = all([self.optimizer_factory, self.scheduler_factory])
|
||||
argument_case = all([self.optimizer, self.scheduler])
|
||||
self.nprocs = max(self.nprocs, 1)
|
||||
|
||||
if self.nprocs > 1 and not factory_case:
|
||||
raise ValueError("Distributed training requires optimizer and scheduler factories.")
|
||||
elif self.nprocs == 1 and not argument_case:
|
||||
raise ValueError("Single process training requires optimizer and scheduler arguments.")
|
||||
|
|
|
|||
|
|
@ -80,19 +80,27 @@ class TrainContextBuilder:
|
|||
return self
|
||||
|
||||
def with_strategy(self) -> Self:
|
||||
device = get_current_device()
|
||||
self._context.strategy = StrategyFactory.load(
|
||||
model=self.config.model,
|
||||
train_type=self.config.strategy,
|
||||
device=device,
|
||||
device=get_current_device(),
|
||||
**self.config.extra_kwargs
|
||||
)
|
||||
return self
|
||||
|
||||
def with_parallel_fn(self) -> Self:
|
||||
fn = self.config.parallel_fn
|
||||
if fn is not None:
|
||||
device = get_current_device()
|
||||
self._context.model = self._context.model.to(device=device)
|
||||
|
||||
if self.config.nprocs > 1:
|
||||
|
||||
fn = self.config.parallel_fn
|
||||
optimizer_fn = self.config.optimizer_factory
|
||||
scheduler_fn = self.config.scheduler_factory
|
||||
|
||||
self._context.model = fn(self._context.model)
|
||||
self._context.optimizer = optimizer_fn(self._context.model.parameters())
|
||||
self._context.scheduler = scheduler_fn(self._context.optimizer)
|
||||
|
||||
return self
|
||||
|
||||
|
|
|
|||
|
|
@ -2,9 +2,10 @@ import os
|
|||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.distributed.fsdp as fsdp
|
||||
|
||||
from torch.optim import AdamW
|
||||
from functools import partial
|
||||
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
||||
from khaosz.trainer import Trainer, SchedulerFactory
|
||||
from khaosz.data import DatasetLoader
|
||||
|
|
@ -26,7 +27,6 @@ def parse_args() -> argparse.Namespace:
|
|||
parser.add_argument("--adamw_beta1", type=float, default=0.9, help="Beta values for AdamW optimizer.")
|
||||
parser.add_argument("--adamw_beta2", type=float, default=0.95, help="Beta values for AdamW optimizer.")
|
||||
parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.")
|
||||
parser.add_argument("--embdeding_lr_rate", type=float, default=1.0, help="The rate between the embedding layers lr rate and the max lr rate.")
|
||||
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
|
||||
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for data loading.")
|
||||
parser.add_argument("--no_pin_memory", action="store_false", dest="pin_memory", help="Disable pin memory")
|
||||
|
|
@ -59,6 +59,12 @@ def fsdp_wrap(model: nn.Module):
|
|||
)
|
||||
return fsdp_model
|
||||
|
||||
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
|
||||
return optim.AdamW(model.parameters(), **kwargs)
|
||||
|
||||
def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler.LRScheduler:
|
||||
return SchedulerFactory.load(optimizer, **kwargs)
|
||||
|
||||
def train(
|
||||
train_type: str,
|
||||
param_path: str,
|
||||
|
|
@ -77,7 +83,6 @@ def train(
|
|||
adamw_beta2: float,
|
||||
adamw_weight_decay: float,
|
||||
max_grad_norm: float,
|
||||
embdeding_lr_rate: int,
|
||||
random_seed: int,
|
||||
num_workers: int,
|
||||
pin_memory: bool,
|
||||
|
|
@ -110,23 +115,19 @@ def train(
|
|||
stride=stride
|
||||
)
|
||||
|
||||
param_groups = [
|
||||
{"params": [p for n, p in model.named_parameters() if "embed" in n], "lr": max_lr * embdeding_lr_rate},
|
||||
{"params": [p for n, p in model.named_parameters() if "embed" not in n], "lr": max_lr}
|
||||
]
|
||||
|
||||
optimizer = AdamW(
|
||||
param_groups,
|
||||
betas=(adamw_beta1, adamw_beta2),
|
||||
weight_decay=adamw_weight_decay
|
||||
)
|
||||
|
||||
schedule_config = CosineScheduleConfig(
|
||||
warmup_steps=warmup_steps,
|
||||
total_steps=len(dataset) * n_epoch // (batch_size * nprocs),
|
||||
)
|
||||
|
||||
scheduler = SchedulerFactory.load(optimizer, schedule_config)
|
||||
|
||||
optimizer_fn = partial(create_optimizer, **{"lr": max_lr, "betas": (adamw_beta1, adamw_beta2), "weight_decay": adamw_weight_decay})
|
||||
scheduler_fn = partial(create_scheduler, **{"schedule_config": schedule_config})
|
||||
optimizer, scheduler = None, None
|
||||
|
||||
if nprocs == 1:
|
||||
optimizer = optimizer_fn(model.parameters())
|
||||
scheduler = scheduler_fn(optimizer)
|
||||
|
||||
train_config = TrainConfig(
|
||||
model=model,
|
||||
|
|
@ -146,6 +147,8 @@ def train(
|
|||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
nprocs=nprocs,
|
||||
optimizer_factory=optimizer_fn,
|
||||
scheduler_factory=scheduler_fn,
|
||||
extra_kwargs=kwargs,
|
||||
parallel_fn=fsdp_wrap
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue