refactor(parallel): 重构parallel模块

This commit is contained in:
ViperEkura 2025-12-13 22:16:17 +08:00
parent a30ddca517
commit d882f65579
8 changed files with 120 additions and 91 deletions

View File

@ -90,7 +90,7 @@ class TrainConfig:
)
# others
kwargs: dict = field(
extra_kwargs: dict = field(
default_factory=dict,
metadata={"help": "Other arguments."}
)

View File

@ -1,12 +1,10 @@
from khaosz.parallel.utils import (
from khaosz.parallel.setup import (
get_world_size,
get_rank,
get_device_count,
get_current_device,
get_available_backend,
setup_parallel,
only_on_rank,
run_on_rank,
setup_parallel,
spawn_parallel_fn
)
@ -18,12 +16,10 @@ from khaosz.parallel.module import (
__all__ = [
"get_world_size",
"get_rank",
"get_device_count",
"get_current_device",
"get_available_backend",
"setup_parallel",
"only_on_rank",
"run_on_rank",
"setup_parallel",
"spawn_parallel_fn",
"RowParallelLinear",

View File

@ -3,38 +3,11 @@ import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from typing import Callable
from functools import wraps
from contextlib import contextmanager
def get_device_count() -> int:
if torch.cuda.is_available():
return torch.cuda.device_count()
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
return torch.xpu.device_count()
elif hasattr(torch, 'mps') and torch.mps.is_available():
return 1
else:
return 1
def get_current_device() -> torch.device:
if torch.cuda.is_available():
return torch.device(f"cuda:{torch.cuda.current_device()}")
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
return torch.device(f"xpu:{torch.xpu.current_device()}")
elif hasattr(torch, 'mps') and torch.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")
def get_available_backend():
if torch.cuda.is_available():
return "nccl"
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
return "ccl" # Intel XPU use ccl
else:
return "gloo"
def get_world_size() -> int:
if dist.is_available() and dist.is_initialized():
return dist.get_world_size()
@ -47,10 +20,21 @@ def get_rank() -> int:
else:
return 0
def get_current_device():
if torch.cuda.is_available():
return torch.device(f"cuda:{torch.cuda.current_device()}")
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
return torch.device(f"xpu:{torch.xpu.current_device()}")
elif hasattr(torch, 'mps') and torch.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")
@contextmanager
def setup_parallel(
rank: int = 0,
world_size: int = 1,
rank: int,
world_size: int,
backend: str = "nccl",
master_addr: str = "localhost",
master_port: str = "29500"
):
@ -69,11 +53,9 @@ def setup_parallel(
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['LOCAL_RANK'] = str(rank)
backend = get_available_backend()
dist.init_process_group(
backend=backend,
init_method="env://",
init_method=f"tcp://{master_addr}:{master_port}",
rank=rank,
world_size=world_size
)
@ -89,24 +71,7 @@ def setup_parallel(
if dist.is_initialized():
dist.destroy_process_group()
@contextmanager
def run_on_rank(rank=0, sync_before=True, sync_after=True):
"""
context manager to run a function only on a specific rank.
"""
is_main_proc = (get_rank() == rank)
if dist.is_initialized() and sync_before:
dist.barrier()
try:
yield is_main_proc
finally:
if dist.is_initialized() and sync_after:
dist.barrier()
def only_on_rank(rank=0):
def only_on_rank(rank, sync=False):
"""
decorator to run a function only on a specific rank.
"""
@ -116,36 +81,31 @@ def only_on_rank(rank=0):
def wrapper(*args, **kwargs):
if get_rank() == rank:
return func(*args, **kwargs)
else:
return None
if sync:
dist.barrier()
return wrapper
return decorator
def wrapper_spawn_func(rank, world_size, func, kwargs_dict):
with setup_parallel(rank, world_size):
func(**kwargs_dict)
def wrapper_spawn_func(rank, world_size, backend, func, kwargs):
with setup_parallel(rank, world_size, backend):
func(**kwargs)
def spawn_parallel_fn(func, world_size=None, **kwargs):
if world_size is None:
world_size = get_device_count()
if world_size < 1:
raise ValueError("world_size must be greater than 0")
device_count = get_device_count()
if world_size > device_count:
raise ValueError(f"world_size ({world_size}) exceeds available devices ({device_count})")
def spawn_parallel_fn(func: Callable, world_size: int, backend: str, **kwargs):
if world_size == 1:
func(**kwargs)
return
# clear environment variables
for key in ['MASTER_ADDR', 'MASTER_PORT', 'RANK', 'WORLD_SIZE', 'LOCAL_RANK']:
if key in os.environ:
del os.environ[key]
mp.spawn(
wrapper_spawn_func,
nprocs=world_size,
args=(world_size, func, kwargs),
args=(world_size, backend, func, kwargs),
join=True
)

View File

@ -8,6 +8,7 @@ from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import LRScheduler
from typing import List, Optional, Protocol, TYPE_CHECKING
from khaosz.parallel import only_on_rank
from khaosz.trainer.metric_util import (
grad_max,
grad_min,
@ -96,6 +97,7 @@ class CheckpointCallback(TrainCallback):
self.save_dir = save_dir
self.last_ckpt_iter = 0
@only_on_rank(0)
def _save_checkpoint(self, context: 'TrainContext'):
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}iter_{context.iteration}")
context.checkpoint = Checkpoint(
@ -127,6 +129,7 @@ class ProgressBarCallback(TrainCallback):
self.num_epoch = num_epoch
self.progress_bar: tqdm = None
@only_on_rank(0)
def on_epoch_begin(self, context: 'TrainContext'):
self.progress_bar = tqdm(
context.dataloader,
@ -134,6 +137,7 @@ class ProgressBarCallback(TrainCallback):
dynamic_ncols=True
)
@only_on_rank(0)
def on_batch_end(self, context: 'TrainContext'):
self.progress_bar.set_postfix({
"loss": f"{context.loss:.4f}",
@ -141,6 +145,7 @@ class ProgressBarCallback(TrainCallback):
})
self.progress_bar.update(1)
@only_on_rank(0)
def on_epoch_end(self, context: 'TrainContext'):
_ = context
if self.progress_bar:
@ -220,6 +225,7 @@ class StepMonitorCallback(TrainCallback):
except Exception:
raise
@only_on_rank(0)
def on_step_end(self, context: 'TrainContext'):
if self.step_num % self.log_interval == 0:
self._handle_log(context)

View File

@ -7,7 +7,7 @@ from khaosz.data import ResumableDistributedSampler
from khaosz.trainer.checkpoint import Checkpoint
from khaosz.trainer.strategy import StrategyFactory, BaseStrategy
from khaosz.config.train_config import TrainConfig
from khaosz.parallel.utils import get_current_device, get_world_size, get_rank
from khaosz.parallel.setup import get_current_device, get_world_size, get_rank
from dataclasses import dataclass, field
from typing import Optional, Self
@ -85,7 +85,7 @@ class TrainContextBuilder:
model=self.config.model,
train_type=self.config.strategy,
device=device,
**self.config.kwargs
**self.config.extra_kwargs
)
return self

View File

@ -8,7 +8,8 @@ from khaosz.trainer.train_callback import (
GradientClippingCallback,
SchedulerCallback
)
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder, Checkpoint
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
from khaosz.trainer.checkpoint import Checkpoint
logger = logging.getLogger(__name__)

39
tests/test_parallel.py Normal file
View File

@ -0,0 +1,39 @@
import torch
import torch.distributed as dist
from khaosz.parallel import (
get_rank,
only_on_rank,
spawn_parallel_fn
)
@only_on_rank(0)
def _test_only_on_rank_helper():
return True
def only_on_rank():
result = _test_only_on_rank_helper()
if get_rank() == 0:
assert result is True
else:
assert result is None
def all_reduce():
x = torch.tensor([get_rank()], dtype=torch.int)
dist.all_reduce(x, op=dist.ReduceOp.SUM)
expected_sum = sum(range(dist.get_world_size()))
assert x.item() == expected_sum
def test_spawn_only_on_rank():
spawn_parallel_fn(
only_on_rank,
world_size=2,
backend="gloo"
)
def test_spawn_all_reduce():
spawn_parallel_fn(
all_reduce,
world_size=2,
backend="gloo"
)

View File

@ -1,11 +1,14 @@
import os
import argparse
import torch
import torch.nn as nn
import torch.distributed.fsdp as fsdp
from torch.optim import AdamW
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
from khaosz.trainer import Trainer, SchedulerFactory
from khaosz.data import DatasetLoader
from khaosz.parallel import get_current_device, spawn_parallel_fn
def parse_args() -> argparse.Namespace:
@ -37,10 +40,26 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.")
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
args = parser.parse_args()
return args
def fsdp_wrap(model: nn.Module):
fsdp_model = fsdp.FullyShardedDataParallel(
model,
sharding_strategy=fsdp.ShardingStrategy.SHARD_GRAD_OP,
mixed_precision=fsdp.MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
backward_prefetch=fsdp.BackwardPrefetch.BACKWARD_PRE
)
return fsdp_model
def train(
train_type: str,
param_path: str,
@ -65,6 +84,7 @@ def train(
pin_memory: bool,
window_size: int,
stride: int,
nprocs: int
):
assert train_type in ["seq", "sft", "dpo"]
assert os.path.exists(param_path)
@ -76,8 +96,8 @@ def train(
window_size = parameter.config.m_len
model = parameter.model
device = torch.device("cuda")
model = model.to(device=device, dtype=torch.bfloat16)
current_device = get_current_device()
model = fsdp_wrap(model.to(device=current_device, dtype=torch.bfloat16))
kwargs = {
"dpo_beta": dpo_beta,
@ -90,8 +110,7 @@ def train(
train_type=train_type,
load_path=data_root_path,
window_size=window_size,
stride=stride,
**kwargs
stride=stride
)
param_groups = [
@ -107,7 +126,7 @@ def train(
schedule_config = CosineScheduleConfig(
warmup_steps=warmup_steps,
total_steps=len(dataset) * n_epoch // batch_size,
total_steps=len(dataset) * n_epoch // (batch_size * nprocs),
)
scheduler = SchedulerFactory.load(optimizer, schedule_config)
@ -128,7 +147,9 @@ def train(
max_grad_norm=max_grad_norm,
random_seed=random_seed,
num_workers=num_workers,
pin_memory=pin_memory
pin_memory=pin_memory,
nprocs=nprocs,
extra_kwargs=kwargs,
)
trainer = Trainer(train_config)
@ -137,4 +158,10 @@ def train(
if __name__ == "__main__":
args = parse_args()
train(**vars(args))
spawn_parallel_fn(
train,
world_size=args.nprocs,
backend="nccl",
**vars(args)
)