From 530fb5035212fdbc090ebd13d84127ea6344ea20 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Wed, 10 Dec 2025 14:43:35 +0800 Subject: [PATCH] =?UTF-8?q?feat(parallel):=20=E9=87=8D=E6=9E=84=E5=B9=B6?= =?UTF-8?q?=E9=87=8D=E5=91=BD=E5=90=8D=E5=B9=B6=E8=A1=8C=E5=B7=A5=E5=85=B7?= =?UTF-8?q?=E5=87=BD=E6=95=B0=E4=BB=A5=E6=8F=90=E5=8D=87=E7=81=B5=E6=B4=BB?= =?UTF-8?q?=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/parallel/__init__.py | 6 ++++-- khaosz/parallel/utils.py | 29 +++++++++++++++++++++++++---- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/khaosz/parallel/__init__.py b/khaosz/parallel/__init__.py index 8bee6f2..ad8b5e9 100644 --- a/khaosz/parallel/__init__.py +++ b/khaosz/parallel/__init__.py @@ -5,7 +5,8 @@ from khaosz.parallel.utils import ( get_current_device, get_available_backend, setup_parallel, - only_main_procs, + only_on_rank, + run_on_rank, spawn_parallel_fn ) @@ -21,7 +22,8 @@ __all__ = [ "get_current_device", "get_available_backend", "setup_parallel", - "only_main_procs", + "only_on_rank", + "run_on_rank", "spawn_parallel_fn", "RowParallelLinear", diff --git a/khaosz/parallel/utils.py b/khaosz/parallel/utils.py index a18c21b..44d0cd9 100644 --- a/khaosz/parallel/utils.py +++ b/khaosz/parallel/utils.py @@ -2,6 +2,8 @@ import os import torch import torch.distributed as dist import torch.multiprocessing as mp + +from functools import wraps from contextlib import contextmanager @@ -88,19 +90,38 @@ def setup_parallel( dist.destroy_process_group() @contextmanager -def only_main_procs(main_process_rank=0, block=True): - is_main_proc = (get_rank() == main_process_rank) +def run_on_rank(rank=0, sync_before=True, sync_after=True): + """ + context manager to run a function only on a specific rank. + """ + is_main_proc = (get_rank() == rank) - if dist.is_initialized() and block: + if dist.is_initialized() and sync_before: dist.barrier() try: yield is_main_proc finally: - if dist.is_initialized() and block: + if dist.is_initialized() and sync_after: dist.barrier() +def only_on_rank(rank=0): + """ + decorator to run a function only on a specific rank. + """ + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + if get_rank() == rank: + return func(*args, **kwargs) + else: + return None + + return wrapper + + return decorator def wrapper_spawn_func(rank, world_size, func, kwargs_dict): with setup_parallel(rank, world_size):