From cfa3cf7daa8a6ae8344b1dbde2a1d7aba46d75dc Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 22 Dec 2025 20:41:03 +0800 Subject: [PATCH] =?UTF-8?q?feat(train):=20=E6=94=AF=E6=8C=81=E5=88=86?= =?UTF-8?q?=E5=B8=83=E5=BC=8F=E8=AE=AD=E7=BB=83=E7=9A=84=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E5=99=A8=E4=B8=8E=E8=B0=83=E5=BA=A6=E5=99=A8=E5=B7=A5=E5=8E=82?= =?UTF-8?q?=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/config/train_config.py | 25 ++++++++++++++++++++++--- khaosz/trainer/train_context.py | 16 ++++++++++++---- tools/train.py | 33 ++++++++++++++++++--------------- 3 files changed, 52 insertions(+), 22 deletions(-) diff --git a/khaosz/config/train_config.py b/khaosz/config/train_config.py index 434172a..d41d01b 100644 --- a/khaosz/config/train_config.py +++ b/khaosz/config/train_config.py @@ -1,4 +1,4 @@ -from torch import nn +import torch.nn as nn from torch.utils.data import Dataset from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler @@ -30,6 +30,7 @@ class TrainConfig: default=None, metadata={"help": "Scheduler for training."} ) + n_epoch: int = field( default=1, metadata={"help": "Number of epochs for training."} @@ -104,7 +105,15 @@ class TrainConfig: default=None, metadata={"help": "Parallel function for training."} ) - + optimizer_factory: Optional[Callable[[nn.Module], Optimizer]] = field( + default=None, + metadata={"help": "Optimizer factory for training."} + ) + scheduler_factory: Optional[Callable[[Optimizer], LRScheduler]] = field( + default=None, + metadata={"help": "Scheduler factory for training."} + ) + # others extra_kwargs: dict = field( default_factory=dict, @@ -115,7 +124,17 @@ class TrainConfig: self.validate() def validate(self): - required_fields = ["model", "strategy", "dataset", "optimizer", "scheduler"] + required_fields = ["model", "strategy", "dataset"] + for field_name in required_fields: if getattr(self, field_name) is None: raise ValueError(f"{field_name} is required.") + + factory_case = all([self.optimizer_factory, self.scheduler_factory]) + argument_case = all([self.optimizer, self.scheduler]) + self.nprocs = max(self.nprocs, 1) + + if self.nprocs > 1 and not factory_case: + raise ValueError("Distributed training requires optimizer and scheduler factories.") + elif self.nprocs == 1 and not argument_case: + raise ValueError("Single process training requires optimizer and scheduler arguments.") diff --git a/khaosz/trainer/train_context.py b/khaosz/trainer/train_context.py index 40c5e10..ab1184a 100644 --- a/khaosz/trainer/train_context.py +++ b/khaosz/trainer/train_context.py @@ -80,19 +80,27 @@ class TrainContextBuilder: return self def with_strategy(self) -> Self: - device = get_current_device() self._context.strategy = StrategyFactory.load( model=self.config.model, train_type=self.config.strategy, - device=device, + device=get_current_device(), **self.config.extra_kwargs ) return self def with_parallel_fn(self) -> Self: - fn = self.config.parallel_fn - if fn is not None: + device = get_current_device() + self._context.model = self._context.model.to(device=device) + + if self.config.nprocs > 1: + + fn = self.config.parallel_fn + optimizer_fn = self.config.optimizer_factory + scheduler_fn = self.config.scheduler_factory + self._context.model = fn(self._context.model) + self._context.optimizer = optimizer_fn(self._context.model.parameters()) + self._context.scheduler = scheduler_fn(self._context.optimizer) return self diff --git a/tools/train.py b/tools/train.py index 358b12f..a7621d2 100644 --- a/tools/train.py +++ b/tools/train.py @@ -2,9 +2,10 @@ import os import argparse import torch import torch.nn as nn +import torch.optim as optim import torch.distributed.fsdp as fsdp -from torch.optim import AdamW +from functools import partial from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig from khaosz.trainer import Trainer, SchedulerFactory from khaosz.data import DatasetLoader @@ -26,7 +27,6 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--adamw_beta1", type=float, default=0.9, help="Beta values for AdamW optimizer.") parser.add_argument("--adamw_beta2", type=float, default=0.95, help="Beta values for AdamW optimizer.") parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.") - parser.add_argument("--embdeding_lr_rate", type=float, default=1.0, help="The rate between the embedding layers lr rate and the max lr rate.") parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.") parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for data loading.") parser.add_argument("--no_pin_memory", action="store_false", dest="pin_memory", help="Disable pin memory") @@ -59,6 +59,12 @@ def fsdp_wrap(model: nn.Module): ) return fsdp_model +def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer: + return optim.AdamW(model.parameters(), **kwargs) + +def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler.LRScheduler: + return SchedulerFactory.load(optimizer, **kwargs) + def train( train_type: str, param_path: str, @@ -77,7 +83,6 @@ def train( adamw_beta2: float, adamw_weight_decay: float, max_grad_norm: float, - embdeding_lr_rate: int, random_seed: int, num_workers: int, pin_memory: bool, @@ -110,23 +115,19 @@ def train( stride=stride ) - param_groups = [ - {"params": [p for n, p in model.named_parameters() if "embed" in n], "lr": max_lr * embdeding_lr_rate}, - {"params": [p for n, p in model.named_parameters() if "embed" not in n], "lr": max_lr} - ] - - optimizer = AdamW( - param_groups, - betas=(adamw_beta1, adamw_beta2), - weight_decay=adamw_weight_decay - ) - schedule_config = CosineScheduleConfig( warmup_steps=warmup_steps, total_steps=len(dataset) * n_epoch // (batch_size * nprocs), ) - scheduler = SchedulerFactory.load(optimizer, schedule_config) + + optimizer_fn = partial(create_optimizer, **{"lr": max_lr, "betas": (adamw_beta1, adamw_beta2), "weight_decay": adamw_weight_decay}) + scheduler_fn = partial(create_scheduler, **{"schedule_config": schedule_config}) + optimizer, scheduler = None, None + + if nprocs == 1: + optimizer = optimizer_fn(model.parameters()) + scheduler = scheduler_fn(optimizer) train_config = TrainConfig( model=model, @@ -146,6 +147,8 @@ def train( num_workers=num_workers, pin_memory=pin_memory, nprocs=nprocs, + optimizer_factory=optimizer_fn, + scheduler_factory=scheduler_fn, extra_kwargs=kwargs, parallel_fn=fsdp_wrap )