feat(trainer): 支持分布式训练配置与检查点加载优化
This commit is contained in:
parent
eab7a51bb6
commit
573f041c51
|
|
@ -4,7 +4,7 @@ from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -88,6 +88,22 @@ class TrainConfig:
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "Number of processes for distributed training."}
|
metadata={"help": "Number of processes for distributed training."}
|
||||||
)
|
)
|
||||||
|
backend: str = field(
|
||||||
|
default="nccl",
|
||||||
|
metadata={"help": "Distributed training backend."}
|
||||||
|
)
|
||||||
|
master_addr: str = field(
|
||||||
|
default="localhost",
|
||||||
|
metadata={"help": "Master address for distributed training."}
|
||||||
|
)
|
||||||
|
master_port: str = field(
|
||||||
|
default="29500",
|
||||||
|
metadata={"help": "Master port for distributed training."}
|
||||||
|
)
|
||||||
|
parallel_fn: Optional[Callable] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Parallel function for training."}
|
||||||
|
)
|
||||||
|
|
||||||
# others
|
# others
|
||||||
extra_kwargs: dict = field(
|
extra_kwargs: dict = field(
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from khaosz.trainer.trainer import Trainer
|
from khaosz.trainer.trainer import Trainer
|
||||||
|
from khaosz.trainer.checkpoint import Checkpoint
|
||||||
from khaosz.trainer.strategy import StrategyFactory
|
from khaosz.trainer.strategy import StrategyFactory
|
||||||
from khaosz.trainer.schedule import SchedulerFactory
|
from khaosz.trainer.schedule import SchedulerFactory
|
||||||
|
|
||||||
|
|
@ -15,6 +16,9 @@ __all__ = [
|
||||||
# trainer
|
# trainer
|
||||||
"Trainer",
|
"Trainer",
|
||||||
|
|
||||||
|
# checkpoint
|
||||||
|
"Checkpoint",
|
||||||
|
|
||||||
# factory
|
# factory
|
||||||
"StrategyFactory",
|
"StrategyFactory",
|
||||||
"SchedulerFactory",
|
"SchedulerFactory",
|
||||||
|
|
|
||||||
|
|
@ -38,19 +38,23 @@ class Checkpoint:
|
||||||
if save_metric_plot and self.metrics:
|
if save_metric_plot and self.metrics:
|
||||||
self._plot_metrics()
|
self._plot_metrics()
|
||||||
|
|
||||||
def load(self, save_dir: str) -> "Checkpoint":
|
@classmethod
|
||||||
if not os.path.exists(save_dir):
|
def load(cls, save_dir: str) -> "Checkpoint":
|
||||||
raise FileNotFoundError(f"Checkpoint directory {save_dir} does not exist.")
|
checkpoint_path = os.path.join(save_dir, "train_state.pkl")
|
||||||
|
|
||||||
with open(os.path.join(save_dir, "train_state.pkl"), "rb") as f:
|
if not os.path.exists(checkpoint_path):
|
||||||
|
raise FileNotFoundError(f"Checkpoint file {checkpoint_path} does not exist.")
|
||||||
|
|
||||||
|
with open(checkpoint_path, "rb") as f:
|
||||||
train_state = pkl.load(f)
|
train_state = pkl.load(f)
|
||||||
self.epoch = train_state["epoch"]
|
|
||||||
self.iteration = train_state["iteration"]
|
|
||||||
self.metrics = train_state["metrics"]
|
|
||||||
self.optimizer_state = train_state["optimizer_state"]
|
|
||||||
self.scheduler_state = train_state["scheduler_state"]
|
|
||||||
|
|
||||||
return self
|
return cls(
|
||||||
|
optimizer_state=train_state["optimizer_state"],
|
||||||
|
scheduler_state=train_state["scheduler_state"],
|
||||||
|
epoch=train_state["epoch"],
|
||||||
|
iteration=train_state["iteration"],
|
||||||
|
metrics=train_state["metrics"]
|
||||||
|
)
|
||||||
|
|
||||||
def _plot_metrics(self):
|
def _plot_metrics(self):
|
||||||
for metric_name, metric_value in self.metrics.items():
|
for metric_name, metric_value in self.metrics.items():
|
||||||
|
|
|
||||||
|
|
@ -99,7 +99,7 @@ class CheckpointCallback(TrainCallback):
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
def _save_checkpoint(self, context: 'TrainContext'):
|
def _save_checkpoint(self, context: 'TrainContext'):
|
||||||
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}iter_{context.iteration}")
|
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}")
|
||||||
context.checkpoint = Checkpoint(
|
context.checkpoint = Checkpoint(
|
||||||
context.optimizer.state_dict(),
|
context.optimizer.state_dict(),
|
||||||
context.scheduler.state_dict(),
|
context.scheduler.state_dict(),
|
||||||
|
|
|
||||||
|
|
@ -89,5 +89,12 @@ class TrainContextBuilder:
|
||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def with_parallel_fn(self) -> Self:
|
||||||
|
fn = self.config.parallel_fn
|
||||||
|
if fn is not None:
|
||||||
|
self._context.model = fn(self._context.model)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
def build(self) -> TrainContext:
|
def build(self) -> TrainContext:
|
||||||
return self._context
|
return self._context
|
||||||
|
|
@ -10,6 +10,7 @@ from khaosz.trainer.train_callback import (
|
||||||
)
|
)
|
||||||
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
|
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
|
||||||
from khaosz.trainer.checkpoint import Checkpoint
|
from khaosz.trainer.checkpoint import Checkpoint
|
||||||
|
from khaosz.parallel.setup import spawn_parallel_fn
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -37,6 +38,7 @@ class Trainer:
|
||||||
.with_checkpoint(checkpoint)
|
.with_checkpoint(checkpoint)
|
||||||
.with_dataloader()
|
.with_dataloader()
|
||||||
.with_strategy()
|
.with_strategy()
|
||||||
|
.with_parallel_fn()
|
||||||
.build())
|
.build())
|
||||||
|
|
||||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||||
|
|
@ -44,8 +46,19 @@ class Trainer:
|
||||||
method = getattr(callback, method_name, None)
|
method = getattr(callback, method_name, None)
|
||||||
if method:
|
if method:
|
||||||
method(context)
|
method(context)
|
||||||
|
|
||||||
|
def train(self, checkpoint: Optional[Checkpoint] = None):
|
||||||
|
config = self.train_config
|
||||||
|
spawn_parallel_fn(
|
||||||
|
self._train_impl,
|
||||||
|
backend=config.backend,
|
||||||
|
world_size=config.nprocs,
|
||||||
|
master_addr=config.master_addr,
|
||||||
|
master_port=config.master_port,
|
||||||
|
checkpoint=checkpoint
|
||||||
|
)
|
||||||
|
|
||||||
def train(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
|
def _train_impl(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
|
||||||
context = self._build_context(checkpoint)
|
context = self._build_context(checkpoint)
|
||||||
self._call_callbacks('on_train_begin', context)
|
self._call_callbacks('on_train_begin', context)
|
||||||
|
|
||||||
|
|
@ -84,5 +97,4 @@ class Trainer:
|
||||||
self._call_callbacks('on_error', context)
|
self._call_callbacks('on_error', context)
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
self._call_callbacks('on_train_end', context)
|
self._call_callbacks('on_train_end', context)
|
||||||
return context.checkpoint
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from khaosz.config import *
|
from khaosz.config import *
|
||||||
|
|
@ -31,10 +32,14 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||||
checkpoint = None
|
checkpoint = None
|
||||||
try:
|
try:
|
||||||
checkpoint = trainer.train()
|
checkpoint = trainer.train()
|
||||||
assert checkpoint.iteration == 2
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# Handle any exceptions
|
# Handle any exceptions
|
||||||
pass
|
pass
|
||||||
|
|
||||||
checkpoint = trainer.train(checkpoint)
|
load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2")
|
||||||
|
checkpoint = Checkpoint.load(load_dir)
|
||||||
|
trainer.train(checkpoint)
|
||||||
|
|
||||||
|
load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10")
|
||||||
|
checkpoint = Checkpoint.load(load_dir)
|
||||||
assert checkpoint.iteration == 10
|
assert checkpoint.iteration == 10
|
||||||
|
|
@ -8,7 +8,6 @@ from torch.optim import AdamW
|
||||||
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
||||||
from khaosz.trainer import Trainer, SchedulerFactory
|
from khaosz.trainer import Trainer, SchedulerFactory
|
||||||
from khaosz.data import DatasetLoader
|
from khaosz.data import DatasetLoader
|
||||||
from khaosz.parallel import get_current_device, spawn_parallel_fn
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
def parse_args() -> argparse.Namespace:
|
||||||
|
|
@ -96,8 +95,6 @@ def train(
|
||||||
window_size = parameter.config.m_len
|
window_size = parameter.config.m_len
|
||||||
|
|
||||||
model = parameter.model
|
model = parameter.model
|
||||||
current_device = get_current_device()
|
|
||||||
model = fsdp_wrap(model.to(device=current_device, dtype=torch.bfloat16))
|
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"dpo_beta": dpo_beta,
|
"dpo_beta": dpo_beta,
|
||||||
|
|
@ -150,6 +147,7 @@ def train(
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
nprocs=nprocs,
|
nprocs=nprocs,
|
||||||
extra_kwargs=kwargs,
|
extra_kwargs=kwargs,
|
||||||
|
parallel_fn=fsdp_wrap
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(train_config)
|
trainer = Trainer(train_config)
|
||||||
|
|
@ -158,10 +156,4 @@ def train(
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
train(**vars(args))
|
||||||
spawn_parallel_fn(
|
|
||||||
train,
|
|
||||||
world_size=args.nprocs,
|
|
||||||
backend="nccl",
|
|
||||||
**vars(args)
|
|
||||||
)
|
|
||||||
Loading…
Reference in New Issue