feat(train): 支持分布式训练的优化器与调度器工厂配置
This commit is contained in:
parent
7623b1e5fd
commit
cfa3cf7daa
|
|
@ -1,4 +1,4 @@
|
||||||
from torch import nn
|
import torch.nn as nn
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
|
|
@ -30,6 +30,7 @@ class TrainConfig:
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Scheduler for training."}
|
metadata={"help": "Scheduler for training."}
|
||||||
)
|
)
|
||||||
|
|
||||||
n_epoch: int = field(
|
n_epoch: int = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "Number of epochs for training."}
|
metadata={"help": "Number of epochs for training."}
|
||||||
|
|
@ -104,6 +105,14 @@ class TrainConfig:
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Parallel function for training."}
|
metadata={"help": "Parallel function for training."}
|
||||||
)
|
)
|
||||||
|
optimizer_factory: Optional[Callable[[nn.Module], Optimizer]] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Optimizer factory for training."}
|
||||||
|
)
|
||||||
|
scheduler_factory: Optional[Callable[[Optimizer], LRScheduler]] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Scheduler factory for training."}
|
||||||
|
)
|
||||||
|
|
||||||
# others
|
# others
|
||||||
extra_kwargs: dict = field(
|
extra_kwargs: dict = field(
|
||||||
|
|
@ -115,7 +124,17 @@ class TrainConfig:
|
||||||
self.validate()
|
self.validate()
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
required_fields = ["model", "strategy", "dataset", "optimizer", "scheduler"]
|
required_fields = ["model", "strategy", "dataset"]
|
||||||
|
|
||||||
for field_name in required_fields:
|
for field_name in required_fields:
|
||||||
if getattr(self, field_name) is None:
|
if getattr(self, field_name) is None:
|
||||||
raise ValueError(f"{field_name} is required.")
|
raise ValueError(f"{field_name} is required.")
|
||||||
|
|
||||||
|
factory_case = all([self.optimizer_factory, self.scheduler_factory])
|
||||||
|
argument_case = all([self.optimizer, self.scheduler])
|
||||||
|
self.nprocs = max(self.nprocs, 1)
|
||||||
|
|
||||||
|
if self.nprocs > 1 and not factory_case:
|
||||||
|
raise ValueError("Distributed training requires optimizer and scheduler factories.")
|
||||||
|
elif self.nprocs == 1 and not argument_case:
|
||||||
|
raise ValueError("Single process training requires optimizer and scheduler arguments.")
|
||||||
|
|
|
||||||
|
|
@ -80,19 +80,27 @@ class TrainContextBuilder:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_strategy(self) -> Self:
|
def with_strategy(self) -> Self:
|
||||||
device = get_current_device()
|
|
||||||
self._context.strategy = StrategyFactory.load(
|
self._context.strategy = StrategyFactory.load(
|
||||||
model=self.config.model,
|
model=self.config.model,
|
||||||
train_type=self.config.strategy,
|
train_type=self.config.strategy,
|
||||||
device=device,
|
device=get_current_device(),
|
||||||
**self.config.extra_kwargs
|
**self.config.extra_kwargs
|
||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_parallel_fn(self) -> Self:
|
def with_parallel_fn(self) -> Self:
|
||||||
|
device = get_current_device()
|
||||||
|
self._context.model = self._context.model.to(device=device)
|
||||||
|
|
||||||
|
if self.config.nprocs > 1:
|
||||||
|
|
||||||
fn = self.config.parallel_fn
|
fn = self.config.parallel_fn
|
||||||
if fn is not None:
|
optimizer_fn = self.config.optimizer_factory
|
||||||
|
scheduler_fn = self.config.scheduler_factory
|
||||||
|
|
||||||
self._context.model = fn(self._context.model)
|
self._context.model = fn(self._context.model)
|
||||||
|
self._context.optimizer = optimizer_fn(self._context.model.parameters())
|
||||||
|
self._context.scheduler = scheduler_fn(self._context.optimizer)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,10 @@ import os
|
||||||
import argparse
|
import argparse
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
import torch.distributed.fsdp as fsdp
|
import torch.distributed.fsdp as fsdp
|
||||||
|
|
||||||
from torch.optim import AdamW
|
from functools import partial
|
||||||
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
||||||
from khaosz.trainer import Trainer, SchedulerFactory
|
from khaosz.trainer import Trainer, SchedulerFactory
|
||||||
from khaosz.data import DatasetLoader
|
from khaosz.data import DatasetLoader
|
||||||
|
|
@ -26,7 +27,6 @@ def parse_args() -> argparse.Namespace:
|
||||||
parser.add_argument("--adamw_beta1", type=float, default=0.9, help="Beta values for AdamW optimizer.")
|
parser.add_argument("--adamw_beta1", type=float, default=0.9, help="Beta values for AdamW optimizer.")
|
||||||
parser.add_argument("--adamw_beta2", type=float, default=0.95, help="Beta values for AdamW optimizer.")
|
parser.add_argument("--adamw_beta2", type=float, default=0.95, help="Beta values for AdamW optimizer.")
|
||||||
parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.")
|
parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.")
|
||||||
parser.add_argument("--embdeding_lr_rate", type=float, default=1.0, help="The rate between the embedding layers lr rate and the max lr rate.")
|
|
||||||
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
|
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
|
||||||
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for data loading.")
|
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for data loading.")
|
||||||
parser.add_argument("--no_pin_memory", action="store_false", dest="pin_memory", help="Disable pin memory")
|
parser.add_argument("--no_pin_memory", action="store_false", dest="pin_memory", help="Disable pin memory")
|
||||||
|
|
@ -59,6 +59,12 @@ def fsdp_wrap(model: nn.Module):
|
||||||
)
|
)
|
||||||
return fsdp_model
|
return fsdp_model
|
||||||
|
|
||||||
|
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
|
||||||
|
return optim.AdamW(model.parameters(), **kwargs)
|
||||||
|
|
||||||
|
def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler.LRScheduler:
|
||||||
|
return SchedulerFactory.load(optimizer, **kwargs)
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
train_type: str,
|
train_type: str,
|
||||||
param_path: str,
|
param_path: str,
|
||||||
|
|
@ -77,7 +83,6 @@ def train(
|
||||||
adamw_beta2: float,
|
adamw_beta2: float,
|
||||||
adamw_weight_decay: float,
|
adamw_weight_decay: float,
|
||||||
max_grad_norm: float,
|
max_grad_norm: float,
|
||||||
embdeding_lr_rate: int,
|
|
||||||
random_seed: int,
|
random_seed: int,
|
||||||
num_workers: int,
|
num_workers: int,
|
||||||
pin_memory: bool,
|
pin_memory: bool,
|
||||||
|
|
@ -110,23 +115,19 @@ def train(
|
||||||
stride=stride
|
stride=stride
|
||||||
)
|
)
|
||||||
|
|
||||||
param_groups = [
|
|
||||||
{"params": [p for n, p in model.named_parameters() if "embed" in n], "lr": max_lr * embdeding_lr_rate},
|
|
||||||
{"params": [p for n, p in model.named_parameters() if "embed" not in n], "lr": max_lr}
|
|
||||||
]
|
|
||||||
|
|
||||||
optimizer = AdamW(
|
|
||||||
param_groups,
|
|
||||||
betas=(adamw_beta1, adamw_beta2),
|
|
||||||
weight_decay=adamw_weight_decay
|
|
||||||
)
|
|
||||||
|
|
||||||
schedule_config = CosineScheduleConfig(
|
schedule_config = CosineScheduleConfig(
|
||||||
warmup_steps=warmup_steps,
|
warmup_steps=warmup_steps,
|
||||||
total_steps=len(dataset) * n_epoch // (batch_size * nprocs),
|
total_steps=len(dataset) * n_epoch // (batch_size * nprocs),
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler = SchedulerFactory.load(optimizer, schedule_config)
|
|
||||||
|
optimizer_fn = partial(create_optimizer, **{"lr": max_lr, "betas": (adamw_beta1, adamw_beta2), "weight_decay": adamw_weight_decay})
|
||||||
|
scheduler_fn = partial(create_scheduler, **{"schedule_config": schedule_config})
|
||||||
|
optimizer, scheduler = None, None
|
||||||
|
|
||||||
|
if nprocs == 1:
|
||||||
|
optimizer = optimizer_fn(model.parameters())
|
||||||
|
scheduler = scheduler_fn(optimizer)
|
||||||
|
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
@ -146,6 +147,8 @@ def train(
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
nprocs=nprocs,
|
nprocs=nprocs,
|
||||||
|
optimizer_factory=optimizer_fn,
|
||||||
|
scheduler_factory=scheduler_fn,
|
||||||
extra_kwargs=kwargs,
|
extra_kwargs=kwargs,
|
||||||
parallel_fn=fsdp_wrap
|
parallel_fn=fsdp_wrap
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue