feat(khaosz/parallel): 添加对多种设备后端的支持并优化并行初始化逻辑
This commit is contained in:
parent
08c5a52dc8
commit
6270415590
|
|
@ -2,53 +2,18 @@ import os
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
@contextmanager
|
||||
def setup_parallel(
|
||||
rank: int = 0,
|
||||
world_size: int=1,
|
||||
master_addr: str="localhost",
|
||||
master_port: str="29500"
|
||||
):
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
yield dist.group.WORLD
|
||||
|
||||
os.environ['MASTER_ADDR'] = master_addr
|
||||
os.environ['MASTER_PORT'] = master_port
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
|
||||
dist.init_process_group(
|
||||
backend="nccl" if torch.cuda.is_available() else "gloo",
|
||||
init_method="env://",
|
||||
rank=rank,
|
||||
world_size=world_size
|
||||
)
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
try:
|
||||
yield dist.group.WORLD
|
||||
finally:
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def wrapper_func(rank, world_size, func, config_pack):
|
||||
with setup_parallel(rank, world_size):
|
||||
func(**config_pack)
|
||||
|
||||
|
||||
def spawn_parallel_fn(func, world_size, kwargs_dict):
|
||||
mp.spawn(
|
||||
wrapper_func,
|
||||
nprocs=world_size,
|
||||
args=(world_size, func, kwargs_dict,),
|
||||
join=True
|
||||
)
|
||||
|
||||
def get_device_count() -> int:
|
||||
if torch.cuda.is_available():
|
||||
return torch.cuda.device_count()
|
||||
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
|
||||
return torch.xpu.device_count()
|
||||
elif hasattr(torch, 'mps') and torch.mps.is_available():
|
||||
return 1
|
||||
else:
|
||||
return 1
|
||||
|
||||
def get_current_device() -> torch.device:
|
||||
if torch.cuda.is_available():
|
||||
|
|
@ -58,4 +23,82 @@ def get_current_device() -> torch.device:
|
|||
elif hasattr(torch, 'mps') and torch.mps.is_available():
|
||||
return torch.device("mps")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
return torch.device("cpu")
|
||||
|
||||
def get_available_backend():
|
||||
if torch.cuda.is_available():
|
||||
return "nccl"
|
||||
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
|
||||
return "ccl" # Intel XPU use ccl
|
||||
else:
|
||||
return "gloo"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def setup_parallel(
|
||||
rank: int = 0,
|
||||
world_size: int = 1,
|
||||
master_addr: str = "localhost",
|
||||
master_port: str = "29500"
|
||||
):
|
||||
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
yield dist.group.WORLD
|
||||
return
|
||||
|
||||
if world_size <= 1:
|
||||
yield None
|
||||
return
|
||||
|
||||
os.environ['MASTER_ADDR'] = master_addr
|
||||
os.environ['MASTER_PORT'] = master_port
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
|
||||
backend = get_available_backend()
|
||||
|
||||
dist.init_process_group(
|
||||
backend=backend,
|
||||
init_method="env://",
|
||||
rank=rank,
|
||||
world_size=world_size
|
||||
)
|
||||
|
||||
try:
|
||||
if backend == "nccl" and torch.cuda.is_available():
|
||||
torch.cuda.set_device(rank)
|
||||
elif backend == "ccl" and hasattr(torch, 'xpu') and torch.xpu.is_available():
|
||||
torch.xpu.set_device(rank)
|
||||
|
||||
yield dist.group.WORLD
|
||||
finally:
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
def wrapper_spawn_func(rank, world_size, func, kwargs_dict):
|
||||
with setup_parallel(rank, world_size):
|
||||
func(**kwargs_dict)
|
||||
|
||||
def spawn_parallel_fn(func, world_size=None, **kwargs):
|
||||
|
||||
if world_size is None:
|
||||
world_size = get_device_count()
|
||||
|
||||
if world_size < 1:
|
||||
raise ValueError("world_size must be greater than 0")
|
||||
|
||||
device_count = get_device_count()
|
||||
if world_size > device_count:
|
||||
raise ValueError(f"world_size ({world_size}) exceeds available devices ({device_count})")
|
||||
|
||||
if world_size == 1:
|
||||
func(**kwargs)
|
||||
return
|
||||
|
||||
mp.spawn(
|
||||
wrapper_spawn_func,
|
||||
nprocs=world_size,
|
||||
args=(world_size, func, kwargs),
|
||||
join=True
|
||||
)
|
||||
Loading…
Reference in New Issue