refactor(parallel): 重构parallel模块

This commit is contained in:
ViperEkura 2025-12-13 22:16:17 +08:00
parent a30ddca517
commit d882f65579
8 changed files with 120 additions and 91 deletions

View File

@ -90,7 +90,7 @@ class TrainConfig:
) )
# others # others
kwargs: dict = field( extra_kwargs: dict = field(
default_factory=dict, default_factory=dict,
metadata={"help": "Other arguments."} metadata={"help": "Other arguments."}
) )

View File

@ -1,12 +1,10 @@
from khaosz.parallel.utils import ( from khaosz.parallel.setup import (
get_world_size, get_world_size,
get_rank, get_rank,
get_device_count, get_current_device,
get_current_device,
get_available_backend,
setup_parallel,
only_on_rank, only_on_rank,
run_on_rank, setup_parallel,
spawn_parallel_fn spawn_parallel_fn
) )
@ -18,12 +16,10 @@ from khaosz.parallel.module import (
__all__ = [ __all__ = [
"get_world_size", "get_world_size",
"get_rank", "get_rank",
"get_device_count",
"get_current_device", "get_current_device",
"get_available_backend",
"setup_parallel",
"only_on_rank", "only_on_rank",
"run_on_rank", "setup_parallel",
"spawn_parallel_fn", "spawn_parallel_fn",
"RowParallelLinear", "RowParallelLinear",

View File

@ -3,38 +3,11 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from typing import Callable
from functools import wraps from functools import wraps
from contextlib import contextmanager from contextlib import contextmanager
def get_device_count() -> int:
if torch.cuda.is_available():
return torch.cuda.device_count()
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
return torch.xpu.device_count()
elif hasattr(torch, 'mps') and torch.mps.is_available():
return 1
else:
return 1
def get_current_device() -> torch.device:
if torch.cuda.is_available():
return torch.device(f"cuda:{torch.cuda.current_device()}")
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
return torch.device(f"xpu:{torch.xpu.current_device()}")
elif hasattr(torch, 'mps') and torch.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")
def get_available_backend():
if torch.cuda.is_available():
return "nccl"
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
return "ccl" # Intel XPU use ccl
else:
return "gloo"
def get_world_size() -> int: def get_world_size() -> int:
if dist.is_available() and dist.is_initialized(): if dist.is_available() and dist.is_initialized():
return dist.get_world_size() return dist.get_world_size()
@ -47,10 +20,21 @@ def get_rank() -> int:
else: else:
return 0 return 0
def get_current_device():
if torch.cuda.is_available():
return torch.device(f"cuda:{torch.cuda.current_device()}")
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
return torch.device(f"xpu:{torch.xpu.current_device()}")
elif hasattr(torch, 'mps') and torch.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")
@contextmanager @contextmanager
def setup_parallel( def setup_parallel(
rank: int = 0, rank: int,
world_size: int = 1, world_size: int,
backend: str = "nccl",
master_addr: str = "localhost", master_addr: str = "localhost",
master_port: str = "29500" master_port: str = "29500"
): ):
@ -69,11 +53,9 @@ def setup_parallel(
os.environ['WORLD_SIZE'] = str(world_size) os.environ['WORLD_SIZE'] = str(world_size)
os.environ['LOCAL_RANK'] = str(rank) os.environ['LOCAL_RANK'] = str(rank)
backend = get_available_backend()
dist.init_process_group( dist.init_process_group(
backend=backend, backend=backend,
init_method="env://", init_method=f"tcp://{master_addr}:{master_port}",
rank=rank, rank=rank,
world_size=world_size world_size=world_size
) )
@ -89,24 +71,7 @@ def setup_parallel(
if dist.is_initialized(): if dist.is_initialized():
dist.destroy_process_group() dist.destroy_process_group()
@contextmanager def only_on_rank(rank, sync=False):
def run_on_rank(rank=0, sync_before=True, sync_after=True):
"""
context manager to run a function only on a specific rank.
"""
is_main_proc = (get_rank() == rank)
if dist.is_initialized() and sync_before:
dist.barrier()
try:
yield is_main_proc
finally:
if dist.is_initialized() and sync_after:
dist.barrier()
def only_on_rank(rank=0):
""" """
decorator to run a function only on a specific rank. decorator to run a function only on a specific rank.
""" """
@ -116,36 +81,31 @@ def only_on_rank(rank=0):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if get_rank() == rank: if get_rank() == rank:
return func(*args, **kwargs) return func(*args, **kwargs)
else: if sync:
return None dist.barrier()
return wrapper return wrapper
return decorator return decorator
def wrapper_spawn_func(rank, world_size, func, kwargs_dict): def wrapper_spawn_func(rank, world_size, backend, func, kwargs):
with setup_parallel(rank, world_size): with setup_parallel(rank, world_size, backend):
func(**kwargs_dict) func(**kwargs)
def spawn_parallel_fn(func, world_size=None, **kwargs): def spawn_parallel_fn(func: Callable, world_size: int, backend: str, **kwargs):
if world_size is None:
world_size = get_device_count()
if world_size < 1:
raise ValueError("world_size must be greater than 0")
device_count = get_device_count()
if world_size > device_count:
raise ValueError(f"world_size ({world_size}) exceeds available devices ({device_count})")
if world_size == 1: if world_size == 1:
func(**kwargs) func(**kwargs)
return return
# clear environment variables
for key in ['MASTER_ADDR', 'MASTER_PORT', 'RANK', 'WORLD_SIZE', 'LOCAL_RANK']:
if key in os.environ:
del os.environ[key]
mp.spawn( mp.spawn(
wrapper_spawn_func, wrapper_spawn_func,
nprocs=world_size, nprocs=world_size,
args=(world_size, func, kwargs), args=(world_size, backend, func, kwargs),
join=True join=True
) )

View File

@ -8,6 +8,7 @@ from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import LRScheduler from torch.optim.lr_scheduler import LRScheduler
from typing import List, Optional, Protocol, TYPE_CHECKING from typing import List, Optional, Protocol, TYPE_CHECKING
from khaosz.parallel import only_on_rank
from khaosz.trainer.metric_util import ( from khaosz.trainer.metric_util import (
grad_max, grad_max,
grad_min, grad_min,
@ -96,6 +97,7 @@ class CheckpointCallback(TrainCallback):
self.save_dir = save_dir self.save_dir = save_dir
self.last_ckpt_iter = 0 self.last_ckpt_iter = 0
@only_on_rank(0)
def _save_checkpoint(self, context: 'TrainContext'): def _save_checkpoint(self, context: 'TrainContext'):
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}iter_{context.iteration}") save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}iter_{context.iteration}")
context.checkpoint = Checkpoint( context.checkpoint = Checkpoint(
@ -127,6 +129,7 @@ class ProgressBarCallback(TrainCallback):
self.num_epoch = num_epoch self.num_epoch = num_epoch
self.progress_bar: tqdm = None self.progress_bar: tqdm = None
@only_on_rank(0)
def on_epoch_begin(self, context: 'TrainContext'): def on_epoch_begin(self, context: 'TrainContext'):
self.progress_bar = tqdm( self.progress_bar = tqdm(
context.dataloader, context.dataloader,
@ -134,6 +137,7 @@ class ProgressBarCallback(TrainCallback):
dynamic_ncols=True dynamic_ncols=True
) )
@only_on_rank(0)
def on_batch_end(self, context: 'TrainContext'): def on_batch_end(self, context: 'TrainContext'):
self.progress_bar.set_postfix({ self.progress_bar.set_postfix({
"loss": f"{context.loss:.4f}", "loss": f"{context.loss:.4f}",
@ -141,6 +145,7 @@ class ProgressBarCallback(TrainCallback):
}) })
self.progress_bar.update(1) self.progress_bar.update(1)
@only_on_rank(0)
def on_epoch_end(self, context: 'TrainContext'): def on_epoch_end(self, context: 'TrainContext'):
_ = context _ = context
if self.progress_bar: if self.progress_bar:
@ -219,7 +224,8 @@ class StepMonitorCallback(TrainCallback):
json.dump(log_data, f, indent=4) json.dump(log_data, f, indent=4)
except Exception: except Exception:
raise raise
@only_on_rank(0)
def on_step_end(self, context: 'TrainContext'): def on_step_end(self, context: 'TrainContext'):
if self.step_num % self.log_interval == 0: if self.step_num % self.log_interval == 0:
self._handle_log(context) self._handle_log(context)

View File

@ -7,7 +7,7 @@ from khaosz.data import ResumableDistributedSampler
from khaosz.trainer.checkpoint import Checkpoint from khaosz.trainer.checkpoint import Checkpoint
from khaosz.trainer.strategy import StrategyFactory, BaseStrategy from khaosz.trainer.strategy import StrategyFactory, BaseStrategy
from khaosz.config.train_config import TrainConfig from khaosz.config.train_config import TrainConfig
from khaosz.parallel.utils import get_current_device, get_world_size, get_rank from khaosz.parallel.setup import get_current_device, get_world_size, get_rank
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional, Self from typing import Optional, Self
@ -85,7 +85,7 @@ class TrainContextBuilder:
model=self.config.model, model=self.config.model,
train_type=self.config.strategy, train_type=self.config.strategy,
device=device, device=device,
**self.config.kwargs **self.config.extra_kwargs
) )
return self return self

View File

@ -8,7 +8,8 @@ from khaosz.trainer.train_callback import (
GradientClippingCallback, GradientClippingCallback,
SchedulerCallback SchedulerCallback
) )
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder, Checkpoint from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
from khaosz.trainer.checkpoint import Checkpoint
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

39
tests/test_parallel.py Normal file
View File

@ -0,0 +1,39 @@
import torch
import torch.distributed as dist
from khaosz.parallel import (
get_rank,
only_on_rank,
spawn_parallel_fn
)
@only_on_rank(0)
def _test_only_on_rank_helper():
return True
def only_on_rank():
result = _test_only_on_rank_helper()
if get_rank() == 0:
assert result is True
else:
assert result is None
def all_reduce():
x = torch.tensor([get_rank()], dtype=torch.int)
dist.all_reduce(x, op=dist.ReduceOp.SUM)
expected_sum = sum(range(dist.get_world_size()))
assert x.item() == expected_sum
def test_spawn_only_on_rank():
spawn_parallel_fn(
only_on_rank,
world_size=2,
backend="gloo"
)
def test_spawn_all_reduce():
spawn_parallel_fn(
all_reduce,
world_size=2,
backend="gloo"
)

View File

@ -1,11 +1,14 @@
import os import os
import argparse import argparse
import torch import torch
import torch.nn as nn
import torch.distributed.fsdp as fsdp
from torch.optim import AdamW from torch.optim import AdamW
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
from khaosz.trainer import Trainer, SchedulerFactory from khaosz.trainer import Trainer, SchedulerFactory
from khaosz.data import DatasetLoader from khaosz.data import DatasetLoader
from khaosz.parallel import get_current_device, spawn_parallel_fn
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:
@ -37,10 +40,26 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.") parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.")
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.") parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
args = parser.parse_args() args = parser.parse_args()
return args return args
def fsdp_wrap(model: nn.Module):
fsdp_model = fsdp.FullyShardedDataParallel(
model,
sharding_strategy=fsdp.ShardingStrategy.SHARD_GRAD_OP,
mixed_precision=fsdp.MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
backward_prefetch=fsdp.BackwardPrefetch.BACKWARD_PRE
)
return fsdp_model
def train( def train(
train_type: str, train_type: str,
param_path: str, param_path: str,
@ -65,6 +84,7 @@ def train(
pin_memory: bool, pin_memory: bool,
window_size: int, window_size: int,
stride: int, stride: int,
nprocs: int
): ):
assert train_type in ["seq", "sft", "dpo"] assert train_type in ["seq", "sft", "dpo"]
assert os.path.exists(param_path) assert os.path.exists(param_path)
@ -76,8 +96,8 @@ def train(
window_size = parameter.config.m_len window_size = parameter.config.m_len
model = parameter.model model = parameter.model
device = torch.device("cuda") current_device = get_current_device()
model = model.to(device=device, dtype=torch.bfloat16) model = fsdp_wrap(model.to(device=current_device, dtype=torch.bfloat16))
kwargs = { kwargs = {
"dpo_beta": dpo_beta, "dpo_beta": dpo_beta,
@ -90,8 +110,7 @@ def train(
train_type=train_type, train_type=train_type,
load_path=data_root_path, load_path=data_root_path,
window_size=window_size, window_size=window_size,
stride=stride, stride=stride
**kwargs
) )
param_groups = [ param_groups = [
@ -107,7 +126,7 @@ def train(
schedule_config = CosineScheduleConfig( schedule_config = CosineScheduleConfig(
warmup_steps=warmup_steps, warmup_steps=warmup_steps,
total_steps=len(dataset) * n_epoch // batch_size, total_steps=len(dataset) * n_epoch // (batch_size * nprocs),
) )
scheduler = SchedulerFactory.load(optimizer, schedule_config) scheduler = SchedulerFactory.load(optimizer, schedule_config)
@ -128,7 +147,9 @@ def train(
max_grad_norm=max_grad_norm, max_grad_norm=max_grad_norm,
random_seed=random_seed, random_seed=random_seed,
num_workers=num_workers, num_workers=num_workers,
pin_memory=pin_memory pin_memory=pin_memory,
nprocs=nprocs,
extra_kwargs=kwargs,
) )
trainer = Trainer(train_config) trainer = Trainer(train_config)
@ -137,4 +158,10 @@ def train(
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
train(**vars(args))
spawn_parallel_fn(
train,
world_size=args.nprocs,
backend="nccl",
**vars(args)
)