commit a4443765eee0ecdeeed0b362f84bd94ad42a4c42
Author: ViperEkura <3081035982@qq.com>
Date: Sat Sep 27 12:02:22 2025 +0800
Initial commit
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..3ecf496
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,22 @@
+# cache
+__pycache__
+.pytest_cache
+
+# params
+*.safetensors
+*.json
+*.pkl
+*.db
+
+# train_log
+*.log
+
+# ignore file
+-*
+
+# vscode file
+.vscode
+
+# build file
+build
+*.egg-info
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..261eeb9
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..0d5fa89
--- /dev/null
+++ b/README.md
@@ -0,0 +1,333 @@
+
+
+
+
+English Version
+
+This is a Chinese-English bilingual Transformer model supporting both languages. It contains model configurations and training workflows, completing training by loading parameters defined in `param_path/config.json`. The training script `train.py` parses command-line arguments, including dataset root directory, number of training epochs, batch size, checkpoint interval, and checkpoint directory.
+
+**Model Download Options (Choose One):**
+
+1. Visit [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) to access **Files and versions**
+2. Run `scripts/download.py` to download parameters
+
+**Demo Video:** [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd)
+
+Training dataset sources are listed in the **Model Card** section of the HuggingFace download link.
+
+**License:** Code follows Apache-2.0 protocol. Please credit the source code when used.
+
+- **📊 Device Selection:** Code defaults to CUDA training
+- **🌐 Performance Optimization:** `dtype=torch.bfloat16` is enabled to accelerate training and reduce memory usage. Ensure hardware supports this feature.
+- **🤖 Language Support:** Model supports Chinese and English training. The BBPE tokenizer was trained without multilingual text, so OOV (out-of-vocabulary) issues are minimized for these languages but may exist for others.
+
+### 📌 Training Guide
+
+To train this Transformer model, follow these steps:
+
+**(1). Prepare Dataset:**
+
+Place datasets in the designated root directory. Files should be text documents in Chinese, English, or mixed. Format should align with model input requirements - preferably pre-tokenized token_ids stored as `torch.Tensor` (using `torch.Tensor` saves memory compared to Python lists, which default to 64-bit precision).
+
+**(2). Install Dependencies:**
+
+```bash
+pip install -r requirements.txt
+pip install .
+```
+
+**(3). Run Training Script:**
+
+```bash
+python train.py \
+--train_type=train_type[seq, sft, dpo] \
+--data_root_path=/path/to/dataset \
+--param_path=/path/to/param_path \
+--n_epoch=5 \
+--batch_size=8 \
+--max_lr=2e-4 \
+--n_iter_ckpt=10000 \
+--ckpt_dir=checkpoints
+```
+
+**Parameters Explanation:**
+- `--train_type`: Training type (seq, sft, dpo)
+- `--data_root_path`: Root directory of the dataset
+- `--param_path`: Path to the model training parameters
+- `--n_epoch`: Total number of training epochs
+- `--batch_size`: Batch size
+- `--n_iter_step`: Number of batches per training step
+- `--warning_step`: Number of warmup steps
+- `--max_lr`: Maximum learning rate (using warmup + cosine decay)
+- `--n_iter_ckpt`: Checkpoint saving interval
+- `--ckpt_dir`: Directory to save checkpoints
+- `--resume_dir`: Resume training from the specified path
+
+Training logs will be saved in `train_log.txt`. Checkpoints will be saved in the specified directory for resuming training or evaluation.
+
+### 👉 Usage Guide
+
+**(1). Chatting with the Model:**
+
+Open `chat.py` or use streaming/non-streaming interfaces:
+
+**Streaming Output:**
+```python
+import torch
+from khaosz import Khaosz
+
+model_dir = "your_model_parameter_dir"
+model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
+history = []
+
+while True:
+ query = input(">> ")
+ if query == "!exit":
+ break
+
+ response_size = 0
+ for response, history in model.stream_generate(
+ query=query,
+ history=history,
+ temperature=0.85,
+ top_p=0.95,
+ top_k=50
+ ):
+ print(response[response_size:], end="")
+ response_size = len(response)
+```
+
+**Non-streaming Output:**
+```python
+import torch
+from khaosz import Khaosz
+
+model_dir = "your_model_parameter_dir"
+model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
+history = []
+
+while True:
+ query = input(">> ")
+ if query == "!exit":
+ break
+
+ response = model.generate(
+ query=query,
+ history=history,
+ temperature=0.85,
+ top_p=0.95,
+ top_k=50
+ )
+ print(response)
+```
+
+**(2) Retrieval-Augmented Generation (RAG):**
+
+```python
+import torch
+from khaosz import Khaosz
+
+model_dir = "your_model_parameter_dir"
+model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
+
+retrieved_content = model.retrieve_generate(
+ query=query,
+ retrieve_top_k=5,
+ temperature=0.6,
+ top_k=30,
+ top_p=0.95
+)
+print(retrieved_content)
+```
+
+### 📌 Model Specifications
+
+This model is based on a 24-layer Transformer with parameters defined in `config.json`, totaling approximately 1.0 billion (1.0B) parameters.
+
+**Key Design Choices:**
+- Weight tying between embedding and final linear layers (standard for small models to save parameters)
+- Embedding layer optimization: Without weight tying, a 10,000-word vocabulary would consume ~102M parameters (0.1B)
+
+**Limitations:**
+- May struggle with complex language phenomena due to smaller parameter size
+- Prone to overfitting on specialized datasets
+- Limited multilingual capabilities
+
+**Advantages:**
+- Runs efficiently on lower-spec hardware
+- Shorter training time compared to larger models
+
+**Training Pipeline:**
+The model has completed pre-training + SFT (Supervised Fine-Tuning) + DPO (Direct Preference Optimization) workflows. All corresponding training code is included in the repository.
+
+
+中文版本
+这是一个支持中英文双语的 Transformer 模型,能够处理两种语言。模型包含配置文件和训练流程,通过加载 `param_path/config.json` 中定义的参数完成训练。训练脚本 `train.py` 支持命令行参数解析,包括数据集根目录、训练轮数(epochs)、批量大小(batch size)、检查点保存间隔、检查点目录等。
+
+**模型下载选项(任选其一):**
+
+1. 访问 [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) 查看 **Files and versions**
+2. 运行 `scripts/download.py` 下载模型参数
+
+**演示视频:** [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd)
+
+训练数据来源请参见 HuggingFace 下载页面中的 **Model Card** 部分。
+
+**许可证:** 代码遵循 Apache-2.0 协议,使用时请注明出处。
+
+- **📊 设备选择:** 默认使用 CUDA 进行训练
+- **🌐 性能优化:** 启用 `dtype=torch.bfloat16` 以加速训练并减少内存占用,请确保硬件支持该特性
+- **🤖 语言支持:** 模型支持中文和英文训练。由于 BBPE 分词器未使用多语言文本训练,因此中英文的 OOV(未登录词)问题较少,其他语言可能存在 OOV 问题
+
+
+
+### 📌 训练指南
+
+要训练该 Transformer 模型,请按照以下步骤操作:
+
+#### **(1). 准备数据集:**
+
+将数据集放置在指定的根目录下。文件应为包含中文、英文或混合文本的文本文档。格式应符合模型输入要求——建议使用预分词后的 `token_ids` 并以 `torch.Tensor` 格式保存(使用 `torch.Tensor` 相比 Python 列表更节省内存,列表默认为 64 位精度)。
+
+#### **(2). 安装依赖:**
+
+```bash
+pip install -r requirements.txt
+pip install .
+```
+
+#### **(3). 运行训练脚本:**
+
+```bash
+python train.py \
+--train_type=train_type[seq, sft, dpo] \
+--data_root_path=/path/to/dataset \
+--param_path=/path/to/param_path \
+--n_epoch=5 \
+--batch_size=8 \
+--max_lr=2e-4 \
+--n_iter_ckpt=10000 \
+--ckpt_dir=checkpoints
+```
+
+**参数说明:**
+- `--train_type`: 训练类型(seq, sft, dpo)
+- `--data_root_path`: 数据集根目录
+- `--param_path`: 模型训练参数路径
+- `--n_epoch`: 总训练轮数
+- `--batch_size`: 批量大小
+- `--n_iter_step`: 每个训练步骤的 batch 数量
+- `--warning_step`: 预热步数(warmup steps)
+- `--max_lr`: 最大学习率(使用预热 + 余弦衰减)
+- `--n_iter_ckpt`: 检查点保存间隔
+- `--ckpt_dir`: 检查点保存目录
+- `--resume_dir`: 从指定路径恢复训练
+
+训练日志将保存在 `train_log.txt` 中。检查点将保存在指定目录,用于恢复训练或评估。
+
+
+
+### 👉 使用指南
+
+#### **(1). 与模型对话:**
+
+打开 `chat.py` 或使用流式/非流式接口:
+
+**流式输出:**
+```python
+import torch
+from khaosz import Khaosz
+
+model_dir = "your_model_parameter_dir"
+model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
+history = []
+
+while True:
+ query = input(">> ")
+ if query == "!exit":
+ break
+
+ response_size = 0
+ for response, history in model.stream_generate(
+ query=query,
+ history=history,
+ temperature=0.85,
+ top_p=0.95,
+ top_k=50
+ ):
+ print(response[response_size:], end="")
+ response_size = len(response)
+```
+
+**非流式输出:**
+```python
+import torch
+from khaosz import Khaosz
+
+model_dir = "your_model_parameter_dir"
+model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
+history = []
+
+while True:
+ query = input(">> ")
+ if query == "!exit":
+ break
+
+ response = model.generate(
+ query=query,
+ history=history,
+ temperature=0.85,
+ top_p=0.95,
+ top_k=50
+ )
+ print(response)
+```
+
+#### **(2). 基于检索的生成(RAG):**
+
+```python
+import torch
+from khaosz import Khaosz
+
+model_dir = "your_model_parameter_dir"
+model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
+
+retrieved_content = model.retrieve_generate(
+ query=query,
+ retrieve_top_k=5,
+ temperature=0.6,
+ top_k=30,
+ top_p=0.95
+)
+print(retrieved_content)
+```
+
+
+
+### 📌 模型规格说明(重复部分)
+
+该模型基于一个 24 层的 Transformer 架构,参数配置定义在 `config.json` 中,总参数量约为 10 亿(1.0B)。
+
+**关键设计选择:**
+- 在嵌入层(embedding)与最终线性层之间进行权重绑定(weight tying),这是小型模型中常见的节省参数量的做法
+- 嵌入层优化:若不进行权重绑定,一个包含 10,000 个词的词汇表将消耗约 1.02 亿(0.1B)参数
+
+**局限性:**
+- 由于参数规模较小,可能在处理复杂语言现象时表现受限
+- 在特定领域的数据集上容易出现过拟合
+- 多语言能力有限
+
+**优势:**
+- 可在低配置硬件上高效运行
+- 相较于大型模型,训练时间更短
+
+**训练流程:**
+该模型已完成预训练(pre-training)+ 监督微调(SFT, Supervised Fine-Tuning)+ 直接偏好优化(DPO, Direct Preference Optimization)的全流程。所有相关的训练代码均已包含在代码库中。
\ No newline at end of file
diff --git a/assets/docs/introduction.md b/assets/docs/introduction.md
new file mode 100644
index 0000000..7333403
--- /dev/null
+++ b/assets/docs/introduction.md
@@ -0,0 +1,89 @@
+## 模型介绍
+
+
+
+### 1. 模型搭建
+
+本模型采用Transformer架构, 使用GQA(q_head=24, kv_head=4) 机制,相较于传统的MHA可以节省KV cache 的显存占用(但是目前没有做KV cache),通过堆叠24层Transformer实现模型的搭建, 参数量为1.0b。Transformer 是自回归模型, 是通过计算前面所有的token的关系得到下一个token的概率分布
+
+
+
+什么是自回归模型呢, 在把句子拆分成token之后, 模型会预测下一个token的概率分布。这意味着模型会根据给定的上下文(即已经出现的tokens序列),计算出下一个可能的token及其对应的概率。
+
+
+
+#### 1. 自回归
+
+假设我们有一个句子被拆分成如下tokens列表:
+
+```
+["你好", "," "今天", "天气"]
+```
+
+接下来,模型会基于这个序列预测下一个可能出现的token。这通常以概率分布的形式给出,比如:
+
+```
+-> {"token": "不错", "probability": 0.4}
+-> {"token": "晴朗", "probability": 0.2}
+-> ......
+```
+
+这里,“不错”和“晴朗”是两个可能跟随在“天气”之后的tokens,并且给出了每个token成为下一个token的可能性大小。
+
+之后,我们通过采样(通过top_k, top_p, temperature参数调整采样后的结果)得到下一个token并且将下一个token加入序列作为输入
+
+```
+["你好", "," "今天", "天气", "不错"]
+```
+
+之后都是在重复这个流程, 直到遇到控制流程结束的token(<|end_of_seqence|>)模型停止处理(一般模型都会设置控制token, 不然模型会一直输出到显存爆炸)。
+
+
+
+
+
+#### 2. 因果掩码
+
+transformer 中采用注意力机制,输入的形状一般为[bsz, seq_len], 输出为[bsz, seq_len,n_dim], 为了实现预测下一个token, 模型的输入和输出必须错开来一个位置。模型预测的target必须错开一个位置, 在训练的时候我们也采用错开一个位置的方法
+
+```
+sequence : [[1, 2, 3, 4, 5, 6]]
+input_ids: [[1, 2, 3, 4, 5]]
+target_ids: [[2, 3, 4, 5, 6]]
+```
+
+
+
+注意力得分计算的公式为
+
+
+$$ s_{ij} = softmax(\frac{q_i^Tk_j}{\sqrt{d_k}}) $$
+$$ s_{ij} := s_{ij} + mask_{ij} $$
+
+
+其中注意力得分代表了模型对两个token之间相似程度的关注程度
+
+对于decoder only结构的模型, 为了防止模型从未来的位置偷到信息, 在注意力的计算过程中需要增加掩码,我们需要在注意力得分计算之前应用一个掩码。这个掩码通常是一个下三角矩阵,对于长度为n的序列,它的形状是[n, n]。下面以一个长度为5的序列为例,展示如何创建这样的因果掩码矩阵:
+
+```
+[[0, -inf, -inf, -inf, -inf],
+ [0, 0, -inf, -inf, -inf],
+ [0, 0, 0, -inf, -inf],
+ [0, 0, 0, 0, -inf],
+ [0, 0, 0, 0, 0]]
+```
+
+在这个矩阵中,0表示可以注意到的位置,而-inf表示应该被掩盖(即不应注意到)的位置。因为这个句子保证了注意力得分中 $j > i$ 的部分通过softmax 之后由`inf` 变成0, 也就是模型不能看到未来的信息
+
+
+
+#### 3. 旋转位置编码
+
+旋转位置编码(Rotary Position Embedding, RoPE)是一种为了解决Transformer模型中缺乏对序列位置信息直接建模的问题而设计的位置编码方法。与传统的位置编码(如正弦和余弦函数的位置编码)不同,RoPE通过将位置信息直接嵌入到查询(Query, Q)和键(Key, K)向量中来实现,使得模型能够更自然地处理序列中的相对位置关系。
+
+
+$$ q_i = R_i W_q x_i $$
+$$ k_j = R_j W_k x_j $$
+$$ q_i^T k_j = (R_i W_q x_i)^T( R_j W_k x_j) = x_i^T W_q^T R_{i-j} W_k x_j $$
+
+其中的 $R_{i-j}$ 控制了模型的不同token 在不同相对距离上注意力的衰减,在 $i - j$ 绝对值越大的时候, 衰减的程度越强, 通过这种方式能让模型学习到相对位置关系, 从而使得模型可以扩展和适应长序列
\ No newline at end of file
diff --git a/assets/docs/kvcache.md b/assets/docs/kvcache.md
new file mode 100644
index 0000000..304c484
--- /dev/null
+++ b/assets/docs/kvcache.md
@@ -0,0 +1,27 @@
+## kv_cache 实现
+
+根据注意力的计算公式
+
+$$
+\begin{align*}
+o_i &= \sum_j s_{ij} v_{j} \\
+s_{ij} &= \text{softmax}\left( \sum_n \frac{q_{i,n} k_{j,n}}{\sqrt{d_k}} \right)
+\end{align*}
+$$
+
+由于模型是自回归模型, 我们只用求序列最后一个部分,也就是说 $ i $ 的下标是确定的, 是序列最后一个元素, 我们求的是 $o_{n} $
+
+$$
+\begin{align*}
+o_n &= \sum_j s_{j}v_{j,n} \\
+s_j &= \text{softmax}\left(\sum_n\frac{q_n k_{j,n}}{\sqrt{d_k}} \right)
+\end{align*}
+$$
+
+如果我们把式子展开
+
+$$
+o_n = \sum_j \sum_n \text{softmax}\left(\frac{q_n k_{j,n}}{\sqrt{d_k}}\right)v_{j,n}
+$$
+
+以上表达式只有k和v存在长度下标, 而 $q$ 没有, 所以计算过程中 $q$ 的输入是确定的上次输入的最后一个token, 而 $k, v$ 是需要对不同长度的部分进行缓存的,同时缓存的时候应该注意位置编码的计算应该在kvcache的计算之前进行,否则会存在位置编码的计算错误
\ No newline at end of file
diff --git a/assets/images/project_logo.png b/assets/images/project_logo.png
new file mode 100644
index 0000000..342eb9c
Binary files /dev/null and b/assets/images/project_logo.png differ
diff --git a/assets/images/project_logo_clipped.png b/assets/images/project_logo_clipped.png
new file mode 100644
index 0000000..b02cd08
Binary files /dev/null and b/assets/images/project_logo_clipped.png differ
diff --git a/assets/images/structure.png b/assets/images/structure.png
new file mode 100644
index 0000000..5f4f3a2
Binary files /dev/null and b/assets/images/structure.png differ
diff --git a/generate.py b/generate.py
new file mode 100644
index 0000000..a18c433
--- /dev/null
+++ b/generate.py
@@ -0,0 +1,101 @@
+import os
+import torch
+import json
+import torch
+import argparse
+
+from khaosz import Khaosz
+from typing import List
+from tqdm import tqdm
+
+
+PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
+
+def batch_generate(
+ model: Khaosz,
+ queries: List[str],
+ temperature: float,
+ top_k: int,
+ top_p: float,
+ batch_size: int,
+) -> List:
+ assert batch_size > 0
+ sorted_queries = sorted(queries, key=lambda x: len(x), reverse=True)
+ original_indices = {query: idx for idx, query in enumerate(queries)}
+
+ responses = [None] * len(queries)
+ total_batches = (len(sorted_queries) + batch_size - 1) // batch_size
+
+ for i in tqdm(range(0, total_batches * batch_size, batch_size), desc="Generating responses"):
+ batch_queries = sorted_queries[i: min(i + batch_size, len(queries))]
+ if not isinstance(batch_queries, list):
+ batch_queries = [batch_queries]
+
+ batch_responses = model.batch_generate(
+ queries=batch_queries,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p
+ )
+
+ for batch_query, batch_response in zip(batch_queries, batch_responses):
+ print(f"Q: {batch_query[:50]} \nR: {batch_response[:50]})")
+
+ for query, response in zip(batch_queries, batch_responses):
+ original_idx = original_indices[query]
+ responses[original_idx] = response
+
+ return responses
+
+
+def processor(
+ model: Khaosz,
+ input_json_file: str,
+ output_json_file: str,
+ batch_size: int,
+ temperature: float,
+ top_p: float,
+ top_k: int,
+ question_key: str="question",
+):
+ with open(input_json_file, "r", encoding='utf-8') as f:
+ input_dict = [json.loads(line) for line in f]
+ queries = [item[question_key] for item in input_dict]
+
+ output_dict = batch_generate(
+ model=model,
+ queries=queries,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ batch_size=batch_size
+ )
+
+ with open(output_json_file, "w", encoding='utf-8') as f:
+ json.dump(output_dict, f, indent=4, ensure_ascii=False)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.")
+
+ parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory.")
+ parser.add_argument("--input_json_file", type=str, required=True, help="Path to the input JSONL file.")
+ parser.add_argument("--output_json_file", type=str, required=True, help="Path to the output JSONL file.")
+ parser.add_argument("--question_key", type=str, default="question", help="Key for the question in the input JSON.")
+ parser.add_argument("--temperature", type=float, default=0.60, help="Temperature for generating responses.")
+ parser.add_argument("--top_p", type=float, default=0.95, help="Top-p value for generating responses.")
+ parser.add_argument("--top_k", type=int, default=30, help="Top-k value for generating responses.")
+ parser.add_argument("--batch_size", type=int, default=1, help="Batch size for generating responses.")
+
+ args = parser.parse_args()
+ model = Khaosz(args.model_dir).to(device='cuda', dtype=torch.bfloat16)
+
+ processor(
+ model,
+ input_json_file=args.input_json_file,
+ output_json_file=args.output_json_file,
+ question_key=args.question_key,
+ batch_size=args.batch_size,
+ temperature=args.temperature,
+ top_k=args.top_k,
+ top_p=args.top_p
+ )
\ No newline at end of file
diff --git a/khaosz/__init__.py b/khaosz/__init__.py
new file mode 100644
index 0000000..b4bbcf5
--- /dev/null
+++ b/khaosz/__init__.py
@@ -0,0 +1,52 @@
+__version__ = "1.2.1"
+__author__ = "ViperEkura"
+
+from khaosz.model import Khaosz
+from khaosz.core.transformer import Transformer, TransformerConfig
+from khaosz.utils.retriever import Retriever
+from khaosz.utils.splitter import (
+ SemanticTextSplitter,
+ PriorityTextSplitter
+)
+from khaosz.core.tokenizer import BpeTokenizer
+from khaosz.core.parameter import ParameterLoader
+from khaosz.core.generator import (
+ TextGenerator,
+ ChatGenerator,
+ StreamGenerator,
+ BatchGenerator,
+ RetrievalGenerator,
+ EmbeddingEncoder
+)
+from khaosz.trainer.trainer import Trainer
+from khaosz.trainer.dataset import SeqDataset, SftDataset, DpoDataset, BaseDataset
+
+
+__all__ = [
+ # model
+ "Khaosz",
+
+ # module
+ "Transformer",
+ "TransformerConfig",
+ "BpeTokenizer",
+ "ParameterLoader",
+ "TextGenerator",
+ "ChatGenerator",
+ "StreamGenerator",
+ "BatchGenerator",
+ "RetrievalGenerator",
+ "EmbeddingEncoder",
+
+ # trainer
+ "Trainer",
+ "SeqDataset",
+ "SftDataset",
+ "DpoDataset",
+ "BaseDataset",
+
+ # utils
+ "Retriever",
+ "SemanticTextSplitter",
+ "PriorityTextSplitter",
+]
diff --git a/khaosz/core/__init__.py b/khaosz/core/__init__.py
new file mode 100644
index 0000000..0b31b94
--- /dev/null
+++ b/khaosz/core/__init__.py
@@ -0,0 +1,27 @@
+from khaosz.core.tokenizer import BpeTokenizer
+from khaosz.core.transformer import Transformer, TransformerConfig
+from khaosz.core.parameter import ParameterLoader, ModelParameter, Checkpoint
+from khaosz.core.generator import (
+ TextGenerator,
+ ChatGenerator,
+ StreamGenerator,
+ BatchGenerator,
+ RetrievalGenerator,
+ EmbeddingEncoder
+)
+
+
+__all__ = [
+ "Transformer",
+ "TransformerConfig",
+ "BpeTokenizer",
+ "ParameterLoader",
+ "ModelParameter",
+ "Checkpoint",
+ "TextGenerator",
+ "ChatGenerator",
+ "StreamGenerator",
+ "BatchGenerator",
+ "RetrievalGenerator",
+ "EmbeddingEncoder"
+]
\ No newline at end of file
diff --git a/khaosz/core/generator.py b/khaosz/core/generator.py
new file mode 100644
index 0000000..3aba293
--- /dev/null
+++ b/khaosz/core/generator.py
@@ -0,0 +1,568 @@
+import torch
+from torch import Tensor
+from typing import List, Tuple, Union, Optional, Generator, Self
+from khaosz.core.parameter import ModelParameter
+
+
+def build_prompt(query: str, history: Optional[List[Tuple[str, str]]] = None) -> str:
+ """
+ Build prompt for query and history
+
+ Args:
+ query(str): query string
+ history(Optional[List[Tuple[str, str]]]): history list of query and response
+
+ Returns:
+ str: prompt string
+
+ """
+ prompt_parts = []
+
+ if history is None:
+ history = []
+
+ for his_query, his_response in history:
+ prompt_parts.append(f"<|user|> {his_query} <|system|> {his_response}")
+
+ if query is not None:
+ prompt_parts.append(f"<|user|> {query} <|system|> ")
+
+ return "\n".join(prompt_parts)
+
+def pad_sequence(ids_list: List[List[int]], max_ids_len: int, pad_id: int) -> List[List[int]]:
+ """
+ Pad a list of sequences to a fixed length.
+
+ Args:
+ ids_list (List[List[int]]): A list of sequences.
+ max_ids_len (int): The maximum length of sequences.
+ pad_id (int): The id to pad sequences.
+
+ Returns:
+ List[List[int]]: A list of padded sequences.
+
+ """
+ new_ids_list = []
+ for ids in ids_list:
+ pad_len = max_ids_len - len(ids)
+ padded_seq = [pad_id] * pad_len + ids
+ new_ids_list.append(padded_seq)
+
+ return new_ids_list
+
+def apply_sampling_strategies(
+ logits: Tensor,
+ temperature: float,
+ top_k: int,
+ top_p: float,
+ filter_value: float = -float("inf")
+) -> Tensor:
+ """
+ Apply sampling strategies to the logits tensor.
+
+ Args:
+ logits (Tensor): The logits tensor.
+ temperature (float): The temperature parameter.
+ top_k (int): The top-k parameter.
+ top_p (float): The top-p parameter.
+ filter_value (float, optional): The filter value. Defaults to -float("inf").
+
+ Returns:
+ Tensor: The sampled logits tensor.
+
+ """
+
+ if temperature != 1.0:
+ logits = logits / temperature
+
+ if top_k > 0:
+ top_k = min(top_k, logits.size(-1))
+ indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
+ logits[indices_to_remove] = filter_value
+
+ if top_p < 1.0:
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
+
+ sorted_indices_to_remove = cumulative_probs > top_p
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
+ sorted_indices_to_remove[..., 0] = 0
+
+ indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
+ indices_to_remove.scatter_(
+ dim=1,
+ index=sorted_indices,
+ src=sorted_indices_to_remove
+ )
+
+ logits[indices_to_remove] = filter_value
+
+ return logits
+
+
+class KVCacheManager:
+ def __init__(
+ self,
+ num_layers: int,
+ batch_size: int,
+ max_len: int,
+ num_heads: int,
+ head_dim: int,
+ device: torch.device = "cuda",
+ dtype: torch.dtype = torch.bfloat16
+ ):
+ self.num_layers = num_layers
+ self.batch_size = batch_size
+ self.max_len = max_len
+ self.num_heads = num_heads
+ self.head_dim = head_dim
+ self.device = device
+ self.dtype = dtype
+
+ self._kv_cache: List[Tuple[Tensor, Tensor]] = None
+ self._seq_mask: Tensor = None
+ self._initialize()
+
+ def _initialize(self):
+ self._kv_cache = []
+ for _ in range(self.num_layers):
+ k_cache = torch.zeros(
+ (self.batch_size, self.max_len, self.num_heads, self.head_dim),
+ device=self.device, dtype=self.dtype
+ )
+ v_cache = torch.zeros(
+ (self.batch_size, self.max_len, self.num_heads, self.head_dim),
+ device=self.device, dtype=self.dtype
+ )
+ self._kv_cache.append((k_cache, v_cache))
+
+ self._seq_mask = torch.ones(
+ (self.batch_size, self.max_len),
+ device=self.device, dtype=torch.bool
+ )
+
+ def update(self, active_mask: Tensor):
+ for i in range(self.num_layers):
+ k_cache, v_cache = self._kv_cache[i]
+ new_k_cache, new_v_cache = k_cache[active_mask], v_cache[active_mask]
+ self._kv_cache[i] = (new_k_cache, new_v_cache)
+
+ self._seq_mask = self._seq_mask[active_mask]
+
+ def reset(self, full_reset=False):
+ if full_reset:
+ self._kv_cache = None
+ self._seq_mask = None
+ else:
+ self._initialize()
+
+ def set_seq_mask(self, input_ids: Tensor, pad_id: int):
+ batch_size, seq_len = input_ids.shape
+ bool_mask = (input_ids != pad_id)
+ self._seq_mask[: batch_size, : seq_len] = bool_mask
+
+ def get_kvcache(self) -> List[Tuple[Tensor, Tensor]]:
+ return self._kv_cache
+
+ def get_seq_mask(self) -> Tensor:
+ return self._seq_mask
+
+
+class GeneratorCore:
+ def __init__(self, parameter: ModelParameter):
+ self.model = parameter.model
+ self.tokenizer = parameter.tokenizer
+ self.config = parameter.config
+
+ def compute_logits(
+ self,
+ input_ids: Tensor,
+ attn_mask: Optional[Tensor] = None,
+ kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
+ start_pos: int = 0
+ ) -> Tuple[Tensor, int]:
+ with torch.inference_mode():
+ outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
+ logits = outputs["logits"][:, -1, :]
+ cache_increase = input_ids.size(-1)
+
+ return logits, cache_increase
+
+ def to(self, *args, **kargs) -> Self:
+ self.model.to(*args, **kargs)
+ return self
+
+
+class EmbeddingEncoderCore:
+ def __init__(self, parameter: ModelParameter):
+ self.model = parameter.model
+ self.tokenizer = parameter.tokenizer
+ self.config = parameter.config
+
+ def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
+ with_batch = isinstance(sentence, list)
+ ids = self.tokenizer.encode(sentence)
+ batch_ids = ids if with_batch else [ids]
+ max_model_len = self.config.m_len
+
+ all_fragments = []
+ fragment_origin_idx = []
+
+ for i, seq in enumerate(batch_ids):
+ if len(seq) > max_model_len:
+ fragments = [seq[j:j+max_model_len] for j in range(0, len(seq), max_model_len)]
+ all_fragments.extend(fragments)
+ fragment_origin_idx.extend([i] * len(fragments))
+ else:
+ all_fragments.append(seq)
+ fragment_origin_idx.append(i)
+
+ #if empty fragments
+ if not all_fragments or not ids:
+ return [] if with_batch else torch.tensor([])
+
+ device = next(self.model.parameters()).device
+ max_len = min(max(len(seq) for seq in all_fragments), max_model_len)
+
+ padded_ids = []
+ masks = []
+ for seq in all_fragments:
+ pad_len = max_len - len(seq)
+ padded_seq = seq + [self.tokenizer.pad_id] * pad_len
+ mask = [token_id != self.tokenizer.pad_id for token_id in padded_seq]
+ padded_ids.append(padded_seq)
+ masks.append(mask)
+
+ input_tensor = torch.tensor(padded_ids, device=device, dtype=torch.long)
+ seq_mask = torch.tensor(masks, device=device, dtype=torch.bool)
+
+ with torch.inference_mode():
+ outputs = self.model(input_tensor, seq_mask)["hidden_states"]
+ # [num_fragments, seq_len, hidden_size]
+ fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1))
+
+ sentence_embs: List[Tensor] = []
+ for i in range(len(batch_ids)):
+ indices = [idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i]
+ if indices is not None:
+ sum_frags = torch.sum(fragment_embs[indices, :, :], dim=1) # [frags, hidden_size]
+ length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(1) # [frags, 1]
+ emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
+ sentence_embs.append(emb.flatten())
+
+ if with_batch:
+ return [emb.flatten() for emb in sentence_embs]
+ else:
+ return sentence_embs[0].flatten()
+
+ def to(self, *args, **kargs) -> Self:
+ self.model.to(*args, **kargs)
+ return self
+
+
+class TextGenerator(GeneratorCore):
+ def __init__(self, parameter: ModelParameter):
+ super().__init__(parameter)
+
+ def generate(
+ self,
+ query: str,
+ temperature: float,
+ top_k: int,
+ top_p: float,
+ ) -> str:
+ assert temperature >= 0.0
+ assert top_k >= 0
+ assert top_p >= 0.0 and top_p <= 1.0
+
+ device = next(self.model.parameters()).device
+ cache_manager = KVCacheManager(
+ num_layers=self.config.n_layer,
+ batch_size=1,
+ max_len=self.config.m_len,
+ num_heads=self.config.n_kvhead,
+ head_dim=self.config.n_dim // self.config.n_head,
+ device=device,
+ )
+
+ ids = self.tokenizer.encode(query)
+ input_ids = torch.tensor([ids], device=device, dtype=torch.long)
+
+ start_cache_pos = len(ids)
+ cur_cache_pos = 0
+ self.model.eval()
+
+ while len(ids) < self.config.m_len:
+ kv_caches = cache_manager.get_kvcache()
+ logits, cache_increase = self.compute_logits(
+ input_ids,
+ kv_caches=kv_caches,
+ start_pos=cur_cache_pos
+ )
+ logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
+ probs = torch.softmax(logits, dim=-1)
+ next_token_id = torch.multinomial(probs, num_samples=1)
+
+ input_ids = next_token_id
+ ids.append(next_token_id.item())
+ cur_cache_pos += cache_increase
+
+ if next_token_id.item() in self.tokenizer.stop_ids:
+ break
+
+ response = self.tokenizer.decode(ids[start_cache_pos:])
+
+ return response
+
+
+
+class ChatGenerator(GeneratorCore):
+ def __init__(self, parameter: ModelParameter):
+ super().__init__(parameter)
+
+ def generate(
+ self,
+ query: str,
+ history: List[Tuple[str, str]],
+ temperature: float,
+ top_k: int,
+ top_p: float,
+ ) -> str:
+
+ assert temperature >= 0.0
+ assert top_k >= 0
+ assert top_p >= 0.0 and top_p <= 1.0
+
+ if history is None:
+ history = []
+
+ device = next(self.model.parameters()).device
+ cache_manager = KVCacheManager(
+ num_layers=self.config.n_layer,
+ batch_size=1,
+ max_len=self.config.m_len,
+ num_heads=self.config.n_kvhead,
+ head_dim=self.config.n_dim // self.config.n_head,
+ device=device,
+ )
+ ids = self.tokenizer.encode(build_prompt(query, history))
+ input_ids = torch.tensor([ids], device=device, dtype=torch.long)
+ cpy_history = history.copy()
+
+ start_cache_pos = len(ids)
+ cur_cache_pos = 0
+ self.model.eval()
+
+
+ while len(ids) < self.config.m_len:
+ kv_caches = cache_manager.get_kvcache()
+ logits, cache_increase = self.compute_logits(
+ input_ids,
+ kv_caches=kv_caches,
+ start_pos=cur_cache_pos
+ )
+ logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
+ probs = torch.softmax(logits, dim=-1)
+ next_token_id = torch.multinomial(probs, num_samples=1)
+
+ input_ids = next_token_id
+ ids.append(next_token_id.item())
+ cur_cache_pos += cache_increase
+
+ if next_token_id.item() in self.tokenizer.stop_ids:
+ break
+
+ response = self.tokenizer.decode(ids[start_cache_pos:])
+ cpy_history.append((query, response))
+
+ return response, cpy_history
+
+
+class StreamGenerator(GeneratorCore):
+ def __init__(self, parameter: ModelParameter):
+ super().__init__(parameter)
+
+ def generate(
+ self,
+ query: str,
+ history: List[Tuple[str, str]],
+ temperature: float,
+ top_k: int,
+ top_p: float,
+ ) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
+
+ assert temperature >= 0.0
+ assert top_k >= 0
+ assert top_p >= 0.0 and top_p <= 1.0
+
+ if history is None:
+ history = []
+
+ device = next(self.model.parameters()).device
+ cache_manager = KVCacheManager(
+ num_layers=self.config.n_layer,
+ batch_size=1,
+ max_len=self.config.m_len,
+ num_heads=self.config.n_kvhead,
+ head_dim=self.config.n_dim // self.config.n_head,
+ device=device,
+ )
+ ids = self.tokenizer.encode(build_prompt(query, history))
+ input_ids = torch.tensor([ids], device=device, dtype=torch.long)
+ cpy_history = history.copy()
+
+ start_cache_pos = len(ids)
+ cur_cache_pos = 0
+ self.model.eval()
+
+
+ while len(ids) < self.config.m_len:
+ kv_caches = cache_manager.get_kvcache()
+ logits, cache_increase = self.compute_logits(
+ input_ids,
+ kv_caches=kv_caches,
+ start_pos=cur_cache_pos
+ )
+ logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
+ probs = torch.softmax(logits, dim=-1)
+ next_token_id = torch.multinomial(probs, num_samples=1)
+
+ input_ids = next_token_id
+ ids.append(next_token_id.item())
+ cur_cache_pos += cache_increase
+
+ response = self.tokenizer.decode(ids[start_cache_pos:])
+ yield response, cpy_history + [(query, response)]
+
+ if next_token_id.item() in self.tokenizer.stop_ids:
+ yield response + "\n", cpy_history + [(query, response)]
+ break
+
+
+class BatchGenerator(GeneratorCore):
+ def __init__(self, parameter: ModelParameter):
+ super().__init__(parameter)
+
+ def generate(
+ self,
+ queries: List[str],
+ histories: List[List[Tuple[str, str]]],
+ temperature: float,
+ top_k: int,
+ top_p: float
+ ) -> List[str]:
+
+ assert temperature >= 0.0
+ assert top_k >= 0
+ assert top_p >= 0.0 and top_p <= 1.0
+
+ batch_size = len(queries)
+ if histories is None:
+ histories = [[] for _ in range(batch_size)]
+
+ prompts = [build_prompt(query, history) for query, history in zip(queries, histories)]
+ ids_list = [self.tokenizer.encode(prompt) for prompt in prompts]
+ max_ids_len = max(len(ids) for ids in ids_list)
+ ids_list = pad_sequence(ids_list, max_ids_len, self.tokenizer.pad_id)
+
+ device = next(self.model.parameters()).device
+ cache_manager = KVCacheManager(
+ num_layers=self.config.n_layer,
+ batch_size=batch_size,
+ max_len=self.config.m_len,
+ num_heads=self.config.n_kvhead,
+ head_dim=self.config.n_dim // self.config.n_head,
+ device=device,
+ )
+
+ input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long)
+ cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id)
+ activate_task_mask = [True] * batch_size
+
+ start_cache_pos = max_ids_len
+ cur_cache_pos = 0
+
+ while max_ids_len < self.config.m_len and sum(activate_task_mask) != 0:
+ kv_caches = cache_manager.get_kvcache()
+ attn_mask =cache_manager.get_seq_mask()
+
+ logits, cache_increase = self.compute_logits(
+ input_tensor,
+ attn_mask=attn_mask,
+ kv_caches=kv_caches,
+ start_pos=cur_cache_pos
+ )
+
+ cur_cache_pos += cache_increase
+ logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
+ probs = torch.softmax(logits, dim=-1)
+ next_token_id = torch.multinomial(probs, num_samples=1)
+
+ active_mask = []
+ c_ids = 0
+
+ for i in range(batch_size):
+ if activate_task_mask[i]:
+ token = next_token_id[c_ids, :].item()
+ ids_list[i].append(token)
+ c_ids += 1
+
+ is_active = not token in self.tokenizer.stop_ids
+ activate_task_mask[i] = is_active
+ active_mask.append(is_active)
+
+ active_mask = torch.tensor(active_mask, device=device, dtype=torch.bool)
+ cache_manager.update(active_mask)
+ input_tensor = next_token_id[active_mask, :]
+
+ max_ids_len += 1
+
+
+ responses = [str()] * batch_size
+ for i in range(batch_size):
+ responses[i] = self.tokenizer.decode(ids_list[i][start_cache_pos:])
+ histories[i].append((queries[i], responses[i]))
+
+ return responses
+
+
+
+class RetrievalGenerator(GeneratorCore):
+ def __init__(self, retriever_parameter: ModelParameter):
+ super().__init__(retriever_parameter)
+
+ def generate(
+ self,
+ retrieved: List[str],
+ query: str,
+ history: List[Tuple[str, str]],
+ temperature: float,
+ top_k: int,
+ top_p: float,
+ ) -> str:
+ assert temperature >= 0.0
+ assert top_k >= 0
+ assert top_p >= 0.0 and top_p <= 1.0
+
+ if history is None:
+ history = []
+
+ retrieved = "\n".join([f"{idx + 1}. {key}" for idx, key in enumerate(retrieved)]) if retrieved else ""
+ retrieved_query = f"{retrieved}\n\n根据以上内容回答: {query}" if retrieved else query
+ parameter = ModelParameter(self.model, self.tokenizer, self.config)
+
+ return ChatGenerator(parameter).generate(
+ retrieved_query,
+ history,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ )
+
+class EmbeddingEncoder(EmbeddingEncoderCore):
+ def __init__(self, parameter: ModelParameter):
+ super().__init__(parameter)
+
+ def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
+ return super().encode(sentence)
+
\ No newline at end of file
diff --git a/khaosz/core/parameter.py b/khaosz/core/parameter.py
new file mode 100644
index 0000000..f752792
--- /dev/null
+++ b/khaosz/core/parameter.py
@@ -0,0 +1,238 @@
+import pickle as pkl
+import matplotlib.pyplot as plt
+import safetensors.torch as st
+import torch.nn as nn
+import torch.optim as optim
+
+from dataclasses import dataclass, field
+from typing import Optional, Self, Union
+from pathlib import Path
+
+from khaosz.core.tokenizer import BpeTokenizer
+from khaosz.core.transformer import TransformerConfig, Transformer
+
+
+class BaseModelIO:
+ """Base class for model I/O operations."""
+
+ def __init__(
+ self,
+ model: Optional[nn.Module] = None,
+ tokenizer: Optional[BpeTokenizer] = None,
+ config: Optional[TransformerConfig] = None
+ ):
+ self.model = model
+ self.tokenizer = tokenizer or BpeTokenizer()
+ self.config = config or TransformerConfig()
+
+ def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
+ """Get standardized file paths for model components."""
+ dir_path = Path(directory)
+ return {
+ "model": dir_path / "model.safetensors",
+ "config": dir_path / "config.json",
+ "tokenizer": dir_path / "tokenizer.json"
+ }
+
+ def save_components(self, save_dir: Union[str, Path]):
+ """Save core model components."""
+ paths = self._get_file_paths(save_dir)
+ paths["model"].parent.mkdir(parents=True, exist_ok=True)
+
+ if self.model is not None:
+ st.save_file(self.model.state_dict(), str(paths["model"]))
+ self.config.save(str(paths["config"]))
+ self.tokenizer.save(str(paths["tokenizer"]))
+
+ def load_components(self, load_dir: Union[str, Path]) -> Self:
+ """Load core model components."""
+ paths = self._get_file_paths(load_dir)
+
+ self.config.load(str(paths["config"]))
+ self.tokenizer.load(str(paths["tokenizer"]))
+
+ if paths["model"].exists():
+ state_dict = st.load_file(str(paths["model"]))
+ if self.model is None:
+ self.model = Transformer(self.config)
+ self.model.load_state_dict(state_dict)
+
+ return self
+
+ def to(self, *args, **kwargs) -> Self:
+ """Move model to device."""
+ if self.model is not None:
+ self.model.to(*args, **kwargs)
+ return self
+
+
+@dataclass
+class ModelParameter(BaseModelIO):
+ """Container for model parameters with serialization capabilities."""
+
+ model: Optional[nn.Module] = field(
+ default=None,
+ metadata={"help": "Transformer model."}
+ )
+ tokenizer: BpeTokenizer = field(
+ default_factory=BpeTokenizer,
+ metadata={"help": "Tokenizer for the model."}
+ )
+ config: TransformerConfig = field(
+ default_factory=TransformerConfig,
+ metadata={"help": "Transformer model configuration."}
+ )
+
+ def save(self, save_dir: Union[str, Path]):
+ """Save model parameters."""
+ self.save_components(save_dir)
+
+ def load(self, load_dir: Union[str, Path]) -> Self:
+ """Load model parameters."""
+ return self.load_components(load_dir)
+
+
+@dataclass
+class Checkpoint(BaseModelIO):
+ """Extended model parameters with training state."""
+
+ model: Optional[nn.Module] = field(
+ default=None,
+ metadata={"help": "Transformer model."}
+ )
+ tokenizer: BpeTokenizer = field(
+ default_factory=BpeTokenizer,
+ metadata={"help": "Tokenizer for the model."}
+ )
+ config: TransformerConfig = field(
+ default_factory=TransformerConfig,
+ metadata={"help": "Transformer model configuration."}
+ )
+ loss_list: list[float] = field(
+ default_factory=list,
+ metadata={"help": "List of training losses."}
+ )
+ current_iter: int = field(
+ default=0,
+ metadata={"help": "Current training iteration."}
+ )
+ optimizer: Optional[optim.Optimizer] = field(
+ default=None,
+ metadata={"help": "Optimizer state."}
+ )
+
+ def __post_init__(self):
+ # Ensure current_iter matches loss list length if not explicitly set
+ if self.current_iter == 0 and self.loss_list:
+ self.current_iter = len(self.loss_list)
+
+ def _get_training_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
+ """Get file paths for training-specific files."""
+ paths = self._get_file_paths(directory)
+ paths.update({
+ "loss_list": paths["model"].parent / "loss.pkl",
+ "loss_plot": paths["model"].parent / "loss.png",
+ "optimizer": paths["model"].parent / "optimizer.pkl"
+ })
+ return paths
+
+ def save_training_state(self, save_dir: Union[str, Path]):
+ """Save training-specific state."""
+ paths = self._get_training_paths(save_dir)
+
+ # Save loss plot
+ self._plot_loss(str(paths["loss_plot"]))
+
+ # Save loss list
+ with open(str(paths["loss_list"]), "wb") as f:
+ pkl.dump(self.loss_list, f)
+
+ # Save optimizer state
+ if self.optimizer is not None:
+ with open(str(paths["optimizer"]), "wb") as f:
+ pkl.dump(self.optimizer.state_dict(), f)
+
+ def load_training_state(self, load_dir: Union[str, Path]) -> Self:
+ """Load training-specific state."""
+ paths = self._get_training_paths(load_dir)
+
+ # Load loss list
+ if paths["loss_list"].exists():
+ with open(str(paths["loss_list"]), "rb") as f:
+ self.loss_list = pkl.load(f)
+ self.current_iter = len(self.loss_list)
+
+ # Load optimizer state
+ if paths["optimizer"].exists() and self.optimizer is not None:
+ with open(str(paths["optimizer"]), "rb") as f:
+ optim_state = pkl.load(f)
+ self.optimizer.load_state_dict(optim_state)
+
+ return self
+
+ def _plot_loss(self, save_path: str):
+ """Plot and save loss curve."""
+ if not self.loss_list:
+ return
+
+ plt.figure(figsize=(10, 6))
+ plt.plot(self.loss_list)
+ plt.title(f"Training Loss - Iteration {self.current_iter}")
+ plt.xlabel("Batch")
+ plt.ylabel("Loss")
+ plt.grid(True)
+ plt.savefig(save_path, dpi=300, bbox_inches="tight")
+ plt.close()
+
+ def save(self, save_dir: Union[str, Path]):
+ """Save complete checkpoint."""
+ self.save_components(save_dir)
+ self.save_training_state(save_dir)
+
+ def load(self, load_dir: Union[str, Path]) -> Self:
+ """Load complete checkpoint."""
+ self.load_components(load_dir)
+ self.load_training_state(load_dir)
+ return self
+
+
+class ParameterLoader:
+ """Factory class for loading model parameters or checkpoints."""
+
+ @staticmethod
+ def load(load_dir: Union[str, Path]) -> Union[ModelParameter, Checkpoint]:
+ """Load either ModelParameter or Checkpoint based on directory contents."""
+ load_dir = Path(load_dir)
+
+ # Check for training-specific files
+ loss_file = load_dir / "loss.pkl"
+ has_training_data = loss_file.exists()
+
+ # Create appropriate instance
+ if has_training_data:
+ checkpoint = Checkpoint()
+ checkpoint.load(str(load_dir))
+ return checkpoint
+ else:
+ params = ModelParameter()
+ params.load(str(load_dir))
+ return params
+
+ @staticmethod
+ def create_checkpoint(
+ model: nn.Module,
+ tokenizer: BpeTokenizer,
+ config: TransformerConfig,
+ loss_list: Optional[list[float]] = None,
+ optimizer: Optional[optim.Optimizer] = None
+ ) -> Checkpoint:
+ """Convenience method to create a training checkpoint."""
+ return Checkpoint(
+ model=model,
+ tokenizer=tokenizer,
+ config=config,
+ loss_list=loss_list or [],
+ optimizer=optimizer
+ )
+
+
diff --git a/khaosz/core/tokenizer.py b/khaosz/core/tokenizer.py
new file mode 100644
index 0000000..1f0f1ed
--- /dev/null
+++ b/khaosz/core/tokenizer.py
@@ -0,0 +1,111 @@
+from tokenizers import Tokenizer, Encoding
+from tokenizers import decoders, processors, normalizers, pre_tokenizers
+from tokenizers.models import BPE
+from tokenizers.trainers import BpeTrainer
+from typing import List, Union
+
+
+class BpeTokenizer:
+ def __init__(self, path=None):
+ self._control_tokens = ["", "", ""]
+ self._special_tokens = ["<|user|>", "<|system|>"]
+ model = BPE()
+ tokenizer = Tokenizer(model)
+ tokenizer.normalizer = normalizers.Sequence([
+ normalizers.NFC()
+ ])
+ tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
+ pre_tokenizers.Punctuation(behavior="isolated"),
+ pre_tokenizers.Metaspace(prepend_scheme="never"),
+ pre_tokenizers.Split(pattern=r"(\d+|[a-zA-Z]+|(?:'s|'t|'re|'ve|'m|'ll|'d))", behavior="isolated"),
+ pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
+ ])
+ tokenizer.decoder = decoders.Sequence([
+ decoders.ByteLevel(),
+ decoders.Metaspace(prepend_scheme="never")
+ ])
+ tokenizer.post_processor = processors.Sequence([
+ processors.ByteLevel(trim_offsets=False)
+ ])
+ self._tokenizer = tokenizer
+
+ if path is not None:
+ self._tokenizer = Tokenizer.from_file(path)
+
+ def _prepare_trainer(self, vocab_size: int, min_freq: int, reserved_token_size: int) -> tuple:
+ assert reserved_token_size > len(self._special_tokens)
+ reserved_tokens = [f"<|rsv{i:02d}|>" for i in range(reserved_token_size - len(self._special_tokens))]
+ detail_vocab_size = vocab_size - (len(reserved_tokens) + len(self._special_tokens))
+
+ alphabet = pre_tokenizers.ByteLevel.alphabet()
+ min_size = len(alphabet) + len(self._control_tokens)
+ assert detail_vocab_size > min_size
+
+ trainer = BpeTrainer(
+ vocab_size=detail_vocab_size,
+ min_frequency=min_freq,
+ limit_alphabet=detail_vocab_size // 4,
+ max_token_length=18,
+ special_tokens=self._control_tokens,
+ show_progress=True,
+ initial_alphabet=alphabet,
+ )
+
+ return trainer, detail_vocab_size, reserved_tokens
+
+ def train(self, files, vocab_size, min_freq, reserved_token_size=100):
+ trainer, _, reserved_tokens = self._prepare_trainer(
+ vocab_size=vocab_size,
+ min_freq=min_freq,
+ reserved_token_size=reserved_token_size
+ )
+ self._tokenizer.train(files=files, trainer=trainer)
+ self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
+
+ def train_from_iterator(self, iterator, vocab_size, min_freq, reserved_token_size=100):
+ trainer, _, reserved_tokens = self._prepare_trainer(
+ vocab_size=vocab_size,
+ min_freq=min_freq,
+ reserved_token_size=reserved_token_size
+ )
+ self._tokenizer.train_from_iterator(iterator=iterator, trainer=trainer)
+ self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
+
+ def save(self, path):
+ self._tokenizer.save(path)
+
+ def load(self, path):
+ self._tokenizer = Tokenizer.from_file(path)
+
+ def encode(self, tokens: Union[str, List[str]], out_ids: bool=True, add_special_tokens: bool=False) -> List:
+ if isinstance(tokens, str):
+ encoded: Encoding = self._tokenizer.encode(tokens, add_special_tokens=add_special_tokens)
+ return encoded.ids if out_ids else encoded.tokens
+ elif isinstance(tokens, list):
+ encoded_list: List[Encoding] = self._tokenizer.encode_batch(tokens, add_special_tokens=add_special_tokens)
+ return [encoded.ids if out_ids else encoded.tokens for encoded in encoded_list]
+
+ def decode(self, tokens: List[int], skip_special_tokens: bool=True) -> str:
+ return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
+
+ def __len__(self) -> int:
+ return self._tokenizer.get_vocab_size()
+
+ @property
+ def stop_ids(self) -> List[int]:
+ stop_ids = []
+ for token in self._control_tokens:
+ stop_ids.append(self._tokenizer.token_to_id(token))
+ return stop_ids
+
+ @property
+ def bos_id(self) -> int:
+ return self._tokenizer.token_to_id("")
+
+ @property
+ def eos_id(self) -> int:
+ return self._tokenizer.token_to_id("")
+
+ @property
+ def pad_id(self) -> int:
+ return self._tokenizer.token_to_id("")
diff --git a/khaosz/core/transformer.py b/khaosz/core/transformer.py
new file mode 100644
index 0000000..b5a2d66
--- /dev/null
+++ b/khaosz/core/transformer.py
@@ -0,0 +1,341 @@
+import json
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from torch import Tensor
+from torch.nn import init
+from dataclasses import asdict, dataclass
+from typing import List, Optional, Self, Tuple
+
+
+def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
+ """
+ Repeat k times along the dimension for attention heads.
+ Args:
+ x (Tensor): The input tensor.
+ n_rep (int): The number of repetitions.
+ Returns:
+ Tensor: The repeated tensor.
+ """
+
+ bs, slen, n_heads, head_dim = x.shape
+ if n_rep == 1:
+ return x
+ return (
+ x[:, :, :, None, :]
+ .expand(bs, slen, n_heads, n_rep, head_dim)
+ .reshape(bs, slen, n_heads * n_rep, head_dim)
+ )
+
+def get_rotary_emb(
+ dim: int,
+ max_len: int,
+ base: float = 10000,
+ device: torch.device = "cuda",
+ ) -> torch.Tensor:
+ """
+ Get the rotary embedding for the given dimension and maximum length.
+ Args:
+ dim (int): The dimension of the input.
+ max_len (int): The maximum length of the input.
+ base (float, optional): The base for the frequency. Defaults to 10000.
+ device (torch.device, optional): The device to use. Defaults to "cuda".
+ Returns:
+ Tensor: The rotary embedding tensor.
+ """
+
+ theta = base ** (-torch.arange(0, dim, 2, device=device).float() / dim)
+ t = torch.arange(0, max_len, device=device).float()
+ freqs = torch.outer(t, theta)
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
+
+ return freqs_cis
+
+def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
+ """
+ Apply rotary embedding to the input tensor.
+ Args:
+ x (Tensor): The input tensor.
+ freqs_cis (Tensor): The rotary embedding tensor.
+ Returns:
+ Tensor: The output tensor.
+ """
+
+ dtype = x.dtype
+ seq_len = x.size(1)
+
+ x_complex = torch.view_as_complex(x.view(*x.shape[:-1], -1, 2).float())
+ freqs_cis = freqs_cis.reshape(1, seq_len, 1, -1)
+ x_out = torch.view_as_real(x_complex * freqs_cis).flatten(3)
+
+ return x_out.to(dtype)
+
+def create_attention_mask(
+ seq_mask: Tensor,
+ start_pos: int = 0,
+ seq_len: int = 0,
+ is_causal: bool = False,
+ device: torch.device = "cuda",
+ dtype: torch.dtype = torch.float32
+ ) -> Tensor:
+ """
+ Create attention mask for GQA
+ Args:
+ seq_mask (Tensor): A tensor indicating whether each position is valid or not.
+ start_pos (int): The starting position of the sequence.
+ seq_len (int): The length of the sequence.
+ is_causal (bool): Whether the attention is causal or not.
+ device (torch.device): The device to use.
+ Returns:
+ Tensor: The attention mask tensor.
+ """
+
+ if start_pos != 0 and seq_mask is None:
+ # for single prompt chat
+ seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
+
+ if seq_mask is None:
+ return None
+
+ batch_size = seq_mask.size(0)
+ seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool)
+ # (bsz, start_pos + seq_len)
+ expanded_mask = seq_mask.unsqueeze(1).expand(batch_size, seq_len, start_pos + seq_len)
+ # (bsz, seq_len, start_pos + seq_len)
+
+ if is_causal:
+ causal_mask = torch.tril(
+ torch.ones((seq_len, start_pos + seq_len), dtype=torch.bool, device=device),
+ diagonal=start_pos
+ )
+ causal_mask = causal_mask.unsqueeze(0).expand(batch_size, seq_len, start_pos + seq_len)
+ expanded_mask = expanded_mask & causal_mask
+
+ attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
+ attention_mask = attention_mask.masked_fill_(~expanded_mask, -torch.finfo(dtype).max / 2).unsqueeze(1)
+ # (bsz, 1, seq_len, seq_len + start_pos)
+
+ return attention_mask
+
+
+@dataclass
+class TransformerConfig:
+ # basic config
+ vocab_size: Optional[int] = None
+ n_dim: Optional[int] = None
+ n_head: Optional[int] = None
+ n_layer: Optional[int] = None
+ m_len: Optional[int] = None
+ norm_eps: Optional[float] = None
+ d_ffn: Optional[int] = None
+
+ # GQA
+ n_kvhead: Optional[int] = None
+
+
+ def load(self, config_path: str) -> Self:
+ with open(config_path, 'r') as f:
+ config: dict = json.load(f)
+ for key, value in config.items():
+ if hasattr(self, key):
+ setattr(self, key, value)
+
+ return self
+
+ def save(self, config_path: str) -> None:
+ config_dict = asdict(self)
+ config_dict = {k: v for k, v in config_dict.items() if v is not None}
+ with open(config_path, 'w') as f:
+ json.dump(config_dict, f, indent=4)
+
+
+class Linear(nn.Module):
+ def __init__(self, in_dim: int, out_dim: int, bias: bool=False):
+ super().__init__()
+ self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
+ self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
+ init.normal_(self.weight, mean=0, std=0.006)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return F.linear(x, self.weight, self.bias)
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, n_dim, norm_eps):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(n_dim))
+ self.norm_eps = norm_eps
+
+ def forward(self, x: Tensor) -> Tensor:
+ dtype = x.dtype
+ x = x.float()
+ mean_square = torch.mean(torch.pow(x, 2), dim=-1, keepdim=True)
+ norm = x * torch.rsqrt(mean_square + self.norm_eps)
+ norm = norm.to(dtype)
+ out = norm * self.weight
+ return out
+
+
+class MLP(nn.Module):
+ def __init__(self, n_dim: int, d_ffn: int):
+ super().__init__()
+ self.up = Linear(n_dim, d_ffn)
+ self.gate = Linear(n_dim, d_ffn)
+ self.down = Linear(d_ffn, n_dim)
+
+ def forward(self, x: Tensor) -> Tensor:
+ gated = self.up(x) * F.silu(self.gate(x))
+ out = self.down(gated)
+ return out
+
+
+class GQA(nn.Module):
+ def __init__(
+ self,
+ n_dim: int,
+ n_head: int,
+ n_kvhead: int,
+ ):
+ super().__init__()
+ assert n_dim % n_head == 0
+ assert n_head % n_kvhead == 0
+
+ self.head_dim = n_dim // n_head
+ self.n_dim = n_dim
+ self.n_heads = n_head
+ self.n_kvheads = n_kvhead
+ self.n_rep = n_head // n_kvhead
+
+ self.q_proj = Linear(n_dim, n_head * self.head_dim)
+ self.k_proj = Linear(n_dim, n_kvhead * self.head_dim)
+ self.v_proj = Linear(n_dim, n_kvhead * self.head_dim)
+ self.o_proj = Linear(n_dim, n_dim)
+
+ def forward(
+ self,
+ x: Tensor,
+ freqs_cis: Tensor,
+ mask: Tensor = None,
+ kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
+ start_pos: int = 0
+ ) -> Tensor:
+ bsz, seq_len, _ = x.size()
+ # x(bsz, seq_len, n_heads * head_dim) -> (bsz, seq_len, n_heads, head_dim)
+ q = self._split_heads(self.q_proj(x), self.n_heads)
+ k = self._split_heads(self.k_proj(x), self.n_kvheads)
+ v = self._split_heads(self.v_proj(x), self.n_kvheads)
+ q, k = apply_rotary_emb(q, freqs_cis), apply_rotary_emb(k, freqs_cis)
+
+ if kv_cache is not None:
+ k_cache, v_cache = kv_cache
+
+ # copy to cache
+ k_cache[:bsz, start_pos:start_pos + seq_len] = k
+ v_cache[:bsz, start_pos:start_pos + seq_len] = v
+
+ # get cache
+ k = k_cache[:bsz, :start_pos + seq_len]
+ v = v_cache[:bsz, :start_pos + seq_len]
+
+ k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
+
+ # (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
+ q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
+ sdqa_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=(mask == None)).permute(0, 2, 1, 3)
+ out = self.o_proj(sdqa_out.contiguous().view(bsz, seq_len, -1))
+
+ return out
+
+ def _split_heads(self, x: Tensor, n_heads) -> Tensor:
+ batch_size, seq_len, _ = x.shape
+ x = x.reshape(batch_size, seq_len, n_heads, self.head_dim)
+ return x
+
+
+class DecoderBlock(nn.Module):
+ def __init__(self, n_dim, n_head, d_ffn, n_kvhead, norm_eps):
+ super().__init__()
+ self.attention = GQA(n_dim, n_head, n_kvhead)
+ self.norm_attn = RMSNorm(n_dim, norm_eps)
+ self.ffn = MLP(n_dim, d_ffn)
+ self.norm_ffn = RMSNorm(n_dim, norm_eps)
+
+ def forward(
+ self,
+ x: Tensor,
+ freqs_cis: Tensor,
+ attention_mask: Optional[Tensor] = None,
+ kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
+ start_pos: int = 0
+ ) -> Tensor:
+ # attention
+ attn_output = self.attention(
+ self.norm_attn(x),
+ freqs_cis,
+ attention_mask,
+ kv_cache,
+ start_pos
+ )
+ x = attn_output + x
+
+ # feed forward
+ x = self.ffn(self.norm_ffn(x)) + x
+
+ return x
+
+
+class Transformer(nn.Module):
+ def __init__(self, config: TransformerConfig):
+ super().__init__()
+ self.embedding = nn.Parameter(torch.empty(config.vocab_size, config.n_dim))
+ self.layers = nn.ModuleList([
+ DecoderBlock(
+ config.n_dim,
+ config.n_head,
+ config.d_ffn,
+ config.n_kvhead,
+ config.norm_eps
+ )
+ for _ in range(config.n_layer)
+ ])
+ self.norm = RMSNorm(config.n_dim, config.norm_eps)
+ self.freq_cis = get_rotary_emb(config.n_dim // config.n_head, config.m_len)
+ init.normal_(self.embedding, mean=0, std=0.02)
+
+ def forward(
+ self,
+ input_ids: Tensor,
+ seq_mask: Optional[Tensor]=None,
+ persistent_key_values: Optional[List[Tuple[Tensor, Tensor]]]=None,
+ start_pos: int = 0
+ ) -> Tensor:
+ assert input_ids.ndim == 2
+ seq_len = input_ids.size(-1)
+ x = F.embedding(input_ids, self.embedding)
+
+ self.freq_cis = self.freq_cis.to(x.device)
+ freqs_cis = self.freq_cis[start_pos:start_pos+seq_len]
+ has_kvcache = persistent_key_values is not None
+
+ attn_mask = create_attention_mask(
+ seq_mask,
+ start_pos=start_pos,
+ seq_len=seq_len,
+ is_causal=has_kvcache,
+ device=x.device,
+ dtype=x.dtype
+ )
+
+ for i, layer in enumerate(self.layers):
+ kv_cache = persistent_key_values[i] if persistent_key_values else None
+ x = layer(x, freqs_cis, attn_mask, kv_cache, start_pos)
+
+ hidden_states = self.norm(x)
+ logits = F.linear(hidden_states, self.embedding)
+
+ return {
+ "logits": logits,
+ "hidden_states": hidden_states
+ }
+
\ No newline at end of file
diff --git a/khaosz/model.py b/khaosz/model.py
new file mode 100644
index 0000000..3723f29
--- /dev/null
+++ b/khaosz/model.py
@@ -0,0 +1,112 @@
+from torch import Tensor
+from typing import List, Tuple, Generator, Union
+
+from khaosz.core.generator import (
+ TextGenerator,
+ ChatGenerator,
+ StreamGenerator,
+ BatchGenerator,
+ RetrievalGenerator,
+ EmbeddingEncoder
+)
+from khaosz.core.parameter import ParameterLoader
+
+
+class Khaosz:
+ def __init__(self, model_dir: str):
+ self.parameter = ParameterLoader.load(model_dir)
+
+ def to(self, *args, **kwargs):
+ self.parameter.to(*args, **kwargs)
+ return self
+
+ def generate(
+ self,
+ query: str,
+ history: List[Tuple[str, str]]=None,
+ temperature: float=0.8,
+ top_k: int=50,
+ top_p: float=0.95,
+ ) -> str:
+ generator = ChatGenerator(self.parameter)
+ return generator.generate(
+ query,
+ history=history,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ )
+
+ def batch_generate(
+ self,
+ queries: List[str],
+ histories: List[Tuple[str, str]]=None,
+ temperature: float=0.8,
+ top_k: int=50,
+ top_p: float=0.95,
+ ) -> List[str]:
+ generator = BatchGenerator(self.parameter)
+ return generator.generate(
+ queries,
+ histories=histories,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ )
+
+
+ def stream_generate(
+ self,
+ query: str,
+ history: List[Tuple[str, str]]=None,
+ temperature: float=0.8,
+ top_k: int=50,
+ top_p: float=0.95,
+ ) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
+ stream_generator = StreamGenerator(self.parameter)
+ return stream_generator.generate(
+ query,
+ history=history,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ )
+
+ def retrieve_generate(
+ self,
+ retrieved,
+ query: str,
+ history: List[Tuple[str, str]] = None,
+ temperature: float=0.8,
+ top_k: int=50,
+ top_p: float=0.95,
+ ) -> str:
+ generator = RetrievalGenerator(self.parameter)
+ return generator.generate(
+ retrieved,
+ query,
+ history=history,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ )
+
+ def text_generate(
+ self,
+ query: str,
+ temperature: float=0.8,
+ top_k: int=50,
+ top_p: float=0.95,
+ ) -> str:
+ generator = TextGenerator(self.parameter)
+
+ return generator.generate(
+ query,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ )
+
+ def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
+ encoder = EmbeddingEncoder(self.parameter)
+ return encoder.encode(sentence)
\ No newline at end of file
diff --git a/khaosz/trainer/__init__.py b/khaosz/trainer/__init__.py
new file mode 100644
index 0000000..7630e6a
--- /dev/null
+++ b/khaosz/trainer/__init__.py
@@ -0,0 +1,11 @@
+from khaosz.trainer.dataset import DatasetLoader
+from khaosz.trainer.trainer import Trainer
+from khaosz.trainer.strategy import TrainConfig, CosineScheduleConfig, SgdrScheduleConfig
+
+__all__ = [
+ "DatasetLoader",
+ "Trainer",
+ "TrainConfig",
+ "CosineScheduleConfig",
+ "SgdrScheduleConfig",
+]
\ No newline at end of file
diff --git a/khaosz/trainer/dataset.py b/khaosz/trainer/dataset.py
new file mode 100644
index 0000000..439567d
--- /dev/null
+++ b/khaosz/trainer/dataset.py
@@ -0,0 +1,210 @@
+import torch
+import bisect
+import pickle as pkl
+from abc import ABC, abstractmethod
+from torch import Tensor
+from torch.utils.data import Dataset
+from typing import Callable, List, Dict, Literal, Union
+
+MutiSeg = Dict[str, List[Tensor]]
+Seg = Dict[str, Tensor]
+
+def load_pkl_files(paths: List[str]):
+ segments: MutiSeg = {}
+ total_samples = 0
+
+ for path in paths:
+ with open(path, "rb") as f:
+ pkl_file: Seg = pkl.load(f)
+ for key, value in pkl_file.items():
+ if key not in segments:
+ segments[key] = []
+ segments[key].append(value)
+ first_key = list(pkl_file.keys())[0]
+ total_samples += pkl_file[first_key].numel()
+
+ return segments, total_samples
+
+
+class BaseSegmentFetcher:
+ def __init__(self, segments: List[Tensor]):
+ self.segments = segments
+ self.cum_lengths = []
+ total = 0
+ for seg in segments:
+ total += len(seg)
+ self.cum_lengths.append(total)
+ self.total_length = total if segments else 0
+
+ def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
+ if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length):
+ raise ValueError("begin_idx or end_idx out of bounds")
+ if begin_idx >= end_idx:
+ return torch.tensor([], dtype=torch.long)
+
+ seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx - 1)
+ seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx - 1)
+
+ result_segments = []
+
+ for i in range(seg_start_idx, seg_end_idx + 1):
+ prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
+ start = max(begin_idx - prev_cum, 0)
+ end = min(end_idx - prev_cum, len(self.segments[i]))
+ result_segments.append(self.segments[i][start:end])
+
+ return torch.cat(result_segments, dim=0)
+
+
+class MutiSegmentFetcher:
+ def __init__(self, muti_segments: MutiSeg):
+ self.muti_keys = list(muti_segments.keys())
+ self.muti_fetchers = {
+ key: BaseSegmentFetcher(segments)
+ for key, segments in muti_segments.items()
+ }
+
+ def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Union[Tensor, Seg]:
+ fetch_dict = {}
+ keys = [keys] if isinstance(keys, str) else keys
+
+ for key in keys:
+ fetcher = self.muti_fetchers[key]
+ fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
+ fetch_dict[key] = fetch_tensor
+
+ return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
+
+ def fetch_data(self, begin_idx: int, end_idx: int) -> Union[Tensor, Seg]:
+ return self.key_fetch(begin_idx, end_idx, self.muti_keys)
+
+
+class BaseDataset(Dataset, ABC):
+ def __init__(self, chunk_size: int, device: str):
+ super().__init__()
+ self.segments: MutiSeg = {}
+ self.chunk_size = chunk_size
+ self.total_samples = 0
+ self.device = device
+
+ def save(self, save_path: str):
+ first_item = self.segments[keys[0]]
+ segment_size = len(first_item)
+ keys = list(self.segments.keys())
+
+ for i in range(segment_size):
+ formated_segment = {key: self.segments[key][i] for key in keys}
+ pkl.dump(formated_segment, open(f"{save_path}_{i}.pkl", "wb"))
+
+
+ def load(self, load_path: Union[str, List[str]]):
+ paths = [load_path] if isinstance(load_path, str) else load_path
+ self.segments, self.total_samples = load_pkl_files(paths)
+ self.fetcher = MutiSegmentFetcher(self.segments)
+
+ @abstractmethod
+ def __getitem__(self, index: int):
+ raise NotImplementedError
+
+ def __len__(self) -> int:
+ assert self.total_samples // self.chunk_size > 0
+ return self.total_samples // self.chunk_size
+
+
+class SeqDataset(BaseDataset):
+ def __init__(self, chunk_size , device='cuda'):
+ super().__init__(chunk_size, device)
+ self.fetcher = MutiSegmentFetcher(self.segments)
+
+ def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
+ return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
+
+ def __getitem__(self, index):
+ begin_idx = index * self.chunk_size
+ end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
+
+ x = self._fetch_data(begin_idx, end_idx).to(device=self.device, dtype=torch.long)
+ y = self._fetch_data(begin_idx + 1, end_idx + 1).to(device=self.device, dtype=torch.long)
+
+ return x, y
+
+
+class SftDataset(BaseDataset):
+ def __init__(self, chunk_size, device='cuda'):
+ super().__init__(chunk_size, device)
+ self.fetcher = MutiSegmentFetcher(self.segments)
+
+ def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
+ return self.fetcher.key_fetch(begin_idx, end_idx, key)
+
+ def __getitem__(self, index):
+ begin_idx = index * self.chunk_size
+ end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
+
+ x = self._fetch_data(begin_idx, end_idx, "sequence").to(device=self.device, dtype=torch.long)
+ y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(device=self.device, dtype=torch.long)
+ loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "mask").to(device=self.device, dtype=torch.bool)
+
+ return x, y, loss_mask
+
+
+class DpoDataset(BaseDataset):
+ def __init__(self, chunk_size: int, device="cuda"):
+ super().__init__(chunk_size, device)
+ self.fetcher = MutiSegmentFetcher(self.segments)
+
+ def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
+ return self.fetcher.key_fetch(begin_idx, end_idx, key)
+
+ def __getitem__(self, index: int):
+ start_idx = index * self.chunk_size
+ end_idx = min(start_idx + self.chunk_size, self.total_samples - 1)
+
+ chosen = self._fetch_data(start_idx, end_idx, "chosen").to(device=self.device, dtype=torch.long)
+ rejected = self._fetch_data(start_idx, end_idx, "rejected").to(device=self.device, dtype=torch.long)
+ chosen_mask = self._fetch_data(start_idx, end_idx, "chosen_mask").to(device=self.device, dtype=torch.bool)
+ rejected_mask = self._fetch_data(start_idx, end_idx, "rejected_mask").to(device=self.device, dtype=torch.bool)
+
+ return chosen, rejected, chosen_mask, rejected_mask
+
+
+class PpoDataset(BaseDataset):
+ def __init__(self, chunk_size: int, device="cuda"):
+ super().__init__(chunk_size, device)
+ self.fetcher = MutiSegmentFetcher(self.segments)
+
+ def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
+ return self.fetcher.key_fetch(begin_idx, end_idx, key)
+
+ def __getitem__(self, index: int) -> Dict[str, Tensor]:
+
+ begin_idx = index * self.chunk_size
+ end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
+
+
+ input_ids = self._fetch_data(begin_idx, end_idx, "input_ids").to(self.device),
+ actions = self._fetch_data(begin_idx, end_idx, "actions").to(self.device),
+ logprobs = self._fetch_data(begin_idx, end_idx, "logprobs").to(self.device),
+ rewards = self._fetch_data(begin_idx, end_idx, "rewards").to(self.device)
+
+ return input_ids, actions, logprobs, rewards
+
+
+class DatasetLoader:
+ @staticmethod
+ def load(
+ train_type: Literal["seq", "sft", "dpo"],
+ load_path: Union[str, List[str]],
+ max_len: int,
+ device: str
+ ) -> BaseDataset:
+
+ dataset_router: Dict[str, Callable[[int, torch.device], BaseDataset]] = {
+ "seq": lambda m_len, device: SeqDataset(m_len, device=device),
+ "sft": lambda m_len, device: SftDataset(m_len, device=device),
+ "dpo": lambda m_len, device: DpoDataset(m_len, device=device),
+ }
+ dataset = dataset_router[train_type](max_len, device)
+ dataset.load(load_path)
+
+ return dataset
\ No newline at end of file
diff --git a/khaosz/trainer/mask.py b/khaosz/trainer/mask.py
new file mode 100644
index 0000000..9de0472
--- /dev/null
+++ b/khaosz/trainer/mask.py
@@ -0,0 +1,55 @@
+
+import torch
+from abc import abstractmethod
+from torch import Tensor
+
+
+
+class MaskBuilder:
+ def __init__(
+ self,
+ bos_token_id: int,
+ eos_token_id: int,
+ user_token_id: int,
+ system_token_id: int,
+
+ ):
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+ self.user_token_id = user_token_id
+ self.system_token_id = system_token_id
+
+ @abstractmethod
+ def build(input_ids: Tensor) -> Tensor:
+ raise NotImplementedError
+
+
+
+class LossMaskBuilder(MaskBuilder):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def build(self, input_ids: Tensor) -> Tensor:
+ token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
+
+ is_user_token = input_ids.eq(self.user_token_id)
+ is_system_token = input_ids.eq(self.system_token_id)
+
+ token_markers[is_user_token] = 1
+ token_markers[is_system_token] = -1
+
+ cumulative_markers = torch.cumsum(token_markers, dim=-1)
+ min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
+ loss_mask = cumulative_markers - min_cumulative
+
+ return loss_mask
+
+
+
+
+class AttentionMaskBuilder:
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def build(input_ids: Tensor):
+ bsz = input_ids.size(0)
\ No newline at end of file
diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py
new file mode 100644
index 0000000..123b43f
--- /dev/null
+++ b/khaosz/trainer/strategy.py
@@ -0,0 +1,388 @@
+import copy
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from torch import Tensor
+from torch.optim import Optimizer
+from torch.utils.data import Dataset
+from typing import Any, Literal, Tuple, Callable, Dict
+from abc import ABC, abstractmethod
+from dataclasses import asdict, dataclass, field
+
+
+def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id):
+ input_mask = input_ids.ne(pad_token_id)
+ logits = model(input_ids, input_mask)["logits"]
+ log_probs = torch.log_softmax(logits, dim=-1)
+
+ shifted_log_probs = log_probs[:, :-1, :]
+ shifted_input_ids = input_ids[:, 1:]
+ shifted_response_mask = mask[:, 1:]
+
+ token_logprobs = torch.gather(
+ shifted_log_probs,
+ dim=-1,
+ index=shifted_input_ids.unsqueeze(-1)
+ ).squeeze(-1)
+
+ prompt_mask = input_mask[:, 1:]
+ valid_mask = (prompt_mask & shifted_response_mask).float()
+
+ return (token_logprobs * valid_mask).sum(dim=-1)
+
+
+class MaskBuilder:
+ def __init__(
+ self,
+ bos_token_id: int,
+ eos_token_id: int,
+ user_token_id: int,
+ system_token_id: int,
+
+ ):
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+ self.user_token_id = user_token_id
+ self.system_token_id = system_token_id
+
+ @abstractmethod
+ def build(input_ids: Tensor) -> Tensor:
+ raise NotImplementedError
+
+
+
+class LossMaskBuilder(MaskBuilder):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def build(self, input_ids: Tensor) -> Tensor:
+ token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
+
+ is_user_token = input_ids.eq(self.user_token_id)
+ is_system_token = input_ids.eq(self.system_token_id)
+
+ token_markers[is_user_token] = 1
+ token_markers[is_system_token] = -1
+
+ cumulative_markers = torch.cumsum(token_markers, dim=-1)
+ min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
+ loss_mask = cumulative_markers - min_cumulative
+
+ return loss_mask.to(dtype=torch.bool)
+
+
+
+
+class AttentionMaskBuilder(MaskBuilder):
+ def __init__(self, multi_turn=False, **kwargs):
+ super().__init__(**kwargs)
+ self.multi_turn = multi_turn
+
+ def build(self, input_ids: Tensor):
+ bsz = input_ids.size(0)
+
+
+ def _build_batch(self, input_ids: Tensor):
+ is_user_token = input_ids.eq(self.user_token_id)
+
+ token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
+ token_markers[is_user_token] = 1
+ cumulative_markers = torch.cumsum(token_markers, dim=-1)
+
+
+
+
+
+class BaseStrategy(ABC):
+ def __init__(self, model: nn.Module):
+ self.model = model
+
+ @abstractmethod
+ def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
+ raise NotImplementedError
+
+ def __call__(self, batch: Tuple[Tensor, ...]) -> Tensor:
+ return self.compute_loss(batch)
+
+
+class SeqStrategy(BaseStrategy):
+ def __init__(self, model):
+ super().__init__(model)
+
+ def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
+ x, y = batch
+ B, L = x.size()
+ logits: Tensor = self.model(x)["logits"]
+
+ loss = F.cross_entropy(
+ logits.view(B * L, -1), y.flatten()
+ )
+ return loss
+
+
+class SftStrategy(BaseStrategy):
+ def __init__(self, model):
+ super().__init__(model)
+
+ def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
+ x, y, loss_mask = batch
+ B, L = x.size()
+ ignore_idx = -1
+
+ logits: Tensor = self.model(x)["logits"]
+ masked_y = y.masked_fill(loss_mask == 0, ignore_idx)
+
+ loss = F.cross_entropy(
+ logits.view(B * L, -1),
+ masked_y.flatten(),
+ ignore_index=ignore_idx
+ )
+
+ return loss
+
+class DpoStrategy(BaseStrategy):
+ def __init__(self, model, pad_token_id, beta):
+ super().__init__(model)
+ ref_model = copy.deepcopy(self.model)
+ ref_model.requires_grad_(False)
+ ref_model.eval()
+
+ self.ref_model = ref_model
+ self.pad_token_id = pad_token_id
+ self.beta = beta
+
+ def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
+ good_ids, bad_ids, good_mask, bad_mask = batch
+
+ log_pi_good = get_logprobs(self.model, good_ids, good_mask, self.pad_token_id)
+ log_pi_bad = get_logprobs(self.model, bad_ids, bad_mask, self.pad_token_id)
+
+ with torch.no_grad():
+ log_ref_good = get_logprobs(self.ref_model, good_ids, good_mask, self.pad_token_id)
+ log_ref_bad = get_logprobs(self.ref_model, bad_ids, bad_mask, self.pad_token_id)
+
+ pi_log_ratio = log_pi_good - log_pi_bad
+ ref_log_ratio = log_ref_good - log_ref_bad
+
+ ratio_diff = pi_log_ratio - ref_log_ratio
+
+ dpo_loss = -F.logsigmoid(self.beta * ratio_diff).mean()
+ return dpo_loss
+
+
+class PpoStrategy(BaseStrategy):
+ def __init__(self, model, pad_token_id, epsilon):
+ super().__init__(model)
+ ref_model = copy.deepcopy(self.model)
+ ref_model.requires_grad_(False)
+ ref_model.eval()
+
+ self.ref_model = ref_model
+ self.pad_token_id = pad_token_id
+ self.epsilon = epsilon
+
+ def ppo_clip_loss_masked(
+ self,
+ log_probs: Tensor,
+ old_log_probs: Tensor,
+ advantages: Tensor,
+ values: Tensor,
+ returns: Tensor,
+ mask: Tensor,
+ clip_eps: float=0.2,
+ ):
+ ratio = torch.exp(log_probs - old_log_probs)
+ surr1 = ratio * advantages
+ surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * advantages
+ policy_loss = -torch.min(surr1, surr2).masked_select(mask).mean()
+
+ value_loss = F.mse_loss(values.masked_select(mask),
+ returns.masked_select(mask))
+
+ entropy = -(log_probs.exp() * log_probs).masked_select(mask).mean()
+ entropy_loss = -entropy
+ return policy_loss, value_loss, entropy_loss
+
+
+
+class StrategyFactory:
+
+ def load(model, train_type, pad_token_id, dpo_beta):
+ train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
+ "seq": lambda: SeqStrategy(model),
+ "sft": lambda: SftStrategy(model),
+ "dpo": lambda: DpoStrategy(model, pad_token_id, dpo_beta)
+ }
+ strategy = train_strategy[train_type]()
+ return strategy
+
+
+@dataclass
+class TrainConfig:
+ train_type: str = field(
+ default_factory=["seq", "sft", "dpo"],
+ metadata={"help": "Type of training."}
+ )
+ dataset: Dataset = field(
+ default=None,
+ metadata={"help": "Dataset for training."}
+ )
+ optimizer: Optimizer = field(
+ default=None,
+ metadata={"help": "Optimizer for training."}
+ )
+ ckpt_dir: str = field(
+ default="./checkpoint",
+ metadata={"help": "Checkpoint directory."}
+ )
+ n_epoch: int = field(
+ default=1,
+ metadata={"help": "Number of epochs for training."}
+ )
+ batch_size: int = field(
+ default=4,
+ metadata={"help": "Batch size for training."}
+ )
+ n_iter_ckpt: int = field(
+ default=5000,
+ metadata={"help": "Number of iterations between checkpoints."}
+ )
+ n_iter_step: int = field(
+ default=1,
+ metadata={"help": "Number of iterations between steps."}
+ )
+ max_grad_norm: float = field(
+ default=1.0,
+ metadata={"help": "Maximum gradient norm."}
+ )
+ random_seed: int = field(
+ default=3407,
+ metadata={"help": "Random seed."}
+ )
+ dpo_beta: float = field(
+ default=0.1,
+ metadata={"help": "DPO beta."}
+ )
+
+ def get_kwargs(self)-> Dict[str, Any]:
+ config_dict = asdict(self)
+ return {k: v for k, v in config_dict.items() if v is not None}
+
+
+@dataclass
+class ScheduleConfig:
+ schedule_type: str = field(
+ default_factory=["cosine", "sgdr"],
+ metadata={"help": "Type of learning rate schedule."}
+ )
+ warning_step: int = field(
+ default=1000,
+ metadata= {"help": "Warning up step."}
+ )
+ @abstractmethod
+ def get_kwargs(self)-> Dict[str, Any]:
+ raise NotImplementedError
+
+
+@dataclass
+class CosineScheduleConfig(ScheduleConfig):
+ total_iters: int = field(
+ default=None,
+ metadata={"help": "Total iterations for cosine schedule."}
+ )
+ min_rate: float = field(
+ default=0.05,
+ metadata={"help": "Minimum rate for cosine schedule."}
+ )
+ schedule_type: Literal["cosine"] = "cosine"
+
+ def get_kwargs(self) -> Dict[str, Any]:
+ return {
+ "schedule_type": self.schedule_type,
+ "warning_step": self.warning_step,
+ "lr_decay_iters": self.total_iters - self.warning_step,
+ "min_rate": self.min_rate
+ }
+
+@dataclass
+class SgdrScheduleConfig(ScheduleConfig):
+ cycle_length: int = field(
+ default=1000,
+ metadata={"help": "Cycle length for sgdr schedule."}
+ )
+ min_rate: float = field(
+ default=0.05,
+ metadata={"help": "Minimum rate for sgdr schedule."}
+ )
+ T_mult: int = field(
+ default=2,
+ metadata={"help": "T_mult for sgdr schedule."}
+ )
+ schedule_type: Literal["sgdr"] = "sgdr"
+
+ def get_kwargs(self) -> Dict[str, Any]:
+ return {
+ "schedule_type": self.schedule_type,
+ "warning_step": self.warning_step,
+ "cycle_length": self.cycle_length,
+ "min_rate": self.min_rate,
+ "T_mult": self.T_mult
+ }
+
+
+class SchedulerFactory:
+
+ @staticmethod
+ def get_sgdr_schedule(
+ warning_step: int,
+ cycle_length: int,
+ min_rate: float = 0.1,
+ T_mult: int = 2
+ ) -> Callable[[int], float]:
+
+ def sgdr_schedule(now_iter: int) -> float:
+ if now_iter < warning_step:
+ return max(min_rate, now_iter / warning_step)
+
+ adjusted_iter = now_iter - warning_step
+ total_cycles, current_cycle = 0, 0
+ while adjusted_iter >= cycle_length * (T_mult ** total_cycles):
+ current_cycle += 1
+ total_cycles += 1
+
+ cycle_start = sum(cycle_length * (T_mult ** i) for i in range(current_cycle))
+ cycle_pos = adjusted_iter - cycle_start
+
+ cycle_length_current = cycle_length * (T_mult ** current_cycle)
+ return max(min_rate, 0.5 * (1 + math.cos(math.pi * cycle_pos / cycle_length_current)))
+
+ return sgdr_schedule
+
+ @staticmethod
+ def get_cosine_warmup_schedule(
+ warning_step: int,
+ lr_decay_iters: int,
+ min_rate: float = 0.1
+ ) -> Callable[[int], float]:
+
+ def cosine_warmup_schedule(now_iter: int) -> float:
+ if now_iter <= warning_step:
+ return max(min_rate, now_iter / warning_step)
+ else:
+ rate = (now_iter - warning_step) / (lr_decay_iters - warning_step)
+ return max(min_rate, 0.5 * (1.0 + math.cos(math.pi * rate)))
+
+ return cosine_warmup_schedule
+
+ @staticmethod
+ def load_schedule_fn(**kwargs):
+ strategy = kwargs.pop("schedule_type")
+ if strategy == "cosine":
+ return SchedulerFactory.get_cosine_warmup_schedule(**kwargs)
+ elif strategy == "sgdr":
+ return SchedulerFactory.get_sgdr_schedule(**kwargs)
+ else:
+ raise ValueError(f"Invalid schedule type: {strategy}")
+
\ No newline at end of file
diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py
new file mode 100644
index 0000000..a0ea0fe
--- /dev/null
+++ b/khaosz/trainer/trainer.py
@@ -0,0 +1,167 @@
+import os
+import torch
+import logging
+
+from typing import Tuple
+from torch.nn.utils import clip_grad_norm_
+from torch.optim.lr_scheduler import LambdaLR
+from torch.utils.data import DataLoader, RandomSampler
+from tqdm import tqdm
+
+from khaosz.core import ModelParameter, Checkpoint
+from khaosz.trainer.strategy import SchedulerFactory, StrategyFactory, TrainConfig, ScheduleConfig
+
+
+class Trainer:
+ def __init__(
+ self,
+ parameter: ModelParameter,
+ log_path: str="./train_log.log"
+ ):
+ logger = logging.getLogger()
+ logger.setLevel(level = logging.INFO)
+ handler = logging.FileHandler(log_path)
+ handler.setLevel(logging.INFO)
+ handler.setFormatter(logging.Formatter('%(asctime)s: %(message)s'))
+ logger.addHandler(handler)
+ logger.info("initializing trainer ...")
+
+ self.logger = logger
+ self.model = parameter.model
+ self.tokenizer = parameter.tokenizer
+ self.config = parameter.config
+
+ def save_checkpoint(
+ self,
+ loss_list: list,
+ ckpt_dir: str,
+ current_iter: int,
+ last_ckpt_iter: int
+ ):
+ save_path = os.path.join(ckpt_dir, f"iter_{current_iter}")
+ Checkpoint(
+ self.model,
+ self.tokenizer,
+ self.config,
+ loss_list,
+ current_iter
+ ).save(save_path)
+
+ diff_iter = current_iter - last_ckpt_iter
+ avg_loss = sum(loss_list[last_ckpt_iter:current_iter]) / diff_iter
+ self.logger.info(f"iter: {current_iter} loss: {avg_loss}")
+
+ return current_iter
+
+ def load_checkpoint(self, train_checkpoint: Checkpoint) -> Tuple[list, int]:
+ self.model = train_checkpoint.model
+ self.tokenizer = train_checkpoint.tokenizer
+ self.config = train_checkpoint.config
+ loss_list = train_checkpoint.loss_list
+ last_ckpt_iter = train_checkpoint.current_iter
+
+ return loss_list, last_ckpt_iter
+
+ def train(
+ self,
+ train_config: TrainConfig,
+ schedule_config: ScheduleConfig,
+ train_checkpoint: Checkpoint = None
+ ):
+ assert schedule_config.schedule_type in ["cosine", "sgdr"]
+ assert train_config.train_type in ["seq", "sft", "dpo"]
+
+ if train_checkpoint:
+ loss_list, last_ckpt_iter = self.load_checkpoint(train_checkpoint)
+ current_iter = train_checkpoint.current_iter + 1
+ self.logger.info(f"Resuming training from checkpoint: iter {current_iter}")
+ else:
+ current_iter = 0
+ last_ckpt_iter = 0
+ loss_list = []
+
+ lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
+ **schedule_config.get_kwargs()
+ )
+
+ strategy = StrategyFactory.load(
+ self.model,
+ train_config.train_type,
+ self.tokenizer.pad_id,
+ train_config.dpo_beta
+ )
+
+ scheduler = LambdaLR(
+ train_config.optimizer,
+ lambda_scheduler_fn,
+ last_epoch=current_iter - 1 if train_checkpoint else -1
+ )
+
+ seed = train_config.random_seed
+ generator = torch.Generator().manual_seed(seed)
+ sampler = RandomSampler(train_config.dataset, generator=generator)
+ remaining_epochs = train_config.n_epoch - current_iter // (len(train_config.dataset) // train_config.batch_size)
+
+ self.logger.info(f"Starting {train_config.train_type.upper()} training for {train_config.n_epoch} epochs")
+ self.logger.info(f"Checkpoint interval: {train_config.n_iter_ckpt} iterations")
+
+ for epoch in range(remaining_epochs):
+ self.model.train()
+ dataloader = DataLoader(
+ train_config.dataset,
+ batch_size=train_config.batch_size,
+ sampler=sampler
+ )
+ progress_bar = tqdm(
+ dataloader,
+ desc=f"Epoch {epoch+1}/{train_config.n_epoch}",
+ dynamic_ncols=True
+ )
+ for batch in progress_bar:
+ #forward
+ loss = strategy(batch)
+ loss_list.append(loss.item())
+ #backward
+ loss.backward()
+ #step
+ if current_iter % train_config.n_iter_step == 0:
+ clip_grad_norm_(
+ self.model.parameters(),
+ train_config.max_grad_norm
+ )
+ train_config.optimizer.step()
+ train_config.optimizer.zero_grad()
+
+ current_iter += 1
+ scheduler.step()
+ progress_bar.set_postfix({
+ "loss": f"{loss.item():.4f}",
+ "lr": f"{train_config.optimizer.param_groups[0]['lr']:.2e}"
+ })
+ #save checkpotint
+ if current_iter - last_ckpt_iter >= train_config.n_iter_ckpt:
+ last_ckpt_iter = self.save_checkpoint(
+ loss_list,
+ train_config.ckpt_dir,
+ current_iter,
+ last_ckpt_iter
+ )
+
+ if current_iter != last_ckpt_iter:
+ last_ckpt_iter = self.save_checkpoint(
+ loss_list,
+ train_config.ckpt_dir,
+ current_iter,
+ last_ckpt_iter
+ )
+
+ self.logger.info("Training completed")
+
+ return Checkpoint(
+ self.model,
+ self.tokenizer,
+ self.config,
+ loss_list,
+ current_iter,
+ train_config.optimizer
+ )
diff --git a/khaosz/utils/retriever.py b/khaosz/utils/retriever.py
new file mode 100644
index 0000000..1264474
--- /dev/null
+++ b/khaosz/utils/retriever.py
@@ -0,0 +1,88 @@
+import torch
+import sqlite3
+import numpy as np
+from torch import Tensor
+from typing import Dict, List, Tuple
+
+
+class Retriever:
+ def __init__(self, db_path=None):
+ self.data: Dict[str, Tensor] = {}
+ self.embedding_cache: Tensor = None
+ self.is_caculated: bool = False
+
+ if db_path is not None:
+ self.load(db_path)
+
+ def retrieve(self, query: Tensor, top_k: int) -> List[Tuple[str, float]]:
+ if not self.data:
+ return []
+
+ query = query.flatten().unsqueeze(1) # [dim, 1]
+ norm_embeddings = self._embeddings.to(
+ device=query.device,
+ dtype=query.dtype
+ ) # [n_vectors, dim]
+ sim_scores = torch.matmul(norm_embeddings, query).squeeze() # [n_vectors]
+
+ top_k = min(top_k, len(self.data))
+ indices = sim_scores.topk(top_k).indices
+ keys = list(self.data.keys())
+
+ return [(keys[i], sim_scores[i].item()) for i in indices]
+
+ def add_vector(self, key: str, vector_data: Tensor):
+ self.is_caculated = False
+ self.data[key] = vector_data.flatten().float().cpu()
+
+ def delete_vector(self, key: str):
+ self.is_caculated = False
+ self.data.pop(key, None)
+
+ def save(self, db_path):
+ conn = sqlite3.connect(db_path)
+ cursor = conn.cursor()
+ self._init_db(cursor)
+ cursor.execute('DELETE FROM vectors')
+
+ for item, vec in self.data.items():
+ vec_bytes = vec.numpy().tobytes()
+ cursor.execute('INSERT OR REPLACE INTO vectors (key, vector) VALUES (?, ?)',
+ (item, vec_bytes))
+
+ conn.commit()
+ conn.close()
+
+ def load(self, db_path):
+ conn = sqlite3.connect(db_path)
+ cursor = conn.cursor()
+ self._init_db(cursor)
+ cursor.execute('SELECT key, vector FROM vectors')
+ rows = cursor.fetchall()
+ self.data = {}
+
+ for row in rows:
+ key, vec_bytes = row
+ vec_numpy = np.frombuffer(vec_bytes, dtype=np.float32).copy()
+ vec = torch.from_numpy(vec_numpy)
+ self.data[key] = vec
+
+ conn.close()
+
+ def _init_db(self,cursor: sqlite3.Cursor):
+ # Create table if not exists (in case loading from a new database)
+ cursor.execute('''
+ CREATE TABLE IF NOT EXISTS vectors (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ key TEXT UNIQUE NOT NULL,
+ vector BLOB NOT NULL
+ )''')
+
+ @property
+ def _embeddings(self) -> Tensor:
+ if not self.is_caculated:
+ embeddings = torch.stack(list(self.data.values()))
+ norm_embeddings = embeddings / torch.norm(embeddings, dim=-1, keepdim=True)
+ self.embedding_cache = norm_embeddings
+
+ return self.embedding_cache
\ No newline at end of file
diff --git a/khaosz/utils/splitter.py b/khaosz/utils/splitter.py
new file mode 100644
index 0000000..7f15fd3
--- /dev/null
+++ b/khaosz/utils/splitter.py
@@ -0,0 +1,127 @@
+import re
+import torch
+import torch.nn.functional as F
+
+from abc import ABC, abstractmethod
+from torch import Tensor
+from typing import List, Callable, Optional
+
+
+class BaseTextSplitter(ABC):
+ def __init__(
+ self,
+ max_len: int = 512,
+ chunk_overlap: int = 0,
+ ):
+ if max_len <= 0:
+ raise ValueError("max_len must be > 0")
+ if chunk_overlap < 0:
+ raise ValueError("chunk_overlap must be >= 0")
+
+ self.max_len = max_len
+ self.chunk_overlap = chunk_overlap
+
+ @abstractmethod
+ def split(self, text: str, **kwargs) -> List[str]:
+ raise NotImplementedError
+
+ def preprocess(self, text: str) -> str:
+ return text.strip()
+
+ def postprocess(self, chunks: List[str]) -> List[str]:
+ return [chunk.strip() for chunk in chunks if chunk.strip()]
+
+
+class PriorityTextSplitter(BaseTextSplitter):
+ def __init__(
+ self,
+ separators: List[str],
+ max_len: int = 512,
+ chunk_overlap: int = 0,
+ ):
+ super().__init__(max_len=max_len, chunk_overlap=chunk_overlap)
+ if not separators:
+ raise ValueError("separators must be a non-empty list")
+ self.separators = separators
+
+ def split(self, text: str) -> List[str]:
+ text = self.preprocess(text)
+ for sep in self.separators:
+ parts = text.split(sep)
+
+ valid_parts = [p.strip() for p in parts if p.strip()]
+ if len(valid_parts) > 1:
+ return self.postprocess(valid_parts)
+ return [text]
+
+
+class SemanticTextSplitter(BaseTextSplitter):
+
+ DEFAULT_PATTERN = r'(?<=[。!?!?])(?=(?:[^"\'‘’“”]*["\'‘’“”][^"\'‘’“”]*["\'‘’“”])*[^"\'‘’“”]*$)'
+
+ def __init__(
+ self,
+ embedding_func: Callable[[List[str]], List[Tensor]],
+ pattern: Optional[str] = None,
+ max_len: int = 512,
+ chunk_overlap: int = 0,
+ ):
+ super().__init__(max_len=max_len, chunk_overlap=chunk_overlap)
+ if not callable(embedding_func):
+ raise TypeError("embedding_func must be callable")
+ self.embedding_func = embedding_func
+ self.pattern = pattern or SemanticTextSplitter.DEFAULT_PATTERN
+
+ def split(
+ self,
+ text: str,
+ threshold: float = 0.5,
+ window_size: int = 1,
+ ) -> List[str]:
+ text = self.preprocess(text)
+ sentences = [s.strip() for s in re.split(self.pattern, text) if s.strip()]
+
+ if len(sentences) <= 1:
+ return self.postprocess(sentences)
+
+ try:
+ sentence_embs = self.embedding_func(sentences)
+ except Exception as e:
+ raise RuntimeError(f"Embedding generation failed: {e}")
+
+ if len(sentence_embs) != len(sentences):
+ raise ValueError("Embedding function must return one vector per sentence")
+
+ chunks = []
+ emb_tensor = torch.stack(sentence_embs) # shape: [N, D]
+ current_chunk: List[str] = [sentences[0]]
+
+ for i in range(1, len(sentences)):
+ start_prev = max(0, i - window_size)
+ end_prev = i
+ start_next = i
+ end_next = min(len(sentences), i + window_size)
+
+ prev_window_emb = emb_tensor[start_prev:end_prev].mean(dim=0)
+ next_window_emb = emb_tensor[start_next:end_next].mean(dim=0)
+
+ similarity = F.cosine_similarity(
+ prev_window_emb.unsqueeze(0),
+ next_window_emb.unsqueeze(0),
+ dim=1
+ ).item()
+
+ dynamic_threshold = max(threshold * (1 - 0.03 * (end_next - start_prev)), 0.2)
+
+ if similarity < dynamic_threshold:
+ chunks.append(" ".join(current_chunk))
+ overlap_start = max(0, len(current_chunk) - self.chunk_overlap)
+ current_chunk = current_chunk[overlap_start:]
+ current_chunk.append(sentences[i])
+ else:
+ current_chunk.append(sentences[i])
+
+ if current_chunk:
+ chunks.append(" ".join(current_chunk))
+
+ return self.postprocess(chunks)
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..5c8a7bb
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,36 @@
+# python=3.12
+--extra-index-url https://download.pytorch.org/whl/cu126
+
+certifi==2025.8.3
+charset-normalizer==3.4.2
+colorama==0.4.6
+contourpy==1.3.3
+cycler==0.12.1
+filelock==3.13.1
+fonttools==4.59.0
+fsspec==2024.6.1
+huggingface-hub==0.34.3
+idna==3.10
+Jinja2==3.1.6
+kiwisolver==1.4.8
+MarkupSafe==2.1.5
+matplotlib==3.10.5
+mpmath==1.3.0
+networkx==3.3
+numpy==2.3.2
+packaging==25.0
+pillow==11.3.0
+pyparsing==3.2.3
+python-dateutil==2.9.0.post0
+PyYAML==6.0.2
+requests==2.32.4
+safetensors==0.5.3
+setuptools==78.1.1
+six==1.17.0
+sympy==1.13.3
+tokenizers==0.21.4
+torch==2.7.1+cu126
+tqdm==4.67.1
+typing_extensions==4.12.2
+urllib3==2.5.0
+wheel==0.45.1
diff --git a/scripts/chat.py b/scripts/chat.py
new file mode 100644
index 0000000..399717b
--- /dev/null
+++ b/scripts/chat.py
@@ -0,0 +1,32 @@
+import os
+import torch
+from khaosz import Khaosz
+
+
+PROJECT_ROOT = os.path.dirname(
+ os.path.dirname(os.path.abspath(__file__)))
+
+def chat():
+ model_dir = os.path.join(PROJECT_ROOT, "params")
+ model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
+
+ histroy = []
+ while True:
+ query = input(">> ")
+ if query == "!exit":
+ break
+
+ response_size = 0
+ for response, histroy in model.stream_generate(
+ query=query,
+ history=histroy,
+ temperature=0.7,
+ top_p=0.95,
+ top_k=30
+ ):
+ print(response[response_size:], end="", flush=True)
+ response_size = len(response)
+
+
+if __name__ == "__main__":
+ chat()
\ No newline at end of file
diff --git a/scripts/download.py b/scripts/download.py
new file mode 100644
index 0000000..ca641f2
--- /dev/null
+++ b/scripts/download.py
@@ -0,0 +1,14 @@
+import os
+from huggingface_hub import snapshot_download
+
+
+PROJECT_ROOT = os.path.dirname(
+ os.path.dirname(os.path.abspath(__file__)))
+
+
+if __name__ == "__main__":
+ snapshot_download(
+ repo_id="ViperEk/KHAOSZ",
+ local_dir=os.path.join(PROJECT_ROOT, "params"),
+ force_download=True
+ )
\ No newline at end of file
diff --git a/scripts/generate_ar.py b/scripts/generate_ar.py
new file mode 100644
index 0000000..23b2832
--- /dev/null
+++ b/scripts/generate_ar.py
@@ -0,0 +1,27 @@
+import os
+import torch
+from khaosz import Khaosz
+
+
+PROJECT_ROOT = os.path.dirname(
+ os.path.dirname(os.path.abspath(__file__)))
+
+def generate_text():
+ model_dir = os.path.join(PROJECT_ROOT, "params")
+ model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
+
+ query = input(">> ")
+
+ response = model.text_generate(
+ query=query,
+ temperature=0.6,
+ top_p=0.95,
+ top_k=30
+ )
+
+ print(response)
+
+
+
+if __name__ == "__main__":
+ generate_text()
\ No newline at end of file
diff --git a/scripts/generate_batch.py b/scripts/generate_batch.py
new file mode 100644
index 0000000..037b04f
--- /dev/null
+++ b/scripts/generate_batch.py
@@ -0,0 +1,25 @@
+import os
+import torch
+from khaosz import Khaosz
+
+
+PROJECT_ROOT = os.path.dirname(
+ os.path.dirname(os.path.abspath(__file__)))
+
+def batch_generate():
+ model_dir = os.path.join(PROJECT_ROOT, "params")
+ model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
+ inputs = ["你好", "请问什么是人工智能", "今天天气如何", "我感到焦虑, 请问我应该怎么办", "请问什么是显卡"]
+
+ responses = model.batch_generate(
+ queries=inputs,
+ temperature=0.7,
+ top_p=0.95,
+ top_k=30
+ )
+
+ for q, r in zip(inputs, responses):
+ print((q, r))
+
+if __name__ == "__main__":
+ batch_generate()
\ No newline at end of file
diff --git a/scripts/generate_retrieve.py b/scripts/generate_retrieve.py
new file mode 100644
index 0000000..46a9e56
--- /dev/null
+++ b/scripts/generate_retrieve.py
@@ -0,0 +1,42 @@
+import os
+import torch
+from khaosz import Khaosz, SemanticTextSplitter, Retriever
+
+
+PROJECT_ROOT = os.path.dirname(
+ os.path.dirname(os.path.abspath(__file__)))
+
+if __name__ == "__main__":
+ model_dir = os.path.join(PROJECT_ROOT, "params")
+ context_path = os.path.join(PROJECT_ROOT, "README.md")
+
+ model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
+ spliter = SemanticTextSplitter(model.encode)
+ retriever = Retriever()
+ text = open(context_path, "r", encoding="utf-8").read()
+
+ res = spliter.split(text, threshold=0.8, window_size=1)
+ # print(("\n" + "+"*100 + "\n").join(res))
+
+ res_embs = model.encode(res)
+ for sentence, emb in zip(res, res_embs):
+ retriever.add_vector(sentence, emb)
+
+ retrive_top_k = 5
+ query = "作者设计了一个怎样的模型"
+ emb_query = model.encode(query)
+ retrieved = retriever.retrieve(emb_query, retrive_top_k)
+
+ retrive_response = model.retrieve_generate(
+ retrieved=retrieved,
+ query=query,
+ temperature=0.7,
+ top_k=30,
+ top_p=0.95,
+ )
+
+ print("retrive content:")
+ print("\n".join([f"{idx + 1}. " + text for idx, (text, _) in enumerate(retrieved)]))
+
+ print("\n\nretrive generate:")
+ print(retrive_response)
\ No newline at end of file
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..ea67044
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,18 @@
+import re
+from setuptools import find_packages, setup
+
+
+with open("requirements.txt") as f:
+ required = [line for line in f.read().splitlines()
+ if line and re.match(r'^[^=]+==[^=]+$', line.strip())]
+
+setup(
+ name="khaosz",
+ version="1.2.0",
+ packages=find_packages(),
+ install_requires=required,
+ dependency_links=[
+ "https://download.pytorch.org/whl/cu126",
+ ],
+ python_requires="==3.12.*",
+)
\ No newline at end of file
diff --git a/tests/test_module.py b/tests/test_module.py
new file mode 100644
index 0000000..025188e
--- /dev/null
+++ b/tests/test_module.py
@@ -0,0 +1,103 @@
+import os
+import json
+import torch
+import shutil
+import pytest
+import tempfile
+import safetensors.torch as st
+from khaosz.core import *
+from khaosz.core.generator import EmbeddingEncoderCore, GeneratorCore
+from tokenizers import pre_tokenizers
+
+@pytest.fixture
+def test_env():
+ test_dir = tempfile.mkdtemp()
+ config_path = os.path.join(test_dir, "config.json")
+ tokenizer_path = os.path.join(test_dir, "tokenizer.json")
+ model_path = os.path.join(test_dir, "model.safetensors")
+
+ config = {
+ "vocab_size": 1000,
+ "n_dim": 128,
+ "n_head": 4,
+ "n_kvhead": 2,
+ "d_ffn": 256,
+ "m_len": 64,
+ "n_layer": 2,
+ "norm_eps": 1e-5
+ }
+ with open(config_path, 'w') as f:
+ json.dump(config, f)
+
+ tokenizer = BpeTokenizer()
+ sp_token_iter = iter(pre_tokenizers.ByteLevel.alphabet())
+ tokenizer.train_from_iterator(sp_token_iter, config["vocab_size"], 1)
+ tokenizer.save(tokenizer_path)
+
+ transformer_config = TransformerConfig().load(config_path)
+ model = Transformer(transformer_config)
+ st.save_file(model.state_dict(), model_path)
+
+ yield {
+ "test_dir": test_dir,
+ "model": model,
+ "tokenizer": tokenizer,
+ "transformer_config": transformer_config,
+ }
+
+ shutil.rmtree(test_dir)
+
+# parameter loader
+def test_parameter_loader(test_env):
+ loaded_param = ParameterLoader.load(test_env["test_dir"])
+ assert loaded_param.model is not None
+ assert loaded_param.tokenizer is not None
+ assert loaded_param.config == test_env["transformer_config"]
+
+def test_model_parameter(test_env):
+ save_dir = os.path.join(test_env["test_dir"], "save")
+ model_param = ModelParameter(test_env["model"],test_env["tokenizer"] , test_env["transformer_config"])
+ model_param.save(save_dir)
+
+ assert os.path.exists(os.path.join(save_dir, "model.safetensors"))
+ assert os.path.exists(os.path.join(save_dir, "tokenizer.json"))
+ assert os.path.exists(os.path.join(save_dir, "config.json"))
+
+# transformer
+def test_transformer(test_env):
+ model = test_env["model"]
+ input_ids = torch.randint(0, test_env["transformer_config"].vocab_size,
+ (4, test_env["transformer_config"].m_len))
+ output_logits = model(input_ids)["logits"]
+ target_shape = (4, test_env["transformer_config"].m_len, test_env["transformer_config"].vocab_size)
+ assert output_logits.shape == target_shape
+
+# generator
+def test_embedding_encoder_core(test_env):
+ parameter = ModelParameter(
+ test_env["model"],
+ test_env["tokenizer"],
+ test_env["transformer_config"]
+ )
+ encoder = EmbeddingEncoderCore(parameter)
+
+ single_emb = encoder.encode("测试文本")
+ assert isinstance(single_emb, torch.Tensor)
+ assert single_emb.shape[-1] == test_env["transformer_config"].n_dim
+
+
+ batch_emb = encoder.encode(["测试1", "测试2"])
+ assert isinstance(batch_emb, list)
+ assert len(batch_emb) == 2
+
+def test_generator_core(test_env):
+ parameter = ModelParameter(
+ test_env["model"],
+ test_env["tokenizer"],
+ test_env["transformer_config"]
+ )
+ generator = GeneratorCore(parameter)
+ logits, incr = generator.compute_logits(torch.randint(0, test_env["transformer_config"].vocab_size, (4, 10)))
+
+ assert logits.shape == (4, test_env["transformer_config"].vocab_size)
+ assert incr == 10
\ No newline at end of file
diff --git a/tests/test_trainer.py b/tests/test_trainer.py
new file mode 100644
index 0000000..1c23544
--- /dev/null
+++ b/tests/test_trainer.py
@@ -0,0 +1,203 @@
+import os
+import json
+import torch
+import shutil
+import pytest
+import pickle
+import tempfile
+import matplotlib
+
+from torch.utils.data import Dataset
+from khaosz.core import *
+from khaosz.trainer import *
+
+# to avoid _tkinter.TclError
+matplotlib.use('Agg')
+
+
+@pytest.fixture
+def test_env():
+ test_dir = tempfile.mkdtemp()
+ config_path = os.path.join(test_dir, "config.json")
+
+ config = {
+ "vocab_size": 1000,
+ "n_dim": 128,
+ "n_head": 4,
+ "n_kvhead": 2,
+ "d_ffn": 256,
+ "m_len": 64,
+ "n_layer": 2,
+ "norm_eps": 1e-5
+ }
+
+ with open(config_path, 'w') as f:
+ json.dump(config, f)
+
+ transformer_config = TransformerConfig().load(config_path)
+ model = Transformer(transformer_config)
+ tokenizer = BpeTokenizer()
+
+ class DummyDataset(Dataset):
+ def __init__(self, length=10):
+ self.length = length
+
+ def __len__(self):
+ return self.length
+
+ def __getitem__(self, idx):
+ return (
+ torch.randint(0, 1000, (64,)),
+ torch.randint(0, 1000, (64,))
+ )
+
+ dataset = DummyDataset()
+
+ yield {
+ "test_dir": test_dir,
+ "config_path": config_path,
+ "transformer_config": transformer_config,
+ "model": model,
+ "tokenizer": tokenizer,
+ "dataset": dataset
+ }
+
+ shutil.rmtree(test_dir)
+
+def test_dataset_loader(test_env):
+ test_dir = test_env["test_dir"]
+ pkl_path = os.path.join(test_dir, "test_data.pkl")
+
+ dummy_data = {"sequence": torch.randint(0, 1000, (64,))}
+ with open(pkl_path, "wb") as f:
+ pickle.dump(dummy_data, f)
+
+ loaded_dataset = DatasetLoader.load(train_type="seq", load_path=pkl_path, max_len=64, device="cpu")
+ assert loaded_dataset is not None
+
+def test_training_config(test_env):
+ optimizer = torch.optim.AdamW(test_env["model"].parameters())
+ train_config = TrainConfig(
+ train_type="seq",
+ dataset=test_env["dataset"],
+ optimizer=optimizer,
+ ckpt_dir=test_env["test_dir"],
+ n_epoch=1,
+ batch_size=2,
+ n_iter_ckpt=5,
+ n_iter_step=1,
+ max_grad_norm=1.0,
+ random_seed=42
+ )
+ assert train_config.get_kwargs()["batch_size"] == 2
+
+def test_cosine_schedule(test_env):
+ assert test_env is not None
+ schedule_config = CosineScheduleConfig(
+ warning_step=100,
+ total_iters=1000
+ )
+ kwargs = schedule_config.get_kwargs()
+ assert kwargs["warning_step"] == 100
+ assert kwargs["lr_decay_iters"] == 900
+
+
+def test_sgdr_schedule(test_env):
+ assert test_env is not None
+ schedule_config = SgdrScheduleConfig(
+ warning_step=100,
+ cycle_length=200,
+ T_mult=2
+ )
+ kwargs = schedule_config.get_kwargs()
+ assert kwargs["warning_step"] == 100
+ assert kwargs["cycle_length"] == 200
+ assert kwargs["T_mult"] == 2
+
+def test_trainer_train(test_env):
+ optimizer = torch.optim.AdamW(test_env["model"].parameters())
+ train_config = TrainConfig(
+ train_type="seq",
+ dataset=test_env["dataset"],
+ optimizer=optimizer,
+ ckpt_dir=test_env["test_dir"],
+ n_epoch=1,
+ batch_size=2,
+ n_iter_ckpt=5,
+ n_iter_step=1,
+ max_grad_norm=1.0,
+ random_seed=42
+ )
+ schedule_config = CosineScheduleConfig(
+ warning_step=100,
+ total_iters=1000
+ )
+ model_parameter = ModelParameter(
+ test_env["model"],
+ test_env["tokenizer"],
+ test_env["transformer_config"]
+ )
+ trainer = Trainer(model_parameter)
+ trainer.train(train_config, schedule_config)
+
+def test_checkpoint(test_env):
+ temp_dir = test_env["test_dir"]
+ config = test_env["transformer_config"]
+ model = test_env["model"]
+ tokenizer = test_env["tokenizer"]
+
+ param = ModelParameter(model, tokenizer, config)
+ checkpoint = Checkpoint(
+ model=param.model,
+ tokenizer=param.tokenizer,
+ config=param.config,
+ loss_list=[1.0, 2.0, 3.0],
+ current_iter=3
+ )
+ ckpt_dir = os.path.join(temp_dir, "ckpt")
+ checkpoint.save(ckpt_dir)
+
+ loaded_ckpt = Checkpoint()
+ loaded_ckpt.load(ckpt_dir)
+
+ assert loaded_ckpt.current_iter == 3
+ assert loaded_ckpt.loss_list == [1.0, 2.0, 3.0]
+
+ for p1, p2 in zip(model.parameters(), loaded_ckpt.model.parameters()):
+ assert torch.allclose(p1, p2)
+
+
+def test_checkpoint_train(test_env):
+ temp_dir = test_env["test_dir"]
+ config = test_env["transformer_config"]
+ model = test_env["model"]
+ tokenizer = test_env["tokenizer"]
+ dataset = test_env["dataset"]
+
+ param = ModelParameter(model, tokenizer, config)
+ trainer = Trainer(param)
+
+ optimizer = torch.optim.AdamW(test_env["model"].parameters())
+ train_config = TrainConfig(
+ train_type="seq",
+ dataset=dataset,
+ optimizer=optimizer,
+ ckpt_dir=test_env["test_dir"],
+ n_epoch=1,
+ batch_size=2,
+ n_iter_ckpt=5,
+ n_iter_step=1,
+ max_grad_norm=1.0,
+ random_seed=42
+ )
+ schedule_config = CosineScheduleConfig(
+ warning_step=100,
+ total_iters=1000
+ )
+
+ trainer.train(
+ train_config=train_config,
+ schedule_config=schedule_config,
+ )
+
+
\ No newline at end of file
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..707ad15
--- /dev/null
+++ b/train.py
@@ -0,0 +1,128 @@
+import os
+import argparse
+import torch
+
+from torch.optim import AdamW
+from khaosz.core import ParameterLoader
+from khaosz.trainer import Trainer, DatasetLoader, TrainConfig, CosineScheduleConfig
+
+
+PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
+
+def get_files(root_path: str) -> list[str]:
+ paths = []
+ for root, _, files in os.walk(root_path):
+ paths.extend([os.path.join(root, file) for file in files])
+
+ return paths
+
+def train(
+ train_type: str,
+ param_path: str,
+ data_root_path: str,
+ n_epoch: int,
+ batch_size: int,
+ n_iter_step: int,
+ warning_step: int,
+ max_lr: int,
+ n_iter_ckpt: int,
+ ckpt_dir: str,
+ dpo_beta: float,
+ adamw_betas: tuple,
+ adamw_weight_decay: float,
+ max_grad_norm: float,
+ embdeding_lr_rate: int,
+ random_seed: int,
+):
+ assert train_type in ["seq", "sft", "dpo"]
+ assert os.path.exists(param_path)
+
+ parameter = ParameterLoader.load(param_path)
+ model = parameter.model
+
+ device = torch.device("cuda")
+ model = model.to(device=device, dtype=torch.bfloat16)
+
+ cache_files = get_files(data_root_path)
+ dataset = DatasetLoader.load(
+ train_type=train_type,
+ load_path=cache_files,
+ max_len=parameter.config.m_len,
+ device=device
+ )
+
+ param_groups = [
+ {"params": [p for n, p in model.named_parameters() if "embedding" in n], "lr": max_lr * embdeding_lr_rate},
+ {"params": [p for n, p in model.named_parameters() if "embedding" not in n], "lr": max_lr}
+ ]
+
+ optim = AdamW(
+ param_groups,
+ betas=adamw_betas,
+ weight_decay=adamw_weight_decay
+ )
+
+ train_config = TrainConfig(
+ train_type=train_type,
+ dataset=dataset,
+ optimizer=optim,
+ ckpt_dir=ckpt_dir,
+ n_epoch=n_epoch,
+ batch_size=batch_size,
+ n_iter_ckpt=n_iter_ckpt,
+ n_iter_step=n_iter_step,
+ max_grad_norm=max_grad_norm,
+ random_seed=random_seed,
+ dpo_beta=dpo_beta
+ )
+
+ schedule_config = CosineScheduleConfig(
+ warning_step=warning_step,
+ total_iters=len(dataset) * n_epoch // batch_size,
+ )
+
+ trainer = Trainer(parameter)
+ trainer.train(
+ train_config=train_config,
+ schedule_config=schedule_config,
+ )
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Train the Transformer model.")
+ parser.add_argument("--train_type",choices=["seq", "sft", "dpo"], help="Train type.")
+ parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.")
+ parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.")
+ parser.add_argument("--n_epoch", type=int, default=1, help="Number of epochs to train.")
+ parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training.")
+ parser.add_argument("--n_iter_step", type=int, default=1, help="Number of iterations between each optimizer step.")
+ parser.add_argument("--warning_step", type=int, default=1000, help="Number of iters between warnings.")
+ parser.add_argument("--max_lr", type=float, default=3e-4, help="Max learning rate for training.")
+ parser.add_argument("--n_iter_ckpt", type=int, default=5000, help="Number of iters between checkpoints.")
+ parser.add_argument("--ckpt_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")
+ parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
+ parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping.")
+ parser.add_argument("--adamw_betas", type=tuple, default=(0.9, 0.95), help="Beta values for AdamW optimizer.")
+ parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.")
+ parser.add_argument("--embdeding_lr_rate", type=float, default=1.0, help="The rate between the embedding layers lr rate and the max lr rate.")
+ parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
+
+ args = parser.parse_args()
+
+ train(
+ param_path=args.param_path,
+ data_root_path=args.data_root_path,
+ n_epoch=args.n_epoch,
+ batch_size=args.batch_size,
+ n_iter_step=args.n_iter_step,
+ warning_step=args.warning_step,
+ max_lr=args.max_lr,
+ dpo_beta=args.dpo_beta,
+ adamw_betas=args.adamw_betas,
+ adamw_weight_decay=args.adamw_weight_decay,
+ max_grad_norm=args.max_grad_norm,
+ embdeding_lr_rate=args.embdeding_lr_rate,
+ n_iter_ckpt=args.n_iter_ckpt,
+ ckpt_dir=args.ckpt_dir,
+ train_type=args.train_type,
+ random_seed=args.random_seed
+ )
\ No newline at end of file