refactor(paralell): 优化并行设备指定方法

This commit is contained in:
ViperEkura 2025-12-26 20:54:33 +08:00
parent cfa3cf7daa
commit fd7ee2895a
6 changed files with 65 additions and 130 deletions

View File

@ -4,7 +4,7 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler from torch.optim.lr_scheduler import LRScheduler
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Callable, Optional from typing import Callable, List, Optional
@dataclass @dataclass
@ -101,7 +101,7 @@ class TrainConfig:
default="29500", default="29500",
metadata={"help": "Master port for distributed training."} metadata={"help": "Master port for distributed training."}
) )
parallel_fn: Optional[Callable] = field( parallel_wrapper: Optional[Callable] = field(
default=None, default=None,
metadata={"help": "Parallel function for training."} metadata={"help": "Parallel function for training."}
) )
@ -115,6 +115,14 @@ class TrainConfig:
) )
# others # others
device_ids: Optional[List[int]] = field(
default=None,
metadata={"help": "Device ids for distributed training."}
)
device_type: str = field(
default="cuda",
metadata={"help": "Device type for distributed training."}
)
extra_kwargs: dict = field( extra_kwargs: dict = field(
default_factory=dict, default_factory=dict,
metadata={"help": "Other arguments."} metadata={"help": "Other arguments."}
@ -138,3 +146,5 @@ class TrainConfig:
raise ValueError("Distributed training requires optimizer and scheduler factories.") raise ValueError("Distributed training requires optimizer and scheduler factories.")
elif self.nprocs == 1 and not argument_case: elif self.nprocs == 1 and not argument_case:
raise ValueError("Single process training requires optimizer and scheduler arguments.") raise ValueError("Single process training requires optimizer and scheduler arguments.")

View File

@ -1,100 +0,0 @@
import os
import torch
import torch.distributed as dist
from dataclasses import dataclass
from typing import Callable, List, Optional
@dataclass
class DeviceStrategy:
"""
A class representing a device strategy.
Attributes:
name: Name of the device backend (e.g., 'cuda', 'xpu').
priority: Higher number means higher priority.
is_available: A callable that returns True if the device is available.
make_device: A callable that takes a rank (int) and returns a torch.device.
"""
name: str
priority: int
is_available: Callable[[], bool]
make_device: Callable[[int], torch.device]
class DeviceStrategyRegistry:
"""
A registry for device strategies that automatically selects the best available device.
And allows overriding the device backend via environment variable.
"""
_instance: Optional["DeviceStrategyRegistry"] = None
_initialized: bool = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
if self._initialized:
return
self._strategies: List[DeviceStrategy] = []
self.register(DeviceStrategy(
name="cuda",
priority=100,
is_available=torch.cuda.is_available,
make_device=lambda rank: torch.device(f"cuda:{rank}")
))
self.register(DeviceStrategy(
name="xpu",
priority=90,
is_available=torch.xpu.is_available,
make_device=lambda rank: torch.device(f"xpu:{rank}")
))
self.register(DeviceStrategy(
name="mps",
priority=80,
is_available=torch.mps.is_available,
make_device=lambda _: torch.device("mps") # MPS ignores rank
))
self.register(DeviceStrategy(
name="cpu",
priority=0,
is_available=lambda: True,
make_device=lambda _: torch.device("cpu")
))
self._initialized = True
def register(self, strategy: DeviceStrategy):
self._strategies.append(strategy)
def get_current_device(self) -> torch.device:
"""Return the best available device for the current process."""
override = os.getenv("TORCH_DEVICE_OVERRIDE")
sorted_strategies = sorted(self._strategies, key=lambda s: -s.priority)
rank = 0
if dist.is_available() and dist.is_initialized():
rank = os.environ["LOCAL_RANK"]
if override:
return torch.device(override, rank)
for strategy in sorted_strategies:
if strategy.is_available():
return strategy.make_device(rank)
raise RuntimeError("No device backend is available, including CPU.")
device_registry = DeviceStrategyRegistry()

View File

@ -6,11 +6,10 @@ import torch.multiprocessing as mp
from functools import wraps from functools import wraps
from contextlib import contextmanager from contextlib import contextmanager
from typing import Callable, List, Optional from typing import Callable, List, Optional
from khaosz.parallel.device import device_registry
def get_current_device(): def get_current_device():
return device_registry.get_current_device() return os.environ["LOCAL_DEVICE"]
def get_world_size() -> int: def get_world_size() -> int:
if dist.is_available() and dist.is_initialized(): if dist.is_available() and dist.is_initialized():
@ -31,7 +30,8 @@ def setup_parallel(
backend: str = "nccl", backend: str = "nccl",
master_addr: str = "localhost", master_addr: str = "localhost",
master_port: str = "29500", master_port: str = "29500",
avail_ids: Optional[List[int]] = None device_type: str = "cuda",
device_ids: Optional[List[int]] = None
): ):
if dist.is_available() and dist.is_initialized(): if dist.is_available() and dist.is_initialized():
@ -42,28 +42,31 @@ def setup_parallel(
yield None yield None
return return
if avail_ids is None: if device_ids is None:
avail_ids = [i for i in range(world_size)] device_ids = [i for i in range(world_size)]
rank = avail_ids[rank % len(avail_ids)] rank = device_ids[rank % len(device_ids)]
device_id = torch.device(device_type, device_ids[rank])
os.environ['MASTER_ADDR'] = master_addr os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = master_port os.environ['MASTER_PORT'] = master_port
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['LOCAL_RANK'] = str(rank) os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ["LOCAL_DEVICE"] = str(device_id)
dist.init_process_group( dist.init_process_group(
backend=backend,
init_method=f"tcp://{master_addr}:{master_port}",
rank=rank, rank=rank,
world_size=world_size world_size=world_size,
backend=backend,
device_id=device_id
) )
try: try:
if backend == "nccl" and torch.cuda.is_available(): if backend == "nccl" and torch.cuda.is_available():
torch.cuda.set_device(rank) torch.cuda.set_device(device_id)
elif backend == "ccl" and hasattr(torch, 'xpu') and torch.xpu.is_available(): elif backend == "ccl" and hasattr(torch, 'xpu') and torch.xpu.is_available():
torch.xpu.set_device(rank) torch.xpu.set_device(device_id)
yield dist.group.WORLD yield dist.group.WORLD
finally: finally:
@ -92,8 +95,9 @@ def wrapper_spawn_func(
world_size: int, world_size: int,
backend: str, backend: str,
master_addr: str, master_addr: str,
master_port: str, master_port: str,
avail_ids: List[int], device_type: str,
device_ids: List[int],
func: Callable, func: Callable,
kwargs: dict kwargs: dict
): ):
@ -104,7 +108,8 @@ def wrapper_spawn_func(
backend=backend, backend=backend,
master_addr=master_addr, master_addr=master_addr,
master_port=master_port, master_port=master_port,
avail_ids=avail_ids device_type=device_type,
device_ids=device_ids
): ):
func(**kwargs) func(**kwargs)
@ -118,22 +123,26 @@ def spawn_parallel_fn(
backend: str = "nccl", backend: str = "nccl",
master_addr: str = "localhost", master_addr: str = "localhost",
master_port: str = "29500", master_port: str = "29500",
avail_ids: Optional[List[int]] = None, device_type: str = "cuda",
device_ids: Optional[List[int]] = None,
**kwargs **kwargs
): ):
# clear environment variables
for key in ['MASTER_ADDR', 'MASTER_PORT', 'RANK', 'WORLD_SIZE', 'LOCAL_RANK', 'LOCAL_DEVICE']:
if key in os.environ:
del os.environ[key]
if world_size == 1: if world_size == 1:
device_ids = device_ids or [0]
deice_id = torch.device(device_type, device_ids[0])
os.environ["LOCAL_DEVICE"] = str(deice_id)
func(**kwargs) func(**kwargs)
return return
# clear environment variables wrapper_spawn_func_args = (world_size, backend, master_addr, master_port,
for key in ['MASTER_ADDR', 'MASTER_PORT', 'RANK', 'WORLD_SIZE', 'LOCAL_RANK']: device_type, device_ids, func, kwargs)
if key in os.environ:
del os.environ[key]
wrapper_spawn_func_args = (world_size, backend,
master_addr, master_port, avail_ids, func, kwargs)
mp.spawn( mp.spawn(
wrapper_spawn_func, wrapper_spawn_func,
nprocs=world_size, nprocs=world_size,

View File

@ -88,13 +88,13 @@ class TrainContextBuilder:
) )
return self return self
def with_parallel_fn(self) -> Self: def with_parallel(self) -> Self:
device = get_current_device() device = get_current_device()
self._context.model = self._context.model.to(device=device) self._context.model = self._context.model.to(device=device)
if self.config.nprocs > 1: if self.config.nprocs > 1:
fn = self.config.parallel_fn fn = self.config.parallel_wrapper
optimizer_fn = self.config.optimizer_factory optimizer_fn = self.config.optimizer_factory
scheduler_fn = self.config.scheduler_factory scheduler_fn = self.config.scheduler_factory

View File

@ -38,7 +38,7 @@ class Trainer:
.with_checkpoint(checkpoint) .with_checkpoint(checkpoint)
.with_dataloader() .with_dataloader()
.with_strategy() .with_strategy()
.with_parallel_fn() .with_parallel()
.build()) .build())
def _call_callbacks(self, method_name: str, context: TrainContext): def _call_callbacks(self, method_name: str, context: TrainContext):

View File

@ -5,6 +5,7 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import torch.distributed.fsdp as fsdp import torch.distributed.fsdp as fsdp
from typing import List, Optional
from functools import partial from functools import partial
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
from khaosz.trainer import Trainer, SchedulerFactory from khaosz.trainer import Trainer, SchedulerFactory
@ -12,6 +13,15 @@ from khaosz.data import DatasetLoader
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:
def parse_device_ids(s: Optional[str]) -> Optional[List[int]]:
if s is None or s.strip() == "":
return None
try:
return [int(x.strip()) for x in s.split(",") if x.strip()]
except ValueError as e:
raise argparse.ArgumentTypeError(f"Invalid device_ids format: {s}. Expected comma-separated integers like '0,1,2'.")
parser = argparse.ArgumentParser(description="Train the Transformer model.") parser = argparse.ArgumentParser(description="Train the Transformer model.")
parser.add_argument("--train_type",choices=["seq", "sft", "dpo"], help="Train type.") parser.add_argument("--train_type",choices=["seq", "sft", "dpo"], help="Train type.")
@ -40,6 +50,8 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.") parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.") parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
parser.add_argument("--device_ids", type=parse_device_ids, default=None, help="Device IDs to use.")
parser.add_argument("--device_type", type=str, default="cuda", help="Device type to use.")
args = parser.parse_args() args = parser.parse_args()
@ -88,7 +100,9 @@ def train(
pin_memory: bool, pin_memory: bool,
window_size: int, window_size: int,
stride: int, stride: int,
nprocs: int nprocs: int,
device_ids: List[int],
device_type: str,
): ):
assert train_type in ["seq", "sft", "dpo"] assert train_type in ["seq", "sft", "dpo"]
assert os.path.exists(param_path) assert os.path.exists(param_path)
@ -147,10 +161,12 @@ def train(
num_workers=num_workers, num_workers=num_workers,
pin_memory=pin_memory, pin_memory=pin_memory,
nprocs=nprocs, nprocs=nprocs,
parallel_wrapper=fsdp_wrap,
optimizer_factory=optimizer_fn, optimizer_factory=optimizer_fn,
scheduler_factory=scheduler_fn, scheduler_factory=scheduler_fn,
device_ids=device_ids,
device_type=device_type,
extra_kwargs=kwargs, extra_kwargs=kwargs,
parallel_fn=fsdp_wrap
) )
trainer = Trainer(train_config) trainer = Trainer(train_config)