refactor(paralell): 优化并行设备指定方法
This commit is contained in:
parent
cfa3cf7daa
commit
fd7ee2895a
|
|
@ -4,7 +4,7 @@ from torch.optim import Optimizer
|
|||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -101,7 +101,7 @@ class TrainConfig:
|
|||
default="29500",
|
||||
metadata={"help": "Master port for distributed training."}
|
||||
)
|
||||
parallel_fn: Optional[Callable] = field(
|
||||
parallel_wrapper: Optional[Callable] = field(
|
||||
default=None,
|
||||
metadata={"help": "Parallel function for training."}
|
||||
)
|
||||
|
|
@ -115,6 +115,14 @@ class TrainConfig:
|
|||
)
|
||||
|
||||
# others
|
||||
device_ids: Optional[List[int]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Device ids for distributed training."}
|
||||
)
|
||||
device_type: str = field(
|
||||
default="cuda",
|
||||
metadata={"help": "Device type for distributed training."}
|
||||
)
|
||||
extra_kwargs: dict = field(
|
||||
default_factory=dict,
|
||||
metadata={"help": "Other arguments."}
|
||||
|
|
@ -138,3 +146,5 @@ class TrainConfig:
|
|||
raise ValueError("Distributed training requires optimizer and scheduler factories.")
|
||||
elif self.nprocs == 1 and not argument_case:
|
||||
raise ValueError("Single process training requires optimizer and scheduler arguments.")
|
||||
|
||||
|
||||
|
|
@ -1,100 +0,0 @@
|
|||
import os
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceStrategy:
|
||||
"""
|
||||
A class representing a device strategy.
|
||||
|
||||
Attributes:
|
||||
name: Name of the device backend (e.g., 'cuda', 'xpu').
|
||||
priority: Higher number means higher priority.
|
||||
is_available: A callable that returns True if the device is available.
|
||||
make_device: A callable that takes a rank (int) and returns a torch.device.
|
||||
"""
|
||||
name: str
|
||||
priority: int
|
||||
is_available: Callable[[], bool]
|
||||
make_device: Callable[[int], torch.device]
|
||||
|
||||
|
||||
class DeviceStrategyRegistry:
|
||||
"""
|
||||
A registry for device strategies that automatically selects the best available device.
|
||||
And allows overriding the device backend via environment variable.
|
||||
"""
|
||||
|
||||
_instance: Optional["DeviceStrategyRegistry"] = None
|
||||
_initialized: bool = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._strategies: List[DeviceStrategy] = []
|
||||
|
||||
self.register(DeviceStrategy(
|
||||
name="cuda",
|
||||
priority=100,
|
||||
is_available=torch.cuda.is_available,
|
||||
make_device=lambda rank: torch.device(f"cuda:{rank}")
|
||||
))
|
||||
|
||||
self.register(DeviceStrategy(
|
||||
name="xpu",
|
||||
priority=90,
|
||||
is_available=torch.xpu.is_available,
|
||||
make_device=lambda rank: torch.device(f"xpu:{rank}")
|
||||
))
|
||||
|
||||
self.register(DeviceStrategy(
|
||||
name="mps",
|
||||
priority=80,
|
||||
is_available=torch.mps.is_available,
|
||||
make_device=lambda _: torch.device("mps") # MPS ignores rank
|
||||
))
|
||||
|
||||
self.register(DeviceStrategy(
|
||||
name="cpu",
|
||||
priority=0,
|
||||
is_available=lambda: True,
|
||||
make_device=lambda _: torch.device("cpu")
|
||||
))
|
||||
|
||||
self._initialized = True
|
||||
|
||||
def register(self, strategy: DeviceStrategy):
|
||||
self._strategies.append(strategy)
|
||||
|
||||
def get_current_device(self) -> torch.device:
|
||||
"""Return the best available device for the current process."""
|
||||
override = os.getenv("TORCH_DEVICE_OVERRIDE")
|
||||
sorted_strategies = sorted(self._strategies, key=lambda s: -s.priority)
|
||||
|
||||
rank = 0
|
||||
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
rank = os.environ["LOCAL_RANK"]
|
||||
|
||||
if override:
|
||||
return torch.device(override, rank)
|
||||
|
||||
|
||||
for strategy in sorted_strategies:
|
||||
if strategy.is_available():
|
||||
|
||||
return strategy.make_device(rank)
|
||||
|
||||
raise RuntimeError("No device backend is available, including CPU.")
|
||||
|
||||
|
||||
device_registry = DeviceStrategyRegistry()
|
||||
|
|
@ -6,11 +6,10 @@ import torch.multiprocessing as mp
|
|||
from functools import wraps
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, List, Optional
|
||||
from khaosz.parallel.device import device_registry
|
||||
|
||||
|
||||
def get_current_device():
|
||||
return device_registry.get_current_device()
|
||||
return os.environ["LOCAL_DEVICE"]
|
||||
|
||||
def get_world_size() -> int:
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
|
|
@ -31,7 +30,8 @@ def setup_parallel(
|
|||
backend: str = "nccl",
|
||||
master_addr: str = "localhost",
|
||||
master_port: str = "29500",
|
||||
avail_ids: Optional[List[int]] = None
|
||||
device_type: str = "cuda",
|
||||
device_ids: Optional[List[int]] = None
|
||||
):
|
||||
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
|
|
@ -42,28 +42,31 @@ def setup_parallel(
|
|||
yield None
|
||||
return
|
||||
|
||||
if avail_ids is None:
|
||||
avail_ids = [i for i in range(world_size)]
|
||||
if device_ids is None:
|
||||
device_ids = [i for i in range(world_size)]
|
||||
|
||||
rank = avail_ids[rank % len(avail_ids)]
|
||||
rank = device_ids[rank % len(device_ids)]
|
||||
device_id = torch.device(device_type, device_ids[rank])
|
||||
|
||||
os.environ['MASTER_ADDR'] = master_addr
|
||||
os.environ['MASTER_PORT'] = master_port
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||
|
||||
dist.init_process_group(
|
||||
backend=backend,
|
||||
init_method=f"tcp://{master_addr}:{master_port}",
|
||||
rank=rank,
|
||||
world_size=world_size
|
||||
world_size=world_size,
|
||||
backend=backend,
|
||||
device_id=device_id
|
||||
)
|
||||
|
||||
try:
|
||||
if backend == "nccl" and torch.cuda.is_available():
|
||||
torch.cuda.set_device(rank)
|
||||
torch.cuda.set_device(device_id)
|
||||
elif backend == "ccl" and hasattr(torch, 'xpu') and torch.xpu.is_available():
|
||||
torch.xpu.set_device(rank)
|
||||
torch.xpu.set_device(device_id)
|
||||
|
||||
yield dist.group.WORLD
|
||||
finally:
|
||||
|
|
@ -93,7 +96,8 @@ def wrapper_spawn_func(
|
|||
backend: str,
|
||||
master_addr: str,
|
||||
master_port: str,
|
||||
avail_ids: List[int],
|
||||
device_type: str,
|
||||
device_ids: List[int],
|
||||
func: Callable,
|
||||
kwargs: dict
|
||||
):
|
||||
|
|
@ -104,7 +108,8 @@ def wrapper_spawn_func(
|
|||
backend=backend,
|
||||
master_addr=master_addr,
|
||||
master_port=master_port,
|
||||
avail_ids=avail_ids
|
||||
device_type=device_type,
|
||||
device_ids=device_ids
|
||||
):
|
||||
func(**kwargs)
|
||||
|
||||
|
|
@ -118,21 +123,25 @@ def spawn_parallel_fn(
|
|||
backend: str = "nccl",
|
||||
master_addr: str = "localhost",
|
||||
master_port: str = "29500",
|
||||
avail_ids: Optional[List[int]] = None,
|
||||
device_type: str = "cuda",
|
||||
device_ids: Optional[List[int]] = None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
if world_size == 1:
|
||||
func(**kwargs)
|
||||
return
|
||||
|
||||
# clear environment variables
|
||||
for key in ['MASTER_ADDR', 'MASTER_PORT', 'RANK', 'WORLD_SIZE', 'LOCAL_RANK']:
|
||||
for key in ['MASTER_ADDR', 'MASTER_PORT', 'RANK', 'WORLD_SIZE', 'LOCAL_RANK', 'LOCAL_DEVICE']:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
|
||||
wrapper_spawn_func_args = (world_size, backend,
|
||||
master_addr, master_port, avail_ids, func, kwargs)
|
||||
if world_size == 1:
|
||||
device_ids = device_ids or [0]
|
||||
deice_id = torch.device(device_type, device_ids[0])
|
||||
os.environ["LOCAL_DEVICE"] = str(deice_id)
|
||||
|
||||
func(**kwargs)
|
||||
return
|
||||
|
||||
wrapper_spawn_func_args = (world_size, backend, master_addr, master_port,
|
||||
device_type, device_ids, func, kwargs)
|
||||
|
||||
mp.spawn(
|
||||
wrapper_spawn_func,
|
||||
|
|
|
|||
|
|
@ -88,13 +88,13 @@ class TrainContextBuilder:
|
|||
)
|
||||
return self
|
||||
|
||||
def with_parallel_fn(self) -> Self:
|
||||
def with_parallel(self) -> Self:
|
||||
device = get_current_device()
|
||||
self._context.model = self._context.model.to(device=device)
|
||||
|
||||
if self.config.nprocs > 1:
|
||||
|
||||
fn = self.config.parallel_fn
|
||||
fn = self.config.parallel_wrapper
|
||||
optimizer_fn = self.config.optimizer_factory
|
||||
scheduler_fn = self.config.scheduler_factory
|
||||
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ class Trainer:
|
|||
.with_checkpoint(checkpoint)
|
||||
.with_dataloader()
|
||||
.with_strategy()
|
||||
.with_parallel_fn()
|
||||
.with_parallel()
|
||||
.build())
|
||||
|
||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import torch.nn as nn
|
|||
import torch.optim as optim
|
||||
import torch.distributed.fsdp as fsdp
|
||||
|
||||
from typing import List, Optional
|
||||
from functools import partial
|
||||
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
||||
from khaosz.trainer import Trainer, SchedulerFactory
|
||||
|
|
@ -12,6 +13,15 @@ from khaosz.data import DatasetLoader
|
|||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
def parse_device_ids(s: Optional[str]) -> Optional[List[int]]:
|
||||
if s is None or s.strip() == "":
|
||||
return None
|
||||
try:
|
||||
return [int(x.strip()) for x in s.split(",") if x.strip()]
|
||||
except ValueError as e:
|
||||
raise argparse.ArgumentTypeError(f"Invalid device_ids format: {s}. Expected comma-separated integers like '0,1,2'.")
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Train the Transformer model.")
|
||||
|
||||
parser.add_argument("--train_type",choices=["seq", "sft", "dpo"], help="Train type.")
|
||||
|
|
@ -40,6 +50,8 @@ def parse_args() -> argparse.Namespace:
|
|||
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
|
||||
|
||||
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
|
||||
parser.add_argument("--device_ids", type=parse_device_ids, default=None, help="Device IDs to use.")
|
||||
parser.add_argument("--device_type", type=str, default="cuda", help="Device type to use.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -88,7 +100,9 @@ def train(
|
|||
pin_memory: bool,
|
||||
window_size: int,
|
||||
stride: int,
|
||||
nprocs: int
|
||||
nprocs: int,
|
||||
device_ids: List[int],
|
||||
device_type: str,
|
||||
):
|
||||
assert train_type in ["seq", "sft", "dpo"]
|
||||
assert os.path.exists(param_path)
|
||||
|
|
@ -147,10 +161,12 @@ def train(
|
|||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
nprocs=nprocs,
|
||||
parallel_wrapper=fsdp_wrap,
|
||||
optimizer_factory=optimizer_fn,
|
||||
scheduler_factory=scheduler_fn,
|
||||
device_ids=device_ids,
|
||||
device_type=device_type,
|
||||
extra_kwargs=kwargs,
|
||||
parallel_fn=fsdp_wrap
|
||||
)
|
||||
|
||||
trainer = Trainer(train_config)
|
||||
|
|
|
|||
Loading…
Reference in New Issue