refactor(parallel): 重构parallel模块
This commit is contained in:
parent
a30ddca517
commit
d882f65579
|
|
@ -90,7 +90,7 @@ class TrainConfig:
|
|||
)
|
||||
|
||||
# others
|
||||
kwargs: dict = field(
|
||||
extra_kwargs: dict = field(
|
||||
default_factory=dict,
|
||||
metadata={"help": "Other arguments."}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,10 @@
|
|||
from khaosz.parallel.utils import (
|
||||
from khaosz.parallel.setup import (
|
||||
get_world_size,
|
||||
get_rank,
|
||||
get_device_count,
|
||||
get_current_device,
|
||||
get_available_backend,
|
||||
setup_parallel,
|
||||
|
||||
only_on_rank,
|
||||
run_on_rank,
|
||||
setup_parallel,
|
||||
spawn_parallel_fn
|
||||
)
|
||||
|
||||
|
|
@ -18,12 +16,10 @@ from khaosz.parallel.module import (
|
|||
__all__ = [
|
||||
"get_world_size",
|
||||
"get_rank",
|
||||
"get_device_count",
|
||||
"get_current_device",
|
||||
"get_available_backend",
|
||||
"setup_parallel",
|
||||
|
||||
"only_on_rank",
|
||||
"run_on_rank",
|
||||
"setup_parallel",
|
||||
"spawn_parallel_fn",
|
||||
|
||||
"RowParallelLinear",
|
||||
|
|
|
|||
|
|
@ -3,38 +3,11 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from typing import Callable
|
||||
from functools import wraps
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
def get_device_count() -> int:
|
||||
if torch.cuda.is_available():
|
||||
return torch.cuda.device_count()
|
||||
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
|
||||
return torch.xpu.device_count()
|
||||
elif hasattr(torch, 'mps') and torch.mps.is_available():
|
||||
return 1
|
||||
else:
|
||||
return 1
|
||||
|
||||
def get_current_device() -> torch.device:
|
||||
if torch.cuda.is_available():
|
||||
return torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
|
||||
return torch.device(f"xpu:{torch.xpu.current_device()}")
|
||||
elif hasattr(torch, 'mps') and torch.mps.is_available():
|
||||
return torch.device("mps")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def get_available_backend():
|
||||
if torch.cuda.is_available():
|
||||
return "nccl"
|
||||
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
|
||||
return "ccl" # Intel XPU use ccl
|
||||
else:
|
||||
return "gloo"
|
||||
|
||||
def get_world_size() -> int:
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
return dist.get_world_size()
|
||||
|
|
@ -47,10 +20,21 @@ def get_rank() -> int:
|
|||
else:
|
||||
return 0
|
||||
|
||||
def get_current_device():
|
||||
if torch.cuda.is_available():
|
||||
return torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
|
||||
return torch.device(f"xpu:{torch.xpu.current_device()}")
|
||||
elif hasattr(torch, 'mps') and torch.mps.is_available():
|
||||
return torch.device("mps")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
@contextmanager
|
||||
def setup_parallel(
|
||||
rank: int = 0,
|
||||
world_size: int = 1,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
backend: str = "nccl",
|
||||
master_addr: str = "localhost",
|
||||
master_port: str = "29500"
|
||||
):
|
||||
|
|
@ -69,11 +53,9 @@ def setup_parallel(
|
|||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
|
||||
backend = get_available_backend()
|
||||
|
||||
dist.init_process_group(
|
||||
backend=backend,
|
||||
init_method="env://",
|
||||
init_method=f"tcp://{master_addr}:{master_port}",
|
||||
rank=rank,
|
||||
world_size=world_size
|
||||
)
|
||||
|
|
@ -89,24 +71,7 @@ def setup_parallel(
|
|||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
@contextmanager
|
||||
def run_on_rank(rank=0, sync_before=True, sync_after=True):
|
||||
"""
|
||||
context manager to run a function only on a specific rank.
|
||||
"""
|
||||
is_main_proc = (get_rank() == rank)
|
||||
|
||||
if dist.is_initialized() and sync_before:
|
||||
dist.barrier()
|
||||
|
||||
try:
|
||||
yield is_main_proc
|
||||
|
||||
finally:
|
||||
if dist.is_initialized() and sync_after:
|
||||
dist.barrier()
|
||||
|
||||
def only_on_rank(rank=0):
|
||||
def only_on_rank(rank, sync=False):
|
||||
"""
|
||||
decorator to run a function only on a specific rank.
|
||||
"""
|
||||
|
|
@ -116,36 +81,31 @@ def only_on_rank(rank=0):
|
|||
def wrapper(*args, **kwargs):
|
||||
if get_rank() == rank:
|
||||
return func(*args, **kwargs)
|
||||
else:
|
||||
return None
|
||||
if sync:
|
||||
dist.barrier()
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
def wrapper_spawn_func(rank, world_size, func, kwargs_dict):
|
||||
with setup_parallel(rank, world_size):
|
||||
func(**kwargs_dict)
|
||||
def wrapper_spawn_func(rank, world_size, backend, func, kwargs):
|
||||
with setup_parallel(rank, world_size, backend):
|
||||
func(**kwargs)
|
||||
|
||||
def spawn_parallel_fn(func, world_size=None, **kwargs):
|
||||
|
||||
if world_size is None:
|
||||
world_size = get_device_count()
|
||||
|
||||
if world_size < 1:
|
||||
raise ValueError("world_size must be greater than 0")
|
||||
|
||||
device_count = get_device_count()
|
||||
if world_size > device_count:
|
||||
raise ValueError(f"world_size ({world_size}) exceeds available devices ({device_count})")
|
||||
def spawn_parallel_fn(func: Callable, world_size: int, backend: str, **kwargs):
|
||||
|
||||
if world_size == 1:
|
||||
func(**kwargs)
|
||||
return
|
||||
|
||||
# clear environment variables
|
||||
for key in ['MASTER_ADDR', 'MASTER_PORT', 'RANK', 'WORLD_SIZE', 'LOCAL_RANK']:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
|
||||
mp.spawn(
|
||||
wrapper_spawn_func,
|
||||
nprocs=world_size,
|
||||
args=(world_size, func, kwargs),
|
||||
args=(world_size, backend, func, kwargs),
|
||||
join=True
|
||||
)
|
||||
|
|
@ -8,6 +8,7 @@ from torch.nn.utils import clip_grad_norm_
|
|||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from typing import List, Optional, Protocol, TYPE_CHECKING
|
||||
|
||||
from khaosz.parallel import only_on_rank
|
||||
from khaosz.trainer.metric_util import (
|
||||
grad_max,
|
||||
grad_min,
|
||||
|
|
@ -96,6 +97,7 @@ class CheckpointCallback(TrainCallback):
|
|||
self.save_dir = save_dir
|
||||
self.last_ckpt_iter = 0
|
||||
|
||||
@only_on_rank(0)
|
||||
def _save_checkpoint(self, context: 'TrainContext'):
|
||||
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}iter_{context.iteration}")
|
||||
context.checkpoint = Checkpoint(
|
||||
|
|
@ -127,6 +129,7 @@ class ProgressBarCallback(TrainCallback):
|
|||
self.num_epoch = num_epoch
|
||||
self.progress_bar: tqdm = None
|
||||
|
||||
@only_on_rank(0)
|
||||
def on_epoch_begin(self, context: 'TrainContext'):
|
||||
self.progress_bar = tqdm(
|
||||
context.dataloader,
|
||||
|
|
@ -134,6 +137,7 @@ class ProgressBarCallback(TrainCallback):
|
|||
dynamic_ncols=True
|
||||
)
|
||||
|
||||
@only_on_rank(0)
|
||||
def on_batch_end(self, context: 'TrainContext'):
|
||||
self.progress_bar.set_postfix({
|
||||
"loss": f"{context.loss:.4f}",
|
||||
|
|
@ -141,6 +145,7 @@ class ProgressBarCallback(TrainCallback):
|
|||
})
|
||||
self.progress_bar.update(1)
|
||||
|
||||
@only_on_rank(0)
|
||||
def on_epoch_end(self, context: 'TrainContext'):
|
||||
_ = context
|
||||
if self.progress_bar:
|
||||
|
|
@ -220,6 +225,7 @@ class StepMonitorCallback(TrainCallback):
|
|||
except Exception:
|
||||
raise
|
||||
|
||||
@only_on_rank(0)
|
||||
def on_step_end(self, context: 'TrainContext'):
|
||||
if self.step_num % self.log_interval == 0:
|
||||
self._handle_log(context)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from khaosz.data import ResumableDistributedSampler
|
|||
from khaosz.trainer.checkpoint import Checkpoint
|
||||
from khaosz.trainer.strategy import StrategyFactory, BaseStrategy
|
||||
from khaosz.config.train_config import TrainConfig
|
||||
from khaosz.parallel.utils import get_current_device, get_world_size, get_rank
|
||||
from khaosz.parallel.setup import get_current_device, get_world_size, get_rank
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Self
|
||||
|
|
@ -85,7 +85,7 @@ class TrainContextBuilder:
|
|||
model=self.config.model,
|
||||
train_type=self.config.strategy,
|
||||
device=device,
|
||||
**self.config.kwargs
|
||||
**self.config.extra_kwargs
|
||||
)
|
||||
return self
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,8 @@ from khaosz.trainer.train_callback import (
|
|||
GradientClippingCallback,
|
||||
SchedulerCallback
|
||||
)
|
||||
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder, Checkpoint
|
||||
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
|
||||
from khaosz.trainer.checkpoint import Checkpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,39 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from khaosz.parallel import (
|
||||
get_rank,
|
||||
only_on_rank,
|
||||
spawn_parallel_fn
|
||||
)
|
||||
|
||||
@only_on_rank(0)
|
||||
def _test_only_on_rank_helper():
|
||||
return True
|
||||
|
||||
def only_on_rank():
|
||||
result = _test_only_on_rank_helper()
|
||||
if get_rank() == 0:
|
||||
assert result is True
|
||||
else:
|
||||
assert result is None
|
||||
|
||||
def all_reduce():
|
||||
x = torch.tensor([get_rank()], dtype=torch.int)
|
||||
dist.all_reduce(x, op=dist.ReduceOp.SUM)
|
||||
expected_sum = sum(range(dist.get_world_size()))
|
||||
assert x.item() == expected_sum
|
||||
|
||||
def test_spawn_only_on_rank():
|
||||
spawn_parallel_fn(
|
||||
only_on_rank,
|
||||
world_size=2,
|
||||
backend="gloo"
|
||||
)
|
||||
|
||||
def test_spawn_all_reduce():
|
||||
spawn_parallel_fn(
|
||||
all_reduce,
|
||||
world_size=2,
|
||||
backend="gloo"
|
||||
)
|
||||
|
|
@ -1,11 +1,14 @@
|
|||
import os
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed.fsdp as fsdp
|
||||
|
||||
from torch.optim import AdamW
|
||||
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
||||
from khaosz.trainer import Trainer, SchedulerFactory
|
||||
from khaosz.data import DatasetLoader
|
||||
from khaosz.parallel import get_current_device, spawn_parallel_fn
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
|
|
@ -37,10 +40,26 @@ def parse_args() -> argparse.Namespace:
|
|||
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.")
|
||||
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
|
||||
|
||||
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
def fsdp_wrap(model: nn.Module):
|
||||
|
||||
fsdp_model = fsdp.FullyShardedDataParallel(
|
||||
model,
|
||||
sharding_strategy=fsdp.ShardingStrategy.SHARD_GRAD_OP,
|
||||
mixed_precision=fsdp.MixedPrecision(
|
||||
param_dtype=torch.bfloat16,
|
||||
reduce_dtype=torch.bfloat16,
|
||||
buffer_dtype=torch.bfloat16,
|
||||
),
|
||||
backward_prefetch=fsdp.BackwardPrefetch.BACKWARD_PRE
|
||||
)
|
||||
return fsdp_model
|
||||
|
||||
def train(
|
||||
train_type: str,
|
||||
param_path: str,
|
||||
|
|
@ -65,6 +84,7 @@ def train(
|
|||
pin_memory: bool,
|
||||
window_size: int,
|
||||
stride: int,
|
||||
nprocs: int
|
||||
):
|
||||
assert train_type in ["seq", "sft", "dpo"]
|
||||
assert os.path.exists(param_path)
|
||||
|
|
@ -76,8 +96,8 @@ def train(
|
|||
window_size = parameter.config.m_len
|
||||
|
||||
model = parameter.model
|
||||
device = torch.device("cuda")
|
||||
model = model.to(device=device, dtype=torch.bfloat16)
|
||||
current_device = get_current_device()
|
||||
model = fsdp_wrap(model.to(device=current_device, dtype=torch.bfloat16))
|
||||
|
||||
kwargs = {
|
||||
"dpo_beta": dpo_beta,
|
||||
|
|
@ -90,8 +110,7 @@ def train(
|
|||
train_type=train_type,
|
||||
load_path=data_root_path,
|
||||
window_size=window_size,
|
||||
stride=stride,
|
||||
**kwargs
|
||||
stride=stride
|
||||
)
|
||||
|
||||
param_groups = [
|
||||
|
|
@ -107,7 +126,7 @@ def train(
|
|||
|
||||
schedule_config = CosineScheduleConfig(
|
||||
warmup_steps=warmup_steps,
|
||||
total_steps=len(dataset) * n_epoch // batch_size,
|
||||
total_steps=len(dataset) * n_epoch // (batch_size * nprocs),
|
||||
)
|
||||
|
||||
scheduler = SchedulerFactory.load(optimizer, schedule_config)
|
||||
|
|
@ -128,7 +147,9 @@ def train(
|
|||
max_grad_norm=max_grad_norm,
|
||||
random_seed=random_seed,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory
|
||||
pin_memory=pin_memory,
|
||||
nprocs=nprocs,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
|
||||
trainer = Trainer(train_config)
|
||||
|
|
@ -137,4 +158,10 @@ def train(
|
|||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
train(**vars(args))
|
||||
|
||||
spawn_parallel_fn(
|
||||
train,
|
||||
world_size=args.nprocs,
|
||||
backend="nccl",
|
||||
**vars(args)
|
||||
)
|
||||
Loading…
Reference in New Issue