引言
在 LLM (大语言模型) 蓬勃发展的今天,如何高效地训练和部署模型成为每一个 AI 工程师面临的挑战。 AstrAI 是一个完全自研的轻量级 Transformer 训练与推理框架,旨在为开发者和研究团队提供高性能、易用的训练推理解决方案。
本文将深入解析 AstrAI 的架构设计、核心特性以及实际应用场景。
1. 项目概览
AstrAI 是一个支持多种训练范式的轻量级框架:
| 🚀 高性能 |
优化的训练与推理流程,支持高效并行 |
| 🔧 多训练范式 |
支持 SEQ / SFT / DPO / GRPO 训练 |
| 💡 易用 |
简洁 API设计,丰富示例 |
| 🤗 HF 兼容 |
类HF api 设计, 未来将支持更多模型 |
GitHub: ViperEkura/AstrAI
2. 核心架构
2.1 模块化设计
AstrAI 采用高度模块化的设计,主要包含以下核心模块:
1 2 3 4 5 6 7 8 9
| astrai/ ├── model/ # 模型定义 (Transformer, GQA, MLA) ├── trainer/ # 训练器与策略 ├── dataset/ # 数据集加载 ├── inference/ # 推理引擎与 Server ├── tokenize/ # 分词器与 Chat 模板 ├── parallel/ # 分布式训练 ├── config/ # 配置管理 └── serialization # 检查点保存/加载
|
2.2 核心模型组件
AstrAI 实现了两种主流的注意力机制:
GQA (Grouped Query Attention)
1 2 3 4 5 6 7
| class GQA(nn.Module): def __init__(self, dim, n_heads, n_kv_heads, ...): self.q_proj = Linear(dim, n_heads * self.head_dim) self.k_proj = Linear(dim, n_kv_heads * self.head_dim) self.v_proj = Linear(dim, n_kv_heads * self.head_dim) self.o_proj = Linear(dim, dim)
|
关键特性: - 支持 Query 分组机制,减少 KV 头数量 - 支持 QK Norm 归一化 - 支持门控注意力 (Gated Attention) - KV Cache 优化
数学原理: GQA 将 Query 头分成 ngroups = nheads/nKV 组,每组共享同一个 Key 和 Value:
其中 Q ∈ ℝbatch × seq × nheads × d,K, V ∈ ℝbatch × seq × nKV × d,通过 repeat_kv 操作将 nKV 扩展到 nheads。
MLA (Multi-Head Latent Attention)
1 2 3 4 5 6
| class MLA(nn.Module): def __init__(self, dim, n_heads, n_kv_heads, kv_lora_rank, ...): self.kv_a_proj = Linear(dim, kv_lora_rank) self.kv_norm = RMSNorm(kv_lora_rank) self.kv_b_proj = Linear(kv_lora_rank, n_heads * head_dim ...)
|
关键特性: - 压缩 KV 表示,降低显存占用 - 分离 NoPE 和 RoPE 位置编码 - LoRA 风格的 KV 压缩
数学原理: MLA 使用低秩分解压缩 Key 和 Value:
Kc = WK x, Vc = WV x
Kdec = Kc WKVT, Vdec = Vc WKVT
其中压缩后的维度为:dlora 远小于原始的 nKV × d。这使得 KV Cache 显存占用从:
𝒪(nlayers ⋅ nKV ⋅ d ⋅ L)
降低到:
𝒪(nlayers ⋅ dlora ⋅ L)
2.3 RoPE 位置编码
1 2 3 4 5 6 7 8 9 10
| def get_rotary_emb( dim: int, max_len: int, base: float = 10000, ) -> Tuple[Tensor, Tensor]: theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64) / dim) t = torch.arange(0, max_len, dtype=torch.float64) freqs = torch.outer(t, theta) return torch.cos(freqs).float(), torch.sin(freqs).float()
|
数学原理: RoPE 通过旋转位置编码实现位置感知,其核心公式为:
其中 θi = base − 2i/d,d 为隐藏维度。应用旋转后的 Query 和 Key 注意力计算保持相对位置信息:
3. 训练系统
3.1 支持的训练范式
AstrAI 支持四种训练方式,通过工厂模式灵活扩展:
| SEQ |
下一个 token 预测 |
label_smoothing |
| SFT |
有监督微调 |
loss_mask |
| DPO |
直接偏好优化 |
beta (默认 0.1) |
| GRPO |
群体相对策略优化 |
clip_eps, kl_coef |
3.2 训练策略实现
SEQ 策略 (下一个 token 预测)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| @StrategyFactory.register("seq") class SEQStrategy(BaseStrategy): def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs): super().__init__(model, device, **kwargs) self.label_smoothing = label_smoothing
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: batch = move_to_device(batch, self.device) input_ids, target_ids = batch["input_ids"], batch["target_ids"] logits = self.model(input_ids=input_ids)["logits"] loss = F.cross_entropy( input=logits.flatten(0, 1).float(), target=target_ids.flatten(), label_smoothing=self.label_smoothing, ) return loss
|
数学原理: SEQ 使用标准的交叉熵损失进行下一个 token 预测:
当启用 label smoothing (默认为 0) 时,损失函数变为:
其中 α 为平滑系数,V 为词表大小。
SFT 策略 (带 Mask 的监督微调)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
| @StrategyFactory.register("sft") class SFTStrategy(BaseStrategy): def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs): super().__init__(model, device, **kwargs) self.label_smoothing = label_smoothing
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: batch = move_to_device(batch, self.device) input_ids, target_ids, loss_mask = ( batch["input_ids"], batch["target_ids"], batch["loss_mask"], ) ignore_index = -100 logits = self.model(input_ids=input_ids)["logits"] target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index) loss = F.cross_entropy( input=logits.flatten(0, 1).float(), target=target_ids.flatten(), ignore_index=ignore_index, label_smoothing=self.label_smoothing, ) return loss
|
数学原理: SFT 只对 loss_mask 为 True 的 token 计算损失:
其中 ℳ 是需要计算损失的位置集合。通过 mask 机制,可以仅训练特定 token(如 assistant 回复部分),避免学习到 prompt 或用户输入。
DPO 策略 (直接偏好优化)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
| @StrategyFactory.register("dpo") class DPOStrategy(BaseStrategy): def __init__( self, model: nn.Module, device: str, beta: float = 0.1, reduction: str = "mean", **kwargs, ): super().__init__(model, device, **kwargs) self.ref_model = create_ref_model(model) self.beta = beta self.reduction = reduction
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: batch = move_to_device(batch, self.device) chosen_ids, rejected_ids = batch["chosen"], batch["rejected"] chosen_mask, rejected_mask = batch["chosen_mask"], batch["rejected_mask"] concat_ids = torch.cat([chosen_ids, rejected_ids], dim=0) concat_mask = torch.cat([chosen_mask, rejected_mask], dim=0)
log_pi = get_logprobs(self.model, concat_ids, concat_mask, self.reduction)
with torch.no_grad(): log_ref = get_logprobs( self.ref_model, concat_ids, concat_mask, self.reduction )
log_pi_chosen = log_pi[: chosen_ids.shape[0]] log_pi_rejected = log_pi[chosen_ids.shape[0] :] log_ref_chosen = log_ref[: chosen_ids.shape[0]] log_ref_rejected = log_ref[chosen_ids.shape[0] :]
pi_log_ratio = log_pi_chosen - log_pi_rejected ref_log_ratio = log_ref_chosen - log_ref_rejected
ratio_diff = pi_log_ratio - ref_log_ratio dpo_loss = -F.logsigmoid(self.beta * ratio_diff).mean() return dpo_loss
|
数学原理: DPO (Direct Preference Optimization) 的目标函数直接最大化偏好数据的 log-likelihood:
其中 σ 是 sigmoid 函数,β 是温度参数(代码中默认为 0.1),yw 是偏好选项,yl 是被拒绝选项。代码实现中:
1 2
| dpo_loss = -F.logsigmoid(self.beta * ratio_diff).mean()
|
3.3 数据集设计
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
| @DatasetFactory.register("sft") class SFTDataset(BaseDataset): def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index): begin_idx, end_idx = self.get_index(index) x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long) y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to( dtype=torch.long ) loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to( dtype=torch.bool ) return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
|
特点: - 多Segment 异步加载 - HDF5 高效存储 - 窗口滑动采样
窗口滑动采样数学原理: 数据集使用滑动窗口生成训练样本:
begini = min (i ⋅ stride, total − 1 − window_size)
endi = min (begini + window_size, total − 1)
采样数量为:
这种采样方式确保数据覆盖最大化,同时保持样本间的时间连续性。
4. 推理系统
4.1 推理引擎
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
| class InferenceEngine: def __init__( self, model: nn.Module, tokenizer: AutoTokenizer, max_batch_size: int = 1, max_seq_len: Optional[int] = None, ): self.model = model self.tokenizer = tokenizer
try: first_param = next(model.parameters()) device = first_param.device dtype = first_param.dtype except StopIteration: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32
self.scheduler = InferenceScheduler( model=self.model, tokenizer=self.tokenizer, max_batch_size=max_batch_size, max_seq_len=max_seq_len, device=device, dtype=dtype, )
self.kv_cache = self.scheduler.kv_cache self.seq_mask = self.scheduler.seq_mask self.scheduler.start()
|
核心能力: - 连续批处理 (Continuous Batching) - 流式输出 - KV Cache 管理
连续批处理数学原理: 传统静态批处理需要等待整个批次完成才返回结果,而连续批处理允许在批次中动态添加新请求:
其中 Ti 是第 i 个请求的 token 数,tprefill 是预填充时间。通过持续批处理,GPU 利用率可提升 2-3 倍。
KV Cache 优化将显存占用从 𝒪(B ⋅ L ⋅ d) 降低到 𝒪(nlayers ⋅ nKV ⋅ L ⋅ d),但通过 GQA/MLA 进一步优化后可降至 𝒪(nlayers ⋅ dlora ⋅ L):
𝒪(nlayers ⋅ dlora ⋅ L)
其中 B 是 batch size,L 是序列长度。
4.2 FastAPI Server
AstrAI 提供 OpenAI 兼容的 API 接口:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
| @app.post("/v1/chat/completions", response_model=CompletionResponse) async def chat_completion(request: ChatCompletionRequest): """OpenAI-compatible chat completion endpoint.""" if _engine is None: raise HTTPException(status_code=503, detail="Engine not initialized")
prompt = convert_messages_to_prompt(request.messages, engine=_engine)
if request.stream: generator = _engine.generate( prompt=prompt, stream=True, max_tokens=request.max_tokens, temperature=request.temperature, top_p=request.top_p, top_k=request.top_k, )
def generate_stream(): for token in generator: if token == "[DONE]": break yield f"data: {json.dumps({'choices': [{'delta': {'content': token}}]})}\n\n" yield "data: [DONE]\n\n"
return StreamingResponse( generate_stream(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, ) else: result = _engine.generate( prompt=prompt, stream=False, max_tokens=request.max_tokens, temperature=request.temperature, top_p=request.top_p, top_k=request.top_k, )
import time resp = CompletionResponse( id=f"chatcmpl-{int(time.time())}", created=int(time.time()), choices=[ { "index": 0, "message": {"role": "assistant", "content": result}, "finish_reason": "stop", } ], ) return resp
|
API 端点: - /health - 健康检查 - /stats - 统计信息 - /v1/chat/completions - OpenAI 兼容聊天接口 - /generate - 简单生成接口
5. 分布式训练
5.1 并行设置
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
| def spawn_parallel_fn( func: Callable, world_size: int, backend: str = "nccl", master_addr: str = "localhost", master_port: str = "29500", device_type: str = "cuda", device_ids: Optional[List[int]] = None, **kwargs, ): for key in [ "MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE", "LOCAL_RANK", "LOCAL_DEVICE", ]: if key in os.environ: del os.environ[key]
if world_size == 1: device_ids = device_ids or [0] device_id = torch.device(device_type, device_ids[0]) os.environ["LOCAL_DEVICE"] = str(device_id) func(**kwargs) return
wrapper_spawn_func_args = ( world_size, backend, master_addr, master_port, device_type, device_ids, func, kwargs, )
mp.spawn( wrapper_spawn_func, nprocs=world_size, args=wrapper_spawn_func_args, join=True )
|
支持: - NCCL 后端 (GPU) - GLOO 后端 (CPU) - CCL 后端 (Intel GPU)
分布式训练通信复杂度: 在 DDP 训练中,每个 GPU 计算梯度后需要 All-Reduce 同步:
其中 n 是 world_size。通过梯度压缩和异步通信可以进一步减少通信开销。
5.2 检查点管理
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
| class Checkpoint: def __init__( self, state_dict: Dict[str, Any], epoch: int = 0, iteration: int = 0, ): self.state_dict = state_dict self.epoch = epoch self.iteration = iteration
def save(self, save_dir: str) -> None: save_path = Path(save_dir) save_path.mkdir(parents=True, exist_ok=True)
rank = get_rank() if rank == 0: meta = { "epoch": self.epoch, "iteration": self.iteration, } with open(save_path / "meta.json", "w") as f: json.dump(meta, f, indent=2)
st.save_file(self.state_dict, save_path / "state_dict.safetensors")
@classmethod def load(cls, save_dir: str) -> "Checkpoint": rank = get_rank() save_path = Path(save_dir)
meta = {} if rank == 0: with open(Path(save_dir) / "meta.json", "r") as f: meta = json.load(f)
if dist.is_initialized(): meta_list = [meta] dist.broadcast_object_list(meta_list, src=0) meta = meta_list[0]
state_dict = st.load_file(save_path / "state_dict.safetensors")
return cls( state_dict=state_dict, epoch=meta["epoch"], iteration=meta["iteration"], )
|
分布式检查点保存原理: - 仅 rank 0 进程负责写入文件系统,避免竞争 - 元数据 (epoch, iteration) 通过 dist.broadcast_object_list 广播到所有进程 - 使用 safetensors 格式支持内存映射,实现零拷贝加载
6. 快速开始
6.1 安装
1 2 3
| git clone https://github.com/ViperEkura/AstrAI.git cd AstrAI pip install -e .
|
6.2 训练模型
1 2 3 4
| python scripts/tools/train.py \ --train_type=seq \ --data_root_path=/path/to/dataset \ --param_path=/path/to/params
|
6.3 启动推理服务
1
| python scripts/tools/server.py --param_path=/path/to/params
|
7. 与同类框架对比
| 训练支持 |
✅ SEQ/SFT/DPO/GRPO |
❌ 仅推理 |
❌ 仅推理 |
| 推理支持 |
✅ |
✅ TP/PP |
✅ TP/PP |
| 分布式训练 |
✅ DDP |
❌ |
❌ |
| 连续批处理 |
✅ |
✅ |
✅ |
| 依赖简洁 |
✅ 仅 PyTorch |
中 |
大 |
8. 未来规划
9. 总结
AstrAI 是一个完全自研的轻量级 Transformer 框架,具有以下优势:
- 模块化设计 - 清晰的分层架构,易于扩展
- 多样化的训练支持 - SEQ/SFT/DPO/GRPO 一站式解决方案
- 高性能推理 - 连续批处理 + KV Cache 优化
- 简洁易用 - 友好的 API 与丰富的示例
- HF 兼容 - 无缝对接 HuggingFace 生态
欢迎 Star、Fork 和贡献!
参考链接: