AstrAI/tests/test_parallel.py

39 lines
812 B
Python

import torch
import torch.distributed as dist
from khaosz.parallel import (
get_rank,
only_on_rank,
spawn_parallel_fn
)
@only_on_rank(0)
def _test_only_on_rank_helper():
return True
def only_on_rank():
result = _test_only_on_rank_helper()
if get_rank() == 0:
assert result is True
else:
assert result is None
def all_reduce():
x = torch.tensor([get_rank()], dtype=torch.int)
dist.all_reduce(x, op=dist.ReduceOp.SUM)
expected_sum = sum(range(dist.get_world_size()))
assert x.item() == expected_sum
def test_spawn_only_on_rank():
spawn_parallel_fn(
only_on_rank,
world_size=2,
backend="gloo"
)
def test_spawn_all_reduce():
spawn_parallel_fn(
all_reduce,
world_size=2,
backend="gloo"
)