diff --git a/khaosz/config/train_config.py b/khaosz/config/train_config.py index 28214da..49b0254 100644 --- a/khaosz/config/train_config.py +++ b/khaosz/config/train_config.py @@ -9,7 +9,6 @@ if TYPE_CHECKING: @dataclass class TrainConfig: - strategy: "BaseStrategy" = field( default=None, metadata={"help": "Training strategy."} @@ -54,6 +53,8 @@ class TrainConfig: default=1.0, metadata={"help": "Maximum gradient norm."} ) + + # dataloader setting random_seed: int = field( default=3407, metadata={"help": "Random seed."} @@ -69,4 +70,10 @@ class TrainConfig: pin_memory: bool = field( default=False, metadata={"help": "Pin memory for dataloader."} + ) + + # distributed training + nprocs: int = field( + default=1, + metadata={"help": "Number of processes for distributed training."} ) \ No newline at end of file diff --git a/khaosz/parallel/__init__.py b/khaosz/parallel/__init__.py new file mode 100644 index 0000000..8bee6f2 --- /dev/null +++ b/khaosz/parallel/__init__.py @@ -0,0 +1,29 @@ +from khaosz.parallel.utils import ( + get_world_size, + get_rank, + get_device_count, + get_current_device, + get_available_backend, + setup_parallel, + only_main_procs, + spawn_parallel_fn +) + +from khaosz.parallel.module import ( + RowParallelLinear, + ColumnParallelLinear +) + +__all__ = [ + "get_world_size", + "get_rank", + "get_device_count", + "get_current_device", + "get_available_backend", + "setup_parallel", + "only_main_procs", + "spawn_parallel_fn", + + "RowParallelLinear", + "ColumnParallelLinear" +] diff --git a/khaosz/parallel/utils.py b/khaosz/parallel/utils.py index bf54549..a18c21b 100644 --- a/khaosz/parallel/utils.py +++ b/khaosz/parallel/utils.py @@ -33,6 +33,17 @@ def get_available_backend(): else: return "gloo" +def get_world_size() -> int: + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size() + else: + return 1 + +def get_rank() -> int: + if dist.is_available() and dist.is_initialized(): + return dist.get_rank() + else: + return 0 @contextmanager def setup_parallel( @@ -76,6 +87,21 @@ def setup_parallel( if dist.is_initialized(): dist.destroy_process_group() +@contextmanager +def only_main_procs(main_process_rank=0, block=True): + is_main_proc = (get_rank() == main_process_rank) + + if dist.is_initialized() and block: + dist.barrier() + + try: + yield is_main_proc + + finally: + if dist.is_initialized() and block: + dist.barrier() + + def wrapper_spawn_func(rank, world_size, func, kwargs_dict): with setup_parallel(rank, world_size): func(**kwargs_dict) diff --git a/khaosz/trainer/train_context.py b/khaosz/trainer/train_context.py index 77a8199..db36b5b 100644 --- a/khaosz/trainer/train_context.py +++ b/khaosz/trainer/train_context.py @@ -5,6 +5,7 @@ from torch.utils.data import DataLoader from khaosz.config import Checkpoint from khaosz.data import ResumableDistributedSampler from khaosz.trainer.schedule import BaseScheduler, SchedulerFactory +from khaosz.parallel.utils import get_world_size, get_rank if TYPE_CHECKING: from khaosz.trainer.trainer import Trainer @@ -20,6 +21,9 @@ class TrainContext: batch_iter: int = field(default=0) loss: float = field(default=0.0) + wolrd_size: int = field(default=1) + rank: int = field(default=0) + def asdict(self) -> dict: return {field.name: getattr(self, field.name) for field in fields(self)} @@ -102,4 +106,9 @@ class TrainContextBuilder: return self def build(self) -> TrainContext: + + if self.trainer.train_config.nprocs > 1: + self._context.wolrd_size = get_world_size() + self._context.rank = get_rank() + return self._context \ No newline at end of file