diff --git a/README.md b/README.md
index e519abe..23d3bbb 100644
--- a/README.md
+++ b/README.md
@@ -1,286 +1,147 @@
-
-
-
+
+

+
+
KHAOSZ
-
-
KHAOSZ
+
+
+ A lightweight Transformer training & inference framework
+
-
English Version
+## 📖 Table of Contents | 目录
-A training and inference framework for autoregressive Transformer language models.
+
-**Model Download Options (choose one):**
+| English | 中文 |
+|---------|------|
+| [Installation](#installation) | [安装](#安装) |
+| [Quick Start](#quick-start) | [快速开始](#快速开始) |
+| [Documentation](#documentation) | [文档](#文档) |
+| [License](#license) | [许可证](#许可证) |
-1. Visit [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) and check **Files and versions**
-2. Run `scripts/download.py` to download model parameters
+
-**Demo Video:** [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd)
+---
-For training data sources, please refer to the **Model Card** section on the HuggingFace download page.
+
+## English
-**License:** The code follows the GPL-3.0 license. Please provide attribution when using it.
+### Features
-- **📊 Device Selection:** Uses CUDA for training by default
-- **🌐 Performance Optimization:** Enable `dtype=torch.bfloat16` to accelerate training and reduce memory usage. Ensure your hardware supports this feature
-- **🤖 Language Support:** The model supports training in Chinese and English. Since the BBPE tokenizer hasn't been trained on multilingual text, OOV (Out-of-Vocabulary) issues are minimal for Chinese and English, but may exist for other languages
+- 🚀 **High Performance**: Optimized for both training and inference
+- 🔧 **Flexible**: Support for seq/sft/dpo training
+- 💡 **Easy to Use**: Simple API with comprehensive examples
+- 📦 **Lightweight**: Minimal dependencies
-
-### 📌 Training Guide
-
-To train this Transformer model, follow these steps:
-
-**(1). Prepare the Dataset:**
-
-Place the dataset in the specified root directory. This system uses the BBPE tokenizer for tokenization and requires training with pre-tokenized segments (stored as *.h5 format files).
-
-**(2). Install Dependencies:**
+### Installation
```bash
+git clone https://github.com/username/khaosz.git
+cd khaosz
pip install -e .
```
-**(3). Run the Training Script:**
+### Quick Start
```bash
-python train.py \
---train_type=train_type[seq, sft, dpo] \
---data_root_path=/path/to/dataset \
---param_path=/path/to/param_path \
---n_epoch=5 \
---batch_size=8 \
---max_lr=2e-4 \
---ckpt_interval=10000 \
---ckpt_dir=checkpoints
+# Train
+python tools/train.py \
+ --train_type=seq \
+ --data_root_path=/path/to/dataset \
+ --param_path=/path/to/param_path
+
+# Generate
+python tools/generate.py --param_path=/path/to/param_path
```
-**Parameter Explanation:**
-- `--train_type`: Training type (seq, sft, dpo)
-- `--data_root_path`: Dataset root directory
-- `--param_path`: Path to model training parameters
-- `--n_epoch`: Total number of training epochs
-- `--batch_size`: Batch size
-- `--accumulation_steps`: Number of batches per training step
-- `--warmup_steps`: Warmup steps
-- `--max_lr`: Maximum learning rate (using warmup + cosine decay)
-- `--ckpt_interval`: Checkpoint saving interval
-- `--ckpt_dir`: Checkpoint saving directory
-- `--resume_dir`: Resume training from specified path
-
-
-
-### 👉 Usage Guide
-
-**(1). Chat with the Model:**
-
-Open `chat.py` or use the streaming/non-streaming interfaces:
-
-**Streaming Output:**
-```python
-import torch
-from khaosz import Khaosz
-
-model_dir = "your_model_parameter_dir"
-model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
-history = []
-
-while True:
- query = input(">> ")
- if query == "!exit":
- break
-
- response_size = 0
- for response, history in model.stream_generate(
- query=query,
- history=history,
- temperature=0.85,
- top_p=0.95,
- top_k=50
- ):
- print(response[response_size:], end="")
- response_size = len(response)
-```
-
-**Non-streaming Output:**
-```python
-import torch
-from khaosz import Khaosz
-
-model_dir = "your_model_parameter_dir"
-model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
-history = []
-
-while True:
- query = input(">> ")
- if query == "!exit":
- break
-
- response = model.generate(
- query=query,
- history=history,
- temperature=0.85,
- top_p=0.95,
- top_k=50
- )
- print(response)
-```
-
-**(2). Retrieval-Augmented Generation (RAG):**
-
-```python
-import torch
-from khaosz import Khaosz
-
-model_dir = "your_model_parameter_dir"
-model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
-
-retrieved_content = model.retrieve_generate(
- query=query,
- retrieve_top_k=5,
- temperature=0.6,
- top_k=30,
- top_p=0.95
-)
-print(retrieved_content)
-```
-
-
中文版本
-这是一个支持基于自回归模式的 Transfomer 语言模型训练以及推理框架
-
-**模型下载选项(任选其一):**
-
-1. 访问 [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) 查看 **Files and versions**
-2. 运行 `scripts/download.py` 下载模型参数
-
-**演示视频:** [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd)
-
-训练数据来源请参见 HuggingFace 下载页面中的 **Model Card** 部分。
-
-**许可证:** 代码遵循 GPL-3.0 协议,使用时请注明出处。
-
-- **📊 设备选择:** 默认使用 CUDA 进行训练
-- **🌐 性能优化:** 启用 `dtype=torch.bfloat16` 以加速训练并减少内存占用,请确保硬件支持该特性
-- **🤖 语言支持:** 模型支持中文和英文训练。由于 BBPE 分词器未使用多语言文本训练,因此中英文的 OOV(未登录词)问题较少,其他语言可能存在 OOV 问题
-
-
-### 📌 训练指南
-
-要训练该 Transformer 模型,请按照以下步骤操作:
-
-**(1). 准备数据集:**
-
-将数据集放置在指定的根目录下, 本系统采用 BBPE 分词器进行分词,并且要求使用已经经过分词的 token 分段训练(分段存储为 *.h5 格式)
-
-**(2). 安装依赖:**
+### Demo
```bash
+# run download before using
+python demo/download.py
+
+# run demo
+python demo/stream_chat.py
+python demo/generate_batch.py
+python demo/generate_ar.py
+```
+
+- [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd)
+
+
+### License
+
+GPL-3.0
+
+---
+
+
+## 中文
+
+### 特性
+
+- 🚀 **高性能**: 训练与推理双向优化
+- 🔧 **灵活**: 支持 seq/sft/dpo 多种训练方式
+- 💡 **易用**: 简洁的 API 与丰富的示例
+- 📦 **轻量**: 依赖少,部署简单
+
+### 安装
+
+```bash
+git clone https://github.com/username/khaosz.git
+cd khaosz
pip install -e .
```
-**(3). 运行训练脚本:**
+### 快速开始
```bash
-python train.py \
---train_type=train_type[seq, sft, dpo] \
---data_root_path=/path/to/dataset \
---param_path=/path/to/param_path \
---n_epoch=5 \
---batch_size=8 \
---max_lr=2e-4 \
---ckpt_interval=10000 \
---ckpt_dir=checkpoints
+# 训练
+python tools/train.py \
+ --train_type=seq \
+ --data_root_path=/path/to/dataset \
+ --param_path=/path/to/param_path
+
+# 生成
+python tools/generate.py --param_path=/path/to/param_path
```
-**参数说明:**
-- `--train_type`: 训练类型(seq, sft, dpo)
-- `--data_root_path`: 数据集根目录
-- `--param_path`: 模型训练参数路径
-- `--n_epoch`: 总训练轮数
-- `--batch_size`: 批量大小
-- `--accumulation_steps`: 每个训练步骤的 batch 数量
-- `--warmup_steps`: 预热步数(warmup steps)
-- `--max_lr`: 最大学习率(使用预热 + 余弦衰减)
-- `--ckpt_interval`: 检查点保存间隔
-- `--ckpt_dir`: 检查点保存目录
-- `--resume_dir`: 从指定路径恢复训练
+### 演示
+```bash
+# 使用前先下载模型
+python demo/download.py
-
-### 👉 使用指南
-
-**(1). 与模型对话:**
-
-打开 `chat.py` 或使用流式/非流式接口:
-
-**流式输出:**
-```python
-import torch
-from khaosz import Khaosz
-
-model_dir = "your_model_parameter_dir"
-model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
-history = []
-
-while True:
- query = input(">> ")
- if query == "!exit":
- break
-
- response_size = 0
- for response, history in model.stream_generate(
- query=query,
- history=history,
- temperature=0.85,
- top_p=0.95,
- top_k=50
- ):
- print(response[response_size:], end="")
- response_size = len(response)
+# 运行示例
+python demo/stream_chat.py
+python demo/generate_batch.py
+python demo/generate_ar.py
```
-**非流式输出:**
-```python
-import torch
-from khaosz import Khaosz
+- [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd)
-model_dir = "your_model_parameter_dir"
-model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
-history = []
+### 许可证
-while True:
- query = input(">> ")
- if query == "!exit":
- break
-
- response = model.generate(
- query=query,
- history=history,
- temperature=0.85,
- top_p=0.95,
- top_k=50
- )
- print(response)
-```
+GPL-3.0
-**(2). 基于检索的生成(RAG):**
+---
-```python
-import torch
-from khaosz import Khaosz
+
+## 📚 Documentation | 文档
-model_dir = "your_model_parameter_dir"
-model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
+| Document | 说明 |
+|----------|------|
+| [参数说明](assets/docs/params.md) | Training & inference parameters |
+| [设计文档](assets/docs/design.md) | Framework design |
+| [数据流程](assets/docs/dataflow.md) | Data processing pipeline |
+| [模型介绍](assets/docs/introduction.md) | Model architecture |
-retrieved_content = model.retrieve_generate(
- query=query,
- retrieve_top_k=5,
- temperature=0.6,
- top_k=30,
- top_p=0.95
-)
-print(retrieved_content)
-```
\ No newline at end of file
+### Download | 下载
+
+- [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ)
+- `python demo/download.py`
\ No newline at end of file
diff --git a/assets/docs/dataflow.md b/assets/docs/dataflow.md
index 6ba9c74..639b96b 100644
--- a/assets/docs/dataflow.md
+++ b/assets/docs/dataflow.md
@@ -1,205 +1,205 @@
-# KHAOSZ 数据流文档
+# KHAOSZ Data Flow Documentation
-本文档描述 KHAOSZ 项目(一个自回归 Transformer 语言模型的训练与推理框架)的数据流。涵盖从原始数据到模型训练、推理的完整流程。
+This document describes the data flow of the KHAOSZ project (a training and inference framework for autoregressive Transformer language models). It covers the complete flow from raw data to model training and inference.
-## 概述
+## Overview
-KHAOSZ 采用模块化设计,主要组件包括:
-- **数据模块** (`khaosz/data/`): 数据集、采样器、分词器、序列化工具
-- **模型模块** (`khaosz/model/`): Transformer 模型及其子模块
-- **训练模块** (`khaosz/trainer/`): 训练器、训练上下文、策略、调度器
-- **推理模块** (`khaosz/inference/`): 生成核心、KV 缓存管理、流式生成
-- **配置模块** (`khaosz/config/`): 模型、训练、调度等配置
-- **并行模块** (`khaosz/parallel/`): 分布式训练支持
+KHAOSZ adopts a modular design with the following main components:
+- **Data Module** (`khaosz/data/`): Dataset, sampler, tokenizer, serialization tools
+- **Model Module** (`khaosz/model/`): Transformer model and its submodules
+- **Training Module** (`khaosz/trainer/`): Trainer, training context, strategies, schedulers
+- **Inference Module** (`khaosz/inference/`): Generation core, KV cache management, streaming generation
+- **Config Module** (`khaosz/config/`): Model, training, scheduler, and other configurations
+- **Parallel Module** (`khaosz/parallel/`): Distributed training support
-数据流总体可分为 **训练数据流** 与 **推理数据流** 两条主线。
+The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**.
-## 数据流图
+## Data Flow Diagram
```mermaid
flowchart LR
- subgraph A[数据准备]
+ subgraph A[Data Preparation]
direction TB
- A1[原始文本] --> A2[BBPE 分词器]
- A2 --> A3[序列化为 .h5 文件]
- A3 --> A4[数据集加载
BaseDataset]
- A4 --> A5[可恢复分布式采样器
ResumableDistributedSampler]
- A5 --> A6[DataLoader 批量加载]
+ A1[Raw Text] --> A2[BBPE Tokenizer]
+ A2 --> A3[Serialize to .h5 files]
+ A3 --> A4[Dataset Loading
BaseDataset]
+ A4 --> A5[Resumable Distributed Sampler
ResumableDistributedSampler]
+ A5 --> A6[DataLoader Batch Loading]
end
- subgraph B[训练循环]
+ subgraph B[Training Loop]
direction TB
- B1[批次数据] --> B2[训练策略
BaseStrategy]
- B2 --> B3[Transformer 模型]
- B3 --> B4[输出 logits]
- B4 --> B5[损失计算]
- B5 --> B6[反向传播]
- B6 --> B7[优化器更新]
- B7 --> B8[学习率调度器]
- B8 --> B9[检查点保存]
+ B1[Batch Data] --> B2[Training Strategy
BaseStrategy]
+ B2 --> B3[Transformer Model]
+ B3 --> B4[Output logits]
+ B4 --> B5[Loss Calculation]
+ B5 --> B6[Backpropagation]
+ B6 --> B7[Optimizer Update]
+ B7 --> B8[Learning Rate Scheduler]
+ B8 --> B9[Checkpoint Save]
end
- subgraph C[推理生成]
+ subgraph C[Inference Generation]
direction TB
- C1[检查点加载] --> C2[推理模型加载]
- C2 --> C3[生成核心
GeneratorCore]
- C3 --> C4[采样策略
温度/top‑k/top‑p]
- C4 --> C5[生成下一个 token]
- C5 --> C6[KV 缓存更新]
- C6 --> C7{是否达到最大长度?}
- C7 -->|否| C5
- C7 -->|是| C8[输出生成文本]
+ C1[Checkpoint Loading] --> C2[Inference Model Loading]
+ C2 --> C3[Generation Core
GeneratorCore]
+ C3 --> C4[Sampling Strategy
Temperature/top-k/top-p]
+ C4 --> C5[Generate Next Token]
+ C5 --> C6[KV Cache Update]
+ C6 --> C7{Max Length Reached?}
+ C7 -->|No| C5
+ C7 -->|Yes| C8[Output Generated Text]
end
A --> B
B --> C
```
-## 各模块详细说明
+## Detailed Module Descriptions
-### 1. 数据模块
+### 1. Data Module
-#### 1.1 分词器 (`tokenizer.py`)
-- 基于 Byte‑Level BPE (BBPE) 实现
-- 支持特殊 token:`
`, ``, ``, `<|im_start|>`, `<|im_end|>`
-- 提供 `encode`/`decode` 方法,将文本与 token ID 相互转换
-- 训练时从语料库学习词汇表,保存为 `.json` 文件
+#### 1.1 Tokenizer (`tokenizer.py`)
+- Implemented based on Byte-Level BPE (BBPE)
+- Supports special tokens: ``, ``, ``, `<|im_start|>`, `<|im_end|>`
+- Provides `encode`/`decode` methods for mutual conversion between text and token IDs
+- Learns vocabulary from corpus during training, saved as `.json` files
-#### 1.2 序列化 (`serialization.py`)
-- **`save_h5`**: 将多个张量按组保存为 HDF5 文件(`.h5`),每个键对应一个张量列表
-- **`load_h5`**: 加载 `.h5` 文件,返回 `Dict[str, List[Tensor]]`,支持共享内存 (`share_memory=True`)
-- **`Checkpoint` 类**: 封装模型状态字典、训练轮次、迭代次数,支持 safetensors 格式保存与加载
+#### 1.2 Serialization (`serialization.py`)
+- **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors
+- **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory (`share_memory=True`)
+- **`Checkpoint` class**: Encapsulates model state dict, training epoch, iteration count; supports safetensors format for saving and loading
-#### 1.3 数据集 (`dataset.py`)
-- **`BaseDataset`**: 抽象基类,定义窗口采样、步长等通用逻辑
-- **`BaseSegmentFetcher`** 与 **`MultiSegmentFetcher`**: 高效地从多个分段中获取指定索引范围的数据
-- **`DatasetFactory`**: 工厂模式,支持动态注册数据集类型(`seq`, `sft`, `dpo`, `grpo`)
-- 数据集加载后通过 `MultiSegmentFetcher` 管理多个数据键(如 `"sequence"`, `"mask"`)
+#### 1.3 Dataset (`dataset.py`)
+- **`BaseDataset`**: Abstract base class, defines common logic for window sampling, stride, etc.
+- **`BaseSegmentFetcher`** and **`MultiSegmentFetcher`**: Efficiently fetch data from specified index ranges in multiple segments
+- **`DatasetFactory`**: Factory pattern, supports dynamic registration of dataset types (`seq`, `sft`, `dpo`, `grpo`)
+- After dataset loading, multiple data keys (such as `"sequence"`, `"mask"`) are managed through `MultiSegmentFetcher`
-#### 1.4 采样器 (`sampler.py`)
-- **`ResumableDistributedSampler`**: 支持分布式训练的可恢复采样器
-- 记录当前 epoch 和迭代位置,便于从断点继续训练
-- 支持 shuffle 与 drop_last 选项
+#### 1.4 Sampler (`sampler.py`)
+- **`ResumableDistributedSampler`**: Resumable sampler supporting distributed training
+- Records current epoch and iteration position, enabling training resume from breakpoints
+- Supports shuffle and drop_last options
-### 2. 模型模块
+### 2. Model Module
#### 2.1 Transformer (`transformer.py`)
-- 核心自回归解码器架构
-- 包含嵌入层、多层 `DecoderBlock`、RMSNorm 和线性输出头
-- 支持权重绑定 (`tie_weight=True`) 以减小参数量
-- 使用 Rotary Position Embedding (RoPE) 注入位置信息
+- Core autoregressive decoder architecture
+- Contains embedding layer, multi-layer `DecoderBlock`, RMSNorm, and linear output head
+- Supports weight tying (`tie_weight=True`) to reduce parameter count
+- Uses Rotary Position Embedding (RoPE) to inject position information
-#### 2.2 子模块 (`module.py`)
-- **`RotaryEmbedding`**: 生成 RoPE 的 cos/sin 缓存
-- **`DecoderBlock`**: 包含多头注意力(支持 GQA)、前馈网络(FFN)、残差连接
-- **`RMSNorm`**: 层归一化变体
-- **`Linear`**, **`Embedding`**: 自定义线性层与嵌入层,支持并行化包装
+#### 2.2 Submodules (`module.py`)
+- **`RotaryEmbedding`**: Generates RoPE cos/sin cache
+- **`DecoderBlock`**: Contains multi-head attention (supports GQA), feedforward network (FFN), residual connections
+- **`RMSNorm`**: Layer normalization variant
+- **`Linear`**, **`Embedding`**: Custom linear layer and embedding layer, supporting parallelism wrappers
-### 3. 训练模块
+### 3. Training Module
-#### 3.1 训练上下文 (`train_context.py`)
-- **`TrainContext`**: 数据类,封装训练所需的所有组件(模型、优化器、数据加载器、策略等)
-- **`TrainContextBuilder`**: 构建器模式,逐步组装训练上下文,支持从检查点恢复
+#### 3.1 Training Context (`train_context.py`)
+- **`TrainContext`**: Data class encapsulating all components needed for training (model, optimizer, data loader, strategy, etc.)
+- **`TrainContextBuilder`**: Builder pattern, progressively assembles training context, supports resume from checkpoint
-#### 3.2 训练器 (`trainer.py`)
-- **`Trainer`**: 主训练循环,管理回调函数(进度条、检查点、指标记录、梯度裁剪、调度器)
-- 支持分布式训练(通过 `spawn_parallel_fn` 启动多进程)
-- 训练步骤包括:
- 1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. 前向/损失计算 → 5. `on_batch_end` → 6. 梯度累积 → 7. `on_step_begin` → 8. 优化器更新 → 9. `on_step_end` → 10. `on_epoch_end`
+#### 3.2 Trainer (`trainer.py`)
+- **`Trainer`**: Main training loop, manages callbacks (progress bar, checkpoint, metric logging, gradient clipping, scheduler)
+- Supports distributed training (launches multi-process via `spawn_parallel_fn`)
+- Training steps include:
+ 1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end`
-#### 3.3 策略 (`strategy.py`)
-- **`BaseStrategy`**: 定义训练策略接口(如 `SeqStrategy`, `SFTStrategy`, `DPOStrategy`)
-- 策略接收批次数据,执行模型前向传播、损失计算,返回 loss 张量
-- 由 `StrategyFactory` 根据配置动态创建
+#### 3.3 Strategy (`strategy.py`)
+- **`BaseStrategy`**: Defines training strategy interface (such as `SeqStrategy`, `SFTStrategy`, `DPOStrategy`)
+- Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor
+- Created dynamically by `StrategyFactory` according to configuration
-#### 3.4 调度器 (`schedule.py`)
-- **`BaseScheduler`**: 抽象基类,定义学习率调度接口
-- **`SchedulerFactory`**: 工厂模式,支持注册多种调度器(如 `cosine`, `sgdr`)
-- 调度器根据配置自动创建,并与优化器绑定
+#### 3.4 Scheduler (`schedule.py`)
+- **`BaseScheduler`**: Abstract base class defining learning rate scheduling interface
+- **`SchedulerFactory`**: Factory pattern, supports registration of various schedulers (such as `cosine`, `sgdr`)
+- Scheduler is automatically created according to configuration and bound to optimizer
-### 4. 推理模块
+### 4. Inference Module
-#### 4.1 生成核心 (`core.py`)
-- **`GeneratorCore`**: 提供 `generate_iterator` 方法,执行单步生成
-- 应用采样策略(温度、top‑k、top‑p)对 logits 进行筛选
-- 支持 KV 缓存以加速自回归生成
+#### 4.1 Generation Core (`core.py`)
+- **`GeneratorCore`**: Provides `generate_iterator` method, executes single-step generation
+- Applies sampling strategies (temperature, top-k, top-p) to filter logits
+- Supports KV cache to accelerate autoregressive generation
-#### 4.2 KV 缓存管理 (`core.py`)
-- **`KVCacheManager`**: 管理每层的 K 和 V 缓存,支持批量生成与长度扩展
-- 缓存形状为 `[batch_size, n_kv_heads, seq_len, head_dim]`
+#### 4.2 KV Cache Management (`core.py`)
+- **`KVCacheManager`**: Manages K and V cache for each layer, supports batch generation and length extension
+- Cache shape is `[batch_size, n_kv_heads, seq_len, head_dim]`
-#### 4.3 生成器 (`generator.py`)
-- **`GenerationRequest`**: 封装生成请求参数(top_k, top_p, temperature, max_len, query, history 等)
-- **`build_prompt`**: 将查询与历史记录转换为 ChatML 格式的提示字符串
-- **`pad_sequence`**: 对输入 ID 进行填充,使其长度一致
-- 提供流式与非流式生成接口
+#### 4.3 Generator (`generator.py`)
+- **`GenerationRequest`**: Encapsulates generation request parameters (top_k, top_p, temperature, max_len, query, history, etc.)
+- **`build_prompt`**: Converts query and history into ChatML format prompt string
+- **`pad_sequence`**: Pads input IDs to consistent length
+- Provides streaming and non-streaming generation interfaces
-## 训练数据流详细步骤
+## Training Data Flow - Detailed Steps
-1. **数据准备**
- - 原始文本经过 BBPE 分词器转换为 token ID 序列
- - 将 token ID 序列(可能带有掩码、标签等)按组保存为 `.h5` 文件
- - 文件可包含多个分段,每个分段对应一个张量
+1. **Data Preparation**
+ - Raw text is converted to token ID sequences through BBPE tokenizer
+ - Token ID sequences (possibly with masks, labels, etc.) are saved by groups as `.h5` files
+ - Files can contain multiple segments, each segment corresponds to a tensor
-2. **数据集加载**
- - `BaseDataset` 的 `load` 方法调用 `load_h5`,得到 `segments` 字典
- - 创建 `MultiSegmentFetcher` 管理多个键的数据
- - 计算总样本数,并根据窗口大小、步长确定每个样本的起始/结束索引
+2. **Dataset Loading**
+ - `BaseDataset`'s `load` method calls `load_h5`, obtaining `segments` dictionary
+ - Create `MultiSegmentFetcher` to manage data for multiple keys
+ - Calculate total sample count, and determine start/end indices for each sample based on window size and stride
-3. **采样与批量加载**
- - `ResumableDistributedSampler` 根据当前 epoch 和迭代位置生成索引序列
- - `DataLoader` 使用采样器获取索引,调用数据集的 `__getitem__` 获取实际数据
- - 批量数据形状为 `[batch_size, window_size]`(或根据具体数据集类型变化)
+3. **Sampling and Batch Loading**
+ - `ResumableDistributedSampler` generates index sequence based on current epoch and iteration position
+ - `DataLoader` uses sampler to get indices, calls dataset's `__getitem__` to get actual data
+ - Batch data shape is `[batch_size, window_size]` (or varies according to specific dataset type)
-4. **策略前向与损失计算**
- - 批次数据传入策略(如 `SeqStrategy`)
- - 策略内部调用 `Transformer` 模型,得到 logits
- - 根据任务类型计算交叉熵损失(或 DPO 损失等)
- - 返回 loss 张量
+4. **Strategy Forward and Loss Calculation**
+ - Batch data is passed to strategy (such as `SeqStrategy`)
+ - Strategy internally calls `Transformer` model, obtaining logits
+ - Calculate cross-entropy loss (or DPO loss, etc.) according to task type
+ - Return loss tensor
-5. **反向传播与优化**
- - 损失除以累积步数进行归一化,然后执行 `loss.backward()`
- - 每累积 `accumulation_steps` 个批次后,执行优化器 `step()` 和 `zero_grad()`
- - 学习率调度器在每个 step 后更新学习率
+5. **Backpropagation and Optimization**
+ - Loss is normalized by dividing by accumulation steps, then `loss.backward()` is executed
+ - After accumulating `accumulation_steps` batches, optimizer `step()` and `zero_grad()` are executed
+ - Learning rate scheduler updates learning rate after each step
-6. **检查点保存**
- - `CheckpointCallback` 按设定的间隔保存检查点
- - 检查点包含模型状态字典、当前 epoch、iteration 等元数据
- - 使用 safetensors 格式保存,确保安全与效率
+6. **Checkpoint Saving**
+ - `CheckpointCallback` saves checkpoints at set intervals
+ - Checkpoints contain model state dict, current epoch, iteration, and other metadata
+ - Saved in safetensors format, ensuring safety and efficiency
-## 推理数据流详细步骤
+## Inference Data Flow - Detailed Steps
-1. **模型加载**
- - 从检查点加载 `Transformer` 模型与分词器
- - 模型设置为评估模式 (`model.eval()`),启用推理模式 (`torch.inference_mode`)
+1. **Model Loading**
+ - Load `Transformer` model and tokenizer from checkpoint
+ - Set model to evaluation mode (`model.eval()`), enable inference mode (`torch.inference_mode`)
-2. **提示构建与编码**
- - 用户查询与历史记录通过 `build_prompt` 转换为 ChatML 格式字符串
- - 分词器将提示字符串编码为 token ID 序列 `input_ids`
- - 若为批量生成,使用 `pad_sequence` 进行填充
+2. **Prompt Construction and Encoding**
+ - User query and history are converted to ChatML format string through `build_prompt`
+ - Tokenizer encodes prompt string to token ID sequence `input_ids`
+ - For batch generation, use `pad_sequence` for padding
-3. **自回归生成循环**
- - 初始化 KV 缓存(可选)
- - 循环直到生成 `max_len` 个 token 或遇到停止 token:
- - 将当前 `input_ids`(或缓存后的新 token)输入模型,得到 `logits`
- - 对 `logits` 应用 `apply_sampling_strategies`(温度、top‑k、top‑p)
- - 从处理后的分布中采样得到下一个 token ID
- - 将新 token 追加到 `input_ids`,同时更新 KV 缓存
- - 若为流式生成,每生成一个 token 立即 yield 给调用方
+3. **Autoregressive Generation Loop**
+ - Initialize KV cache (optional)
+ - Loop until generating `max_len` tokens or encountering stop token:
+ - Input current `input_ids` (or cached new token) to model, obtain `logits`
+ - Apply `apply_sampling_strategies` (temperature, top-k, top-p) to `logits`
+ - Sample next token ID from the processed distribution
+ - Append new token to `input_ids`, while updating KV cache
+ - For streaming generation, yield each token to caller immediately
-4. **解码与输出**
- - 将生成的 token ID 序列通过分词器解码为文本
- - 去除特殊 token,返回纯文本响应
+4. **Decoding and Output**
+ - Decode generated token ID sequence to text through tokenizer
+ - Remove special tokens, return plain text response
-## 检查点与序列化
+## Checkpoint and Serialization
-- **训练检查点**:保存模型参数、优化器状态、调度器状态、当前 epoch 与 iteration
-- **模型参数**:支持 safetensors 格式,加载时自动处理权重绑定等特殊逻辑
-- **数据集序列化**:HDF5 格式支持高效随机读取与共享内存,适合大规模预训练数据
+- **Training Checkpoint**: Saves model parameters, optimizer state, scheduler state, current epoch and iteration
+- **Model Parameters**: Supports safetensors format, automatically handles special logic like weight tying during loading
+- **Dataset Serialization**: HDF5 format supports efficient random access and shared memory, suitable for large-scale pre-training data
-## 总结
+## Summary
-KHAOSZ 的数据流设计体现了模块化、可扩展、可恢复的特点。训练数据流通过分块加载、可恢复采样、梯度累积等机制支持大规模分布式训练;推理数据流则利用 KV 缓存、采样策略实现高效的文本生成。各模块之间通过清晰的接口耦合,便于定制与扩展。
+The data flow design of KHAOSZ reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using KV cache and sampling strategies. Clear interfaces between modules facilitate customization and extension.
-> 文档更新时间:2026‑03‑30
-> 对应代码版本:参考 `pyproject.toml` 中定义的版本号
\ No newline at end of file
+> Document Update Time: 2026-03-30
+> Corresponding Code Version: Refer to version number defined in `pyproject.toml`
\ No newline at end of file
diff --git a/assets/docs/design.md b/assets/docs/design.md
index 2f6c658..ba7bc61 100644
--- a/assets/docs/design.md
+++ b/assets/docs/design.md
@@ -1,16 +1,16 @@
-## 1. 为什么我要做这个项目?
+## 1. Why I Created This Project
-现在市面上有很多大模型,比如GPT、LLaMA这些,动不动就是几十亿甚至上千亿参数。但说实话,这些模型对硬件要求太高了,普通开发者根本玩不起。我就想:**能不能做一个既好用又能在普通电脑上跑起来的模型呢?** 这其实也是目前大部分人的期望, 能有一个可以本地部署的ai小型项目,实现完全私有化并且有一定的智能能力。
+There are many large language models on the market today, such as GPT, LLaMA, and others, with tens of billions or even hundreds of billions of parameters. But honestly, these models have extremely high hardware requirements, making them inaccessible for ordinary developers. I thought: **Can we create a model that is both useful and can run on ordinary computers?** This is also what most people currently hope for - a locally deployable AI project that achieves complete privatization while maintaining some level of intelligence.
-于是就有了这个KHAOSZ项目,1B参数,中英双语,支持对话、文本生成、RAG检索,而且训练代码都是开源的!
+Thus, the KHAOSZ project was born - 1B parameters, Chinese-English bilingual, supporting dialogue, text generation, RAG retrieval, and the training code is open source!
-## 2. 系统架构
+## 2. System Architecture
-系统分为以下板块
+The system is divided into the following modules:
```mermaid
graph LR
- %% 样式定义
+ %% Style definitions
classDef config fill:#e1f5fe,stroke:#01579b;
classDef trainer fill:#f3e5f5,stroke:#4a148c;
classDef data fill:#e8f5e8,stroke:#1b5e20;
@@ -18,16 +18,16 @@ graph LR
classDef inference fill:#fce4ec,stroke:#880e4f;
classDef parallel fill:#e0f2f1,stroke:#004d40;
- %% 配置模块
- subgraph Config["Config(配置模块)"]
+ %% Config module
+ subgraph Config["Config"]
C1[model_config.py]
C2[train_config.py]
C3[scheduler_config.py]
end
class Config config;
- %% 训练器模块
- subgraph Trainer["Trainer(训练器模块)"]
+ %% Trainer module
+ subgraph Trainer["Trainer"]
T1[trainer.py]
T2[train_content.py]
T3[schedule.py]
@@ -36,8 +36,8 @@ graph LR
end
class Trainer trainer;
- %% 数据模块
- subgraph Data["Data(数据模块)"]
+ %% Data module
+ subgraph Data["Data"]
D1[dataset.py]
D2[sampler.py]
D3[mmap.py]
@@ -46,175 +46,159 @@ graph LR
end
class Data data;
- %% 模型模块
- subgraph Model["Model(模型模块)"]
+ %% Model module
+ subgraph Model["Model"]
M1[transformer.py]
M2[module.py]
end
class Model model;
- %% 推理模块
- subgraph Inference["Inference(推理模块)"]
+ %% Inference module
+ subgraph Inference["Inference"]
I1[generator.py]
I2[core.py]
end
class Inference inference;
- %% 并行模块
- subgraph Parallel["Parallel(并行模块)"]
+ %% Parallel module
+ subgraph Parallel["Parallel"]
P1[setup.py]
P2[module.py]
end
class Parallel parallel;
- %% 配置依赖
+ %% Config dependencies
C2 -.-> T1
C1 -.-> M1
C3 -.-> T3
- %% 训练器内部依赖
+ %% Trainer internal dependencies
T1 --> T5
T1 --> T2
T2 --> T3
T2 --> T4
- %% 数据流
+ %% Data flow
D1 --> D2
D1 --> D3
D1 --> D4
D1 --> D5
- %% 模型依赖
+ %% Model dependencies
M1 --> M2
- %% 推理依赖
+ %% Inference dependencies
I1 --> I2
- %% 跨模块依赖
+ %% Cross-module dependencies
T2 -.-> M1
I1 -.-> M1
T2 -.-> D1
T1 -.-> P1
```
+### 1. Configuration Management (/config/)
+- **Model Configuration**: Defines model structure parameters (such as layers, heads, dimensions, etc.), managed uniformly through `ModelConfig`.
+- **Training Configuration**: Sets training parameters (such as batch size, training stages PT/SFT/DPO, optimizers, etc.), loaded by `TrainConfig`.
+- **Scheduler Configuration**: Controls learning rate strategies (such as cosine annealing) and training progress.
-### 1. 配置管理(/config/)
-- **模型配置**:定义模型结构参数(如层数、头数、维度等),通过 `ModelConfig` 统一管理。
-- **训练配置**:设置训练参数(如批次大小、训练阶段 PT/SFT/DPO、优化器等),由 `TrainConfig` 加载。
-- **调度配置**:控制学习率策略(如余弦退火)和训练进度。
+### 2. Hardware and Parallelism (/parallel/)
+- **Distributed Initialization**: Initializes multi-GPU/multi-machine training environments through the `setup_parallel` function according to configuration.
-### 2. 硬件与并行(/parallel/)
-- **分布式初始化**:通过 `setup_parallel` 函数,根据配置初始化多卡/多机训练环境。
+### 3. Data Processing (/data/)
+- **Efficient Loading**: Uses memory mapping (mmap) technology to load massive corpora, avoiding memory overflow and achieving zero-copy reading.
-### 3. 数据处理(/data/)
-- **高效加载**:使用内存映射(mmap)技术加载超大语料,避免内存溢出,实现零拷贝读取。
+### 4. Model and Training (/model/, /trainer/)
+- **Unified Model Architecture**: Based on Transformer, supporting flexible configuration of different scales (such as 7B, 13B).
+- **Strategy-based Trainer**: `Trainer` automatically switches training strategies according to training stages (PT/SFT/DPO), reusing the same training loop.
+- **Training Context Management**: Unifies management of model, optimizer, scheduler, and metrics, supporting seamless multi-stage transitions.
-### 4. 模型与训练(/model/, /trainer/)
-- **统一模型架构**:基于 Transformer,支持灵活配置不同规模(如7B、13B)。
-- **策略化训练器**:`Trainer` 根据训练阶段(PT/SFT/DPO)自动切换训练策略,复用同一训练循环。
-- **训练上下文管理**:统一管理模型、优化器、调度器和指标,支持多阶段无缝衔接。
+### 5. Inference Service (/inference/, /utils/)
+- **Unified Generation Interface**: Provides synchronous, batch, and streaming generation methods, adapting to all training stages.
+- **KV Cache Optimization**: Caches Key/Value during autoregressive generation, utilizing high-speed on-chip memory acceleration on NVIDIA GPU.
+- **RAG Support**: Combines retriever and embedding models to inject relevant information from external knowledge bases, improving answer quality.
+- **Intelligent Text Segmentation**:
+ - **Structure-first Segmentation**: Splits by titles, paragraphs, etc.;
+ - **Semantic Segmentation**: Based on sentence embedding similarity, ensuring fragment semantic completeness and improving fine-tuning effects.
-### 5. 推理服务(/inference/, /utils/)
-- **统一生成接口**:提供同步、批量、流式生成方法,适配所有训练阶段。
-- **KV缓存优化**:在自回归生成中缓存 Key/Value,昇腾XPU下利用高速片上内存加速。
-- **RAG支持**:结合检索器和嵌入模型,从外部知识库注入相关信息,提升回答质量。
-- **智能文本分割**:
- - **结构优先分割**:按标题、段落等切分;
- - **语义分割**:基于句子嵌入相似度,确保片段语义完整,提升微调效果。
+## 3. Training Process
+The common training process for large language models (LLM) typically includes three stages: **Pre-training (PT)**, **Supervised Fine-Tuning (SFT)**, and **Reinforcement Learning from Human Feedback (RLHF)**. This system is designed to support seamless end-to-end flow, achieving efficient switching and state management of different training stages through modular strategies, ensuring the model's capabilities gradually evolve from general language understanding to human-preference-aligned dialogue and instruction execution.
-## 3. 训练流程
+### **2.1 Pre-training Stage**
-常见大语言模型(Large Language Model, LLM)的训练流程通常包含三个阶段:**预训练(Pre-training, PT)**、**监督微调(Supervised Fine-Tuning, SFT)** 以及 **基于人类反馈的强化学习(Reinforcement Learning from Human Feedback, RLHF)**。本系统设计支持全流程无缝衔接,通过模块化策略实现不同训练阶段的高效切换与状态管理,确保模型能力从通用语言理解逐步对齐至符合人类偏好的对话与指令执行。
+The pre-training stage aims to build the model's foundational language capabilities and general knowledge representation. This stage performs self-supervised learning on large-scale, unlabeled corpora (typically covering hundreds of GB to TB of text data). The model architecture is based on the standard Transformer Decoder, trained through masked language modeling objectives (such as causal language modeling), enabling the model to learn vocabulary, grammar, semantics, and world knowledge embedded in text.
-### **2.1 预训练阶段**
-
-预训练阶段旨在构建模型的基础语言能力与通用知识表示。该阶段在大规模、无标注的语料库(通常涵盖数百GB至数TB的文本数据)上进行自监督学习。模型架构基于标准的Transformer Decoder,通过掩码语言建模(如因果语言建模)目标进行训练,使模型能够学习词汇、语法、语义及蕴含于文本中的世界知识。
-
-**核心公式:因果语言建模(Causal Language Modeling)**
+**Core Formula: Causal Language Modeling**
$$
L_{\text{PT}} = - \sum_{t=1}^{T} \log P(x_t \mid x_{\lt t}; \theta)
$$
-**符号说明:**
+**Symbol Description:**
-- $T$:序列长度
-- $x_t$:序列中第 $ t $ 个词元(token)
-- $x_{ {"token": "不错", "probability": 0.4}
--> {"token": "晴朗", "probability": 0.2}
--> ......
-```
-
-这里,“不错”和“晴朗”是两个可能跟随在“天气”之后的tokens,并且给出了每个token成为下一个token的可能性大小。
-
-之后,我们通过采样(通过top_k, top_p, temperature参数调整采样后的结果)得到下一个token并且将下一个token加入序列作为输入
-
-```
-["你好", "," "今天", "天气", "不错"]
-```
-
-之后都是在重复这个流程, 直到遇到控制流程结束的token(<|end_of_seqence|>)模型停止处理(一般模型都会设置控制token, 不然模型会一直输出到显存爆炸)。
-
-
-
-
-
-#### 2. 因果掩码
-
-transformer 中采用注意力机制,输入的形状一般为[bsz, seq_len], 输出为[bsz, seq_len,n_dim], 为了实现预测下一个token, 模型的输入和输出必须错开来一个位置。模型预测的target必须错开一个位置, 在训练的时候我们也采用错开一个位置的方法
+Transformers use attention mechanism. The input shape is generally [bsz, seq_len], and the output is [bsz, seq_len, n_dim]. To predict the next token, the model's input and output must be offset by one position. The target predicted by the model must be offset by one position, and during training we also use the offset-by-one method:
```
sequence : [[1, 2, 3, 4, 5, 6]]
@@ -52,18 +28,14 @@ input_ids: [[1, 2, 3, 4, 5]]
target_ids: [[2, 3, 4, 5, 6]]
```
-
-
-注意力得分计算的公式为
-
+The attention score calculation formula is:
$$ s_{ij} = softmax(\frac{q_i^Tk_j}{\sqrt{d_k}}) $$
$$ s_{ij} := s_{ij} + mask_{ij} $$
+Here, the attention score represents the degree to which the model attends to the similarity between two tokens.
-其中注意力得分代表了模型对两个token之间相似程度的关注程度
-
-对于decoder only结构的模型, 为了防止模型从未来的位置偷到信息, 在注意力的计算过程中需要增加掩码,我们需要在注意力得分计算之前应用一个掩码。这个掩码通常是一个下三角矩阵,对于长度为n的序列,它的形状是[n, n]。下面以一个长度为5的序列为例,展示如何创建这样的因果掩码矩阵:
+For decoder-only structure models, to prevent the model from "stealing" information from future positions, a mask needs to be added during attention calculation. We need to apply a mask before attention score calculation. This mask is typically a lower triangular matrix, and for a sequence of length n, its shape is [n, n]. Below is an example of how to create such a causal mask matrix for a sequence of length 5:
```
[[0, -inf, -inf, -inf, -inf],
@@ -73,25 +45,21 @@ $$ s_{ij} := s_{ij} + mask_{ij} $$
[0, 0, 0, 0, 0]]
```
-在这个矩阵中,0表示可以注意到的位置,而-inf表示应该被掩盖(即不应注意到)的位置。因为这个句子保证了注意力得分中 $j > i$ 的部分通过softmax 之后由`inf` 变成0, 也就是模型不能看到未来的信息
+In this matrix, 0 represents positions that can be attended to, while -inf represents positions that should be masked (i.e., should not be attended to). Because this matrix ensures that after the softmax, the parts of the attention scores where $j > i$ change from `inf` to 0, meaning the model cannot see future information.
+#### 3. Rotary Position Embedding
-
-#### 3. 旋转位置编码
-
-旋转位置编码(Rotary Position Embedding, RoPE)是一种为了解决Transformer模型中缺乏对序列位置信息直接建模的问题而设计的位置编码方法。与传统的位置编码(如正弦和余弦函数的位置编码)不同,RoPE通过将位置信息直接嵌入到查询(Query, Q)和键(Key, K)向量中来实现,使得模型能够更自然地处理序列中的相对位置关系。
-
+Rotary Position Embedding (RoPE) is a position encoding method designed to solve the problem of lacking direct modeling of sequence position information in Transformer models. Unlike traditional position encodings (such as sine and cosine function position encodings), RoPE embeds position information directly into the Query (Q) and Key (K) vectors, allowing the model to more naturally handle relative position relationships in sequences.
$$ q_i = R_i W_q x_i $$
$$ k_j = R_j W_k x_j $$
$$ q_i^T k_j = (R_i W_q x_i)^T( R_j W_k x_j) = x_i^T W_q^T R_{i-j} W_k x_j $$
-其中的 $R_{i-j}$ 控制了模型的不同token 在不同相对距离上注意力的衰减,在 $i - j$ 绝对值越大的时候, 衰减的程度越强, 通过这种方式能让模型学习到相对位置关系, 从而使得模型可以扩展和适应长序列
+The $R_{i-j}$ controls the attenuation of attention for different tokens at different relative distances. When the absolute value of $i - j$ is larger, the degree of attenuation is stronger. This approach allows the model to learn relative position relationships, enabling the model to scale and adapt to longer sequences.
+## KV Cache Implementation
-## kv_cache 实现
-
-根据注意力的计算公式
+According to the attention calculation formula:
$$
\begin{align*}
@@ -100,7 +68,7 @@ s_{ij} &= \text{softmax}\left( \frac{q_{i} k_{j}}{\sqrt{d_k}} \right)
\end{align*}
$$
-由于模型是自回归模型, 我们只用求序列最后一个部分,也就是说 $ i $ 的下标是确定的, 是序列最后一个元素, 我们求的是 $o_{n} $
+Since the model is an autoregressive model, we only need to calculate for the last part of the sequence, meaning the index $i$ is fixed as the last element of the sequence, and we compute $o_{n}$:
$$
\begin{align*}
@@ -109,10 +77,10 @@ s_j &= \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}} \right)
\end{align*}
$$
-如果我们把式子展开
+If we expand the expression:
$$
o_n = \sum_j \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}}\right)v_{j}
$$
-以上表达式只有k和v存在长度下标, 而 $q$ 没有, 所以计算过程中 $q$ 的输入是确定的上次输入的最后一个token, 而 $k, v$ 是需要对不同长度的部分进行缓存的,同时缓存的时候应该注意位置编码的计算应该在kvcache的计算之前进行,否则会存在位置编码的计算错误
\ No newline at end of file
+In the above expression, only k and v have length indices, while $q$ does not. Therefore, during the calculation process, the input of $q$ is fixed as the last token from the previous input, while $k$ and $v$ need to be cached for parts of different lengths. Also, when caching, note that position encoding calculation should be performed before KV cache computation, otherwise there will be position encoding calculation errors.
\ No newline at end of file
diff --git a/assets/docs/params.md b/assets/docs/params.md
new file mode 100644
index 0000000..965812e
--- /dev/null
+++ b/assets/docs/params.md
@@ -0,0 +1,115 @@
+# Parameter Documentation
+
+## Training Parameters
+
+### Basic Parameters
+
+| Parameter | Description | Default Value |
+|-----------|-------------|---------------|
+| `--train_type` | Training type (seq, sft, dpo) | required |
+| `--data_root_path` | Dataset root directory | required |
+| `--param_path` | Model parameters or checkpoint path | required |
+| `--n_epoch` | Total training epochs | 1 |
+| `--batch_size` | Batch size | 1 |
+| `--accumulation_steps` | Gradient accumulation steps | 1 |
+
+### Learning Rate Scheduling
+
+| Parameter | Description | Default Value |
+|-----------|-------------|---------------|
+| `--warmup_steps` | Warmup steps | 1000 |
+| `--max_lr` | Maximum learning rate (warmup + cosine decay) | 3e-4 |
+| `--max_grad_norm` | Maximum gradient norm | 1.0 |
+
+### Checkpoint
+
+| Parameter | Description | Default Value |
+|-----------|-------------|---------------|
+| `--ckpt_interval` | Checkpoint save interval (iterations) | 5000 |
+| `--ckpt_dir` | Checkpoint save directory | checkpoint |
+| `--resume_dir` | Resume training from specified path | - |
+
+### Optimizer Parameters
+
+| Parameter | Description | Default Value |
+|-----------|-------------|---------------|
+| `--adamw_beta1` | AdamW beta1 | 0.9 |
+| `--adamw_beta2` | AdamW beta2 | 0.95 |
+| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
+
+### Data Loading
+
+| Parameter | Description | Default Value |
+|-----------|-------------|---------------|
+| `--random_seed` | Random seed | 3407 |
+| `--num_workers` | DataLoader workers | 4 |
+| `--no_pin_memory` | Disable pin_memory | - |
+
+### Distributed Training
+
+| Parameter | Description | Default Value |
+|-----------|-------------|---------------|
+| `--nprocs` | Number of GPUs | 1 |
+| `--device_type` | Device type (cuda/cpu) | cuda |
+
+### Other Parameters
+
+| Parameter | Description | Default Value |
+|-----------|-------------|---------------|
+| `--window_size` | Maximum input sequence length | model config max_len |
+| `--stride` | Input sequence stride | - |
+| `--dpo_beta` | DPO beta value | 0.1 |
+| `--label_smoothing` | Label smoothing parameter | 0.1 |
+| `--start_epoch` | Starting epoch | 0 |
+| `--start_batch` | Starting batch | 0 |
+
+---
+
+## Generation Parameters
+
+### GenerationRequest Parameters
+
+| Parameter | Description | Default Value |
+|-----------|-------------|---------------|
+| `query` | Input text or text list | required |
+| `history` | Conversation history | None |
+| `system_prompt` | System prompt | None |
+| `temperature` | Sampling temperature (higher = more random) | required |
+| `top_p` | Nucleus sampling threshold | required |
+| `top_k` | Top-k sampling count | required |
+| `max_len` | Maximum generation length | model config max_len |
+| `stream` | Whether to stream output | False |
+
+### Usage Example
+
+```python
+from khaosz.config.param_config import ModelParameter
+from khaosz.inference.generator import StreamGenerator, GenerationRequest
+
+# Load model
+param = ModelParameter.load("your_model_dir")
+param.to(device="cuda", dtype=torch.bfloat16)
+
+# Create generator
+generator = StreamGenerator(param)
+
+# Build request
+request = GenerationRequest(
+ query="Hello",
+ history=[],
+ temperature=0.8,
+ top_p=0.95,
+ top_k=50,
+)
+
+# Generate
+response = generator.generate(request)
+```
+
+### Three Types of Generators
+
+| Generator | Usage |
+|-----------|-------|
+| `StreamGenerator` | Streaming output, returns word by word |
+| `LoopGenerator` | Non-streaming output, returns at once |
+| `BatchGenerator` | Batch generation, processes multiple queries simultaneously |
\ No newline at end of file
diff --git a/demo/download.py b/demo/download.py
index 670e9aa..8cb9052 100644
--- a/demo/download.py
+++ b/demo/download.py
@@ -1,13 +1,12 @@
-import os
+from pathlib import Path
from huggingface_hub import snapshot_download
-
-PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
-
+PROJECT_ROOT = Path(__file__).parent.parent
+PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
if __name__ == "__main__":
snapshot_download(
repo_id="ViperEk/KHAOSZ",
- local_dir=os.path.join(PROJECT_ROOT, "params"),
+ local_dir=PARAMETER_ROOT,
force_download=True,
)
diff --git a/demo/generate_ar.py b/demo/generate_ar.py
index 1a6d078..87033d0 100644
--- a/demo/generate_ar.py
+++ b/demo/generate_ar.py
@@ -1,20 +1,19 @@
-import os
import torch
+from pathlib import Path
from khaosz.config.param_config import ModelParameter
from khaosz.inference.core import disable_random_init
-from khaosz.inference.generator import LoopGenerator, GenerationRequest
+from khaosz.inference.generator import GeneratorFactory, GenerationRequest
-
-PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+PROJECT_ROOT = Path(__file__).parent.parent
+PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
def generate_text():
with disable_random_init():
- model_dir = os.path.join(PROJECT_ROOT, "params")
- param = ModelParameter.load(model_dir)
+ param = ModelParameter.load(PARAMETER_ROOT)
+ param.to(device="cuda", dtype=torch.bfloat16)
- param.to(device="cuda", dtype=torch.bfloat16)
query = input(">> ")
request = GenerationRequest(
@@ -26,7 +25,7 @@ def generate_text():
history=None,
system_prompt=None,
)
- generator = LoopGenerator(param)
+ generator = GeneratorFactory.create(param, request)
response = generator.generate(request)
print(response)
diff --git a/demo/generate_batch.py b/demo/generate_batch.py
index 234f5f4..0bf7645 100644
--- a/demo/generate_batch.py
+++ b/demo/generate_batch.py
@@ -1,19 +1,19 @@
-import os
import torch
+from pathlib import Path
from khaosz.config.param_config import ModelParameter
from khaosz.inference.core import disable_random_init
-from khaosz.inference.generator import BatchGenerator, GenerationRequest
+from khaosz.inference.generator import GeneratorFactory, GenerationRequest
-PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+PROJECT_ROOT = Path(__file__).parent.parent
+PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
def batch_generate():
- with disable_random_init():
- model_dir = os.path.join(PROJECT_ROOT, "params")
- param = ModelParameter.load(model_dir)
- param.to(device="cuda", dtype=torch.bfloat16)
- generator = BatchGenerator(param)
+ with disable_random_init():
+ param = ModelParameter.load(PARAMETER_ROOT)
+ param.to(device="cuda", dtype=torch.bfloat16)
+
inputs = [
"你好",
"请问什么是人工智能",
@@ -31,6 +31,7 @@ def batch_generate():
history=None,
system_prompt=None,
)
+ generator = GeneratorFactory.create(param, request)
responses = generator.generate(request)
for q, r in zip(inputs, responses):
diff --git a/demo/stream_chat.py b/demo/stream_chat.py
index c2bb322..d119685 100644
--- a/demo/stream_chat.py
+++ b/demo/stream_chat.py
@@ -1,21 +1,18 @@
-import os
import torch
+from pathlib import Path
from khaosz.config.param_config import ModelParameter
from khaosz.inference.core import disable_random_init
-from khaosz.inference.generator import StreamGenerator, GenerationRequest
+from khaosz.inference.generator import GeneratorFactory, GenerationRequest
-
-PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+PROJECT_ROOT = Path(__file__).parent.parent
+PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
def chat():
with disable_random_init():
- model_dir = os.path.join(PROJECT_ROOT, "params")
- param = ModelParameter.load(model_dir)
-
- param.to(device="cuda", dtype=torch.bfloat16)
- generator = StreamGenerator(param)
+ param = ModelParameter.load(PARAMETER_ROOT)
+ param.to(device="cuda", dtype=torch.bfloat16)
history = []
while True:
@@ -32,6 +29,7 @@ def chat():
history=history,
system_prompt=None,
)
+ generator = GeneratorFactory.create(param, request)
response_size = 0
full_response = ""
diff --git a/khaosz/data/dataset.py b/khaosz/data/dataset.py
index bca8d2c..b696644 100644
--- a/khaosz/data/dataset.py
+++ b/khaosz/data/dataset.py
@@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
from torch import Tensor
from torch.utils.data import Dataset
from khaosz.data.serialization import load_h5
-from typing import Callable, List, Dict, Literal, Optional, Union
+from typing import List, Dict, Optional, Union
class BaseSegmentFetcher:
diff --git a/khaosz/data/serialization.py b/khaosz/data/serialization.py
index c234bfe..af6c859 100644
--- a/khaosz/data/serialization.py
+++ b/khaosz/data/serialization.py
@@ -75,7 +75,7 @@ class Checkpoint:
with open(save_path / "meta.json", "w") as f:
json.dump(meta, f, indent=2)
- st.save_file(self.state_dict, save_path / f"state_dict.safetensors")
+ st.save_file(self.state_dict, save_path / "state_dict.safetensors")
@classmethod
def load(
@@ -96,7 +96,7 @@ class Checkpoint:
dist.broadcast_object_list(meta_list, src=0)
meta = meta_list[0]
- state_dict = st.load_file(save_path / f"state_dict.safetensors")
+ state_dict = st.load_file(save_path / "state_dict.safetensors")
return cls(
state_dict=state_dict,
diff --git a/khaosz/inference/generator.py b/khaosz/inference/generator.py
index 78a4357..534ac23 100644
--- a/khaosz/inference/generator.py
+++ b/khaosz/inference/generator.py
@@ -219,7 +219,7 @@ class BatchGenerator(GeneratorCore):
ids_list[i].append(token)
c_ids += 1
- is_active = not token in self.tokenizer.stop_ids
+ is_active = token not in self.tokenizer.stop_ids
activate_task_mask[i] = is_active
active_mask.append(is_active)
diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py
index db877be..1f959b3 100644
--- a/khaosz/trainer/strategy.py
+++ b/khaosz/trainer/strategy.py
@@ -7,7 +7,7 @@ import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch import Tensor
-from typing import Any, Callable, Dict, Union, Optional
+from typing import Any, Callable, Dict, Union
from abc import ABC, abstractmethod
diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py
index 0b1a94a..08de779 100644
--- a/khaosz/trainer/train_callback.py
+++ b/khaosz/trainer/train_callback.py
@@ -6,7 +6,6 @@ import torch.nn as nn
from pathlib import Path
from tqdm import tqdm
from torch.nn.utils import clip_grad_norm_
-from torch.optim.lr_scheduler import LRScheduler
from typing import Callable, List, Optional, Protocol
from khaosz.parallel import only_on_rank