22 lines
432 B
Python
22 lines
432 B
Python
from astrai.parallel.setup import (
|
|
get_world_size,
|
|
get_rank,
|
|
get_current_device,
|
|
only_on_rank,
|
|
setup_parallel,
|
|
spawn_parallel_fn,
|
|
)
|
|
|
|
from astrai.parallel.module import RowParallelLinear, ColumnParallelLinear
|
|
|
|
__all__ = [
|
|
"get_world_size",
|
|
"get_rank",
|
|
"get_current_device",
|
|
"only_on_rank",
|
|
"setup_parallel",
|
|
"spawn_parallel_fn",
|
|
"RowParallelLinear",
|
|
"ColumnParallelLinear",
|
|
]
|