feat(parallel): 改进设备策略注册表与并行设置功能
This commit is contained in:
parent
3ac38a7ebc
commit
eab7a51bb6
|
|
@ -11,3 +11,5 @@ params/*
|
||||||
# build file
|
# build file
|
||||||
build
|
build
|
||||||
*.egg-info
|
*.egg-info
|
||||||
|
|
||||||
|
*.ipynb
|
||||||
|
|
@ -2,7 +2,7 @@ import os
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Callable, List
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -28,10 +28,20 @@ class DeviceStrategyRegistry:
|
||||||
And allows overriding the device backend via environment variable.
|
And allows overriding the device backend via environment variable.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_instance: Optional["DeviceStrategyRegistry"] = None
|
||||||
|
_initialized: bool = False
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
self._strategies: List[DeviceStrategy] = []
|
self._strategies: List[DeviceStrategy] = []
|
||||||
|
|
||||||
# Register default strategies
|
|
||||||
self.register(DeviceStrategy(
|
self.register(DeviceStrategy(
|
||||||
name="cuda",
|
name="cuda",
|
||||||
priority=100,
|
priority=100,
|
||||||
|
|
@ -60,26 +70,31 @@ class DeviceStrategyRegistry:
|
||||||
make_device=lambda _: torch.device("cpu")
|
make_device=lambda _: torch.device("cpu")
|
||||||
))
|
))
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
def register(self, strategy: DeviceStrategy):
|
def register(self, strategy: DeviceStrategy):
|
||||||
self._strategies.append(strategy)
|
self._strategies.append(strategy)
|
||||||
|
|
||||||
def get_current_device(self) -> torch.device:
|
def get_current_device(self) -> torch.device:
|
||||||
"""Return the best available device for the current process."""
|
"""Return the best available device for the current process."""
|
||||||
# Allow environment override (for debugging)
|
|
||||||
override = os.getenv("TORCH_DEVICE_OVERRIDE")
|
override = os.getenv("TORCH_DEVICE_OVERRIDE")
|
||||||
if override:
|
|
||||||
return torch.device(override)
|
|
||||||
|
|
||||||
sorted_strategies = sorted(self._strategies, key=lambda s: -s.priority)
|
sorted_strategies = sorted(self._strategies, key=lambda s: -s.priority)
|
||||||
|
|
||||||
rank = 0
|
rank = 0
|
||||||
|
|
||||||
if dist.is_available() and dist.is_initialized():
|
if dist.is_available() and dist.is_initialized():
|
||||||
rank = dist.get_rank()
|
rank = os.environ["LOCAL_RANK"]
|
||||||
|
|
||||||
|
if override:
|
||||||
|
return torch.device(override, rank)
|
||||||
|
|
||||||
|
|
||||||
for strategy in sorted_strategies:
|
for strategy in sorted_strategies:
|
||||||
if strategy.is_available():
|
if strategy.is_available():
|
||||||
|
|
||||||
return strategy.make_device(rank)
|
return strategy.make_device(rank)
|
||||||
|
|
||||||
raise RuntimeError("No device backend is available, including CPU.")
|
raise RuntimeError("No device backend is available, including CPU.")
|
||||||
|
|
||||||
device_strategy_registry = DeviceStrategyRegistry()
|
|
||||||
|
device_registry = DeviceStrategyRegistry()
|
||||||
|
|
@ -1,16 +1,16 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
from typing import Callable
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from khaosz.parallel.device import device_strategy_registry
|
from typing import Callable, List, Optional
|
||||||
|
from khaosz.parallel.device import device_registry
|
||||||
|
|
||||||
|
|
||||||
def get_current_device():
|
def get_current_device():
|
||||||
return device_strategy_registry.get_current_device()
|
return device_registry.get_current_device()
|
||||||
|
|
||||||
def get_world_size() -> int:
|
def get_world_size() -> int:
|
||||||
if dist.is_available() and dist.is_initialized():
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
|
@ -30,7 +30,8 @@ def setup_parallel(
|
||||||
world_size: int,
|
world_size: int,
|
||||||
backend: str = "nccl",
|
backend: str = "nccl",
|
||||||
master_addr: str = "localhost",
|
master_addr: str = "localhost",
|
||||||
master_port: str = "29500"
|
master_port: str = "29500",
|
||||||
|
avail_ids: Optional[List[int]] = None
|
||||||
):
|
):
|
||||||
|
|
||||||
if dist.is_available() and dist.is_initialized():
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
|
@ -41,9 +42,13 @@ def setup_parallel(
|
||||||
yield None
|
yield None
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if avail_ids is None:
|
||||||
|
avail_ids = [i for i in range(world_size)]
|
||||||
|
|
||||||
|
rank = avail_ids[rank % len(avail_ids)]
|
||||||
|
|
||||||
os.environ['MASTER_ADDR'] = master_addr
|
os.environ['MASTER_ADDR'] = master_addr
|
||||||
os.environ['MASTER_PORT'] = master_port
|
os.environ['MASTER_PORT'] = master_port
|
||||||
os.environ['RANK'] = str(rank)
|
|
||||||
os.environ['WORLD_SIZE'] = str(world_size)
|
os.environ['WORLD_SIZE'] = str(world_size)
|
||||||
os.environ['LOCAL_RANK'] = str(rank)
|
os.environ['LOCAL_RANK'] = str(rank)
|
||||||
|
|
||||||
|
|
@ -82,11 +87,40 @@ def only_on_rank(rank, sync=False):
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def wrapper_spawn_func(rank, world_size, backend, func, kwargs):
|
def wrapper_spawn_func(
|
||||||
with setup_parallel(rank, world_size, backend):
|
rank: int,
|
||||||
|
world_size: int,
|
||||||
|
backend: str,
|
||||||
|
master_addr: str,
|
||||||
|
master_port: str,
|
||||||
|
avail_ids: List[int],
|
||||||
|
func: Callable,
|
||||||
|
kwargs: dict
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
with setup_parallel(
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
backend=backend,
|
||||||
|
master_addr=master_addr,
|
||||||
|
master_port=master_port,
|
||||||
|
avail_ids=avail_ids
|
||||||
|
):
|
||||||
func(**kwargs)
|
func(**kwargs)
|
||||||
|
|
||||||
def spawn_parallel_fn(func: Callable, world_size: int, backend: str, **kwargs):
|
except Exception as e:
|
||||||
|
print(f"Error in rank {rank}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def spawn_parallel_fn(
|
||||||
|
func: Callable,
|
||||||
|
world_size: int,
|
||||||
|
backend: str = "nccl",
|
||||||
|
master_addr: str = "localhost",
|
||||||
|
master_port: str = "29500",
|
||||||
|
avail_ids: Optional[List[int]] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
|
||||||
if world_size == 1:
|
if world_size == 1:
|
||||||
func(**kwargs)
|
func(**kwargs)
|
||||||
|
|
@ -97,9 +131,12 @@ def spawn_parallel_fn(func: Callable, world_size: int, backend: str, **kwargs):
|
||||||
if key in os.environ:
|
if key in os.environ:
|
||||||
del os.environ[key]
|
del os.environ[key]
|
||||||
|
|
||||||
|
wrapper_spawn_func_args = (world_size, backend,
|
||||||
|
master_addr, master_port, avail_ids, func, kwargs)
|
||||||
|
|
||||||
mp.spawn(
|
mp.spawn(
|
||||||
wrapper_spawn_func,
|
wrapper_spawn_func,
|
||||||
nprocs=world_size,
|
nprocs=world_size,
|
||||||
args=(world_size, backend, func, kwargs),
|
args=wrapper_spawn_func_args,
|
||||||
join=True
|
join=True
|
||||||
)
|
)
|
||||||
Loading…
Reference in New Issue