Initial commit

This commit is contained in:
ViperEkura 2025-09-27 12:02:22 +08:00
commit a4443765ee
33 changed files with 3896 additions and 0 deletions

22
.gitignore vendored Normal file
View File

@ -0,0 +1,22 @@
# cache
__pycache__
.pytest_cache
# params
*.safetensors
*.json
*.pkl
*.db
# train_log
*.log
# ignore file
-*
# vscode file
.vscode
# build file
build
*.egg-info

201
LICENSE Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

333
README.md Normal file
View File

@ -0,0 +1,333 @@
![image-20250306182014120](/assets/images/project_logo_clipped.png)
<div style="display: flex; flex-direction: column; align-items: center; justify-content: center; text-align: center; font-size: 16px; font-weight: bold; margin-top: 50px;">
<div>
<a href="#english" style="text-decoration: none; margin: 0 10px; color: blue;">English</a> |
<a href="#chinese" style="text-decoration: none; margin: 0 10px; color: blue;">中文</a>
</div>
<h1 style="margin: 20px 0 0 0; font-size: 2.5em; font-weight: bold;">KHAOSZ </h1>
</div>
<h2 id="english">English Version</h2>
This is a Chinese-English bilingual Transformer model supporting both languages. It contains model configurations and training workflows, completing training by loading parameters defined in `param_path/config.json`. The training script `train.py` parses command-line arguments, including dataset root directory, number of training epochs, batch size, checkpoint interval, and checkpoint directory.
**Model Download Options (Choose One):**
1. Visit [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) to access **Files and versions**
2. Run `scripts/download.py` to download parameters
**Demo Video:** [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd)
Training dataset sources are listed in the **Model Card** section of the HuggingFace download link.
**License:** Code follows Apache-2.0 protocol. Please credit the source code when used.
- **📊 Device Selection:** Code defaults to CUDA training
- **🌐 Performance Optimization:** `dtype=torch.bfloat16` is enabled to accelerate training and reduce memory usage. Ensure hardware supports this feature.
- **🤖 Language Support:** Model supports Chinese and English training. The BBPE tokenizer was trained without multilingual text, so OOV (out-of-vocabulary) issues are minimized for these languages but may exist for others.
### 📌 Training Guide
To train this Transformer model, follow these steps:
**(1). Prepare Dataset:**
Place datasets in the designated root directory. Files should be text documents in Chinese, English, or mixed. Format should align with model input requirements - preferably pre-tokenized token_ids stored as `torch.Tensor` (using `torch.Tensor` saves memory compared to Python lists, which default to 64-bit precision).
**(2). Install Dependencies:**
```bash
pip install -r requirements.txt
pip install .
```
**(3). Run Training Script:**
```bash
python train.py \
--train_type=train_type[seq, sft, dpo] \
--data_root_path=/path/to/dataset \
--param_path=/path/to/param_path \
--n_epoch=5 \
--batch_size=8 \
--max_lr=2e-4 \
--n_iter_ckpt=10000 \
--ckpt_dir=checkpoints
```
**Parameters Explanation:**
- `--train_type`: Training type (seq, sft, dpo)
- `--data_root_path`: Root directory of the dataset
- `--param_path`: Path to the model training parameters
- `--n_epoch`: Total number of training epochs
- `--batch_size`: Batch size
- `--n_iter_step`: Number of batches per training step
- `--warning_step`: Number of warmup steps
- `--max_lr`: Maximum learning rate (using warmup + cosine decay)
- `--n_iter_ckpt`: Checkpoint saving interval
- `--ckpt_dir`: Directory to save checkpoints
- `--resume_dir`: Resume training from the specified path
Training logs will be saved in `train_log.txt`. Checkpoints will be saved in the specified directory for resuming training or evaluation.
### 👉 Usage Guide
**(1). Chatting with the Model:**
Open `chat.py` or use streaming/non-streaming interfaces:
**Streaming Output:**
```python
import torch
from khaosz import Khaosz
model_dir = "your_model_parameter_dir"
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
history = []
while True:
query = input(">> ")
if query == "!exit":
break
response_size = 0
for response, history in model.stream_generate(
query=query,
history=history,
temperature=0.85,
top_p=0.95,
top_k=50
):
print(response[response_size:], end="")
response_size = len(response)
```
**Non-streaming Output:**
```python
import torch
from khaosz import Khaosz
model_dir = "your_model_parameter_dir"
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
history = []
while True:
query = input(">> ")
if query == "!exit":
break
response = model.generate(
query=query,
history=history,
temperature=0.85,
top_p=0.95,
top_k=50
)
print(response)
```
**(2) Retrieval-Augmented Generation (RAG):**
```python
import torch
from khaosz import Khaosz
model_dir = "your_model_parameter_dir"
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
retrieved_content = model.retrieve_generate(
query=query,
retrieve_top_k=5,
temperature=0.6,
top_k=30,
top_p=0.95
)
print(retrieved_content)
```
### 📌 Model Specifications
This model is based on a 24-layer Transformer with parameters defined in `config.json`, totaling approximately 1.0 billion (1.0B) parameters.
**Key Design Choices:**
- Weight tying between embedding and final linear layers (standard for small models to save parameters)
- Embedding layer optimization: Without weight tying, a 10,000-word vocabulary would consume ~102M parameters (0.1B)
**Limitations:**
- May struggle with complex language phenomena due to smaller parameter size
- Prone to overfitting on specialized datasets
- Limited multilingual capabilities
**Advantages:**
- Runs efficiently on lower-spec hardware
- Shorter training time compared to larger models
**Training Pipeline:**
The model has completed pre-training + SFT (Supervised Fine-Tuning) + DPO (Direct Preference Optimization) workflows. All corresponding training code is included in the repository.
<h2 id="chinese">中文版本</h2>
这是一个支持中英文双语的 Transformer 模型,能够处理两种语言。模型包含配置文件和训练流程,通过加载 `param_path/config.json` 中定义的参数完成训练。训练脚本 `train.py` 支持命令行参数解析包括数据集根目录、训练轮数epochs、批量大小batch size、检查点保存间隔、检查点目录等。
**模型下载选项(任选其一):**
1. 访问 [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) 查看 **Files and versions**
2. 运行 `scripts/download.py` 下载模型参数
**演示视频:** [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd)
训练数据来源请参见 HuggingFace 下载页面中的 **Model Card** 部分。
**许可证:** 代码遵循 Apache-2.0 协议,使用时请注明出处。
- **📊 设备选择:** 默认使用 CUDA 进行训练
- **🌐 性能优化:** 启用 `dtype=torch.bfloat16` 以加速训练并减少内存占用,请确保硬件支持该特性
- **🤖 语言支持:** 模型支持中文和英文训练。由于 BBPE 分词器未使用多语言文本训练,因此中英文的 OOV未登录词问题较少其他语言可能存在 OOV 问题
### 📌 训练指南
要训练该 Transformer 模型,请按照以下步骤操作:
#### **(1). 准备数据集:**
将数据集放置在指定的根目录下。文件应为包含中文、英文或混合文本的文本文档。格式应符合模型输入要求——建议使用预分词后的 `token_ids` 并以 `torch.Tensor` 格式保存(使用 `torch.Tensor` 相比 Python 列表更节省内存,列表默认为 64 位精度)。
#### **(2). 安装依赖:**
```bash
pip install -r requirements.txt
pip install .
```
#### **(3). 运行训练脚本:**
```bash
python train.py \
--train_type=train_type[seq, sft, dpo] \
--data_root_path=/path/to/dataset \
--param_path=/path/to/param_path \
--n_epoch=5 \
--batch_size=8 \
--max_lr=2e-4 \
--n_iter_ckpt=10000 \
--ckpt_dir=checkpoints
```
**参数说明:**
- `--train_type`: 训练类型seq, sft, dpo
- `--data_root_path`: 数据集根目录
- `--param_path`: 模型训练参数路径
- `--n_epoch`: 总训练轮数
- `--batch_size`: 批量大小
- `--n_iter_step`: 每个训练步骤的 batch 数量
- `--warning_step`: 预热步数warmup steps
- `--max_lr`: 最大学习率(使用预热 + 余弦衰减)
- `--n_iter_ckpt`: 检查点保存间隔
- `--ckpt_dir`: 检查点保存目录
- `--resume_dir`: 从指定路径恢复训练
训练日志将保存在 `train_log.txt` 中。检查点将保存在指定目录,用于恢复训练或评估。
### 👉 使用指南
#### **(1). 与模型对话:**
打开 `chat.py` 或使用流式/非流式接口:
**流式输出:**
```python
import torch
from khaosz import Khaosz
model_dir = "your_model_parameter_dir"
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
history = []
while True:
query = input(">> ")
if query == "!exit":
break
response_size = 0
for response, history in model.stream_generate(
query=query,
history=history,
temperature=0.85,
top_p=0.95,
top_k=50
):
print(response[response_size:], end="")
response_size = len(response)
```
**非流式输出:**
```python
import torch
from khaosz import Khaosz
model_dir = "your_model_parameter_dir"
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
history = []
while True:
query = input(">> ")
if query == "!exit":
break
response = model.generate(
query=query,
history=history,
temperature=0.85,
top_p=0.95,
top_k=50
)
print(response)
```
#### **(2). 基于检索的生成RAG**
```python
import torch
from khaosz import Khaosz
model_dir = "your_model_parameter_dir"
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
retrieved_content = model.retrieve_generate(
query=query,
retrieve_top_k=5,
temperature=0.6,
top_k=30,
top_p=0.95
)
print(retrieved_content)
```
### 📌 模型规格说明(重复部分)
该模型基于一个 24 层的 Transformer 架构,参数配置定义在 `config.json` 中,总参数量约为 10 亿1.0B)。
**关键设计选择:**
- 在嵌入层embedding与最终线性层之间进行权重绑定weight tying这是小型模型中常见的节省参数量的做法
- 嵌入层优化:若不进行权重绑定,一个包含 10,000 个词的词汇表将消耗约 1.02 亿0.1B)参数
**局限性:**
- 由于参数规模较小,可能在处理复杂语言现象时表现受限
- 在特定领域的数据集上容易出现过拟合
- 多语言能力有限
**优势:**
- 可在低配置硬件上高效运行
- 相较于大型模型,训练时间更短
**训练流程:**
该模型已完成预训练pre-training+ 监督微调SFT, Supervised Fine-Tuning+ 直接偏好优化DPO, Direct Preference Optimization的全流程。所有相关的训练代码均已包含在代码库中。

View File

@ -0,0 +1,89 @@
## 模型介绍
### 1. 模型搭建
本模型采用Transformer架构 使用GQAq_head=24, kv_head=4 机制相较于传统的MHA可以节省KV cache 的显存占用但是目前没有做KV cache通过堆叠24层Transformer实现模型的搭建 参数量为1.0b。Transformer 是自回归模型, 是通过计算前面所有的token的关系得到下一个token的概率分布
![structure](../images/structure.png)
什么是自回归模型呢, 在把句子拆分成token之后, 模型会预测下一个token的概率分布。这意味着模型会根据给定的上下文即已经出现的tokens序列计算出下一个可能的token及其对应的概率。
#### 1. 自回归
假设我们有一个句子被拆分成如下tokens列表
```
["你好", "" "今天", "天气"]
```
接下来模型会基于这个序列预测下一个可能出现的token。这通常以概率分布的形式给出比如
```
-> {"token": "不错", "probability": 0.4}
-> {"token": "晴朗", "probability": 0.2}
-> ......
```
这里“不错”和“晴朗”是两个可能跟随在“天气”之后的tokens并且给出了每个token成为下一个token的可能性大小。
之后我们通过采样通过top_k, top_p, temperature参数调整采样后的结果得到下一个token并且将下一个token加入序列作为输入
```
["你好", "" "今天", "天气", "不错"]
```
之后都是在重复这个流程, 直到遇到控制流程结束的token<|end_of_seqence|>模型停止处理一般模型都会设置控制token 不然模型会一直输出到显存爆炸)。
#### 2. 因果掩码
transformer 中采用注意力机制,输入的形状一般为[bsz, seq_len] 输出为[bsz, seq_lenn_dim] 为了实现预测下一个token 模型的输入和输出必须错开来一个位置。模型预测的target必须错开一个位置 在训练的时候我们也采用错开一个位置的方法
```
sequence : [[1, 2, 3, 4, 5, 6]]
input_ids: [[1, 2, 3, 4, 5]]
target_ids: [[2, 3, 4, 5, 6]]
```
注意力得分计算的公式为
$$ s_{ij} = softmax(\frac{q_i^Tk_j}{\sqrt{d_k}}) $$
$$ s_{ij} := s_{ij} + mask_{ij} $$
其中注意力得分代表了模型对两个token之间相似程度的关注程度
对于decoder only结构的模型 为了防止模型从未来的位置偷到信息, 在注意力的计算过程中需要增加掩码我们需要在注意力得分计算之前应用一个掩码。这个掩码通常是一个下三角矩阵对于长度为n的序列它的形状是[n, n]。下面以一个长度为5的序列为例展示如何创建这样的因果掩码矩阵
```
[[0, -inf, -inf, -inf, -inf],
[0, 0, -inf, -inf, -inf],
[0, 0, 0, -inf, -inf],
[0, 0, 0, 0, -inf],
[0, 0, 0, 0, 0]]
```
在这个矩阵中0表示可以注意到的位置而-inf表示应该被掩盖即不应注意到的位置。因为这个句子保证了注意力得分中 $j > i$ 的部分通过softmax 之后由`inf` 变成0 也就是模型不能看到未来的信息
#### 3. 旋转位置编码
旋转位置编码Rotary Position Embedding, RoPE是一种为了解决Transformer模型中缺乏对序列位置信息直接建模的问题而设计的位置编码方法。与传统的位置编码如正弦和余弦函数的位置编码不同RoPE通过将位置信息直接嵌入到查询Query, Q和键Key, K向量中来实现使得模型能够更自然地处理序列中的相对位置关系。
$$ q_i = R_i W_q x_i $$
$$ k_j = R_j W_k x_j $$
$$ q_i^T k_j = (R_i W_q x_i)^T( R_j W_k x_j) = x_i^T W_q^T R_{i-j} W_k x_j $$
其中的 $R_{i-j}$ 控制了模型的不同token 在不同相对距离上注意力的衰减,在 $i - j$ 绝对值越大的时候, 衰减的程度越强, 通过这种方式能让模型学习到相对位置关系, 从而使得模型可以扩展和适应长序列

27
assets/docs/kvcache.md Normal file
View File

@ -0,0 +1,27 @@
## kv_cache 实现
根据注意力的计算公式
$$
\begin{align*}
o_i &= \sum_j s_{ij} v_{j} \\
s_{ij} &= \text{softmax}\left( \sum_n \frac{q_{i,n} k_{j,n}}{\sqrt{d_k}} \right)
\end{align*}
$$
由于模型是自回归模型, 我们只用求序列最后一个部分,也就是说 $ i $ 的下标是确定的, 是序列最后一个元素, 我们求的是 $o_{n} $
$$
\begin{align*}
o_n &= \sum_j s_{j}v_{j,n} \\
s_j &= \text{softmax}\left(\sum_n\frac{q_n k_{j,n}}{\sqrt{d_k}} \right)
\end{align*}
$$
如果我们把式子展开
$$
o_n = \sum_j \sum_n \text{softmax}\left(\frac{q_n k_{j,n}}{\sqrt{d_k}}\right)v_{j,n}
$$
以上表达式只有k和v存在长度下标, 而 $q$ 没有, 所以计算过程中 $q$ 的输入是确定的上次输入的最后一个token, 而 $k, v$ 是需要对不同长度的部分进行缓存的同时缓存的时候应该注意位置编码的计算应该在kvcache的计算之前进行否则会存在位置编码的计算错误

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

BIN
assets/images/structure.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 590 KiB

101
generate.py Normal file
View File

@ -0,0 +1,101 @@
import os
import torch
import json
import torch
import argparse
from khaosz import Khaosz
from typing import List
from tqdm import tqdm
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
def batch_generate(
model: Khaosz,
queries: List[str],
temperature: float,
top_k: int,
top_p: float,
batch_size: int,
) -> List:
assert batch_size > 0
sorted_queries = sorted(queries, key=lambda x: len(x), reverse=True)
original_indices = {query: idx for idx, query in enumerate(queries)}
responses = [None] * len(queries)
total_batches = (len(sorted_queries) + batch_size - 1) // batch_size
for i in tqdm(range(0, total_batches * batch_size, batch_size), desc="Generating responses"):
batch_queries = sorted_queries[i: min(i + batch_size, len(queries))]
if not isinstance(batch_queries, list):
batch_queries = [batch_queries]
batch_responses = model.batch_generate(
queries=batch_queries,
temperature=temperature,
top_k=top_k,
top_p=top_p
)
for batch_query, batch_response in zip(batch_queries, batch_responses):
print(f"Q: {batch_query[:50]} \nR: {batch_response[:50]})")
for query, response in zip(batch_queries, batch_responses):
original_idx = original_indices[query]
responses[original_idx] = response
return responses
def processor(
model: Khaosz,
input_json_file: str,
output_json_file: str,
batch_size: int,
temperature: float,
top_p: float,
top_k: int,
question_key: str="question",
):
with open(input_json_file, "r", encoding='utf-8') as f:
input_dict = [json.loads(line) for line in f]
queries = [item[question_key] for item in input_dict]
output_dict = batch_generate(
model=model,
queries=queries,
temperature=temperature,
top_k=top_k,
top_p=top_p,
batch_size=batch_size
)
with open(output_json_file, "w", encoding='utf-8') as f:
json.dump(output_dict, f, indent=4, ensure_ascii=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.")
parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory.")
parser.add_argument("--input_json_file", type=str, required=True, help="Path to the input JSONL file.")
parser.add_argument("--output_json_file", type=str, required=True, help="Path to the output JSONL file.")
parser.add_argument("--question_key", type=str, default="question", help="Key for the question in the input JSON.")
parser.add_argument("--temperature", type=float, default=0.60, help="Temperature for generating responses.")
parser.add_argument("--top_p", type=float, default=0.95, help="Top-p value for generating responses.")
parser.add_argument("--top_k", type=int, default=30, help="Top-k value for generating responses.")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for generating responses.")
args = parser.parse_args()
model = Khaosz(args.model_dir).to(device='cuda', dtype=torch.bfloat16)
processor(
model,
input_json_file=args.input_json_file,
output_json_file=args.output_json_file,
question_key=args.question_key,
batch_size=args.batch_size,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p
)

52
khaosz/__init__.py Normal file
View File

@ -0,0 +1,52 @@
__version__ = "1.2.1"
__author__ = "ViperEkura"
from khaosz.model import Khaosz
from khaosz.core.transformer import Transformer, TransformerConfig
from khaosz.utils.retriever import Retriever
from khaosz.utils.splitter import (
SemanticTextSplitter,
PriorityTextSplitter
)
from khaosz.core.tokenizer import BpeTokenizer
from khaosz.core.parameter import ParameterLoader
from khaosz.core.generator import (
TextGenerator,
ChatGenerator,
StreamGenerator,
BatchGenerator,
RetrievalGenerator,
EmbeddingEncoder
)
from khaosz.trainer.trainer import Trainer
from khaosz.trainer.dataset import SeqDataset, SftDataset, DpoDataset, BaseDataset
__all__ = [
# model
"Khaosz",
# module
"Transformer",
"TransformerConfig",
"BpeTokenizer",
"ParameterLoader",
"TextGenerator",
"ChatGenerator",
"StreamGenerator",
"BatchGenerator",
"RetrievalGenerator",
"EmbeddingEncoder",
# trainer
"Trainer",
"SeqDataset",
"SftDataset",
"DpoDataset",
"BaseDataset",
# utils
"Retriever",
"SemanticTextSplitter",
"PriorityTextSplitter",
]

27
khaosz/core/__init__.py Normal file
View File

@ -0,0 +1,27 @@
from khaosz.core.tokenizer import BpeTokenizer
from khaosz.core.transformer import Transformer, TransformerConfig
from khaosz.core.parameter import ParameterLoader, ModelParameter, Checkpoint
from khaosz.core.generator import (
TextGenerator,
ChatGenerator,
StreamGenerator,
BatchGenerator,
RetrievalGenerator,
EmbeddingEncoder
)
__all__ = [
"Transformer",
"TransformerConfig",
"BpeTokenizer",
"ParameterLoader",
"ModelParameter",
"Checkpoint",
"TextGenerator",
"ChatGenerator",
"StreamGenerator",
"BatchGenerator",
"RetrievalGenerator",
"EmbeddingEncoder"
]

568
khaosz/core/generator.py Normal file
View File

@ -0,0 +1,568 @@
import torch
from torch import Tensor
from typing import List, Tuple, Union, Optional, Generator, Self
from khaosz.core.parameter import ModelParameter
def build_prompt(query: str, history: Optional[List[Tuple[str, str]]] = None) -> str:
"""
Build prompt for query and history
Args:
query(str): query string
history(Optional[List[Tuple[str, str]]]): history list of query and response
Returns:
str: prompt string
"""
prompt_parts = []
if history is None:
history = []
for his_query, his_response in history:
prompt_parts.append(f"<|user|> {his_query} <|system|> <bos>{his_response}<eos>")
if query is not None:
prompt_parts.append(f"<|user|> {query} <|system|> <bos>")
return "\n".join(prompt_parts)
def pad_sequence(ids_list: List[List[int]], max_ids_len: int, pad_id: int) -> List[List[int]]:
"""
Pad a list of sequences to a fixed length.
Args:
ids_list (List[List[int]]): A list of sequences.
max_ids_len (int): The maximum length of sequences.
pad_id (int): The id to pad sequences.
Returns:
List[List[int]]: A list of padded sequences.
"""
new_ids_list = []
for ids in ids_list:
pad_len = max_ids_len - len(ids)
padded_seq = [pad_id] * pad_len + ids
new_ids_list.append(padded_seq)
return new_ids_list
def apply_sampling_strategies(
logits: Tensor,
temperature: float,
top_k: int,
top_p: float,
filter_value: float = -float("inf")
) -> Tensor:
"""
Apply sampling strategies to the logits tensor.
Args:
logits (Tensor): The logits tensor.
temperature (float): The temperature parameter.
top_k (int): The top-k parameter.
top_p (float): The top-p parameter.
filter_value (float, optional): The filter value. Defaults to -float("inf").
Returns:
Tensor: The sampled logits tensor.
"""
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
top_k = min(top_k, logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(
dim=1,
index=sorted_indices,
src=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
class KVCacheManager:
def __init__(
self,
num_layers: int,
batch_size: int,
max_len: int,
num_heads: int,
head_dim: int,
device: torch.device = "cuda",
dtype: torch.dtype = torch.bfloat16
):
self.num_layers = num_layers
self.batch_size = batch_size
self.max_len = max_len
self.num_heads = num_heads
self.head_dim = head_dim
self.device = device
self.dtype = dtype
self._kv_cache: List[Tuple[Tensor, Tensor]] = None
self._seq_mask: Tensor = None
self._initialize()
def _initialize(self):
self._kv_cache = []
for _ in range(self.num_layers):
k_cache = torch.zeros(
(self.batch_size, self.max_len, self.num_heads, self.head_dim),
device=self.device, dtype=self.dtype
)
v_cache = torch.zeros(
(self.batch_size, self.max_len, self.num_heads, self.head_dim),
device=self.device, dtype=self.dtype
)
self._kv_cache.append((k_cache, v_cache))
self._seq_mask = torch.ones(
(self.batch_size, self.max_len),
device=self.device, dtype=torch.bool
)
def update(self, active_mask: Tensor):
for i in range(self.num_layers):
k_cache, v_cache = self._kv_cache[i]
new_k_cache, new_v_cache = k_cache[active_mask], v_cache[active_mask]
self._kv_cache[i] = (new_k_cache, new_v_cache)
self._seq_mask = self._seq_mask[active_mask]
def reset(self, full_reset=False):
if full_reset:
self._kv_cache = None
self._seq_mask = None
else:
self._initialize()
def set_seq_mask(self, input_ids: Tensor, pad_id: int):
batch_size, seq_len = input_ids.shape
bool_mask = (input_ids != pad_id)
self._seq_mask[: batch_size, : seq_len] = bool_mask
def get_kvcache(self) -> List[Tuple[Tensor, Tensor]]:
return self._kv_cache
def get_seq_mask(self) -> Tensor:
return self._seq_mask
class GeneratorCore:
def __init__(self, parameter: ModelParameter):
self.model = parameter.model
self.tokenizer = parameter.tokenizer
self.config = parameter.config
def compute_logits(
self,
input_ids: Tensor,
attn_mask: Optional[Tensor] = None,
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
start_pos: int = 0
) -> Tuple[Tensor, int]:
with torch.inference_mode():
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
logits = outputs["logits"][:, -1, :]
cache_increase = input_ids.size(-1)
return logits, cache_increase
def to(self, *args, **kargs) -> Self:
self.model.to(*args, **kargs)
return self
class EmbeddingEncoderCore:
def __init__(self, parameter: ModelParameter):
self.model = parameter.model
self.tokenizer = parameter.tokenizer
self.config = parameter.config
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
with_batch = isinstance(sentence, list)
ids = self.tokenizer.encode(sentence)
batch_ids = ids if with_batch else [ids]
max_model_len = self.config.m_len
all_fragments = []
fragment_origin_idx = []
for i, seq in enumerate(batch_ids):
if len(seq) > max_model_len:
fragments = [seq[j:j+max_model_len] for j in range(0, len(seq), max_model_len)]
all_fragments.extend(fragments)
fragment_origin_idx.extend([i] * len(fragments))
else:
all_fragments.append(seq)
fragment_origin_idx.append(i)
#if empty fragments
if not all_fragments or not ids:
return [] if with_batch else torch.tensor([])
device = next(self.model.parameters()).device
max_len = min(max(len(seq) for seq in all_fragments), max_model_len)
padded_ids = []
masks = []
for seq in all_fragments:
pad_len = max_len - len(seq)
padded_seq = seq + [self.tokenizer.pad_id] * pad_len
mask = [token_id != self.tokenizer.pad_id for token_id in padded_seq]
padded_ids.append(padded_seq)
masks.append(mask)
input_tensor = torch.tensor(padded_ids, device=device, dtype=torch.long)
seq_mask = torch.tensor(masks, device=device, dtype=torch.bool)
with torch.inference_mode():
outputs = self.model(input_tensor, seq_mask)["hidden_states"]
# [num_fragments, seq_len, hidden_size]
fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1))
sentence_embs: List[Tensor] = []
for i in range(len(batch_ids)):
indices = [idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i]
if indices is not None:
sum_frags = torch.sum(fragment_embs[indices, :, :], dim=1) # [frags, hidden_size]
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(1) # [frags, 1]
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
sentence_embs.append(emb.flatten())
if with_batch:
return [emb.flatten() for emb in sentence_embs]
else:
return sentence_embs[0].flatten()
def to(self, *args, **kargs) -> Self:
self.model.to(*args, **kargs)
return self
class TextGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def generate(
self,
query: str,
temperature: float,
top_k: int,
top_p: float,
) -> str:
assert temperature >= 0.0
assert top_k >= 0
assert top_p >= 0.0 and top_p <= 1.0
device = next(self.model.parameters()).device
cache_manager = KVCacheManager(
num_layers=self.config.n_layer,
batch_size=1,
max_len=self.config.m_len,
num_heads=self.config.n_kvhead,
head_dim=self.config.n_dim // self.config.n_head,
device=device,
)
ids = self.tokenizer.encode(query)
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
start_cache_pos = len(ids)
cur_cache_pos = 0
self.model.eval()
while len(ids) < self.config.m_len:
kv_caches = cache_manager.get_kvcache()
logits, cache_increase = self.compute_logits(
input_ids,
kv_caches=kv_caches,
start_pos=cur_cache_pos
)
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
probs = torch.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
input_ids = next_token_id
ids.append(next_token_id.item())
cur_cache_pos += cache_increase
if next_token_id.item() in self.tokenizer.stop_ids:
break
response = self.tokenizer.decode(ids[start_cache_pos:])
return response
class ChatGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def generate(
self,
query: str,
history: List[Tuple[str, str]],
temperature: float,
top_k: int,
top_p: float,
) -> str:
assert temperature >= 0.0
assert top_k >= 0
assert top_p >= 0.0 and top_p <= 1.0
if history is None:
history = []
device = next(self.model.parameters()).device
cache_manager = KVCacheManager(
num_layers=self.config.n_layer,
batch_size=1,
max_len=self.config.m_len,
num_heads=self.config.n_kvhead,
head_dim=self.config.n_dim // self.config.n_head,
device=device,
)
ids = self.tokenizer.encode(build_prompt(query, history))
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
cpy_history = history.copy()
start_cache_pos = len(ids)
cur_cache_pos = 0
self.model.eval()
while len(ids) < self.config.m_len:
kv_caches = cache_manager.get_kvcache()
logits, cache_increase = self.compute_logits(
input_ids,
kv_caches=kv_caches,
start_pos=cur_cache_pos
)
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
probs = torch.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
input_ids = next_token_id
ids.append(next_token_id.item())
cur_cache_pos += cache_increase
if next_token_id.item() in self.tokenizer.stop_ids:
break
response = self.tokenizer.decode(ids[start_cache_pos:])
cpy_history.append((query, response))
return response, cpy_history
class StreamGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def generate(
self,
query: str,
history: List[Tuple[str, str]],
temperature: float,
top_k: int,
top_p: float,
) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
assert temperature >= 0.0
assert top_k >= 0
assert top_p >= 0.0 and top_p <= 1.0
if history is None:
history = []
device = next(self.model.parameters()).device
cache_manager = KVCacheManager(
num_layers=self.config.n_layer,
batch_size=1,
max_len=self.config.m_len,
num_heads=self.config.n_kvhead,
head_dim=self.config.n_dim // self.config.n_head,
device=device,
)
ids = self.tokenizer.encode(build_prompt(query, history))
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
cpy_history = history.copy()
start_cache_pos = len(ids)
cur_cache_pos = 0
self.model.eval()
while len(ids) < self.config.m_len:
kv_caches = cache_manager.get_kvcache()
logits, cache_increase = self.compute_logits(
input_ids,
kv_caches=kv_caches,
start_pos=cur_cache_pos
)
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
probs = torch.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
input_ids = next_token_id
ids.append(next_token_id.item())
cur_cache_pos += cache_increase
response = self.tokenizer.decode(ids[start_cache_pos:])
yield response, cpy_history + [(query, response)]
if next_token_id.item() in self.tokenizer.stop_ids:
yield response + "\n", cpy_history + [(query, response)]
break
class BatchGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def generate(
self,
queries: List[str],
histories: List[List[Tuple[str, str]]],
temperature: float,
top_k: int,
top_p: float
) -> List[str]:
assert temperature >= 0.0
assert top_k >= 0
assert top_p >= 0.0 and top_p <= 1.0
batch_size = len(queries)
if histories is None:
histories = [[] for _ in range(batch_size)]
prompts = [build_prompt(query, history) for query, history in zip(queries, histories)]
ids_list = [self.tokenizer.encode(prompt) for prompt in prompts]
max_ids_len = max(len(ids) for ids in ids_list)
ids_list = pad_sequence(ids_list, max_ids_len, self.tokenizer.pad_id)
device = next(self.model.parameters()).device
cache_manager = KVCacheManager(
num_layers=self.config.n_layer,
batch_size=batch_size,
max_len=self.config.m_len,
num_heads=self.config.n_kvhead,
head_dim=self.config.n_dim // self.config.n_head,
device=device,
)
input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long)
cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id)
activate_task_mask = [True] * batch_size
start_cache_pos = max_ids_len
cur_cache_pos = 0
while max_ids_len < self.config.m_len and sum(activate_task_mask) != 0:
kv_caches = cache_manager.get_kvcache()
attn_mask =cache_manager.get_seq_mask()
logits, cache_increase = self.compute_logits(
input_tensor,
attn_mask=attn_mask,
kv_caches=kv_caches,
start_pos=cur_cache_pos
)
cur_cache_pos += cache_increase
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
probs = torch.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
active_mask = []
c_ids = 0
for i in range(batch_size):
if activate_task_mask[i]:
token = next_token_id[c_ids, :].item()
ids_list[i].append(token)
c_ids += 1
is_active = not token in self.tokenizer.stop_ids
activate_task_mask[i] = is_active
active_mask.append(is_active)
active_mask = torch.tensor(active_mask, device=device, dtype=torch.bool)
cache_manager.update(active_mask)
input_tensor = next_token_id[active_mask, :]
max_ids_len += 1
responses = [str()] * batch_size
for i in range(batch_size):
responses[i] = self.tokenizer.decode(ids_list[i][start_cache_pos:])
histories[i].append((queries[i], responses[i]))
return responses
class RetrievalGenerator(GeneratorCore):
def __init__(self, retriever_parameter: ModelParameter):
super().__init__(retriever_parameter)
def generate(
self,
retrieved: List[str],
query: str,
history: List[Tuple[str, str]],
temperature: float,
top_k: int,
top_p: float,
) -> str:
assert temperature >= 0.0
assert top_k >= 0
assert top_p >= 0.0 and top_p <= 1.0
if history is None:
history = []
retrieved = "\n".join([f"{idx + 1}. {key}" for idx, key in enumerate(retrieved)]) if retrieved else ""
retrieved_query = f"{retrieved}<eos>\n\n根据以上内容回答: {query}" if retrieved else query
parameter = ModelParameter(self.model, self.tokenizer, self.config)
return ChatGenerator(parameter).generate(
retrieved_query,
history,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
class EmbeddingEncoder(EmbeddingEncoderCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
return super().encode(sentence)

238
khaosz/core/parameter.py Normal file
View File

@ -0,0 +1,238 @@
import pickle as pkl
import matplotlib.pyplot as plt
import safetensors.torch as st
import torch.nn as nn
import torch.optim as optim
from dataclasses import dataclass, field
from typing import Optional, Self, Union
from pathlib import Path
from khaosz.core.tokenizer import BpeTokenizer
from khaosz.core.transformer import TransformerConfig, Transformer
class BaseModelIO:
"""Base class for model I/O operations."""
def __init__(
self,
model: Optional[nn.Module] = None,
tokenizer: Optional[BpeTokenizer] = None,
config: Optional[TransformerConfig] = None
):
self.model = model
self.tokenizer = tokenizer or BpeTokenizer()
self.config = config or TransformerConfig()
def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
"""Get standardized file paths for model components."""
dir_path = Path(directory)
return {
"model": dir_path / "model.safetensors",
"config": dir_path / "config.json",
"tokenizer": dir_path / "tokenizer.json"
}
def save_components(self, save_dir: Union[str, Path]):
"""Save core model components."""
paths = self._get_file_paths(save_dir)
paths["model"].parent.mkdir(parents=True, exist_ok=True)
if self.model is not None:
st.save_file(self.model.state_dict(), str(paths["model"]))
self.config.save(str(paths["config"]))
self.tokenizer.save(str(paths["tokenizer"]))
def load_components(self, load_dir: Union[str, Path]) -> Self:
"""Load core model components."""
paths = self._get_file_paths(load_dir)
self.config.load(str(paths["config"]))
self.tokenizer.load(str(paths["tokenizer"]))
if paths["model"].exists():
state_dict = st.load_file(str(paths["model"]))
if self.model is None:
self.model = Transformer(self.config)
self.model.load_state_dict(state_dict)
return self
def to(self, *args, **kwargs) -> Self:
"""Move model to device."""
if self.model is not None:
self.model.to(*args, **kwargs)
return self
@dataclass
class ModelParameter(BaseModelIO):
"""Container for model parameters with serialization capabilities."""
model: Optional[nn.Module] = field(
default=None,
metadata={"help": "Transformer model."}
)
tokenizer: BpeTokenizer = field(
default_factory=BpeTokenizer,
metadata={"help": "Tokenizer for the model."}
)
config: TransformerConfig = field(
default_factory=TransformerConfig,
metadata={"help": "Transformer model configuration."}
)
def save(self, save_dir: Union[str, Path]):
"""Save model parameters."""
self.save_components(save_dir)
def load(self, load_dir: Union[str, Path]) -> Self:
"""Load model parameters."""
return self.load_components(load_dir)
@dataclass
class Checkpoint(BaseModelIO):
"""Extended model parameters with training state."""
model: Optional[nn.Module] = field(
default=None,
metadata={"help": "Transformer model."}
)
tokenizer: BpeTokenizer = field(
default_factory=BpeTokenizer,
metadata={"help": "Tokenizer for the model."}
)
config: TransformerConfig = field(
default_factory=TransformerConfig,
metadata={"help": "Transformer model configuration."}
)
loss_list: list[float] = field(
default_factory=list,
metadata={"help": "List of training losses."}
)
current_iter: int = field(
default=0,
metadata={"help": "Current training iteration."}
)
optimizer: Optional[optim.Optimizer] = field(
default=None,
metadata={"help": "Optimizer state."}
)
def __post_init__(self):
# Ensure current_iter matches loss list length if not explicitly set
if self.current_iter == 0 and self.loss_list:
self.current_iter = len(self.loss_list)
def _get_training_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
"""Get file paths for training-specific files."""
paths = self._get_file_paths(directory)
paths.update({
"loss_list": paths["model"].parent / "loss.pkl",
"loss_plot": paths["model"].parent / "loss.png",
"optimizer": paths["model"].parent / "optimizer.pkl"
})
return paths
def save_training_state(self, save_dir: Union[str, Path]):
"""Save training-specific state."""
paths = self._get_training_paths(save_dir)
# Save loss plot
self._plot_loss(str(paths["loss_plot"]))
# Save loss list
with open(str(paths["loss_list"]), "wb") as f:
pkl.dump(self.loss_list, f)
# Save optimizer state
if self.optimizer is not None:
with open(str(paths["optimizer"]), "wb") as f:
pkl.dump(self.optimizer.state_dict(), f)
def load_training_state(self, load_dir: Union[str, Path]) -> Self:
"""Load training-specific state."""
paths = self._get_training_paths(load_dir)
# Load loss list
if paths["loss_list"].exists():
with open(str(paths["loss_list"]), "rb") as f:
self.loss_list = pkl.load(f)
self.current_iter = len(self.loss_list)
# Load optimizer state
if paths["optimizer"].exists() and self.optimizer is not None:
with open(str(paths["optimizer"]), "rb") as f:
optim_state = pkl.load(f)
self.optimizer.load_state_dict(optim_state)
return self
def _plot_loss(self, save_path: str):
"""Plot and save loss curve."""
if not self.loss_list:
return
plt.figure(figsize=(10, 6))
plt.plot(self.loss_list)
plt.title(f"Training Loss - Iteration {self.current_iter}")
plt.xlabel("Batch")
plt.ylabel("Loss")
plt.grid(True)
plt.savefig(save_path, dpi=300, bbox_inches="tight")
plt.close()
def save(self, save_dir: Union[str, Path]):
"""Save complete checkpoint."""
self.save_components(save_dir)
self.save_training_state(save_dir)
def load(self, load_dir: Union[str, Path]) -> Self:
"""Load complete checkpoint."""
self.load_components(load_dir)
self.load_training_state(load_dir)
return self
class ParameterLoader:
"""Factory class for loading model parameters or checkpoints."""
@staticmethod
def load(load_dir: Union[str, Path]) -> Union[ModelParameter, Checkpoint]:
"""Load either ModelParameter or Checkpoint based on directory contents."""
load_dir = Path(load_dir)
# Check for training-specific files
loss_file = load_dir / "loss.pkl"
has_training_data = loss_file.exists()
# Create appropriate instance
if has_training_data:
checkpoint = Checkpoint()
checkpoint.load(str(load_dir))
return checkpoint
else:
params = ModelParameter()
params.load(str(load_dir))
return params
@staticmethod
def create_checkpoint(
model: nn.Module,
tokenizer: BpeTokenizer,
config: TransformerConfig,
loss_list: Optional[list[float]] = None,
optimizer: Optional[optim.Optimizer] = None
) -> Checkpoint:
"""Convenience method to create a training checkpoint."""
return Checkpoint(
model=model,
tokenizer=tokenizer,
config=config,
loss_list=loss_list or [],
optimizer=optimizer
)

111
khaosz/core/tokenizer.py Normal file
View File

@ -0,0 +1,111 @@
from tokenizers import Tokenizer, Encoding
from tokenizers import decoders, processors, normalizers, pre_tokenizers
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from typing import List, Union
class BpeTokenizer:
def __init__(self, path=None):
self._control_tokens = ["<bos>", "<eos>", "<pad>"]
self._special_tokens = ["<|user|>", "<|system|>"]
model = BPE()
tokenizer = Tokenizer(model)
tokenizer.normalizer = normalizers.Sequence([
normalizers.NFC()
])
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
pre_tokenizers.Punctuation(behavior="isolated"),
pre_tokenizers.Metaspace(prepend_scheme="never"),
pre_tokenizers.Split(pattern=r"(\d+|[a-zA-Z]+|(?:'s|'t|'re|'ve|'m|'ll|'d))", behavior="isolated"),
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
])
tokenizer.decoder = decoders.Sequence([
decoders.ByteLevel(),
decoders.Metaspace(prepend_scheme="never")
])
tokenizer.post_processor = processors.Sequence([
processors.ByteLevel(trim_offsets=False)
])
self._tokenizer = tokenizer
if path is not None:
self._tokenizer = Tokenizer.from_file(path)
def _prepare_trainer(self, vocab_size: int, min_freq: int, reserved_token_size: int) -> tuple:
assert reserved_token_size > len(self._special_tokens)
reserved_tokens = [f"<|rsv{i:02d}|>" for i in range(reserved_token_size - len(self._special_tokens))]
detail_vocab_size = vocab_size - (len(reserved_tokens) + len(self._special_tokens))
alphabet = pre_tokenizers.ByteLevel.alphabet()
min_size = len(alphabet) + len(self._control_tokens)
assert detail_vocab_size > min_size
trainer = BpeTrainer(
vocab_size=detail_vocab_size,
min_frequency=min_freq,
limit_alphabet=detail_vocab_size // 4,
max_token_length=18,
special_tokens=self._control_tokens,
show_progress=True,
initial_alphabet=alphabet,
)
return trainer, detail_vocab_size, reserved_tokens
def train(self, files, vocab_size, min_freq, reserved_token_size=100):
trainer, _, reserved_tokens = self._prepare_trainer(
vocab_size=vocab_size,
min_freq=min_freq,
reserved_token_size=reserved_token_size
)
self._tokenizer.train(files=files, trainer=trainer)
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
def train_from_iterator(self, iterator, vocab_size, min_freq, reserved_token_size=100):
trainer, _, reserved_tokens = self._prepare_trainer(
vocab_size=vocab_size,
min_freq=min_freq,
reserved_token_size=reserved_token_size
)
self._tokenizer.train_from_iterator(iterator=iterator, trainer=trainer)
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
def save(self, path):
self._tokenizer.save(path)
def load(self, path):
self._tokenizer = Tokenizer.from_file(path)
def encode(self, tokens: Union[str, List[str]], out_ids: bool=True, add_special_tokens: bool=False) -> List:
if isinstance(tokens, str):
encoded: Encoding = self._tokenizer.encode(tokens, add_special_tokens=add_special_tokens)
return encoded.ids if out_ids else encoded.tokens
elif isinstance(tokens, list):
encoded_list: List[Encoding] = self._tokenizer.encode_batch(tokens, add_special_tokens=add_special_tokens)
return [encoded.ids if out_ids else encoded.tokens for encoded in encoded_list]
def decode(self, tokens: List[int], skip_special_tokens: bool=True) -> str:
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
def __len__(self) -> int:
return self._tokenizer.get_vocab_size()
@property
def stop_ids(self) -> List[int]:
stop_ids = []
for token in self._control_tokens:
stop_ids.append(self._tokenizer.token_to_id(token))
return stop_ids
@property
def bos_id(self) -> int:
return self._tokenizer.token_to_id("<bos>")
@property
def eos_id(self) -> int:
return self._tokenizer.token_to_id("<eos>")
@property
def pad_id(self) -> int:
return self._tokenizer.token_to_id("<pad>")

341
khaosz/core/transformer.py Normal file
View File

@ -0,0 +1,341 @@
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import init
from dataclasses import asdict, dataclass
from typing import List, Optional, Self, Tuple
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
"""
Repeat k times along the dimension for attention heads.
Args:
x (Tensor): The input tensor.
n_rep (int): The number of repetitions.
Returns:
Tensor: The repeated tensor.
"""
bs, slen, n_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_heads, n_rep, head_dim)
.reshape(bs, slen, n_heads * n_rep, head_dim)
)
def get_rotary_emb(
dim: int,
max_len: int,
base: float = 10000,
device: torch.device = "cuda",
) -> torch.Tensor:
"""
Get the rotary embedding for the given dimension and maximum length.
Args:
dim (int): The dimension of the input.
max_len (int): The maximum length of the input.
base (float, optional): The base for the frequency. Defaults to 10000.
device (torch.device, optional): The device to use. Defaults to "cuda".
Returns:
Tensor: The rotary embedding tensor.
"""
theta = base ** (-torch.arange(0, dim, 2, device=device).float() / dim)
t = torch.arange(0, max_len, device=device).float()
freqs = torch.outer(t, theta)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
"""
Apply rotary embedding to the input tensor.
Args:
x (Tensor): The input tensor.
freqs_cis (Tensor): The rotary embedding tensor.
Returns:
Tensor: The output tensor.
"""
dtype = x.dtype
seq_len = x.size(1)
x_complex = torch.view_as_complex(x.view(*x.shape[:-1], -1, 2).float())
freqs_cis = freqs_cis.reshape(1, seq_len, 1, -1)
x_out = torch.view_as_real(x_complex * freqs_cis).flatten(3)
return x_out.to(dtype)
def create_attention_mask(
seq_mask: Tensor,
start_pos: int = 0,
seq_len: int = 0,
is_causal: bool = False,
device: torch.device = "cuda",
dtype: torch.dtype = torch.float32
) -> Tensor:
"""
Create attention mask for GQA
Args:
seq_mask (Tensor): A tensor indicating whether each position is valid or not.
start_pos (int): The starting position of the sequence.
seq_len (int): The length of the sequence.
is_causal (bool): Whether the attention is causal or not.
device (torch.device): The device to use.
Returns:
Tensor: The attention mask tensor.
"""
if start_pos != 0 and seq_mask is None:
# for single prompt chat
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
if seq_mask is None:
return None
batch_size = seq_mask.size(0)
seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool)
# (bsz, start_pos + seq_len)
expanded_mask = seq_mask.unsqueeze(1).expand(batch_size, seq_len, start_pos + seq_len)
# (bsz, seq_len, start_pos + seq_len)
if is_causal:
causal_mask = torch.tril(
torch.ones((seq_len, start_pos + seq_len), dtype=torch.bool, device=device),
diagonal=start_pos
)
causal_mask = causal_mask.unsqueeze(0).expand(batch_size, seq_len, start_pos + seq_len)
expanded_mask = expanded_mask & causal_mask
attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
attention_mask = attention_mask.masked_fill_(~expanded_mask, -torch.finfo(dtype).max / 2).unsqueeze(1)
# (bsz, 1, seq_len, seq_len + start_pos)
return attention_mask
@dataclass
class TransformerConfig:
# basic config
vocab_size: Optional[int] = None
n_dim: Optional[int] = None
n_head: Optional[int] = None
n_layer: Optional[int] = None
m_len: Optional[int] = None
norm_eps: Optional[float] = None
d_ffn: Optional[int] = None
# GQA
n_kvhead: Optional[int] = None
def load(self, config_path: str) -> Self:
with open(config_path, 'r') as f:
config: dict = json.load(f)
for key, value in config.items():
if hasattr(self, key):
setattr(self, key, value)
return self
def save(self, config_path: str) -> None:
config_dict = asdict(self)
config_dict = {k: v for k, v in config_dict.items() if v is not None}
with open(config_path, 'w') as f:
json.dump(config_dict, f, indent=4)
class Linear(nn.Module):
def __init__(self, in_dim: int, out_dim: int, bias: bool=False):
super().__init__()
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
init.normal_(self.weight, mean=0, std=0.006)
def forward(self, x: Tensor) -> Tensor:
return F.linear(x, self.weight, self.bias)
class RMSNorm(nn.Module):
def __init__(self, n_dim, norm_eps):
super().__init__()
self.weight = nn.Parameter(torch.ones(n_dim))
self.norm_eps = norm_eps
def forward(self, x: Tensor) -> Tensor:
dtype = x.dtype
x = x.float()
mean_square = torch.mean(torch.pow(x, 2), dim=-1, keepdim=True)
norm = x * torch.rsqrt(mean_square + self.norm_eps)
norm = norm.to(dtype)
out = norm * self.weight
return out
class MLP(nn.Module):
def __init__(self, n_dim: int, d_ffn: int):
super().__init__()
self.up = Linear(n_dim, d_ffn)
self.gate = Linear(n_dim, d_ffn)
self.down = Linear(d_ffn, n_dim)
def forward(self, x: Tensor) -> Tensor:
gated = self.up(x) * F.silu(self.gate(x))
out = self.down(gated)
return out
class GQA(nn.Module):
def __init__(
self,
n_dim: int,
n_head: int,
n_kvhead: int,
):
super().__init__()
assert n_dim % n_head == 0
assert n_head % n_kvhead == 0
self.head_dim = n_dim // n_head
self.n_dim = n_dim
self.n_heads = n_head
self.n_kvheads = n_kvhead
self.n_rep = n_head // n_kvhead
self.q_proj = Linear(n_dim, n_head * self.head_dim)
self.k_proj = Linear(n_dim, n_kvhead * self.head_dim)
self.v_proj = Linear(n_dim, n_kvhead * self.head_dim)
self.o_proj = Linear(n_dim, n_dim)
def forward(
self,
x: Tensor,
freqs_cis: Tensor,
mask: Tensor = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0
) -> Tensor:
bsz, seq_len, _ = x.size()
# x(bsz, seq_len, n_heads * head_dim) -> (bsz, seq_len, n_heads, head_dim)
q = self._split_heads(self.q_proj(x), self.n_heads)
k = self._split_heads(self.k_proj(x), self.n_kvheads)
v = self._split_heads(self.v_proj(x), self.n_kvheads)
q, k = apply_rotary_emb(q, freqs_cis), apply_rotary_emb(k, freqs_cis)
if kv_cache is not None:
k_cache, v_cache = kv_cache
# copy to cache
k_cache[:bsz, start_pos:start_pos + seq_len] = k
v_cache[:bsz, start_pos:start_pos + seq_len] = v
# get cache
k = k_cache[:bsz, :start_pos + seq_len]
v = v_cache[:bsz, :start_pos + seq_len]
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
sdqa_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=(mask == None)).permute(0, 2, 1, 3)
out = self.o_proj(sdqa_out.contiguous().view(bsz, seq_len, -1))
return out
def _split_heads(self, x: Tensor, n_heads) -> Tensor:
batch_size, seq_len, _ = x.shape
x = x.reshape(batch_size, seq_len, n_heads, self.head_dim)
return x
class DecoderBlock(nn.Module):
def __init__(self, n_dim, n_head, d_ffn, n_kvhead, norm_eps):
super().__init__()
self.attention = GQA(n_dim, n_head, n_kvhead)
self.norm_attn = RMSNorm(n_dim, norm_eps)
self.ffn = MLP(n_dim, d_ffn)
self.norm_ffn = RMSNorm(n_dim, norm_eps)
def forward(
self,
x: Tensor,
freqs_cis: Tensor,
attention_mask: Optional[Tensor] = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0
) -> Tensor:
# attention
attn_output = self.attention(
self.norm_attn(x),
freqs_cis,
attention_mask,
kv_cache,
start_pos
)
x = attn_output + x
# feed forward
x = self.ffn(self.norm_ffn(x)) + x
return x
class Transformer(nn.Module):
def __init__(self, config: TransformerConfig):
super().__init__()
self.embedding = nn.Parameter(torch.empty(config.vocab_size, config.n_dim))
self.layers = nn.ModuleList([
DecoderBlock(
config.n_dim,
config.n_head,
config.d_ffn,
config.n_kvhead,
config.norm_eps
)
for _ in range(config.n_layer)
])
self.norm = RMSNorm(config.n_dim, config.norm_eps)
self.freq_cis = get_rotary_emb(config.n_dim // config.n_head, config.m_len)
init.normal_(self.embedding, mean=0, std=0.02)
def forward(
self,
input_ids: Tensor,
seq_mask: Optional[Tensor]=None,
persistent_key_values: Optional[List[Tuple[Tensor, Tensor]]]=None,
start_pos: int = 0
) -> Tensor:
assert input_ids.ndim == 2
seq_len = input_ids.size(-1)
x = F.embedding(input_ids, self.embedding)
self.freq_cis = self.freq_cis.to(x.device)
freqs_cis = self.freq_cis[start_pos:start_pos+seq_len]
has_kvcache = persistent_key_values is not None
attn_mask = create_attention_mask(
seq_mask,
start_pos=start_pos,
seq_len=seq_len,
is_causal=has_kvcache,
device=x.device,
dtype=x.dtype
)
for i, layer in enumerate(self.layers):
kv_cache = persistent_key_values[i] if persistent_key_values else None
x = layer(x, freqs_cis, attn_mask, kv_cache, start_pos)
hidden_states = self.norm(x)
logits = F.linear(hidden_states, self.embedding)
return {
"logits": logits,
"hidden_states": hidden_states
}

112
khaosz/model.py Normal file
View File

@ -0,0 +1,112 @@
from torch import Tensor
from typing import List, Tuple, Generator, Union
from khaosz.core.generator import (
TextGenerator,
ChatGenerator,
StreamGenerator,
BatchGenerator,
RetrievalGenerator,
EmbeddingEncoder
)
from khaosz.core.parameter import ParameterLoader
class Khaosz:
def __init__(self, model_dir: str):
self.parameter = ParameterLoader.load(model_dir)
def to(self, *args, **kwargs):
self.parameter.to(*args, **kwargs)
return self
def generate(
self,
query: str,
history: List[Tuple[str, str]]=None,
temperature: float=0.8,
top_k: int=50,
top_p: float=0.95,
) -> str:
generator = ChatGenerator(self.parameter)
return generator.generate(
query,
history=history,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
def batch_generate(
self,
queries: List[str],
histories: List[Tuple[str, str]]=None,
temperature: float=0.8,
top_k: int=50,
top_p: float=0.95,
) -> List[str]:
generator = BatchGenerator(self.parameter)
return generator.generate(
queries,
histories=histories,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
def stream_generate(
self,
query: str,
history: List[Tuple[str, str]]=None,
temperature: float=0.8,
top_k: int=50,
top_p: float=0.95,
) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
stream_generator = StreamGenerator(self.parameter)
return stream_generator.generate(
query,
history=history,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
def retrieve_generate(
self,
retrieved,
query: str,
history: List[Tuple[str, str]] = None,
temperature: float=0.8,
top_k: int=50,
top_p: float=0.95,
) -> str:
generator = RetrievalGenerator(self.parameter)
return generator.generate(
retrieved,
query,
history=history,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
def text_generate(
self,
query: str,
temperature: float=0.8,
top_k: int=50,
top_p: float=0.95,
) -> str:
generator = TextGenerator(self.parameter)
return generator.generate(
query,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
encoder = EmbeddingEncoder(self.parameter)
return encoder.encode(sentence)

View File

@ -0,0 +1,11 @@
from khaosz.trainer.dataset import DatasetLoader
from khaosz.trainer.trainer import Trainer
from khaosz.trainer.strategy import TrainConfig, CosineScheduleConfig, SgdrScheduleConfig
__all__ = [
"DatasetLoader",
"Trainer",
"TrainConfig",
"CosineScheduleConfig",
"SgdrScheduleConfig",
]

210
khaosz/trainer/dataset.py Normal file
View File

@ -0,0 +1,210 @@
import torch
import bisect
import pickle as pkl
from abc import ABC, abstractmethod
from torch import Tensor
from torch.utils.data import Dataset
from typing import Callable, List, Dict, Literal, Union
MutiSeg = Dict[str, List[Tensor]]
Seg = Dict[str, Tensor]
def load_pkl_files(paths: List[str]):
segments: MutiSeg = {}
total_samples = 0
for path in paths:
with open(path, "rb") as f:
pkl_file: Seg = pkl.load(f)
for key, value in pkl_file.items():
if key not in segments:
segments[key] = []
segments[key].append(value)
first_key = list(pkl_file.keys())[0]
total_samples += pkl_file[first_key].numel()
return segments, total_samples
class BaseSegmentFetcher:
def __init__(self, segments: List[Tensor]):
self.segments = segments
self.cum_lengths = []
total = 0
for seg in segments:
total += len(seg)
self.cum_lengths.append(total)
self.total_length = total if segments else 0
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length):
raise ValueError("begin_idx or end_idx out of bounds")
if begin_idx >= end_idx:
return torch.tensor([], dtype=torch.long)
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx - 1)
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx - 1)
result_segments = []
for i in range(seg_start_idx, seg_end_idx + 1):
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
start = max(begin_idx - prev_cum, 0)
end = min(end_idx - prev_cum, len(self.segments[i]))
result_segments.append(self.segments[i][start:end])
return torch.cat(result_segments, dim=0)
class MutiSegmentFetcher:
def __init__(self, muti_segments: MutiSeg):
self.muti_keys = list(muti_segments.keys())
self.muti_fetchers = {
key: BaseSegmentFetcher(segments)
for key, segments in muti_segments.items()
}
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Union[Tensor, Seg]:
fetch_dict = {}
keys = [keys] if isinstance(keys, str) else keys
for key in keys:
fetcher = self.muti_fetchers[key]
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
fetch_dict[key] = fetch_tensor
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
def fetch_data(self, begin_idx: int, end_idx: int) -> Union[Tensor, Seg]:
return self.key_fetch(begin_idx, end_idx, self.muti_keys)
class BaseDataset(Dataset, ABC):
def __init__(self, chunk_size: int, device: str):
super().__init__()
self.segments: MutiSeg = {}
self.chunk_size = chunk_size
self.total_samples = 0
self.device = device
def save(self, save_path: str):
first_item = self.segments[keys[0]]
segment_size = len(first_item)
keys = list(self.segments.keys())
for i in range(segment_size):
formated_segment = {key: self.segments[key][i] for key in keys}
pkl.dump(formated_segment, open(f"{save_path}_{i}.pkl", "wb"))
def load(self, load_path: Union[str, List[str]]):
paths = [load_path] if isinstance(load_path, str) else load_path
self.segments, self.total_samples = load_pkl_files(paths)
self.fetcher = MutiSegmentFetcher(self.segments)
@abstractmethod
def __getitem__(self, index: int):
raise NotImplementedError
def __len__(self) -> int:
assert self.total_samples // self.chunk_size > 0
return self.total_samples // self.chunk_size
class SeqDataset(BaseDataset):
def __init__(self, chunk_size , device='cuda'):
super().__init__(chunk_size, device)
self.fetcher = MutiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
def __getitem__(self, index):
begin_idx = index * self.chunk_size
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
x = self._fetch_data(begin_idx, end_idx).to(device=self.device, dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(device=self.device, dtype=torch.long)
return x, y
class SftDataset(BaseDataset):
def __init__(self, chunk_size, device='cuda'):
super().__init__(chunk_size, device)
self.fetcher = MutiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index):
begin_idx = index * self.chunk_size
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
x = self._fetch_data(begin_idx, end_idx, "sequence").to(device=self.device, dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(device=self.device, dtype=torch.long)
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "mask").to(device=self.device, dtype=torch.bool)
return x, y, loss_mask
class DpoDataset(BaseDataset):
def __init__(self, chunk_size: int, device="cuda"):
super().__init__(chunk_size, device)
self.fetcher = MutiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index: int):
start_idx = index * self.chunk_size
end_idx = min(start_idx + self.chunk_size, self.total_samples - 1)
chosen = self._fetch_data(start_idx, end_idx, "chosen").to(device=self.device, dtype=torch.long)
rejected = self._fetch_data(start_idx, end_idx, "rejected").to(device=self.device, dtype=torch.long)
chosen_mask = self._fetch_data(start_idx, end_idx, "chosen_mask").to(device=self.device, dtype=torch.bool)
rejected_mask = self._fetch_data(start_idx, end_idx, "rejected_mask").to(device=self.device, dtype=torch.bool)
return chosen, rejected, chosen_mask, rejected_mask
class PpoDataset(BaseDataset):
def __init__(self, chunk_size: int, device="cuda"):
super().__init__(chunk_size, device)
self.fetcher = MutiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index: int) -> Dict[str, Tensor]:
begin_idx = index * self.chunk_size
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
input_ids = self._fetch_data(begin_idx, end_idx, "input_ids").to(self.device),
actions = self._fetch_data(begin_idx, end_idx, "actions").to(self.device),
logprobs = self._fetch_data(begin_idx, end_idx, "logprobs").to(self.device),
rewards = self._fetch_data(begin_idx, end_idx, "rewards").to(self.device)
return input_ids, actions, logprobs, rewards
class DatasetLoader:
@staticmethod
def load(
train_type: Literal["seq", "sft", "dpo"],
load_path: Union[str, List[str]],
max_len: int,
device: str
) -> BaseDataset:
dataset_router: Dict[str, Callable[[int, torch.device], BaseDataset]] = {
"seq": lambda m_len, device: SeqDataset(m_len, device=device),
"sft": lambda m_len, device: SftDataset(m_len, device=device),
"dpo": lambda m_len, device: DpoDataset(m_len, device=device),
}
dataset = dataset_router[train_type](max_len, device)
dataset.load(load_path)
return dataset

55
khaosz/trainer/mask.py Normal file
View File

@ -0,0 +1,55 @@
import torch
from abc import abstractmethod
from torch import Tensor
class MaskBuilder:
def __init__(
self,
bos_token_id: int,
eos_token_id: int,
user_token_id: int,
system_token_id: int,
):
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.user_token_id = user_token_id
self.system_token_id = system_token_id
@abstractmethod
def build(input_ids: Tensor) -> Tensor:
raise NotImplementedError
class LossMaskBuilder(MaskBuilder):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def build(self, input_ids: Tensor) -> Tensor:
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
is_user_token = input_ids.eq(self.user_token_id)
is_system_token = input_ids.eq(self.system_token_id)
token_markers[is_user_token] = 1
token_markers[is_system_token] = -1
cumulative_markers = torch.cumsum(token_markers, dim=-1)
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
loss_mask = cumulative_markers - min_cumulative
return loss_mask
class AttentionMaskBuilder:
def __init__(self, **kwargs):
super().__init__(**kwargs)
def build(input_ids: Tensor):
bsz = input_ids.size(0)

388
khaosz/trainer/strategy.py Normal file
View File

@ -0,0 +1,388 @@
import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import Dataset
from typing import Any, Literal, Tuple, Callable, Dict
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass, field
def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id):
input_mask = input_ids.ne(pad_token_id)
logits = model(input_ids, input_mask)["logits"]
log_probs = torch.log_softmax(logits, dim=-1)
shifted_log_probs = log_probs[:, :-1, :]
shifted_input_ids = input_ids[:, 1:]
shifted_response_mask = mask[:, 1:]
token_logprobs = torch.gather(
shifted_log_probs,
dim=-1,
index=shifted_input_ids.unsqueeze(-1)
).squeeze(-1)
prompt_mask = input_mask[:, 1:]
valid_mask = (prompt_mask & shifted_response_mask).float()
return (token_logprobs * valid_mask).sum(dim=-1)
class MaskBuilder:
def __init__(
self,
bos_token_id: int,
eos_token_id: int,
user_token_id: int,
system_token_id: int,
):
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.user_token_id = user_token_id
self.system_token_id = system_token_id
@abstractmethod
def build(input_ids: Tensor) -> Tensor:
raise NotImplementedError
class LossMaskBuilder(MaskBuilder):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def build(self, input_ids: Tensor) -> Tensor:
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
is_user_token = input_ids.eq(self.user_token_id)
is_system_token = input_ids.eq(self.system_token_id)
token_markers[is_user_token] = 1
token_markers[is_system_token] = -1
cumulative_markers = torch.cumsum(token_markers, dim=-1)
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
loss_mask = cumulative_markers - min_cumulative
return loss_mask.to(dtype=torch.bool)
class AttentionMaskBuilder(MaskBuilder):
def __init__(self, multi_turn=False, **kwargs):
super().__init__(**kwargs)
self.multi_turn = multi_turn
def build(self, input_ids: Tensor):
bsz = input_ids.size(0)
def _build_batch(self, input_ids: Tensor):
is_user_token = input_ids.eq(self.user_token_id)
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
token_markers[is_user_token] = 1
cumulative_markers = torch.cumsum(token_markers, dim=-1)
class BaseStrategy(ABC):
def __init__(self, model: nn.Module):
self.model = model
@abstractmethod
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
raise NotImplementedError
def __call__(self, batch: Tuple[Tensor, ...]) -> Tensor:
return self.compute_loss(batch)
class SeqStrategy(BaseStrategy):
def __init__(self, model):
super().__init__(model)
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
x, y = batch
B, L = x.size()
logits: Tensor = self.model(x)["logits"]
loss = F.cross_entropy(
logits.view(B * L, -1), y.flatten()
)
return loss
class SftStrategy(BaseStrategy):
def __init__(self, model):
super().__init__(model)
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
x, y, loss_mask = batch
B, L = x.size()
ignore_idx = -1
logits: Tensor = self.model(x)["logits"]
masked_y = y.masked_fill(loss_mask == 0, ignore_idx)
loss = F.cross_entropy(
logits.view(B * L, -1),
masked_y.flatten(),
ignore_index=ignore_idx
)
return loss
class DpoStrategy(BaseStrategy):
def __init__(self, model, pad_token_id, beta):
super().__init__(model)
ref_model = copy.deepcopy(self.model)
ref_model.requires_grad_(False)
ref_model.eval()
self.ref_model = ref_model
self.pad_token_id = pad_token_id
self.beta = beta
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
good_ids, bad_ids, good_mask, bad_mask = batch
log_pi_good = get_logprobs(self.model, good_ids, good_mask, self.pad_token_id)
log_pi_bad = get_logprobs(self.model, bad_ids, bad_mask, self.pad_token_id)
with torch.no_grad():
log_ref_good = get_logprobs(self.ref_model, good_ids, good_mask, self.pad_token_id)
log_ref_bad = get_logprobs(self.ref_model, bad_ids, bad_mask, self.pad_token_id)
pi_log_ratio = log_pi_good - log_pi_bad
ref_log_ratio = log_ref_good - log_ref_bad
ratio_diff = pi_log_ratio - ref_log_ratio
dpo_loss = -F.logsigmoid(self.beta * ratio_diff).mean()
return dpo_loss
class PpoStrategy(BaseStrategy):
def __init__(self, model, pad_token_id, epsilon):
super().__init__(model)
ref_model = copy.deepcopy(self.model)
ref_model.requires_grad_(False)
ref_model.eval()
self.ref_model = ref_model
self.pad_token_id = pad_token_id
self.epsilon = epsilon
def ppo_clip_loss_masked(
self,
log_probs: Tensor,
old_log_probs: Tensor,
advantages: Tensor,
values: Tensor,
returns: Tensor,
mask: Tensor,
clip_eps: float=0.2,
):
ratio = torch.exp(log_probs - old_log_probs)
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * advantages
policy_loss = -torch.min(surr1, surr2).masked_select(mask).mean()
value_loss = F.mse_loss(values.masked_select(mask),
returns.masked_select(mask))
entropy = -(log_probs.exp() * log_probs).masked_select(mask).mean()
entropy_loss = -entropy
return policy_loss, value_loss, entropy_loss
class StrategyFactory:
def load(model, train_type, pad_token_id, dpo_beta):
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
"seq": lambda: SeqStrategy(model),
"sft": lambda: SftStrategy(model),
"dpo": lambda: DpoStrategy(model, pad_token_id, dpo_beta)
}
strategy = train_strategy[train_type]()
return strategy
@dataclass
class TrainConfig:
train_type: str = field(
default_factory=["seq", "sft", "dpo"],
metadata={"help": "Type of training."}
)
dataset: Dataset = field(
default=None,
metadata={"help": "Dataset for training."}
)
optimizer: Optimizer = field(
default=None,
metadata={"help": "Optimizer for training."}
)
ckpt_dir: str = field(
default="./checkpoint",
metadata={"help": "Checkpoint directory."}
)
n_epoch: int = field(
default=1,
metadata={"help": "Number of epochs for training."}
)
batch_size: int = field(
default=4,
metadata={"help": "Batch size for training."}
)
n_iter_ckpt: int = field(
default=5000,
metadata={"help": "Number of iterations between checkpoints."}
)
n_iter_step: int = field(
default=1,
metadata={"help": "Number of iterations between steps."}
)
max_grad_norm: float = field(
default=1.0,
metadata={"help": "Maximum gradient norm."}
)
random_seed: int = field(
default=3407,
metadata={"help": "Random seed."}
)
dpo_beta: float = field(
default=0.1,
metadata={"help": "DPO beta."}
)
def get_kwargs(self)-> Dict[str, Any]:
config_dict = asdict(self)
return {k: v for k, v in config_dict.items() if v is not None}
@dataclass
class ScheduleConfig:
schedule_type: str = field(
default_factory=["cosine", "sgdr"],
metadata={"help": "Type of learning rate schedule."}
)
warning_step: int = field(
default=1000,
metadata= {"help": "Warning up step."}
)
@abstractmethod
def get_kwargs(self)-> Dict[str, Any]:
raise NotImplementedError
@dataclass
class CosineScheduleConfig(ScheduleConfig):
total_iters: int = field(
default=None,
metadata={"help": "Total iterations for cosine schedule."}
)
min_rate: float = field(
default=0.05,
metadata={"help": "Minimum rate for cosine schedule."}
)
schedule_type: Literal["cosine"] = "cosine"
def get_kwargs(self) -> Dict[str, Any]:
return {
"schedule_type": self.schedule_type,
"warning_step": self.warning_step,
"lr_decay_iters": self.total_iters - self.warning_step,
"min_rate": self.min_rate
}
@dataclass
class SgdrScheduleConfig(ScheduleConfig):
cycle_length: int = field(
default=1000,
metadata={"help": "Cycle length for sgdr schedule."}
)
min_rate: float = field(
default=0.05,
metadata={"help": "Minimum rate for sgdr schedule."}
)
T_mult: int = field(
default=2,
metadata={"help": "T_mult for sgdr schedule."}
)
schedule_type: Literal["sgdr"] = "sgdr"
def get_kwargs(self) -> Dict[str, Any]:
return {
"schedule_type": self.schedule_type,
"warning_step": self.warning_step,
"cycle_length": self.cycle_length,
"min_rate": self.min_rate,
"T_mult": self.T_mult
}
class SchedulerFactory:
@staticmethod
def get_sgdr_schedule(
warning_step: int,
cycle_length: int,
min_rate: float = 0.1,
T_mult: int = 2
) -> Callable[[int], float]:
def sgdr_schedule(now_iter: int) -> float:
if now_iter < warning_step:
return max(min_rate, now_iter / warning_step)
adjusted_iter = now_iter - warning_step
total_cycles, current_cycle = 0, 0
while adjusted_iter >= cycle_length * (T_mult ** total_cycles):
current_cycle += 1
total_cycles += 1
cycle_start = sum(cycle_length * (T_mult ** i) for i in range(current_cycle))
cycle_pos = adjusted_iter - cycle_start
cycle_length_current = cycle_length * (T_mult ** current_cycle)
return max(min_rate, 0.5 * (1 + math.cos(math.pi * cycle_pos / cycle_length_current)))
return sgdr_schedule
@staticmethod
def get_cosine_warmup_schedule(
warning_step: int,
lr_decay_iters: int,
min_rate: float = 0.1
) -> Callable[[int], float]:
def cosine_warmup_schedule(now_iter: int) -> float:
if now_iter <= warning_step:
return max(min_rate, now_iter / warning_step)
else:
rate = (now_iter - warning_step) / (lr_decay_iters - warning_step)
return max(min_rate, 0.5 * (1.0 + math.cos(math.pi * rate)))
return cosine_warmup_schedule
@staticmethod
def load_schedule_fn(**kwargs):
strategy = kwargs.pop("schedule_type")
if strategy == "cosine":
return SchedulerFactory.get_cosine_warmup_schedule(**kwargs)
elif strategy == "sgdr":
return SchedulerFactory.get_sgdr_schedule(**kwargs)
else:
raise ValueError(f"Invalid schedule type: {strategy}")

167
khaosz/trainer/trainer.py Normal file
View File

@ -0,0 +1,167 @@
import os
import torch
import logging
from typing import Tuple
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, RandomSampler
from tqdm import tqdm
from khaosz.core import ModelParameter, Checkpoint
from khaosz.trainer.strategy import SchedulerFactory, StrategyFactory, TrainConfig, ScheduleConfig
class Trainer:
def __init__(
self,
parameter: ModelParameter,
log_path: str="./train_log.log"
):
logger = logging.getLogger()
logger.setLevel(level = logging.INFO)
handler = logging.FileHandler(log_path)
handler.setLevel(logging.INFO)
handler.setFormatter(logging.Formatter('%(asctime)s: %(message)s'))
logger.addHandler(handler)
logger.info("initializing trainer ...")
self.logger = logger
self.model = parameter.model
self.tokenizer = parameter.tokenizer
self.config = parameter.config
def save_checkpoint(
self,
loss_list: list,
ckpt_dir: str,
current_iter: int,
last_ckpt_iter: int
):
save_path = os.path.join(ckpt_dir, f"iter_{current_iter}")
Checkpoint(
self.model,
self.tokenizer,
self.config,
loss_list,
current_iter
).save(save_path)
diff_iter = current_iter - last_ckpt_iter
avg_loss = sum(loss_list[last_ckpt_iter:current_iter]) / diff_iter
self.logger.info(f"iter: {current_iter} loss: {avg_loss}")
return current_iter
def load_checkpoint(self, train_checkpoint: Checkpoint) -> Tuple[list, int]:
self.model = train_checkpoint.model
self.tokenizer = train_checkpoint.tokenizer
self.config = train_checkpoint.config
loss_list = train_checkpoint.loss_list
last_ckpt_iter = train_checkpoint.current_iter
return loss_list, last_ckpt_iter
def train(
self,
train_config: TrainConfig,
schedule_config: ScheduleConfig,
train_checkpoint: Checkpoint = None
):
assert schedule_config.schedule_type in ["cosine", "sgdr"]
assert train_config.train_type in ["seq", "sft", "dpo"]
if train_checkpoint:
loss_list, last_ckpt_iter = self.load_checkpoint(train_checkpoint)
current_iter = train_checkpoint.current_iter + 1
self.logger.info(f"Resuming training from checkpoint: iter {current_iter}")
else:
current_iter = 0
last_ckpt_iter = 0
loss_list = []
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
**schedule_config.get_kwargs()
)
strategy = StrategyFactory.load(
self.model,
train_config.train_type,
self.tokenizer.pad_id,
train_config.dpo_beta
)
scheduler = LambdaLR(
train_config.optimizer,
lambda_scheduler_fn,
last_epoch=current_iter - 1 if train_checkpoint else -1
)
seed = train_config.random_seed
generator = torch.Generator().manual_seed(seed)
sampler = RandomSampler(train_config.dataset, generator=generator)
remaining_epochs = train_config.n_epoch - current_iter // (len(train_config.dataset) // train_config.batch_size)
self.logger.info(f"Starting {train_config.train_type.upper()} training for {train_config.n_epoch} epochs")
self.logger.info(f"Checkpoint interval: {train_config.n_iter_ckpt} iterations")
for epoch in range(remaining_epochs):
self.model.train()
dataloader = DataLoader(
train_config.dataset,
batch_size=train_config.batch_size,
sampler=sampler
)
progress_bar = tqdm(
dataloader,
desc=f"Epoch {epoch+1}/{train_config.n_epoch}",
dynamic_ncols=True
)
for batch in progress_bar:
#forward
loss = strategy(batch)
loss_list.append(loss.item())
#backward
loss.backward()
#step
if current_iter % train_config.n_iter_step == 0:
clip_grad_norm_(
self.model.parameters(),
train_config.max_grad_norm
)
train_config.optimizer.step()
train_config.optimizer.zero_grad()
current_iter += 1
scheduler.step()
progress_bar.set_postfix({
"loss": f"{loss.item():.4f}",
"lr": f"{train_config.optimizer.param_groups[0]['lr']:.2e}"
})
#save checkpotint
if current_iter - last_ckpt_iter >= train_config.n_iter_ckpt:
last_ckpt_iter = self.save_checkpoint(
loss_list,
train_config.ckpt_dir,
current_iter,
last_ckpt_iter
)
if current_iter != last_ckpt_iter:
last_ckpt_iter = self.save_checkpoint(
loss_list,
train_config.ckpt_dir,
current_iter,
last_ckpt_iter
)
self.logger.info("Training completed")
return Checkpoint(
self.model,
self.tokenizer,
self.config,
loss_list,
current_iter,
train_config.optimizer
)

88
khaosz/utils/retriever.py Normal file
View File

@ -0,0 +1,88 @@
import torch
import sqlite3
import numpy as np
from torch import Tensor
from typing import Dict, List, Tuple
class Retriever:
def __init__(self, db_path=None):
self.data: Dict[str, Tensor] = {}
self.embedding_cache: Tensor = None
self.is_caculated: bool = False
if db_path is not None:
self.load(db_path)
def retrieve(self, query: Tensor, top_k: int) -> List[Tuple[str, float]]:
if not self.data:
return []
query = query.flatten().unsqueeze(1) # [dim, 1]
norm_embeddings = self._embeddings.to(
device=query.device,
dtype=query.dtype
) # [n_vectors, dim]
sim_scores = torch.matmul(norm_embeddings, query).squeeze() # [n_vectors]
top_k = min(top_k, len(self.data))
indices = sim_scores.topk(top_k).indices
keys = list(self.data.keys())
return [(keys[i], sim_scores[i].item()) for i in indices]
def add_vector(self, key: str, vector_data: Tensor):
self.is_caculated = False
self.data[key] = vector_data.flatten().float().cpu()
def delete_vector(self, key: str):
self.is_caculated = False
self.data.pop(key, None)
def save(self, db_path):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
self._init_db(cursor)
cursor.execute('DELETE FROM vectors')
for item, vec in self.data.items():
vec_bytes = vec.numpy().tobytes()
cursor.execute('INSERT OR REPLACE INTO vectors (key, vector) VALUES (?, ?)',
(item, vec_bytes))
conn.commit()
conn.close()
def load(self, db_path):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
self._init_db(cursor)
cursor.execute('SELECT key, vector FROM vectors')
rows = cursor.fetchall()
self.data = {}
for row in rows:
key, vec_bytes = row
vec_numpy = np.frombuffer(vec_bytes, dtype=np.float32).copy()
vec = torch.from_numpy(vec_numpy)
self.data[key] = vec
conn.close()
def _init_db(self,cursor: sqlite3.Cursor):
# Create table if not exists (in case loading from a new database)
cursor.execute('''
CREATE TABLE IF NOT EXISTS vectors (
id INTEGER PRIMARY KEY AUTOINCREMENT,
key TEXT UNIQUE NOT NULL,
vector BLOB NOT NULL
)''')
@property
def _embeddings(self) -> Tensor:
if not self.is_caculated:
embeddings = torch.stack(list(self.data.values()))
norm_embeddings = embeddings / torch.norm(embeddings, dim=-1, keepdim=True)
self.embedding_cache = norm_embeddings
return self.embedding_cache

127
khaosz/utils/splitter.py Normal file
View File

@ -0,0 +1,127 @@
import re
import torch
import torch.nn.functional as F
from abc import ABC, abstractmethod
from torch import Tensor
from typing import List, Callable, Optional
class BaseTextSplitter(ABC):
def __init__(
self,
max_len: int = 512,
chunk_overlap: int = 0,
):
if max_len <= 0:
raise ValueError("max_len must be > 0")
if chunk_overlap < 0:
raise ValueError("chunk_overlap must be >= 0")
self.max_len = max_len
self.chunk_overlap = chunk_overlap
@abstractmethod
def split(self, text: str, **kwargs) -> List[str]:
raise NotImplementedError
def preprocess(self, text: str) -> str:
return text.strip()
def postprocess(self, chunks: List[str]) -> List[str]:
return [chunk.strip() for chunk in chunks if chunk.strip()]
class PriorityTextSplitter(BaseTextSplitter):
def __init__(
self,
separators: List[str],
max_len: int = 512,
chunk_overlap: int = 0,
):
super().__init__(max_len=max_len, chunk_overlap=chunk_overlap)
if not separators:
raise ValueError("separators must be a non-empty list")
self.separators = separators
def split(self, text: str) -> List[str]:
text = self.preprocess(text)
for sep in self.separators:
parts = text.split(sep)
valid_parts = [p.strip() for p in parts if p.strip()]
if len(valid_parts) > 1:
return self.postprocess(valid_parts)
return [text]
class SemanticTextSplitter(BaseTextSplitter):
DEFAULT_PATTERN = r'(?<=[。!?!?])(?=(?:[^"\'‘’“”]*["\'‘’“”][^"\'‘’“”]*["\'‘’“”])*[^"\'‘’“”]*$)'
def __init__(
self,
embedding_func: Callable[[List[str]], List[Tensor]],
pattern: Optional[str] = None,
max_len: int = 512,
chunk_overlap: int = 0,
):
super().__init__(max_len=max_len, chunk_overlap=chunk_overlap)
if not callable(embedding_func):
raise TypeError("embedding_func must be callable")
self.embedding_func = embedding_func
self.pattern = pattern or SemanticTextSplitter.DEFAULT_PATTERN
def split(
self,
text: str,
threshold: float = 0.5,
window_size: int = 1,
) -> List[str]:
text = self.preprocess(text)
sentences = [s.strip() for s in re.split(self.pattern, text) if s.strip()]
if len(sentences) <= 1:
return self.postprocess(sentences)
try:
sentence_embs = self.embedding_func(sentences)
except Exception as e:
raise RuntimeError(f"Embedding generation failed: {e}")
if len(sentence_embs) != len(sentences):
raise ValueError("Embedding function must return one vector per sentence")
chunks = []
emb_tensor = torch.stack(sentence_embs) # shape: [N, D]
current_chunk: List[str] = [sentences[0]]
for i in range(1, len(sentences)):
start_prev = max(0, i - window_size)
end_prev = i
start_next = i
end_next = min(len(sentences), i + window_size)
prev_window_emb = emb_tensor[start_prev:end_prev].mean(dim=0)
next_window_emb = emb_tensor[start_next:end_next].mean(dim=0)
similarity = F.cosine_similarity(
prev_window_emb.unsqueeze(0),
next_window_emb.unsqueeze(0),
dim=1
).item()
dynamic_threshold = max(threshold * (1 - 0.03 * (end_next - start_prev)), 0.2)
if similarity < dynamic_threshold:
chunks.append(" ".join(current_chunk))
overlap_start = max(0, len(current_chunk) - self.chunk_overlap)
current_chunk = current_chunk[overlap_start:]
current_chunk.append(sentences[i])
else:
current_chunk.append(sentences[i])
if current_chunk:
chunks.append(" ".join(current_chunk))
return self.postprocess(chunks)

36
requirements.txt Normal file
View File

@ -0,0 +1,36 @@
# python=3.12
--extra-index-url https://download.pytorch.org/whl/cu126
certifi==2025.8.3
charset-normalizer==3.4.2
colorama==0.4.6
contourpy==1.3.3
cycler==0.12.1
filelock==3.13.1
fonttools==4.59.0
fsspec==2024.6.1
huggingface-hub==0.34.3
idna==3.10
Jinja2==3.1.6
kiwisolver==1.4.8
MarkupSafe==2.1.5
matplotlib==3.10.5
mpmath==1.3.0
networkx==3.3
numpy==2.3.2
packaging==25.0
pillow==11.3.0
pyparsing==3.2.3
python-dateutil==2.9.0.post0
PyYAML==6.0.2
requests==2.32.4
safetensors==0.5.3
setuptools==78.1.1
six==1.17.0
sympy==1.13.3
tokenizers==0.21.4
torch==2.7.1+cu126
tqdm==4.67.1
typing_extensions==4.12.2
urllib3==2.5.0
wheel==0.45.1

32
scripts/chat.py Normal file
View File

@ -0,0 +1,32 @@
import os
import torch
from khaosz import Khaosz
PROJECT_ROOT = os.path.dirname(
os.path.dirname(os.path.abspath(__file__)))
def chat():
model_dir = os.path.join(PROJECT_ROOT, "params")
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
histroy = []
while True:
query = input(">> ")
if query == "!exit":
break
response_size = 0
for response, histroy in model.stream_generate(
query=query,
history=histroy,
temperature=0.7,
top_p=0.95,
top_k=30
):
print(response[response_size:], end="", flush=True)
response_size = len(response)
if __name__ == "__main__":
chat()

14
scripts/download.py Normal file
View File

@ -0,0 +1,14 @@
import os
from huggingface_hub import snapshot_download
PROJECT_ROOT = os.path.dirname(
os.path.dirname(os.path.abspath(__file__)))
if __name__ == "__main__":
snapshot_download(
repo_id="ViperEk/KHAOSZ",
local_dir=os.path.join(PROJECT_ROOT, "params"),
force_download=True
)

27
scripts/generate_ar.py Normal file
View File

@ -0,0 +1,27 @@
import os
import torch
from khaosz import Khaosz
PROJECT_ROOT = os.path.dirname(
os.path.dirname(os.path.abspath(__file__)))
def generate_text():
model_dir = os.path.join(PROJECT_ROOT, "params")
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
query = input(">> ")
response = model.text_generate(
query=query,
temperature=0.6,
top_p=0.95,
top_k=30
)
print(response)
if __name__ == "__main__":
generate_text()

25
scripts/generate_batch.py Normal file
View File

@ -0,0 +1,25 @@
import os
import torch
from khaosz import Khaosz
PROJECT_ROOT = os.path.dirname(
os.path.dirname(os.path.abspath(__file__)))
def batch_generate():
model_dir = os.path.join(PROJECT_ROOT, "params")
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
inputs = ["你好", "请问什么是人工智能", "今天天气如何", "我感到焦虑, 请问我应该怎么办", "请问什么是显卡"]
responses = model.batch_generate(
queries=inputs,
temperature=0.7,
top_p=0.95,
top_k=30
)
for q, r in zip(inputs, responses):
print((q, r))
if __name__ == "__main__":
batch_generate()

View File

@ -0,0 +1,42 @@
import os
import torch
from khaosz import Khaosz, SemanticTextSplitter, Retriever
PROJECT_ROOT = os.path.dirname(
os.path.dirname(os.path.abspath(__file__)))
if __name__ == "__main__":
model_dir = os.path.join(PROJECT_ROOT, "params")
context_path = os.path.join(PROJECT_ROOT, "README.md")
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
spliter = SemanticTextSplitter(model.encode)
retriever = Retriever()
text = open(context_path, "r", encoding="utf-8").read()
res = spliter.split(text, threshold=0.8, window_size=1)
# print(("\n" + "+"*100 + "\n").join(res))
res_embs = model.encode(res)
for sentence, emb in zip(res, res_embs):
retriever.add_vector(sentence, emb)
retrive_top_k = 5
query = "作者设计了一个怎样的模型"
emb_query = model.encode(query)
retrieved = retriever.retrieve(emb_query, retrive_top_k)
retrive_response = model.retrieve_generate(
retrieved=retrieved,
query=query,
temperature=0.7,
top_k=30,
top_p=0.95,
)
print("retrive content:")
print("\n".join([f"{idx + 1}. " + text for idx, (text, _) in enumerate(retrieved)]))
print("\n\nretrive generate:")
print(retrive_response)

18
setup.py Normal file
View File

@ -0,0 +1,18 @@
import re
from setuptools import find_packages, setup
with open("requirements.txt") as f:
required = [line for line in f.read().splitlines()
if line and re.match(r'^[^=]+==[^=]+$', line.strip())]
setup(
name="khaosz",
version="1.2.0",
packages=find_packages(),
install_requires=required,
dependency_links=[
"https://download.pytorch.org/whl/cu126",
],
python_requires="==3.12.*",
)

103
tests/test_module.py Normal file
View File

@ -0,0 +1,103 @@
import os
import json
import torch
import shutil
import pytest
import tempfile
import safetensors.torch as st
from khaosz.core import *
from khaosz.core.generator import EmbeddingEncoderCore, GeneratorCore
from tokenizers import pre_tokenizers
@pytest.fixture
def test_env():
test_dir = tempfile.mkdtemp()
config_path = os.path.join(test_dir, "config.json")
tokenizer_path = os.path.join(test_dir, "tokenizer.json")
model_path = os.path.join(test_dir, "model.safetensors")
config = {
"vocab_size": 1000,
"n_dim": 128,
"n_head": 4,
"n_kvhead": 2,
"d_ffn": 256,
"m_len": 64,
"n_layer": 2,
"norm_eps": 1e-5
}
with open(config_path, 'w') as f:
json.dump(config, f)
tokenizer = BpeTokenizer()
sp_token_iter = iter(pre_tokenizers.ByteLevel.alphabet())
tokenizer.train_from_iterator(sp_token_iter, config["vocab_size"], 1)
tokenizer.save(tokenizer_path)
transformer_config = TransformerConfig().load(config_path)
model = Transformer(transformer_config)
st.save_file(model.state_dict(), model_path)
yield {
"test_dir": test_dir,
"model": model,
"tokenizer": tokenizer,
"transformer_config": transformer_config,
}
shutil.rmtree(test_dir)
# parameter loader
def test_parameter_loader(test_env):
loaded_param = ParameterLoader.load(test_env["test_dir"])
assert loaded_param.model is not None
assert loaded_param.tokenizer is not None
assert loaded_param.config == test_env["transformer_config"]
def test_model_parameter(test_env):
save_dir = os.path.join(test_env["test_dir"], "save")
model_param = ModelParameter(test_env["model"],test_env["tokenizer"] , test_env["transformer_config"])
model_param.save(save_dir)
assert os.path.exists(os.path.join(save_dir, "model.safetensors"))
assert os.path.exists(os.path.join(save_dir, "tokenizer.json"))
assert os.path.exists(os.path.join(save_dir, "config.json"))
# transformer
def test_transformer(test_env):
model = test_env["model"]
input_ids = torch.randint(0, test_env["transformer_config"].vocab_size,
(4, test_env["transformer_config"].m_len))
output_logits = model(input_ids)["logits"]
target_shape = (4, test_env["transformer_config"].m_len, test_env["transformer_config"].vocab_size)
assert output_logits.shape == target_shape
# generator
def test_embedding_encoder_core(test_env):
parameter = ModelParameter(
test_env["model"],
test_env["tokenizer"],
test_env["transformer_config"]
)
encoder = EmbeddingEncoderCore(parameter)
single_emb = encoder.encode("测试文本")
assert isinstance(single_emb, torch.Tensor)
assert single_emb.shape[-1] == test_env["transformer_config"].n_dim
batch_emb = encoder.encode(["测试1", "测试2"])
assert isinstance(batch_emb, list)
assert len(batch_emb) == 2
def test_generator_core(test_env):
parameter = ModelParameter(
test_env["model"],
test_env["tokenizer"],
test_env["transformer_config"]
)
generator = GeneratorCore(parameter)
logits, incr = generator.compute_logits(torch.randint(0, test_env["transformer_config"].vocab_size, (4, 10)))
assert logits.shape == (4, test_env["transformer_config"].vocab_size)
assert incr == 10

203
tests/test_trainer.py Normal file
View File

@ -0,0 +1,203 @@
import os
import json
import torch
import shutil
import pytest
import pickle
import tempfile
import matplotlib
from torch.utils.data import Dataset
from khaosz.core import *
from khaosz.trainer import *
# to avoid _tkinter.TclError
matplotlib.use('Agg')
@pytest.fixture
def test_env():
test_dir = tempfile.mkdtemp()
config_path = os.path.join(test_dir, "config.json")
config = {
"vocab_size": 1000,
"n_dim": 128,
"n_head": 4,
"n_kvhead": 2,
"d_ffn": 256,
"m_len": 64,
"n_layer": 2,
"norm_eps": 1e-5
}
with open(config_path, 'w') as f:
json.dump(config, f)
transformer_config = TransformerConfig().load(config_path)
model = Transformer(transformer_config)
tokenizer = BpeTokenizer()
class DummyDataset(Dataset):
def __init__(self, length=10):
self.length = length
def __len__(self):
return self.length
def __getitem__(self, idx):
return (
torch.randint(0, 1000, (64,)),
torch.randint(0, 1000, (64,))
)
dataset = DummyDataset()
yield {
"test_dir": test_dir,
"config_path": config_path,
"transformer_config": transformer_config,
"model": model,
"tokenizer": tokenizer,
"dataset": dataset
}
shutil.rmtree(test_dir)
def test_dataset_loader(test_env):
test_dir = test_env["test_dir"]
pkl_path = os.path.join(test_dir, "test_data.pkl")
dummy_data = {"sequence": torch.randint(0, 1000, (64,))}
with open(pkl_path, "wb") as f:
pickle.dump(dummy_data, f)
loaded_dataset = DatasetLoader.load(train_type="seq", load_path=pkl_path, max_len=64, device="cpu")
assert loaded_dataset is not None
def test_training_config(test_env):
optimizer = torch.optim.AdamW(test_env["model"].parameters())
train_config = TrainConfig(
train_type="seq",
dataset=test_env["dataset"],
optimizer=optimizer,
ckpt_dir=test_env["test_dir"],
n_epoch=1,
batch_size=2,
n_iter_ckpt=5,
n_iter_step=1,
max_grad_norm=1.0,
random_seed=42
)
assert train_config.get_kwargs()["batch_size"] == 2
def test_cosine_schedule(test_env):
assert test_env is not None
schedule_config = CosineScheduleConfig(
warning_step=100,
total_iters=1000
)
kwargs = schedule_config.get_kwargs()
assert kwargs["warning_step"] == 100
assert kwargs["lr_decay_iters"] == 900
def test_sgdr_schedule(test_env):
assert test_env is not None
schedule_config = SgdrScheduleConfig(
warning_step=100,
cycle_length=200,
T_mult=2
)
kwargs = schedule_config.get_kwargs()
assert kwargs["warning_step"] == 100
assert kwargs["cycle_length"] == 200
assert kwargs["T_mult"] == 2
def test_trainer_train(test_env):
optimizer = torch.optim.AdamW(test_env["model"].parameters())
train_config = TrainConfig(
train_type="seq",
dataset=test_env["dataset"],
optimizer=optimizer,
ckpt_dir=test_env["test_dir"],
n_epoch=1,
batch_size=2,
n_iter_ckpt=5,
n_iter_step=1,
max_grad_norm=1.0,
random_seed=42
)
schedule_config = CosineScheduleConfig(
warning_step=100,
total_iters=1000
)
model_parameter = ModelParameter(
test_env["model"],
test_env["tokenizer"],
test_env["transformer_config"]
)
trainer = Trainer(model_parameter)
trainer.train(train_config, schedule_config)
def test_checkpoint(test_env):
temp_dir = test_env["test_dir"]
config = test_env["transformer_config"]
model = test_env["model"]
tokenizer = test_env["tokenizer"]
param = ModelParameter(model, tokenizer, config)
checkpoint = Checkpoint(
model=param.model,
tokenizer=param.tokenizer,
config=param.config,
loss_list=[1.0, 2.0, 3.0],
current_iter=3
)
ckpt_dir = os.path.join(temp_dir, "ckpt")
checkpoint.save(ckpt_dir)
loaded_ckpt = Checkpoint()
loaded_ckpt.load(ckpt_dir)
assert loaded_ckpt.current_iter == 3
assert loaded_ckpt.loss_list == [1.0, 2.0, 3.0]
for p1, p2 in zip(model.parameters(), loaded_ckpt.model.parameters()):
assert torch.allclose(p1, p2)
def test_checkpoint_train(test_env):
temp_dir = test_env["test_dir"]
config = test_env["transformer_config"]
model = test_env["model"]
tokenizer = test_env["tokenizer"]
dataset = test_env["dataset"]
param = ModelParameter(model, tokenizer, config)
trainer = Trainer(param)
optimizer = torch.optim.AdamW(test_env["model"].parameters())
train_config = TrainConfig(
train_type="seq",
dataset=dataset,
optimizer=optimizer,
ckpt_dir=test_env["test_dir"],
n_epoch=1,
batch_size=2,
n_iter_ckpt=5,
n_iter_step=1,
max_grad_norm=1.0,
random_seed=42
)
schedule_config = CosineScheduleConfig(
warning_step=100,
total_iters=1000
)
trainer.train(
train_config=train_config,
schedule_config=schedule_config,
)

128
train.py Normal file
View File

@ -0,0 +1,128 @@
import os
import argparse
import torch
from torch.optim import AdamW
from khaosz.core import ParameterLoader
from khaosz.trainer import Trainer, DatasetLoader, TrainConfig, CosineScheduleConfig
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
def get_files(root_path: str) -> list[str]:
paths = []
for root, _, files in os.walk(root_path):
paths.extend([os.path.join(root, file) for file in files])
return paths
def train(
train_type: str,
param_path: str,
data_root_path: str,
n_epoch: int,
batch_size: int,
n_iter_step: int,
warning_step: int,
max_lr: int,
n_iter_ckpt: int,
ckpt_dir: str,
dpo_beta: float,
adamw_betas: tuple,
adamw_weight_decay: float,
max_grad_norm: float,
embdeding_lr_rate: int,
random_seed: int,
):
assert train_type in ["seq", "sft", "dpo"]
assert os.path.exists(param_path)
parameter = ParameterLoader.load(param_path)
model = parameter.model
device = torch.device("cuda")
model = model.to(device=device, dtype=torch.bfloat16)
cache_files = get_files(data_root_path)
dataset = DatasetLoader.load(
train_type=train_type,
load_path=cache_files,
max_len=parameter.config.m_len,
device=device
)
param_groups = [
{"params": [p for n, p in model.named_parameters() if "embedding" in n], "lr": max_lr * embdeding_lr_rate},
{"params": [p for n, p in model.named_parameters() if "embedding" not in n], "lr": max_lr}
]
optim = AdamW(
param_groups,
betas=adamw_betas,
weight_decay=adamw_weight_decay
)
train_config = TrainConfig(
train_type=train_type,
dataset=dataset,
optimizer=optim,
ckpt_dir=ckpt_dir,
n_epoch=n_epoch,
batch_size=batch_size,
n_iter_ckpt=n_iter_ckpt,
n_iter_step=n_iter_step,
max_grad_norm=max_grad_norm,
random_seed=random_seed,
dpo_beta=dpo_beta
)
schedule_config = CosineScheduleConfig(
warning_step=warning_step,
total_iters=len(dataset) * n_epoch // batch_size,
)
trainer = Trainer(parameter)
trainer.train(
train_config=train_config,
schedule_config=schedule_config,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train the Transformer model.")
parser.add_argument("--train_type",choices=["seq", "sft", "dpo"], help="Train type.")
parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.")
parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.")
parser.add_argument("--n_epoch", type=int, default=1, help="Number of epochs to train.")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training.")
parser.add_argument("--n_iter_step", type=int, default=1, help="Number of iterations between each optimizer step.")
parser.add_argument("--warning_step", type=int, default=1000, help="Number of iters between warnings.")
parser.add_argument("--max_lr", type=float, default=3e-4, help="Max learning rate for training.")
parser.add_argument("--n_iter_ckpt", type=int, default=5000, help="Number of iters between checkpoints.")
parser.add_argument("--ckpt_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping.")
parser.add_argument("--adamw_betas", type=tuple, default=(0.9, 0.95), help="Beta values for AdamW optimizer.")
parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.")
parser.add_argument("--embdeding_lr_rate", type=float, default=1.0, help="The rate between the embedding layers lr rate and the max lr rate.")
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
args = parser.parse_args()
train(
param_path=args.param_path,
data_root_path=args.data_root_path,
n_epoch=args.n_epoch,
batch_size=args.batch_size,
n_iter_step=args.n_iter_step,
warning_step=args.warning_step,
max_lr=args.max_lr,
dpo_beta=args.dpo_beta,
adamw_betas=args.adamw_betas,
adamw_weight_decay=args.adamw_weight_decay,
max_grad_norm=args.max_grad_norm,
embdeding_lr_rate=args.embdeding_lr_rate,
n_iter_ckpt=args.n_iter_ckpt,
ckpt_dir=args.ckpt_dir,
train_type=args.train_type,
random_seed=args.random_seed
)