From d5cc9f065d7cf93f6cca557184baf1ae3bdbdf07 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 30 Nov 2025 16:44:04 +0800 Subject: [PATCH] =?UTF-8?q?feat(khaosz/parallel):=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E5=B9=B6=E8=A1=8C=E8=AE=AD=E7=BB=83=E8=AE=BE=E7=BD=AE=E5=8A=9F?= =?UTF-8?q?=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/parallel/setup_parallel.py | 61 +++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 khaosz/parallel/setup_parallel.py diff --git a/khaosz/parallel/setup_parallel.py b/khaosz/parallel/setup_parallel.py new file mode 100644 index 0000000..2749337 --- /dev/null +++ b/khaosz/parallel/setup_parallel.py @@ -0,0 +1,61 @@ +import os +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from contextlib import contextmanager + + +@contextmanager +def setup_parallel( + rank: int = 0, + world_size: int=1, + master_addr: str="localhost", + master_port: str="29500" +): + if dist.is_available() and dist.is_initialized(): + yield dist.group.WORLD + + os.environ['MASTER_ADDR'] = master_addr + os.environ['MASTER_PORT'] = master_port + os.environ['RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['LOCAL_RANK'] = str(rank) + + dist.init_process_group( + backend="nccl" if torch.cuda.is_available() else "gloo", + init_method="env://", + rank=rank, + world_size=world_size + ) + torch.cuda.set_device(rank) + + try: + yield dist.group.WORLD + finally: + dist.destroy_process_group() + + +def wrapper_func(rank, world_size, func, config_pack): + with setup_parallel(rank, world_size): + func(**config_pack) + + +def spawn_parallel_fn(func, world_size, kwargs_dict): + mp.spawn( + wrapper_func, + nprocs=world_size, + args=(world_size, func, kwargs_dict,), + join=True + ) + + +def get_current_device() -> torch.device: + if torch.cuda.is_available(): + return torch.device(f"cuda:{torch.cuda.current_device()}") + elif hasattr(torch, 'xpu') and torch.xpu.is_available(): + return torch.device(f"xpu:{torch.xpu.current_device()}") + elif hasattr(torch, 'mps') and torch.mps.is_available(): + return torch.device("mps") + else: + return torch.device("cpu") \ No newline at end of file