feat(parallel/device): 引入设备策略注册机制以支持多种后端

This commit is contained in:
ViperEkura 2025-12-15 13:58:59 +08:00
parent 831933fb66
commit 3ac38a7ebc
2 changed files with 89 additions and 10 deletions

85
khaosz/parallel/device.py Normal file
View File

@ -0,0 +1,85 @@
import os
import torch
import torch.distributed as dist
from dataclasses import dataclass
from typing import Callable, List
@dataclass
class DeviceStrategy:
"""
A class representing a device strategy.
Attributes:
name: Name of the device backend (e.g., 'cuda', 'xpu').
priority: Higher number means higher priority.
is_available: A callable that returns True if the device is available.
make_device: A callable that takes a rank (int) and returns a torch.device.
"""
name: str
priority: int
is_available: Callable[[], bool]
make_device: Callable[[int], torch.device]
class DeviceStrategyRegistry:
"""
A registry for device strategies that automatically selects the best available device.
And allows overriding the device backend via environment variable.
"""
def __init__(self) -> None:
self._strategies: List[DeviceStrategy] = []
# Register default strategies
self.register(DeviceStrategy(
name="cuda",
priority=100,
is_available=torch.cuda.is_available,
make_device=lambda rank: torch.device(f"cuda:{rank}")
))
self.register(DeviceStrategy(
name="xpu",
priority=90,
is_available=torch.xpu.is_available,
make_device=lambda rank: torch.device(f"xpu:{rank}")
))
self.register(DeviceStrategy(
name="mps",
priority=80,
is_available=torch.mps.is_available,
make_device=lambda _: torch.device("mps") # MPS ignores rank
))
self.register(DeviceStrategy(
name="cpu",
priority=0,
is_available=lambda: True,
make_device=lambda _: torch.device("cpu")
))
def register(self, strategy: DeviceStrategy):
self._strategies.append(strategy)
def get_current_device(self) -> torch.device:
"""Return the best available device for the current process."""
# Allow environment override (for debugging)
override = os.getenv("TORCH_DEVICE_OVERRIDE")
if override:
return torch.device(override)
sorted_strategies = sorted(self._strategies, key=lambda s: -s.priority)
rank = 0
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
for strategy in sorted_strategies:
if strategy.is_available():
return strategy.make_device(rank)
raise RuntimeError("No device backend is available, including CPU.")
device_strategy_registry = DeviceStrategyRegistry()

View File

@ -1,3 +1,4 @@
import os import os
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -6,7 +7,10 @@ import torch.multiprocessing as mp
from typing import Callable from typing import Callable
from functools import wraps from functools import wraps
from contextlib import contextmanager from contextlib import contextmanager
from khaosz.parallel.device import device_strategy_registry
def get_current_device():
return device_strategy_registry.get_current_device()
def get_world_size() -> int: def get_world_size() -> int:
if dist.is_available() and dist.is_initialized(): if dist.is_available() and dist.is_initialized():
@ -20,16 +24,6 @@ def get_rank() -> int:
else: else:
return 0 return 0
def get_current_device():
if torch.cuda.is_available():
return torch.device(f"cuda:{torch.cuda.current_device()}")
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
return torch.device(f"xpu:{torch.xpu.current_device()}")
elif hasattr(torch, 'mps') and torch.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")
@contextmanager @contextmanager
def setup_parallel( def setup_parallel(
rank: int, rank: int,