commit a4443765eee0ecdeeed0b362f84bd94ad42a4c42 Author: ViperEkura <3081035982@qq.com> Date: Sat Sep 27 12:02:22 2025 +0800 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3ecf496 --- /dev/null +++ b/.gitignore @@ -0,0 +1,22 @@ +# cache +__pycache__ +.pytest_cache + +# params +*.safetensors +*.json +*.pkl +*.db + +# train_log +*.log + +# ignore file +-* + +# vscode file +.vscode + +# build file +build +*.egg-info \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..0d5fa89 --- /dev/null +++ b/README.md @@ -0,0 +1,333 @@ +![image-20250306182014120](/assets/images/project_logo_clipped.png) + +
+ +
+ English | + 中文 +
+ +

KHAOSZ

+
+ +

English Version

+ +This is a Chinese-English bilingual Transformer model supporting both languages. It contains model configurations and training workflows, completing training by loading parameters defined in `param_path/config.json`. The training script `train.py` parses command-line arguments, including dataset root directory, number of training epochs, batch size, checkpoint interval, and checkpoint directory. + +**Model Download Options (Choose One):** + +1. Visit [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) to access **Files and versions** +2. Run `scripts/download.py` to download parameters + +**Demo Video:** [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd) + +Training dataset sources are listed in the **Model Card** section of the HuggingFace download link. + +**License:** Code follows Apache-2.0 protocol. Please credit the source code when used. + +- **📊 Device Selection:** Code defaults to CUDA training +- **🌐 Performance Optimization:** `dtype=torch.bfloat16` is enabled to accelerate training and reduce memory usage. Ensure hardware supports this feature. +- **🤖 Language Support:** Model supports Chinese and English training. The BBPE tokenizer was trained without multilingual text, so OOV (out-of-vocabulary) issues are minimized for these languages but may exist for others. + +### 📌 Training Guide + +To train this Transformer model, follow these steps: + +**(1). Prepare Dataset:** + +Place datasets in the designated root directory. Files should be text documents in Chinese, English, or mixed. Format should align with model input requirements - preferably pre-tokenized token_ids stored as `torch.Tensor` (using `torch.Tensor` saves memory compared to Python lists, which default to 64-bit precision). + +**(2). Install Dependencies:** + +```bash +pip install -r requirements.txt +pip install . +``` + +**(3). Run Training Script:** + +```bash +python train.py \ +--train_type=train_type[seq, sft, dpo] \ +--data_root_path=/path/to/dataset \ +--param_path=/path/to/param_path \ +--n_epoch=5 \ +--batch_size=8 \ +--max_lr=2e-4 \ +--n_iter_ckpt=10000 \ +--ckpt_dir=checkpoints +``` + +**Parameters Explanation:** +- `--train_type`: Training type (seq, sft, dpo) +- `--data_root_path`: Root directory of the dataset +- `--param_path`: Path to the model training parameters +- `--n_epoch`: Total number of training epochs +- `--batch_size`: Batch size +- `--n_iter_step`: Number of batches per training step +- `--warning_step`: Number of warmup steps +- `--max_lr`: Maximum learning rate (using warmup + cosine decay) +- `--n_iter_ckpt`: Checkpoint saving interval +- `--ckpt_dir`: Directory to save checkpoints +- `--resume_dir`: Resume training from the specified path + +Training logs will be saved in `train_log.txt`. Checkpoints will be saved in the specified directory for resuming training or evaluation. + +### 👉 Usage Guide + +**(1). Chatting with the Model:** + +Open `chat.py` or use streaming/non-streaming interfaces: + +**Streaming Output:** +```python +import torch +from khaosz import Khaosz + +model_dir = "your_model_parameter_dir" +model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16) +history = [] + +while True: + query = input(">> ") + if query == "!exit": + break + + response_size = 0 + for response, history in model.stream_generate( + query=query, + history=history, + temperature=0.85, + top_p=0.95, + top_k=50 + ): + print(response[response_size:], end="") + response_size = len(response) +``` + +**Non-streaming Output:** +```python +import torch +from khaosz import Khaosz + +model_dir = "your_model_parameter_dir" +model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16) +history = [] + +while True: + query = input(">> ") + if query == "!exit": + break + + response = model.generate( + query=query, + history=history, + temperature=0.85, + top_p=0.95, + top_k=50 + ) + print(response) +``` + +**(2) Retrieval-Augmented Generation (RAG):** + +```python +import torch +from khaosz import Khaosz + +model_dir = "your_model_parameter_dir" +model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16) + +retrieved_content = model.retrieve_generate( + query=query, + retrieve_top_k=5, + temperature=0.6, + top_k=30, + top_p=0.95 +) +print(retrieved_content) +``` + +### 📌 Model Specifications + +This model is based on a 24-layer Transformer with parameters defined in `config.json`, totaling approximately 1.0 billion (1.0B) parameters. + +**Key Design Choices:** +- Weight tying between embedding and final linear layers (standard for small models to save parameters) +- Embedding layer optimization: Without weight tying, a 10,000-word vocabulary would consume ~102M parameters (0.1B) + +**Limitations:** +- May struggle with complex language phenomena due to smaller parameter size +- Prone to overfitting on specialized datasets +- Limited multilingual capabilities + +**Advantages:** +- Runs efficiently on lower-spec hardware +- Shorter training time compared to larger models + +**Training Pipeline:** +The model has completed pre-training + SFT (Supervised Fine-Tuning) + DPO (Direct Preference Optimization) workflows. All corresponding training code is included in the repository. + + +

中文版本

+这是一个支持中英文双语的 Transformer 模型,能够处理两种语言。模型包含配置文件和训练流程,通过加载 `param_path/config.json` 中定义的参数完成训练。训练脚本 `train.py` 支持命令行参数解析,包括数据集根目录、训练轮数(epochs)、批量大小(batch size)、检查点保存间隔、检查点目录等。 + +**模型下载选项(任选其一):** + +1. 访问 [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) 查看 **Files and versions** +2. 运行 `scripts/download.py` 下载模型参数 + +**演示视频:** [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd) + +训练数据来源请参见 HuggingFace 下载页面中的 **Model Card** 部分。 + +**许可证:** 代码遵循 Apache-2.0 协议,使用时请注明出处。 + +- **📊 设备选择:** 默认使用 CUDA 进行训练 +- **🌐 性能优化:** 启用 `dtype=torch.bfloat16` 以加速训练并减少内存占用,请确保硬件支持该特性 +- **🤖 语言支持:** 模型支持中文和英文训练。由于 BBPE 分词器未使用多语言文本训练,因此中英文的 OOV(未登录词)问题较少,其他语言可能存在 OOV 问题 + + + +### 📌 训练指南 + +要训练该 Transformer 模型,请按照以下步骤操作: + +#### **(1). 准备数据集:** + +将数据集放置在指定的根目录下。文件应为包含中文、英文或混合文本的文本文档。格式应符合模型输入要求——建议使用预分词后的 `token_ids` 并以 `torch.Tensor` 格式保存(使用 `torch.Tensor` 相比 Python 列表更节省内存,列表默认为 64 位精度)。 + +#### **(2). 安装依赖:** + +```bash +pip install -r requirements.txt +pip install . +``` + +#### **(3). 运行训练脚本:** + +```bash +python train.py \ +--train_type=train_type[seq, sft, dpo] \ +--data_root_path=/path/to/dataset \ +--param_path=/path/to/param_path \ +--n_epoch=5 \ +--batch_size=8 \ +--max_lr=2e-4 \ +--n_iter_ckpt=10000 \ +--ckpt_dir=checkpoints +``` + +**参数说明:** +- `--train_type`: 训练类型(seq, sft, dpo) +- `--data_root_path`: 数据集根目录 +- `--param_path`: 模型训练参数路径 +- `--n_epoch`: 总训练轮数 +- `--batch_size`: 批量大小 +- `--n_iter_step`: 每个训练步骤的 batch 数量 +- `--warning_step`: 预热步数(warmup steps) +- `--max_lr`: 最大学习率(使用预热 + 余弦衰减) +- `--n_iter_ckpt`: 检查点保存间隔 +- `--ckpt_dir`: 检查点保存目录 +- `--resume_dir`: 从指定路径恢复训练 + +训练日志将保存在 `train_log.txt` 中。检查点将保存在指定目录,用于恢复训练或评估。 + + + +### 👉 使用指南 + +#### **(1). 与模型对话:** + +打开 `chat.py` 或使用流式/非流式接口: + +**流式输出:** +```python +import torch +from khaosz import Khaosz + +model_dir = "your_model_parameter_dir" +model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16) +history = [] + +while True: + query = input(">> ") + if query == "!exit": + break + + response_size = 0 + for response, history in model.stream_generate( + query=query, + history=history, + temperature=0.85, + top_p=0.95, + top_k=50 + ): + print(response[response_size:], end="") + response_size = len(response) +``` + +**非流式输出:** +```python +import torch +from khaosz import Khaosz + +model_dir = "your_model_parameter_dir" +model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16) +history = [] + +while True: + query = input(">> ") + if query == "!exit": + break + + response = model.generate( + query=query, + history=history, + temperature=0.85, + top_p=0.95, + top_k=50 + ) + print(response) +``` + +#### **(2). 基于检索的生成(RAG):** + +```python +import torch +from khaosz import Khaosz + +model_dir = "your_model_parameter_dir" +model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16) + +retrieved_content = model.retrieve_generate( + query=query, + retrieve_top_k=5, + temperature=0.6, + top_k=30, + top_p=0.95 +) +print(retrieved_content) +``` + + + +### 📌 模型规格说明(重复部分) + +该模型基于一个 24 层的 Transformer 架构,参数配置定义在 `config.json` 中,总参数量约为 10 亿(1.0B)。 + +**关键设计选择:** +- 在嵌入层(embedding)与最终线性层之间进行权重绑定(weight tying),这是小型模型中常见的节省参数量的做法 +- 嵌入层优化:若不进行权重绑定,一个包含 10,000 个词的词汇表将消耗约 1.02 亿(0.1B)参数 + +**局限性:** +- 由于参数规模较小,可能在处理复杂语言现象时表现受限 +- 在特定领域的数据集上容易出现过拟合 +- 多语言能力有限 + +**优势:** +- 可在低配置硬件上高效运行 +- 相较于大型模型,训练时间更短 + +**训练流程:** +该模型已完成预训练(pre-training)+ 监督微调(SFT, Supervised Fine-Tuning)+ 直接偏好优化(DPO, Direct Preference Optimization)的全流程。所有相关的训练代码均已包含在代码库中。 \ No newline at end of file diff --git a/assets/docs/introduction.md b/assets/docs/introduction.md new file mode 100644 index 0000000..7333403 --- /dev/null +++ b/assets/docs/introduction.md @@ -0,0 +1,89 @@ +## 模型介绍 + + + +### 1. 模型搭建 + +本模型采用Transformer架构, 使用GQA(q_head=24, kv_head=4) 机制,相较于传统的MHA可以节省KV cache 的显存占用(但是目前没有做KV cache),通过堆叠24层Transformer实现模型的搭建, 参数量为1.0b。Transformer 是自回归模型, 是通过计算前面所有的token的关系得到下一个token的概率分布 + +![structure](../images/structure.png) + +什么是自回归模型呢, 在把句子拆分成token之后, 模型会预测下一个token的概率分布。这意味着模型会根据给定的上下文(即已经出现的tokens序列),计算出下一个可能的token及其对应的概率。 + + + +#### 1. 自回归 + +假设我们有一个句子被拆分成如下tokens列表: + +``` +["你好", "," "今天", "天气"] +``` + +接下来,模型会基于这个序列预测下一个可能出现的token。这通常以概率分布的形式给出,比如: + +``` +-> {"token": "不错", "probability": 0.4} +-> {"token": "晴朗", "probability": 0.2} +-> ...... +``` + +这里,“不错”和“晴朗”是两个可能跟随在“天气”之后的tokens,并且给出了每个token成为下一个token的可能性大小。 + +之后,我们通过采样(通过top_k, top_p, temperature参数调整采样后的结果)得到下一个token并且将下一个token加入序列作为输入 + +``` +["你好", "," "今天", "天气", "不错"] +``` + +之后都是在重复这个流程, 直到遇到控制流程结束的token(<|end_of_seqence|>)模型停止处理(一般模型都会设置控制token, 不然模型会一直输出到显存爆炸)。 + + + + + +#### 2. 因果掩码 + +transformer 中采用注意力机制,输入的形状一般为[bsz, seq_len], 输出为[bsz, seq_len,n_dim], 为了实现预测下一个token, 模型的输入和输出必须错开来一个位置。模型预测的target必须错开一个位置, 在训练的时候我们也采用错开一个位置的方法 + +``` +sequence : [[1, 2, 3, 4, 5, 6]] +input_ids: [[1, 2, 3, 4, 5]] +target_ids: [[2, 3, 4, 5, 6]] +``` + + + +注意力得分计算的公式为 + + +$$ s_{ij} = softmax(\frac{q_i^Tk_j}{\sqrt{d_k}}) $$ +$$ s_{ij} := s_{ij} + mask_{ij} $$ + + +其中注意力得分代表了模型对两个token之间相似程度的关注程度 + +对于decoder only结构的模型, 为了防止模型从未来的位置偷到信息, 在注意力的计算过程中需要增加掩码,我们需要在注意力得分计算之前应用一个掩码。这个掩码通常是一个下三角矩阵,对于长度为n的序列,它的形状是[n, n]。下面以一个长度为5的序列为例,展示如何创建这样的因果掩码矩阵: + +``` +[[0, -inf, -inf, -inf, -inf], + [0, 0, -inf, -inf, -inf], + [0, 0, 0, -inf, -inf], + [0, 0, 0, 0, -inf], + [0, 0, 0, 0, 0]] +``` + +在这个矩阵中,0表示可以注意到的位置,而-inf表示应该被掩盖(即不应注意到)的位置。因为这个句子保证了注意力得分中 $j > i$ 的部分通过softmax 之后由`inf` 变成0, 也就是模型不能看到未来的信息 + + + +#### 3. 旋转位置编码 + +旋转位置编码(Rotary Position Embedding, RoPE)是一种为了解决Transformer模型中缺乏对序列位置信息直接建模的问题而设计的位置编码方法。与传统的位置编码(如正弦和余弦函数的位置编码)不同,RoPE通过将位置信息直接嵌入到查询(Query, Q)和键(Key, K)向量中来实现,使得模型能够更自然地处理序列中的相对位置关系。 + + +$$ q_i = R_i W_q x_i $$ +$$ k_j = R_j W_k x_j $$ +$$ q_i^T k_j = (R_i W_q x_i)^T( R_j W_k x_j) = x_i^T W_q^T R_{i-j} W_k x_j $$ + +其中的 $R_{i-j}$ 控制了模型的不同token 在不同相对距离上注意力的衰减,在 $i - j$ 绝对值越大的时候, 衰减的程度越强, 通过这种方式能让模型学习到相对位置关系, 从而使得模型可以扩展和适应长序列 \ No newline at end of file diff --git a/assets/docs/kvcache.md b/assets/docs/kvcache.md new file mode 100644 index 0000000..304c484 --- /dev/null +++ b/assets/docs/kvcache.md @@ -0,0 +1,27 @@ +## kv_cache 实现 + +根据注意力的计算公式 + +$$ +\begin{align*} +o_i &= \sum_j s_{ij} v_{j} \\ +s_{ij} &= \text{softmax}\left( \sum_n \frac{q_{i,n} k_{j,n}}{\sqrt{d_k}} \right) +\end{align*} +$$ + +由于模型是自回归模型, 我们只用求序列最后一个部分,也就是说 $ i $ 的下标是确定的, 是序列最后一个元素, 我们求的是 $o_{n} $ + +$$ +\begin{align*} +o_n &= \sum_j s_{j}v_{j,n} \\ +s_j &= \text{softmax}\left(\sum_n\frac{q_n k_{j,n}}{\sqrt{d_k}} \right) +\end{align*} +$$ + +如果我们把式子展开 + +$$ +o_n = \sum_j \sum_n \text{softmax}\left(\frac{q_n k_{j,n}}{\sqrt{d_k}}\right)v_{j,n} +$$ + +以上表达式只有k和v存在长度下标, 而 $q$ 没有, 所以计算过程中 $q$ 的输入是确定的上次输入的最后一个token, 而 $k, v$ 是需要对不同长度的部分进行缓存的,同时缓存的时候应该注意位置编码的计算应该在kvcache的计算之前进行,否则会存在位置编码的计算错误 \ No newline at end of file diff --git a/assets/images/project_logo.png b/assets/images/project_logo.png new file mode 100644 index 0000000..342eb9c Binary files /dev/null and b/assets/images/project_logo.png differ diff --git a/assets/images/project_logo_clipped.png b/assets/images/project_logo_clipped.png new file mode 100644 index 0000000..b02cd08 Binary files /dev/null and b/assets/images/project_logo_clipped.png differ diff --git a/assets/images/structure.png b/assets/images/structure.png new file mode 100644 index 0000000..5f4f3a2 Binary files /dev/null and b/assets/images/structure.png differ diff --git a/generate.py b/generate.py new file mode 100644 index 0000000..a18c433 --- /dev/null +++ b/generate.py @@ -0,0 +1,101 @@ +import os +import torch +import json +import torch +import argparse + +from khaosz import Khaosz +from typing import List +from tqdm import tqdm + + +PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) + +def batch_generate( + model: Khaosz, + queries: List[str], + temperature: float, + top_k: int, + top_p: float, + batch_size: int, +) -> List: + assert batch_size > 0 + sorted_queries = sorted(queries, key=lambda x: len(x), reverse=True) + original_indices = {query: idx for idx, query in enumerate(queries)} + + responses = [None] * len(queries) + total_batches = (len(sorted_queries) + batch_size - 1) // batch_size + + for i in tqdm(range(0, total_batches * batch_size, batch_size), desc="Generating responses"): + batch_queries = sorted_queries[i: min(i + batch_size, len(queries))] + if not isinstance(batch_queries, list): + batch_queries = [batch_queries] + + batch_responses = model.batch_generate( + queries=batch_queries, + temperature=temperature, + top_k=top_k, + top_p=top_p + ) + + for batch_query, batch_response in zip(batch_queries, batch_responses): + print(f"Q: {batch_query[:50]} \nR: {batch_response[:50]})") + + for query, response in zip(batch_queries, batch_responses): + original_idx = original_indices[query] + responses[original_idx] = response + + return responses + + +def processor( + model: Khaosz, + input_json_file: str, + output_json_file: str, + batch_size: int, + temperature: float, + top_p: float, + top_k: int, + question_key: str="question", +): + with open(input_json_file, "r", encoding='utf-8') as f: + input_dict = [json.loads(line) for line in f] + queries = [item[question_key] for item in input_dict] + + output_dict = batch_generate( + model=model, + queries=queries, + temperature=temperature, + top_k=top_k, + top_p=top_p, + batch_size=batch_size + ) + + with open(output_json_file, "w", encoding='utf-8') as f: + json.dump(output_dict, f, indent=4, ensure_ascii=False) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.") + + parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory.") + parser.add_argument("--input_json_file", type=str, required=True, help="Path to the input JSONL file.") + parser.add_argument("--output_json_file", type=str, required=True, help="Path to the output JSONL file.") + parser.add_argument("--question_key", type=str, default="question", help="Key for the question in the input JSON.") + parser.add_argument("--temperature", type=float, default=0.60, help="Temperature for generating responses.") + parser.add_argument("--top_p", type=float, default=0.95, help="Top-p value for generating responses.") + parser.add_argument("--top_k", type=int, default=30, help="Top-k value for generating responses.") + parser.add_argument("--batch_size", type=int, default=1, help="Batch size for generating responses.") + + args = parser.parse_args() + model = Khaosz(args.model_dir).to(device='cuda', dtype=torch.bfloat16) + + processor( + model, + input_json_file=args.input_json_file, + output_json_file=args.output_json_file, + question_key=args.question_key, + batch_size=args.batch_size, + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p + ) \ No newline at end of file diff --git a/khaosz/__init__.py b/khaosz/__init__.py new file mode 100644 index 0000000..b4bbcf5 --- /dev/null +++ b/khaosz/__init__.py @@ -0,0 +1,52 @@ +__version__ = "1.2.1" +__author__ = "ViperEkura" + +from khaosz.model import Khaosz +from khaosz.core.transformer import Transformer, TransformerConfig +from khaosz.utils.retriever import Retriever +from khaosz.utils.splitter import ( + SemanticTextSplitter, + PriorityTextSplitter +) +from khaosz.core.tokenizer import BpeTokenizer +from khaosz.core.parameter import ParameterLoader +from khaosz.core.generator import ( + TextGenerator, + ChatGenerator, + StreamGenerator, + BatchGenerator, + RetrievalGenerator, + EmbeddingEncoder +) +from khaosz.trainer.trainer import Trainer +from khaosz.trainer.dataset import SeqDataset, SftDataset, DpoDataset, BaseDataset + + +__all__ = [ + # model + "Khaosz", + + # module + "Transformer", + "TransformerConfig", + "BpeTokenizer", + "ParameterLoader", + "TextGenerator", + "ChatGenerator", + "StreamGenerator", + "BatchGenerator", + "RetrievalGenerator", + "EmbeddingEncoder", + + # trainer + "Trainer", + "SeqDataset", + "SftDataset", + "DpoDataset", + "BaseDataset", + + # utils + "Retriever", + "SemanticTextSplitter", + "PriorityTextSplitter", +] diff --git a/khaosz/core/__init__.py b/khaosz/core/__init__.py new file mode 100644 index 0000000..0b31b94 --- /dev/null +++ b/khaosz/core/__init__.py @@ -0,0 +1,27 @@ +from khaosz.core.tokenizer import BpeTokenizer +from khaosz.core.transformer import Transformer, TransformerConfig +from khaosz.core.parameter import ParameterLoader, ModelParameter, Checkpoint +from khaosz.core.generator import ( + TextGenerator, + ChatGenerator, + StreamGenerator, + BatchGenerator, + RetrievalGenerator, + EmbeddingEncoder +) + + +__all__ = [ + "Transformer", + "TransformerConfig", + "BpeTokenizer", + "ParameterLoader", + "ModelParameter", + "Checkpoint", + "TextGenerator", + "ChatGenerator", + "StreamGenerator", + "BatchGenerator", + "RetrievalGenerator", + "EmbeddingEncoder" +] \ No newline at end of file diff --git a/khaosz/core/generator.py b/khaosz/core/generator.py new file mode 100644 index 0000000..3aba293 --- /dev/null +++ b/khaosz/core/generator.py @@ -0,0 +1,568 @@ +import torch +from torch import Tensor +from typing import List, Tuple, Union, Optional, Generator, Self +from khaosz.core.parameter import ModelParameter + + +def build_prompt(query: str, history: Optional[List[Tuple[str, str]]] = None) -> str: + """ + Build prompt for query and history + + Args: + query(str): query string + history(Optional[List[Tuple[str, str]]]): history list of query and response + + Returns: + str: prompt string + + """ + prompt_parts = [] + + if history is None: + history = [] + + for his_query, his_response in history: + prompt_parts.append(f"<|user|> {his_query} <|system|> {his_response}") + + if query is not None: + prompt_parts.append(f"<|user|> {query} <|system|> ") + + return "\n".join(prompt_parts) + +def pad_sequence(ids_list: List[List[int]], max_ids_len: int, pad_id: int) -> List[List[int]]: + """ + Pad a list of sequences to a fixed length. + + Args: + ids_list (List[List[int]]): A list of sequences. + max_ids_len (int): The maximum length of sequences. + pad_id (int): The id to pad sequences. + + Returns: + List[List[int]]: A list of padded sequences. + + """ + new_ids_list = [] + for ids in ids_list: + pad_len = max_ids_len - len(ids) + padded_seq = [pad_id] * pad_len + ids + new_ids_list.append(padded_seq) + + return new_ids_list + +def apply_sampling_strategies( + logits: Tensor, + temperature: float, + top_k: int, + top_p: float, + filter_value: float = -float("inf") +) -> Tensor: + """ + Apply sampling strategies to the logits tensor. + + Args: + logits (Tensor): The logits tensor. + temperature (float): The temperature parameter. + top_k (int): The top-k parameter. + top_p (float): The top-p parameter. + filter_value (float, optional): The filter value. Defaults to -float("inf"). + + Returns: + Tensor: The sampled logits tensor. + + """ + + if temperature != 1.0: + logits = logits / temperature + + if top_k > 0: + top_k = min(top_k, logits.size(-1)) + indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) + + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + indices_to_remove = torch.zeros_like(logits, dtype=torch.bool) + indices_to_remove.scatter_( + dim=1, + index=sorted_indices, + src=sorted_indices_to_remove + ) + + logits[indices_to_remove] = filter_value + + return logits + + +class KVCacheManager: + def __init__( + self, + num_layers: int, + batch_size: int, + max_len: int, + num_heads: int, + head_dim: int, + device: torch.device = "cuda", + dtype: torch.dtype = torch.bfloat16 + ): + self.num_layers = num_layers + self.batch_size = batch_size + self.max_len = max_len + self.num_heads = num_heads + self.head_dim = head_dim + self.device = device + self.dtype = dtype + + self._kv_cache: List[Tuple[Tensor, Tensor]] = None + self._seq_mask: Tensor = None + self._initialize() + + def _initialize(self): + self._kv_cache = [] + for _ in range(self.num_layers): + k_cache = torch.zeros( + (self.batch_size, self.max_len, self.num_heads, self.head_dim), + device=self.device, dtype=self.dtype + ) + v_cache = torch.zeros( + (self.batch_size, self.max_len, self.num_heads, self.head_dim), + device=self.device, dtype=self.dtype + ) + self._kv_cache.append((k_cache, v_cache)) + + self._seq_mask = torch.ones( + (self.batch_size, self.max_len), + device=self.device, dtype=torch.bool + ) + + def update(self, active_mask: Tensor): + for i in range(self.num_layers): + k_cache, v_cache = self._kv_cache[i] + new_k_cache, new_v_cache = k_cache[active_mask], v_cache[active_mask] + self._kv_cache[i] = (new_k_cache, new_v_cache) + + self._seq_mask = self._seq_mask[active_mask] + + def reset(self, full_reset=False): + if full_reset: + self._kv_cache = None + self._seq_mask = None + else: + self._initialize() + + def set_seq_mask(self, input_ids: Tensor, pad_id: int): + batch_size, seq_len = input_ids.shape + bool_mask = (input_ids != pad_id) + self._seq_mask[: batch_size, : seq_len] = bool_mask + + def get_kvcache(self) -> List[Tuple[Tensor, Tensor]]: + return self._kv_cache + + def get_seq_mask(self) -> Tensor: + return self._seq_mask + + +class GeneratorCore: + def __init__(self, parameter: ModelParameter): + self.model = parameter.model + self.tokenizer = parameter.tokenizer + self.config = parameter.config + + def compute_logits( + self, + input_ids: Tensor, + attn_mask: Optional[Tensor] = None, + kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None, + start_pos: int = 0 + ) -> Tuple[Tensor, int]: + with torch.inference_mode(): + outputs = self.model(input_ids, attn_mask, kv_caches, start_pos) + logits = outputs["logits"][:, -1, :] + cache_increase = input_ids.size(-1) + + return logits, cache_increase + + def to(self, *args, **kargs) -> Self: + self.model.to(*args, **kargs) + return self + + +class EmbeddingEncoderCore: + def __init__(self, parameter: ModelParameter): + self.model = parameter.model + self.tokenizer = parameter.tokenizer + self.config = parameter.config + + def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]: + with_batch = isinstance(sentence, list) + ids = self.tokenizer.encode(sentence) + batch_ids = ids if with_batch else [ids] + max_model_len = self.config.m_len + + all_fragments = [] + fragment_origin_idx = [] + + for i, seq in enumerate(batch_ids): + if len(seq) > max_model_len: + fragments = [seq[j:j+max_model_len] for j in range(0, len(seq), max_model_len)] + all_fragments.extend(fragments) + fragment_origin_idx.extend([i] * len(fragments)) + else: + all_fragments.append(seq) + fragment_origin_idx.append(i) + + #if empty fragments + if not all_fragments or not ids: + return [] if with_batch else torch.tensor([]) + + device = next(self.model.parameters()).device + max_len = min(max(len(seq) for seq in all_fragments), max_model_len) + + padded_ids = [] + masks = [] + for seq in all_fragments: + pad_len = max_len - len(seq) + padded_seq = seq + [self.tokenizer.pad_id] * pad_len + mask = [token_id != self.tokenizer.pad_id for token_id in padded_seq] + padded_ids.append(padded_seq) + masks.append(mask) + + input_tensor = torch.tensor(padded_ids, device=device, dtype=torch.long) + seq_mask = torch.tensor(masks, device=device, dtype=torch.bool) + + with torch.inference_mode(): + outputs = self.model(input_tensor, seq_mask)["hidden_states"] + # [num_fragments, seq_len, hidden_size] + fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1)) + + sentence_embs: List[Tensor] = [] + for i in range(len(batch_ids)): + indices = [idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i] + if indices is not None: + sum_frags = torch.sum(fragment_embs[indices, :, :], dim=1) # [frags, hidden_size] + length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(1) # [frags, 1] + emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size] + sentence_embs.append(emb.flatten()) + + if with_batch: + return [emb.flatten() for emb in sentence_embs] + else: + return sentence_embs[0].flatten() + + def to(self, *args, **kargs) -> Self: + self.model.to(*args, **kargs) + return self + + +class TextGenerator(GeneratorCore): + def __init__(self, parameter: ModelParameter): + super().__init__(parameter) + + def generate( + self, + query: str, + temperature: float, + top_k: int, + top_p: float, + ) -> str: + assert temperature >= 0.0 + assert top_k >= 0 + assert top_p >= 0.0 and top_p <= 1.0 + + device = next(self.model.parameters()).device + cache_manager = KVCacheManager( + num_layers=self.config.n_layer, + batch_size=1, + max_len=self.config.m_len, + num_heads=self.config.n_kvhead, + head_dim=self.config.n_dim // self.config.n_head, + device=device, + ) + + ids = self.tokenizer.encode(query) + input_ids = torch.tensor([ids], device=device, dtype=torch.long) + + start_cache_pos = len(ids) + cur_cache_pos = 0 + self.model.eval() + + while len(ids) < self.config.m_len: + kv_caches = cache_manager.get_kvcache() + logits, cache_increase = self.compute_logits( + input_ids, + kv_caches=kv_caches, + start_pos=cur_cache_pos + ) + logits = apply_sampling_strategies(logits, temperature, top_k, top_p) + probs = torch.softmax(logits, dim=-1) + next_token_id = torch.multinomial(probs, num_samples=1) + + input_ids = next_token_id + ids.append(next_token_id.item()) + cur_cache_pos += cache_increase + + if next_token_id.item() in self.tokenizer.stop_ids: + break + + response = self.tokenizer.decode(ids[start_cache_pos:]) + + return response + + + +class ChatGenerator(GeneratorCore): + def __init__(self, parameter: ModelParameter): + super().__init__(parameter) + + def generate( + self, + query: str, + history: List[Tuple[str, str]], + temperature: float, + top_k: int, + top_p: float, + ) -> str: + + assert temperature >= 0.0 + assert top_k >= 0 + assert top_p >= 0.0 and top_p <= 1.0 + + if history is None: + history = [] + + device = next(self.model.parameters()).device + cache_manager = KVCacheManager( + num_layers=self.config.n_layer, + batch_size=1, + max_len=self.config.m_len, + num_heads=self.config.n_kvhead, + head_dim=self.config.n_dim // self.config.n_head, + device=device, + ) + ids = self.tokenizer.encode(build_prompt(query, history)) + input_ids = torch.tensor([ids], device=device, dtype=torch.long) + cpy_history = history.copy() + + start_cache_pos = len(ids) + cur_cache_pos = 0 + self.model.eval() + + + while len(ids) < self.config.m_len: + kv_caches = cache_manager.get_kvcache() + logits, cache_increase = self.compute_logits( + input_ids, + kv_caches=kv_caches, + start_pos=cur_cache_pos + ) + logits = apply_sampling_strategies(logits, temperature, top_k, top_p) + probs = torch.softmax(logits, dim=-1) + next_token_id = torch.multinomial(probs, num_samples=1) + + input_ids = next_token_id + ids.append(next_token_id.item()) + cur_cache_pos += cache_increase + + if next_token_id.item() in self.tokenizer.stop_ids: + break + + response = self.tokenizer.decode(ids[start_cache_pos:]) + cpy_history.append((query, response)) + + return response, cpy_history + + +class StreamGenerator(GeneratorCore): + def __init__(self, parameter: ModelParameter): + super().__init__(parameter) + + def generate( + self, + query: str, + history: List[Tuple[str, str]], + temperature: float, + top_k: int, + top_p: float, + ) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]: + + assert temperature >= 0.0 + assert top_k >= 0 + assert top_p >= 0.0 and top_p <= 1.0 + + if history is None: + history = [] + + device = next(self.model.parameters()).device + cache_manager = KVCacheManager( + num_layers=self.config.n_layer, + batch_size=1, + max_len=self.config.m_len, + num_heads=self.config.n_kvhead, + head_dim=self.config.n_dim // self.config.n_head, + device=device, + ) + ids = self.tokenizer.encode(build_prompt(query, history)) + input_ids = torch.tensor([ids], device=device, dtype=torch.long) + cpy_history = history.copy() + + start_cache_pos = len(ids) + cur_cache_pos = 0 + self.model.eval() + + + while len(ids) < self.config.m_len: + kv_caches = cache_manager.get_kvcache() + logits, cache_increase = self.compute_logits( + input_ids, + kv_caches=kv_caches, + start_pos=cur_cache_pos + ) + logits = apply_sampling_strategies(logits, temperature, top_k, top_p) + probs = torch.softmax(logits, dim=-1) + next_token_id = torch.multinomial(probs, num_samples=1) + + input_ids = next_token_id + ids.append(next_token_id.item()) + cur_cache_pos += cache_increase + + response = self.tokenizer.decode(ids[start_cache_pos:]) + yield response, cpy_history + [(query, response)] + + if next_token_id.item() in self.tokenizer.stop_ids: + yield response + "\n", cpy_history + [(query, response)] + break + + +class BatchGenerator(GeneratorCore): + def __init__(self, parameter: ModelParameter): + super().__init__(parameter) + + def generate( + self, + queries: List[str], + histories: List[List[Tuple[str, str]]], + temperature: float, + top_k: int, + top_p: float + ) -> List[str]: + + assert temperature >= 0.0 + assert top_k >= 0 + assert top_p >= 0.0 and top_p <= 1.0 + + batch_size = len(queries) + if histories is None: + histories = [[] for _ in range(batch_size)] + + prompts = [build_prompt(query, history) for query, history in zip(queries, histories)] + ids_list = [self.tokenizer.encode(prompt) for prompt in prompts] + max_ids_len = max(len(ids) for ids in ids_list) + ids_list = pad_sequence(ids_list, max_ids_len, self.tokenizer.pad_id) + + device = next(self.model.parameters()).device + cache_manager = KVCacheManager( + num_layers=self.config.n_layer, + batch_size=batch_size, + max_len=self.config.m_len, + num_heads=self.config.n_kvhead, + head_dim=self.config.n_dim // self.config.n_head, + device=device, + ) + + input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long) + cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id) + activate_task_mask = [True] * batch_size + + start_cache_pos = max_ids_len + cur_cache_pos = 0 + + while max_ids_len < self.config.m_len and sum(activate_task_mask) != 0: + kv_caches = cache_manager.get_kvcache() + attn_mask =cache_manager.get_seq_mask() + + logits, cache_increase = self.compute_logits( + input_tensor, + attn_mask=attn_mask, + kv_caches=kv_caches, + start_pos=cur_cache_pos + ) + + cur_cache_pos += cache_increase + logits = apply_sampling_strategies(logits, temperature, top_k, top_p) + probs = torch.softmax(logits, dim=-1) + next_token_id = torch.multinomial(probs, num_samples=1) + + active_mask = [] + c_ids = 0 + + for i in range(batch_size): + if activate_task_mask[i]: + token = next_token_id[c_ids, :].item() + ids_list[i].append(token) + c_ids += 1 + + is_active = not token in self.tokenizer.stop_ids + activate_task_mask[i] = is_active + active_mask.append(is_active) + + active_mask = torch.tensor(active_mask, device=device, dtype=torch.bool) + cache_manager.update(active_mask) + input_tensor = next_token_id[active_mask, :] + + max_ids_len += 1 + + + responses = [str()] * batch_size + for i in range(batch_size): + responses[i] = self.tokenizer.decode(ids_list[i][start_cache_pos:]) + histories[i].append((queries[i], responses[i])) + + return responses + + + +class RetrievalGenerator(GeneratorCore): + def __init__(self, retriever_parameter: ModelParameter): + super().__init__(retriever_parameter) + + def generate( + self, + retrieved: List[str], + query: str, + history: List[Tuple[str, str]], + temperature: float, + top_k: int, + top_p: float, + ) -> str: + assert temperature >= 0.0 + assert top_k >= 0 + assert top_p >= 0.0 and top_p <= 1.0 + + if history is None: + history = [] + + retrieved = "\n".join([f"{idx + 1}. {key}" for idx, key in enumerate(retrieved)]) if retrieved else "" + retrieved_query = f"{retrieved}\n\n根据以上内容回答: {query}" if retrieved else query + parameter = ModelParameter(self.model, self.tokenizer, self.config) + + return ChatGenerator(parameter).generate( + retrieved_query, + history, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + +class EmbeddingEncoder(EmbeddingEncoderCore): + def __init__(self, parameter: ModelParameter): + super().__init__(parameter) + + def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]: + return super().encode(sentence) + \ No newline at end of file diff --git a/khaosz/core/parameter.py b/khaosz/core/parameter.py new file mode 100644 index 0000000..f752792 --- /dev/null +++ b/khaosz/core/parameter.py @@ -0,0 +1,238 @@ +import pickle as pkl +import matplotlib.pyplot as plt +import safetensors.torch as st +import torch.nn as nn +import torch.optim as optim + +from dataclasses import dataclass, field +from typing import Optional, Self, Union +from pathlib import Path + +from khaosz.core.tokenizer import BpeTokenizer +from khaosz.core.transformer import TransformerConfig, Transformer + + +class BaseModelIO: + """Base class for model I/O operations.""" + + def __init__( + self, + model: Optional[nn.Module] = None, + tokenizer: Optional[BpeTokenizer] = None, + config: Optional[TransformerConfig] = None + ): + self.model = model + self.tokenizer = tokenizer or BpeTokenizer() + self.config = config or TransformerConfig() + + def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]: + """Get standardized file paths for model components.""" + dir_path = Path(directory) + return { + "model": dir_path / "model.safetensors", + "config": dir_path / "config.json", + "tokenizer": dir_path / "tokenizer.json" + } + + def save_components(self, save_dir: Union[str, Path]): + """Save core model components.""" + paths = self._get_file_paths(save_dir) + paths["model"].parent.mkdir(parents=True, exist_ok=True) + + if self.model is not None: + st.save_file(self.model.state_dict(), str(paths["model"])) + self.config.save(str(paths["config"])) + self.tokenizer.save(str(paths["tokenizer"])) + + def load_components(self, load_dir: Union[str, Path]) -> Self: + """Load core model components.""" + paths = self._get_file_paths(load_dir) + + self.config.load(str(paths["config"])) + self.tokenizer.load(str(paths["tokenizer"])) + + if paths["model"].exists(): + state_dict = st.load_file(str(paths["model"])) + if self.model is None: + self.model = Transformer(self.config) + self.model.load_state_dict(state_dict) + + return self + + def to(self, *args, **kwargs) -> Self: + """Move model to device.""" + if self.model is not None: + self.model.to(*args, **kwargs) + return self + + +@dataclass +class ModelParameter(BaseModelIO): + """Container for model parameters with serialization capabilities.""" + + model: Optional[nn.Module] = field( + default=None, + metadata={"help": "Transformer model."} + ) + tokenizer: BpeTokenizer = field( + default_factory=BpeTokenizer, + metadata={"help": "Tokenizer for the model."} + ) + config: TransformerConfig = field( + default_factory=TransformerConfig, + metadata={"help": "Transformer model configuration."} + ) + + def save(self, save_dir: Union[str, Path]): + """Save model parameters.""" + self.save_components(save_dir) + + def load(self, load_dir: Union[str, Path]) -> Self: + """Load model parameters.""" + return self.load_components(load_dir) + + +@dataclass +class Checkpoint(BaseModelIO): + """Extended model parameters with training state.""" + + model: Optional[nn.Module] = field( + default=None, + metadata={"help": "Transformer model."} + ) + tokenizer: BpeTokenizer = field( + default_factory=BpeTokenizer, + metadata={"help": "Tokenizer for the model."} + ) + config: TransformerConfig = field( + default_factory=TransformerConfig, + metadata={"help": "Transformer model configuration."} + ) + loss_list: list[float] = field( + default_factory=list, + metadata={"help": "List of training losses."} + ) + current_iter: int = field( + default=0, + metadata={"help": "Current training iteration."} + ) + optimizer: Optional[optim.Optimizer] = field( + default=None, + metadata={"help": "Optimizer state."} + ) + + def __post_init__(self): + # Ensure current_iter matches loss list length if not explicitly set + if self.current_iter == 0 and self.loss_list: + self.current_iter = len(self.loss_list) + + def _get_training_paths(self, directory: Union[str, Path]) -> dict[str, Path]: + """Get file paths for training-specific files.""" + paths = self._get_file_paths(directory) + paths.update({ + "loss_list": paths["model"].parent / "loss.pkl", + "loss_plot": paths["model"].parent / "loss.png", + "optimizer": paths["model"].parent / "optimizer.pkl" + }) + return paths + + def save_training_state(self, save_dir: Union[str, Path]): + """Save training-specific state.""" + paths = self._get_training_paths(save_dir) + + # Save loss plot + self._plot_loss(str(paths["loss_plot"])) + + # Save loss list + with open(str(paths["loss_list"]), "wb") as f: + pkl.dump(self.loss_list, f) + + # Save optimizer state + if self.optimizer is not None: + with open(str(paths["optimizer"]), "wb") as f: + pkl.dump(self.optimizer.state_dict(), f) + + def load_training_state(self, load_dir: Union[str, Path]) -> Self: + """Load training-specific state.""" + paths = self._get_training_paths(load_dir) + + # Load loss list + if paths["loss_list"].exists(): + with open(str(paths["loss_list"]), "rb") as f: + self.loss_list = pkl.load(f) + self.current_iter = len(self.loss_list) + + # Load optimizer state + if paths["optimizer"].exists() and self.optimizer is not None: + with open(str(paths["optimizer"]), "rb") as f: + optim_state = pkl.load(f) + self.optimizer.load_state_dict(optim_state) + + return self + + def _plot_loss(self, save_path: str): + """Plot and save loss curve.""" + if not self.loss_list: + return + + plt.figure(figsize=(10, 6)) + plt.plot(self.loss_list) + plt.title(f"Training Loss - Iteration {self.current_iter}") + plt.xlabel("Batch") + plt.ylabel("Loss") + plt.grid(True) + plt.savefig(save_path, dpi=300, bbox_inches="tight") + plt.close() + + def save(self, save_dir: Union[str, Path]): + """Save complete checkpoint.""" + self.save_components(save_dir) + self.save_training_state(save_dir) + + def load(self, load_dir: Union[str, Path]) -> Self: + """Load complete checkpoint.""" + self.load_components(load_dir) + self.load_training_state(load_dir) + return self + + +class ParameterLoader: + """Factory class for loading model parameters or checkpoints.""" + + @staticmethod + def load(load_dir: Union[str, Path]) -> Union[ModelParameter, Checkpoint]: + """Load either ModelParameter or Checkpoint based on directory contents.""" + load_dir = Path(load_dir) + + # Check for training-specific files + loss_file = load_dir / "loss.pkl" + has_training_data = loss_file.exists() + + # Create appropriate instance + if has_training_data: + checkpoint = Checkpoint() + checkpoint.load(str(load_dir)) + return checkpoint + else: + params = ModelParameter() + params.load(str(load_dir)) + return params + + @staticmethod + def create_checkpoint( + model: nn.Module, + tokenizer: BpeTokenizer, + config: TransformerConfig, + loss_list: Optional[list[float]] = None, + optimizer: Optional[optim.Optimizer] = None + ) -> Checkpoint: + """Convenience method to create a training checkpoint.""" + return Checkpoint( + model=model, + tokenizer=tokenizer, + config=config, + loss_list=loss_list or [], + optimizer=optimizer + ) + + diff --git a/khaosz/core/tokenizer.py b/khaosz/core/tokenizer.py new file mode 100644 index 0000000..1f0f1ed --- /dev/null +++ b/khaosz/core/tokenizer.py @@ -0,0 +1,111 @@ +from tokenizers import Tokenizer, Encoding +from tokenizers import decoders, processors, normalizers, pre_tokenizers +from tokenizers.models import BPE +from tokenizers.trainers import BpeTrainer +from typing import List, Union + + +class BpeTokenizer: + def __init__(self, path=None): + self._control_tokens = ["", "", ""] + self._special_tokens = ["<|user|>", "<|system|>"] + model = BPE() + tokenizer = Tokenizer(model) + tokenizer.normalizer = normalizers.Sequence([ + normalizers.NFC() + ]) + tokenizer.pre_tokenizer = pre_tokenizers.Sequence([ + pre_tokenizers.Punctuation(behavior="isolated"), + pre_tokenizers.Metaspace(prepend_scheme="never"), + pre_tokenizers.Split(pattern=r"(\d+|[a-zA-Z]+|(?:'s|'t|'re|'ve|'m|'ll|'d))", behavior="isolated"), + pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False) + ]) + tokenizer.decoder = decoders.Sequence([ + decoders.ByteLevel(), + decoders.Metaspace(prepend_scheme="never") + ]) + tokenizer.post_processor = processors.Sequence([ + processors.ByteLevel(trim_offsets=False) + ]) + self._tokenizer = tokenizer + + if path is not None: + self._tokenizer = Tokenizer.from_file(path) + + def _prepare_trainer(self, vocab_size: int, min_freq: int, reserved_token_size: int) -> tuple: + assert reserved_token_size > len(self._special_tokens) + reserved_tokens = [f"<|rsv{i:02d}|>" for i in range(reserved_token_size - len(self._special_tokens))] + detail_vocab_size = vocab_size - (len(reserved_tokens) + len(self._special_tokens)) + + alphabet = pre_tokenizers.ByteLevel.alphabet() + min_size = len(alphabet) + len(self._control_tokens) + assert detail_vocab_size > min_size + + trainer = BpeTrainer( + vocab_size=detail_vocab_size, + min_frequency=min_freq, + limit_alphabet=detail_vocab_size // 4, + max_token_length=18, + special_tokens=self._control_tokens, + show_progress=True, + initial_alphabet=alphabet, + ) + + return trainer, detail_vocab_size, reserved_tokens + + def train(self, files, vocab_size, min_freq, reserved_token_size=100): + trainer, _, reserved_tokens = self._prepare_trainer( + vocab_size=vocab_size, + min_freq=min_freq, + reserved_token_size=reserved_token_size + ) + self._tokenizer.train(files=files, trainer=trainer) + self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens) + + def train_from_iterator(self, iterator, vocab_size, min_freq, reserved_token_size=100): + trainer, _, reserved_tokens = self._prepare_trainer( + vocab_size=vocab_size, + min_freq=min_freq, + reserved_token_size=reserved_token_size + ) + self._tokenizer.train_from_iterator(iterator=iterator, trainer=trainer) + self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens) + + def save(self, path): + self._tokenizer.save(path) + + def load(self, path): + self._tokenizer = Tokenizer.from_file(path) + + def encode(self, tokens: Union[str, List[str]], out_ids: bool=True, add_special_tokens: bool=False) -> List: + if isinstance(tokens, str): + encoded: Encoding = self._tokenizer.encode(tokens, add_special_tokens=add_special_tokens) + return encoded.ids if out_ids else encoded.tokens + elif isinstance(tokens, list): + encoded_list: List[Encoding] = self._tokenizer.encode_batch(tokens, add_special_tokens=add_special_tokens) + return [encoded.ids if out_ids else encoded.tokens for encoded in encoded_list] + + def decode(self, tokens: List[int], skip_special_tokens: bool=True) -> str: + return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) + + def __len__(self) -> int: + return self._tokenizer.get_vocab_size() + + @property + def stop_ids(self) -> List[int]: + stop_ids = [] + for token in self._control_tokens: + stop_ids.append(self._tokenizer.token_to_id(token)) + return stop_ids + + @property + def bos_id(self) -> int: + return self._tokenizer.token_to_id("") + + @property + def eos_id(self) -> int: + return self._tokenizer.token_to_id("") + + @property + def pad_id(self) -> int: + return self._tokenizer.token_to_id("") diff --git a/khaosz/core/transformer.py b/khaosz/core/transformer.py new file mode 100644 index 0000000..b5a2d66 --- /dev/null +++ b/khaosz/core/transformer.py @@ -0,0 +1,341 @@ +import json +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch import Tensor +from torch.nn import init +from dataclasses import asdict, dataclass +from typing import List, Optional, Self, Tuple + + +def repeat_kv(x: Tensor, n_rep: int) -> Tensor: + """ + Repeat k times along the dimension for attention heads. + Args: + x (Tensor): The input tensor. + n_rep (int): The number of repetitions. + Returns: + Tensor: The repeated tensor. + """ + + bs, slen, n_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_heads, n_rep, head_dim) + .reshape(bs, slen, n_heads * n_rep, head_dim) + ) + +def get_rotary_emb( + dim: int, + max_len: int, + base: float = 10000, + device: torch.device = "cuda", + ) -> torch.Tensor: + """ + Get the rotary embedding for the given dimension and maximum length. + Args: + dim (int): The dimension of the input. + max_len (int): The maximum length of the input. + base (float, optional): The base for the frequency. Defaults to 10000. + device (torch.device, optional): The device to use. Defaults to "cuda". + Returns: + Tensor: The rotary embedding tensor. + """ + + theta = base ** (-torch.arange(0, dim, 2, device=device).float() / dim) + t = torch.arange(0, max_len, device=device).float() + freqs = torch.outer(t, theta) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + + return freqs_cis + +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + """ + Apply rotary embedding to the input tensor. + Args: + x (Tensor): The input tensor. + freqs_cis (Tensor): The rotary embedding tensor. + Returns: + Tensor: The output tensor. + """ + + dtype = x.dtype + seq_len = x.size(1) + + x_complex = torch.view_as_complex(x.view(*x.shape[:-1], -1, 2).float()) + freqs_cis = freqs_cis.reshape(1, seq_len, 1, -1) + x_out = torch.view_as_real(x_complex * freqs_cis).flatten(3) + + return x_out.to(dtype) + +def create_attention_mask( + seq_mask: Tensor, + start_pos: int = 0, + seq_len: int = 0, + is_causal: bool = False, + device: torch.device = "cuda", + dtype: torch.dtype = torch.float32 + ) -> Tensor: + """ + Create attention mask for GQA + Args: + seq_mask (Tensor): A tensor indicating whether each position is valid or not. + start_pos (int): The starting position of the sequence. + seq_len (int): The length of the sequence. + is_causal (bool): Whether the attention is causal or not. + device (torch.device): The device to use. + Returns: + Tensor: The attention mask tensor. + """ + + if start_pos != 0 and seq_mask is None: + # for single prompt chat + seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device) + + if seq_mask is None: + return None + + batch_size = seq_mask.size(0) + seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool) + # (bsz, start_pos + seq_len) + expanded_mask = seq_mask.unsqueeze(1).expand(batch_size, seq_len, start_pos + seq_len) + # (bsz, seq_len, start_pos + seq_len) + + if is_causal: + causal_mask = torch.tril( + torch.ones((seq_len, start_pos + seq_len), dtype=torch.bool, device=device), + diagonal=start_pos + ) + causal_mask = causal_mask.unsqueeze(0).expand(batch_size, seq_len, start_pos + seq_len) + expanded_mask = expanded_mask & causal_mask + + attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device) + attention_mask = attention_mask.masked_fill_(~expanded_mask, -torch.finfo(dtype).max / 2).unsqueeze(1) + # (bsz, 1, seq_len, seq_len + start_pos) + + return attention_mask + + +@dataclass +class TransformerConfig: + # basic config + vocab_size: Optional[int] = None + n_dim: Optional[int] = None + n_head: Optional[int] = None + n_layer: Optional[int] = None + m_len: Optional[int] = None + norm_eps: Optional[float] = None + d_ffn: Optional[int] = None + + # GQA + n_kvhead: Optional[int] = None + + + def load(self, config_path: str) -> Self: + with open(config_path, 'r') as f: + config: dict = json.load(f) + for key, value in config.items(): + if hasattr(self, key): + setattr(self, key, value) + + return self + + def save(self, config_path: str) -> None: + config_dict = asdict(self) + config_dict = {k: v for k, v in config_dict.items() if v is not None} + with open(config_path, 'w') as f: + json.dump(config_dict, f, indent=4) + + +class Linear(nn.Module): + def __init__(self, in_dim: int, out_dim: int, bias: bool=False): + super().__init__() + self.weight = nn.Parameter(torch.empty((out_dim, in_dim))) + self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None + init.normal_(self.weight, mean=0, std=0.006) + + def forward(self, x: Tensor) -> Tensor: + return F.linear(x, self.weight, self.bias) + + +class RMSNorm(nn.Module): + def __init__(self, n_dim, norm_eps): + super().__init__() + self.weight = nn.Parameter(torch.ones(n_dim)) + self.norm_eps = norm_eps + + def forward(self, x: Tensor) -> Tensor: + dtype = x.dtype + x = x.float() + mean_square = torch.mean(torch.pow(x, 2), dim=-1, keepdim=True) + norm = x * torch.rsqrt(mean_square + self.norm_eps) + norm = norm.to(dtype) + out = norm * self.weight + return out + + +class MLP(nn.Module): + def __init__(self, n_dim: int, d_ffn: int): + super().__init__() + self.up = Linear(n_dim, d_ffn) + self.gate = Linear(n_dim, d_ffn) + self.down = Linear(d_ffn, n_dim) + + def forward(self, x: Tensor) -> Tensor: + gated = self.up(x) * F.silu(self.gate(x)) + out = self.down(gated) + return out + + +class GQA(nn.Module): + def __init__( + self, + n_dim: int, + n_head: int, + n_kvhead: int, + ): + super().__init__() + assert n_dim % n_head == 0 + assert n_head % n_kvhead == 0 + + self.head_dim = n_dim // n_head + self.n_dim = n_dim + self.n_heads = n_head + self.n_kvheads = n_kvhead + self.n_rep = n_head // n_kvhead + + self.q_proj = Linear(n_dim, n_head * self.head_dim) + self.k_proj = Linear(n_dim, n_kvhead * self.head_dim) + self.v_proj = Linear(n_dim, n_kvhead * self.head_dim) + self.o_proj = Linear(n_dim, n_dim) + + def forward( + self, + x: Tensor, + freqs_cis: Tensor, + mask: Tensor = None, + kv_cache: Optional[Tuple[Tensor, Tensor]] = None, + start_pos: int = 0 + ) -> Tensor: + bsz, seq_len, _ = x.size() + # x(bsz, seq_len, n_heads * head_dim) -> (bsz, seq_len, n_heads, head_dim) + q = self._split_heads(self.q_proj(x), self.n_heads) + k = self._split_heads(self.k_proj(x), self.n_kvheads) + v = self._split_heads(self.v_proj(x), self.n_kvheads) + q, k = apply_rotary_emb(q, freqs_cis), apply_rotary_emb(k, freqs_cis) + + if kv_cache is not None: + k_cache, v_cache = kv_cache + + # copy to cache + k_cache[:bsz, start_pos:start_pos + seq_len] = k + v_cache[:bsz, start_pos:start_pos + seq_len] = v + + # get cache + k = k_cache[:bsz, :start_pos + seq_len] + v = v_cache[:bsz, :start_pos + seq_len] + + k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep) + + # (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim) + q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3) + sdqa_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=(mask == None)).permute(0, 2, 1, 3) + out = self.o_proj(sdqa_out.contiguous().view(bsz, seq_len, -1)) + + return out + + def _split_heads(self, x: Tensor, n_heads) -> Tensor: + batch_size, seq_len, _ = x.shape + x = x.reshape(batch_size, seq_len, n_heads, self.head_dim) + return x + + +class DecoderBlock(nn.Module): + def __init__(self, n_dim, n_head, d_ffn, n_kvhead, norm_eps): + super().__init__() + self.attention = GQA(n_dim, n_head, n_kvhead) + self.norm_attn = RMSNorm(n_dim, norm_eps) + self.ffn = MLP(n_dim, d_ffn) + self.norm_ffn = RMSNorm(n_dim, norm_eps) + + def forward( + self, + x: Tensor, + freqs_cis: Tensor, + attention_mask: Optional[Tensor] = None, + kv_cache: Optional[Tuple[Tensor, Tensor]] = None, + start_pos: int = 0 + ) -> Tensor: + # attention + attn_output = self.attention( + self.norm_attn(x), + freqs_cis, + attention_mask, + kv_cache, + start_pos + ) + x = attn_output + x + + # feed forward + x = self.ffn(self.norm_ffn(x)) + x + + return x + + +class Transformer(nn.Module): + def __init__(self, config: TransformerConfig): + super().__init__() + self.embedding = nn.Parameter(torch.empty(config.vocab_size, config.n_dim)) + self.layers = nn.ModuleList([ + DecoderBlock( + config.n_dim, + config.n_head, + config.d_ffn, + config.n_kvhead, + config.norm_eps + ) + for _ in range(config.n_layer) + ]) + self.norm = RMSNorm(config.n_dim, config.norm_eps) + self.freq_cis = get_rotary_emb(config.n_dim // config.n_head, config.m_len) + init.normal_(self.embedding, mean=0, std=0.02) + + def forward( + self, + input_ids: Tensor, + seq_mask: Optional[Tensor]=None, + persistent_key_values: Optional[List[Tuple[Tensor, Tensor]]]=None, + start_pos: int = 0 + ) -> Tensor: + assert input_ids.ndim == 2 + seq_len = input_ids.size(-1) + x = F.embedding(input_ids, self.embedding) + + self.freq_cis = self.freq_cis.to(x.device) + freqs_cis = self.freq_cis[start_pos:start_pos+seq_len] + has_kvcache = persistent_key_values is not None + + attn_mask = create_attention_mask( + seq_mask, + start_pos=start_pos, + seq_len=seq_len, + is_causal=has_kvcache, + device=x.device, + dtype=x.dtype + ) + + for i, layer in enumerate(self.layers): + kv_cache = persistent_key_values[i] if persistent_key_values else None + x = layer(x, freqs_cis, attn_mask, kv_cache, start_pos) + + hidden_states = self.norm(x) + logits = F.linear(hidden_states, self.embedding) + + return { + "logits": logits, + "hidden_states": hidden_states + } + \ No newline at end of file diff --git a/khaosz/model.py b/khaosz/model.py new file mode 100644 index 0000000..3723f29 --- /dev/null +++ b/khaosz/model.py @@ -0,0 +1,112 @@ +from torch import Tensor +from typing import List, Tuple, Generator, Union + +from khaosz.core.generator import ( + TextGenerator, + ChatGenerator, + StreamGenerator, + BatchGenerator, + RetrievalGenerator, + EmbeddingEncoder +) +from khaosz.core.parameter import ParameterLoader + + +class Khaosz: + def __init__(self, model_dir: str): + self.parameter = ParameterLoader.load(model_dir) + + def to(self, *args, **kwargs): + self.parameter.to(*args, **kwargs) + return self + + def generate( + self, + query: str, + history: List[Tuple[str, str]]=None, + temperature: float=0.8, + top_k: int=50, + top_p: float=0.95, + ) -> str: + generator = ChatGenerator(self.parameter) + return generator.generate( + query, + history=history, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + + def batch_generate( + self, + queries: List[str], + histories: List[Tuple[str, str]]=None, + temperature: float=0.8, + top_k: int=50, + top_p: float=0.95, + ) -> List[str]: + generator = BatchGenerator(self.parameter) + return generator.generate( + queries, + histories=histories, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + + + def stream_generate( + self, + query: str, + history: List[Tuple[str, str]]=None, + temperature: float=0.8, + top_k: int=50, + top_p: float=0.95, + ) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]: + stream_generator = StreamGenerator(self.parameter) + return stream_generator.generate( + query, + history=history, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + + def retrieve_generate( + self, + retrieved, + query: str, + history: List[Tuple[str, str]] = None, + temperature: float=0.8, + top_k: int=50, + top_p: float=0.95, + ) -> str: + generator = RetrievalGenerator(self.parameter) + return generator.generate( + retrieved, + query, + history=history, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + + def text_generate( + self, + query: str, + temperature: float=0.8, + top_k: int=50, + top_p: float=0.95, + ) -> str: + generator = TextGenerator(self.parameter) + + return generator.generate( + query, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + + def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]: + encoder = EmbeddingEncoder(self.parameter) + return encoder.encode(sentence) \ No newline at end of file diff --git a/khaosz/trainer/__init__.py b/khaosz/trainer/__init__.py new file mode 100644 index 0000000..7630e6a --- /dev/null +++ b/khaosz/trainer/__init__.py @@ -0,0 +1,11 @@ +from khaosz.trainer.dataset import DatasetLoader +from khaosz.trainer.trainer import Trainer +from khaosz.trainer.strategy import TrainConfig, CosineScheduleConfig, SgdrScheduleConfig + +__all__ = [ + "DatasetLoader", + "Trainer", + "TrainConfig", + "CosineScheduleConfig", + "SgdrScheduleConfig", +] \ No newline at end of file diff --git a/khaosz/trainer/dataset.py b/khaosz/trainer/dataset.py new file mode 100644 index 0000000..439567d --- /dev/null +++ b/khaosz/trainer/dataset.py @@ -0,0 +1,210 @@ +import torch +import bisect +import pickle as pkl +from abc import ABC, abstractmethod +from torch import Tensor +from torch.utils.data import Dataset +from typing import Callable, List, Dict, Literal, Union + +MutiSeg = Dict[str, List[Tensor]] +Seg = Dict[str, Tensor] + +def load_pkl_files(paths: List[str]): + segments: MutiSeg = {} + total_samples = 0 + + for path in paths: + with open(path, "rb") as f: + pkl_file: Seg = pkl.load(f) + for key, value in pkl_file.items(): + if key not in segments: + segments[key] = [] + segments[key].append(value) + first_key = list(pkl_file.keys())[0] + total_samples += pkl_file[first_key].numel() + + return segments, total_samples + + +class BaseSegmentFetcher: + def __init__(self, segments: List[Tensor]): + self.segments = segments + self.cum_lengths = [] + total = 0 + for seg in segments: + total += len(seg) + self.cum_lengths.append(total) + self.total_length = total if segments else 0 + + def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: + if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length): + raise ValueError("begin_idx or end_idx out of bounds") + if begin_idx >= end_idx: + return torch.tensor([], dtype=torch.long) + + seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx - 1) + seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx - 1) + + result_segments = [] + + for i in range(seg_start_idx, seg_end_idx + 1): + prev_cum = self.cum_lengths[i - 1] if i > 0 else 0 + start = max(begin_idx - prev_cum, 0) + end = min(end_idx - prev_cum, len(self.segments[i])) + result_segments.append(self.segments[i][start:end]) + + return torch.cat(result_segments, dim=0) + + +class MutiSegmentFetcher: + def __init__(self, muti_segments: MutiSeg): + self.muti_keys = list(muti_segments.keys()) + self.muti_fetchers = { + key: BaseSegmentFetcher(segments) + for key, segments in muti_segments.items() + } + + def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Union[Tensor, Seg]: + fetch_dict = {} + keys = [keys] if isinstance(keys, str) else keys + + for key in keys: + fetcher = self.muti_fetchers[key] + fetch_tensor = fetcher.fetch_data(begin_idx, end_idx) + fetch_dict[key] = fetch_tensor + + return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]] + + def fetch_data(self, begin_idx: int, end_idx: int) -> Union[Tensor, Seg]: + return self.key_fetch(begin_idx, end_idx, self.muti_keys) + + +class BaseDataset(Dataset, ABC): + def __init__(self, chunk_size: int, device: str): + super().__init__() + self.segments: MutiSeg = {} + self.chunk_size = chunk_size + self.total_samples = 0 + self.device = device + + def save(self, save_path: str): + first_item = self.segments[keys[0]] + segment_size = len(first_item) + keys = list(self.segments.keys()) + + for i in range(segment_size): + formated_segment = {key: self.segments[key][i] for key in keys} + pkl.dump(formated_segment, open(f"{save_path}_{i}.pkl", "wb")) + + + def load(self, load_path: Union[str, List[str]]): + paths = [load_path] if isinstance(load_path, str) else load_path + self.segments, self.total_samples = load_pkl_files(paths) + self.fetcher = MutiSegmentFetcher(self.segments) + + @abstractmethod + def __getitem__(self, index: int): + raise NotImplementedError + + def __len__(self) -> int: + assert self.total_samples // self.chunk_size > 0 + return self.total_samples // self.chunk_size + + +class SeqDataset(BaseDataset): + def __init__(self, chunk_size , device='cuda'): + super().__init__(chunk_size, device) + self.fetcher = MutiSegmentFetcher(self.segments) + + def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: + return self.fetcher.key_fetch(begin_idx, end_idx, "sequence") + + def __getitem__(self, index): + begin_idx = index * self.chunk_size + end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1) + + x = self._fetch_data(begin_idx, end_idx).to(device=self.device, dtype=torch.long) + y = self._fetch_data(begin_idx + 1, end_idx + 1).to(device=self.device, dtype=torch.long) + + return x, y + + +class SftDataset(BaseDataset): + def __init__(self, chunk_size, device='cuda'): + super().__init__(chunk_size, device) + self.fetcher = MutiSegmentFetcher(self.segments) + + def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: + return self.fetcher.key_fetch(begin_idx, end_idx, key) + + def __getitem__(self, index): + begin_idx = index * self.chunk_size + end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1) + + x = self._fetch_data(begin_idx, end_idx, "sequence").to(device=self.device, dtype=torch.long) + y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(device=self.device, dtype=torch.long) + loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "mask").to(device=self.device, dtype=torch.bool) + + return x, y, loss_mask + + +class DpoDataset(BaseDataset): + def __init__(self, chunk_size: int, device="cuda"): + super().__init__(chunk_size, device) + self.fetcher = MutiSegmentFetcher(self.segments) + + def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: + return self.fetcher.key_fetch(begin_idx, end_idx, key) + + def __getitem__(self, index: int): + start_idx = index * self.chunk_size + end_idx = min(start_idx + self.chunk_size, self.total_samples - 1) + + chosen = self._fetch_data(start_idx, end_idx, "chosen").to(device=self.device, dtype=torch.long) + rejected = self._fetch_data(start_idx, end_idx, "rejected").to(device=self.device, dtype=torch.long) + chosen_mask = self._fetch_data(start_idx, end_idx, "chosen_mask").to(device=self.device, dtype=torch.bool) + rejected_mask = self._fetch_data(start_idx, end_idx, "rejected_mask").to(device=self.device, dtype=torch.bool) + + return chosen, rejected, chosen_mask, rejected_mask + + +class PpoDataset(BaseDataset): + def __init__(self, chunk_size: int, device="cuda"): + super().__init__(chunk_size, device) + self.fetcher = MutiSegmentFetcher(self.segments) + + def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: + return self.fetcher.key_fetch(begin_idx, end_idx, key) + + def __getitem__(self, index: int) -> Dict[str, Tensor]: + + begin_idx = index * self.chunk_size + end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1) + + + input_ids = self._fetch_data(begin_idx, end_idx, "input_ids").to(self.device), + actions = self._fetch_data(begin_idx, end_idx, "actions").to(self.device), + logprobs = self._fetch_data(begin_idx, end_idx, "logprobs").to(self.device), + rewards = self._fetch_data(begin_idx, end_idx, "rewards").to(self.device) + + return input_ids, actions, logprobs, rewards + + +class DatasetLoader: + @staticmethod + def load( + train_type: Literal["seq", "sft", "dpo"], + load_path: Union[str, List[str]], + max_len: int, + device: str + ) -> BaseDataset: + + dataset_router: Dict[str, Callable[[int, torch.device], BaseDataset]] = { + "seq": lambda m_len, device: SeqDataset(m_len, device=device), + "sft": lambda m_len, device: SftDataset(m_len, device=device), + "dpo": lambda m_len, device: DpoDataset(m_len, device=device), + } + dataset = dataset_router[train_type](max_len, device) + dataset.load(load_path) + + return dataset \ No newline at end of file diff --git a/khaosz/trainer/mask.py b/khaosz/trainer/mask.py new file mode 100644 index 0000000..9de0472 --- /dev/null +++ b/khaosz/trainer/mask.py @@ -0,0 +1,55 @@ + +import torch +from abc import abstractmethod +from torch import Tensor + + + +class MaskBuilder: + def __init__( + self, + bos_token_id: int, + eos_token_id: int, + user_token_id: int, + system_token_id: int, + + ): + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.user_token_id = user_token_id + self.system_token_id = system_token_id + + @abstractmethod + def build(input_ids: Tensor) -> Tensor: + raise NotImplementedError + + + +class LossMaskBuilder(MaskBuilder): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def build(self, input_ids: Tensor) -> Tensor: + token_markers = torch.zeros_like(input_ids, dtype=torch.int8) + + is_user_token = input_ids.eq(self.user_token_id) + is_system_token = input_ids.eq(self.system_token_id) + + token_markers[is_user_token] = 1 + token_markers[is_system_token] = -1 + + cumulative_markers = torch.cumsum(token_markers, dim=-1) + min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values + loss_mask = cumulative_markers - min_cumulative + + return loss_mask + + + + +class AttentionMaskBuilder: + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def build(input_ids: Tensor): + bsz = input_ids.size(0) \ No newline at end of file diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py new file mode 100644 index 0000000..123b43f --- /dev/null +++ b/khaosz/trainer/strategy.py @@ -0,0 +1,388 @@ +import copy +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch import Tensor +from torch.optim import Optimizer +from torch.utils.data import Dataset +from typing import Any, Literal, Tuple, Callable, Dict +from abc import ABC, abstractmethod +from dataclasses import asdict, dataclass, field + + +def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id): + input_mask = input_ids.ne(pad_token_id) + logits = model(input_ids, input_mask)["logits"] + log_probs = torch.log_softmax(logits, dim=-1) + + shifted_log_probs = log_probs[:, :-1, :] + shifted_input_ids = input_ids[:, 1:] + shifted_response_mask = mask[:, 1:] + + token_logprobs = torch.gather( + shifted_log_probs, + dim=-1, + index=shifted_input_ids.unsqueeze(-1) + ).squeeze(-1) + + prompt_mask = input_mask[:, 1:] + valid_mask = (prompt_mask & shifted_response_mask).float() + + return (token_logprobs * valid_mask).sum(dim=-1) + + +class MaskBuilder: + def __init__( + self, + bos_token_id: int, + eos_token_id: int, + user_token_id: int, + system_token_id: int, + + ): + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.user_token_id = user_token_id + self.system_token_id = system_token_id + + @abstractmethod + def build(input_ids: Tensor) -> Tensor: + raise NotImplementedError + + + +class LossMaskBuilder(MaskBuilder): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def build(self, input_ids: Tensor) -> Tensor: + token_markers = torch.zeros_like(input_ids, dtype=torch.int8) + + is_user_token = input_ids.eq(self.user_token_id) + is_system_token = input_ids.eq(self.system_token_id) + + token_markers[is_user_token] = 1 + token_markers[is_system_token] = -1 + + cumulative_markers = torch.cumsum(token_markers, dim=-1) + min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values + loss_mask = cumulative_markers - min_cumulative + + return loss_mask.to(dtype=torch.bool) + + + + +class AttentionMaskBuilder(MaskBuilder): + def __init__(self, multi_turn=False, **kwargs): + super().__init__(**kwargs) + self.multi_turn = multi_turn + + def build(self, input_ids: Tensor): + bsz = input_ids.size(0) + + + def _build_batch(self, input_ids: Tensor): + is_user_token = input_ids.eq(self.user_token_id) + + token_markers = torch.zeros_like(input_ids, dtype=torch.int8) + token_markers[is_user_token] = 1 + cumulative_markers = torch.cumsum(token_markers, dim=-1) + + + + + +class BaseStrategy(ABC): + def __init__(self, model: nn.Module): + self.model = model + + @abstractmethod + def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor: + raise NotImplementedError + + def __call__(self, batch: Tuple[Tensor, ...]) -> Tensor: + return self.compute_loss(batch) + + +class SeqStrategy(BaseStrategy): + def __init__(self, model): + super().__init__(model) + + def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor: + x, y = batch + B, L = x.size() + logits: Tensor = self.model(x)["logits"] + + loss = F.cross_entropy( + logits.view(B * L, -1), y.flatten() + ) + return loss + + +class SftStrategy(BaseStrategy): + def __init__(self, model): + super().__init__(model) + + def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor: + x, y, loss_mask = batch + B, L = x.size() + ignore_idx = -1 + + logits: Tensor = self.model(x)["logits"] + masked_y = y.masked_fill(loss_mask == 0, ignore_idx) + + loss = F.cross_entropy( + logits.view(B * L, -1), + masked_y.flatten(), + ignore_index=ignore_idx + ) + + return loss + +class DpoStrategy(BaseStrategy): + def __init__(self, model, pad_token_id, beta): + super().__init__(model) + ref_model = copy.deepcopy(self.model) + ref_model.requires_grad_(False) + ref_model.eval() + + self.ref_model = ref_model + self.pad_token_id = pad_token_id + self.beta = beta + + def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor: + good_ids, bad_ids, good_mask, bad_mask = batch + + log_pi_good = get_logprobs(self.model, good_ids, good_mask, self.pad_token_id) + log_pi_bad = get_logprobs(self.model, bad_ids, bad_mask, self.pad_token_id) + + with torch.no_grad(): + log_ref_good = get_logprobs(self.ref_model, good_ids, good_mask, self.pad_token_id) + log_ref_bad = get_logprobs(self.ref_model, bad_ids, bad_mask, self.pad_token_id) + + pi_log_ratio = log_pi_good - log_pi_bad + ref_log_ratio = log_ref_good - log_ref_bad + + ratio_diff = pi_log_ratio - ref_log_ratio + + dpo_loss = -F.logsigmoid(self.beta * ratio_diff).mean() + return dpo_loss + + +class PpoStrategy(BaseStrategy): + def __init__(self, model, pad_token_id, epsilon): + super().__init__(model) + ref_model = copy.deepcopy(self.model) + ref_model.requires_grad_(False) + ref_model.eval() + + self.ref_model = ref_model + self.pad_token_id = pad_token_id + self.epsilon = epsilon + + def ppo_clip_loss_masked( + self, + log_probs: Tensor, + old_log_probs: Tensor, + advantages: Tensor, + values: Tensor, + returns: Tensor, + mask: Tensor, + clip_eps: float=0.2, + ): + ratio = torch.exp(log_probs - old_log_probs) + surr1 = ratio * advantages + surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * advantages + policy_loss = -torch.min(surr1, surr2).masked_select(mask).mean() + + value_loss = F.mse_loss(values.masked_select(mask), + returns.masked_select(mask)) + + entropy = -(log_probs.exp() * log_probs).masked_select(mask).mean() + entropy_loss = -entropy + return policy_loss, value_loss, entropy_loss + + + +class StrategyFactory: + + def load(model, train_type, pad_token_id, dpo_beta): + train_strategy: Dict[str, Callable[[], BaseStrategy]] = { + "seq": lambda: SeqStrategy(model), + "sft": lambda: SftStrategy(model), + "dpo": lambda: DpoStrategy(model, pad_token_id, dpo_beta) + } + strategy = train_strategy[train_type]() + return strategy + + +@dataclass +class TrainConfig: + train_type: str = field( + default_factory=["seq", "sft", "dpo"], + metadata={"help": "Type of training."} + ) + dataset: Dataset = field( + default=None, + metadata={"help": "Dataset for training."} + ) + optimizer: Optimizer = field( + default=None, + metadata={"help": "Optimizer for training."} + ) + ckpt_dir: str = field( + default="./checkpoint", + metadata={"help": "Checkpoint directory."} + ) + n_epoch: int = field( + default=1, + metadata={"help": "Number of epochs for training."} + ) + batch_size: int = field( + default=4, + metadata={"help": "Batch size for training."} + ) + n_iter_ckpt: int = field( + default=5000, + metadata={"help": "Number of iterations between checkpoints."} + ) + n_iter_step: int = field( + default=1, + metadata={"help": "Number of iterations between steps."} + ) + max_grad_norm: float = field( + default=1.0, + metadata={"help": "Maximum gradient norm."} + ) + random_seed: int = field( + default=3407, + metadata={"help": "Random seed."} + ) + dpo_beta: float = field( + default=0.1, + metadata={"help": "DPO beta."} + ) + + def get_kwargs(self)-> Dict[str, Any]: + config_dict = asdict(self) + return {k: v for k, v in config_dict.items() if v is not None} + + +@dataclass +class ScheduleConfig: + schedule_type: str = field( + default_factory=["cosine", "sgdr"], + metadata={"help": "Type of learning rate schedule."} + ) + warning_step: int = field( + default=1000, + metadata= {"help": "Warning up step."} + ) + @abstractmethod + def get_kwargs(self)-> Dict[str, Any]: + raise NotImplementedError + + +@dataclass +class CosineScheduleConfig(ScheduleConfig): + total_iters: int = field( + default=None, + metadata={"help": "Total iterations for cosine schedule."} + ) + min_rate: float = field( + default=0.05, + metadata={"help": "Minimum rate for cosine schedule."} + ) + schedule_type: Literal["cosine"] = "cosine" + + def get_kwargs(self) -> Dict[str, Any]: + return { + "schedule_type": self.schedule_type, + "warning_step": self.warning_step, + "lr_decay_iters": self.total_iters - self.warning_step, + "min_rate": self.min_rate + } + +@dataclass +class SgdrScheduleConfig(ScheduleConfig): + cycle_length: int = field( + default=1000, + metadata={"help": "Cycle length for sgdr schedule."} + ) + min_rate: float = field( + default=0.05, + metadata={"help": "Minimum rate for sgdr schedule."} + ) + T_mult: int = field( + default=2, + metadata={"help": "T_mult for sgdr schedule."} + ) + schedule_type: Literal["sgdr"] = "sgdr" + + def get_kwargs(self) -> Dict[str, Any]: + return { + "schedule_type": self.schedule_type, + "warning_step": self.warning_step, + "cycle_length": self.cycle_length, + "min_rate": self.min_rate, + "T_mult": self.T_mult + } + + +class SchedulerFactory: + + @staticmethod + def get_sgdr_schedule( + warning_step: int, + cycle_length: int, + min_rate: float = 0.1, + T_mult: int = 2 + ) -> Callable[[int], float]: + + def sgdr_schedule(now_iter: int) -> float: + if now_iter < warning_step: + return max(min_rate, now_iter / warning_step) + + adjusted_iter = now_iter - warning_step + total_cycles, current_cycle = 0, 0 + while adjusted_iter >= cycle_length * (T_mult ** total_cycles): + current_cycle += 1 + total_cycles += 1 + + cycle_start = sum(cycle_length * (T_mult ** i) for i in range(current_cycle)) + cycle_pos = adjusted_iter - cycle_start + + cycle_length_current = cycle_length * (T_mult ** current_cycle) + return max(min_rate, 0.5 * (1 + math.cos(math.pi * cycle_pos / cycle_length_current))) + + return sgdr_schedule + + @staticmethod + def get_cosine_warmup_schedule( + warning_step: int, + lr_decay_iters: int, + min_rate: float = 0.1 + ) -> Callable[[int], float]: + + def cosine_warmup_schedule(now_iter: int) -> float: + if now_iter <= warning_step: + return max(min_rate, now_iter / warning_step) + else: + rate = (now_iter - warning_step) / (lr_decay_iters - warning_step) + return max(min_rate, 0.5 * (1.0 + math.cos(math.pi * rate))) + + return cosine_warmup_schedule + + @staticmethod + def load_schedule_fn(**kwargs): + strategy = kwargs.pop("schedule_type") + if strategy == "cosine": + return SchedulerFactory.get_cosine_warmup_schedule(**kwargs) + elif strategy == "sgdr": + return SchedulerFactory.get_sgdr_schedule(**kwargs) + else: + raise ValueError(f"Invalid schedule type: {strategy}") + \ No newline at end of file diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py new file mode 100644 index 0000000..a0ea0fe --- /dev/null +++ b/khaosz/trainer/trainer.py @@ -0,0 +1,167 @@ +import os +import torch +import logging + +from typing import Tuple +from torch.nn.utils import clip_grad_norm_ +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader, RandomSampler +from tqdm import tqdm + +from khaosz.core import ModelParameter, Checkpoint +from khaosz.trainer.strategy import SchedulerFactory, StrategyFactory, TrainConfig, ScheduleConfig + + +class Trainer: + def __init__( + self, + parameter: ModelParameter, + log_path: str="./train_log.log" + ): + logger = logging.getLogger() + logger.setLevel(level = logging.INFO) + handler = logging.FileHandler(log_path) + handler.setLevel(logging.INFO) + handler.setFormatter(logging.Formatter('%(asctime)s: %(message)s')) + logger.addHandler(handler) + logger.info("initializing trainer ...") + + self.logger = logger + self.model = parameter.model + self.tokenizer = parameter.tokenizer + self.config = parameter.config + + def save_checkpoint( + self, + loss_list: list, + ckpt_dir: str, + current_iter: int, + last_ckpt_iter: int + ): + save_path = os.path.join(ckpt_dir, f"iter_{current_iter}") + Checkpoint( + self.model, + self.tokenizer, + self.config, + loss_list, + current_iter + ).save(save_path) + + diff_iter = current_iter - last_ckpt_iter + avg_loss = sum(loss_list[last_ckpt_iter:current_iter]) / diff_iter + self.logger.info(f"iter: {current_iter} loss: {avg_loss}") + + return current_iter + + def load_checkpoint(self, train_checkpoint: Checkpoint) -> Tuple[list, int]: + self.model = train_checkpoint.model + self.tokenizer = train_checkpoint.tokenizer + self.config = train_checkpoint.config + loss_list = train_checkpoint.loss_list + last_ckpt_iter = train_checkpoint.current_iter + + return loss_list, last_ckpt_iter + + def train( + self, + train_config: TrainConfig, + schedule_config: ScheduleConfig, + train_checkpoint: Checkpoint = None + ): + assert schedule_config.schedule_type in ["cosine", "sgdr"] + assert train_config.train_type in ["seq", "sft", "dpo"] + + if train_checkpoint: + loss_list, last_ckpt_iter = self.load_checkpoint(train_checkpoint) + current_iter = train_checkpoint.current_iter + 1 + self.logger.info(f"Resuming training from checkpoint: iter {current_iter}") + else: + current_iter = 0 + last_ckpt_iter = 0 + loss_list = [] + + lambda_scheduler_fn = SchedulerFactory.load_schedule_fn( + **schedule_config.get_kwargs() + ) + + strategy = StrategyFactory.load( + self.model, + train_config.train_type, + self.tokenizer.pad_id, + train_config.dpo_beta + ) + + scheduler = LambdaLR( + train_config.optimizer, + lambda_scheduler_fn, + last_epoch=current_iter - 1 if train_checkpoint else -1 + ) + + seed = train_config.random_seed + generator = torch.Generator().manual_seed(seed) + sampler = RandomSampler(train_config.dataset, generator=generator) + remaining_epochs = train_config.n_epoch - current_iter // (len(train_config.dataset) // train_config.batch_size) + + self.logger.info(f"Starting {train_config.train_type.upper()} training for {train_config.n_epoch} epochs") + self.logger.info(f"Checkpoint interval: {train_config.n_iter_ckpt} iterations") + + for epoch in range(remaining_epochs): + self.model.train() + dataloader = DataLoader( + train_config.dataset, + batch_size=train_config.batch_size, + sampler=sampler + ) + progress_bar = tqdm( + dataloader, + desc=f"Epoch {epoch+1}/{train_config.n_epoch}", + dynamic_ncols=True + ) + for batch in progress_bar: + #forward + loss = strategy(batch) + loss_list.append(loss.item()) + #backward + loss.backward() + #step + if current_iter % train_config.n_iter_step == 0: + clip_grad_norm_( + self.model.parameters(), + train_config.max_grad_norm + ) + train_config.optimizer.step() + train_config.optimizer.zero_grad() + + current_iter += 1 + scheduler.step() + progress_bar.set_postfix({ + "loss": f"{loss.item():.4f}", + "lr": f"{train_config.optimizer.param_groups[0]['lr']:.2e}" + }) + #save checkpotint + if current_iter - last_ckpt_iter >= train_config.n_iter_ckpt: + last_ckpt_iter = self.save_checkpoint( + loss_list, + train_config.ckpt_dir, + current_iter, + last_ckpt_iter + ) + + if current_iter != last_ckpt_iter: + last_ckpt_iter = self.save_checkpoint( + loss_list, + train_config.ckpt_dir, + current_iter, + last_ckpt_iter + ) + + self.logger.info("Training completed") + + return Checkpoint( + self.model, + self.tokenizer, + self.config, + loss_list, + current_iter, + train_config.optimizer + ) diff --git a/khaosz/utils/retriever.py b/khaosz/utils/retriever.py new file mode 100644 index 0000000..1264474 --- /dev/null +++ b/khaosz/utils/retriever.py @@ -0,0 +1,88 @@ +import torch +import sqlite3 +import numpy as np +from torch import Tensor +from typing import Dict, List, Tuple + + +class Retriever: + def __init__(self, db_path=None): + self.data: Dict[str, Tensor] = {} + self.embedding_cache: Tensor = None + self.is_caculated: bool = False + + if db_path is not None: + self.load(db_path) + + def retrieve(self, query: Tensor, top_k: int) -> List[Tuple[str, float]]: + if not self.data: + return [] + + query = query.flatten().unsqueeze(1) # [dim, 1] + norm_embeddings = self._embeddings.to( + device=query.device, + dtype=query.dtype + ) # [n_vectors, dim] + sim_scores = torch.matmul(norm_embeddings, query).squeeze() # [n_vectors] + + top_k = min(top_k, len(self.data)) + indices = sim_scores.topk(top_k).indices + keys = list(self.data.keys()) + + return [(keys[i], sim_scores[i].item()) for i in indices] + + def add_vector(self, key: str, vector_data: Tensor): + self.is_caculated = False + self.data[key] = vector_data.flatten().float().cpu() + + def delete_vector(self, key: str): + self.is_caculated = False + self.data.pop(key, None) + + def save(self, db_path): + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + self._init_db(cursor) + cursor.execute('DELETE FROM vectors') + + for item, vec in self.data.items(): + vec_bytes = vec.numpy().tobytes() + cursor.execute('INSERT OR REPLACE INTO vectors (key, vector) VALUES (?, ?)', + (item, vec_bytes)) + + conn.commit() + conn.close() + + def load(self, db_path): + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + self._init_db(cursor) + cursor.execute('SELECT key, vector FROM vectors') + rows = cursor.fetchall() + self.data = {} + + for row in rows: + key, vec_bytes = row + vec_numpy = np.frombuffer(vec_bytes, dtype=np.float32).copy() + vec = torch.from_numpy(vec_numpy) + self.data[key] = vec + + conn.close() + + def _init_db(self,cursor: sqlite3.Cursor): + # Create table if not exists (in case loading from a new database) + cursor.execute(''' + CREATE TABLE IF NOT EXISTS vectors ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + key TEXT UNIQUE NOT NULL, + vector BLOB NOT NULL + )''') + + @property + def _embeddings(self) -> Tensor: + if not self.is_caculated: + embeddings = torch.stack(list(self.data.values())) + norm_embeddings = embeddings / torch.norm(embeddings, dim=-1, keepdim=True) + self.embedding_cache = norm_embeddings + + return self.embedding_cache \ No newline at end of file diff --git a/khaosz/utils/splitter.py b/khaosz/utils/splitter.py new file mode 100644 index 0000000..7f15fd3 --- /dev/null +++ b/khaosz/utils/splitter.py @@ -0,0 +1,127 @@ +import re +import torch +import torch.nn.functional as F + +from abc import ABC, abstractmethod +from torch import Tensor +from typing import List, Callable, Optional + + +class BaseTextSplitter(ABC): + def __init__( + self, + max_len: int = 512, + chunk_overlap: int = 0, + ): + if max_len <= 0: + raise ValueError("max_len must be > 0") + if chunk_overlap < 0: + raise ValueError("chunk_overlap must be >= 0") + + self.max_len = max_len + self.chunk_overlap = chunk_overlap + + @abstractmethod + def split(self, text: str, **kwargs) -> List[str]: + raise NotImplementedError + + def preprocess(self, text: str) -> str: + return text.strip() + + def postprocess(self, chunks: List[str]) -> List[str]: + return [chunk.strip() for chunk in chunks if chunk.strip()] + + +class PriorityTextSplitter(BaseTextSplitter): + def __init__( + self, + separators: List[str], + max_len: int = 512, + chunk_overlap: int = 0, + ): + super().__init__(max_len=max_len, chunk_overlap=chunk_overlap) + if not separators: + raise ValueError("separators must be a non-empty list") + self.separators = separators + + def split(self, text: str) -> List[str]: + text = self.preprocess(text) + for sep in self.separators: + parts = text.split(sep) + + valid_parts = [p.strip() for p in parts if p.strip()] + if len(valid_parts) > 1: + return self.postprocess(valid_parts) + return [text] + + +class SemanticTextSplitter(BaseTextSplitter): + + DEFAULT_PATTERN = r'(?<=[。!?!?])(?=(?:[^"\'‘’“”]*["\'‘’“”][^"\'‘’“”]*["\'‘’“”])*[^"\'‘’“”]*$)' + + def __init__( + self, + embedding_func: Callable[[List[str]], List[Tensor]], + pattern: Optional[str] = None, + max_len: int = 512, + chunk_overlap: int = 0, + ): + super().__init__(max_len=max_len, chunk_overlap=chunk_overlap) + if not callable(embedding_func): + raise TypeError("embedding_func must be callable") + self.embedding_func = embedding_func + self.pattern = pattern or SemanticTextSplitter.DEFAULT_PATTERN + + def split( + self, + text: str, + threshold: float = 0.5, + window_size: int = 1, + ) -> List[str]: + text = self.preprocess(text) + sentences = [s.strip() for s in re.split(self.pattern, text) if s.strip()] + + if len(sentences) <= 1: + return self.postprocess(sentences) + + try: + sentence_embs = self.embedding_func(sentences) + except Exception as e: + raise RuntimeError(f"Embedding generation failed: {e}") + + if len(sentence_embs) != len(sentences): + raise ValueError("Embedding function must return one vector per sentence") + + chunks = [] + emb_tensor = torch.stack(sentence_embs) # shape: [N, D] + current_chunk: List[str] = [sentences[0]] + + for i in range(1, len(sentences)): + start_prev = max(0, i - window_size) + end_prev = i + start_next = i + end_next = min(len(sentences), i + window_size) + + prev_window_emb = emb_tensor[start_prev:end_prev].mean(dim=0) + next_window_emb = emb_tensor[start_next:end_next].mean(dim=0) + + similarity = F.cosine_similarity( + prev_window_emb.unsqueeze(0), + next_window_emb.unsqueeze(0), + dim=1 + ).item() + + dynamic_threshold = max(threshold * (1 - 0.03 * (end_next - start_prev)), 0.2) + + if similarity < dynamic_threshold: + chunks.append(" ".join(current_chunk)) + overlap_start = max(0, len(current_chunk) - self.chunk_overlap) + current_chunk = current_chunk[overlap_start:] + current_chunk.append(sentences[i]) + else: + current_chunk.append(sentences[i]) + + if current_chunk: + chunks.append(" ".join(current_chunk)) + + return self.postprocess(chunks) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..5c8a7bb --- /dev/null +++ b/requirements.txt @@ -0,0 +1,36 @@ +# python=3.12 +--extra-index-url https://download.pytorch.org/whl/cu126 + +certifi==2025.8.3 +charset-normalizer==3.4.2 +colorama==0.4.6 +contourpy==1.3.3 +cycler==0.12.1 +filelock==3.13.1 +fonttools==4.59.0 +fsspec==2024.6.1 +huggingface-hub==0.34.3 +idna==3.10 +Jinja2==3.1.6 +kiwisolver==1.4.8 +MarkupSafe==2.1.5 +matplotlib==3.10.5 +mpmath==1.3.0 +networkx==3.3 +numpy==2.3.2 +packaging==25.0 +pillow==11.3.0 +pyparsing==3.2.3 +python-dateutil==2.9.0.post0 +PyYAML==6.0.2 +requests==2.32.4 +safetensors==0.5.3 +setuptools==78.1.1 +six==1.17.0 +sympy==1.13.3 +tokenizers==0.21.4 +torch==2.7.1+cu126 +tqdm==4.67.1 +typing_extensions==4.12.2 +urllib3==2.5.0 +wheel==0.45.1 diff --git a/scripts/chat.py b/scripts/chat.py new file mode 100644 index 0000000..399717b --- /dev/null +++ b/scripts/chat.py @@ -0,0 +1,32 @@ +import os +import torch +from khaosz import Khaosz + + +PROJECT_ROOT = os.path.dirname( + os.path.dirname(os.path.abspath(__file__))) + +def chat(): + model_dir = os.path.join(PROJECT_ROOT, "params") + model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16) + + histroy = [] + while True: + query = input(">> ") + if query == "!exit": + break + + response_size = 0 + for response, histroy in model.stream_generate( + query=query, + history=histroy, + temperature=0.7, + top_p=0.95, + top_k=30 + ): + print(response[response_size:], end="", flush=True) + response_size = len(response) + + +if __name__ == "__main__": + chat() \ No newline at end of file diff --git a/scripts/download.py b/scripts/download.py new file mode 100644 index 0000000..ca641f2 --- /dev/null +++ b/scripts/download.py @@ -0,0 +1,14 @@ +import os +from huggingface_hub import snapshot_download + + +PROJECT_ROOT = os.path.dirname( + os.path.dirname(os.path.abspath(__file__))) + + +if __name__ == "__main__": + snapshot_download( + repo_id="ViperEk/KHAOSZ", + local_dir=os.path.join(PROJECT_ROOT, "params"), + force_download=True + ) \ No newline at end of file diff --git a/scripts/generate_ar.py b/scripts/generate_ar.py new file mode 100644 index 0000000..23b2832 --- /dev/null +++ b/scripts/generate_ar.py @@ -0,0 +1,27 @@ +import os +import torch +from khaosz import Khaosz + + +PROJECT_ROOT = os.path.dirname( + os.path.dirname(os.path.abspath(__file__))) + +def generate_text(): + model_dir = os.path.join(PROJECT_ROOT, "params") + model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16) + + query = input(">> ") + + response = model.text_generate( + query=query, + temperature=0.6, + top_p=0.95, + top_k=30 + ) + + print(response) + + + +if __name__ == "__main__": + generate_text() \ No newline at end of file diff --git a/scripts/generate_batch.py b/scripts/generate_batch.py new file mode 100644 index 0000000..037b04f --- /dev/null +++ b/scripts/generate_batch.py @@ -0,0 +1,25 @@ +import os +import torch +from khaosz import Khaosz + + +PROJECT_ROOT = os.path.dirname( + os.path.dirname(os.path.abspath(__file__))) + +def batch_generate(): + model_dir = os.path.join(PROJECT_ROOT, "params") + model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16) + inputs = ["你好", "请问什么是人工智能", "今天天气如何", "我感到焦虑, 请问我应该怎么办", "请问什么是显卡"] + + responses = model.batch_generate( + queries=inputs, + temperature=0.7, + top_p=0.95, + top_k=30 + ) + + for q, r in zip(inputs, responses): + print((q, r)) + +if __name__ == "__main__": + batch_generate() \ No newline at end of file diff --git a/scripts/generate_retrieve.py b/scripts/generate_retrieve.py new file mode 100644 index 0000000..46a9e56 --- /dev/null +++ b/scripts/generate_retrieve.py @@ -0,0 +1,42 @@ +import os +import torch +from khaosz import Khaosz, SemanticTextSplitter, Retriever + + +PROJECT_ROOT = os.path.dirname( + os.path.dirname(os.path.abspath(__file__))) + +if __name__ == "__main__": + model_dir = os.path.join(PROJECT_ROOT, "params") + context_path = os.path.join(PROJECT_ROOT, "README.md") + + model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16) + spliter = SemanticTextSplitter(model.encode) + retriever = Retriever() + text = open(context_path, "r", encoding="utf-8").read() + + res = spliter.split(text, threshold=0.8, window_size=1) + # print(("\n" + "+"*100 + "\n").join(res)) + + res_embs = model.encode(res) + for sentence, emb in zip(res, res_embs): + retriever.add_vector(sentence, emb) + + retrive_top_k = 5 + query = "作者设计了一个怎样的模型" + emb_query = model.encode(query) + retrieved = retriever.retrieve(emb_query, retrive_top_k) + + retrive_response = model.retrieve_generate( + retrieved=retrieved, + query=query, + temperature=0.7, + top_k=30, + top_p=0.95, + ) + + print("retrive content:") + print("\n".join([f"{idx + 1}. " + text for idx, (text, _) in enumerate(retrieved)])) + + print("\n\nretrive generate:") + print(retrive_response) \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..ea67044 --- /dev/null +++ b/setup.py @@ -0,0 +1,18 @@ +import re +from setuptools import find_packages, setup + + +with open("requirements.txt") as f: + required = [line for line in f.read().splitlines() + if line and re.match(r'^[^=]+==[^=]+$', line.strip())] + +setup( + name="khaosz", + version="1.2.0", + packages=find_packages(), + install_requires=required, + dependency_links=[ + "https://download.pytorch.org/whl/cu126", + ], + python_requires="==3.12.*", +) \ No newline at end of file diff --git a/tests/test_module.py b/tests/test_module.py new file mode 100644 index 0000000..025188e --- /dev/null +++ b/tests/test_module.py @@ -0,0 +1,103 @@ +import os +import json +import torch +import shutil +import pytest +import tempfile +import safetensors.torch as st +from khaosz.core import * +from khaosz.core.generator import EmbeddingEncoderCore, GeneratorCore +from tokenizers import pre_tokenizers + +@pytest.fixture +def test_env(): + test_dir = tempfile.mkdtemp() + config_path = os.path.join(test_dir, "config.json") + tokenizer_path = os.path.join(test_dir, "tokenizer.json") + model_path = os.path.join(test_dir, "model.safetensors") + + config = { + "vocab_size": 1000, + "n_dim": 128, + "n_head": 4, + "n_kvhead": 2, + "d_ffn": 256, + "m_len": 64, + "n_layer": 2, + "norm_eps": 1e-5 + } + with open(config_path, 'w') as f: + json.dump(config, f) + + tokenizer = BpeTokenizer() + sp_token_iter = iter(pre_tokenizers.ByteLevel.alphabet()) + tokenizer.train_from_iterator(sp_token_iter, config["vocab_size"], 1) + tokenizer.save(tokenizer_path) + + transformer_config = TransformerConfig().load(config_path) + model = Transformer(transformer_config) + st.save_file(model.state_dict(), model_path) + + yield { + "test_dir": test_dir, + "model": model, + "tokenizer": tokenizer, + "transformer_config": transformer_config, + } + + shutil.rmtree(test_dir) + +# parameter loader +def test_parameter_loader(test_env): + loaded_param = ParameterLoader.load(test_env["test_dir"]) + assert loaded_param.model is not None + assert loaded_param.tokenizer is not None + assert loaded_param.config == test_env["transformer_config"] + +def test_model_parameter(test_env): + save_dir = os.path.join(test_env["test_dir"], "save") + model_param = ModelParameter(test_env["model"],test_env["tokenizer"] , test_env["transformer_config"]) + model_param.save(save_dir) + + assert os.path.exists(os.path.join(save_dir, "model.safetensors")) + assert os.path.exists(os.path.join(save_dir, "tokenizer.json")) + assert os.path.exists(os.path.join(save_dir, "config.json")) + +# transformer +def test_transformer(test_env): + model = test_env["model"] + input_ids = torch.randint(0, test_env["transformer_config"].vocab_size, + (4, test_env["transformer_config"].m_len)) + output_logits = model(input_ids)["logits"] + target_shape = (4, test_env["transformer_config"].m_len, test_env["transformer_config"].vocab_size) + assert output_logits.shape == target_shape + +# generator +def test_embedding_encoder_core(test_env): + parameter = ModelParameter( + test_env["model"], + test_env["tokenizer"], + test_env["transformer_config"] + ) + encoder = EmbeddingEncoderCore(parameter) + + single_emb = encoder.encode("测试文本") + assert isinstance(single_emb, torch.Tensor) + assert single_emb.shape[-1] == test_env["transformer_config"].n_dim + + + batch_emb = encoder.encode(["测试1", "测试2"]) + assert isinstance(batch_emb, list) + assert len(batch_emb) == 2 + +def test_generator_core(test_env): + parameter = ModelParameter( + test_env["model"], + test_env["tokenizer"], + test_env["transformer_config"] + ) + generator = GeneratorCore(parameter) + logits, incr = generator.compute_logits(torch.randint(0, test_env["transformer_config"].vocab_size, (4, 10))) + + assert logits.shape == (4, test_env["transformer_config"].vocab_size) + assert incr == 10 \ No newline at end of file diff --git a/tests/test_trainer.py b/tests/test_trainer.py new file mode 100644 index 0000000..1c23544 --- /dev/null +++ b/tests/test_trainer.py @@ -0,0 +1,203 @@ +import os +import json +import torch +import shutil +import pytest +import pickle +import tempfile +import matplotlib + +from torch.utils.data import Dataset +from khaosz.core import * +from khaosz.trainer import * + +# to avoid _tkinter.TclError +matplotlib.use('Agg') + + +@pytest.fixture +def test_env(): + test_dir = tempfile.mkdtemp() + config_path = os.path.join(test_dir, "config.json") + + config = { + "vocab_size": 1000, + "n_dim": 128, + "n_head": 4, + "n_kvhead": 2, + "d_ffn": 256, + "m_len": 64, + "n_layer": 2, + "norm_eps": 1e-5 + } + + with open(config_path, 'w') as f: + json.dump(config, f) + + transformer_config = TransformerConfig().load(config_path) + model = Transformer(transformer_config) + tokenizer = BpeTokenizer() + + class DummyDataset(Dataset): + def __init__(self, length=10): + self.length = length + + def __len__(self): + return self.length + + def __getitem__(self, idx): + return ( + torch.randint(0, 1000, (64,)), + torch.randint(0, 1000, (64,)) + ) + + dataset = DummyDataset() + + yield { + "test_dir": test_dir, + "config_path": config_path, + "transformer_config": transformer_config, + "model": model, + "tokenizer": tokenizer, + "dataset": dataset + } + + shutil.rmtree(test_dir) + +def test_dataset_loader(test_env): + test_dir = test_env["test_dir"] + pkl_path = os.path.join(test_dir, "test_data.pkl") + + dummy_data = {"sequence": torch.randint(0, 1000, (64,))} + with open(pkl_path, "wb") as f: + pickle.dump(dummy_data, f) + + loaded_dataset = DatasetLoader.load(train_type="seq", load_path=pkl_path, max_len=64, device="cpu") + assert loaded_dataset is not None + +def test_training_config(test_env): + optimizer = torch.optim.AdamW(test_env["model"].parameters()) + train_config = TrainConfig( + train_type="seq", + dataset=test_env["dataset"], + optimizer=optimizer, + ckpt_dir=test_env["test_dir"], + n_epoch=1, + batch_size=2, + n_iter_ckpt=5, + n_iter_step=1, + max_grad_norm=1.0, + random_seed=42 + ) + assert train_config.get_kwargs()["batch_size"] == 2 + +def test_cosine_schedule(test_env): + assert test_env is not None + schedule_config = CosineScheduleConfig( + warning_step=100, + total_iters=1000 + ) + kwargs = schedule_config.get_kwargs() + assert kwargs["warning_step"] == 100 + assert kwargs["lr_decay_iters"] == 900 + + +def test_sgdr_schedule(test_env): + assert test_env is not None + schedule_config = SgdrScheduleConfig( + warning_step=100, + cycle_length=200, + T_mult=2 + ) + kwargs = schedule_config.get_kwargs() + assert kwargs["warning_step"] == 100 + assert kwargs["cycle_length"] == 200 + assert kwargs["T_mult"] == 2 + +def test_trainer_train(test_env): + optimizer = torch.optim.AdamW(test_env["model"].parameters()) + train_config = TrainConfig( + train_type="seq", + dataset=test_env["dataset"], + optimizer=optimizer, + ckpt_dir=test_env["test_dir"], + n_epoch=1, + batch_size=2, + n_iter_ckpt=5, + n_iter_step=1, + max_grad_norm=1.0, + random_seed=42 + ) + schedule_config = CosineScheduleConfig( + warning_step=100, + total_iters=1000 + ) + model_parameter = ModelParameter( + test_env["model"], + test_env["tokenizer"], + test_env["transformer_config"] + ) + trainer = Trainer(model_parameter) + trainer.train(train_config, schedule_config) + +def test_checkpoint(test_env): + temp_dir = test_env["test_dir"] + config = test_env["transformer_config"] + model = test_env["model"] + tokenizer = test_env["tokenizer"] + + param = ModelParameter(model, tokenizer, config) + checkpoint = Checkpoint( + model=param.model, + tokenizer=param.tokenizer, + config=param.config, + loss_list=[1.0, 2.0, 3.0], + current_iter=3 + ) + ckpt_dir = os.path.join(temp_dir, "ckpt") + checkpoint.save(ckpt_dir) + + loaded_ckpt = Checkpoint() + loaded_ckpt.load(ckpt_dir) + + assert loaded_ckpt.current_iter == 3 + assert loaded_ckpt.loss_list == [1.0, 2.0, 3.0] + + for p1, p2 in zip(model.parameters(), loaded_ckpt.model.parameters()): + assert torch.allclose(p1, p2) + + +def test_checkpoint_train(test_env): + temp_dir = test_env["test_dir"] + config = test_env["transformer_config"] + model = test_env["model"] + tokenizer = test_env["tokenizer"] + dataset = test_env["dataset"] + + param = ModelParameter(model, tokenizer, config) + trainer = Trainer(param) + + optimizer = torch.optim.AdamW(test_env["model"].parameters()) + train_config = TrainConfig( + train_type="seq", + dataset=dataset, + optimizer=optimizer, + ckpt_dir=test_env["test_dir"], + n_epoch=1, + batch_size=2, + n_iter_ckpt=5, + n_iter_step=1, + max_grad_norm=1.0, + random_seed=42 + ) + schedule_config = CosineScheduleConfig( + warning_step=100, + total_iters=1000 + ) + + trainer.train( + train_config=train_config, + schedule_config=schedule_config, + ) + + \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..707ad15 --- /dev/null +++ b/train.py @@ -0,0 +1,128 @@ +import os +import argparse +import torch + +from torch.optim import AdamW +from khaosz.core import ParameterLoader +from khaosz.trainer import Trainer, DatasetLoader, TrainConfig, CosineScheduleConfig + + +PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) + +def get_files(root_path: str) -> list[str]: + paths = [] + for root, _, files in os.walk(root_path): + paths.extend([os.path.join(root, file) for file in files]) + + return paths + +def train( + train_type: str, + param_path: str, + data_root_path: str, + n_epoch: int, + batch_size: int, + n_iter_step: int, + warning_step: int, + max_lr: int, + n_iter_ckpt: int, + ckpt_dir: str, + dpo_beta: float, + adamw_betas: tuple, + adamw_weight_decay: float, + max_grad_norm: float, + embdeding_lr_rate: int, + random_seed: int, +): + assert train_type in ["seq", "sft", "dpo"] + assert os.path.exists(param_path) + + parameter = ParameterLoader.load(param_path) + model = parameter.model + + device = torch.device("cuda") + model = model.to(device=device, dtype=torch.bfloat16) + + cache_files = get_files(data_root_path) + dataset = DatasetLoader.load( + train_type=train_type, + load_path=cache_files, + max_len=parameter.config.m_len, + device=device + ) + + param_groups = [ + {"params": [p for n, p in model.named_parameters() if "embedding" in n], "lr": max_lr * embdeding_lr_rate}, + {"params": [p for n, p in model.named_parameters() if "embedding" not in n], "lr": max_lr} + ] + + optim = AdamW( + param_groups, + betas=adamw_betas, + weight_decay=adamw_weight_decay + ) + + train_config = TrainConfig( + train_type=train_type, + dataset=dataset, + optimizer=optim, + ckpt_dir=ckpt_dir, + n_epoch=n_epoch, + batch_size=batch_size, + n_iter_ckpt=n_iter_ckpt, + n_iter_step=n_iter_step, + max_grad_norm=max_grad_norm, + random_seed=random_seed, + dpo_beta=dpo_beta + ) + + schedule_config = CosineScheduleConfig( + warning_step=warning_step, + total_iters=len(dataset) * n_epoch // batch_size, + ) + + trainer = Trainer(parameter) + trainer.train( + train_config=train_config, + schedule_config=schedule_config, + ) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Train the Transformer model.") + parser.add_argument("--train_type",choices=["seq", "sft", "dpo"], help="Train type.") + parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.") + parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.") + parser.add_argument("--n_epoch", type=int, default=1, help="Number of epochs to train.") + parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training.") + parser.add_argument("--n_iter_step", type=int, default=1, help="Number of iterations between each optimizer step.") + parser.add_argument("--warning_step", type=int, default=1000, help="Number of iters between warnings.") + parser.add_argument("--max_lr", type=float, default=3e-4, help="Max learning rate for training.") + parser.add_argument("--n_iter_ckpt", type=int, default=5000, help="Number of iters between checkpoints.") + parser.add_argument("--ckpt_dir", type=str, default="checkpoint", help="Directory to save checkpoints.") + parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.") + parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping.") + parser.add_argument("--adamw_betas", type=tuple, default=(0.9, 0.95), help="Beta values for AdamW optimizer.") + parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.") + parser.add_argument("--embdeding_lr_rate", type=float, default=1.0, help="The rate between the embedding layers lr rate and the max lr rate.") + parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.") + + args = parser.parse_args() + + train( + param_path=args.param_path, + data_root_path=args.data_root_path, + n_epoch=args.n_epoch, + batch_size=args.batch_size, + n_iter_step=args.n_iter_step, + warning_step=args.warning_step, + max_lr=args.max_lr, + dpo_beta=args.dpo_beta, + adamw_betas=args.adamw_betas, + adamw_weight_decay=args.adamw_weight_decay, + max_grad_norm=args.max_grad_norm, + embdeding_lr_rate=args.embdeding_lr_rate, + n_iter_ckpt=args.n_iter_ckpt, + ckpt_dir=args.ckpt_dir, + train_type=args.train_type, + random_seed=args.random_seed + ) \ No newline at end of file