From fd7ee2895aad69a13ab25791f7e6a8c8fa9710b4 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Fri, 26 Dec 2025 20:54:33 +0800 Subject: [PATCH] =?UTF-8?q?refactor(paralell):=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E5=B9=B6=E8=A1=8C=E8=AE=BE=E5=A4=87=E6=8C=87=E5=AE=9A=E6=96=B9?= =?UTF-8?q?=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/config/train_config.py | 14 ++++- khaosz/parallel/device.py | 100 -------------------------------- khaosz/parallel/setup.py | 55 ++++++++++-------- khaosz/trainer/train_context.py | 4 +- khaosz/trainer/trainer.py | 2 +- tools/train.py | 20 ++++++- 6 files changed, 65 insertions(+), 130 deletions(-) delete mode 100644 khaosz/parallel/device.py diff --git a/khaosz/config/train_config.py b/khaosz/config/train_config.py index d41d01b..cb4a8ac 100644 --- a/khaosz/config/train_config.py +++ b/khaosz/config/train_config.py @@ -4,7 +4,7 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler from dataclasses import dataclass, field -from typing import Callable, Optional +from typing import Callable, List, Optional @dataclass @@ -101,7 +101,7 @@ class TrainConfig: default="29500", metadata={"help": "Master port for distributed training."} ) - parallel_fn: Optional[Callable] = field( + parallel_wrapper: Optional[Callable] = field( default=None, metadata={"help": "Parallel function for training."} ) @@ -115,6 +115,14 @@ class TrainConfig: ) # others + device_ids: Optional[List[int]] = field( + default=None, + metadata={"help": "Device ids for distributed training."} + ) + device_type: str = field( + default="cuda", + metadata={"help": "Device type for distributed training."} + ) extra_kwargs: dict = field( default_factory=dict, metadata={"help": "Other arguments."} @@ -138,3 +146,5 @@ class TrainConfig: raise ValueError("Distributed training requires optimizer and scheduler factories.") elif self.nprocs == 1 and not argument_case: raise ValueError("Single process training requires optimizer and scheduler arguments.") + + \ No newline at end of file diff --git a/khaosz/parallel/device.py b/khaosz/parallel/device.py deleted file mode 100644 index d568f67..0000000 --- a/khaosz/parallel/device.py +++ /dev/null @@ -1,100 +0,0 @@ -import os -import torch -import torch.distributed as dist -from dataclasses import dataclass -from typing import Callable, List, Optional - - -@dataclass -class DeviceStrategy: - """ - A class representing a device strategy. - - Attributes: - name: Name of the device backend (e.g., 'cuda', 'xpu'). - priority: Higher number means higher priority. - is_available: A callable that returns True if the device is available. - make_device: A callable that takes a rank (int) and returns a torch.device. - """ - name: str - priority: int - is_available: Callable[[], bool] - make_device: Callable[[int], torch.device] - - -class DeviceStrategyRegistry: - """ - A registry for device strategies that automatically selects the best available device. - And allows overriding the device backend via environment variable. - """ - - _instance: Optional["DeviceStrategyRegistry"] = None - _initialized: bool = False - - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - - def __init__(self) -> None: - if self._initialized: - return - - self._strategies: List[DeviceStrategy] = [] - - self.register(DeviceStrategy( - name="cuda", - priority=100, - is_available=torch.cuda.is_available, - make_device=lambda rank: torch.device(f"cuda:{rank}") - )) - - self.register(DeviceStrategy( - name="xpu", - priority=90, - is_available=torch.xpu.is_available, - make_device=lambda rank: torch.device(f"xpu:{rank}") - )) - - self.register(DeviceStrategy( - name="mps", - priority=80, - is_available=torch.mps.is_available, - make_device=lambda _: torch.device("mps") # MPS ignores rank - )) - - self.register(DeviceStrategy( - name="cpu", - priority=0, - is_available=lambda: True, - make_device=lambda _: torch.device("cpu") - )) - - self._initialized = True - - def register(self, strategy: DeviceStrategy): - self._strategies.append(strategy) - - def get_current_device(self) -> torch.device: - """Return the best available device for the current process.""" - override = os.getenv("TORCH_DEVICE_OVERRIDE") - sorted_strategies = sorted(self._strategies, key=lambda s: -s.priority) - - rank = 0 - - if dist.is_available() and dist.is_initialized(): - rank = os.environ["LOCAL_RANK"] - - if override: - return torch.device(override, rank) - - - for strategy in sorted_strategies: - if strategy.is_available(): - - return strategy.make_device(rank) - - raise RuntimeError("No device backend is available, including CPU.") - - -device_registry = DeviceStrategyRegistry() \ No newline at end of file diff --git a/khaosz/parallel/setup.py b/khaosz/parallel/setup.py index 386a6dd..932dc38 100644 --- a/khaosz/parallel/setup.py +++ b/khaosz/parallel/setup.py @@ -6,11 +6,10 @@ import torch.multiprocessing as mp from functools import wraps from contextlib import contextmanager from typing import Callable, List, Optional -from khaosz.parallel.device import device_registry def get_current_device(): - return device_registry.get_current_device() + return os.environ["LOCAL_DEVICE"] def get_world_size() -> int: if dist.is_available() and dist.is_initialized(): @@ -31,7 +30,8 @@ def setup_parallel( backend: str = "nccl", master_addr: str = "localhost", master_port: str = "29500", - avail_ids: Optional[List[int]] = None + device_type: str = "cuda", + device_ids: Optional[List[int]] = None ): if dist.is_available() and dist.is_initialized(): @@ -42,28 +42,31 @@ def setup_parallel( yield None return - if avail_ids is None: - avail_ids = [i for i in range(world_size)] + if device_ids is None: + device_ids = [i for i in range(world_size)] - rank = avail_ids[rank % len(avail_ids)] + rank = device_ids[rank % len(device_ids)] + device_id = torch.device(device_type, device_ids[rank]) os.environ['MASTER_ADDR'] = master_addr os.environ['MASTER_PORT'] = master_port - os.environ['WORLD_SIZE'] = str(world_size) + os.environ['LOCAL_RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ["LOCAL_DEVICE"] = str(device_id) dist.init_process_group( - backend=backend, - init_method=f"tcp://{master_addr}:{master_port}", rank=rank, - world_size=world_size + world_size=world_size, + backend=backend, + device_id=device_id ) try: if backend == "nccl" and torch.cuda.is_available(): - torch.cuda.set_device(rank) + torch.cuda.set_device(device_id) elif backend == "ccl" and hasattr(torch, 'xpu') and torch.xpu.is_available(): - torch.xpu.set_device(rank) + torch.xpu.set_device(device_id) yield dist.group.WORLD finally: @@ -92,8 +95,9 @@ def wrapper_spawn_func( world_size: int, backend: str, master_addr: str, - master_port: str, - avail_ids: List[int], + master_port: str, + device_type: str, + device_ids: List[int], func: Callable, kwargs: dict ): @@ -104,7 +108,8 @@ def wrapper_spawn_func( backend=backend, master_addr=master_addr, master_port=master_port, - avail_ids=avail_ids + device_type=device_type, + device_ids=device_ids ): func(**kwargs) @@ -118,22 +123,26 @@ def spawn_parallel_fn( backend: str = "nccl", master_addr: str = "localhost", master_port: str = "29500", - avail_ids: Optional[List[int]] = None, + device_type: str = "cuda", + device_ids: Optional[List[int]] = None, **kwargs ): + # clear environment variables + for key in ['MASTER_ADDR', 'MASTER_PORT', 'RANK', 'WORLD_SIZE', 'LOCAL_RANK', 'LOCAL_DEVICE']: + if key in os.environ: + del os.environ[key] if world_size == 1: + device_ids = device_ids or [0] + deice_id = torch.device(device_type, device_ids[0]) + os.environ["LOCAL_DEVICE"] = str(deice_id) + func(**kwargs) return - # clear environment variables - for key in ['MASTER_ADDR', 'MASTER_PORT', 'RANK', 'WORLD_SIZE', 'LOCAL_RANK']: - if key in os.environ: - del os.environ[key] + wrapper_spawn_func_args = (world_size, backend, master_addr, master_port, + device_type, device_ids, func, kwargs) - wrapper_spawn_func_args = (world_size, backend, - master_addr, master_port, avail_ids, func, kwargs) - mp.spawn( wrapper_spawn_func, nprocs=world_size, diff --git a/khaosz/trainer/train_context.py b/khaosz/trainer/train_context.py index ab1184a..ee3cea5 100644 --- a/khaosz/trainer/train_context.py +++ b/khaosz/trainer/train_context.py @@ -88,13 +88,13 @@ class TrainContextBuilder: ) return self - def with_parallel_fn(self) -> Self: + def with_parallel(self) -> Self: device = get_current_device() self._context.model = self._context.model.to(device=device) if self.config.nprocs > 1: - fn = self.config.parallel_fn + fn = self.config.parallel_wrapper optimizer_fn = self.config.optimizer_factory scheduler_fn = self.config.scheduler_factory diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index e409845..dec2d0c 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -38,7 +38,7 @@ class Trainer: .with_checkpoint(checkpoint) .with_dataloader() .with_strategy() - .with_parallel_fn() + .with_parallel() .build()) def _call_callbacks(self, method_name: str, context: TrainContext): diff --git a/tools/train.py b/tools/train.py index a7621d2..a352b3c 100644 --- a/tools/train.py +++ b/tools/train.py @@ -5,6 +5,7 @@ import torch.nn as nn import torch.optim as optim import torch.distributed.fsdp as fsdp +from typing import List, Optional from functools import partial from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig from khaosz.trainer import Trainer, SchedulerFactory @@ -12,6 +13,15 @@ from khaosz.data import DatasetLoader def parse_args() -> argparse.Namespace: + def parse_device_ids(s: Optional[str]) -> Optional[List[int]]: + if s is None or s.strip() == "": + return None + try: + return [int(x.strip()) for x in s.split(",") if x.strip()] + except ValueError as e: + raise argparse.ArgumentTypeError(f"Invalid device_ids format: {s}. Expected comma-separated integers like '0,1,2'.") + + parser = argparse.ArgumentParser(description="Train the Transformer model.") parser.add_argument("--train_type",choices=["seq", "sft", "dpo"], help="Train type.") @@ -40,6 +50,8 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.") parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.") + parser.add_argument("--device_ids", type=parse_device_ids, default=None, help="Device IDs to use.") + parser.add_argument("--device_type", type=str, default="cuda", help="Device type to use.") args = parser.parse_args() @@ -88,7 +100,9 @@ def train( pin_memory: bool, window_size: int, stride: int, - nprocs: int + nprocs: int, + device_ids: List[int], + device_type: str, ): assert train_type in ["seq", "sft", "dpo"] assert os.path.exists(param_path) @@ -147,10 +161,12 @@ def train( num_workers=num_workers, pin_memory=pin_memory, nprocs=nprocs, + parallel_wrapper=fsdp_wrap, optimizer_factory=optimizer_fn, scheduler_factory=scheduler_fn, + device_ids=device_ids, + device_type=device_type, extra_kwargs=kwargs, - parallel_fn=fsdp_wrap ) trainer = Trainer(train_config)