105 lines
2.6 KiB
Python
105 lines
2.6 KiB
Python
|
|
import os
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.multiprocessing as mp
|
|
|
|
from typing import Callable
|
|
from functools import wraps
|
|
from contextlib import contextmanager
|
|
from khaosz.parallel.device import device_strategy_registry
|
|
|
|
def get_current_device():
|
|
return device_strategy_registry.get_current_device()
|
|
|
|
def get_world_size() -> int:
|
|
if dist.is_available() and dist.is_initialized():
|
|
return dist.get_world_size()
|
|
else:
|
|
return 1
|
|
|
|
def get_rank() -> int:
|
|
if dist.is_available() and dist.is_initialized():
|
|
return dist.get_rank()
|
|
else:
|
|
return 0
|
|
|
|
@contextmanager
|
|
def setup_parallel(
|
|
rank: int,
|
|
world_size: int,
|
|
backend: str = "nccl",
|
|
master_addr: str = "localhost",
|
|
master_port: str = "29500"
|
|
):
|
|
|
|
if dist.is_available() and dist.is_initialized():
|
|
yield dist.group.WORLD
|
|
return
|
|
|
|
if world_size <= 1:
|
|
yield None
|
|
return
|
|
|
|
os.environ['MASTER_ADDR'] = master_addr
|
|
os.environ['MASTER_PORT'] = master_port
|
|
os.environ['RANK'] = str(rank)
|
|
os.environ['WORLD_SIZE'] = str(world_size)
|
|
os.environ['LOCAL_RANK'] = str(rank)
|
|
|
|
dist.init_process_group(
|
|
backend=backend,
|
|
init_method=f"tcp://{master_addr}:{master_port}",
|
|
rank=rank,
|
|
world_size=world_size
|
|
)
|
|
|
|
try:
|
|
if backend == "nccl" and torch.cuda.is_available():
|
|
torch.cuda.set_device(rank)
|
|
elif backend == "ccl" and hasattr(torch, 'xpu') and torch.xpu.is_available():
|
|
torch.xpu.set_device(rank)
|
|
|
|
yield dist.group.WORLD
|
|
finally:
|
|
if dist.is_initialized():
|
|
dist.destroy_process_group()
|
|
|
|
def only_on_rank(rank, sync=False):
|
|
"""
|
|
decorator to run a function only on a specific rank.
|
|
"""
|
|
|
|
def decorator(func):
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
if get_rank() == rank:
|
|
return func(*args, **kwargs)
|
|
if sync:
|
|
dist.barrier()
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
def wrapper_spawn_func(rank, world_size, backend, func, kwargs):
|
|
with setup_parallel(rank, world_size, backend):
|
|
func(**kwargs)
|
|
|
|
def spawn_parallel_fn(func: Callable, world_size: int, backend: str, **kwargs):
|
|
|
|
if world_size == 1:
|
|
func(**kwargs)
|
|
return
|
|
|
|
# clear environment variables
|
|
for key in ['MASTER_ADDR', 'MASTER_PORT', 'RANK', 'WORLD_SIZE', 'LOCAL_RANK']:
|
|
if key in os.environ:
|
|
del os.environ[key]
|
|
|
|
mp.spawn(
|
|
wrapper_spawn_func,
|
|
nprocs=world_size,
|
|
args=(world_size, backend, func, kwargs),
|
|
join=True
|
|
) |