diff --git a/khaosz/config/train_config.py b/khaosz/config/train_config.py index 4d899c6..21a6ae4 100644 --- a/khaosz/config/train_config.py +++ b/khaosz/config/train_config.py @@ -90,7 +90,7 @@ class TrainConfig: ) # others - kwargs: dict = field( + extra_kwargs: dict = field( default_factory=dict, metadata={"help": "Other arguments."} ) diff --git a/khaosz/parallel/__init__.py b/khaosz/parallel/__init__.py index ad8b5e9..00254d4 100644 --- a/khaosz/parallel/__init__.py +++ b/khaosz/parallel/__init__.py @@ -1,12 +1,10 @@ -from khaosz.parallel.utils import ( +from khaosz.parallel.setup import ( get_world_size, - get_rank, - get_device_count, - get_current_device, - get_available_backend, - setup_parallel, + get_rank, + get_current_device, + only_on_rank, - run_on_rank, + setup_parallel, spawn_parallel_fn ) @@ -18,12 +16,10 @@ from khaosz.parallel.module import ( __all__ = [ "get_world_size", "get_rank", - "get_device_count", "get_current_device", - "get_available_backend", - "setup_parallel", + "only_on_rank", - "run_on_rank", + "setup_parallel", "spawn_parallel_fn", "RowParallelLinear", diff --git a/khaosz/parallel/utils.py b/khaosz/parallel/setup.py similarity index 56% rename from khaosz/parallel/utils.py rename to khaosz/parallel/setup.py index 44d0cd9..9cf7099 100644 --- a/khaosz/parallel/utils.py +++ b/khaosz/parallel/setup.py @@ -3,38 +3,11 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp +from typing import Callable from functools import wraps from contextlib import contextmanager -def get_device_count() -> int: - if torch.cuda.is_available(): - return torch.cuda.device_count() - elif hasattr(torch, 'xpu') and torch.xpu.is_available(): - return torch.xpu.device_count() - elif hasattr(torch, 'mps') and torch.mps.is_available(): - return 1 - else: - return 1 - -def get_current_device() -> torch.device: - if torch.cuda.is_available(): - return torch.device(f"cuda:{torch.cuda.current_device()}") - elif hasattr(torch, 'xpu') and torch.xpu.is_available(): - return torch.device(f"xpu:{torch.xpu.current_device()}") - elif hasattr(torch, 'mps') and torch.mps.is_available(): - return torch.device("mps") - else: - return torch.device("cpu") - -def get_available_backend(): - if torch.cuda.is_available(): - return "nccl" - elif hasattr(torch, 'xpu') and torch.xpu.is_available(): - return "ccl" # Intel XPU use ccl - else: - return "gloo" - def get_world_size() -> int: if dist.is_available() and dist.is_initialized(): return dist.get_world_size() @@ -47,10 +20,21 @@ def get_rank() -> int: else: return 0 +def get_current_device(): + if torch.cuda.is_available(): + return torch.device(f"cuda:{torch.cuda.current_device()}") + elif hasattr(torch, 'xpu') and torch.xpu.is_available(): + return torch.device(f"xpu:{torch.xpu.current_device()}") + elif hasattr(torch, 'mps') and torch.mps.is_available(): + return torch.device("mps") + else: + return torch.device("cpu") + @contextmanager def setup_parallel( - rank: int = 0, - world_size: int = 1, + rank: int, + world_size: int, + backend: str = "nccl", master_addr: str = "localhost", master_port: str = "29500" ): @@ -69,11 +53,9 @@ def setup_parallel( os.environ['WORLD_SIZE'] = str(world_size) os.environ['LOCAL_RANK'] = str(rank) - backend = get_available_backend() - dist.init_process_group( backend=backend, - init_method="env://", + init_method=f"tcp://{master_addr}:{master_port}", rank=rank, world_size=world_size ) @@ -89,24 +71,7 @@ def setup_parallel( if dist.is_initialized(): dist.destroy_process_group() -@contextmanager -def run_on_rank(rank=0, sync_before=True, sync_after=True): - """ - context manager to run a function only on a specific rank. - """ - is_main_proc = (get_rank() == rank) - - if dist.is_initialized() and sync_before: - dist.barrier() - - try: - yield is_main_proc - - finally: - if dist.is_initialized() and sync_after: - dist.barrier() - -def only_on_rank(rank=0): +def only_on_rank(rank, sync=False): """ decorator to run a function only on a specific rank. """ @@ -116,36 +81,31 @@ def only_on_rank(rank=0): def wrapper(*args, **kwargs): if get_rank() == rank: return func(*args, **kwargs) - else: - return None + if sync: + dist.barrier() return wrapper return decorator -def wrapper_spawn_func(rank, world_size, func, kwargs_dict): - with setup_parallel(rank, world_size): - func(**kwargs_dict) +def wrapper_spawn_func(rank, world_size, backend, func, kwargs): + with setup_parallel(rank, world_size, backend): + func(**kwargs) -def spawn_parallel_fn(func, world_size=None, **kwargs): - - if world_size is None: - world_size = get_device_count() - - if world_size < 1: - raise ValueError("world_size must be greater than 0") - - device_count = get_device_count() - if world_size > device_count: - raise ValueError(f"world_size ({world_size}) exceeds available devices ({device_count})") +def spawn_parallel_fn(func: Callable, world_size: int, backend: str, **kwargs): if world_size == 1: func(**kwargs) return + + # clear environment variables + for key in ['MASTER_ADDR', 'MASTER_PORT', 'RANK', 'WORLD_SIZE', 'LOCAL_RANK']: + if key in os.environ: + del os.environ[key] mp.spawn( wrapper_spawn_func, nprocs=world_size, - args=(world_size, func, kwargs), + args=(world_size, backend, func, kwargs), join=True ) \ No newline at end of file diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index dd31629..a06bf1d 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -8,6 +8,7 @@ from torch.nn.utils import clip_grad_norm_ from torch.optim.lr_scheduler import LRScheduler from typing import List, Optional, Protocol, TYPE_CHECKING +from khaosz.parallel import only_on_rank from khaosz.trainer.metric_util import ( grad_max, grad_min, @@ -96,6 +97,7 @@ class CheckpointCallback(TrainCallback): self.save_dir = save_dir self.last_ckpt_iter = 0 + @only_on_rank(0) def _save_checkpoint(self, context: 'TrainContext'): save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}iter_{context.iteration}") context.checkpoint = Checkpoint( @@ -127,6 +129,7 @@ class ProgressBarCallback(TrainCallback): self.num_epoch = num_epoch self.progress_bar: tqdm = None + @only_on_rank(0) def on_epoch_begin(self, context: 'TrainContext'): self.progress_bar = tqdm( context.dataloader, @@ -134,6 +137,7 @@ class ProgressBarCallback(TrainCallback): dynamic_ncols=True ) + @only_on_rank(0) def on_batch_end(self, context: 'TrainContext'): self.progress_bar.set_postfix({ "loss": f"{context.loss:.4f}", @@ -141,6 +145,7 @@ class ProgressBarCallback(TrainCallback): }) self.progress_bar.update(1) + @only_on_rank(0) def on_epoch_end(self, context: 'TrainContext'): _ = context if self.progress_bar: @@ -219,7 +224,8 @@ class StepMonitorCallback(TrainCallback): json.dump(log_data, f, indent=4) except Exception: raise - + + @only_on_rank(0) def on_step_end(self, context: 'TrainContext'): if self.step_num % self.log_interval == 0: self._handle_log(context) diff --git a/khaosz/trainer/train_context.py b/khaosz/trainer/train_context.py index a941583..d95e3f5 100644 --- a/khaosz/trainer/train_context.py +++ b/khaosz/trainer/train_context.py @@ -7,7 +7,7 @@ from khaosz.data import ResumableDistributedSampler from khaosz.trainer.checkpoint import Checkpoint from khaosz.trainer.strategy import StrategyFactory, BaseStrategy from khaosz.config.train_config import TrainConfig -from khaosz.parallel.utils import get_current_device, get_world_size, get_rank +from khaosz.parallel.setup import get_current_device, get_world_size, get_rank from dataclasses import dataclass, field from typing import Optional, Self @@ -85,7 +85,7 @@ class TrainContextBuilder: model=self.config.model, train_type=self.config.strategy, device=device, - **self.config.kwargs + **self.config.extra_kwargs ) return self diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 8456ad8..2ca75e8 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -8,7 +8,8 @@ from khaosz.trainer.train_callback import ( GradientClippingCallback, SchedulerCallback ) -from khaosz.trainer.train_context import TrainContext, TrainContextBuilder, Checkpoint +from khaosz.trainer.train_context import TrainContext, TrainContextBuilder +from khaosz.trainer.checkpoint import Checkpoint logger = logging.getLogger(__name__) diff --git a/tests/test_parallel.py b/tests/test_parallel.py new file mode 100644 index 0000000..714cc74 --- /dev/null +++ b/tests/test_parallel.py @@ -0,0 +1,39 @@ +import torch +import torch.distributed as dist + +from khaosz.parallel import ( + get_rank, + only_on_rank, + spawn_parallel_fn +) + +@only_on_rank(0) +def _test_only_on_rank_helper(): + return True + +def only_on_rank(): + result = _test_only_on_rank_helper() + if get_rank() == 0: + assert result is True + else: + assert result is None + +def all_reduce(): + x = torch.tensor([get_rank()], dtype=torch.int) + dist.all_reduce(x, op=dist.ReduceOp.SUM) + expected_sum = sum(range(dist.get_world_size())) + assert x.item() == expected_sum + +def test_spawn_only_on_rank(): + spawn_parallel_fn( + only_on_rank, + world_size=2, + backend="gloo" + ) + +def test_spawn_all_reduce(): + spawn_parallel_fn( + all_reduce, + world_size=2, + backend="gloo" + ) \ No newline at end of file diff --git a/tools/train.py b/tools/train.py index 697b7e9..e8cf86b 100644 --- a/tools/train.py +++ b/tools/train.py @@ -1,11 +1,14 @@ import os import argparse import torch +import torch.nn as nn +import torch.distributed.fsdp as fsdp from torch.optim import AdamW from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig from khaosz.trainer import Trainer, SchedulerFactory from khaosz.data import DatasetLoader +from khaosz.parallel import get_current_device, spawn_parallel_fn def parse_args() -> argparse.Namespace: @@ -37,10 +40,26 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.") parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.") + parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.") + args = parser.parse_args() return args +def fsdp_wrap(model: nn.Module): + + fsdp_model = fsdp.FullyShardedDataParallel( + model, + sharding_strategy=fsdp.ShardingStrategy.SHARD_GRAD_OP, + mixed_precision=fsdp.MixedPrecision( + param_dtype=torch.bfloat16, + reduce_dtype=torch.bfloat16, + buffer_dtype=torch.bfloat16, + ), + backward_prefetch=fsdp.BackwardPrefetch.BACKWARD_PRE + ) + return fsdp_model + def train( train_type: str, param_path: str, @@ -65,6 +84,7 @@ def train( pin_memory: bool, window_size: int, stride: int, + nprocs: int ): assert train_type in ["seq", "sft", "dpo"] assert os.path.exists(param_path) @@ -76,8 +96,8 @@ def train( window_size = parameter.config.m_len model = parameter.model - device = torch.device("cuda") - model = model.to(device=device, dtype=torch.bfloat16) + current_device = get_current_device() + model = fsdp_wrap(model.to(device=current_device, dtype=torch.bfloat16)) kwargs = { "dpo_beta": dpo_beta, @@ -90,8 +110,7 @@ def train( train_type=train_type, load_path=data_root_path, window_size=window_size, - stride=stride, - **kwargs + stride=stride ) param_groups = [ @@ -107,7 +126,7 @@ def train( schedule_config = CosineScheduleConfig( warmup_steps=warmup_steps, - total_steps=len(dataset) * n_epoch // batch_size, + total_steps=len(dataset) * n_epoch // (batch_size * nprocs), ) scheduler = SchedulerFactory.load(optimizer, schedule_config) @@ -128,7 +147,9 @@ def train( max_grad_norm=max_grad_norm, random_seed=random_seed, num_workers=num_workers, - pin_memory=pin_memory + pin_memory=pin_memory, + nprocs=nprocs, + extra_kwargs=kwargs, ) trainer = Trainer(train_config) @@ -137,4 +158,10 @@ def train( if __name__ == "__main__": args = parse_args() - train(**vars(args)) \ No newline at end of file + + spawn_parallel_fn( + train, + world_size=args.nprocs, + backend="nccl", + **vars(args) + ) \ No newline at end of file