AstrAI/astrai/parallel/__init__.py

21 lines
431 B
Python

from astrai.parallel.module import ColumnParallelLinear, RowParallelLinear
from astrai.parallel.setup import (
get_current_device,
get_rank,
get_world_size,
only_on_rank,
setup_parallel,
spawn_parallel_fn,
)
__all__ = [
"get_world_size",
"get_rank",
"get_current_device",
"only_on_rank",
"setup_parallel",
"spawn_parallel_fn",
"RowParallelLinear",
"ColumnParallelLinear",
]