refactor(paralell): 优化并行设备指定方法
This commit is contained in:
parent
cfa3cf7daa
commit
fd7ee2895a
|
|
@ -4,7 +4,7 @@ from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Callable, Optional
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -101,7 +101,7 @@ class TrainConfig:
|
||||||
default="29500",
|
default="29500",
|
||||||
metadata={"help": "Master port for distributed training."}
|
metadata={"help": "Master port for distributed training."}
|
||||||
)
|
)
|
||||||
parallel_fn: Optional[Callable] = field(
|
parallel_wrapper: Optional[Callable] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Parallel function for training."}
|
metadata={"help": "Parallel function for training."}
|
||||||
)
|
)
|
||||||
|
|
@ -115,6 +115,14 @@ class TrainConfig:
|
||||||
)
|
)
|
||||||
|
|
||||||
# others
|
# others
|
||||||
|
device_ids: Optional[List[int]] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Device ids for distributed training."}
|
||||||
|
)
|
||||||
|
device_type: str = field(
|
||||||
|
default="cuda",
|
||||||
|
metadata={"help": "Device type for distributed training."}
|
||||||
|
)
|
||||||
extra_kwargs: dict = field(
|
extra_kwargs: dict = field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
metadata={"help": "Other arguments."}
|
metadata={"help": "Other arguments."}
|
||||||
|
|
@ -138,3 +146,5 @@ class TrainConfig:
|
||||||
raise ValueError("Distributed training requires optimizer and scheduler factories.")
|
raise ValueError("Distributed training requires optimizer and scheduler factories.")
|
||||||
elif self.nprocs == 1 and not argument_case:
|
elif self.nprocs == 1 and not argument_case:
|
||||||
raise ValueError("Single process training requires optimizer and scheduler arguments.")
|
raise ValueError("Single process training requires optimizer and scheduler arguments.")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,100 +0,0 @@
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Callable, List, Optional
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DeviceStrategy:
|
|
||||||
"""
|
|
||||||
A class representing a device strategy.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
name: Name of the device backend (e.g., 'cuda', 'xpu').
|
|
||||||
priority: Higher number means higher priority.
|
|
||||||
is_available: A callable that returns True if the device is available.
|
|
||||||
make_device: A callable that takes a rank (int) and returns a torch.device.
|
|
||||||
"""
|
|
||||||
name: str
|
|
||||||
priority: int
|
|
||||||
is_available: Callable[[], bool]
|
|
||||||
make_device: Callable[[int], torch.device]
|
|
||||||
|
|
||||||
|
|
||||||
class DeviceStrategyRegistry:
|
|
||||||
"""
|
|
||||||
A registry for device strategies that automatically selects the best available device.
|
|
||||||
And allows overriding the device backend via environment variable.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_instance: Optional["DeviceStrategyRegistry"] = None
|
|
||||||
_initialized: bool = False
|
|
||||||
|
|
||||||
def __new__(cls):
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = super().__new__(cls)
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
if self._initialized:
|
|
||||||
return
|
|
||||||
|
|
||||||
self._strategies: List[DeviceStrategy] = []
|
|
||||||
|
|
||||||
self.register(DeviceStrategy(
|
|
||||||
name="cuda",
|
|
||||||
priority=100,
|
|
||||||
is_available=torch.cuda.is_available,
|
|
||||||
make_device=lambda rank: torch.device(f"cuda:{rank}")
|
|
||||||
))
|
|
||||||
|
|
||||||
self.register(DeviceStrategy(
|
|
||||||
name="xpu",
|
|
||||||
priority=90,
|
|
||||||
is_available=torch.xpu.is_available,
|
|
||||||
make_device=lambda rank: torch.device(f"xpu:{rank}")
|
|
||||||
))
|
|
||||||
|
|
||||||
self.register(DeviceStrategy(
|
|
||||||
name="mps",
|
|
||||||
priority=80,
|
|
||||||
is_available=torch.mps.is_available,
|
|
||||||
make_device=lambda _: torch.device("mps") # MPS ignores rank
|
|
||||||
))
|
|
||||||
|
|
||||||
self.register(DeviceStrategy(
|
|
||||||
name="cpu",
|
|
||||||
priority=0,
|
|
||||||
is_available=lambda: True,
|
|
||||||
make_device=lambda _: torch.device("cpu")
|
|
||||||
))
|
|
||||||
|
|
||||||
self._initialized = True
|
|
||||||
|
|
||||||
def register(self, strategy: DeviceStrategy):
|
|
||||||
self._strategies.append(strategy)
|
|
||||||
|
|
||||||
def get_current_device(self) -> torch.device:
|
|
||||||
"""Return the best available device for the current process."""
|
|
||||||
override = os.getenv("TORCH_DEVICE_OVERRIDE")
|
|
||||||
sorted_strategies = sorted(self._strategies, key=lambda s: -s.priority)
|
|
||||||
|
|
||||||
rank = 0
|
|
||||||
|
|
||||||
if dist.is_available() and dist.is_initialized():
|
|
||||||
rank = os.environ["LOCAL_RANK"]
|
|
||||||
|
|
||||||
if override:
|
|
||||||
return torch.device(override, rank)
|
|
||||||
|
|
||||||
|
|
||||||
for strategy in sorted_strategies:
|
|
||||||
if strategy.is_available():
|
|
||||||
|
|
||||||
return strategy.make_device(rank)
|
|
||||||
|
|
||||||
raise RuntimeError("No device backend is available, including CPU.")
|
|
||||||
|
|
||||||
|
|
||||||
device_registry = DeviceStrategyRegistry()
|
|
||||||
|
|
@ -6,11 +6,10 @@ import torch.multiprocessing as mp
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Callable, List, Optional
|
from typing import Callable, List, Optional
|
||||||
from khaosz.parallel.device import device_registry
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_device():
|
def get_current_device():
|
||||||
return device_registry.get_current_device()
|
return os.environ["LOCAL_DEVICE"]
|
||||||
|
|
||||||
def get_world_size() -> int:
|
def get_world_size() -> int:
|
||||||
if dist.is_available() and dist.is_initialized():
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
|
@ -31,7 +30,8 @@ def setup_parallel(
|
||||||
backend: str = "nccl",
|
backend: str = "nccl",
|
||||||
master_addr: str = "localhost",
|
master_addr: str = "localhost",
|
||||||
master_port: str = "29500",
|
master_port: str = "29500",
|
||||||
avail_ids: Optional[List[int]] = None
|
device_type: str = "cuda",
|
||||||
|
device_ids: Optional[List[int]] = None
|
||||||
):
|
):
|
||||||
|
|
||||||
if dist.is_available() and dist.is_initialized():
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
|
@ -42,28 +42,31 @@ def setup_parallel(
|
||||||
yield None
|
yield None
|
||||||
return
|
return
|
||||||
|
|
||||||
if avail_ids is None:
|
if device_ids is None:
|
||||||
avail_ids = [i for i in range(world_size)]
|
device_ids = [i for i in range(world_size)]
|
||||||
|
|
||||||
rank = avail_ids[rank % len(avail_ids)]
|
rank = device_ids[rank % len(device_ids)]
|
||||||
|
device_id = torch.device(device_type, device_ids[rank])
|
||||||
|
|
||||||
os.environ['MASTER_ADDR'] = master_addr
|
os.environ['MASTER_ADDR'] = master_addr
|
||||||
os.environ['MASTER_PORT'] = master_port
|
os.environ['MASTER_PORT'] = master_port
|
||||||
os.environ['WORLD_SIZE'] = str(world_size)
|
|
||||||
os.environ['LOCAL_RANK'] = str(rank)
|
os.environ['LOCAL_RANK'] = str(rank)
|
||||||
|
os.environ['WORLD_SIZE'] = str(world_size)
|
||||||
|
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||||
|
|
||||||
dist.init_process_group(
|
dist.init_process_group(
|
||||||
backend=backend,
|
|
||||||
init_method=f"tcp://{master_addr}:{master_port}",
|
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size
|
world_size=world_size,
|
||||||
|
backend=backend,
|
||||||
|
device_id=device_id
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if backend == "nccl" and torch.cuda.is_available():
|
if backend == "nccl" and torch.cuda.is_available():
|
||||||
torch.cuda.set_device(rank)
|
torch.cuda.set_device(device_id)
|
||||||
elif backend == "ccl" and hasattr(torch, 'xpu') and torch.xpu.is_available():
|
elif backend == "ccl" and hasattr(torch, 'xpu') and torch.xpu.is_available():
|
||||||
torch.xpu.set_device(rank)
|
torch.xpu.set_device(device_id)
|
||||||
|
|
||||||
yield dist.group.WORLD
|
yield dist.group.WORLD
|
||||||
finally:
|
finally:
|
||||||
|
|
@ -92,8 +95,9 @@ def wrapper_spawn_func(
|
||||||
world_size: int,
|
world_size: int,
|
||||||
backend: str,
|
backend: str,
|
||||||
master_addr: str,
|
master_addr: str,
|
||||||
master_port: str,
|
master_port: str,
|
||||||
avail_ids: List[int],
|
device_type: str,
|
||||||
|
device_ids: List[int],
|
||||||
func: Callable,
|
func: Callable,
|
||||||
kwargs: dict
|
kwargs: dict
|
||||||
):
|
):
|
||||||
|
|
@ -104,7 +108,8 @@ def wrapper_spawn_func(
|
||||||
backend=backend,
|
backend=backend,
|
||||||
master_addr=master_addr,
|
master_addr=master_addr,
|
||||||
master_port=master_port,
|
master_port=master_port,
|
||||||
avail_ids=avail_ids
|
device_type=device_type,
|
||||||
|
device_ids=device_ids
|
||||||
):
|
):
|
||||||
func(**kwargs)
|
func(**kwargs)
|
||||||
|
|
||||||
|
|
@ -118,22 +123,26 @@ def spawn_parallel_fn(
|
||||||
backend: str = "nccl",
|
backend: str = "nccl",
|
||||||
master_addr: str = "localhost",
|
master_addr: str = "localhost",
|
||||||
master_port: str = "29500",
|
master_port: str = "29500",
|
||||||
avail_ids: Optional[List[int]] = None,
|
device_type: str = "cuda",
|
||||||
|
device_ids: Optional[List[int]] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
|
# clear environment variables
|
||||||
|
for key in ['MASTER_ADDR', 'MASTER_PORT', 'RANK', 'WORLD_SIZE', 'LOCAL_RANK', 'LOCAL_DEVICE']:
|
||||||
|
if key in os.environ:
|
||||||
|
del os.environ[key]
|
||||||
|
|
||||||
if world_size == 1:
|
if world_size == 1:
|
||||||
|
device_ids = device_ids or [0]
|
||||||
|
deice_id = torch.device(device_type, device_ids[0])
|
||||||
|
os.environ["LOCAL_DEVICE"] = str(deice_id)
|
||||||
|
|
||||||
func(**kwargs)
|
func(**kwargs)
|
||||||
return
|
return
|
||||||
|
|
||||||
# clear environment variables
|
wrapper_spawn_func_args = (world_size, backend, master_addr, master_port,
|
||||||
for key in ['MASTER_ADDR', 'MASTER_PORT', 'RANK', 'WORLD_SIZE', 'LOCAL_RANK']:
|
device_type, device_ids, func, kwargs)
|
||||||
if key in os.environ:
|
|
||||||
del os.environ[key]
|
|
||||||
|
|
||||||
wrapper_spawn_func_args = (world_size, backend,
|
|
||||||
master_addr, master_port, avail_ids, func, kwargs)
|
|
||||||
|
|
||||||
mp.spawn(
|
mp.spawn(
|
||||||
wrapper_spawn_func,
|
wrapper_spawn_func,
|
||||||
nprocs=world_size,
|
nprocs=world_size,
|
||||||
|
|
|
||||||
|
|
@ -88,13 +88,13 @@ class TrainContextBuilder:
|
||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_parallel_fn(self) -> Self:
|
def with_parallel(self) -> Self:
|
||||||
device = get_current_device()
|
device = get_current_device()
|
||||||
self._context.model = self._context.model.to(device=device)
|
self._context.model = self._context.model.to(device=device)
|
||||||
|
|
||||||
if self.config.nprocs > 1:
|
if self.config.nprocs > 1:
|
||||||
|
|
||||||
fn = self.config.parallel_fn
|
fn = self.config.parallel_wrapper
|
||||||
optimizer_fn = self.config.optimizer_factory
|
optimizer_fn = self.config.optimizer_factory
|
||||||
scheduler_fn = self.config.scheduler_factory
|
scheduler_fn = self.config.scheduler_factory
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ class Trainer:
|
||||||
.with_checkpoint(checkpoint)
|
.with_checkpoint(checkpoint)
|
||||||
.with_dataloader()
|
.with_dataloader()
|
||||||
.with_strategy()
|
.with_strategy()
|
||||||
.with_parallel_fn()
|
.with_parallel()
|
||||||
.build())
|
.build())
|
||||||
|
|
||||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import torch.distributed.fsdp as fsdp
|
import torch.distributed.fsdp as fsdp
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
||||||
from khaosz.trainer import Trainer, SchedulerFactory
|
from khaosz.trainer import Trainer, SchedulerFactory
|
||||||
|
|
@ -12,6 +13,15 @@ from khaosz.data import DatasetLoader
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
def parse_args() -> argparse.Namespace:
|
||||||
|
def parse_device_ids(s: Optional[str]) -> Optional[List[int]]:
|
||||||
|
if s is None or s.strip() == "":
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return [int(x.strip()) for x in s.split(",") if x.strip()]
|
||||||
|
except ValueError as e:
|
||||||
|
raise argparse.ArgumentTypeError(f"Invalid device_ids format: {s}. Expected comma-separated integers like '0,1,2'.")
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Train the Transformer model.")
|
parser = argparse.ArgumentParser(description="Train the Transformer model.")
|
||||||
|
|
||||||
parser.add_argument("--train_type",choices=["seq", "sft", "dpo"], help="Train type.")
|
parser.add_argument("--train_type",choices=["seq", "sft", "dpo"], help="Train type.")
|
||||||
|
|
@ -40,6 +50,8 @@ def parse_args() -> argparse.Namespace:
|
||||||
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
|
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
|
||||||
|
|
||||||
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
|
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
|
||||||
|
parser.add_argument("--device_ids", type=parse_device_ids, default=None, help="Device IDs to use.")
|
||||||
|
parser.add_argument("--device_type", type=str, default="cuda", help="Device type to use.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
@ -88,7 +100,9 @@ def train(
|
||||||
pin_memory: bool,
|
pin_memory: bool,
|
||||||
window_size: int,
|
window_size: int,
|
||||||
stride: int,
|
stride: int,
|
||||||
nprocs: int
|
nprocs: int,
|
||||||
|
device_ids: List[int],
|
||||||
|
device_type: str,
|
||||||
):
|
):
|
||||||
assert train_type in ["seq", "sft", "dpo"]
|
assert train_type in ["seq", "sft", "dpo"]
|
||||||
assert os.path.exists(param_path)
|
assert os.path.exists(param_path)
|
||||||
|
|
@ -147,10 +161,12 @@ def train(
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
nprocs=nprocs,
|
nprocs=nprocs,
|
||||||
|
parallel_wrapper=fsdp_wrap,
|
||||||
optimizer_factory=optimizer_fn,
|
optimizer_factory=optimizer_fn,
|
||||||
scheduler_factory=scheduler_fn,
|
scheduler_factory=scheduler_fn,
|
||||||
|
device_ids=device_ids,
|
||||||
|
device_type=device_type,
|
||||||
extra_kwargs=kwargs,
|
extra_kwargs=kwargs,
|
||||||
parallel_fn=fsdp_wrap
|
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(train_config)
|
trainer = Trainer(train_config)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue