diff --git a/assets/docs/dataflow.md b/assets/docs/dataflow.md new file mode 100644 index 0000000..6ba9c74 --- /dev/null +++ b/assets/docs/dataflow.md @@ -0,0 +1,205 @@ +# KHAOSZ 数据流文档 + +本文档描述 KHAOSZ 项目(一个自回归 Transformer 语言模型的训练与推理框架)的数据流。涵盖从原始数据到模型训练、推理的完整流程。 + +## 概述 + +KHAOSZ 采用模块化设计,主要组件包括: +- **数据模块** (`khaosz/data/`): 数据集、采样器、分词器、序列化工具 +- **模型模块** (`khaosz/model/`): Transformer 模型及其子模块 +- **训练模块** (`khaosz/trainer/`): 训练器、训练上下文、策略、调度器 +- **推理模块** (`khaosz/inference/`): 生成核心、KV 缓存管理、流式生成 +- **配置模块** (`khaosz/config/`): 模型、训练、调度等配置 +- **并行模块** (`khaosz/parallel/`): 分布式训练支持 + +数据流总体可分为 **训练数据流** 与 **推理数据流** 两条主线。 + +## 数据流图 + +```mermaid +flowchart LR + subgraph A[数据准备] + direction TB + A1[原始文本] --> A2[BBPE 分词器] + A2 --> A3[序列化为 .h5 文件] + A3 --> A4[数据集加载
BaseDataset] + A4 --> A5[可恢复分布式采样器
ResumableDistributedSampler] + A5 --> A6[DataLoader 批量加载] + end + + subgraph B[训练循环] + direction TB + B1[批次数据] --> B2[训练策略
BaseStrategy] + B2 --> B3[Transformer 模型] + B3 --> B4[输出 logits] + B4 --> B5[损失计算] + B5 --> B6[反向传播] + B6 --> B7[优化器更新] + B7 --> B8[学习率调度器] + B8 --> B9[检查点保存] + end + + subgraph C[推理生成] + direction TB + C1[检查点加载] --> C2[推理模型加载] + C2 --> C3[生成核心
GeneratorCore] + C3 --> C4[采样策略
温度/top‑k/top‑p] + C4 --> C5[生成下一个 token] + C5 --> C6[KV 缓存更新] + C6 --> C7{是否达到最大长度?} + C7 -->|否| C5 + C7 -->|是| C8[输出生成文本] + end + + A --> B + B --> C +``` + +## 各模块详细说明 + +### 1. 数据模块 + +#### 1.1 分词器 (`tokenizer.py`) +- 基于 Byte‑Level BPE (BBPE) 实现 +- 支持特殊 token:``, ``, ``, `<|im_start|>`, `<|im_end|>` +- 提供 `encode`/`decode` 方法,将文本与 token ID 相互转换 +- 训练时从语料库学习词汇表,保存为 `.json` 文件 + +#### 1.2 序列化 (`serialization.py`) +- **`save_h5`**: 将多个张量按组保存为 HDF5 文件(`.h5`),每个键对应一个张量列表 +- **`load_h5`**: 加载 `.h5` 文件,返回 `Dict[str, List[Tensor]]`,支持共享内存 (`share_memory=True`) +- **`Checkpoint` 类**: 封装模型状态字典、训练轮次、迭代次数,支持 safetensors 格式保存与加载 + +#### 1.3 数据集 (`dataset.py`) +- **`BaseDataset`**: 抽象基类,定义窗口采样、步长等通用逻辑 +- **`BaseSegmentFetcher`** 与 **`MultiSegmentFetcher`**: 高效地从多个分段中获取指定索引范围的数据 +- **`DatasetFactory`**: 工厂模式,支持动态注册数据集类型(`seq`, `sft`, `dpo`, `grpo`) +- 数据集加载后通过 `MultiSegmentFetcher` 管理多个数据键(如 `"sequence"`, `"mask"`) + +#### 1.4 采样器 (`sampler.py`) +- **`ResumableDistributedSampler`**: 支持分布式训练的可恢复采样器 +- 记录当前 epoch 和迭代位置,便于从断点继续训练 +- 支持 shuffle 与 drop_last 选项 + +### 2. 模型模块 + +#### 2.1 Transformer (`transformer.py`) +- 核心自回归解码器架构 +- 包含嵌入层、多层 `DecoderBlock`、RMSNorm 和线性输出头 +- 支持权重绑定 (`tie_weight=True`) 以减小参数量 +- 使用 Rotary Position Embedding (RoPE) 注入位置信息 + +#### 2.2 子模块 (`module.py`) +- **`RotaryEmbedding`**: 生成 RoPE 的 cos/sin 缓存 +- **`DecoderBlock`**: 包含多头注意力(支持 GQA)、前馈网络(FFN)、残差连接 +- **`RMSNorm`**: 层归一化变体 +- **`Linear`**, **`Embedding`**: 自定义线性层与嵌入层,支持并行化包装 + +### 3. 训练模块 + +#### 3.1 训练上下文 (`train_context.py`) +- **`TrainContext`**: 数据类,封装训练所需的所有组件(模型、优化器、数据加载器、策略等) +- **`TrainContextBuilder`**: 构建器模式,逐步组装训练上下文,支持从检查点恢复 + +#### 3.2 训练器 (`trainer.py`) +- **`Trainer`**: 主训练循环,管理回调函数(进度条、检查点、指标记录、梯度裁剪、调度器) +- 支持分布式训练(通过 `spawn_parallel_fn` 启动多进程) +- 训练步骤包括: + 1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. 前向/损失计算 → 5. `on_batch_end` → 6. 梯度累积 → 7. `on_step_begin` → 8. 优化器更新 → 9. `on_step_end` → 10. `on_epoch_end` + +#### 3.3 策略 (`strategy.py`) +- **`BaseStrategy`**: 定义训练策略接口(如 `SeqStrategy`, `SFTStrategy`, `DPOStrategy`) +- 策略接收批次数据,执行模型前向传播、损失计算,返回 loss 张量 +- 由 `StrategyFactory` 根据配置动态创建 + +#### 3.4 调度器 (`schedule.py`) +- **`BaseScheduler`**: 抽象基类,定义学习率调度接口 +- **`SchedulerFactory`**: 工厂模式,支持注册多种调度器(如 `cosine`, `sgdr`) +- 调度器根据配置自动创建,并与优化器绑定 + +### 4. 推理模块 + +#### 4.1 生成核心 (`core.py`) +- **`GeneratorCore`**: 提供 `generate_iterator` 方法,执行单步生成 +- 应用采样策略(温度、top‑k、top‑p)对 logits 进行筛选 +- 支持 KV 缓存以加速自回归生成 + +#### 4.2 KV 缓存管理 (`core.py`) +- **`KVCacheManager`**: 管理每层的 K 和 V 缓存,支持批量生成与长度扩展 +- 缓存形状为 `[batch_size, n_kv_heads, seq_len, head_dim]` + +#### 4.3 生成器 (`generator.py`) +- **`GenerationRequest`**: 封装生成请求参数(top_k, top_p, temperature, max_len, query, history 等) +- **`build_prompt`**: 将查询与历史记录转换为 ChatML 格式的提示字符串 +- **`pad_sequence`**: 对输入 ID 进行填充,使其长度一致 +- 提供流式与非流式生成接口 + +## 训练数据流详细步骤 + +1. **数据准备** + - 原始文本经过 BBPE 分词器转换为 token ID 序列 + - 将 token ID 序列(可能带有掩码、标签等)按组保存为 `.h5` 文件 + - 文件可包含多个分段,每个分段对应一个张量 + +2. **数据集加载** + - `BaseDataset` 的 `load` 方法调用 `load_h5`,得到 `segments` 字典 + - 创建 `MultiSegmentFetcher` 管理多个键的数据 + - 计算总样本数,并根据窗口大小、步长确定每个样本的起始/结束索引 + +3. **采样与批量加载** + - `ResumableDistributedSampler` 根据当前 epoch 和迭代位置生成索引序列 + - `DataLoader` 使用采样器获取索引,调用数据集的 `__getitem__` 获取实际数据 + - 批量数据形状为 `[batch_size, window_size]`(或根据具体数据集类型变化) + +4. **策略前向与损失计算** + - 批次数据传入策略(如 `SeqStrategy`) + - 策略内部调用 `Transformer` 模型,得到 logits + - 根据任务类型计算交叉熵损失(或 DPO 损失等) + - 返回 loss 张量 + +5. **反向传播与优化** + - 损失除以累积步数进行归一化,然后执行 `loss.backward()` + - 每累积 `accumulation_steps` 个批次后,执行优化器 `step()` 和 `zero_grad()` + - 学习率调度器在每个 step 后更新学习率 + +6. **检查点保存** + - `CheckpointCallback` 按设定的间隔保存检查点 + - 检查点包含模型状态字典、当前 epoch、iteration 等元数据 + - 使用 safetensors 格式保存,确保安全与效率 + +## 推理数据流详细步骤 + +1. **模型加载** + - 从检查点加载 `Transformer` 模型与分词器 + - 模型设置为评估模式 (`model.eval()`),启用推理模式 (`torch.inference_mode`) + +2. **提示构建与编码** + - 用户查询与历史记录通过 `build_prompt` 转换为 ChatML 格式字符串 + - 分词器将提示字符串编码为 token ID 序列 `input_ids` + - 若为批量生成,使用 `pad_sequence` 进行填充 + +3. **自回归生成循环** + - 初始化 KV 缓存(可选) + - 循环直到生成 `max_len` 个 token 或遇到停止 token: + - 将当前 `input_ids`(或缓存后的新 token)输入模型,得到 `logits` + - 对 `logits` 应用 `apply_sampling_strategies`(温度、top‑k、top‑p) + - 从处理后的分布中采样得到下一个 token ID + - 将新 token 追加到 `input_ids`,同时更新 KV 缓存 + - 若为流式生成,每生成一个 token 立即 yield 给调用方 + +4. **解码与输出** + - 将生成的 token ID 序列通过分词器解码为文本 + - 去除特殊 token,返回纯文本响应 + +## 检查点与序列化 + +- **训练检查点**:保存模型参数、优化器状态、调度器状态、当前 epoch 与 iteration +- **模型参数**:支持 safetensors 格式,加载时自动处理权重绑定等特殊逻辑 +- **数据集序列化**:HDF5 格式支持高效随机读取与共享内存,适合大规模预训练数据 + +## 总结 + +KHAOSZ 的数据流设计体现了模块化、可扩展、可恢复的特点。训练数据流通过分块加载、可恢复采样、梯度累积等机制支持大规模分布式训练;推理数据流则利用 KV 缓存、采样策略实现高效的文本生成。各模块之间通过清晰的接口耦合,便于定制与扩展。 + +> 文档更新时间:2026‑03‑30 +> 对应代码版本:参考 `pyproject.toml` 中定义的版本号 \ No newline at end of file diff --git a/assets/docs/architecture.md b/assets/docs/design.md similarity index 100% rename from assets/docs/architecture.md rename to assets/docs/design.md diff --git a/assets/docs/introduction.md b/assets/docs/introduction.md index 7333403..4b678f5 100644 --- a/assets/docs/introduction.md +++ b/assets/docs/introduction.md @@ -86,4 +86,33 @@ $$ q_i = R_i W_q x_i $$ $$ k_j = R_j W_k x_j $$ $$ q_i^T k_j = (R_i W_q x_i)^T( R_j W_k x_j) = x_i^T W_q^T R_{i-j} W_k x_j $$ -其中的 $R_{i-j}$ 控制了模型的不同token 在不同相对距离上注意力的衰减,在 $i - j$ 绝对值越大的时候, 衰减的程度越强, 通过这种方式能让模型学习到相对位置关系, 从而使得模型可以扩展和适应长序列 \ No newline at end of file +其中的 $R_{i-j}$ 控制了模型的不同token 在不同相对距离上注意力的衰减,在 $i - j$ 绝对值越大的时候, 衰减的程度越强, 通过这种方式能让模型学习到相对位置关系, 从而使得模型可以扩展和适应长序列 + + +## kv_cache 实现 + +根据注意力的计算公式 + +$$ +\begin{align*} +o_i &= \sum_j s_{ij} v_{j} \newline +s_{ij} &= \text{softmax}\left( \frac{q_{i} k_{j}}{\sqrt{d_k}} \right) +\end{align*} +$$ + +由于模型是自回归模型, 我们只用求序列最后一个部分,也就是说 $ i $ 的下标是确定的, 是序列最后一个元素, 我们求的是 $o_{n} $ + +$$ +\begin{align*} +o_n &= \sum_j s_{j}v_{j} \newline +s_j &= \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}} \right) +\end{align*} +$$ + +如果我们把式子展开 + +$$ +o_n = \sum_j \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}}\right)v_{j} +$$ + +以上表达式只有k和v存在长度下标, 而 $q$ 没有, 所以计算过程中 $q$ 的输入是确定的上次输入的最后一个token, 而 $k, v$ 是需要对不同长度的部分进行缓存的,同时缓存的时候应该注意位置编码的计算应该在kvcache的计算之前进行,否则会存在位置编码的计算错误 \ No newline at end of file diff --git a/assets/docs/kvcache.md b/assets/docs/kvcache.md deleted file mode 100644 index 63a5511..0000000 --- a/assets/docs/kvcache.md +++ /dev/null @@ -1,27 +0,0 @@ -## kv_cache 实现 - -根据注意力的计算公式 - -$$ -\begin{align*} -o_i &= \sum_j s_{ij} v_{j} \newline -s_{ij} &= \text{softmax}\left( \frac{q_{i} k_{j}}{\sqrt{d_k}} \right) -\end{align*} -$$ - -由于模型是自回归模型, 我们只用求序列最后一个部分,也就是说 $ i $ 的下标是确定的, 是序列最后一个元素, 我们求的是 $o_{n} $ - -$$ -\begin{align*} -o_n &= \sum_j s_{j}v_{j} \newline -s_j &= \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}} \right) -\end{align*} -$$ - -如果我们把式子展开 - -$$ -o_n = \sum_j \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}}\right)v_{j} -$$ - -以上表达式只有k和v存在长度下标, 而 $q$ 没有, 所以计算过程中 $q$ 的输入是确定的上次输入的最后一个token, 而 $k, v$ 是需要对不同长度的部分进行缓存的,同时缓存的时候应该注意位置编码的计算应该在kvcache的计算之前进行,否则会存在位置编码的计算错误 \ No newline at end of file diff --git a/khaosz/inference/core.py b/khaosz/inference/core.py index d2fbcc7..dc37229 100644 --- a/khaosz/inference/core.py +++ b/khaosz/inference/core.py @@ -191,7 +191,7 @@ class EmbeddingEncoderCore: sentence_embs: List[Tensor] = [] for i in range(len(batch_ids)): indices = [idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i] - if indices is not None: + if indices: sum_frags = torch.sum(fragment_embs[indices, :, :], dim=1) # [frags, hidden_size] length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(1) # [frags, 1] emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size] @@ -228,11 +228,11 @@ class KVCacheManager: self._initialize() def _initialize(self): - k_cache = torch.zeros( + k_cache = torch.empty( (self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim), device=self.device, dtype=self.dtype ) - v_cache = torch.zeros( + v_cache = torch.empty( (self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim), device=self.device, dtype=self.dtype ) diff --git a/khaosz/model/module.py b/khaosz/model/module.py index 14d5dc6..ca14b40 100644 --- a/khaosz/model/module.py +++ b/khaosz/model/module.py @@ -93,7 +93,7 @@ class RotaryEmbedding(nn.Module): seq_len = x.size(1) if self.max_len_cached < seq_len + start_pos: - self._set_rotary_buffer(seq_len) + self._set_rotary_buffer(seq_len + start_pos) cos = self.cos_cached[start_pos : start_pos + seq_len] sin = self.sin_cached[start_pos : start_pos + seq_len] @@ -237,6 +237,7 @@ class MLA(nn.Module): use_gated_attention: bool, layer_id: int ): + super().__init__() self.dim = dim self.n_heads = n_heads self.n_kv_heads = n_kv_heads diff --git a/khaosz/parallel/setup.py b/khaosz/parallel/setup.py index c1c6686..452e30f 100644 --- a/khaosz/parallel/setup.py +++ b/khaosz/parallel/setup.py @@ -82,9 +82,12 @@ def only_on_rank(rank, sync=False): @wraps(func) def wrapper(*args, **kwargs): if get_rank() == rank: - return func(*args, **kwargs) - if sync: + ret_args = func(*args, **kwargs) + + if sync and dist.is_available() and dist.is_initialized(): dist.barrier() + + return ret_args return wrapper diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index 6b12b0d..986b72a 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -74,19 +74,16 @@ class SchedulerCallback(TrainCallback): Scheduler callback for trainer. """ def __init__(self): - self.scheduler: LRScheduler = None + pass def on_train_begin(self, context: TrainContext): for group in context.optimizer.param_groups: if "initial_lr" not in group: group["initial_lr"] = group["lr"] - - self.scheduler = context.scheduler def on_batch_end(self, context: TrainContext): - _ = context - if self.scheduler: - self.scheduler.step() + if context.scheduler: + context.scheduler.step() class CheckpointCallback(TrainCallback): diff --git a/khaosz/trainer/train_context.py b/khaosz/trainer/train_context.py index f2172ad..0f7e08c 100644 --- a/khaosz/trainer/train_context.py +++ b/khaosz/trainer/train_context.py @@ -87,7 +87,7 @@ class TrainContextBuilder: return self def with_strategy(self) -> Self: - self._context.strategy = StrategyFactory.load( + self._context.strategy = StrategyFactory.create( model=self._context.model, train_type=self.config.strategy, device=get_current_device(), diff --git a/pyproject.toml b/pyproject.toml index e64e83d..f2b5e2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,10 +15,13 @@ dependencies = [ "tqdm==4.67.1", "safetensors==0.5.3", "huggingface-hub==0.34.3", - "pytest==9.0.2" ] + +[project.optional-dependencies] +dev = ["pytest==9.0.2"] + keywords = ["nlp", "datasets", "language-models", "machine-learning"] -license = { text = "GPL-3.0" } +license = ["GPL-3.0"] classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: GPL-3.0",