feat(parallel): 改进设备策略注册表与并行设置功能

This commit is contained in:
ViperEkura 2025-12-19 15:25:31 +08:00
parent 3ac38a7ebc
commit eab7a51bb6
3 changed files with 77 additions and 23 deletions

2
.gitignore vendored
View File

@ -11,3 +11,5 @@ params/*
# build file # build file
build build
*.egg-info *.egg-info
*.ipynb

View File

@ -2,7 +2,7 @@ import os
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, List from typing import Callable, List, Optional
@dataclass @dataclass
@ -28,10 +28,20 @@ class DeviceStrategyRegistry:
And allows overriding the device backend via environment variable. And allows overriding the device backend via environment variable.
""" """
_instance: Optional["DeviceStrategyRegistry"] = None
_initialized: bool = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None: def __init__(self) -> None:
if self._initialized:
return
self._strategies: List[DeviceStrategy] = [] self._strategies: List[DeviceStrategy] = []
# Register default strategies
self.register(DeviceStrategy( self.register(DeviceStrategy(
name="cuda", name="cuda",
priority=100, priority=100,
@ -60,26 +70,31 @@ class DeviceStrategyRegistry:
make_device=lambda _: torch.device("cpu") make_device=lambda _: torch.device("cpu")
)) ))
self._initialized = True
def register(self, strategy: DeviceStrategy): def register(self, strategy: DeviceStrategy):
self._strategies.append(strategy) self._strategies.append(strategy)
def get_current_device(self) -> torch.device: def get_current_device(self) -> torch.device:
"""Return the best available device for the current process.""" """Return the best available device for the current process."""
# Allow environment override (for debugging)
override = os.getenv("TORCH_DEVICE_OVERRIDE") override = os.getenv("TORCH_DEVICE_OVERRIDE")
if override:
return torch.device(override)
sorted_strategies = sorted(self._strategies, key=lambda s: -s.priority) sorted_strategies = sorted(self._strategies, key=lambda s: -s.priority)
rank = 0 rank = 0
if dist.is_available() and dist.is_initialized(): if dist.is_available() and dist.is_initialized():
rank = dist.get_rank() rank = os.environ["LOCAL_RANK"]
if override:
return torch.device(override, rank)
for strategy in sorted_strategies: for strategy in sorted_strategies:
if strategy.is_available(): if strategy.is_available():
return strategy.make_device(rank) return strategy.make_device(rank)
raise RuntimeError("No device backend is available, including CPU.") raise RuntimeError("No device backend is available, including CPU.")
device_strategy_registry = DeviceStrategyRegistry()
device_registry = DeviceStrategyRegistry()

View File

@ -1,16 +1,16 @@
import os import os
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from typing import Callable
from functools import wraps from functools import wraps
from contextlib import contextmanager from contextlib import contextmanager
from khaosz.parallel.device import device_strategy_registry from typing import Callable, List, Optional
from khaosz.parallel.device import device_registry
def get_current_device(): def get_current_device():
return device_strategy_registry.get_current_device() return device_registry.get_current_device()
def get_world_size() -> int: def get_world_size() -> int:
if dist.is_available() and dist.is_initialized(): if dist.is_available() and dist.is_initialized():
@ -30,7 +30,8 @@ def setup_parallel(
world_size: int, world_size: int,
backend: str = "nccl", backend: str = "nccl",
master_addr: str = "localhost", master_addr: str = "localhost",
master_port: str = "29500" master_port: str = "29500",
avail_ids: Optional[List[int]] = None
): ):
if dist.is_available() and dist.is_initialized(): if dist.is_available() and dist.is_initialized():
@ -41,9 +42,13 @@ def setup_parallel(
yield None yield None
return return
if avail_ids is None:
avail_ids = [i for i in range(world_size)]
rank = avail_ids[rank % len(avail_ids)]
os.environ['MASTER_ADDR'] = master_addr os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = master_port os.environ['MASTER_PORT'] = master_port
os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size) os.environ['WORLD_SIZE'] = str(world_size)
os.environ['LOCAL_RANK'] = str(rank) os.environ['LOCAL_RANK'] = str(rank)
@ -82,11 +87,40 @@ def only_on_rank(rank, sync=False):
return decorator return decorator
def wrapper_spawn_func(rank, world_size, backend, func, kwargs): def wrapper_spawn_func(
with setup_parallel(rank, world_size, backend): rank: int,
world_size: int,
backend: str,
master_addr: str,
master_port: str,
avail_ids: List[int],
func: Callable,
kwargs: dict
):
try:
with setup_parallel(
rank=rank,
world_size=world_size,
backend=backend,
master_addr=master_addr,
master_port=master_port,
avail_ids=avail_ids
):
func(**kwargs) func(**kwargs)
def spawn_parallel_fn(func: Callable, world_size: int, backend: str, **kwargs): except Exception as e:
print(f"Error in rank {rank}: {e}")
raise
def spawn_parallel_fn(
func: Callable,
world_size: int,
backend: str = "nccl",
master_addr: str = "localhost",
master_port: str = "29500",
avail_ids: Optional[List[int]] = None,
**kwargs
):
if world_size == 1: if world_size == 1:
func(**kwargs) func(**kwargs)
@ -97,9 +131,12 @@ def spawn_parallel_fn(func: Callable, world_size: int, backend: str, **kwargs):
if key in os.environ: if key in os.environ:
del os.environ[key] del os.environ[key]
wrapper_spawn_func_args = (world_size, backend,
master_addr, master_port, avail_ids, func, kwargs)
mp.spawn( mp.spawn(
wrapper_spawn_func, wrapper_spawn_func,
nprocs=world_size, nprocs=world_size,
args=(world_size, backend, func, kwargs), args=wrapper_spawn_func_args,
join=True join=True
) )