diff --git a/assets/docs/dataflow.md b/assets/docs/dataflow.md
new file mode 100644
index 0000000..6ba9c74
--- /dev/null
+++ b/assets/docs/dataflow.md
@@ -0,0 +1,205 @@
+# KHAOSZ 数据流文档
+
+本文档描述 KHAOSZ 项目(一个自回归 Transformer 语言模型的训练与推理框架)的数据流。涵盖从原始数据到模型训练、推理的完整流程。
+
+## 概述
+
+KHAOSZ 采用模块化设计,主要组件包括:
+- **数据模块** (`khaosz/data/`): 数据集、采样器、分词器、序列化工具
+- **模型模块** (`khaosz/model/`): Transformer 模型及其子模块
+- **训练模块** (`khaosz/trainer/`): 训练器、训练上下文、策略、调度器
+- **推理模块** (`khaosz/inference/`): 生成核心、KV 缓存管理、流式生成
+- **配置模块** (`khaosz/config/`): 模型、训练、调度等配置
+- **并行模块** (`khaosz/parallel/`): 分布式训练支持
+
+数据流总体可分为 **训练数据流** 与 **推理数据流** 两条主线。
+
+## 数据流图
+
+```mermaid
+flowchart LR
+ subgraph A[数据准备]
+ direction TB
+ A1[原始文本] --> A2[BBPE 分词器]
+ A2 --> A3[序列化为 .h5 文件]
+ A3 --> A4[数据集加载
BaseDataset]
+ A4 --> A5[可恢复分布式采样器
ResumableDistributedSampler]
+ A5 --> A6[DataLoader 批量加载]
+ end
+
+ subgraph B[训练循环]
+ direction TB
+ B1[批次数据] --> B2[训练策略
BaseStrategy]
+ B2 --> B3[Transformer 模型]
+ B3 --> B4[输出 logits]
+ B4 --> B5[损失计算]
+ B5 --> B6[反向传播]
+ B6 --> B7[优化器更新]
+ B7 --> B8[学习率调度器]
+ B8 --> B9[检查点保存]
+ end
+
+ subgraph C[推理生成]
+ direction TB
+ C1[检查点加载] --> C2[推理模型加载]
+ C2 --> C3[生成核心
GeneratorCore]
+ C3 --> C4[采样策略
温度/top‑k/top‑p]
+ C4 --> C5[生成下一个 token]
+ C5 --> C6[KV 缓存更新]
+ C6 --> C7{是否达到最大长度?}
+ C7 -->|否| C5
+ C7 -->|是| C8[输出生成文本]
+ end
+
+ A --> B
+ B --> C
+```
+
+## 各模块详细说明
+
+### 1. 数据模块
+
+#### 1.1 分词器 (`tokenizer.py`)
+- 基于 Byte‑Level BPE (BBPE) 实现
+- 支持特殊 token:``, ``, ``, `<|im_start|>`, `<|im_end|>`
+- 提供 `encode`/`decode` 方法,将文本与 token ID 相互转换
+- 训练时从语料库学习词汇表,保存为 `.json` 文件
+
+#### 1.2 序列化 (`serialization.py`)
+- **`save_h5`**: 将多个张量按组保存为 HDF5 文件(`.h5`),每个键对应一个张量列表
+- **`load_h5`**: 加载 `.h5` 文件,返回 `Dict[str, List[Tensor]]`,支持共享内存 (`share_memory=True`)
+- **`Checkpoint` 类**: 封装模型状态字典、训练轮次、迭代次数,支持 safetensors 格式保存与加载
+
+#### 1.3 数据集 (`dataset.py`)
+- **`BaseDataset`**: 抽象基类,定义窗口采样、步长等通用逻辑
+- **`BaseSegmentFetcher`** 与 **`MultiSegmentFetcher`**: 高效地从多个分段中获取指定索引范围的数据
+- **`DatasetFactory`**: 工厂模式,支持动态注册数据集类型(`seq`, `sft`, `dpo`, `grpo`)
+- 数据集加载后通过 `MultiSegmentFetcher` 管理多个数据键(如 `"sequence"`, `"mask"`)
+
+#### 1.4 采样器 (`sampler.py`)
+- **`ResumableDistributedSampler`**: 支持分布式训练的可恢复采样器
+- 记录当前 epoch 和迭代位置,便于从断点继续训练
+- 支持 shuffle 与 drop_last 选项
+
+### 2. 模型模块
+
+#### 2.1 Transformer (`transformer.py`)
+- 核心自回归解码器架构
+- 包含嵌入层、多层 `DecoderBlock`、RMSNorm 和线性输出头
+- 支持权重绑定 (`tie_weight=True`) 以减小参数量
+- 使用 Rotary Position Embedding (RoPE) 注入位置信息
+
+#### 2.2 子模块 (`module.py`)
+- **`RotaryEmbedding`**: 生成 RoPE 的 cos/sin 缓存
+- **`DecoderBlock`**: 包含多头注意力(支持 GQA)、前馈网络(FFN)、残差连接
+- **`RMSNorm`**: 层归一化变体
+- **`Linear`**, **`Embedding`**: 自定义线性层与嵌入层,支持并行化包装
+
+### 3. 训练模块
+
+#### 3.1 训练上下文 (`train_context.py`)
+- **`TrainContext`**: 数据类,封装训练所需的所有组件(模型、优化器、数据加载器、策略等)
+- **`TrainContextBuilder`**: 构建器模式,逐步组装训练上下文,支持从检查点恢复
+
+#### 3.2 训练器 (`trainer.py`)
+- **`Trainer`**: 主训练循环,管理回调函数(进度条、检查点、指标记录、梯度裁剪、调度器)
+- 支持分布式训练(通过 `spawn_parallel_fn` 启动多进程)
+- 训练步骤包括:
+ 1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. 前向/损失计算 → 5. `on_batch_end` → 6. 梯度累积 → 7. `on_step_begin` → 8. 优化器更新 → 9. `on_step_end` → 10. `on_epoch_end`
+
+#### 3.3 策略 (`strategy.py`)
+- **`BaseStrategy`**: 定义训练策略接口(如 `SeqStrategy`, `SFTStrategy`, `DPOStrategy`)
+- 策略接收批次数据,执行模型前向传播、损失计算,返回 loss 张量
+- 由 `StrategyFactory` 根据配置动态创建
+
+#### 3.4 调度器 (`schedule.py`)
+- **`BaseScheduler`**: 抽象基类,定义学习率调度接口
+- **`SchedulerFactory`**: 工厂模式,支持注册多种调度器(如 `cosine`, `sgdr`)
+- 调度器根据配置自动创建,并与优化器绑定
+
+### 4. 推理模块
+
+#### 4.1 生成核心 (`core.py`)
+- **`GeneratorCore`**: 提供 `generate_iterator` 方法,执行单步生成
+- 应用采样策略(温度、top‑k、top‑p)对 logits 进行筛选
+- 支持 KV 缓存以加速自回归生成
+
+#### 4.2 KV 缓存管理 (`core.py`)
+- **`KVCacheManager`**: 管理每层的 K 和 V 缓存,支持批量生成与长度扩展
+- 缓存形状为 `[batch_size, n_kv_heads, seq_len, head_dim]`
+
+#### 4.3 生成器 (`generator.py`)
+- **`GenerationRequest`**: 封装生成请求参数(top_k, top_p, temperature, max_len, query, history 等)
+- **`build_prompt`**: 将查询与历史记录转换为 ChatML 格式的提示字符串
+- **`pad_sequence`**: 对输入 ID 进行填充,使其长度一致
+- 提供流式与非流式生成接口
+
+## 训练数据流详细步骤
+
+1. **数据准备**
+ - 原始文本经过 BBPE 分词器转换为 token ID 序列
+ - 将 token ID 序列(可能带有掩码、标签等)按组保存为 `.h5` 文件
+ - 文件可包含多个分段,每个分段对应一个张量
+
+2. **数据集加载**
+ - `BaseDataset` 的 `load` 方法调用 `load_h5`,得到 `segments` 字典
+ - 创建 `MultiSegmentFetcher` 管理多个键的数据
+ - 计算总样本数,并根据窗口大小、步长确定每个样本的起始/结束索引
+
+3. **采样与批量加载**
+ - `ResumableDistributedSampler` 根据当前 epoch 和迭代位置生成索引序列
+ - `DataLoader` 使用采样器获取索引,调用数据集的 `__getitem__` 获取实际数据
+ - 批量数据形状为 `[batch_size, window_size]`(或根据具体数据集类型变化)
+
+4. **策略前向与损失计算**
+ - 批次数据传入策略(如 `SeqStrategy`)
+ - 策略内部调用 `Transformer` 模型,得到 logits
+ - 根据任务类型计算交叉熵损失(或 DPO 损失等)
+ - 返回 loss 张量
+
+5. **反向传播与优化**
+ - 损失除以累积步数进行归一化,然后执行 `loss.backward()`
+ - 每累积 `accumulation_steps` 个批次后,执行优化器 `step()` 和 `zero_grad()`
+ - 学习率调度器在每个 step 后更新学习率
+
+6. **检查点保存**
+ - `CheckpointCallback` 按设定的间隔保存检查点
+ - 检查点包含模型状态字典、当前 epoch、iteration 等元数据
+ - 使用 safetensors 格式保存,确保安全与效率
+
+## 推理数据流详细步骤
+
+1. **模型加载**
+ - 从检查点加载 `Transformer` 模型与分词器
+ - 模型设置为评估模式 (`model.eval()`),启用推理模式 (`torch.inference_mode`)
+
+2. **提示构建与编码**
+ - 用户查询与历史记录通过 `build_prompt` 转换为 ChatML 格式字符串
+ - 分词器将提示字符串编码为 token ID 序列 `input_ids`
+ - 若为批量生成,使用 `pad_sequence` 进行填充
+
+3. **自回归生成循环**
+ - 初始化 KV 缓存(可选)
+ - 循环直到生成 `max_len` 个 token 或遇到停止 token:
+ - 将当前 `input_ids`(或缓存后的新 token)输入模型,得到 `logits`
+ - 对 `logits` 应用 `apply_sampling_strategies`(温度、top‑k、top‑p)
+ - 从处理后的分布中采样得到下一个 token ID
+ - 将新 token 追加到 `input_ids`,同时更新 KV 缓存
+ - 若为流式生成,每生成一个 token 立即 yield 给调用方
+
+4. **解码与输出**
+ - 将生成的 token ID 序列通过分词器解码为文本
+ - 去除特殊 token,返回纯文本响应
+
+## 检查点与序列化
+
+- **训练检查点**:保存模型参数、优化器状态、调度器状态、当前 epoch 与 iteration
+- **模型参数**:支持 safetensors 格式,加载时自动处理权重绑定等特殊逻辑
+- **数据集序列化**:HDF5 格式支持高效随机读取与共享内存,适合大规模预训练数据
+
+## 总结
+
+KHAOSZ 的数据流设计体现了模块化、可扩展、可恢复的特点。训练数据流通过分块加载、可恢复采样、梯度累积等机制支持大规模分布式训练;推理数据流则利用 KV 缓存、采样策略实现高效的文本生成。各模块之间通过清晰的接口耦合,便于定制与扩展。
+
+> 文档更新时间:2026‑03‑30
+> 对应代码版本:参考 `pyproject.toml` 中定义的版本号
\ No newline at end of file
diff --git a/assets/docs/architecture.md b/assets/docs/design.md
similarity index 100%
rename from assets/docs/architecture.md
rename to assets/docs/design.md
diff --git a/assets/docs/introduction.md b/assets/docs/introduction.md
index 7333403..4b678f5 100644
--- a/assets/docs/introduction.md
+++ b/assets/docs/introduction.md
@@ -86,4 +86,33 @@ $$ q_i = R_i W_q x_i $$
$$ k_j = R_j W_k x_j $$
$$ q_i^T k_j = (R_i W_q x_i)^T( R_j W_k x_j) = x_i^T W_q^T R_{i-j} W_k x_j $$
-其中的 $R_{i-j}$ 控制了模型的不同token 在不同相对距离上注意力的衰减,在 $i - j$ 绝对值越大的时候, 衰减的程度越强, 通过这种方式能让模型学习到相对位置关系, 从而使得模型可以扩展和适应长序列
\ No newline at end of file
+其中的 $R_{i-j}$ 控制了模型的不同token 在不同相对距离上注意力的衰减,在 $i - j$ 绝对值越大的时候, 衰减的程度越强, 通过这种方式能让模型学习到相对位置关系, 从而使得模型可以扩展和适应长序列
+
+
+## kv_cache 实现
+
+根据注意力的计算公式
+
+$$
+\begin{align*}
+o_i &= \sum_j s_{ij} v_{j} \newline
+s_{ij} &= \text{softmax}\left( \frac{q_{i} k_{j}}{\sqrt{d_k}} \right)
+\end{align*}
+$$
+
+由于模型是自回归模型, 我们只用求序列最后一个部分,也就是说 $ i $ 的下标是确定的, 是序列最后一个元素, 我们求的是 $o_{n} $
+
+$$
+\begin{align*}
+o_n &= \sum_j s_{j}v_{j} \newline
+s_j &= \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}} \right)
+\end{align*}
+$$
+
+如果我们把式子展开
+
+$$
+o_n = \sum_j \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}}\right)v_{j}
+$$
+
+以上表达式只有k和v存在长度下标, 而 $q$ 没有, 所以计算过程中 $q$ 的输入是确定的上次输入的最后一个token, 而 $k, v$ 是需要对不同长度的部分进行缓存的,同时缓存的时候应该注意位置编码的计算应该在kvcache的计算之前进行,否则会存在位置编码的计算错误
\ No newline at end of file
diff --git a/assets/docs/kvcache.md b/assets/docs/kvcache.md
deleted file mode 100644
index 63a5511..0000000
--- a/assets/docs/kvcache.md
+++ /dev/null
@@ -1,27 +0,0 @@
-## kv_cache 实现
-
-根据注意力的计算公式
-
-$$
-\begin{align*}
-o_i &= \sum_j s_{ij} v_{j} \newline
-s_{ij} &= \text{softmax}\left( \frac{q_{i} k_{j}}{\sqrt{d_k}} \right)
-\end{align*}
-$$
-
-由于模型是自回归模型, 我们只用求序列最后一个部分,也就是说 $ i $ 的下标是确定的, 是序列最后一个元素, 我们求的是 $o_{n} $
-
-$$
-\begin{align*}
-o_n &= \sum_j s_{j}v_{j} \newline
-s_j &= \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}} \right)
-\end{align*}
-$$
-
-如果我们把式子展开
-
-$$
-o_n = \sum_j \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}}\right)v_{j}
-$$
-
-以上表达式只有k和v存在长度下标, 而 $q$ 没有, 所以计算过程中 $q$ 的输入是确定的上次输入的最后一个token, 而 $k, v$ 是需要对不同长度的部分进行缓存的,同时缓存的时候应该注意位置编码的计算应该在kvcache的计算之前进行,否则会存在位置编码的计算错误
\ No newline at end of file
diff --git a/khaosz/inference/core.py b/khaosz/inference/core.py
index d2fbcc7..dc37229 100644
--- a/khaosz/inference/core.py
+++ b/khaosz/inference/core.py
@@ -191,7 +191,7 @@ class EmbeddingEncoderCore:
sentence_embs: List[Tensor] = []
for i in range(len(batch_ids)):
indices = [idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i]
- if indices is not None:
+ if indices:
sum_frags = torch.sum(fragment_embs[indices, :, :], dim=1) # [frags, hidden_size]
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(1) # [frags, 1]
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
@@ -228,11 +228,11 @@ class KVCacheManager:
self._initialize()
def _initialize(self):
- k_cache = torch.zeros(
+ k_cache = torch.empty(
(self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim),
device=self.device, dtype=self.dtype
)
- v_cache = torch.zeros(
+ v_cache = torch.empty(
(self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim),
device=self.device, dtype=self.dtype
)
diff --git a/khaosz/model/module.py b/khaosz/model/module.py
index 14d5dc6..ca14b40 100644
--- a/khaosz/model/module.py
+++ b/khaosz/model/module.py
@@ -93,7 +93,7 @@ class RotaryEmbedding(nn.Module):
seq_len = x.size(1)
if self.max_len_cached < seq_len + start_pos:
- self._set_rotary_buffer(seq_len)
+ self._set_rotary_buffer(seq_len + start_pos)
cos = self.cos_cached[start_pos : start_pos + seq_len]
sin = self.sin_cached[start_pos : start_pos + seq_len]
@@ -237,6 +237,7 @@ class MLA(nn.Module):
use_gated_attention: bool,
layer_id: int
):
+ super().__init__()
self.dim = dim
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
diff --git a/khaosz/parallel/setup.py b/khaosz/parallel/setup.py
index c1c6686..452e30f 100644
--- a/khaosz/parallel/setup.py
+++ b/khaosz/parallel/setup.py
@@ -82,9 +82,12 @@ def only_on_rank(rank, sync=False):
@wraps(func)
def wrapper(*args, **kwargs):
if get_rank() == rank:
- return func(*args, **kwargs)
- if sync:
+ ret_args = func(*args, **kwargs)
+
+ if sync and dist.is_available() and dist.is_initialized():
dist.barrier()
+
+ return ret_args
return wrapper
diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py
index 6b12b0d..986b72a 100644
--- a/khaosz/trainer/train_callback.py
+++ b/khaosz/trainer/train_callback.py
@@ -74,19 +74,16 @@ class SchedulerCallback(TrainCallback):
Scheduler callback for trainer.
"""
def __init__(self):
- self.scheduler: LRScheduler = None
+ pass
def on_train_begin(self, context: TrainContext):
for group in context.optimizer.param_groups:
if "initial_lr" not in group:
group["initial_lr"] = group["lr"]
-
- self.scheduler = context.scheduler
def on_batch_end(self, context: TrainContext):
- _ = context
- if self.scheduler:
- self.scheduler.step()
+ if context.scheduler:
+ context.scheduler.step()
class CheckpointCallback(TrainCallback):
diff --git a/khaosz/trainer/train_context.py b/khaosz/trainer/train_context.py
index f2172ad..0f7e08c 100644
--- a/khaosz/trainer/train_context.py
+++ b/khaosz/trainer/train_context.py
@@ -87,7 +87,7 @@ class TrainContextBuilder:
return self
def with_strategy(self) -> Self:
- self._context.strategy = StrategyFactory.load(
+ self._context.strategy = StrategyFactory.create(
model=self._context.model,
train_type=self.config.strategy,
device=get_current_device(),
diff --git a/pyproject.toml b/pyproject.toml
index e64e83d..f2b5e2c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -15,10 +15,13 @@ dependencies = [
"tqdm==4.67.1",
"safetensors==0.5.3",
"huggingface-hub==0.34.3",
- "pytest==9.0.2"
]
+
+[project.optional-dependencies]
+dev = ["pytest==9.0.2"]
+
keywords = ["nlp", "datasets", "language-models", "machine-learning"]
-license = { text = "GPL-3.0" }
+license = ["GPL-3.0"]
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: GPL-3.0",