feat(paralell): 添加分布式训练配置与并行工具支持
This commit is contained in:
parent
d31137a2db
commit
d52685facd
|
|
@ -9,7 +9,6 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainConfig:
|
class TrainConfig:
|
||||||
|
|
||||||
strategy: "BaseStrategy" = field(
|
strategy: "BaseStrategy" = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Training strategy."}
|
metadata={"help": "Training strategy."}
|
||||||
|
|
@ -54,6 +53,8 @@ class TrainConfig:
|
||||||
default=1.0,
|
default=1.0,
|
||||||
metadata={"help": "Maximum gradient norm."}
|
metadata={"help": "Maximum gradient norm."}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# dataloader setting
|
||||||
random_seed: int = field(
|
random_seed: int = field(
|
||||||
default=3407,
|
default=3407,
|
||||||
metadata={"help": "Random seed."}
|
metadata={"help": "Random seed."}
|
||||||
|
|
@ -70,3 +71,9 @@ class TrainConfig:
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Pin memory for dataloader."}
|
metadata={"help": "Pin memory for dataloader."}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# distributed training
|
||||||
|
nprocs: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "Number of processes for distributed training."}
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,29 @@
|
||||||
|
from khaosz.parallel.utils import (
|
||||||
|
get_world_size,
|
||||||
|
get_rank,
|
||||||
|
get_device_count,
|
||||||
|
get_current_device,
|
||||||
|
get_available_backend,
|
||||||
|
setup_parallel,
|
||||||
|
only_main_procs,
|
||||||
|
spawn_parallel_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
from khaosz.parallel.module import (
|
||||||
|
RowParallelLinear,
|
||||||
|
ColumnParallelLinear
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_world_size",
|
||||||
|
"get_rank",
|
||||||
|
"get_device_count",
|
||||||
|
"get_current_device",
|
||||||
|
"get_available_backend",
|
||||||
|
"setup_parallel",
|
||||||
|
"only_main_procs",
|
||||||
|
"spawn_parallel_fn",
|
||||||
|
|
||||||
|
"RowParallelLinear",
|
||||||
|
"ColumnParallelLinear"
|
||||||
|
]
|
||||||
|
|
@ -33,6 +33,17 @@ def get_available_backend():
|
||||||
else:
|
else:
|
||||||
return "gloo"
|
return "gloo"
|
||||||
|
|
||||||
|
def get_world_size() -> int:
|
||||||
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
return dist.get_world_size()
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def get_rank() -> int:
|
||||||
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
return dist.get_rank()
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def setup_parallel(
|
def setup_parallel(
|
||||||
|
|
@ -76,6 +87,21 @@ def setup_parallel(
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def only_main_procs(main_process_rank=0, block=True):
|
||||||
|
is_main_proc = (get_rank() == main_process_rank)
|
||||||
|
|
||||||
|
if dist.is_initialized() and block:
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield is_main_proc
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if dist.is_initialized() and block:
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
def wrapper_spawn_func(rank, world_size, func, kwargs_dict):
|
def wrapper_spawn_func(rank, world_size, func, kwargs_dict):
|
||||||
with setup_parallel(rank, world_size):
|
with setup_parallel(rank, world_size):
|
||||||
func(**kwargs_dict)
|
func(**kwargs_dict)
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from torch.utils.data import DataLoader
|
||||||
from khaosz.config import Checkpoint
|
from khaosz.config import Checkpoint
|
||||||
from khaosz.data import ResumableDistributedSampler
|
from khaosz.data import ResumableDistributedSampler
|
||||||
from khaosz.trainer.schedule import BaseScheduler, SchedulerFactory
|
from khaosz.trainer.schedule import BaseScheduler, SchedulerFactory
|
||||||
|
from khaosz.parallel.utils import get_world_size, get_rank
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from khaosz.trainer.trainer import Trainer
|
from khaosz.trainer.trainer import Trainer
|
||||||
|
|
@ -20,6 +21,9 @@ class TrainContext:
|
||||||
batch_iter: int = field(default=0)
|
batch_iter: int = field(default=0)
|
||||||
loss: float = field(default=0.0)
|
loss: float = field(default=0.0)
|
||||||
|
|
||||||
|
wolrd_size: int = field(default=1)
|
||||||
|
rank: int = field(default=0)
|
||||||
|
|
||||||
def asdict(self) -> dict:
|
def asdict(self) -> dict:
|
||||||
return {field.name: getattr(self, field.name)
|
return {field.name: getattr(self, field.name)
|
||||||
for field in fields(self)}
|
for field in fields(self)}
|
||||||
|
|
@ -102,4 +106,9 @@ class TrainContextBuilder:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def build(self) -> TrainContext:
|
def build(self) -> TrainContext:
|
||||||
|
|
||||||
|
if self.trainer.train_config.nprocs > 1:
|
||||||
|
self._context.wolrd_size = get_world_size()
|
||||||
|
self._context.rank = get_rank()
|
||||||
|
|
||||||
return self._context
|
return self._context
|
||||||
Loading…
Reference in New Issue