From 3ac38a7ebc96899c98818be7bd19568afc7d6a1e Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 15 Dec 2025 13:58:59 +0800 Subject: [PATCH] =?UTF-8?q?feat(parallel/device):=20=E5=BC=95=E5=85=A5?= =?UTF-8?q?=E8=AE=BE=E5=A4=87=E7=AD=96=E7=95=A5=E6=B3=A8=E5=86=8C=E6=9C=BA?= =?UTF-8?q?=E5=88=B6=E4=BB=A5=E6=94=AF=E6=8C=81=E5=A4=9A=E7=A7=8D=E5=90=8E?= =?UTF-8?q?=E7=AB=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/parallel/device.py | 85 +++++++++++++++++++++++++++++++++++++++ khaosz/parallel/setup.py | 14 ++----- 2 files changed, 89 insertions(+), 10 deletions(-) create mode 100644 khaosz/parallel/device.py diff --git a/khaosz/parallel/device.py b/khaosz/parallel/device.py new file mode 100644 index 0000000..21876f0 --- /dev/null +++ b/khaosz/parallel/device.py @@ -0,0 +1,85 @@ +import os +import torch +import torch.distributed as dist +from dataclasses import dataclass +from typing import Callable, List + + +@dataclass +class DeviceStrategy: + """ + A class representing a device strategy. + + Attributes: + name: Name of the device backend (e.g., 'cuda', 'xpu'). + priority: Higher number means higher priority. + is_available: A callable that returns True if the device is available. + make_device: A callable that takes a rank (int) and returns a torch.device. + """ + name: str + priority: int + is_available: Callable[[], bool] + make_device: Callable[[int], torch.device] + + +class DeviceStrategyRegistry: + """ + A registry for device strategies that automatically selects the best available device. + And allows overriding the device backend via environment variable. + """ + + def __init__(self) -> None: + self._strategies: List[DeviceStrategy] = [] + + # Register default strategies + self.register(DeviceStrategy( + name="cuda", + priority=100, + is_available=torch.cuda.is_available, + make_device=lambda rank: torch.device(f"cuda:{rank}") + )) + + self.register(DeviceStrategy( + name="xpu", + priority=90, + is_available=torch.xpu.is_available, + make_device=lambda rank: torch.device(f"xpu:{rank}") + )) + + self.register(DeviceStrategy( + name="mps", + priority=80, + is_available=torch.mps.is_available, + make_device=lambda _: torch.device("mps") # MPS ignores rank + )) + + self.register(DeviceStrategy( + name="cpu", + priority=0, + is_available=lambda: True, + make_device=lambda _: torch.device("cpu") + )) + + def register(self, strategy: DeviceStrategy): + self._strategies.append(strategy) + + def get_current_device(self) -> torch.device: + """Return the best available device for the current process.""" + # Allow environment override (for debugging) + override = os.getenv("TORCH_DEVICE_OVERRIDE") + if override: + return torch.device(override) + + sorted_strategies = sorted(self._strategies, key=lambda s: -s.priority) + + rank = 0 + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + + for strategy in sorted_strategies: + if strategy.is_available(): + return strategy.make_device(rank) + + raise RuntimeError("No device backend is available, including CPU.") + +device_strategy_registry = DeviceStrategyRegistry() \ No newline at end of file diff --git a/khaosz/parallel/setup.py b/khaosz/parallel/setup.py index 9cf7099..875639f 100644 --- a/khaosz/parallel/setup.py +++ b/khaosz/parallel/setup.py @@ -1,3 +1,4 @@ + import os import torch import torch.distributed as dist @@ -6,7 +7,10 @@ import torch.multiprocessing as mp from typing import Callable from functools import wraps from contextlib import contextmanager +from khaosz.parallel.device import device_strategy_registry +def get_current_device(): + return device_strategy_registry.get_current_device() def get_world_size() -> int: if dist.is_available() and dist.is_initialized(): @@ -20,16 +24,6 @@ def get_rank() -> int: else: return 0 -def get_current_device(): - if torch.cuda.is_available(): - return torch.device(f"cuda:{torch.cuda.current_device()}") - elif hasattr(torch, 'xpu') and torch.xpu.is_available(): - return torch.device(f"xpu:{torch.xpu.current_device()}") - elif hasattr(torch, 'mps') and torch.mps.is_available(): - return torch.device("mps") - else: - return torch.device("cpu") - @contextmanager def setup_parallel( rank: int,