refactor: 修改参数传递方案
This commit is contained in:
parent
a33d086883
commit
b17cc6a6fb
|
|
@ -22,15 +22,14 @@ class TrainConfig:
|
|||
default=None,
|
||||
metadata={"help": "Dataset for training."}
|
||||
)
|
||||
optimizer: Optimizer = field(
|
||||
optimizer_fn: Callable[[nn.Module], Optimizer] = field(
|
||||
default=None,
|
||||
metadata={"help": "Optimizer for training."}
|
||||
metadata={"help": "Optimizer factory for training."}
|
||||
)
|
||||
scheduler: LRScheduler = field(
|
||||
scheduler_fn: Callable[[Optimizer], LRScheduler] = field(
|
||||
default=None,
|
||||
metadata={"help": "Scheduler for training."}
|
||||
metadata={"help": "Scheduler factory for training."}
|
||||
)
|
||||
|
||||
n_epoch: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of epochs for training."}
|
||||
|
|
@ -105,20 +104,11 @@ class TrainConfig:
|
|||
default=None,
|
||||
metadata={"help": "Parallel function for training."}
|
||||
)
|
||||
state_dict_wrapper: Optional[Callable] = field(
|
||||
state_dict_fn: Optional[Callable] = field(
|
||||
default=None,
|
||||
metadata={"help": "Parallel function for state dict saving."}
|
||||
)
|
||||
|
||||
optimizer_factory: Optional[Callable[[nn.Module], Optimizer]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Optimizer factory for training."}
|
||||
)
|
||||
scheduler_factory: Optional[Callable[[Optimizer], LRScheduler]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Scheduler factory for training."}
|
||||
)
|
||||
|
||||
# others
|
||||
device_ids: Optional[List[int]] = field(
|
||||
default=None,
|
||||
|
|
@ -137,19 +127,10 @@ class TrainConfig:
|
|||
self.validate()
|
||||
|
||||
def validate(self):
|
||||
required_fields = ["model", "strategy", "dataset"]
|
||||
required_fields = ["model", "strategy", "dataset", "optimizer_fn", "scheduler_fn"]
|
||||
|
||||
for field_name in required_fields:
|
||||
if getattr(self, field_name) is None:
|
||||
raise ValueError(f"{field_name} is required.")
|
||||
|
||||
factory_case = all([self.optimizer_factory, self.scheduler_factory])
|
||||
argument_case = all([self.optimizer, self.scheduler])
|
||||
self.nprocs = max(self.nprocs, 1)
|
||||
|
||||
if self.nprocs > 1 and not factory_case:
|
||||
raise ValueError("Distributed training requires optimizer and scheduler factories.")
|
||||
elif self.nprocs == 1 and not argument_case:
|
||||
raise ValueError("Single process training requires optimizer and scheduler arguments.")
|
||||
|
||||
|
||||
|
|
@ -36,8 +36,6 @@ class TrainContextBuilder:
|
|||
self.config = config
|
||||
self._context = TrainContext(
|
||||
model=config.model,
|
||||
optimizer=config.optimizer,
|
||||
scheduler=config.scheduler,
|
||||
world_size=get_world_size(),
|
||||
rank=get_rank(),
|
||||
)
|
||||
|
|
@ -46,20 +44,17 @@ class TrainContextBuilder:
|
|||
self._context.model = self._context.model.to(device=device)
|
||||
|
||||
if self.config.nprocs > 1:
|
||||
|
||||
fn = self.config.parallel_wrapper
|
||||
optimizer_fn = self.config.optimizer_factory
|
||||
scheduler_fn = self.config.scheduler_factory
|
||||
|
||||
self._context.model = fn(self._context.model)
|
||||
self._context.optimizer = optimizer_fn(self._context.model.parameters())
|
||||
self._context.scheduler = scheduler_fn(self._context.optimizer)
|
||||
|
||||
self._context.optimizer = self.config.optimizer_fn(self._context.model.parameters())
|
||||
self._context.scheduler = self.config.scheduler_fn(self._context.optimizer)
|
||||
|
||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||
if checkpoint is None:
|
||||
checkpoint = Checkpoint(
|
||||
optimizer_state_dict=self.config.optimizer.state_dict(),
|
||||
scheduler_state_dict=self.config.scheduler.state_dict() if self.config.scheduler is not None else None,
|
||||
optimizer_state_dict=self._context.optimizer.state_dict(),
|
||||
scheduler_state_dict=self._context.scheduler.state_dict(),
|
||||
)
|
||||
else:
|
||||
# resume from the assigned checkpoint or assigned iteration
|
||||
|
|
@ -102,6 +97,5 @@ class TrainContextBuilder:
|
|||
)
|
||||
return self
|
||||
|
||||
|
||||
def build(self) -> TrainContext:
|
||||
return self._context
|
||||
|
|
@ -5,6 +5,8 @@ import torch.nn as nn
|
|||
import torch.optim as optim
|
||||
import torch.distributed.fsdp as fsdp
|
||||
|
||||
from torch.distributed.fsdp.api import StateDictType
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from typing import List, Optional
|
||||
from functools import partial
|
||||
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
||||
|
|
@ -77,6 +79,13 @@ def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
|
|||
def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler.LRScheduler:
|
||||
return SchedulerFactory.load(optimizer, **kwargs)
|
||||
|
||||
def prepare_checkpoint(model: nn.Module, optimizer: optim.Optimizer) -> dict:
|
||||
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
|
||||
model_state_dict = model.state_dict()
|
||||
optim_state_dict = FSDP.optim_state_dict(model, optimizer)
|
||||
|
||||
return model_state_dict, optim_state_dict
|
||||
|
||||
def train(
|
||||
train_type: str,
|
||||
param_path: str,
|
||||
|
|
@ -134,21 +143,17 @@ def train(
|
|||
total_steps=len(dataset) * n_epoch // (batch_size * nprocs),
|
||||
)
|
||||
|
||||
|
||||
optimizer_fn = partial(create_optimizer, **{"lr": max_lr, "betas": (adamw_beta1, adamw_beta2), "weight_decay": adamw_weight_decay})
|
||||
scheduler_fn = partial(create_scheduler, **{"schedule_config": schedule_config})
|
||||
optimizer, scheduler = None, None
|
||||
|
||||
if nprocs == 1:
|
||||
optimizer = optimizer_fn(model.parameters())
|
||||
scheduler = scheduler_fn(optimizer)
|
||||
optimizer_fn = partial(create_optimizer,
|
||||
**{"lr": max_lr, "betas": (adamw_beta1, adamw_beta2), "weight_decay": adamw_weight_decay})
|
||||
scheduler_fn = partial(create_scheduler,
|
||||
**{"schedule_config": schedule_config})
|
||||
|
||||
train_config = TrainConfig(
|
||||
model=model,
|
||||
strategy=train_type,
|
||||
dataset=dataset,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
optimizer_fn=optimizer_fn,
|
||||
scheduler_fn=scheduler_fn,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
n_epoch=n_epoch,
|
||||
batch_size=batch_size,
|
||||
|
|
@ -162,8 +167,7 @@ def train(
|
|||
pin_memory=pin_memory,
|
||||
nprocs=nprocs,
|
||||
parallel_wrapper=fsdp_wrap,
|
||||
optimizer_factory=optimizer_fn,
|
||||
scheduler_factory=scheduler_fn,
|
||||
state_dict_fn=prepare_checkpoint,
|
||||
device_ids=device_ids,
|
||||
device_type=device_type,
|
||||
extra_kwargs=kwargs,
|
||||
|
|
|
|||
Loading…
Reference in New Issue