refactor: 修改参数传递方案

This commit is contained in:
ViperEkura 2026-02-28 18:09:00 +08:00
parent a33d086883
commit b17cc6a6fb
3 changed files with 28 additions and 49 deletions

View File

@ -22,15 +22,14 @@ class TrainConfig:
default=None, default=None,
metadata={"help": "Dataset for training."} metadata={"help": "Dataset for training."}
) )
optimizer: Optimizer = field( optimizer_fn: Callable[[nn.Module], Optimizer] = field(
default=None, default=None,
metadata={"help": "Optimizer for training."} metadata={"help": "Optimizer factory for training."}
) )
scheduler: LRScheduler = field( scheduler_fn: Callable[[Optimizer], LRScheduler] = field(
default=None, default=None,
metadata={"help": "Scheduler for training."} metadata={"help": "Scheduler factory for training."}
) )
n_epoch: int = field( n_epoch: int = field(
default=1, default=1,
metadata={"help": "Number of epochs for training."} metadata={"help": "Number of epochs for training."}
@ -105,19 +104,10 @@ class TrainConfig:
default=None, default=None,
metadata={"help": "Parallel function for training."} metadata={"help": "Parallel function for training."}
) )
state_dict_wrapper: Optional[Callable] = field( state_dict_fn: Optional[Callable] = field(
default=None, default=None,
metadata={"help": "Parallel function for state dict saving."} metadata={"help": "Parallel function for state dict saving."}
) )
optimizer_factory: Optional[Callable[[nn.Module], Optimizer]] = field(
default=None,
metadata={"help": "Optimizer factory for training."}
)
scheduler_factory: Optional[Callable[[Optimizer], LRScheduler]] = field(
default=None,
metadata={"help": "Scheduler factory for training."}
)
# others # others
device_ids: Optional[List[int]] = field( device_ids: Optional[List[int]] = field(
@ -137,19 +127,10 @@ class TrainConfig:
self.validate() self.validate()
def validate(self): def validate(self):
required_fields = ["model", "strategy", "dataset"] required_fields = ["model", "strategy", "dataset", "optimizer_fn", "scheduler_fn"]
for field_name in required_fields: for field_name in required_fields:
if getattr(self, field_name) is None: if getattr(self, field_name) is None:
raise ValueError(f"{field_name} is required.") raise ValueError(f"{field_name} is required.")
factory_case = all([self.optimizer_factory, self.scheduler_factory])
argument_case = all([self.optimizer, self.scheduler])
self.nprocs = max(self.nprocs, 1)
if self.nprocs > 1 and not factory_case:
raise ValueError("Distributed training requires optimizer and scheduler factories.")
elif self.nprocs == 1 and not argument_case:
raise ValueError("Single process training requires optimizer and scheduler arguments.")

View File

@ -36,8 +36,6 @@ class TrainContextBuilder:
self.config = config self.config = config
self._context = TrainContext( self._context = TrainContext(
model=config.model, model=config.model,
optimizer=config.optimizer,
scheduler=config.scheduler,
world_size=get_world_size(), world_size=get_world_size(),
rank=get_rank(), rank=get_rank(),
) )
@ -46,20 +44,17 @@ class TrainContextBuilder:
self._context.model = self._context.model.to(device=device) self._context.model = self._context.model.to(device=device)
if self.config.nprocs > 1: if self.config.nprocs > 1:
fn = self.config.parallel_wrapper fn = self.config.parallel_wrapper
optimizer_fn = self.config.optimizer_factory
scheduler_fn = self.config.scheduler_factory
self._context.model = fn(self._context.model) self._context.model = fn(self._context.model)
self._context.optimizer = optimizer_fn(self._context.model.parameters())
self._context.scheduler = scheduler_fn(self._context.optimizer) self._context.optimizer = self.config.optimizer_fn(self._context.model.parameters())
self._context.scheduler = self.config.scheduler_fn(self._context.optimizer)
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
if checkpoint is None: if checkpoint is None:
checkpoint = Checkpoint( checkpoint = Checkpoint(
optimizer_state_dict=self.config.optimizer.state_dict(), optimizer_state_dict=self._context.optimizer.state_dict(),
scheduler_state_dict=self.config.scheduler.state_dict() if self.config.scheduler is not None else None, scheduler_state_dict=self._context.scheduler.state_dict(),
) )
else: else:
# resume from the assigned checkpoint or assigned iteration # resume from the assigned checkpoint or assigned iteration
@ -102,6 +97,5 @@ class TrainContextBuilder:
) )
return self return self
def build(self) -> TrainContext: def build(self) -> TrainContext:
return self._context return self._context

View File

@ -5,6 +5,8 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import torch.distributed.fsdp as fsdp import torch.distributed.fsdp as fsdp
from torch.distributed.fsdp.api import StateDictType
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from typing import List, Optional from typing import List, Optional
from functools import partial from functools import partial
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
@ -77,6 +79,13 @@ def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler.LRScheduler: def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler.LRScheduler:
return SchedulerFactory.load(optimizer, **kwargs) return SchedulerFactory.load(optimizer, **kwargs)
def prepare_checkpoint(model: nn.Module, optimizer: optim.Optimizer) -> dict:
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
model_state_dict = model.state_dict()
optim_state_dict = FSDP.optim_state_dict(model, optimizer)
return model_state_dict, optim_state_dict
def train( def train(
train_type: str, train_type: str,
param_path: str, param_path: str,
@ -133,22 +142,18 @@ def train(
warmup_steps=warmup_steps, warmup_steps=warmup_steps,
total_steps=len(dataset) * n_epoch // (batch_size * nprocs), total_steps=len(dataset) * n_epoch // (batch_size * nprocs),
) )
optimizer_fn = partial(create_optimizer, **{"lr": max_lr, "betas": (adamw_beta1, adamw_beta2), "weight_decay": adamw_weight_decay}) optimizer_fn = partial(create_optimizer,
scheduler_fn = partial(create_scheduler, **{"schedule_config": schedule_config}) **{"lr": max_lr, "betas": (adamw_beta1, adamw_beta2), "weight_decay": adamw_weight_decay})
optimizer, scheduler = None, None scheduler_fn = partial(create_scheduler,
**{"schedule_config": schedule_config})
if nprocs == 1:
optimizer = optimizer_fn(model.parameters())
scheduler = scheduler_fn(optimizer)
train_config = TrainConfig( train_config = TrainConfig(
model=model, model=model,
strategy=train_type, strategy=train_type,
dataset=dataset, dataset=dataset,
optimizer=optimizer, optimizer_fn=optimizer_fn,
scheduler=scheduler, scheduler_fn=scheduler_fn,
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
n_epoch=n_epoch, n_epoch=n_epoch,
batch_size=batch_size, batch_size=batch_size,
@ -162,8 +167,7 @@ def train(
pin_memory=pin_memory, pin_memory=pin_memory,
nprocs=nprocs, nprocs=nprocs,
parallel_wrapper=fsdp_wrap, parallel_wrapper=fsdp_wrap,
optimizer_factory=optimizer_fn, state_dict_fn=prepare_checkpoint,
scheduler_factory=scheduler_fn,
device_ids=device_ids, device_ids=device_ids,
device_type=device_type, device_type=device_type,
extra_kwargs=kwargs, extra_kwargs=kwargs,