# KHAOSZ 数据流文档 本文档描述 KHAOSZ 项目(一个自回归 Transformer 语言模型的训练与推理框架)的数据流。涵盖从原始数据到模型训练、推理的完整流程。 ## 概述 KHAOSZ 采用模块化设计,主要组件包括: - **数据模块** (`khaosz/data/`): 数据集、采样器、分词器、序列化工具 - **模型模块** (`khaosz/model/`): Transformer 模型及其子模块 - **训练模块** (`khaosz/trainer/`): 训练器、训练上下文、策略、调度器 - **推理模块** (`khaosz/inference/`): 生成核心、KV 缓存管理、流式生成 - **配置模块** (`khaosz/config/`): 模型、训练、调度等配置 - **并行模块** (`khaosz/parallel/`): 分布式训练支持 数据流总体可分为 **训练数据流** 与 **推理数据流** 两条主线。 ## 数据流图 ```mermaid flowchart LR subgraph A[数据准备] direction TB A1[原始文本] --> A2[BBPE 分词器] A2 --> A3[序列化为 .h5 文件] A3 --> A4[数据集加载
BaseDataset] A4 --> A5[可恢复分布式采样器
ResumableDistributedSampler] A5 --> A6[DataLoader 批量加载] end subgraph B[训练循环] direction TB B1[批次数据] --> B2[训练策略
BaseStrategy] B2 --> B3[Transformer 模型] B3 --> B4[输出 logits] B4 --> B5[损失计算] B5 --> B6[反向传播] B6 --> B7[优化器更新] B7 --> B8[学习率调度器] B8 --> B9[检查点保存] end subgraph C[推理生成] direction TB C1[检查点加载] --> C2[推理模型加载] C2 --> C3[生成核心
GeneratorCore] C3 --> C4[采样策略
温度/top‑k/top‑p] C4 --> C5[生成下一个 token] C5 --> C6[KV 缓存更新] C6 --> C7{是否达到最大长度?} C7 -->|否| C5 C7 -->|是| C8[输出生成文本] end A --> B B --> C ``` ## 各模块详细说明 ### 1. 数据模块 #### 1.1 分词器 (`tokenizer.py`) - 基于 Byte‑Level BPE (BBPE) 实现 - 支持特殊 token:``, ``, ``, `<|im_start|>`, `<|im_end|>` - 提供 `encode`/`decode` 方法,将文本与 token ID 相互转换 - 训练时从语料库学习词汇表,保存为 `.json` 文件 #### 1.2 序列化 (`serialization.py`) - **`save_h5`**: 将多个张量按组保存为 HDF5 文件(`.h5`),每个键对应一个张量列表 - **`load_h5`**: 加载 `.h5` 文件,返回 `Dict[str, List[Tensor]]`,支持共享内存 (`share_memory=True`) - **`Checkpoint` 类**: 封装模型状态字典、训练轮次、迭代次数,支持 safetensors 格式保存与加载 #### 1.3 数据集 (`dataset.py`) - **`BaseDataset`**: 抽象基类,定义窗口采样、步长等通用逻辑 - **`BaseSegmentFetcher`** 与 **`MultiSegmentFetcher`**: 高效地从多个分段中获取指定索引范围的数据 - **`DatasetFactory`**: 工厂模式,支持动态注册数据集类型(`seq`, `sft`, `dpo`, `grpo`) - 数据集加载后通过 `MultiSegmentFetcher` 管理多个数据键(如 `"sequence"`, `"mask"`) #### 1.4 采样器 (`sampler.py`) - **`ResumableDistributedSampler`**: 支持分布式训练的可恢复采样器 - 记录当前 epoch 和迭代位置,便于从断点继续训练 - 支持 shuffle 与 drop_last 选项 ### 2. 模型模块 #### 2.1 Transformer (`transformer.py`) - 核心自回归解码器架构 - 包含嵌入层、多层 `DecoderBlock`、RMSNorm 和线性输出头 - 支持权重绑定 (`tie_weight=True`) 以减小参数量 - 使用 Rotary Position Embedding (RoPE) 注入位置信息 #### 2.2 子模块 (`module.py`) - **`RotaryEmbedding`**: 生成 RoPE 的 cos/sin 缓存 - **`DecoderBlock`**: 包含多头注意力(支持 GQA)、前馈网络(FFN)、残差连接 - **`RMSNorm`**: 层归一化变体 - **`Linear`**, **`Embedding`**: 自定义线性层与嵌入层,支持并行化包装 ### 3. 训练模块 #### 3.1 训练上下文 (`train_context.py`) - **`TrainContext`**: 数据类,封装训练所需的所有组件(模型、优化器、数据加载器、策略等) - **`TrainContextBuilder`**: 构建器模式,逐步组装训练上下文,支持从检查点恢复 #### 3.2 训练器 (`trainer.py`) - **`Trainer`**: 主训练循环,管理回调函数(进度条、检查点、指标记录、梯度裁剪、调度器) - 支持分布式训练(通过 `spawn_parallel_fn` 启动多进程) - 训练步骤包括: 1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. 前向/损失计算 → 5. `on_batch_end` → 6. 梯度累积 → 7. `on_step_begin` → 8. 优化器更新 → 9. `on_step_end` → 10. `on_epoch_end` #### 3.3 策略 (`strategy.py`) - **`BaseStrategy`**: 定义训练策略接口(如 `SeqStrategy`, `SFTStrategy`, `DPOStrategy`) - 策略接收批次数据,执行模型前向传播、损失计算,返回 loss 张量 - 由 `StrategyFactory` 根据配置动态创建 #### 3.4 调度器 (`schedule.py`) - **`BaseScheduler`**: 抽象基类,定义学习率调度接口 - **`SchedulerFactory`**: 工厂模式,支持注册多种调度器(如 `cosine`, `sgdr`) - 调度器根据配置自动创建,并与优化器绑定 ### 4. 推理模块 #### 4.1 生成核心 (`core.py`) - **`GeneratorCore`**: 提供 `generate_iterator` 方法,执行单步生成 - 应用采样策略(温度、top‑k、top‑p)对 logits 进行筛选 - 支持 KV 缓存以加速自回归生成 #### 4.2 KV 缓存管理 (`core.py`) - **`KVCacheManager`**: 管理每层的 K 和 V 缓存,支持批量生成与长度扩展 - 缓存形状为 `[batch_size, n_kv_heads, seq_len, head_dim]` #### 4.3 生成器 (`generator.py`) - **`GenerationRequest`**: 封装生成请求参数(top_k, top_p, temperature, max_len, query, history 等) - **`build_prompt`**: 将查询与历史记录转换为 ChatML 格式的提示字符串 - **`pad_sequence`**: 对输入 ID 进行填充,使其长度一致 - 提供流式与非流式生成接口 ## 训练数据流详细步骤 1. **数据准备** - 原始文本经过 BBPE 分词器转换为 token ID 序列 - 将 token ID 序列(可能带有掩码、标签等)按组保存为 `.h5` 文件 - 文件可包含多个分段,每个分段对应一个张量 2. **数据集加载** - `BaseDataset` 的 `load` 方法调用 `load_h5`,得到 `segments` 字典 - 创建 `MultiSegmentFetcher` 管理多个键的数据 - 计算总样本数,并根据窗口大小、步长确定每个样本的起始/结束索引 3. **采样与批量加载** - `ResumableDistributedSampler` 根据当前 epoch 和迭代位置生成索引序列 - `DataLoader` 使用采样器获取索引,调用数据集的 `__getitem__` 获取实际数据 - 批量数据形状为 `[batch_size, window_size]`(或根据具体数据集类型变化) 4. **策略前向与损失计算** - 批次数据传入策略(如 `SeqStrategy`) - 策略内部调用 `Transformer` 模型,得到 logits - 根据任务类型计算交叉熵损失(或 DPO 损失等) - 返回 loss 张量 5. **反向传播与优化** - 损失除以累积步数进行归一化,然后执行 `loss.backward()` - 每累积 `accumulation_steps` 个批次后,执行优化器 `step()` 和 `zero_grad()` - 学习率调度器在每个 step 后更新学习率 6. **检查点保存** - `CheckpointCallback` 按设定的间隔保存检查点 - 检查点包含模型状态字典、当前 epoch、iteration 等元数据 - 使用 safetensors 格式保存,确保安全与效率 ## 推理数据流详细步骤 1. **模型加载** - 从检查点加载 `Transformer` 模型与分词器 - 模型设置为评估模式 (`model.eval()`),启用推理模式 (`torch.inference_mode`) 2. **提示构建与编码** - 用户查询与历史记录通过 `build_prompt` 转换为 ChatML 格式字符串 - 分词器将提示字符串编码为 token ID 序列 `input_ids` - 若为批量生成,使用 `pad_sequence` 进行填充 3. **自回归生成循环** - 初始化 KV 缓存(可选) - 循环直到生成 `max_len` 个 token 或遇到停止 token: - 将当前 `input_ids`(或缓存后的新 token)输入模型,得到 `logits` - 对 `logits` 应用 `apply_sampling_strategies`(温度、top‑k、top‑p) - 从处理后的分布中采样得到下一个 token ID - 将新 token 追加到 `input_ids`,同时更新 KV 缓存 - 若为流式生成,每生成一个 token 立即 yield 给调用方 4. **解码与输出** - 将生成的 token ID 序列通过分词器解码为文本 - 去除特殊 token,返回纯文本响应 ## 检查点与序列化 - **训练检查点**:保存模型参数、优化器状态、调度器状态、当前 epoch 与 iteration - **模型参数**:支持 safetensors 格式,加载时自动处理权重绑定等特殊逻辑 - **数据集序列化**:HDF5 格式支持高效随机读取与共享内存,适合大规模预训练数据 ## 总结 KHAOSZ 的数据流设计体现了模块化、可扩展、可恢复的特点。训练数据流通过分块加载、可恢复采样、梯度累积等机制支持大规模分布式训练;推理数据流则利用 KV 缓存、采样策略实现高效的文本生成。各模块之间通过清晰的接口耦合,便于定制与扩展。 > 文档更新时间:2026‑03‑30 > 对应代码版本:参考 `pyproject.toml` 中定义的版本号