9.3 KiB
9.3 KiB
KHAOSZ 数据流文档
本文档描述 KHAOSZ 项目(一个自回归 Transformer 语言模型的训练与推理框架)的数据流。涵盖从原始数据到模型训练、推理的完整流程。
概述
KHAOSZ 采用模块化设计,主要组件包括:
- 数据模块 (
khaosz/data/): 数据集、采样器、分词器、序列化工具 - 模型模块 (
khaosz/model/): Transformer 模型及其子模块 - 训练模块 (
khaosz/trainer/): 训练器、训练上下文、策略、调度器 - 推理模块 (
khaosz/inference/): 生成核心、KV 缓存管理、流式生成 - 配置模块 (
khaosz/config/): 模型、训练、调度等配置 - 并行模块 (
khaosz/parallel/): 分布式训练支持
数据流总体可分为 训练数据流 与 推理数据流 两条主线。
数据流图
flowchart LR
subgraph A[数据准备]
direction TB
A1[原始文本] --> A2[BBPE 分词器]
A2 --> A3[序列化为 .h5 文件]
A3 --> A4[数据集加载<br/>BaseDataset]
A4 --> A5[可恢复分布式采样器<br/>ResumableDistributedSampler]
A5 --> A6[DataLoader 批量加载]
end
subgraph B[训练循环]
direction TB
B1[批次数据] --> B2[训练策略<br/>BaseStrategy]
B2 --> B3[Transformer 模型]
B3 --> B4[输出 logits]
B4 --> B5[损失计算]
B5 --> B6[反向传播]
B6 --> B7[优化器更新]
B7 --> B8[学习率调度器]
B8 --> B9[检查点保存]
end
subgraph C[推理生成]
direction TB
C1[检查点加载] --> C2[推理模型加载]
C2 --> C3[生成核心<br/>GeneratorCore]
C3 --> C4[采样策略<br/>温度/top‑k/top‑p]
C4 --> C5[生成下一个 token]
C5 --> C6[KV 缓存更新]
C6 --> C7{是否达到最大长度?}
C7 -->|否| C5
C7 -->|是| C8[输出生成文本]
end
A --> B
B --> C
各模块详细说明
1. 数据模块
1.1 分词器 (tokenizer.py)
- 基于 Byte‑Level BPE (BBPE) 实现
- 支持特殊 token:
<bos>,<eos>,<pad>,<|im_start|>,<|im_end|> - 提供
encode/decode方法,将文本与 token ID 相互转换 - 训练时从语料库学习词汇表,保存为
.json文件
1.2 序列化 (serialization.py)
save_h5: 将多个张量按组保存为 HDF5 文件(.h5),每个键对应一个张量列表load_h5: 加载.h5文件,返回Dict[str, List[Tensor]],支持共享内存 (share_memory=True)Checkpoint类: 封装模型状态字典、训练轮次、迭代次数,支持 safetensors 格式保存与加载
1.3 数据集 (dataset.py)
BaseDataset: 抽象基类,定义窗口采样、步长等通用逻辑BaseSegmentFetcher与MultiSegmentFetcher: 高效地从多个分段中获取指定索引范围的数据DatasetFactory: 工厂模式,支持动态注册数据集类型(seq,sft,dpo,grpo)- 数据集加载后通过
MultiSegmentFetcher管理多个数据键(如"sequence","mask")
1.4 采样器 (sampler.py)
ResumableDistributedSampler: 支持分布式训练的可恢复采样器- 记录当前 epoch 和迭代位置,便于从断点继续训练
- 支持 shuffle 与 drop_last 选项
2. 模型模块
2.1 Transformer (transformer.py)
- 核心自回归解码器架构
- 包含嵌入层、多层
DecoderBlock、RMSNorm 和线性输出头 - 支持权重绑定 (
tie_weight=True) 以减小参数量 - 使用 Rotary Position Embedding (RoPE) 注入位置信息
2.2 子模块 (module.py)
RotaryEmbedding: 生成 RoPE 的 cos/sin 缓存DecoderBlock: 包含多头注意力(支持 GQA)、前馈网络(FFN)、残差连接RMSNorm: 层归一化变体Linear,Embedding: 自定义线性层与嵌入层,支持并行化包装
3. 训练模块
3.1 训练上下文 (train_context.py)
TrainContext: 数据类,封装训练所需的所有组件(模型、优化器、数据加载器、策略等)TrainContextBuilder: 构建器模式,逐步组装训练上下文,支持从检查点恢复
3.2 训练器 (trainer.py)
Trainer: 主训练循环,管理回调函数(进度条、检查点、指标记录、梯度裁剪、调度器)- 支持分布式训练(通过
spawn_parallel_fn启动多进程) - 训练步骤包括:
on_train_begin→ 2.on_epoch_begin→ 3.on_batch_begin→ 4. 前向/损失计算 → 5.on_batch_end→ 6. 梯度累积 → 7.on_step_begin→ 8. 优化器更新 → 9.on_step_end→ 10.on_epoch_end
3.3 策略 (strategy.py)
BaseStrategy: 定义训练策略接口(如SeqStrategy,SFTStrategy,DPOStrategy)- 策略接收批次数据,执行模型前向传播、损失计算,返回 loss 张量
- 由
StrategyFactory根据配置动态创建
3.4 调度器 (schedule.py)
BaseScheduler: 抽象基类,定义学习率调度接口SchedulerFactory: 工厂模式,支持注册多种调度器(如cosine,sgdr)- 调度器根据配置自动创建,并与优化器绑定
4. 推理模块
4.1 生成核心 (core.py)
GeneratorCore: 提供generate_iterator方法,执行单步生成- 应用采样策略(温度、top‑k、top‑p)对 logits 进行筛选
- 支持 KV 缓存以加速自回归生成
4.2 KV 缓存管理 (core.py)
KVCacheManager: 管理每层的 K 和 V 缓存,支持批量生成与长度扩展- 缓存形状为
[batch_size, n_kv_heads, seq_len, head_dim]
4.3 生成器 (generator.py)
GenerationRequest: 封装生成请求参数(top_k, top_p, temperature, max_len, query, history 等)build_prompt: 将查询与历史记录转换为 ChatML 格式的提示字符串pad_sequence: 对输入 ID 进行填充,使其长度一致- 提供流式与非流式生成接口
训练数据流详细步骤
-
数据准备
- 原始文本经过 BBPE 分词器转换为 token ID 序列
- 将 token ID 序列(可能带有掩码、标签等)按组保存为
.h5文件 - 文件可包含多个分段,每个分段对应一个张量
-
数据集加载
BaseDataset的load方法调用load_h5,得到segments字典- 创建
MultiSegmentFetcher管理多个键的数据 - 计算总样本数,并根据窗口大小、步长确定每个样本的起始/结束索引
-
采样与批量加载
ResumableDistributedSampler根据当前 epoch 和迭代位置生成索引序列DataLoader使用采样器获取索引,调用数据集的__getitem__获取实际数据- 批量数据形状为
[batch_size, window_size](或根据具体数据集类型变化)
-
策略前向与损失计算
- 批次数据传入策略(如
SeqStrategy) - 策略内部调用
Transformer模型,得到 logits - 根据任务类型计算交叉熵损失(或 DPO 损失等)
- 返回 loss 张量
- 批次数据传入策略(如
-
反向传播与优化
- 损失除以累积步数进行归一化,然后执行
loss.backward() - 每累积
accumulation_steps个批次后,执行优化器step()和zero_grad() - 学习率调度器在每个 step 后更新学习率
- 损失除以累积步数进行归一化,然后执行
-
检查点保存
CheckpointCallback按设定的间隔保存检查点- 检查点包含模型状态字典、当前 epoch、iteration 等元数据
- 使用 safetensors 格式保存,确保安全与效率
推理数据流详细步骤
-
模型加载
- 从检查点加载
Transformer模型与分词器 - 模型设置为评估模式 (
model.eval()),启用推理模式 (torch.inference_mode)
- 从检查点加载
-
提示构建与编码
- 用户查询与历史记录通过
build_prompt转换为 ChatML 格式字符串 - 分词器将提示字符串编码为 token ID 序列
input_ids - 若为批量生成,使用
pad_sequence进行填充
- 用户查询与历史记录通过
-
自回归生成循环
- 初始化 KV 缓存(可选)
- 循环直到生成
max_len个 token 或遇到停止 token:- 将当前
input_ids(或缓存后的新 token)输入模型,得到logits - 对
logits应用apply_sampling_strategies(温度、top‑k、top‑p) - 从处理后的分布中采样得到下一个 token ID
- 将新 token 追加到
input_ids,同时更新 KV 缓存 - 若为流式生成,每生成一个 token 立即 yield 给调用方
- 将当前
-
解码与输出
- 将生成的 token ID 序列通过分词器解码为文本
- 去除特殊 token,返回纯文本响应
检查点与序列化
- 训练检查点:保存模型参数、优化器状态、调度器状态、当前 epoch 与 iteration
- 模型参数:支持 safetensors 格式,加载时自动处理权重绑定等特殊逻辑
- 数据集序列化:HDF5 格式支持高效随机读取与共享内存,适合大规模预训练数据
总结
KHAOSZ 的数据流设计体现了模块化、可扩展、可恢复的特点。训练数据流通过分块加载、可恢复采样、梯度累积等机制支持大规模分布式训练;推理数据流则利用 KV 缓存、采样策略实现高效的文本生成。各模块之间通过清晰的接口耦合,便于定制与扩展。
文档更新时间:2026‑03‑30
对应代码版本:参考pyproject.toml中定义的版本号