引言
AstrAI 是一个完全自研的轻量级 Transformer 训练与推理框架,仅依赖 PyTorch,提供从预训练到推理服务的一站式解决方案。其 1B 参数的中英双语模型已开源在 HuggingFace。本文从模型架构、训练系统、推理引擎到分布式部署,进行完整的技术解析。
1. 项目概览
| 特性 | 说明 |
|---|---|
| 🚀 高性能 | GQA/MLA + Paged KV Cache + 连续批处理 |
| 🔧 多训练范式 | SEQ / SFT / DPO / GRPO 一站式支持 |
| 📦 轻量 | 仅依赖 PyTorch,无其他重型框架 |
| 🤗 HF 风格 | AutoModel.from_pretrained / AutoTokenizer |
| 🔌 双 API | OpenAI / Anthropic 聊天接口兼容 |
| 🔬 研究友好 | 模块化设计 + 工厂模式 + MoE 支持 |
2. 系统架构
2.1 工厂注册表模式
核心设计是 BaseFactory + Registry,组件通过装饰器动态注册:
1 | class BaseFactory[T]: |
用法:
1 |
框架通过 9 个工厂类串联:AttnFactory、FFNFactory、StrategyFactory、DatasetFactory、StorageFactory、SchedulerFactory、CallbackFactory、ConfigFactory、AutoModel。扩展新组件只需新建文件 + 装饰器注册,零侵入。
2.2 模块目录
1 | astrai/ |
3. 核心模型组件
3.1 模型加载
AutoModel.from_pretrained(path) 读取 config.json → ConfigFactory.load() 按 model_type 分发 → 加载 model.safetensors。_disable_random_init 上下文在加载权重时将 nn.init.* 临时替换为 no-op,避免预训练权重被覆盖。
3.2 AutoRegressiveLM
Decoder-only 自回归模型,Embedding → N × DecoderBlock → RMSNorm → LM Head:
1 |
|
训练时 position_ids=None → RoPE 自动构建 [0..seq_len),SDPA 使用 is_causal=True。推理时由 Executor 传入 position_ids 并绑定 Paged KV Cache。
3.3 EmbeddingEncoder
@AutoModel.register("embedding") — 去掉 lm_head 的编码器版本,支持 cls / mean / last 三种 pooling 策略和 L2 归一化。load_state_dict 自动 pop lm_head.weight 从而兼容 transformer 检查点。
3.4 注意力机制
GQA(Grouped Query Attention): 将 Query 头分组,每组共享同一组 KV 头:
1 |
|
repeat_kv 通过 expand 实现零拷贝广播。支持可选 QK Norm(RoPE 之后应用)和门控注意力(sigmoid(gate) × output)。
MLA(Multi-head Latent Attention): DeepSeek-V2 提出的低秩压缩注意力。KV 先通过 kv_a_proj 压缩到 kv_lora_rank 维,再通过 kv_b_proj 展开。Q 和 K 分离为 nope(无 RoPE)和 rope(有 RoPE)两部分。KV Cache 显存从 O(d) 降至 O(d_lora)。
3.5 FFN
MLP(SwiGLU): down(SiLU(gate(x)) × up(x)),标准门控激活。
DeepSeekMoE: 共享专家(所有 token 激活,输出取平均)+ 路由专家(Top-K 门控)。Router 是线性层 + softmax,每个 token 激活 K 个专家并按权重加权求和。
3.6 RoPE
复数乘法实现旋转位置编码。freqs_cis 作为 buffer 预先计算,apply_rotary_emb 将 Q/K 转为复数后相乘。
4. 训练系统
4.1 训练循环
Trainer 的 callback 驱动循环,梯度累积由 BaseExecutor 统一管理:
1 | class Trainer: |
AccumOptimizer / AccumScheduler 包装器仅在 sync_gradients=True 时执行 step(),训练代码无需手动判断累积边界。
4.2 训练策略
| 策略 | 损失函数 | 应用场景 |
|---|---|---|
| SEQ | 标准交叉熵 | 预训练 |
| SFT | 带 ignore_index=-100 的掩码交叉熵 | 有监督微调 |
| DPO | β × log(π_θ/π_ref) 偏好对比 | 偏好对齐 |
| GRPO | PPO + 组内归一化优势 + KL 惩罚 | 在线强化学习 |
DPO 和 GRPO 均使用冻结的参考模型,GRPO 每 sync_interval 步同步一次。
4.3 默认 Callback
顺序:gradient_checkpointing → checkpoint → metric_logger → progress_bar → gradient_clipping → validation。所有 callback 通过 CallbackFactory 注册。
4.4 LR Scheduler
| 类型 | 公式 | 特点 |
|---|---|---|
| Cosine | 线性 warmup + 余弦衰减到 min_rate | 标准 |
| SGDR | 余弦退火 + 带 t_mult=2 的热重启 | 跳出局部最优 |
5. 推理系统
5.1 Paged KV Cache
六类协作:
- Allocator: 位掩码空闲页管理 + 引用计数 + LRU 淘汰
- PrefixCache: 滚动哈希(
h = h × 31 + token_id)匹配前缀 - PagePool: 协调 Allocator + PrefixCache
- TaskTable: task_id → page_table + cached 计数
- Storage: 5D 张量
(n_layers × n_pages × page_size × n_kv_heads × head_dim) - KvcacheView: Storage + page_table 绑定,供 attention 层 write/gather
RoPE 在写入 KV Cache 之前应用,保证位置编码不漂移。
5.2 连续批处理调度
InferenceScheduler 的 daemon 线程执行 4 阶段循环:
- Cleanup — 移除已完成任务,释放 KV 页
- Refill — 从等待队列取出,task_alloc 分配页,激活
- Prefill — 按
(prompt_len, start_pos)分组,全量 forward - Decode — 选最大同位置组,单 token forward
5.3 Protocol Handler(模板方法模式)
1 | ProtocolHandler.handle() |
- OpenAIHandler →
/v1/chat/completions - AnthropicHandler →
/v1/messages
额外端点:GET /health、GET /stats。
5.4 采样管线
BaseSamplingStrategy → TemperatureStrategy → TopKStrategy → TopPStrategy,SamplingPipeline 顺序组合,每次 sample() 执行 softmax + multinomial。
6. 分布式训练
6.1 Executor 模式
ExecutorFactory.create(parallel_mode, **kwargs) 分发,所有 Executor 共享 BaseExecutor 接口:
1 | class BaseExecutor: |
| mode | 包装方式 |
|---|---|
none |
无包装 |
ddp |
DistributedDataParallel + no_sync |
fsdp |
FullyShardedDataParallel + transformer_auto_wrap_policy + use_orig_params=True |
6.2 多进程启动
spawn_parallel_fn → mp.start_processes(),world_size=1 时直接执行。setup_parallel 上下文管理 init_process_group / destroy_process_group 生命周期。
6.3 检查点
safetensors 格式,rank-0 保存:meta.json + state_dict.safetensors + {key}.pt。加载时广播 metadata 到所有 rank。
7. 与同类框架对比
| 特性 | AstrAI | vLLM | TGI |
|---|---|---|---|
| 训练支持 | ✅ SEQ/SFT/DPO/GRPO | ❌ | ❌ |
| Paged KV Cache | ✅ 位掩码 + LRU | ✅ | ✅ |
| 前缀缓存 | ✅ 滚动哈希 | ✅ | ❌ |
| MoE | ✅ DeepSeekMoE | ✅ | ✅ |
| 双 API | ✅ OpenAI + Anthropic | ✅ OpenAI | ✅ OpenAI |
| 分布式训练 | ✅ DDP + FSDP + TP | ❌ | ❌ |
| 依赖 | 仅 PyTorch | 中 | 大 |
8. 总结
AstrAI 的核心优势:
- 工厂注册表模式 — 9 个工厂类,装饰器注册,扩展零侵入
- GQA + MLA + MoE — 灵活可配的注意力与 FFN 选择
- 四种训练策略 — SEQ → SFT → DPO → GRPO 完整管线
- Executor 模式 — 统一梯度累积,DDP/FSDP 一键切换
- Paged KV Cache — 位掩码分配 + 滚动哈希前缀 + LRU
- 连续批处理 — 4 阶段 daemon 调度
- 双 API 兼容 — OpenAI + Anthropic,模板方法模式
参考链接: