refactor(parallel): 重构parallel模块
This commit is contained in:
parent
a30ddca517
commit
d882f65579
|
|
@ -90,7 +90,7 @@ class TrainConfig:
|
||||||
)
|
)
|
||||||
|
|
||||||
# others
|
# others
|
||||||
kwargs: dict = field(
|
extra_kwargs: dict = field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
metadata={"help": "Other arguments."}
|
metadata={"help": "Other arguments."}
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,10 @@
|
||||||
from khaosz.parallel.utils import (
|
from khaosz.parallel.setup import (
|
||||||
get_world_size,
|
get_world_size,
|
||||||
get_rank,
|
get_rank,
|
||||||
get_device_count,
|
get_current_device,
|
||||||
get_current_device,
|
|
||||||
get_available_backend,
|
|
||||||
setup_parallel,
|
|
||||||
only_on_rank,
|
only_on_rank,
|
||||||
run_on_rank,
|
setup_parallel,
|
||||||
spawn_parallel_fn
|
spawn_parallel_fn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -18,12 +16,10 @@ from khaosz.parallel.module import (
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"get_world_size",
|
"get_world_size",
|
||||||
"get_rank",
|
"get_rank",
|
||||||
"get_device_count",
|
|
||||||
"get_current_device",
|
"get_current_device",
|
||||||
"get_available_backend",
|
|
||||||
"setup_parallel",
|
|
||||||
"only_on_rank",
|
"only_on_rank",
|
||||||
"run_on_rank",
|
"setup_parallel",
|
||||||
"spawn_parallel_fn",
|
"spawn_parallel_fn",
|
||||||
|
|
||||||
"RowParallelLinear",
|
"RowParallelLinear",
|
||||||
|
|
|
||||||
|
|
@ -3,38 +3,11 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
from typing import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
|
||||||
def get_device_count() -> int:
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
return torch.cuda.device_count()
|
|
||||||
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
|
|
||||||
return torch.xpu.device_count()
|
|
||||||
elif hasattr(torch, 'mps') and torch.mps.is_available():
|
|
||||||
return 1
|
|
||||||
else:
|
|
||||||
return 1
|
|
||||||
|
|
||||||
def get_current_device() -> torch.device:
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
return torch.device(f"cuda:{torch.cuda.current_device()}")
|
|
||||||
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
|
|
||||||
return torch.device(f"xpu:{torch.xpu.current_device()}")
|
|
||||||
elif hasattr(torch, 'mps') and torch.mps.is_available():
|
|
||||||
return torch.device("mps")
|
|
||||||
else:
|
|
||||||
return torch.device("cpu")
|
|
||||||
|
|
||||||
def get_available_backend():
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
return "nccl"
|
|
||||||
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
|
|
||||||
return "ccl" # Intel XPU use ccl
|
|
||||||
else:
|
|
||||||
return "gloo"
|
|
||||||
|
|
||||||
def get_world_size() -> int:
|
def get_world_size() -> int:
|
||||||
if dist.is_available() and dist.is_initialized():
|
if dist.is_available() and dist.is_initialized():
|
||||||
return dist.get_world_size()
|
return dist.get_world_size()
|
||||||
|
|
@ -47,10 +20,21 @@ def get_rank() -> int:
|
||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
def get_current_device():
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||||
|
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
|
||||||
|
return torch.device(f"xpu:{torch.xpu.current_device()}")
|
||||||
|
elif hasattr(torch, 'mps') and torch.mps.is_available():
|
||||||
|
return torch.device("mps")
|
||||||
|
else:
|
||||||
|
return torch.device("cpu")
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def setup_parallel(
|
def setup_parallel(
|
||||||
rank: int = 0,
|
rank: int,
|
||||||
world_size: int = 1,
|
world_size: int,
|
||||||
|
backend: str = "nccl",
|
||||||
master_addr: str = "localhost",
|
master_addr: str = "localhost",
|
||||||
master_port: str = "29500"
|
master_port: str = "29500"
|
||||||
):
|
):
|
||||||
|
|
@ -69,11 +53,9 @@ def setup_parallel(
|
||||||
os.environ['WORLD_SIZE'] = str(world_size)
|
os.environ['WORLD_SIZE'] = str(world_size)
|
||||||
os.environ['LOCAL_RANK'] = str(rank)
|
os.environ['LOCAL_RANK'] = str(rank)
|
||||||
|
|
||||||
backend = get_available_backend()
|
|
||||||
|
|
||||||
dist.init_process_group(
|
dist.init_process_group(
|
||||||
backend=backend,
|
backend=backend,
|
||||||
init_method="env://",
|
init_method=f"tcp://{master_addr}:{master_port}",
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size
|
world_size=world_size
|
||||||
)
|
)
|
||||||
|
|
@ -89,24 +71,7 @@ def setup_parallel(
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
@contextmanager
|
def only_on_rank(rank, sync=False):
|
||||||
def run_on_rank(rank=0, sync_before=True, sync_after=True):
|
|
||||||
"""
|
|
||||||
context manager to run a function only on a specific rank.
|
|
||||||
"""
|
|
||||||
is_main_proc = (get_rank() == rank)
|
|
||||||
|
|
||||||
if dist.is_initialized() and sync_before:
|
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
try:
|
|
||||||
yield is_main_proc
|
|
||||||
|
|
||||||
finally:
|
|
||||||
if dist.is_initialized() and sync_after:
|
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
def only_on_rank(rank=0):
|
|
||||||
"""
|
"""
|
||||||
decorator to run a function only on a specific rank.
|
decorator to run a function only on a specific rank.
|
||||||
"""
|
"""
|
||||||
|
|
@ -116,36 +81,31 @@ def only_on_rank(rank=0):
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
if get_rank() == rank:
|
if get_rank() == rank:
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
else:
|
if sync:
|
||||||
return None
|
dist.barrier()
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def wrapper_spawn_func(rank, world_size, func, kwargs_dict):
|
def wrapper_spawn_func(rank, world_size, backend, func, kwargs):
|
||||||
with setup_parallel(rank, world_size):
|
with setup_parallel(rank, world_size, backend):
|
||||||
func(**kwargs_dict)
|
func(**kwargs)
|
||||||
|
|
||||||
def spawn_parallel_fn(func, world_size=None, **kwargs):
|
def spawn_parallel_fn(func: Callable, world_size: int, backend: str, **kwargs):
|
||||||
|
|
||||||
if world_size is None:
|
|
||||||
world_size = get_device_count()
|
|
||||||
|
|
||||||
if world_size < 1:
|
|
||||||
raise ValueError("world_size must be greater than 0")
|
|
||||||
|
|
||||||
device_count = get_device_count()
|
|
||||||
if world_size > device_count:
|
|
||||||
raise ValueError(f"world_size ({world_size}) exceeds available devices ({device_count})")
|
|
||||||
|
|
||||||
if world_size == 1:
|
if world_size == 1:
|
||||||
func(**kwargs)
|
func(**kwargs)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# clear environment variables
|
||||||
|
for key in ['MASTER_ADDR', 'MASTER_PORT', 'RANK', 'WORLD_SIZE', 'LOCAL_RANK']:
|
||||||
|
if key in os.environ:
|
||||||
|
del os.environ[key]
|
||||||
|
|
||||||
mp.spawn(
|
mp.spawn(
|
||||||
wrapper_spawn_func,
|
wrapper_spawn_func,
|
||||||
nprocs=world_size,
|
nprocs=world_size,
|
||||||
args=(world_size, func, kwargs),
|
args=(world_size, backend, func, kwargs),
|
||||||
join=True
|
join=True
|
||||||
)
|
)
|
||||||
|
|
@ -8,6 +8,7 @@ from torch.nn.utils import clip_grad_norm_
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
from typing import List, Optional, Protocol, TYPE_CHECKING
|
from typing import List, Optional, Protocol, TYPE_CHECKING
|
||||||
|
|
||||||
|
from khaosz.parallel import only_on_rank
|
||||||
from khaosz.trainer.metric_util import (
|
from khaosz.trainer.metric_util import (
|
||||||
grad_max,
|
grad_max,
|
||||||
grad_min,
|
grad_min,
|
||||||
|
|
@ -96,6 +97,7 @@ class CheckpointCallback(TrainCallback):
|
||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
self.last_ckpt_iter = 0
|
self.last_ckpt_iter = 0
|
||||||
|
|
||||||
|
@only_on_rank(0)
|
||||||
def _save_checkpoint(self, context: 'TrainContext'):
|
def _save_checkpoint(self, context: 'TrainContext'):
|
||||||
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}iter_{context.iteration}")
|
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}iter_{context.iteration}")
|
||||||
context.checkpoint = Checkpoint(
|
context.checkpoint = Checkpoint(
|
||||||
|
|
@ -127,6 +129,7 @@ class ProgressBarCallback(TrainCallback):
|
||||||
self.num_epoch = num_epoch
|
self.num_epoch = num_epoch
|
||||||
self.progress_bar: tqdm = None
|
self.progress_bar: tqdm = None
|
||||||
|
|
||||||
|
@only_on_rank(0)
|
||||||
def on_epoch_begin(self, context: 'TrainContext'):
|
def on_epoch_begin(self, context: 'TrainContext'):
|
||||||
self.progress_bar = tqdm(
|
self.progress_bar = tqdm(
|
||||||
context.dataloader,
|
context.dataloader,
|
||||||
|
|
@ -134,6 +137,7 @@ class ProgressBarCallback(TrainCallback):
|
||||||
dynamic_ncols=True
|
dynamic_ncols=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@only_on_rank(0)
|
||||||
def on_batch_end(self, context: 'TrainContext'):
|
def on_batch_end(self, context: 'TrainContext'):
|
||||||
self.progress_bar.set_postfix({
|
self.progress_bar.set_postfix({
|
||||||
"loss": f"{context.loss:.4f}",
|
"loss": f"{context.loss:.4f}",
|
||||||
|
|
@ -141,6 +145,7 @@ class ProgressBarCallback(TrainCallback):
|
||||||
})
|
})
|
||||||
self.progress_bar.update(1)
|
self.progress_bar.update(1)
|
||||||
|
|
||||||
|
@only_on_rank(0)
|
||||||
def on_epoch_end(self, context: 'TrainContext'):
|
def on_epoch_end(self, context: 'TrainContext'):
|
||||||
_ = context
|
_ = context
|
||||||
if self.progress_bar:
|
if self.progress_bar:
|
||||||
|
|
@ -219,7 +224,8 @@ class StepMonitorCallback(TrainCallback):
|
||||||
json.dump(log_data, f, indent=4)
|
json.dump(log_data, f, indent=4)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@only_on_rank(0)
|
||||||
def on_step_end(self, context: 'TrainContext'):
|
def on_step_end(self, context: 'TrainContext'):
|
||||||
if self.step_num % self.log_interval == 0:
|
if self.step_num % self.log_interval == 0:
|
||||||
self._handle_log(context)
|
self._handle_log(context)
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from khaosz.data import ResumableDistributedSampler
|
||||||
from khaosz.trainer.checkpoint import Checkpoint
|
from khaosz.trainer.checkpoint import Checkpoint
|
||||||
from khaosz.trainer.strategy import StrategyFactory, BaseStrategy
|
from khaosz.trainer.strategy import StrategyFactory, BaseStrategy
|
||||||
from khaosz.config.train_config import TrainConfig
|
from khaosz.config.train_config import TrainConfig
|
||||||
from khaosz.parallel.utils import get_current_device, get_world_size, get_rank
|
from khaosz.parallel.setup import get_current_device, get_world_size, get_rank
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional, Self
|
from typing import Optional, Self
|
||||||
|
|
@ -85,7 +85,7 @@ class TrainContextBuilder:
|
||||||
model=self.config.model,
|
model=self.config.model,
|
||||||
train_type=self.config.strategy,
|
train_type=self.config.strategy,
|
||||||
device=device,
|
device=device,
|
||||||
**self.config.kwargs
|
**self.config.extra_kwargs
|
||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,8 @@ from khaosz.trainer.train_callback import (
|
||||||
GradientClippingCallback,
|
GradientClippingCallback,
|
||||||
SchedulerCallback
|
SchedulerCallback
|
||||||
)
|
)
|
||||||
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder, Checkpoint
|
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
|
||||||
|
from khaosz.trainer.checkpoint import Checkpoint
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from khaosz.parallel import (
|
||||||
|
get_rank,
|
||||||
|
only_on_rank,
|
||||||
|
spawn_parallel_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
@only_on_rank(0)
|
||||||
|
def _test_only_on_rank_helper():
|
||||||
|
return True
|
||||||
|
|
||||||
|
def only_on_rank():
|
||||||
|
result = _test_only_on_rank_helper()
|
||||||
|
if get_rank() == 0:
|
||||||
|
assert result is True
|
||||||
|
else:
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def all_reduce():
|
||||||
|
x = torch.tensor([get_rank()], dtype=torch.int)
|
||||||
|
dist.all_reduce(x, op=dist.ReduceOp.SUM)
|
||||||
|
expected_sum = sum(range(dist.get_world_size()))
|
||||||
|
assert x.item() == expected_sum
|
||||||
|
|
||||||
|
def test_spawn_only_on_rank():
|
||||||
|
spawn_parallel_fn(
|
||||||
|
only_on_rank,
|
||||||
|
world_size=2,
|
||||||
|
backend="gloo"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_spawn_all_reduce():
|
||||||
|
spawn_parallel_fn(
|
||||||
|
all_reduce,
|
||||||
|
world_size=2,
|
||||||
|
backend="gloo"
|
||||||
|
)
|
||||||
|
|
@ -1,11 +1,14 @@
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.distributed.fsdp as fsdp
|
||||||
|
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
||||||
from khaosz.trainer import Trainer, SchedulerFactory
|
from khaosz.trainer import Trainer, SchedulerFactory
|
||||||
from khaosz.data import DatasetLoader
|
from khaosz.data import DatasetLoader
|
||||||
|
from khaosz.parallel import get_current_device, spawn_parallel_fn
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
def parse_args() -> argparse.Namespace:
|
||||||
|
|
@ -37,10 +40,26 @@ def parse_args() -> argparse.Namespace:
|
||||||
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.")
|
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.")
|
||||||
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
|
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
|
||||||
|
|
||||||
|
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
def fsdp_wrap(model: nn.Module):
|
||||||
|
|
||||||
|
fsdp_model = fsdp.FullyShardedDataParallel(
|
||||||
|
model,
|
||||||
|
sharding_strategy=fsdp.ShardingStrategy.SHARD_GRAD_OP,
|
||||||
|
mixed_precision=fsdp.MixedPrecision(
|
||||||
|
param_dtype=torch.bfloat16,
|
||||||
|
reduce_dtype=torch.bfloat16,
|
||||||
|
buffer_dtype=torch.bfloat16,
|
||||||
|
),
|
||||||
|
backward_prefetch=fsdp.BackwardPrefetch.BACKWARD_PRE
|
||||||
|
)
|
||||||
|
return fsdp_model
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
train_type: str,
|
train_type: str,
|
||||||
param_path: str,
|
param_path: str,
|
||||||
|
|
@ -65,6 +84,7 @@ def train(
|
||||||
pin_memory: bool,
|
pin_memory: bool,
|
||||||
window_size: int,
|
window_size: int,
|
||||||
stride: int,
|
stride: int,
|
||||||
|
nprocs: int
|
||||||
):
|
):
|
||||||
assert train_type in ["seq", "sft", "dpo"]
|
assert train_type in ["seq", "sft", "dpo"]
|
||||||
assert os.path.exists(param_path)
|
assert os.path.exists(param_path)
|
||||||
|
|
@ -76,8 +96,8 @@ def train(
|
||||||
window_size = parameter.config.m_len
|
window_size = parameter.config.m_len
|
||||||
|
|
||||||
model = parameter.model
|
model = parameter.model
|
||||||
device = torch.device("cuda")
|
current_device = get_current_device()
|
||||||
model = model.to(device=device, dtype=torch.bfloat16)
|
model = fsdp_wrap(model.to(device=current_device, dtype=torch.bfloat16))
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"dpo_beta": dpo_beta,
|
"dpo_beta": dpo_beta,
|
||||||
|
|
@ -90,8 +110,7 @@ def train(
|
||||||
train_type=train_type,
|
train_type=train_type,
|
||||||
load_path=data_root_path,
|
load_path=data_root_path,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
stride=stride,
|
stride=stride
|
||||||
**kwargs
|
|
||||||
)
|
)
|
||||||
|
|
||||||
param_groups = [
|
param_groups = [
|
||||||
|
|
@ -107,7 +126,7 @@ def train(
|
||||||
|
|
||||||
schedule_config = CosineScheduleConfig(
|
schedule_config = CosineScheduleConfig(
|
||||||
warmup_steps=warmup_steps,
|
warmup_steps=warmup_steps,
|
||||||
total_steps=len(dataset) * n_epoch // batch_size,
|
total_steps=len(dataset) * n_epoch // (batch_size * nprocs),
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler = SchedulerFactory.load(optimizer, schedule_config)
|
scheduler = SchedulerFactory.load(optimizer, schedule_config)
|
||||||
|
|
@ -128,7 +147,9 @@ def train(
|
||||||
max_grad_norm=max_grad_norm,
|
max_grad_norm=max_grad_norm,
|
||||||
random_seed=random_seed,
|
random_seed=random_seed,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
pin_memory=pin_memory
|
pin_memory=pin_memory,
|
||||||
|
nprocs=nprocs,
|
||||||
|
extra_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(train_config)
|
trainer = Trainer(train_config)
|
||||||
|
|
@ -137,4 +158,10 @@ def train(
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
train(**vars(args))
|
|
||||||
|
spawn_parallel_fn(
|
||||||
|
train,
|
||||||
|
world_size=args.nprocs,
|
||||||
|
backend="nccl",
|
||||||
|
**vars(args)
|
||||||
|
)
|
||||||
Loading…
Reference in New Issue