fix: 修复部分已知问题
This commit is contained in:
parent
3e33c14376
commit
0e7fc623b4
|
|
@ -0,0 +1,205 @@
|
|||
# KHAOSZ 数据流文档
|
||||
|
||||
本文档描述 KHAOSZ 项目(一个自回归 Transformer 语言模型的训练与推理框架)的数据流。涵盖从原始数据到模型训练、推理的完整流程。
|
||||
|
||||
## 概述
|
||||
|
||||
KHAOSZ 采用模块化设计,主要组件包括:
|
||||
- **数据模块** (`khaosz/data/`): 数据集、采样器、分词器、序列化工具
|
||||
- **模型模块** (`khaosz/model/`): Transformer 模型及其子模块
|
||||
- **训练模块** (`khaosz/trainer/`): 训练器、训练上下文、策略、调度器
|
||||
- **推理模块** (`khaosz/inference/`): 生成核心、KV 缓存管理、流式生成
|
||||
- **配置模块** (`khaosz/config/`): 模型、训练、调度等配置
|
||||
- **并行模块** (`khaosz/parallel/`): 分布式训练支持
|
||||
|
||||
数据流总体可分为 **训练数据流** 与 **推理数据流** 两条主线。
|
||||
|
||||
## 数据流图
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
subgraph A[数据准备]
|
||||
direction TB
|
||||
A1[原始文本] --> A2[BBPE 分词器]
|
||||
A2 --> A3[序列化为 .h5 文件]
|
||||
A3 --> A4[数据集加载<br/>BaseDataset]
|
||||
A4 --> A5[可恢复分布式采样器<br/>ResumableDistributedSampler]
|
||||
A5 --> A6[DataLoader 批量加载]
|
||||
end
|
||||
|
||||
subgraph B[训练循环]
|
||||
direction TB
|
||||
B1[批次数据] --> B2[训练策略<br/>BaseStrategy]
|
||||
B2 --> B3[Transformer 模型]
|
||||
B3 --> B4[输出 logits]
|
||||
B4 --> B5[损失计算]
|
||||
B5 --> B6[反向传播]
|
||||
B6 --> B7[优化器更新]
|
||||
B7 --> B8[学习率调度器]
|
||||
B8 --> B9[检查点保存]
|
||||
end
|
||||
|
||||
subgraph C[推理生成]
|
||||
direction TB
|
||||
C1[检查点加载] --> C2[推理模型加载]
|
||||
C2 --> C3[生成核心<br/>GeneratorCore]
|
||||
C3 --> C4[采样策略<br/>温度/top‑k/top‑p]
|
||||
C4 --> C5[生成下一个 token]
|
||||
C5 --> C6[KV 缓存更新]
|
||||
C6 --> C7{是否达到最大长度?}
|
||||
C7 -->|否| C5
|
||||
C7 -->|是| C8[输出生成文本]
|
||||
end
|
||||
|
||||
A --> B
|
||||
B --> C
|
||||
```
|
||||
|
||||
## 各模块详细说明
|
||||
|
||||
### 1. 数据模块
|
||||
|
||||
#### 1.1 分词器 (`tokenizer.py`)
|
||||
- 基于 Byte‑Level BPE (BBPE) 实现
|
||||
- 支持特殊 token:`<bos>`, `<eos>`, `<pad>`, `<|im_start|>`, `<|im_end|>`
|
||||
- 提供 `encode`/`decode` 方法,将文本与 token ID 相互转换
|
||||
- 训练时从语料库学习词汇表,保存为 `.json` 文件
|
||||
|
||||
#### 1.2 序列化 (`serialization.py`)
|
||||
- **`save_h5`**: 将多个张量按组保存为 HDF5 文件(`.h5`),每个键对应一个张量列表
|
||||
- **`load_h5`**: 加载 `.h5` 文件,返回 `Dict[str, List[Tensor]]`,支持共享内存 (`share_memory=True`)
|
||||
- **`Checkpoint` 类**: 封装模型状态字典、训练轮次、迭代次数,支持 safetensors 格式保存与加载
|
||||
|
||||
#### 1.3 数据集 (`dataset.py`)
|
||||
- **`BaseDataset`**: 抽象基类,定义窗口采样、步长等通用逻辑
|
||||
- **`BaseSegmentFetcher`** 与 **`MultiSegmentFetcher`**: 高效地从多个分段中获取指定索引范围的数据
|
||||
- **`DatasetFactory`**: 工厂模式,支持动态注册数据集类型(`seq`, `sft`, `dpo`, `grpo`)
|
||||
- 数据集加载后通过 `MultiSegmentFetcher` 管理多个数据键(如 `"sequence"`, `"mask"`)
|
||||
|
||||
#### 1.4 采样器 (`sampler.py`)
|
||||
- **`ResumableDistributedSampler`**: 支持分布式训练的可恢复采样器
|
||||
- 记录当前 epoch 和迭代位置,便于从断点继续训练
|
||||
- 支持 shuffle 与 drop_last 选项
|
||||
|
||||
### 2. 模型模块
|
||||
|
||||
#### 2.1 Transformer (`transformer.py`)
|
||||
- 核心自回归解码器架构
|
||||
- 包含嵌入层、多层 `DecoderBlock`、RMSNorm 和线性输出头
|
||||
- 支持权重绑定 (`tie_weight=True`) 以减小参数量
|
||||
- 使用 Rotary Position Embedding (RoPE) 注入位置信息
|
||||
|
||||
#### 2.2 子模块 (`module.py`)
|
||||
- **`RotaryEmbedding`**: 生成 RoPE 的 cos/sin 缓存
|
||||
- **`DecoderBlock`**: 包含多头注意力(支持 GQA)、前馈网络(FFN)、残差连接
|
||||
- **`RMSNorm`**: 层归一化变体
|
||||
- **`Linear`**, **`Embedding`**: 自定义线性层与嵌入层,支持并行化包装
|
||||
|
||||
### 3. 训练模块
|
||||
|
||||
#### 3.1 训练上下文 (`train_context.py`)
|
||||
- **`TrainContext`**: 数据类,封装训练所需的所有组件(模型、优化器、数据加载器、策略等)
|
||||
- **`TrainContextBuilder`**: 构建器模式,逐步组装训练上下文,支持从检查点恢复
|
||||
|
||||
#### 3.2 训练器 (`trainer.py`)
|
||||
- **`Trainer`**: 主训练循环,管理回调函数(进度条、检查点、指标记录、梯度裁剪、调度器)
|
||||
- 支持分布式训练(通过 `spawn_parallel_fn` 启动多进程)
|
||||
- 训练步骤包括:
|
||||
1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. 前向/损失计算 → 5. `on_batch_end` → 6. 梯度累积 → 7. `on_step_begin` → 8. 优化器更新 → 9. `on_step_end` → 10. `on_epoch_end`
|
||||
|
||||
#### 3.3 策略 (`strategy.py`)
|
||||
- **`BaseStrategy`**: 定义训练策略接口(如 `SeqStrategy`, `SFTStrategy`, `DPOStrategy`)
|
||||
- 策略接收批次数据,执行模型前向传播、损失计算,返回 loss 张量
|
||||
- 由 `StrategyFactory` 根据配置动态创建
|
||||
|
||||
#### 3.4 调度器 (`schedule.py`)
|
||||
- **`BaseScheduler`**: 抽象基类,定义学习率调度接口
|
||||
- **`SchedulerFactory`**: 工厂模式,支持注册多种调度器(如 `cosine`, `sgdr`)
|
||||
- 调度器根据配置自动创建,并与优化器绑定
|
||||
|
||||
### 4. 推理模块
|
||||
|
||||
#### 4.1 生成核心 (`core.py`)
|
||||
- **`GeneratorCore`**: 提供 `generate_iterator` 方法,执行单步生成
|
||||
- 应用采样策略(温度、top‑k、top‑p)对 logits 进行筛选
|
||||
- 支持 KV 缓存以加速自回归生成
|
||||
|
||||
#### 4.2 KV 缓存管理 (`core.py`)
|
||||
- **`KVCacheManager`**: 管理每层的 K 和 V 缓存,支持批量生成与长度扩展
|
||||
- 缓存形状为 `[batch_size, n_kv_heads, seq_len, head_dim]`
|
||||
|
||||
#### 4.3 生成器 (`generator.py`)
|
||||
- **`GenerationRequest`**: 封装生成请求参数(top_k, top_p, temperature, max_len, query, history 等)
|
||||
- **`build_prompt`**: 将查询与历史记录转换为 ChatML 格式的提示字符串
|
||||
- **`pad_sequence`**: 对输入 ID 进行填充,使其长度一致
|
||||
- 提供流式与非流式生成接口
|
||||
|
||||
## 训练数据流详细步骤
|
||||
|
||||
1. **数据准备**
|
||||
- 原始文本经过 BBPE 分词器转换为 token ID 序列
|
||||
- 将 token ID 序列(可能带有掩码、标签等)按组保存为 `.h5` 文件
|
||||
- 文件可包含多个分段,每个分段对应一个张量
|
||||
|
||||
2. **数据集加载**
|
||||
- `BaseDataset` 的 `load` 方法调用 `load_h5`,得到 `segments` 字典
|
||||
- 创建 `MultiSegmentFetcher` 管理多个键的数据
|
||||
- 计算总样本数,并根据窗口大小、步长确定每个样本的起始/结束索引
|
||||
|
||||
3. **采样与批量加载**
|
||||
- `ResumableDistributedSampler` 根据当前 epoch 和迭代位置生成索引序列
|
||||
- `DataLoader` 使用采样器获取索引,调用数据集的 `__getitem__` 获取实际数据
|
||||
- 批量数据形状为 `[batch_size, window_size]`(或根据具体数据集类型变化)
|
||||
|
||||
4. **策略前向与损失计算**
|
||||
- 批次数据传入策略(如 `SeqStrategy`)
|
||||
- 策略内部调用 `Transformer` 模型,得到 logits
|
||||
- 根据任务类型计算交叉熵损失(或 DPO 损失等)
|
||||
- 返回 loss 张量
|
||||
|
||||
5. **反向传播与优化**
|
||||
- 损失除以累积步数进行归一化,然后执行 `loss.backward()`
|
||||
- 每累积 `accumulation_steps` 个批次后,执行优化器 `step()` 和 `zero_grad()`
|
||||
- 学习率调度器在每个 step 后更新学习率
|
||||
|
||||
6. **检查点保存**
|
||||
- `CheckpointCallback` 按设定的间隔保存检查点
|
||||
- 检查点包含模型状态字典、当前 epoch、iteration 等元数据
|
||||
- 使用 safetensors 格式保存,确保安全与效率
|
||||
|
||||
## 推理数据流详细步骤
|
||||
|
||||
1. **模型加载**
|
||||
- 从检查点加载 `Transformer` 模型与分词器
|
||||
- 模型设置为评估模式 (`model.eval()`),启用推理模式 (`torch.inference_mode`)
|
||||
|
||||
2. **提示构建与编码**
|
||||
- 用户查询与历史记录通过 `build_prompt` 转换为 ChatML 格式字符串
|
||||
- 分词器将提示字符串编码为 token ID 序列 `input_ids`
|
||||
- 若为批量生成,使用 `pad_sequence` 进行填充
|
||||
|
||||
3. **自回归生成循环**
|
||||
- 初始化 KV 缓存(可选)
|
||||
- 循环直到生成 `max_len` 个 token 或遇到停止 token:
|
||||
- 将当前 `input_ids`(或缓存后的新 token)输入模型,得到 `logits`
|
||||
- 对 `logits` 应用 `apply_sampling_strategies`(温度、top‑k、top‑p)
|
||||
- 从处理后的分布中采样得到下一个 token ID
|
||||
- 将新 token 追加到 `input_ids`,同时更新 KV 缓存
|
||||
- 若为流式生成,每生成一个 token 立即 yield 给调用方
|
||||
|
||||
4. **解码与输出**
|
||||
- 将生成的 token ID 序列通过分词器解码为文本
|
||||
- 去除特殊 token,返回纯文本响应
|
||||
|
||||
## 检查点与序列化
|
||||
|
||||
- **训练检查点**:保存模型参数、优化器状态、调度器状态、当前 epoch 与 iteration
|
||||
- **模型参数**:支持 safetensors 格式,加载时自动处理权重绑定等特殊逻辑
|
||||
- **数据集序列化**:HDF5 格式支持高效随机读取与共享内存,适合大规模预训练数据
|
||||
|
||||
## 总结
|
||||
|
||||
KHAOSZ 的数据流设计体现了模块化、可扩展、可恢复的特点。训练数据流通过分块加载、可恢复采样、梯度累积等机制支持大规模分布式训练;推理数据流则利用 KV 缓存、采样策略实现高效的文本生成。各模块之间通过清晰的接口耦合,便于定制与扩展。
|
||||
|
||||
> 文档更新时间:2026‑03‑30
|
||||
> 对应代码版本:参考 `pyproject.toml` 中定义的版本号
|
||||
|
|
@ -87,3 +87,32 @@ $$ k_j = R_j W_k x_j $$
|
|||
$$ q_i^T k_j = (R_i W_q x_i)^T( R_j W_k x_j) = x_i^T W_q^T R_{i-j} W_k x_j $$
|
||||
|
||||
其中的 $R_{i-j}$ 控制了模型的不同token 在不同相对距离上注意力的衰减,在 $i - j$ 绝对值越大的时候, 衰减的程度越强, 通过这种方式能让模型学习到相对位置关系, 从而使得模型可以扩展和适应长序列
|
||||
|
||||
|
||||
## kv_cache 实现
|
||||
|
||||
根据注意力的计算公式
|
||||
|
||||
$$
|
||||
\begin{align*}
|
||||
o_i &= \sum_j s_{ij} v_{j} \newline
|
||||
s_{ij} &= \text{softmax}\left( \frac{q_{i} k_{j}}{\sqrt{d_k}} \right)
|
||||
\end{align*}
|
||||
$$
|
||||
|
||||
由于模型是自回归模型, 我们只用求序列最后一个部分,也就是说 $ i $ 的下标是确定的, 是序列最后一个元素, 我们求的是 $o_{n} $
|
||||
|
||||
$$
|
||||
\begin{align*}
|
||||
o_n &= \sum_j s_{j}v_{j} \newline
|
||||
s_j &= \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}} \right)
|
||||
\end{align*}
|
||||
$$
|
||||
|
||||
如果我们把式子展开
|
||||
|
||||
$$
|
||||
o_n = \sum_j \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}}\right)v_{j}
|
||||
$$
|
||||
|
||||
以上表达式只有k和v存在长度下标, 而 $q$ 没有, 所以计算过程中 $q$ 的输入是确定的上次输入的最后一个token, 而 $k, v$ 是需要对不同长度的部分进行缓存的,同时缓存的时候应该注意位置编码的计算应该在kvcache的计算之前进行,否则会存在位置编码的计算错误
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
## kv_cache 实现
|
||||
|
||||
根据注意力的计算公式
|
||||
|
||||
$$
|
||||
\begin{align*}
|
||||
o_i &= \sum_j s_{ij} v_{j} \newline
|
||||
s_{ij} &= \text{softmax}\left( \frac{q_{i} k_{j}}{\sqrt{d_k}} \right)
|
||||
\end{align*}
|
||||
$$
|
||||
|
||||
由于模型是自回归模型, 我们只用求序列最后一个部分,也就是说 $ i $ 的下标是确定的, 是序列最后一个元素, 我们求的是 $o_{n} $
|
||||
|
||||
$$
|
||||
\begin{align*}
|
||||
o_n &= \sum_j s_{j}v_{j} \newline
|
||||
s_j &= \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}} \right)
|
||||
\end{align*}
|
||||
$$
|
||||
|
||||
如果我们把式子展开
|
||||
|
||||
$$
|
||||
o_n = \sum_j \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}}\right)v_{j}
|
||||
$$
|
||||
|
||||
以上表达式只有k和v存在长度下标, 而 $q$ 没有, 所以计算过程中 $q$ 的输入是确定的上次输入的最后一个token, 而 $k, v$ 是需要对不同长度的部分进行缓存的,同时缓存的时候应该注意位置编码的计算应该在kvcache的计算之前进行,否则会存在位置编码的计算错误
|
||||
|
|
@ -191,7 +191,7 @@ class EmbeddingEncoderCore:
|
|||
sentence_embs: List[Tensor] = []
|
||||
for i in range(len(batch_ids)):
|
||||
indices = [idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i]
|
||||
if indices is not None:
|
||||
if indices:
|
||||
sum_frags = torch.sum(fragment_embs[indices, :, :], dim=1) # [frags, hidden_size]
|
||||
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(1) # [frags, 1]
|
||||
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
|
||||
|
|
@ -228,11 +228,11 @@ class KVCacheManager:
|
|||
self._initialize()
|
||||
|
||||
def _initialize(self):
|
||||
k_cache = torch.zeros(
|
||||
k_cache = torch.empty(
|
||||
(self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim),
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
v_cache = torch.zeros(
|
||||
v_cache = torch.empty(
|
||||
(self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim),
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ class RotaryEmbedding(nn.Module):
|
|||
seq_len = x.size(1)
|
||||
|
||||
if self.max_len_cached < seq_len + start_pos:
|
||||
self._set_rotary_buffer(seq_len)
|
||||
self._set_rotary_buffer(seq_len + start_pos)
|
||||
|
||||
cos = self.cos_cached[start_pos : start_pos + seq_len]
|
||||
sin = self.sin_cached[start_pos : start_pos + seq_len]
|
||||
|
|
@ -237,6 +237,7 @@ class MLA(nn.Module):
|
|||
use_gated_attention: bool,
|
||||
layer_id: int
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.n_heads = n_heads
|
||||
self.n_kv_heads = n_kv_heads
|
||||
|
|
|
|||
|
|
@ -82,10 +82,13 @@ def only_on_rank(rank, sync=False):
|
|||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if get_rank() == rank:
|
||||
return func(*args, **kwargs)
|
||||
if sync:
|
||||
ret_args = func(*args, **kwargs)
|
||||
|
||||
if sync and dist.is_available() and dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
return ret_args
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
|
|
|||
|
|
@ -74,19 +74,16 @@ class SchedulerCallback(TrainCallback):
|
|||
Scheduler callback for trainer.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.scheduler: LRScheduler = None
|
||||
pass
|
||||
|
||||
def on_train_begin(self, context: TrainContext):
|
||||
for group in context.optimizer.param_groups:
|
||||
if "initial_lr" not in group:
|
||||
group["initial_lr"] = group["lr"]
|
||||
|
||||
self.scheduler = context.scheduler
|
||||
|
||||
def on_batch_end(self, context: TrainContext):
|
||||
_ = context
|
||||
if self.scheduler:
|
||||
self.scheduler.step()
|
||||
if context.scheduler:
|
||||
context.scheduler.step()
|
||||
|
||||
|
||||
class CheckpointCallback(TrainCallback):
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ class TrainContextBuilder:
|
|||
return self
|
||||
|
||||
def with_strategy(self) -> Self:
|
||||
self._context.strategy = StrategyFactory.load(
|
||||
self._context.strategy = StrategyFactory.create(
|
||||
model=self._context.model,
|
||||
train_type=self.config.strategy,
|
||||
device=get_current_device(),
|
||||
|
|
|
|||
|
|
@ -15,10 +15,13 @@ dependencies = [
|
|||
"tqdm==4.67.1",
|
||||
"safetensors==0.5.3",
|
||||
"huggingface-hub==0.34.3",
|
||||
"pytest==9.0.2"
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = ["pytest==9.0.2"]
|
||||
|
||||
keywords = ["nlp", "datasets", "language-models", "machine-learning"]
|
||||
license = { text = "GPL-3.0" }
|
||||
license = ["GPL-3.0"]
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: GPL-3.0",
|
||||
|
|
|
|||
Loading…
Reference in New Issue