feat(paralell): 添加分布式训练配置与并行工具支持

This commit is contained in:
ViperEkura 2025-12-05 13:52:17 +08:00
parent d31137a2db
commit d52685facd
4 changed files with 72 additions and 1 deletions

View File

@ -9,7 +9,6 @@ if TYPE_CHECKING:
@dataclass @dataclass
class TrainConfig: class TrainConfig:
strategy: "BaseStrategy" = field( strategy: "BaseStrategy" = field(
default=None, default=None,
metadata={"help": "Training strategy."} metadata={"help": "Training strategy."}
@ -54,6 +53,8 @@ class TrainConfig:
default=1.0, default=1.0,
metadata={"help": "Maximum gradient norm."} metadata={"help": "Maximum gradient norm."}
) )
# dataloader setting
random_seed: int = field( random_seed: int = field(
default=3407, default=3407,
metadata={"help": "Random seed."} metadata={"help": "Random seed."}
@ -69,4 +70,10 @@ class TrainConfig:
pin_memory: bool = field( pin_memory: bool = field(
default=False, default=False,
metadata={"help": "Pin memory for dataloader."} metadata={"help": "Pin memory for dataloader."}
)
# distributed training
nprocs: int = field(
default=1,
metadata={"help": "Number of processes for distributed training."}
) )

View File

@ -0,0 +1,29 @@
from khaosz.parallel.utils import (
get_world_size,
get_rank,
get_device_count,
get_current_device,
get_available_backend,
setup_parallel,
only_main_procs,
spawn_parallel_fn
)
from khaosz.parallel.module import (
RowParallelLinear,
ColumnParallelLinear
)
__all__ = [
"get_world_size",
"get_rank",
"get_device_count",
"get_current_device",
"get_available_backend",
"setup_parallel",
"only_main_procs",
"spawn_parallel_fn",
"RowParallelLinear",
"ColumnParallelLinear"
]

View File

@ -33,6 +33,17 @@ def get_available_backend():
else: else:
return "gloo" return "gloo"
def get_world_size() -> int:
if dist.is_available() and dist.is_initialized():
return dist.get_world_size()
else:
return 1
def get_rank() -> int:
if dist.is_available() and dist.is_initialized():
return dist.get_rank()
else:
return 0
@contextmanager @contextmanager
def setup_parallel( def setup_parallel(
@ -76,6 +87,21 @@ def setup_parallel(
if dist.is_initialized(): if dist.is_initialized():
dist.destroy_process_group() dist.destroy_process_group()
@contextmanager
def only_main_procs(main_process_rank=0, block=True):
is_main_proc = (get_rank() == main_process_rank)
if dist.is_initialized() and block:
dist.barrier()
try:
yield is_main_proc
finally:
if dist.is_initialized() and block:
dist.barrier()
def wrapper_spawn_func(rank, world_size, func, kwargs_dict): def wrapper_spawn_func(rank, world_size, func, kwargs_dict):
with setup_parallel(rank, world_size): with setup_parallel(rank, world_size):
func(**kwargs_dict) func(**kwargs_dict)

View File

@ -5,6 +5,7 @@ from torch.utils.data import DataLoader
from khaosz.config import Checkpoint from khaosz.config import Checkpoint
from khaosz.data import ResumableDistributedSampler from khaosz.data import ResumableDistributedSampler
from khaosz.trainer.schedule import BaseScheduler, SchedulerFactory from khaosz.trainer.schedule import BaseScheduler, SchedulerFactory
from khaosz.parallel.utils import get_world_size, get_rank
if TYPE_CHECKING: if TYPE_CHECKING:
from khaosz.trainer.trainer import Trainer from khaosz.trainer.trainer import Trainer
@ -20,6 +21,9 @@ class TrainContext:
batch_iter: int = field(default=0) batch_iter: int = field(default=0)
loss: float = field(default=0.0) loss: float = field(default=0.0)
wolrd_size: int = field(default=1)
rank: int = field(default=0)
def asdict(self) -> dict: def asdict(self) -> dict:
return {field.name: getattr(self, field.name) return {field.name: getattr(self, field.name)
for field in fields(self)} for field in fields(self)}
@ -102,4 +106,9 @@ class TrainContextBuilder:
return self return self
def build(self) -> TrainContext: def build(self) -> TrainContext:
if self.trainer.train_config.nprocs > 1:
self._context.wolrd_size = get_world_size()
self._context.rank = get_rank()
return self._context return self._context