diff --git a/khaosz/parallel/__init__.py b/khaosz/parallel/__init__.py index 8bee6f2..ad8b5e9 100644 --- a/khaosz/parallel/__init__.py +++ b/khaosz/parallel/__init__.py @@ -5,7 +5,8 @@ from khaosz.parallel.utils import ( get_current_device, get_available_backend, setup_parallel, - only_main_procs, + only_on_rank, + run_on_rank, spawn_parallel_fn ) @@ -21,7 +22,8 @@ __all__ = [ "get_current_device", "get_available_backend", "setup_parallel", - "only_main_procs", + "only_on_rank", + "run_on_rank", "spawn_parallel_fn", "RowParallelLinear", diff --git a/khaosz/parallel/utils.py b/khaosz/parallel/utils.py index a18c21b..44d0cd9 100644 --- a/khaosz/parallel/utils.py +++ b/khaosz/parallel/utils.py @@ -2,6 +2,8 @@ import os import torch import torch.distributed as dist import torch.multiprocessing as mp + +from functools import wraps from contextlib import contextmanager @@ -88,19 +90,38 @@ def setup_parallel( dist.destroy_process_group() @contextmanager -def only_main_procs(main_process_rank=0, block=True): - is_main_proc = (get_rank() == main_process_rank) +def run_on_rank(rank=0, sync_before=True, sync_after=True): + """ + context manager to run a function only on a specific rank. + """ + is_main_proc = (get_rank() == rank) - if dist.is_initialized() and block: + if dist.is_initialized() and sync_before: dist.barrier() try: yield is_main_proc finally: - if dist.is_initialized() and block: + if dist.is_initialized() and sync_after: dist.barrier() +def only_on_rank(rank=0): + """ + decorator to run a function only on a specific rank. + """ + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + if get_rank() == rank: + return func(*args, **kwargs) + else: + return None + + return wrapper + + return decorator def wrapper_spawn_func(rank, world_size, func, kwargs_dict): with setup_parallel(rank, world_size):