feat(parallel): 重构并重命名并行工具函数以提升灵活性

This commit is contained in:
ViperEkura 2025-12-10 14:43:35 +08:00
parent c86e573195
commit 530fb50352
2 changed files with 29 additions and 6 deletions

View File

@ -5,7 +5,8 @@ from khaosz.parallel.utils import (
get_current_device, get_current_device,
get_available_backend, get_available_backend,
setup_parallel, setup_parallel,
only_main_procs, only_on_rank,
run_on_rank,
spawn_parallel_fn spawn_parallel_fn
) )
@ -21,7 +22,8 @@ __all__ = [
"get_current_device", "get_current_device",
"get_available_backend", "get_available_backend",
"setup_parallel", "setup_parallel",
"only_main_procs", "only_on_rank",
"run_on_rank",
"spawn_parallel_fn", "spawn_parallel_fn",
"RowParallelLinear", "RowParallelLinear",

View File

@ -2,6 +2,8 @@ import os
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from functools import wraps
from contextlib import contextmanager from contextlib import contextmanager
@ -88,19 +90,38 @@ def setup_parallel(
dist.destroy_process_group() dist.destroy_process_group()
@contextmanager @contextmanager
def only_main_procs(main_process_rank=0, block=True): def run_on_rank(rank=0, sync_before=True, sync_after=True):
is_main_proc = (get_rank() == main_process_rank) """
context manager to run a function only on a specific rank.
"""
is_main_proc = (get_rank() == rank)
if dist.is_initialized() and block: if dist.is_initialized() and sync_before:
dist.barrier() dist.barrier()
try: try:
yield is_main_proc yield is_main_proc
finally: finally:
if dist.is_initialized() and block: if dist.is_initialized() and sync_after:
dist.barrier() dist.barrier()
def only_on_rank(rank=0):
"""
decorator to run a function only on a specific rank.
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if get_rank() == rank:
return func(*args, **kwargs)
else:
return None
return wrapper
return decorator
def wrapper_spawn_func(rank, world_size, func, kwargs_dict): def wrapper_spawn_func(rank, world_size, func, kwargs_dict):
with setup_parallel(rank, world_size): with setup_parallel(rank, world_size):