feat(parallel/device): 引入设备策略注册机制以支持多种后端

This commit is contained in:
ViperEkura 2025-12-15 13:58:59 +08:00
parent 831933fb66
commit 3ac38a7ebc
2 changed files with 89 additions and 10 deletions

85
khaosz/parallel/device.py Normal file
View File

@ -0,0 +1,85 @@
import os
import torch
import torch.distributed as dist
from dataclasses import dataclass
from typing import Callable, List
@dataclass
class DeviceStrategy:
"""
A class representing a device strategy.
Attributes:
name: Name of the device backend (e.g., 'cuda', 'xpu').
priority: Higher number means higher priority.
is_available: A callable that returns True if the device is available.
make_device: A callable that takes a rank (int) and returns a torch.device.
"""
name: str
priority: int
is_available: Callable[[], bool]
make_device: Callable[[int], torch.device]
class DeviceStrategyRegistry:
"""
A registry for device strategies that automatically selects the best available device.
And allows overriding the device backend via environment variable.
"""
def __init__(self) -> None:
self._strategies: List[DeviceStrategy] = []
# Register default strategies
self.register(DeviceStrategy(
name="cuda",
priority=100,
is_available=torch.cuda.is_available,
make_device=lambda rank: torch.device(f"cuda:{rank}")
))
self.register(DeviceStrategy(
name="xpu",
priority=90,
is_available=torch.xpu.is_available,
make_device=lambda rank: torch.device(f"xpu:{rank}")
))
self.register(DeviceStrategy(
name="mps",
priority=80,
is_available=torch.mps.is_available,
make_device=lambda _: torch.device("mps") # MPS ignores rank
))
self.register(DeviceStrategy(
name="cpu",
priority=0,
is_available=lambda: True,
make_device=lambda _: torch.device("cpu")
))
def register(self, strategy: DeviceStrategy):
self._strategies.append(strategy)
def get_current_device(self) -> torch.device:
"""Return the best available device for the current process."""
# Allow environment override (for debugging)
override = os.getenv("TORCH_DEVICE_OVERRIDE")
if override:
return torch.device(override)
sorted_strategies = sorted(self._strategies, key=lambda s: -s.priority)
rank = 0
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
for strategy in sorted_strategies:
if strategy.is_available():
return strategy.make_device(rank)
raise RuntimeError("No device backend is available, including CPU.")
device_strategy_registry = DeviceStrategyRegistry()

View File

@ -1,3 +1,4 @@
import os
import torch
import torch.distributed as dist
@ -6,7 +7,10 @@ import torch.multiprocessing as mp
from typing import Callable
from functools import wraps
from contextlib import contextmanager
from khaosz.parallel.device import device_strategy_registry
def get_current_device():
return device_strategy_registry.get_current_device()
def get_world_size() -> int:
if dist.is_available() and dist.is_initialized():
@ -20,16 +24,6 @@ def get_rank() -> int:
else:
return 0
def get_current_device():
if torch.cuda.is_available():
return torch.device(f"cuda:{torch.cuda.current_device()}")
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
return torch.device(f"xpu:{torch.xpu.current_device()}")
elif hasattr(torch, 'mps') and torch.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")
@contextmanager
def setup_parallel(
rank: int,