AstrAI:一个轻量级 LLM 训练与推理框架的技术解析
ViperEkura Lv1

引言

AstrAI 是一个完全自研的轻量级 Transformer 训练与推理框架,仅依赖 PyTorch,提供从预训练到推理服务的一站式解决方案。其 1B 参数的中英双语模型已开源在 HuggingFace。本文从模型架构、训练系统、推理引擎到分布式部署,进行完整的技术解析。


1. 项目概览

特性 说明
🚀 高性能 GQA/MLA + Paged KV Cache + 连续批处理
🔧 多训练范式 SEQ / SFT / DPO / GRPO 一站式支持
📦 轻量 仅依赖 PyTorch,无其他重型框架
🤗 HF 风格 AutoModel.from_pretrained / AutoTokenizer
🔌 双 API OpenAI / Anthropic 聊天接口兼容
🔬 研究友好 模块化设计 + 工厂模式 + MoE 支持

2. 系统架构

2.1 工厂注册表模式

核心设计是 BaseFactory + Registry,组件通过装饰器动态注册:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class BaseFactory[T]:
_registry: Registry

@classmethod
def register(cls, name: str):
def decorator(component_cls):
cls._registry.register(name, component_cls)
return component_cls
return decorator

@classmethod
def create(cls, name: str, *args, **kwargs) -> T:
# 自动过滤 kwargs 匹配 __init__ 签名
component_cls = cls._registry.get(name)
sig = inspect.signature(component_cls.__init__)
filtered = {k: v for k, v in kwargs.items() if k in sig.parameters}
return component_cls(*args, **filtered)

用法:

1
2
3
@AttnFactory.register("gqa")   class GQA(nn.Module): ...
@FFNFactory.register("moe") class DeepSeekMoE(nn.Module): ...
@StrategyFactory.register("dpo") class DPOStrategy(BaseStrategy): ...

框架通过 9 个工厂类串联:AttnFactoryFFNFactoryStrategyFactoryDatasetFactoryStorageFactorySchedulerFactoryCallbackFactoryConfigFactoryAutoModel。扩展新组件只需新建文件 + 装饰器注册,零侵入。

2.2 模块目录

1
2
3
4
5
6
7
8
9
10
11
12
astrai/
├── config/ # BaseConfig, TrainConfig
├── model/ # AutoRegressiveLM, EmbeddingEncoder
│ └── components/ # GQA, MLA, MLP, DeepSeekMoE, RoPE, RMSNorm
├── trainer/ # Trainer, Strategy, Callback, Scheduler
├── dataset/ # H5/JSON 存储,滑动窗口采样
├── inference/ # KVCache, 连续批处理, FastAPI Server
│ ├── core/ # KVCache, Executor, Scheduler, Task
│ └── api/ # OpenAI / Anthropic Protocol Handler
├── tokenize/ # AutoTokenizer, ChatTemplate
├── parallel/ # DDP / FSDP / Executor 模式
└── serialization.py # safetensors 检查点

3. 核心模型组件

3.1 模型加载

AutoModel.from_pretrained(path) 读取 config.jsonConfigFactory.load()model_type 分发 → 加载 model.safetensors_disable_random_init 上下文在加载权重时将 nn.init.* 临时替换为 no-op,避免预训练权重被覆盖。

3.2 AutoRegressiveLM

Decoder-only 自回归模型,Embedding → N × DecoderBlock → RMSNorm → LM Head:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
@AutoModel.register("autoregressive_lm")
class AutoRegressiveLM(AutoModel):
def __init__(self, config):
super().__init__(config)
self.embed_tokens = Embedding(config.vocab_size, config.dim)
self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base)
self.layers = nn.ModuleList([
DecoderBlock(config, layer_id) for layer_id in range(config.n_layers)
])
self.norm = RMSNorm(config.dim, config.norm_eps)
self.lm_head = Linear(config.dim, config.vocab_size)
if config.tie_weight:
self.lm_head.weight = self.embed_tokens.weight

def forward(self, input_ids, input_mask=None, paged_cache=None, position_ids=None):
x = self.embed_tokens(input_ids)
rotary_emb = self.rotary_embedding(x, position_ids)
for layer in self.layers:
x = layer(x, rotary_emb, attn_mask, paged_cache)
hidden = self.norm(x)
return {"logits": self.lm_head(hidden), "hidden_states": hidden}

训练时 position_ids=None → RoPE 自动构建 [0..seq_len),SDPA 使用 is_causal=True。推理时由 Executor 传入 position_ids 并绑定 Paged KV Cache。

3.3 EmbeddingEncoder

@AutoModel.register("embedding") — 去掉 lm_head 的编码器版本,支持 cls / mean / last 三种 pooling 策略和 L2 归一化。load_state_dict 自动 pop lm_head.weight 从而兼容 transformer 检查点。

3.4 注意力机制

GQA(Grouped Query Attention): 将 Query 头分组,每组共享同一组 KV 头:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
@AttnFactory.register("gqa")
class GQA(nn.Module):
def __init__(self, dim, n_heads, n_kv_heads, use_qk_norm, ...):
self.head_dim = dim // n_heads
self.n_rep = n_heads // n_kv_heads
self.q_proj = Linear(dim, n_heads * self.head_dim)
self.k_proj = Linear(dim, n_kv_heads * self.head_dim) # 更少 KV 头
self.v_proj = Linear(dim, n_kv_heads * self.head_dim)
self.o_proj = Linear(dim, dim)

def forward(self, x, rotary_emb, attn_mask, paged_cache):
q, k, v = self.q_proj(x), self.k_proj(x), self.v_proj(x)
q, k = apply_rotary_emb(q, rotary_emb), apply_rotary_emb(k, rotary_emb)
if paged_cache is not None:
paged_cache.write(layer_id, k, v)
k, v = paged_cache.gather(layer_id)
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
out = F.scaled_dot_product_attention(q, k, v, attn_mask, is_causal=is_causal)
return self.o_proj(out)

repeat_kv 通过 expand 实现零拷贝广播。支持可选 QK Norm(RoPE 之后应用)和门控注意力(sigmoid(gate) × output)。

MLA(Multi-head Latent Attention): DeepSeek-V2 提出的低秩压缩注意力。KV 先通过 kv_a_proj 压缩到 kv_lora_rank 维,再通过 kv_b_proj 展开。Q 和 K 分离为 nope(无 RoPE)和 rope(有 RoPE)两部分。KV Cache 显存从 O(d) 降至 O(d_lora)。

3.5 FFN

MLP(SwiGLU): down(SiLU(gate(x)) × up(x)),标准门控激活。

DeepSeekMoE: 共享专家(所有 token 激活,输出取平均)+ 路由专家(Top-K 门控)。Router 是线性层 + softmax,每个 token 激活 K 个专家并按权重加权求和。

3.6 RoPE

复数乘法实现旋转位置编码。freqs_cis 作为 buffer 预先计算,apply_rotary_emb 将 Q/K 转为复数后相乘。


4. 训练系统

4.1 训练循环

Trainer 的 callback 驱动循环,梯度累积由 BaseExecutor 统一管理:

1
2
3
4
5
6
7
8
9
10
11
12
13
class Trainer:
def _trainer_loop(self, checkpoint=None):
context = TrainContextBuilder(self.train_config).build()
executor = context.executor
for epoch in range(context.epoch, context.config.n_epoch):
for batch in context.dataloader:
with executor.accumulate(context.model):
loss = context.strategy(batch)
executor.backward(loss / executor.grad_accum_steps)
if executor.sync_gradients:
context.optimizer.step()
context.optimizer.zero_grad()
context.scheduler.step()

AccumOptimizer / AccumScheduler 包装器仅在 sync_gradients=True 时执行 step(),训练代码无需手动判断累积边界。

4.2 训练策略

策略 损失函数 应用场景
SEQ 标准交叉熵 预训练
SFT 带 ignore_index=-100 的掩码交叉熵 有监督微调
DPO β × log(π_θ/π_ref) 偏好对比 偏好对齐
GRPO PPO + 组内归一化优势 + KL 惩罚 在线强化学习

DPO 和 GRPO 均使用冻结的参考模型,GRPO 每 sync_interval 步同步一次。

4.3 默认 Callback

顺序:gradient_checkpointing → checkpoint → metric_logger → progress_bar → gradient_clipping → validation。所有 callback 通过 CallbackFactory 注册。

4.4 LR Scheduler

类型 公式 特点
Cosine 线性 warmup + 余弦衰减到 min_rate 标准
SGDR 余弦退火 + 带 t_mult=2 的热重启 跳出局部最优

5. 推理系统

5.1 Paged KV Cache

六类协作:

  • Allocator: 位掩码空闲页管理 + 引用计数 + LRU 淘汰
  • PrefixCache: 滚动哈希(h = h × 31 + token_id)匹配前缀
  • PagePool: 协调 Allocator + PrefixCache
  • TaskTable: task_id → page_table + cached 计数
  • Storage: 5D 张量 (n_layers × n_pages × page_size × n_kv_heads × head_dim)
  • KvcacheView: Storage + page_table 绑定,供 attention 层 write/gather

RoPE 在写入 KV Cache 之前应用,保证位置编码不漂移。

5.2 连续批处理调度

InferenceScheduler 的 daemon 线程执行 4 阶段循环:

  1. Cleanup — 移除已完成任务,释放 KV 页
  2. Refill — 从等待队列取出,task_alloc 分配页,激活
  3. Prefill — 按 (prompt_len, start_pos) 分组,全量 forward
  4. Decode — 选最大同位置组,单 token forward

5.3 Protocol Handler(模板方法模式)

1
2
3
4
ProtocolHandler.handle()
→ build_prompt() → create_response_id()
→ stream: format_stream_start/token/end()
→ non-stream: format_non_stream_response()
  • OpenAIHandler/v1/chat/completions
  • AnthropicHandler/v1/messages

额外端点:GET /healthGET /stats

5.4 采样管线

BaseSamplingStrategy → TemperatureStrategy → TopKStrategy → TopPStrategySamplingPipeline 顺序组合,每次 sample() 执行 softmax + multinomial。


6. 分布式训练

6.1 Executor 模式

ExecutorFactory.create(parallel_mode, **kwargs) 分发,所有 Executor 共享 BaseExecutor 接口:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class BaseExecutor:
def prepare(self, model, optimizer, dataloader, scheduler):
model = self._prepare_model(model)
optimizer = AccumOptimizer(optimizer, self.gradient_state)
scheduler = AccumScheduler(scheduler, self.gradient_state)
return model, optimizer, dataloader, scheduler

def accumulate(self, model):
self.gradient_state._do_sync()
if not self.gradient_state.sync_gradients:
with self._no_sync(model): yield
else: yield

@ExecutorFactory.register("ddp")
class DDPExecutor(BaseExecutor):
def _prepare_model(self, model):
return DDP(model, device_ids=[local_rank], **self._ddp_kwargs)

@ExecutorFactory.register("fsdp")
class FSDPExecutor(BaseExecutor):
def _prepare_model(self, model):
return FSDP(model, auto_wrap_policy=..., device_id=local_rank)
mode 包装方式
none 无包装
ddp DistributedDataParallel + no_sync
fsdp FullyShardedDataParallel + transformer_auto_wrap_policy + use_orig_params=True

6.2 多进程启动

spawn_parallel_fnmp.start_processes(),world_size=1 时直接执行。setup_parallel 上下文管理 init_process_group / destroy_process_group 生命周期。

6.3 检查点

safetensors 格式,rank-0 保存:meta.json + state_dict.safetensors + {key}.pt。加载时广播 metadata 到所有 rank。


7. 与同类框架对比

特性 AstrAI vLLM TGI
训练支持 ✅ SEQ/SFT/DPO/GRPO
Paged KV Cache ✅ 位掩码 + LRU
前缀缓存 ✅ 滚动哈希
MoE ✅ DeepSeekMoE
双 API ✅ OpenAI + Anthropic ✅ OpenAI ✅ OpenAI
分布式训练 ✅ DDP + FSDP + TP
依赖 仅 PyTorch

8. 总结

AstrAI 的核心优势:

  1. 工厂注册表模式 — 9 个工厂类,装饰器注册,扩展零侵入
  2. GQA + MLA + MoE — 灵活可配的注意力与 FFN 选择
  3. 四种训练策略 — SEQ → SFT → DPO → GRPO 完整管线
  4. Executor 模式 — 统一梯度累积,DDP/FSDP 一键切换
  5. Paged KV Cache — 位掩码分配 + 滚动哈希前缀 + LRU
  6. 连续批处理 — 4 阶段 daemon 调度
  7. 双 API 兼容 — OpenAI + Anthropic,模板方法模式

参考链接:

 REWARD AUTHOR