refactor: 修改参数传递方案
This commit is contained in:
parent
a33d086883
commit
b17cc6a6fb
|
|
@ -22,15 +22,14 @@ class TrainConfig:
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Dataset for training."}
|
metadata={"help": "Dataset for training."}
|
||||||
)
|
)
|
||||||
optimizer: Optimizer = field(
|
optimizer_fn: Callable[[nn.Module], Optimizer] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Optimizer for training."}
|
metadata={"help": "Optimizer factory for training."}
|
||||||
)
|
)
|
||||||
scheduler: LRScheduler = field(
|
scheduler_fn: Callable[[Optimizer], LRScheduler] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Scheduler for training."}
|
metadata={"help": "Scheduler factory for training."}
|
||||||
)
|
)
|
||||||
|
|
||||||
n_epoch: int = field(
|
n_epoch: int = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "Number of epochs for training."}
|
metadata={"help": "Number of epochs for training."}
|
||||||
|
|
@ -105,19 +104,10 @@ class TrainConfig:
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Parallel function for training."}
|
metadata={"help": "Parallel function for training."}
|
||||||
)
|
)
|
||||||
state_dict_wrapper: Optional[Callable] = field(
|
state_dict_fn: Optional[Callable] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Parallel function for state dict saving."}
|
metadata={"help": "Parallel function for state dict saving."}
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer_factory: Optional[Callable[[nn.Module], Optimizer]] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Optimizer factory for training."}
|
|
||||||
)
|
|
||||||
scheduler_factory: Optional[Callable[[Optimizer], LRScheduler]] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Scheduler factory for training."}
|
|
||||||
)
|
|
||||||
|
|
||||||
# others
|
# others
|
||||||
device_ids: Optional[List[int]] = field(
|
device_ids: Optional[List[int]] = field(
|
||||||
|
|
@ -137,19 +127,10 @@ class TrainConfig:
|
||||||
self.validate()
|
self.validate()
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
required_fields = ["model", "strategy", "dataset"]
|
required_fields = ["model", "strategy", "dataset", "optimizer_fn", "scheduler_fn"]
|
||||||
|
|
||||||
for field_name in required_fields:
|
for field_name in required_fields:
|
||||||
if getattr(self, field_name) is None:
|
if getattr(self, field_name) is None:
|
||||||
raise ValueError(f"{field_name} is required.")
|
raise ValueError(f"{field_name} is required.")
|
||||||
|
|
||||||
factory_case = all([self.optimizer_factory, self.scheduler_factory])
|
|
||||||
argument_case = all([self.optimizer, self.scheduler])
|
|
||||||
self.nprocs = max(self.nprocs, 1)
|
|
||||||
|
|
||||||
if self.nprocs > 1 and not factory_case:
|
|
||||||
raise ValueError("Distributed training requires optimizer and scheduler factories.")
|
|
||||||
elif self.nprocs == 1 and not argument_case:
|
|
||||||
raise ValueError("Single process training requires optimizer and scheduler arguments.")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -36,8 +36,6 @@ class TrainContextBuilder:
|
||||||
self.config = config
|
self.config = config
|
||||||
self._context = TrainContext(
|
self._context = TrainContext(
|
||||||
model=config.model,
|
model=config.model,
|
||||||
optimizer=config.optimizer,
|
|
||||||
scheduler=config.scheduler,
|
|
||||||
world_size=get_world_size(),
|
world_size=get_world_size(),
|
||||||
rank=get_rank(),
|
rank=get_rank(),
|
||||||
)
|
)
|
||||||
|
|
@ -46,20 +44,17 @@ class TrainContextBuilder:
|
||||||
self._context.model = self._context.model.to(device=device)
|
self._context.model = self._context.model.to(device=device)
|
||||||
|
|
||||||
if self.config.nprocs > 1:
|
if self.config.nprocs > 1:
|
||||||
|
|
||||||
fn = self.config.parallel_wrapper
|
fn = self.config.parallel_wrapper
|
||||||
optimizer_fn = self.config.optimizer_factory
|
|
||||||
scheduler_fn = self.config.scheduler_factory
|
|
||||||
|
|
||||||
self._context.model = fn(self._context.model)
|
self._context.model = fn(self._context.model)
|
||||||
self._context.optimizer = optimizer_fn(self._context.model.parameters())
|
|
||||||
self._context.scheduler = scheduler_fn(self._context.optimizer)
|
self._context.optimizer = self.config.optimizer_fn(self._context.model.parameters())
|
||||||
|
self._context.scheduler = self.config.scheduler_fn(self._context.optimizer)
|
||||||
|
|
||||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||||
if checkpoint is None:
|
if checkpoint is None:
|
||||||
checkpoint = Checkpoint(
|
checkpoint = Checkpoint(
|
||||||
optimizer_state_dict=self.config.optimizer.state_dict(),
|
optimizer_state_dict=self._context.optimizer.state_dict(),
|
||||||
scheduler_state_dict=self.config.scheduler.state_dict() if self.config.scheduler is not None else None,
|
scheduler_state_dict=self._context.scheduler.state_dict(),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# resume from the assigned checkpoint or assigned iteration
|
# resume from the assigned checkpoint or assigned iteration
|
||||||
|
|
@ -102,6 +97,5 @@ class TrainContextBuilder:
|
||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
def build(self) -> TrainContext:
|
def build(self) -> TrainContext:
|
||||||
return self._context
|
return self._context
|
||||||
|
|
@ -5,6 +5,8 @@ import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import torch.distributed.fsdp as fsdp
|
import torch.distributed.fsdp as fsdp
|
||||||
|
|
||||||
|
from torch.distributed.fsdp.api import StateDictType
|
||||||
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
||||||
|
|
@ -77,6 +79,13 @@ def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
|
||||||
def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler.LRScheduler:
|
def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler.LRScheduler:
|
||||||
return SchedulerFactory.load(optimizer, **kwargs)
|
return SchedulerFactory.load(optimizer, **kwargs)
|
||||||
|
|
||||||
|
def prepare_checkpoint(model: nn.Module, optimizer: optim.Optimizer) -> dict:
|
||||||
|
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
|
||||||
|
model_state_dict = model.state_dict()
|
||||||
|
optim_state_dict = FSDP.optim_state_dict(model, optimizer)
|
||||||
|
|
||||||
|
return model_state_dict, optim_state_dict
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
train_type: str,
|
train_type: str,
|
||||||
param_path: str,
|
param_path: str,
|
||||||
|
|
@ -133,22 +142,18 @@ def train(
|
||||||
warmup_steps=warmup_steps,
|
warmup_steps=warmup_steps,
|
||||||
total_steps=len(dataset) * n_epoch // (batch_size * nprocs),
|
total_steps=len(dataset) * n_epoch // (batch_size * nprocs),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
optimizer_fn = partial(create_optimizer, **{"lr": max_lr, "betas": (adamw_beta1, adamw_beta2), "weight_decay": adamw_weight_decay})
|
optimizer_fn = partial(create_optimizer,
|
||||||
scheduler_fn = partial(create_scheduler, **{"schedule_config": schedule_config})
|
**{"lr": max_lr, "betas": (adamw_beta1, adamw_beta2), "weight_decay": adamw_weight_decay})
|
||||||
optimizer, scheduler = None, None
|
scheduler_fn = partial(create_scheduler,
|
||||||
|
**{"schedule_config": schedule_config})
|
||||||
if nprocs == 1:
|
|
||||||
optimizer = optimizer_fn(model.parameters())
|
|
||||||
scheduler = scheduler_fn(optimizer)
|
|
||||||
|
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
model=model,
|
model=model,
|
||||||
strategy=train_type,
|
strategy=train_type,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
optimizer=optimizer,
|
optimizer_fn=optimizer_fn,
|
||||||
scheduler=scheduler,
|
scheduler_fn=scheduler_fn,
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
n_epoch=n_epoch,
|
n_epoch=n_epoch,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
|
@ -162,8 +167,7 @@ def train(
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
nprocs=nprocs,
|
nprocs=nprocs,
|
||||||
parallel_wrapper=fsdp_wrap,
|
parallel_wrapper=fsdp_wrap,
|
||||||
optimizer_factory=optimizer_fn,
|
state_dict_fn=prepare_checkpoint,
|
||||||
scheduler_factory=scheduler_fn,
|
|
||||||
device_ids=device_ids,
|
device_ids=device_ids,
|
||||||
device_type=device_type,
|
device_type=device_type,
|
||||||
extra_kwargs=kwargs,
|
extra_kwargs=kwargs,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue