feat(paralell): 添加分布式训练配置与并行工具支持

This commit is contained in:
ViperEkura 2025-12-05 13:52:17 +08:00
parent d31137a2db
commit d52685facd
4 changed files with 72 additions and 1 deletions

View File

@ -9,7 +9,6 @@ if TYPE_CHECKING:
@dataclass
class TrainConfig:
strategy: "BaseStrategy" = field(
default=None,
metadata={"help": "Training strategy."}
@ -54,6 +53,8 @@ class TrainConfig:
default=1.0,
metadata={"help": "Maximum gradient norm."}
)
# dataloader setting
random_seed: int = field(
default=3407,
metadata={"help": "Random seed."}
@ -69,4 +70,10 @@ class TrainConfig:
pin_memory: bool = field(
default=False,
metadata={"help": "Pin memory for dataloader."}
)
# distributed training
nprocs: int = field(
default=1,
metadata={"help": "Number of processes for distributed training."}
)

View File

@ -0,0 +1,29 @@
from khaosz.parallel.utils import (
get_world_size,
get_rank,
get_device_count,
get_current_device,
get_available_backend,
setup_parallel,
only_main_procs,
spawn_parallel_fn
)
from khaosz.parallel.module import (
RowParallelLinear,
ColumnParallelLinear
)
__all__ = [
"get_world_size",
"get_rank",
"get_device_count",
"get_current_device",
"get_available_backend",
"setup_parallel",
"only_main_procs",
"spawn_parallel_fn",
"RowParallelLinear",
"ColumnParallelLinear"
]

View File

@ -33,6 +33,17 @@ def get_available_backend():
else:
return "gloo"
def get_world_size() -> int:
if dist.is_available() and dist.is_initialized():
return dist.get_world_size()
else:
return 1
def get_rank() -> int:
if dist.is_available() and dist.is_initialized():
return dist.get_rank()
else:
return 0
@contextmanager
def setup_parallel(
@ -76,6 +87,21 @@ def setup_parallel(
if dist.is_initialized():
dist.destroy_process_group()
@contextmanager
def only_main_procs(main_process_rank=0, block=True):
is_main_proc = (get_rank() == main_process_rank)
if dist.is_initialized() and block:
dist.barrier()
try:
yield is_main_proc
finally:
if dist.is_initialized() and block:
dist.barrier()
def wrapper_spawn_func(rank, world_size, func, kwargs_dict):
with setup_parallel(rank, world_size):
func(**kwargs_dict)

View File

@ -5,6 +5,7 @@ from torch.utils.data import DataLoader
from khaosz.config import Checkpoint
from khaosz.data import ResumableDistributedSampler
from khaosz.trainer.schedule import BaseScheduler, SchedulerFactory
from khaosz.parallel.utils import get_world_size, get_rank
if TYPE_CHECKING:
from khaosz.trainer.trainer import Trainer
@ -20,6 +21,9 @@ class TrainContext:
batch_iter: int = field(default=0)
loss: float = field(default=0.0)
wolrd_size: int = field(default=1)
rank: int = field(default=0)
def asdict(self) -> dict:
return {field.name: getattr(self, field.name)
for field in fields(self)}
@ -102,4 +106,9 @@ class TrainContextBuilder:
return self
def build(self) -> TrainContext:
if self.trainer.train_config.nprocs > 1:
self._context.wolrd_size = get_world_size()
self._context.rank = get_rank()
return self._context