diff --git a/.gitignore b/.gitignore index 53ab585..e13969b 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,6 @@ params/* # build file build -*.egg-info \ No newline at end of file +*.egg-info + +*.ipynb \ No newline at end of file diff --git a/khaosz/parallel/device.py b/khaosz/parallel/device.py index 21876f0..d568f67 100644 --- a/khaosz/parallel/device.py +++ b/khaosz/parallel/device.py @@ -2,7 +2,7 @@ import os import torch import torch.distributed as dist from dataclasses import dataclass -from typing import Callable, List +from typing import Callable, List, Optional @dataclass @@ -27,11 +27,21 @@ class DeviceStrategyRegistry: A registry for device strategies that automatically selects the best available device. And allows overriding the device backend via environment variable. """ + + _instance: Optional["DeviceStrategyRegistry"] = None + _initialized: bool = False + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance def __init__(self) -> None: + if self._initialized: + return + self._strategies: List[DeviceStrategy] = [] - # Register default strategies self.register(DeviceStrategy( name="cuda", priority=100, @@ -59,27 +69,32 @@ class DeviceStrategyRegistry: is_available=lambda: True, make_device=lambda _: torch.device("cpu") )) + + self._initialized = True def register(self, strategy: DeviceStrategy): self._strategies.append(strategy) def get_current_device(self) -> torch.device: """Return the best available device for the current process.""" - # Allow environment override (for debugging) override = os.getenv("TORCH_DEVICE_OVERRIDE") - if override: - return torch.device(override) - sorted_strategies = sorted(self._strategies, key=lambda s: -s.priority) rank = 0 - if dist.is_available() and dist.is_initialized(): - rank = dist.get_rank() + if dist.is_available() and dist.is_initialized(): + rank = os.environ["LOCAL_RANK"] + + if override: + return torch.device(override, rank) + + for strategy in sorted_strategies: if strategy.is_available(): + return strategy.make_device(rank) raise RuntimeError("No device backend is available, including CPU.") -device_strategy_registry = DeviceStrategyRegistry() \ No newline at end of file + +device_registry = DeviceStrategyRegistry() \ No newline at end of file diff --git a/khaosz/parallel/setup.py b/khaosz/parallel/setup.py index 875639f..386a6dd 100644 --- a/khaosz/parallel/setup.py +++ b/khaosz/parallel/setup.py @@ -1,16 +1,16 @@ - import os import torch import torch.distributed as dist import torch.multiprocessing as mp -from typing import Callable from functools import wraps from contextlib import contextmanager -from khaosz.parallel.device import device_strategy_registry +from typing import Callable, List, Optional +from khaosz.parallel.device import device_registry + def get_current_device(): - return device_strategy_registry.get_current_device() + return device_registry.get_current_device() def get_world_size() -> int: if dist.is_available() and dist.is_initialized(): @@ -30,7 +30,8 @@ def setup_parallel( world_size: int, backend: str = "nccl", master_addr: str = "localhost", - master_port: str = "29500" + master_port: str = "29500", + avail_ids: Optional[List[int]] = None ): if dist.is_available() and dist.is_initialized(): @@ -41,9 +42,13 @@ def setup_parallel( yield None return + if avail_ids is None: + avail_ids = [i for i in range(world_size)] + + rank = avail_ids[rank % len(avail_ids)] + os.environ['MASTER_ADDR'] = master_addr os.environ['MASTER_PORT'] = master_port - os.environ['RANK'] = str(rank) os.environ['WORLD_SIZE'] = str(world_size) os.environ['LOCAL_RANK'] = str(rank) @@ -82,11 +87,40 @@ def only_on_rank(rank, sync=False): return decorator -def wrapper_spawn_func(rank, world_size, backend, func, kwargs): - with setup_parallel(rank, world_size, backend): - func(**kwargs) +def wrapper_spawn_func( + rank: int, + world_size: int, + backend: str, + master_addr: str, + master_port: str, + avail_ids: List[int], + func: Callable, + kwargs: dict +): + try: + with setup_parallel( + rank=rank, + world_size=world_size, + backend=backend, + master_addr=master_addr, + master_port=master_port, + avail_ids=avail_ids + ): + func(**kwargs) + + except Exception as e: + print(f"Error in rank {rank}: {e}") + raise -def spawn_parallel_fn(func: Callable, world_size: int, backend: str, **kwargs): +def spawn_parallel_fn( + func: Callable, + world_size: int, + backend: str = "nccl", + master_addr: str = "localhost", + master_port: str = "29500", + avail_ids: Optional[List[int]] = None, + **kwargs +): if world_size == 1: func(**kwargs) @@ -96,10 +130,13 @@ def spawn_parallel_fn(func: Callable, world_size: int, backend: str, **kwargs): for key in ['MASTER_ADDR', 'MASTER_PORT', 'RANK', 'WORLD_SIZE', 'LOCAL_RANK']: if key in os.environ: del os.environ[key] + + wrapper_spawn_func_args = (world_size, backend, + master_addr, master_port, avail_ids, func, kwargs) mp.spawn( - wrapper_spawn_func, - nprocs=world_size, - args=(world_size, backend, func, kwargs), + wrapper_spawn_func, + nprocs=world_size, + args=wrapper_spawn_func_args, join=True ) \ No newline at end of file