feat(parallel/device): 引入设备策略注册机制以支持多种后端
This commit is contained in:
parent
831933fb66
commit
3ac38a7ebc
|
|
@ -0,0 +1,85 @@
|
|||
import os
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, List
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceStrategy:
|
||||
"""
|
||||
A class representing a device strategy.
|
||||
|
||||
Attributes:
|
||||
name: Name of the device backend (e.g., 'cuda', 'xpu').
|
||||
priority: Higher number means higher priority.
|
||||
is_available: A callable that returns True if the device is available.
|
||||
make_device: A callable that takes a rank (int) and returns a torch.device.
|
||||
"""
|
||||
name: str
|
||||
priority: int
|
||||
is_available: Callable[[], bool]
|
||||
make_device: Callable[[int], torch.device]
|
||||
|
||||
|
||||
class DeviceStrategyRegistry:
|
||||
"""
|
||||
A registry for device strategies that automatically selects the best available device.
|
||||
And allows overriding the device backend via environment variable.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._strategies: List[DeviceStrategy] = []
|
||||
|
||||
# Register default strategies
|
||||
self.register(DeviceStrategy(
|
||||
name="cuda",
|
||||
priority=100,
|
||||
is_available=torch.cuda.is_available,
|
||||
make_device=lambda rank: torch.device(f"cuda:{rank}")
|
||||
))
|
||||
|
||||
self.register(DeviceStrategy(
|
||||
name="xpu",
|
||||
priority=90,
|
||||
is_available=torch.xpu.is_available,
|
||||
make_device=lambda rank: torch.device(f"xpu:{rank}")
|
||||
))
|
||||
|
||||
self.register(DeviceStrategy(
|
||||
name="mps",
|
||||
priority=80,
|
||||
is_available=torch.mps.is_available,
|
||||
make_device=lambda _: torch.device("mps") # MPS ignores rank
|
||||
))
|
||||
|
||||
self.register(DeviceStrategy(
|
||||
name="cpu",
|
||||
priority=0,
|
||||
is_available=lambda: True,
|
||||
make_device=lambda _: torch.device("cpu")
|
||||
))
|
||||
|
||||
def register(self, strategy: DeviceStrategy):
|
||||
self._strategies.append(strategy)
|
||||
|
||||
def get_current_device(self) -> torch.device:
|
||||
"""Return the best available device for the current process."""
|
||||
# Allow environment override (for debugging)
|
||||
override = os.getenv("TORCH_DEVICE_OVERRIDE")
|
||||
if override:
|
||||
return torch.device(override)
|
||||
|
||||
sorted_strategies = sorted(self._strategies, key=lambda s: -s.priority)
|
||||
|
||||
rank = 0
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
rank = dist.get_rank()
|
||||
|
||||
for strategy in sorted_strategies:
|
||||
if strategy.is_available():
|
||||
return strategy.make_device(rank)
|
||||
|
||||
raise RuntimeError("No device backend is available, including CPU.")
|
||||
|
||||
device_strategy_registry = DeviceStrategyRegistry()
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
|
||||
import os
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -6,7 +7,10 @@ import torch.multiprocessing as mp
|
|||
from typing import Callable
|
||||
from functools import wraps
|
||||
from contextlib import contextmanager
|
||||
from khaosz.parallel.device import device_strategy_registry
|
||||
|
||||
def get_current_device():
|
||||
return device_strategy_registry.get_current_device()
|
||||
|
||||
def get_world_size() -> int:
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
|
|
@ -20,16 +24,6 @@ def get_rank() -> int:
|
|||
else:
|
||||
return 0
|
||||
|
||||
def get_current_device():
|
||||
if torch.cuda.is_available():
|
||||
return torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
|
||||
return torch.device(f"xpu:{torch.xpu.current_device()}")
|
||||
elif hasattr(torch, 'mps') and torch.mps.is_available():
|
||||
return torch.device("mps")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
@contextmanager
|
||||
def setup_parallel(
|
||||
rank: int,
|
||||
|
|
|
|||
Loading…
Reference in New Issue