diff --git a/khaosz/config/train_config.py b/khaosz/config/train_config.py index 434172a..d41d01b 100644 --- a/khaosz/config/train_config.py +++ b/khaosz/config/train_config.py @@ -1,4 +1,4 @@ -from torch import nn +import torch.nn as nn from torch.utils.data import Dataset from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler @@ -30,6 +30,7 @@ class TrainConfig: default=None, metadata={"help": "Scheduler for training."} ) + n_epoch: int = field( default=1, metadata={"help": "Number of epochs for training."} @@ -104,7 +105,15 @@ class TrainConfig: default=None, metadata={"help": "Parallel function for training."} ) - + optimizer_factory: Optional[Callable[[nn.Module], Optimizer]] = field( + default=None, + metadata={"help": "Optimizer factory for training."} + ) + scheduler_factory: Optional[Callable[[Optimizer], LRScheduler]] = field( + default=None, + metadata={"help": "Scheduler factory for training."} + ) + # others extra_kwargs: dict = field( default_factory=dict, @@ -115,7 +124,17 @@ class TrainConfig: self.validate() def validate(self): - required_fields = ["model", "strategy", "dataset", "optimizer", "scheduler"] + required_fields = ["model", "strategy", "dataset"] + for field_name in required_fields: if getattr(self, field_name) is None: raise ValueError(f"{field_name} is required.") + + factory_case = all([self.optimizer_factory, self.scheduler_factory]) + argument_case = all([self.optimizer, self.scheduler]) + self.nprocs = max(self.nprocs, 1) + + if self.nprocs > 1 and not factory_case: + raise ValueError("Distributed training requires optimizer and scheduler factories.") + elif self.nprocs == 1 and not argument_case: + raise ValueError("Single process training requires optimizer and scheduler arguments.") diff --git a/khaosz/trainer/train_context.py b/khaosz/trainer/train_context.py index 40c5e10..ab1184a 100644 --- a/khaosz/trainer/train_context.py +++ b/khaosz/trainer/train_context.py @@ -80,19 +80,27 @@ class TrainContextBuilder: return self def with_strategy(self) -> Self: - device = get_current_device() self._context.strategy = StrategyFactory.load( model=self.config.model, train_type=self.config.strategy, - device=device, + device=get_current_device(), **self.config.extra_kwargs ) return self def with_parallel_fn(self) -> Self: - fn = self.config.parallel_fn - if fn is not None: + device = get_current_device() + self._context.model = self._context.model.to(device=device) + + if self.config.nprocs > 1: + + fn = self.config.parallel_fn + optimizer_fn = self.config.optimizer_factory + scheduler_fn = self.config.scheduler_factory + self._context.model = fn(self._context.model) + self._context.optimizer = optimizer_fn(self._context.model.parameters()) + self._context.scheduler = scheduler_fn(self._context.optimizer) return self diff --git a/tools/train.py b/tools/train.py index 358b12f..a7621d2 100644 --- a/tools/train.py +++ b/tools/train.py @@ -2,9 +2,10 @@ import os import argparse import torch import torch.nn as nn +import torch.optim as optim import torch.distributed.fsdp as fsdp -from torch.optim import AdamW +from functools import partial from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig from khaosz.trainer import Trainer, SchedulerFactory from khaosz.data import DatasetLoader @@ -26,7 +27,6 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--adamw_beta1", type=float, default=0.9, help="Beta values for AdamW optimizer.") parser.add_argument("--adamw_beta2", type=float, default=0.95, help="Beta values for AdamW optimizer.") parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.") - parser.add_argument("--embdeding_lr_rate", type=float, default=1.0, help="The rate between the embedding layers lr rate and the max lr rate.") parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.") parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for data loading.") parser.add_argument("--no_pin_memory", action="store_false", dest="pin_memory", help="Disable pin memory") @@ -59,6 +59,12 @@ def fsdp_wrap(model: nn.Module): ) return fsdp_model +def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer: + return optim.AdamW(model.parameters(), **kwargs) + +def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler.LRScheduler: + return SchedulerFactory.load(optimizer, **kwargs) + def train( train_type: str, param_path: str, @@ -77,7 +83,6 @@ def train( adamw_beta2: float, adamw_weight_decay: float, max_grad_norm: float, - embdeding_lr_rate: int, random_seed: int, num_workers: int, pin_memory: bool, @@ -110,23 +115,19 @@ def train( stride=stride ) - param_groups = [ - {"params": [p for n, p in model.named_parameters() if "embed" in n], "lr": max_lr * embdeding_lr_rate}, - {"params": [p for n, p in model.named_parameters() if "embed" not in n], "lr": max_lr} - ] - - optimizer = AdamW( - param_groups, - betas=(adamw_beta1, adamw_beta2), - weight_decay=adamw_weight_decay - ) - schedule_config = CosineScheduleConfig( warmup_steps=warmup_steps, total_steps=len(dataset) * n_epoch // (batch_size * nprocs), ) - scheduler = SchedulerFactory.load(optimizer, schedule_config) + + optimizer_fn = partial(create_optimizer, **{"lr": max_lr, "betas": (adamw_beta1, adamw_beta2), "weight_decay": adamw_weight_decay}) + scheduler_fn = partial(create_scheduler, **{"schedule_config": schedule_config}) + optimizer, scheduler = None, None + + if nprocs == 1: + optimizer = optimizer_fn(model.parameters()) + scheduler = scheduler_fn(optimizer) train_config = TrainConfig( model=model, @@ -146,6 +147,8 @@ def train( num_workers=num_workers, pin_memory=pin_memory, nprocs=nprocs, + optimizer_factory=optimizer_fn, + scheduler_factory=scheduler_fn, extra_kwargs=kwargs, parallel_fn=fsdp_wrap )