feat(parallel): 重构并重命名并行工具函数以提升灵活性
This commit is contained in:
parent
c86e573195
commit
530fb50352
|
|
@ -5,7 +5,8 @@ from khaosz.parallel.utils import (
|
||||||
get_current_device,
|
get_current_device,
|
||||||
get_available_backend,
|
get_available_backend,
|
||||||
setup_parallel,
|
setup_parallel,
|
||||||
only_main_procs,
|
only_on_rank,
|
||||||
|
run_on_rank,
|
||||||
spawn_parallel_fn
|
spawn_parallel_fn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -21,7 +22,8 @@ __all__ = [
|
||||||
"get_current_device",
|
"get_current_device",
|
||||||
"get_available_backend",
|
"get_available_backend",
|
||||||
"setup_parallel",
|
"setup_parallel",
|
||||||
"only_main_procs",
|
"only_on_rank",
|
||||||
|
"run_on_rank",
|
||||||
"spawn_parallel_fn",
|
"spawn_parallel_fn",
|
||||||
|
|
||||||
"RowParallelLinear",
|
"RowParallelLinear",
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,8 @@ import os
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
from functools import wraps
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -88,19 +90,38 @@ def setup_parallel(
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def only_main_procs(main_process_rank=0, block=True):
|
def run_on_rank(rank=0, sync_before=True, sync_after=True):
|
||||||
is_main_proc = (get_rank() == main_process_rank)
|
"""
|
||||||
|
context manager to run a function only on a specific rank.
|
||||||
|
"""
|
||||||
|
is_main_proc = (get_rank() == rank)
|
||||||
|
|
||||||
if dist.is_initialized() and block:
|
if dist.is_initialized() and sync_before:
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield is_main_proc
|
yield is_main_proc
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if dist.is_initialized() and block:
|
if dist.is_initialized() and sync_after:
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
|
def only_on_rank(rank=0):
|
||||||
|
"""
|
||||||
|
decorator to run a function only on a specific rank.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if get_rank() == rank:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
def wrapper_spawn_func(rank, world_size, func, kwargs_dict):
|
def wrapper_spawn_func(rank, world_size, func, kwargs_dict):
|
||||||
with setup_parallel(rank, world_size):
|
with setup_parallel(rank, world_size):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue