AstrAI/khaosz/parallel/setup.py

105 lines
2.6 KiB
Python

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from typing import Callable
from functools import wraps
from contextlib import contextmanager
from khaosz.parallel.device import device_strategy_registry
def get_current_device():
return device_strategy_registry.get_current_device()
def get_world_size() -> int:
if dist.is_available() and dist.is_initialized():
return dist.get_world_size()
else:
return 1
def get_rank() -> int:
if dist.is_available() and dist.is_initialized():
return dist.get_rank()
else:
return 0
@contextmanager
def setup_parallel(
rank: int,
world_size: int,
backend: str = "nccl",
master_addr: str = "localhost",
master_port: str = "29500"
):
if dist.is_available() and dist.is_initialized():
yield dist.group.WORLD
return
if world_size <= 1:
yield None
return
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = master_port
os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['LOCAL_RANK'] = str(rank)
dist.init_process_group(
backend=backend,
init_method=f"tcp://{master_addr}:{master_port}",
rank=rank,
world_size=world_size
)
try:
if backend == "nccl" and torch.cuda.is_available():
torch.cuda.set_device(rank)
elif backend == "ccl" and hasattr(torch, 'xpu') and torch.xpu.is_available():
torch.xpu.set_device(rank)
yield dist.group.WORLD
finally:
if dist.is_initialized():
dist.destroy_process_group()
def only_on_rank(rank, sync=False):
"""
decorator to run a function only on a specific rank.
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if get_rank() == rank:
return func(*args, **kwargs)
if sync:
dist.barrier()
return wrapper
return decorator
def wrapper_spawn_func(rank, world_size, backend, func, kwargs):
with setup_parallel(rank, world_size, backend):
func(**kwargs)
def spawn_parallel_fn(func: Callable, world_size: int, backend: str, **kwargs):
if world_size == 1:
func(**kwargs)
return
# clear environment variables
for key in ['MASTER_ADDR', 'MASTER_PORT', 'RANK', 'WORLD_SIZE', 'LOCAL_RANK']:
if key in os.environ:
del os.environ[key]
mp.spawn(
wrapper_spawn_func,
nprocs=world_size,
args=(world_size, backend, func, kwargs),
join=True
)