AstrAI:一个轻量级 LLM 训练与推理框架的技术解析
ViperEkura Lv1

引言

在 LLM (大语言模型) 蓬勃发展的今天,如何高效地训练和部署模型成为每一个 AI 工程师面临的挑战。 AstrAI 是一个完全自研的轻量级 Transformer 训练与推理框架,旨在为开发者和研究团队提供高性能、易用的训练推理解决方案。

本文将深入解析 AstrAI 的架构设计、核心特性以及实际应用场景。


1. 项目概览

AstrAI 是一个支持多种训练范式的轻量级框架:

特性 说明
🚀 高性能 优化的训练与推理流程,支持高效并行
🔧 多训练范式 支持 SEQ / SFT / DPO / GRPO 训练
💡 易用 简洁 API设计,丰富示例
🤗 HF 兼容 类HF api 设计, 未来将支持更多模型

GitHub: ViperEkura/AstrAI


2. 核心架构

2.1 模块化设计

AstrAI 采用高度模块化的设计,主要包含以下核心模块:

1
2
3
4
5
6
7
8
9
astrai/
├── model/ # 模型定义 (Transformer, GQA, MLA)
├── trainer/ # 训练器与策略
├── dataset/ # 数据集加载
├── inference/ # 推理引擎与 Server
├── tokenize/ # 分词器与 Chat 模板
├── parallel/ # 分布式训练
├── config/ # 配置管理
└── serialization # 检查点保存/加载

2.2 核心模型组件

AstrAI 实现了两种主流的注意力机制:

GQA (Grouped Query Attention)

1
2
3
4
5
6
7
# 来自 astrai/model/module.py:141-231
class GQA(nn.Module):
def __init__(self, dim, n_heads, n_kv_heads, ...):
self.q_proj = Linear(dim, n_heads * self.head_dim)
self.k_proj = Linear(dim, n_kv_heads * self.head_dim)
self.v_proj = Linear(dim, n_kv_heads * self.head_dim)
self.o_proj = Linear(dim, dim)

关键特性: - 支持 Query 分组机制,减少 KV 头数量 - 支持 QK Norm 归一化 - 支持门控注意力 (Gated Attention) - KV Cache 优化

数学原理: GQA 将 Query 头分成 ngroups = nheads/nKV 组,每组共享同一个 Key 和 Value:



其中 Q ∈ ℝbatch × seq × nheads × dK, V ∈ ℝbatch × seq × nKV × d,通过 repeat_kv 操作将 nKV 扩展到 nheads

MLA (Multi-Head Latent Attention)

1
2
3
4
5
6
# 来自 astrai/model/module.py:233-327
class MLA(nn.Module):
def __init__(self, dim, n_heads, n_kv_heads, kv_lora_rank, ...):
self.kv_a_proj = Linear(dim, kv_lora_rank)
self.kv_norm = RMSNorm(kv_lora_rank)
self.kv_b_proj = Linear(kv_lora_rank, n_heads * head_dim ...)

关键特性: - 压缩 KV 表示,降低显存占用 - 分离 NoPE 和 RoPE 位置编码 - LoRA 风格的 KV 压缩

数学原理: MLA 使用低秩分解压缩 Key 和 Value:


Kc = WKx,  Vc = WVx


Kdec = KcWKVT,  Vdec = VcWKVT

其中压缩后的维度为:dlora 远小于原始的 nKV × d。这使得 KV Cache 显存占用从:


𝒪(nlayers ⋅ nKV ⋅ d ⋅ L)

降低到:


𝒪(nlayers ⋅ dlora ⋅ L)

2.3 RoPE 位置编码

1
2
3
4
5
6
7
8
9
10
# 来自 astrai/model/module.py:29-48
def get_rotary_emb(
dim: int,
max_len: int,
base: float = 10000,
) -> Tuple[Tensor, Tensor]:
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64) / dim)
t = torch.arange(0, max_len, dtype=torch.float64)
freqs = torch.outer(t, theta)
return torch.cos(freqs).float(), torch.sin(freqs).float()

数学原理: RoPE 通过旋转位置编码实现位置感知,其核心公式为:



其中 θi = base − 2i/dd 为隐藏维度。应用旋转后的 Query 和 Key 注意力计算保持相对位置信息:




3. 训练系统

3.1 支持的训练范式

AstrAI 支持四种训练方式,通过工厂模式灵活扩展:

训练类型 用途 关键参数
SEQ 下一个 token 预测 label_smoothing
SFT 有监督微调 loss_mask
DPO 直接偏好优化 beta (默认 0.1)
GRPO 群体相对策略优化 clip_eps, kl_coef

3.2 训练策略实现

SEQ 策略 (下一个 token 预测)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 来自 astrai/trainer/strategy.py:156-178
@StrategyFactory.register("seq")
class SEQStrategy(BaseStrategy):
def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs):
super().__init__(model, device, **kwargs)
self.label_smoothing = label_smoothing

def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device)
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
logits = self.model(input_ids=input_ids)["logits"]
loss = F.cross_entropy(
input=logits.flatten(0, 1).float(),
target=target_ids.flatten(),
label_smoothing=self.label_smoothing,
)
return loss

数学原理: SEQ 使用标准的交叉熵损失进行下一个 token 预测:



当启用 label smoothing (默认为 0) 时,损失函数变为:



其中 α 为平滑系数,V 为词表大小。

SFT 策略 (带 Mask 的监督微调)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# 来自 astrai/trainer/strategy.py:181-211
@StrategyFactory.register("sft")
class SFTStrategy(BaseStrategy):
def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs):
super().__init__(model, device, **kwargs)
self.label_smoothing = label_smoothing

def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device)
input_ids, target_ids, loss_mask = (
batch["input_ids"],
batch["target_ids"],
batch["loss_mask"],
)
ignore_index = -100
logits = self.model(input_ids=input_ids)["logits"]
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
loss = F.cross_entropy(
input=logits.flatten(0, 1).float(),
target=target_ids.flatten(),
ignore_index=ignore_index,
label_smoothing=self.label_smoothing,
)
return loss

数学原理: SFT 只对 loss_mask 为 True 的 token 计算损失:



其中 是需要计算损失的位置集合。通过 mask 机制,可以仅训练特定 token(如 assistant 回复部分),避免学习到 prompt 或用户输入。

DPO 策略 (直接偏好优化)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# 来自 astrai/trainer/strategy.py:214-261
@StrategyFactory.register("dpo")
class DPOStrategy(BaseStrategy):
def __init__(
self,
model: nn.Module,
device: str,
beta: float = 0.1,
reduction: str = "mean",
**kwargs,
):
super().__init__(model, device, **kwargs)
self.ref_model = create_ref_model(model)
self.beta = beta
self.reduction = reduction

def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device)
chosen_ids, rejected_ids = batch["chosen"], batch["rejected"]
chosen_mask, rejected_mask = batch["chosen_mask"], batch["rejected_mask"]
concat_ids = torch.cat([chosen_ids, rejected_ids], dim=0)
concat_mask = torch.cat([chosen_mask, rejected_mask], dim=0)

log_pi = get_logprobs(self.model, concat_ids, concat_mask, self.reduction)

with torch.no_grad():
log_ref = get_logprobs(
self.ref_model, concat_ids, concat_mask, self.reduction
)

log_pi_chosen = log_pi[: chosen_ids.shape[0]]
log_pi_rejected = log_pi[chosen_ids.shape[0] :]
log_ref_chosen = log_ref[: chosen_ids.shape[0]]
log_ref_rejected = log_ref[chosen_ids.shape[0] :]

pi_log_ratio = log_pi_chosen - log_pi_rejected
ref_log_ratio = log_ref_chosen - log_ref_rejected

ratio_diff = pi_log_ratio - ref_log_ratio
dpo_loss = -F.logsigmoid(self.beta * ratio_diff).mean()
return dpo_loss

数学原理: DPO (Direct Preference Optimization) 的目标函数直接最大化偏好数据的 log-likelihood:



其中 σ 是 sigmoid 函数,β 是温度参数(代码中默认为 0.1),yw 是偏好选项,yl 是被拒绝选项。代码实现中:

1
2
# ratio_diff = (log_pi_chosen - log_pi_rejected) - (log_ref_chosen - log_ref_rejected)
dpo_loss = -F.logsigmoid(self.beta * ratio_diff).mean()

3.3 数据集设计

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 来自 astrai/dataset/dataset.py:261-282
@DatasetFactory.register("sft")
class SFTDataset(BaseDataset):
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)

def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key)

def __getitem__(self, index):
begin_idx, end_idx = self.get_index(index)
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(
dtype=torch.long
)
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(
dtype=torch.bool
)
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}

特点: - 多Segment 异步加载 - HDF5 高效存储 - 窗口滑动采样

窗口滑动采样数学原理: 数据集使用滑动窗口生成训练样本:


begini = min (i ⋅ stride, total − 1 − window_size)

endi = min (begini + window_size, total − 1)

采样数量为:



这种采样方式确保数据覆盖最大化,同时保持样本间的时间连续性。


4. 推理系统

4.1 推理引擎

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# 来自 astrai/inference/engine.py:106-150
class InferenceEngine:
def __init__(
self,
model: nn.Module,
tokenizer: AutoTokenizer,
max_batch_size: int = 1,
max_seq_len: Optional[int] = None,
):
self.model = model
self.tokenizer = tokenizer

# Get device and dtype from model parameters
try:
first_param = next(model.parameters())
device = first_param.device
dtype = first_param.dtype
except StopIteration:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32

self.scheduler = InferenceScheduler(
model=self.model,
tokenizer=self.tokenizer,
max_batch_size=max_batch_size,
max_seq_len=max_seq_len,
device=device,
dtype=dtype,
)

self.kv_cache = self.scheduler.kv_cache
self.seq_mask = self.scheduler.seq_mask
self.scheduler.start()

核心能力: - 连续批处理 (Continuous Batching) - 流式输出 - KV Cache 管理

连续批处理数学原理: 传统静态批处理需要等待整个批次完成才返回结果,而连续批处理允许在批次中动态添加新请求:



其中 Ti 是第 i 个请求的 token 数,tprefill 是预填充时间。通过持续批处理,GPU 利用率可提升 2-3 倍。

KV Cache 优化将显存占用从 𝒪(B ⋅ L ⋅ d) 降低到 𝒪(nlayers ⋅ nKV ⋅ L ⋅ d),但通过 GQA/MLA 进一步优化后可降至 𝒪(nlayers ⋅ dlora ⋅ L)


𝒪(nlayers ⋅ dlora ⋅ L)

其中 B 是 batch size,L 是序列长度。

4.2 FastAPI Server

AstrAI 提供 OpenAI 兼容的 API 接口:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# 来自 astrai/inference/server.py:226-286
@app.post("/v1/chat/completions", response_model=CompletionResponse)
async def chat_completion(request: ChatCompletionRequest):
"""OpenAI-compatible chat completion endpoint."""
if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")

# Convert messages to prompt using engine's tokenizer
prompt = convert_messages_to_prompt(request.messages, engine=_engine)

if request.stream:
# Streaming response (use synchronous generator)
generator = _engine.generate(
prompt=prompt,
stream=True,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
)

def generate_stream():
for token in generator:
if token == "[DONE]":
break
yield f"data: {json.dumps({'choices': [{'delta': {'content': token}}]})}\n\n"
yield "data: [DONE]\n\n"

return StreamingResponse(
generate_stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
else:
# Non-streaming response
result = _engine.generate(
prompt=prompt,
stream=False,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
)

# Build OpenAI-style response
import time
resp = CompletionResponse(
id=f"chatcmpl-{int(time.time())}",
created=int(time.time()),
choices=[
{
"index": 0,
"message": {"role": "assistant", "content": result},
"finish_reason": "stop",
}
],
)
return resp

API 端点: - /health - 健康检查 - /stats - 统计信息 - /v1/chat/completions - OpenAI 兼容聊天接口 - /generate - 简单生成接口


5. 分布式训练

5.1 并行设置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# 来自 astrai/parallel/setup.py:127-170
def spawn_parallel_fn(
func: Callable,
world_size: int,
backend: str = "nccl",
master_addr: str = "localhost",
master_port: str = "29500",
device_type: str = "cuda",
device_ids: Optional[List[int]] = None,
**kwargs,
):
# clear environment variables
for key in [
"MASTER_ADDR",
"MASTER_PORT",
"RANK",
"WORLD_SIZE",
"LOCAL_RANK",
"LOCAL_DEVICE",
]:
if key in os.environ:
del os.environ[key]

if world_size == 1:
device_ids = device_ids or [0]
device_id = torch.device(device_type, device_ids[0])
os.environ["LOCAL_DEVICE"] = str(device_id)
func(**kwargs)
return

wrapper_spawn_func_args = (
world_size,
backend,
master_addr,
master_port,
device_type,
device_ids,
func,
kwargs,
)

mp.spawn(
wrapper_spawn_func, nprocs=world_size, args=wrapper_spawn_func_args, join=True
)

支持: - NCCL 后端 (GPU) - GLOO 后端 (CPU) - CCL 后端 (Intel GPU)

分布式训练通信复杂度: 在 DDP 训练中,每个 GPU 计算梯度后需要 All-Reduce 同步:



其中 n 是 world_size。通过梯度压缩和异步通信可以进一步减少通信开销。

5.2 检查点管理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# 来自 astrai/serialization.py:51-106
class Checkpoint:
def __init__(
self,
state_dict: Dict[str, Any],
epoch: int = 0,
iteration: int = 0,
):
self.state_dict = state_dict
self.epoch = epoch
self.iteration = iteration

def save(self, save_dir: str) -> None:
save_path = Path(save_dir)
save_path.mkdir(parents=True, exist_ok=True)

rank = get_rank()
if rank == 0:
meta = {
"epoch": self.epoch,
"iteration": self.iteration,
}
with open(save_path / "meta.json", "w") as f:
json.dump(meta, f, indent=2)

st.save_file(self.state_dict, save_path / "state_dict.safetensors")

@classmethod
def load(cls, save_dir: str) -> "Checkpoint":
rank = get_rank()
save_path = Path(save_dir)

meta = {}
if rank == 0:
with open(Path(save_dir) / "meta.json", "r") as f:
meta = json.load(f)

if dist.is_initialized():
meta_list = [meta]
dist.broadcast_object_list(meta_list, src=0)
meta = meta_list[0]

state_dict = st.load_file(save_path / "state_dict.safetensors")

return cls(
state_dict=state_dict,
epoch=meta["epoch"],
iteration=meta["iteration"],
)

分布式检查点保存原理: - 仅 rank 0 进程负责写入文件系统,避免竞争 - 元数据 (epoch, iteration) 通过 dist.broadcast_object_list 广播到所有进程 - 使用 safetensors 格式支持内存映射,实现零拷贝加载


6. 快速开始

6.1 安装

1
2
3
git clone https://github.com/ViperEkura/AstrAI.git
cd AstrAI
pip install -e .

6.2 训练模型

1
2
3
4
python scripts/tools/train.py \
--train_type=seq \
--data_root_path=/path/to/dataset \
--param_path=/path/to/params

6.3 启动推理服务

1
python scripts/tools/server.py --param_path=/path/to/params

7. 与同类框架对比

特性 AstrAI vLLM TGI
训练支持 ✅ SEQ/SFT/DPO/GRPO ❌ 仅推理 ❌ 仅推理
推理支持 ✅ TP/PP ✅ TP/PP
分布式训练 ✅ DDP
连续批处理
依赖简洁 ✅ 仅 PyTorch

8. 未来规划

  • 完善 API 文档
  • 支持更多模型架构 (Yi, DeepSeek, Mistral)
  • 与 LangChain/LlamaIndex 集成
  • 性能优化与基准测试
  • 2.0 大版本发布

9. 总结

AstrAI 是一个完全自研的轻量级 Transformer 框架,具有以下优势:

  1. 模块化设计 - 清晰的分层架构,易于扩展
  2. 多样化的训练支持 - SEQ/SFT/DPO/GRPO 一站式解决方案
  3. 高性能推理 - 连续批处理 + KV Cache 优化
  4. 简洁易用 - 友好的 API 与丰富的示例
  5. HF 兼容 - 无缝对接 HuggingFace 生态

欢迎 Star、Fork 和贡献!

参考链接:

 REWARD AUTHOR