feat(parallel): 重构并重命名并行工具函数以提升灵活性

This commit is contained in:
ViperEkura 2025-12-10 14:43:35 +08:00
parent c86e573195
commit 530fb50352
2 changed files with 29 additions and 6 deletions

View File

@ -5,7 +5,8 @@ from khaosz.parallel.utils import (
get_current_device,
get_available_backend,
setup_parallel,
only_main_procs,
only_on_rank,
run_on_rank,
spawn_parallel_fn
)
@ -21,7 +22,8 @@ __all__ = [
"get_current_device",
"get_available_backend",
"setup_parallel",
"only_main_procs",
"only_on_rank",
"run_on_rank",
"spawn_parallel_fn",
"RowParallelLinear",

View File

@ -2,6 +2,8 @@ import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from functools import wraps
from contextlib import contextmanager
@ -88,19 +90,38 @@ def setup_parallel(
dist.destroy_process_group()
@contextmanager
def only_main_procs(main_process_rank=0, block=True):
is_main_proc = (get_rank() == main_process_rank)
def run_on_rank(rank=0, sync_before=True, sync_after=True):
"""
context manager to run a function only on a specific rank.
"""
is_main_proc = (get_rank() == rank)
if dist.is_initialized() and block:
if dist.is_initialized() and sync_before:
dist.barrier()
try:
yield is_main_proc
finally:
if dist.is_initialized() and block:
if dist.is_initialized() and sync_after:
dist.barrier()
def only_on_rank(rank=0):
"""
decorator to run a function only on a specific rank.
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if get_rank() == rank:
return func(*args, **kwargs)
else:
return None
return wrapper
return decorator
def wrapper_spawn_func(rank, world_size, func, kwargs_dict):
with setup_parallel(rank, world_size):