feat(parallel): 改进设备策略注册表与并行设置功能

This commit is contained in:
ViperEkura 2025-12-19 15:25:31 +08:00
parent 3ac38a7ebc
commit eab7a51bb6
3 changed files with 77 additions and 23 deletions

4
.gitignore vendored
View File

@ -10,4 +10,6 @@ params/*
# build file
build
*.egg-info
*.egg-info
*.ipynb

View File

@ -2,7 +2,7 @@ import os
import torch
import torch.distributed as dist
from dataclasses import dataclass
from typing import Callable, List
from typing import Callable, List, Optional
@dataclass
@ -27,11 +27,21 @@ class DeviceStrategyRegistry:
A registry for device strategies that automatically selects the best available device.
And allows overriding the device backend via environment variable.
"""
_instance: Optional["DeviceStrategyRegistry"] = None
_initialized: bool = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
if self._initialized:
return
self._strategies: List[DeviceStrategy] = []
# Register default strategies
self.register(DeviceStrategy(
name="cuda",
priority=100,
@ -59,27 +69,32 @@ class DeviceStrategyRegistry:
is_available=lambda: True,
make_device=lambda _: torch.device("cpu")
))
self._initialized = True
def register(self, strategy: DeviceStrategy):
self._strategies.append(strategy)
def get_current_device(self) -> torch.device:
"""Return the best available device for the current process."""
# Allow environment override (for debugging)
override = os.getenv("TORCH_DEVICE_OVERRIDE")
if override:
return torch.device(override)
sorted_strategies = sorted(self._strategies, key=lambda s: -s.priority)
rank = 0
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
if dist.is_available() and dist.is_initialized():
rank = os.environ["LOCAL_RANK"]
if override:
return torch.device(override, rank)
for strategy in sorted_strategies:
if strategy.is_available():
return strategy.make_device(rank)
raise RuntimeError("No device backend is available, including CPU.")
device_strategy_registry = DeviceStrategyRegistry()
device_registry = DeviceStrategyRegistry()

View File

@ -1,16 +1,16 @@
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from typing import Callable
from functools import wraps
from contextlib import contextmanager
from khaosz.parallel.device import device_strategy_registry
from typing import Callable, List, Optional
from khaosz.parallel.device import device_registry
def get_current_device():
return device_strategy_registry.get_current_device()
return device_registry.get_current_device()
def get_world_size() -> int:
if dist.is_available() and dist.is_initialized():
@ -30,7 +30,8 @@ def setup_parallel(
world_size: int,
backend: str = "nccl",
master_addr: str = "localhost",
master_port: str = "29500"
master_port: str = "29500",
avail_ids: Optional[List[int]] = None
):
if dist.is_available() and dist.is_initialized():
@ -41,9 +42,13 @@ def setup_parallel(
yield None
return
if avail_ids is None:
avail_ids = [i for i in range(world_size)]
rank = avail_ids[rank % len(avail_ids)]
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = master_port
os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['LOCAL_RANK'] = str(rank)
@ -82,11 +87,40 @@ def only_on_rank(rank, sync=False):
return decorator
def wrapper_spawn_func(rank, world_size, backend, func, kwargs):
with setup_parallel(rank, world_size, backend):
func(**kwargs)
def wrapper_spawn_func(
rank: int,
world_size: int,
backend: str,
master_addr: str,
master_port: str,
avail_ids: List[int],
func: Callable,
kwargs: dict
):
try:
with setup_parallel(
rank=rank,
world_size=world_size,
backend=backend,
master_addr=master_addr,
master_port=master_port,
avail_ids=avail_ids
):
func(**kwargs)
except Exception as e:
print(f"Error in rank {rank}: {e}")
raise
def spawn_parallel_fn(func: Callable, world_size: int, backend: str, **kwargs):
def spawn_parallel_fn(
func: Callable,
world_size: int,
backend: str = "nccl",
master_addr: str = "localhost",
master_port: str = "29500",
avail_ids: Optional[List[int]] = None,
**kwargs
):
if world_size == 1:
func(**kwargs)
@ -96,10 +130,13 @@ def spawn_parallel_fn(func: Callable, world_size: int, backend: str, **kwargs):
for key in ['MASTER_ADDR', 'MASTER_PORT', 'RANK', 'WORLD_SIZE', 'LOCAL_RANK']:
if key in os.environ:
del os.environ[key]
wrapper_spawn_func_args = (world_size, backend,
master_addr, master_port, avail_ids, func, kwargs)
mp.spawn(
wrapper_spawn_func,
nprocs=world_size,
args=(world_size, backend, func, kwargs),
wrapper_spawn_func,
nprocs=world_size,
args=wrapper_spawn_func_args,
join=True
)