feat(trainer): 支持分布式训练配置与检查点加载优化
This commit is contained in:
parent
eab7a51bb6
commit
573f041c51
|
|
@ -4,7 +4,7 @@ from torch.optim import Optimizer
|
|||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -88,6 +88,22 @@ class TrainConfig:
|
|||
default=1,
|
||||
metadata={"help": "Number of processes for distributed training."}
|
||||
)
|
||||
backend: str = field(
|
||||
default="nccl",
|
||||
metadata={"help": "Distributed training backend."}
|
||||
)
|
||||
master_addr: str = field(
|
||||
default="localhost",
|
||||
metadata={"help": "Master address for distributed training."}
|
||||
)
|
||||
master_port: str = field(
|
||||
default="29500",
|
||||
metadata={"help": "Master port for distributed training."}
|
||||
)
|
||||
parallel_fn: Optional[Callable] = field(
|
||||
default=None,
|
||||
metadata={"help": "Parallel function for training."}
|
||||
)
|
||||
|
||||
# others
|
||||
extra_kwargs: dict = field(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from khaosz.trainer.trainer import Trainer
|
||||
from khaosz.trainer.checkpoint import Checkpoint
|
||||
from khaosz.trainer.strategy import StrategyFactory
|
||||
from khaosz.trainer.schedule import SchedulerFactory
|
||||
|
||||
|
|
@ -15,6 +16,9 @@ __all__ = [
|
|||
# trainer
|
||||
"Trainer",
|
||||
|
||||
# checkpoint
|
||||
"Checkpoint",
|
||||
|
||||
# factory
|
||||
"StrategyFactory",
|
||||
"SchedulerFactory",
|
||||
|
|
|
|||
|
|
@ -38,19 +38,23 @@ class Checkpoint:
|
|||
if save_metric_plot and self.metrics:
|
||||
self._plot_metrics()
|
||||
|
||||
def load(self, save_dir: str) -> "Checkpoint":
|
||||
if not os.path.exists(save_dir):
|
||||
raise FileNotFoundError(f"Checkpoint directory {save_dir} does not exist.")
|
||||
@classmethod
|
||||
def load(cls, save_dir: str) -> "Checkpoint":
|
||||
checkpoint_path = os.path.join(save_dir, "train_state.pkl")
|
||||
|
||||
with open(os.path.join(save_dir, "train_state.pkl"), "rb") as f:
|
||||
if not os.path.exists(checkpoint_path):
|
||||
raise FileNotFoundError(f"Checkpoint file {checkpoint_path} does not exist.")
|
||||
|
||||
with open(checkpoint_path, "rb") as f:
|
||||
train_state = pkl.load(f)
|
||||
self.epoch = train_state["epoch"]
|
||||
self.iteration = train_state["iteration"]
|
||||
self.metrics = train_state["metrics"]
|
||||
self.optimizer_state = train_state["optimizer_state"]
|
||||
self.scheduler_state = train_state["scheduler_state"]
|
||||
|
||||
return self
|
||||
return cls(
|
||||
optimizer_state=train_state["optimizer_state"],
|
||||
scheduler_state=train_state["scheduler_state"],
|
||||
epoch=train_state["epoch"],
|
||||
iteration=train_state["iteration"],
|
||||
metrics=train_state["metrics"]
|
||||
)
|
||||
|
||||
def _plot_metrics(self):
|
||||
for metric_name, metric_value in self.metrics.items():
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ class CheckpointCallback(TrainCallback):
|
|||
|
||||
@only_on_rank(0)
|
||||
def _save_checkpoint(self, context: 'TrainContext'):
|
||||
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}iter_{context.iteration}")
|
||||
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}")
|
||||
context.checkpoint = Checkpoint(
|
||||
context.optimizer.state_dict(),
|
||||
context.scheduler.state_dict(),
|
||||
|
|
|
|||
|
|
@ -89,5 +89,12 @@ class TrainContextBuilder:
|
|||
)
|
||||
return self
|
||||
|
||||
def with_parallel_fn(self) -> Self:
|
||||
fn = self.config.parallel_fn
|
||||
if fn is not None:
|
||||
self._context.model = fn(self._context.model)
|
||||
|
||||
return self
|
||||
|
||||
def build(self) -> TrainContext:
|
||||
return self._context
|
||||
|
|
@ -10,6 +10,7 @@ from khaosz.trainer.train_callback import (
|
|||
)
|
||||
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
|
||||
from khaosz.trainer.checkpoint import Checkpoint
|
||||
from khaosz.parallel.setup import spawn_parallel_fn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -37,6 +38,7 @@ class Trainer:
|
|||
.with_checkpoint(checkpoint)
|
||||
.with_dataloader()
|
||||
.with_strategy()
|
||||
.with_parallel_fn()
|
||||
.build())
|
||||
|
||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||
|
|
@ -44,8 +46,19 @@ class Trainer:
|
|||
method = getattr(callback, method_name, None)
|
||||
if method:
|
||||
method(context)
|
||||
|
||||
def train(self, checkpoint: Optional[Checkpoint] = None):
|
||||
config = self.train_config
|
||||
spawn_parallel_fn(
|
||||
self._train_impl,
|
||||
backend=config.backend,
|
||||
world_size=config.nprocs,
|
||||
master_addr=config.master_addr,
|
||||
master_port=config.master_port,
|
||||
checkpoint=checkpoint
|
||||
)
|
||||
|
||||
def train(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
|
||||
def _train_impl(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
|
||||
context = self._build_context(checkpoint)
|
||||
self._call_callbacks('on_train_begin', context)
|
||||
|
||||
|
|
@ -84,5 +97,4 @@ class Trainer:
|
|||
self._call_callbacks('on_error', context)
|
||||
raise
|
||||
finally:
|
||||
self._call_callbacks('on_train_end', context)
|
||||
return context.checkpoint
|
||||
self._call_callbacks('on_train_end', context)
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from khaosz.config import *
|
||||
|
|
@ -31,10 +32,14 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
|||
checkpoint = None
|
||||
try:
|
||||
checkpoint = trainer.train()
|
||||
assert checkpoint.iteration == 2
|
||||
except Exception:
|
||||
# Handle any exceptions
|
||||
pass
|
||||
|
||||
checkpoint = trainer.train(checkpoint)
|
||||
load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2")
|
||||
checkpoint = Checkpoint.load(load_dir)
|
||||
trainer.train(checkpoint)
|
||||
|
||||
load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10")
|
||||
checkpoint = Checkpoint.load(load_dir)
|
||||
assert checkpoint.iteration == 10
|
||||
|
|
@ -8,7 +8,6 @@ from torch.optim import AdamW
|
|||
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
||||
from khaosz.trainer import Trainer, SchedulerFactory
|
||||
from khaosz.data import DatasetLoader
|
||||
from khaosz.parallel import get_current_device, spawn_parallel_fn
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
|
|
@ -96,8 +95,6 @@ def train(
|
|||
window_size = parameter.config.m_len
|
||||
|
||||
model = parameter.model
|
||||
current_device = get_current_device()
|
||||
model = fsdp_wrap(model.to(device=current_device, dtype=torch.bfloat16))
|
||||
|
||||
kwargs = {
|
||||
"dpo_beta": dpo_beta,
|
||||
|
|
@ -150,6 +147,7 @@ def train(
|
|||
pin_memory=pin_memory,
|
||||
nprocs=nprocs,
|
||||
extra_kwargs=kwargs,
|
||||
parallel_fn=fsdp_wrap
|
||||
)
|
||||
|
||||
trainer = Trainer(train_config)
|
||||
|
|
@ -158,10 +156,4 @@ def train(
|
|||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
spawn_parallel_fn(
|
||||
train,
|
||||
world_size=args.nprocs,
|
||||
backend="nccl",
|
||||
**vars(args)
|
||||
)
|
||||
train(**vars(args))
|
||||
Loading…
Reference in New Issue