feat(paralell): 添加分布式训练配置与并行工具支持
This commit is contained in:
parent
d31137a2db
commit
d52685facd
|
|
@ -9,7 +9,6 @@ if TYPE_CHECKING:
|
|||
|
||||
@dataclass
|
||||
class TrainConfig:
|
||||
|
||||
strategy: "BaseStrategy" = field(
|
||||
default=None,
|
||||
metadata={"help": "Training strategy."}
|
||||
|
|
@ -54,6 +53,8 @@ class TrainConfig:
|
|||
default=1.0,
|
||||
metadata={"help": "Maximum gradient norm."}
|
||||
)
|
||||
|
||||
# dataloader setting
|
||||
random_seed: int = field(
|
||||
default=3407,
|
||||
metadata={"help": "Random seed."}
|
||||
|
|
@ -70,3 +71,9 @@ class TrainConfig:
|
|||
default=False,
|
||||
metadata={"help": "Pin memory for dataloader."}
|
||||
)
|
||||
|
||||
# distributed training
|
||||
nprocs: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of processes for distributed training."}
|
||||
)
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
from khaosz.parallel.utils import (
|
||||
get_world_size,
|
||||
get_rank,
|
||||
get_device_count,
|
||||
get_current_device,
|
||||
get_available_backend,
|
||||
setup_parallel,
|
||||
only_main_procs,
|
||||
spawn_parallel_fn
|
||||
)
|
||||
|
||||
from khaosz.parallel.module import (
|
||||
RowParallelLinear,
|
||||
ColumnParallelLinear
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_world_size",
|
||||
"get_rank",
|
||||
"get_device_count",
|
||||
"get_current_device",
|
||||
"get_available_backend",
|
||||
"setup_parallel",
|
||||
"only_main_procs",
|
||||
"spawn_parallel_fn",
|
||||
|
||||
"RowParallelLinear",
|
||||
"ColumnParallelLinear"
|
||||
]
|
||||
|
|
@ -33,6 +33,17 @@ def get_available_backend():
|
|||
else:
|
||||
return "gloo"
|
||||
|
||||
def get_world_size() -> int:
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
return dist.get_world_size()
|
||||
else:
|
||||
return 1
|
||||
|
||||
def get_rank() -> int:
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
return dist.get_rank()
|
||||
else:
|
||||
return 0
|
||||
|
||||
@contextmanager
|
||||
def setup_parallel(
|
||||
|
|
@ -76,6 +87,21 @@ def setup_parallel(
|
|||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
@contextmanager
|
||||
def only_main_procs(main_process_rank=0, block=True):
|
||||
is_main_proc = (get_rank() == main_process_rank)
|
||||
|
||||
if dist.is_initialized() and block:
|
||||
dist.barrier()
|
||||
|
||||
try:
|
||||
yield is_main_proc
|
||||
|
||||
finally:
|
||||
if dist.is_initialized() and block:
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def wrapper_spawn_func(rank, world_size, func, kwargs_dict):
|
||||
with setup_parallel(rank, world_size):
|
||||
func(**kwargs_dict)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from torch.utils.data import DataLoader
|
|||
from khaosz.config import Checkpoint
|
||||
from khaosz.data import ResumableDistributedSampler
|
||||
from khaosz.trainer.schedule import BaseScheduler, SchedulerFactory
|
||||
from khaosz.parallel.utils import get_world_size, get_rank
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from khaosz.trainer.trainer import Trainer
|
||||
|
|
@ -20,6 +21,9 @@ class TrainContext:
|
|||
batch_iter: int = field(default=0)
|
||||
loss: float = field(default=0.0)
|
||||
|
||||
wolrd_size: int = field(default=1)
|
||||
rank: int = field(default=0)
|
||||
|
||||
def asdict(self) -> dict:
|
||||
return {field.name: getattr(self, field.name)
|
||||
for field in fields(self)}
|
||||
|
|
@ -102,4 +106,9 @@ class TrainContextBuilder:
|
|||
return self
|
||||
|
||||
def build(self) -> TrainContext:
|
||||
|
||||
if self.trainer.train_config.nprocs > 1:
|
||||
self._context.wolrd_size = get_world_size()
|
||||
self._context.rank = get_rank()
|
||||
|
||||
return self._context
|
||||
Loading…
Reference in New Issue