205 lines
9.3 KiB
Markdown
205 lines
9.3 KiB
Markdown
# KHAOSZ 数据流文档
|
||
|
||
本文档描述 KHAOSZ 项目(一个自回归 Transformer 语言模型的训练与推理框架)的数据流。涵盖从原始数据到模型训练、推理的完整流程。
|
||
|
||
## 概述
|
||
|
||
KHAOSZ 采用模块化设计,主要组件包括:
|
||
- **数据模块** (`khaosz/data/`): 数据集、采样器、分词器、序列化工具
|
||
- **模型模块** (`khaosz/model/`): Transformer 模型及其子模块
|
||
- **训练模块** (`khaosz/trainer/`): 训练器、训练上下文、策略、调度器
|
||
- **推理模块** (`khaosz/inference/`): 生成核心、KV 缓存管理、流式生成
|
||
- **配置模块** (`khaosz/config/`): 模型、训练、调度等配置
|
||
- **并行模块** (`khaosz/parallel/`): 分布式训练支持
|
||
|
||
数据流总体可分为 **训练数据流** 与 **推理数据流** 两条主线。
|
||
|
||
## 数据流图
|
||
|
||
```mermaid
|
||
flowchart LR
|
||
subgraph A[数据准备]
|
||
direction TB
|
||
A1[原始文本] --> A2[BBPE 分词器]
|
||
A2 --> A3[序列化为 .h5 文件]
|
||
A3 --> A4[数据集加载<br/>BaseDataset]
|
||
A4 --> A5[可恢复分布式采样器<br/>ResumableDistributedSampler]
|
||
A5 --> A6[DataLoader 批量加载]
|
||
end
|
||
|
||
subgraph B[训练循环]
|
||
direction TB
|
||
B1[批次数据] --> B2[训练策略<br/>BaseStrategy]
|
||
B2 --> B3[Transformer 模型]
|
||
B3 --> B4[输出 logits]
|
||
B4 --> B5[损失计算]
|
||
B5 --> B6[反向传播]
|
||
B6 --> B7[优化器更新]
|
||
B7 --> B8[学习率调度器]
|
||
B8 --> B9[检查点保存]
|
||
end
|
||
|
||
subgraph C[推理生成]
|
||
direction TB
|
||
C1[检查点加载] --> C2[推理模型加载]
|
||
C2 --> C3[生成核心<br/>GeneratorCore]
|
||
C3 --> C4[采样策略<br/>温度/top‑k/top‑p]
|
||
C4 --> C5[生成下一个 token]
|
||
C5 --> C6[KV 缓存更新]
|
||
C6 --> C7{是否达到最大长度?}
|
||
C7 -->|否| C5
|
||
C7 -->|是| C8[输出生成文本]
|
||
end
|
||
|
||
A --> B
|
||
B --> C
|
||
```
|
||
|
||
## 各模块详细说明
|
||
|
||
### 1. 数据模块
|
||
|
||
#### 1.1 分词器 (`tokenizer.py`)
|
||
- 基于 Byte‑Level BPE (BBPE) 实现
|
||
- 支持特殊 token:`<bos>`, `<eos>`, `<pad>`, `<|im_start|>`, `<|im_end|>`
|
||
- 提供 `encode`/`decode` 方法,将文本与 token ID 相互转换
|
||
- 训练时从语料库学习词汇表,保存为 `.json` 文件
|
||
|
||
#### 1.2 序列化 (`serialization.py`)
|
||
- **`save_h5`**: 将多个张量按组保存为 HDF5 文件(`.h5`),每个键对应一个张量列表
|
||
- **`load_h5`**: 加载 `.h5` 文件,返回 `Dict[str, List[Tensor]]`,支持共享内存 (`share_memory=True`)
|
||
- **`Checkpoint` 类**: 封装模型状态字典、训练轮次、迭代次数,支持 safetensors 格式保存与加载
|
||
|
||
#### 1.3 数据集 (`dataset.py`)
|
||
- **`BaseDataset`**: 抽象基类,定义窗口采样、步长等通用逻辑
|
||
- **`BaseSegmentFetcher`** 与 **`MultiSegmentFetcher`**: 高效地从多个分段中获取指定索引范围的数据
|
||
- **`DatasetFactory`**: 工厂模式,支持动态注册数据集类型(`seq`, `sft`, `dpo`, `grpo`)
|
||
- 数据集加载后通过 `MultiSegmentFetcher` 管理多个数据键(如 `"sequence"`, `"mask"`)
|
||
|
||
#### 1.4 采样器 (`sampler.py`)
|
||
- **`ResumableDistributedSampler`**: 支持分布式训练的可恢复采样器
|
||
- 记录当前 epoch 和迭代位置,便于从断点继续训练
|
||
- 支持 shuffle 与 drop_last 选项
|
||
|
||
### 2. 模型模块
|
||
|
||
#### 2.1 Transformer (`transformer.py`)
|
||
- 核心自回归解码器架构
|
||
- 包含嵌入层、多层 `DecoderBlock`、RMSNorm 和线性输出头
|
||
- 支持权重绑定 (`tie_weight=True`) 以减小参数量
|
||
- 使用 Rotary Position Embedding (RoPE) 注入位置信息
|
||
|
||
#### 2.2 子模块 (`module.py`)
|
||
- **`RotaryEmbedding`**: 生成 RoPE 的 cos/sin 缓存
|
||
- **`DecoderBlock`**: 包含多头注意力(支持 GQA)、前馈网络(FFN)、残差连接
|
||
- **`RMSNorm`**: 层归一化变体
|
||
- **`Linear`**, **`Embedding`**: 自定义线性层与嵌入层,支持并行化包装
|
||
|
||
### 3. 训练模块
|
||
|
||
#### 3.1 训练上下文 (`train_context.py`)
|
||
- **`TrainContext`**: 数据类,封装训练所需的所有组件(模型、优化器、数据加载器、策略等)
|
||
- **`TrainContextBuilder`**: 构建器模式,逐步组装训练上下文,支持从检查点恢复
|
||
|
||
#### 3.2 训练器 (`trainer.py`)
|
||
- **`Trainer`**: 主训练循环,管理回调函数(进度条、检查点、指标记录、梯度裁剪、调度器)
|
||
- 支持分布式训练(通过 `spawn_parallel_fn` 启动多进程)
|
||
- 训练步骤包括:
|
||
1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. 前向/损失计算 → 5. `on_batch_end` → 6. 梯度累积 → 7. `on_step_begin` → 8. 优化器更新 → 9. `on_step_end` → 10. `on_epoch_end`
|
||
|
||
#### 3.3 策略 (`strategy.py`)
|
||
- **`BaseStrategy`**: 定义训练策略接口(如 `SeqStrategy`, `SFTStrategy`, `DPOStrategy`)
|
||
- 策略接收批次数据,执行模型前向传播、损失计算,返回 loss 张量
|
||
- 由 `StrategyFactory` 根据配置动态创建
|
||
|
||
#### 3.4 调度器 (`schedule.py`)
|
||
- **`BaseScheduler`**: 抽象基类,定义学习率调度接口
|
||
- **`SchedulerFactory`**: 工厂模式,支持注册多种调度器(如 `cosine`, `sgdr`)
|
||
- 调度器根据配置自动创建,并与优化器绑定
|
||
|
||
### 4. 推理模块
|
||
|
||
#### 4.1 生成核心 (`core.py`)
|
||
- **`GeneratorCore`**: 提供 `generate_iterator` 方法,执行单步生成
|
||
- 应用采样策略(温度、top‑k、top‑p)对 logits 进行筛选
|
||
- 支持 KV 缓存以加速自回归生成
|
||
|
||
#### 4.2 KV 缓存管理 (`core.py`)
|
||
- **`KVCacheManager`**: 管理每层的 K 和 V 缓存,支持批量生成与长度扩展
|
||
- 缓存形状为 `[batch_size, n_kv_heads, seq_len, head_dim]`
|
||
|
||
#### 4.3 生成器 (`generator.py`)
|
||
- **`GenerationRequest`**: 封装生成请求参数(top_k, top_p, temperature, max_len, query, history 等)
|
||
- **`build_prompt`**: 将查询与历史记录转换为 ChatML 格式的提示字符串
|
||
- **`pad_sequence`**: 对输入 ID 进行填充,使其长度一致
|
||
- 提供流式与非流式生成接口
|
||
|
||
## 训练数据流详细步骤
|
||
|
||
1. **数据准备**
|
||
- 原始文本经过 BBPE 分词器转换为 token ID 序列
|
||
- 将 token ID 序列(可能带有掩码、标签等)按组保存为 `.h5` 文件
|
||
- 文件可包含多个分段,每个分段对应一个张量
|
||
|
||
2. **数据集加载**
|
||
- `BaseDataset` 的 `load` 方法调用 `load_h5`,得到 `segments` 字典
|
||
- 创建 `MultiSegmentFetcher` 管理多个键的数据
|
||
- 计算总样本数,并根据窗口大小、步长确定每个样本的起始/结束索引
|
||
|
||
3. **采样与批量加载**
|
||
- `ResumableDistributedSampler` 根据当前 epoch 和迭代位置生成索引序列
|
||
- `DataLoader` 使用采样器获取索引,调用数据集的 `__getitem__` 获取实际数据
|
||
- 批量数据形状为 `[batch_size, window_size]`(或根据具体数据集类型变化)
|
||
|
||
4. **策略前向与损失计算**
|
||
- 批次数据传入策略(如 `SeqStrategy`)
|
||
- 策略内部调用 `Transformer` 模型,得到 logits
|
||
- 根据任务类型计算交叉熵损失(或 DPO 损失等)
|
||
- 返回 loss 张量
|
||
|
||
5. **反向传播与优化**
|
||
- 损失除以累积步数进行归一化,然后执行 `loss.backward()`
|
||
- 每累积 `accumulation_steps` 个批次后,执行优化器 `step()` 和 `zero_grad()`
|
||
- 学习率调度器在每个 step 后更新学习率
|
||
|
||
6. **检查点保存**
|
||
- `CheckpointCallback` 按设定的间隔保存检查点
|
||
- 检查点包含模型状态字典、当前 epoch、iteration 等元数据
|
||
- 使用 safetensors 格式保存,确保安全与效率
|
||
|
||
## 推理数据流详细步骤
|
||
|
||
1. **模型加载**
|
||
- 从检查点加载 `Transformer` 模型与分词器
|
||
- 模型设置为评估模式 (`model.eval()`),启用推理模式 (`torch.inference_mode`)
|
||
|
||
2. **提示构建与编码**
|
||
- 用户查询与历史记录通过 `build_prompt` 转换为 ChatML 格式字符串
|
||
- 分词器将提示字符串编码为 token ID 序列 `input_ids`
|
||
- 若为批量生成,使用 `pad_sequence` 进行填充
|
||
|
||
3. **自回归生成循环**
|
||
- 初始化 KV 缓存(可选)
|
||
- 循环直到生成 `max_len` 个 token 或遇到停止 token:
|
||
- 将当前 `input_ids`(或缓存后的新 token)输入模型,得到 `logits`
|
||
- 对 `logits` 应用 `apply_sampling_strategies`(温度、top‑k、top‑p)
|
||
- 从处理后的分布中采样得到下一个 token ID
|
||
- 将新 token 追加到 `input_ids`,同时更新 KV 缓存
|
||
- 若为流式生成,每生成一个 token 立即 yield 给调用方
|
||
|
||
4. **解码与输出**
|
||
- 将生成的 token ID 序列通过分词器解码为文本
|
||
- 去除特殊 token,返回纯文本响应
|
||
|
||
## 检查点与序列化
|
||
|
||
- **训练检查点**:保存模型参数、优化器状态、调度器状态、当前 epoch 与 iteration
|
||
- **模型参数**:支持 safetensors 格式,加载时自动处理权重绑定等特殊逻辑
|
||
- **数据集序列化**:HDF5 格式支持高效随机读取与共享内存,适合大规模预训练数据
|
||
|
||
## 总结
|
||
|
||
KHAOSZ 的数据流设计体现了模块化、可扩展、可恢复的特点。训练数据流通过分块加载、可恢复采样、梯度累积等机制支持大规模分布式训练;推理数据流则利用 KV 缓存、采样策略实现高效的文本生成。各模块之间通过清晰的接口耦合,便于定制与扩展。
|
||
|
||
> 文档更新时间:2026‑03‑30
|
||
> 对应代码版本:参考 `pyproject.toml` 中定义的版本号 |