feat(parallel): 重构并重命名并行工具函数以提升灵活性
This commit is contained in:
parent
c86e573195
commit
530fb50352
|
|
@ -5,7 +5,8 @@ from khaosz.parallel.utils import (
|
|||
get_current_device,
|
||||
get_available_backend,
|
||||
setup_parallel,
|
||||
only_main_procs,
|
||||
only_on_rank,
|
||||
run_on_rank,
|
||||
spawn_parallel_fn
|
||||
)
|
||||
|
||||
|
|
@ -21,7 +22,8 @@ __all__ = [
|
|||
"get_current_device",
|
||||
"get_available_backend",
|
||||
"setup_parallel",
|
||||
"only_main_procs",
|
||||
"only_on_rank",
|
||||
"run_on_rank",
|
||||
"spawn_parallel_fn",
|
||||
|
||||
"RowParallelLinear",
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ import os
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from functools import wraps
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
|
|
@ -88,19 +90,38 @@ def setup_parallel(
|
|||
dist.destroy_process_group()
|
||||
|
||||
@contextmanager
|
||||
def only_main_procs(main_process_rank=0, block=True):
|
||||
is_main_proc = (get_rank() == main_process_rank)
|
||||
def run_on_rank(rank=0, sync_before=True, sync_after=True):
|
||||
"""
|
||||
context manager to run a function only on a specific rank.
|
||||
"""
|
||||
is_main_proc = (get_rank() == rank)
|
||||
|
||||
if dist.is_initialized() and block:
|
||||
if dist.is_initialized() and sync_before:
|
||||
dist.barrier()
|
||||
|
||||
try:
|
||||
yield is_main_proc
|
||||
|
||||
finally:
|
||||
if dist.is_initialized() and block:
|
||||
if dist.is_initialized() and sync_after:
|
||||
dist.barrier()
|
||||
|
||||
def only_on_rank(rank=0):
|
||||
"""
|
||||
decorator to run a function only on a specific rank.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if get_rank() == rank:
|
||||
return func(*args, **kwargs)
|
||||
else:
|
||||
return None
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
def wrapper_spawn_func(rank, world_size, func, kwargs_dict):
|
||||
with setup_parallel(rank, world_size):
|
||||
|
|
|
|||
Loading…
Reference in New Issue