diff --git a/khaosz/config/train_config.py b/khaosz/config/train_config.py index 44f20c4..b5e2903 100644 --- a/khaosz/config/train_config.py +++ b/khaosz/config/train_config.py @@ -22,15 +22,14 @@ class TrainConfig: default=None, metadata={"help": "Dataset for training."} ) - optimizer: Optimizer = field( + optimizer_fn: Callable[[nn.Module], Optimizer] = field( default=None, - metadata={"help": "Optimizer for training."} + metadata={"help": "Optimizer factory for training."} ) - scheduler: LRScheduler = field( + scheduler_fn: Callable[[Optimizer], LRScheduler] = field( default=None, - metadata={"help": "Scheduler for training."} + metadata={"help": "Scheduler factory for training."} ) - n_epoch: int = field( default=1, metadata={"help": "Number of epochs for training."} @@ -105,19 +104,10 @@ class TrainConfig: default=None, metadata={"help": "Parallel function for training."} ) - state_dict_wrapper: Optional[Callable] = field( + state_dict_fn: Optional[Callable] = field( default=None, metadata={"help": "Parallel function for state dict saving."} ) - - optimizer_factory: Optional[Callable[[nn.Module], Optimizer]] = field( - default=None, - metadata={"help": "Optimizer factory for training."} - ) - scheduler_factory: Optional[Callable[[Optimizer], LRScheduler]] = field( - default=None, - metadata={"help": "Scheduler factory for training."} - ) # others device_ids: Optional[List[int]] = field( @@ -137,19 +127,10 @@ class TrainConfig: self.validate() def validate(self): - required_fields = ["model", "strategy", "dataset"] + required_fields = ["model", "strategy", "dataset", "optimizer_fn", "scheduler_fn"] for field_name in required_fields: if getattr(self, field_name) is None: raise ValueError(f"{field_name} is required.") - factory_case = all([self.optimizer_factory, self.scheduler_factory]) - argument_case = all([self.optimizer, self.scheduler]) - self.nprocs = max(self.nprocs, 1) - - if self.nprocs > 1 and not factory_case: - raise ValueError("Distributed training requires optimizer and scheduler factories.") - elif self.nprocs == 1 and not argument_case: - raise ValueError("Single process training requires optimizer and scheduler arguments.") - \ No newline at end of file diff --git a/khaosz/trainer/train_context.py b/khaosz/trainer/train_context.py index 4792e2a..e4e4dcb 100644 --- a/khaosz/trainer/train_context.py +++ b/khaosz/trainer/train_context.py @@ -36,8 +36,6 @@ class TrainContextBuilder: self.config = config self._context = TrainContext( model=config.model, - optimizer=config.optimizer, - scheduler=config.scheduler, world_size=get_world_size(), rank=get_rank(), ) @@ -46,20 +44,17 @@ class TrainContextBuilder: self._context.model = self._context.model.to(device=device) if self.config.nprocs > 1: - fn = self.config.parallel_wrapper - optimizer_fn = self.config.optimizer_factory - scheduler_fn = self.config.scheduler_factory - self._context.model = fn(self._context.model) - self._context.optimizer = optimizer_fn(self._context.model.parameters()) - self._context.scheduler = scheduler_fn(self._context.optimizer) - + + self._context.optimizer = self.config.optimizer_fn(self._context.model.parameters()) + self._context.scheduler = self.config.scheduler_fn(self._context.optimizer) + def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: if checkpoint is None: checkpoint = Checkpoint( - optimizer_state_dict=self.config.optimizer.state_dict(), - scheduler_state_dict=self.config.scheduler.state_dict() if self.config.scheduler is not None else None, + optimizer_state_dict=self._context.optimizer.state_dict(), + scheduler_state_dict=self._context.scheduler.state_dict(), ) else: # resume from the assigned checkpoint or assigned iteration @@ -102,6 +97,5 @@ class TrainContextBuilder: ) return self - def build(self) -> TrainContext: return self._context \ No newline at end of file diff --git a/tools/train.py b/tools/train.py index c3b600f..dc65172 100644 --- a/tools/train.py +++ b/tools/train.py @@ -5,6 +5,8 @@ import torch.nn as nn import torch.optim as optim import torch.distributed.fsdp as fsdp +from torch.distributed.fsdp.api import StateDictType +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from typing import List, Optional from functools import partial from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig @@ -77,6 +79,13 @@ def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer: def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler.LRScheduler: return SchedulerFactory.load(optimizer, **kwargs) +def prepare_checkpoint(model: nn.Module, optimizer: optim.Optimizer) -> dict: + with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT): + model_state_dict = model.state_dict() + optim_state_dict = FSDP.optim_state_dict(model, optimizer) + + return model_state_dict, optim_state_dict + def train( train_type: str, param_path: str, @@ -133,22 +142,18 @@ def train( warmup_steps=warmup_steps, total_steps=len(dataset) * n_epoch // (batch_size * nprocs), ) - - optimizer_fn = partial(create_optimizer, **{"lr": max_lr, "betas": (adamw_beta1, adamw_beta2), "weight_decay": adamw_weight_decay}) - scheduler_fn = partial(create_scheduler, **{"schedule_config": schedule_config}) - optimizer, scheduler = None, None - - if nprocs == 1: - optimizer = optimizer_fn(model.parameters()) - scheduler = scheduler_fn(optimizer) + optimizer_fn = partial(create_optimizer, + **{"lr": max_lr, "betas": (adamw_beta1, adamw_beta2), "weight_decay": adamw_weight_decay}) + scheduler_fn = partial(create_scheduler, + **{"schedule_config": schedule_config}) train_config = TrainConfig( model=model, strategy=train_type, dataset=dataset, - optimizer=optimizer, - scheduler=scheduler, + optimizer_fn=optimizer_fn, + scheduler_fn=scheduler_fn, checkpoint_dir=checkpoint_dir, n_epoch=n_epoch, batch_size=batch_size, @@ -162,8 +167,7 @@ def train( pin_memory=pin_memory, nprocs=nprocs, parallel_wrapper=fsdp_wrap, - optimizer_factory=optimizer_fn, - scheduler_factory=scheduler_fn, + state_dict_fn=prepare_checkpoint, device_ids=device_ids, device_type=device_type, extra_kwargs=kwargs,