AstrAI/astrai/model/__init__.py

23 lines
341 B
Python

from astrai.model.module import (
GQA,
MLP,
DecoderBlock,
Linear,
RMSNorm,
)
from astrai.model.transformer import Transformer
from astrai.model.automodel import AutoModel
__all__ = [
# Modules
"Linear",
"RMSNorm",
"MLP",
"GQA",
"DecoderBlock",
# Models
"Transformer",
"AutoModel",
]