Initial commit
This commit is contained in:
commit
a4443765ee
|
|
@ -0,0 +1,22 @@
|
||||||
|
# cache
|
||||||
|
__pycache__
|
||||||
|
.pytest_cache
|
||||||
|
|
||||||
|
# params
|
||||||
|
*.safetensors
|
||||||
|
*.json
|
||||||
|
*.pkl
|
||||||
|
*.db
|
||||||
|
|
||||||
|
# train_log
|
||||||
|
*.log
|
||||||
|
|
||||||
|
# ignore file
|
||||||
|
-*
|
||||||
|
|
||||||
|
# vscode file
|
||||||
|
.vscode
|
||||||
|
|
||||||
|
# build file
|
||||||
|
build
|
||||||
|
*.egg-info
|
||||||
|
|
@ -0,0 +1,201 @@
|
||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
|
@ -0,0 +1,333 @@
|
||||||
|

|
||||||
|
|
||||||
|
<div style="display: flex; flex-direction: column; align-items: center; justify-content: center; text-align: center; font-size: 16px; font-weight: bold; margin-top: 50px;">
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<a href="#english" style="text-decoration: none; margin: 0 10px; color: blue;">English</a> |
|
||||||
|
<a href="#chinese" style="text-decoration: none; margin: 0 10px; color: blue;">中文</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<h1 style="margin: 20px 0 0 0; font-size: 2.5em; font-weight: bold;">KHAOSZ </h1>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<h2 id="english">English Version</h2>
|
||||||
|
|
||||||
|
This is a Chinese-English bilingual Transformer model supporting both languages. It contains model configurations and training workflows, completing training by loading parameters defined in `param_path/config.json`. The training script `train.py` parses command-line arguments, including dataset root directory, number of training epochs, batch size, checkpoint interval, and checkpoint directory.
|
||||||
|
|
||||||
|
**Model Download Options (Choose One):**
|
||||||
|
|
||||||
|
1. Visit [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) to access **Files and versions**
|
||||||
|
2. Run `scripts/download.py` to download parameters
|
||||||
|
|
||||||
|
**Demo Video:** [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd)
|
||||||
|
|
||||||
|
Training dataset sources are listed in the **Model Card** section of the HuggingFace download link.
|
||||||
|
|
||||||
|
**License:** Code follows Apache-2.0 protocol. Please credit the source code when used.
|
||||||
|
|
||||||
|
- **📊 Device Selection:** Code defaults to CUDA training
|
||||||
|
- **🌐 Performance Optimization:** `dtype=torch.bfloat16` is enabled to accelerate training and reduce memory usage. Ensure hardware supports this feature.
|
||||||
|
- **🤖 Language Support:** Model supports Chinese and English training. The BBPE tokenizer was trained without multilingual text, so OOV (out-of-vocabulary) issues are minimized for these languages but may exist for others.
|
||||||
|
|
||||||
|
### 📌 Training Guide
|
||||||
|
|
||||||
|
To train this Transformer model, follow these steps:
|
||||||
|
|
||||||
|
**(1). Prepare Dataset:**
|
||||||
|
|
||||||
|
Place datasets in the designated root directory. Files should be text documents in Chinese, English, or mixed. Format should align with model input requirements - preferably pre-tokenized token_ids stored as `torch.Tensor` (using `torch.Tensor` saves memory compared to Python lists, which default to 64-bit precision).
|
||||||
|
|
||||||
|
**(2). Install Dependencies:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
pip install .
|
||||||
|
```
|
||||||
|
|
||||||
|
**(3). Run Training Script:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python train.py \
|
||||||
|
--train_type=train_type[seq, sft, dpo] \
|
||||||
|
--data_root_path=/path/to/dataset \
|
||||||
|
--param_path=/path/to/param_path \
|
||||||
|
--n_epoch=5 \
|
||||||
|
--batch_size=8 \
|
||||||
|
--max_lr=2e-4 \
|
||||||
|
--n_iter_ckpt=10000 \
|
||||||
|
--ckpt_dir=checkpoints
|
||||||
|
```
|
||||||
|
|
||||||
|
**Parameters Explanation:**
|
||||||
|
- `--train_type`: Training type (seq, sft, dpo)
|
||||||
|
- `--data_root_path`: Root directory of the dataset
|
||||||
|
- `--param_path`: Path to the model training parameters
|
||||||
|
- `--n_epoch`: Total number of training epochs
|
||||||
|
- `--batch_size`: Batch size
|
||||||
|
- `--n_iter_step`: Number of batches per training step
|
||||||
|
- `--warning_step`: Number of warmup steps
|
||||||
|
- `--max_lr`: Maximum learning rate (using warmup + cosine decay)
|
||||||
|
- `--n_iter_ckpt`: Checkpoint saving interval
|
||||||
|
- `--ckpt_dir`: Directory to save checkpoints
|
||||||
|
- `--resume_dir`: Resume training from the specified path
|
||||||
|
|
||||||
|
Training logs will be saved in `train_log.txt`. Checkpoints will be saved in the specified directory for resuming training or evaluation.
|
||||||
|
|
||||||
|
### 👉 Usage Guide
|
||||||
|
|
||||||
|
**(1). Chatting with the Model:**
|
||||||
|
|
||||||
|
Open `chat.py` or use streaming/non-streaming interfaces:
|
||||||
|
|
||||||
|
**Streaming Output:**
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from khaosz import Khaosz
|
||||||
|
|
||||||
|
model_dir = "your_model_parameter_dir"
|
||||||
|
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||||
|
history = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
query = input(">> ")
|
||||||
|
if query == "!exit":
|
||||||
|
break
|
||||||
|
|
||||||
|
response_size = 0
|
||||||
|
for response, history in model.stream_generate(
|
||||||
|
query=query,
|
||||||
|
history=history,
|
||||||
|
temperature=0.85,
|
||||||
|
top_p=0.95,
|
||||||
|
top_k=50
|
||||||
|
):
|
||||||
|
print(response[response_size:], end="")
|
||||||
|
response_size = len(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Non-streaming Output:**
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from khaosz import Khaosz
|
||||||
|
|
||||||
|
model_dir = "your_model_parameter_dir"
|
||||||
|
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||||
|
history = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
query = input(">> ")
|
||||||
|
if query == "!exit":
|
||||||
|
break
|
||||||
|
|
||||||
|
response = model.generate(
|
||||||
|
query=query,
|
||||||
|
history=history,
|
||||||
|
temperature=0.85,
|
||||||
|
top_p=0.95,
|
||||||
|
top_k=50
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
**(2) Retrieval-Augmented Generation (RAG):**
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from khaosz import Khaosz
|
||||||
|
|
||||||
|
model_dir = "your_model_parameter_dir"
|
||||||
|
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
retrieved_content = model.retrieve_generate(
|
||||||
|
query=query,
|
||||||
|
retrieve_top_k=5,
|
||||||
|
temperature=0.6,
|
||||||
|
top_k=30,
|
||||||
|
top_p=0.95
|
||||||
|
)
|
||||||
|
print(retrieved_content)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 📌 Model Specifications
|
||||||
|
|
||||||
|
This model is based on a 24-layer Transformer with parameters defined in `config.json`, totaling approximately 1.0 billion (1.0B) parameters.
|
||||||
|
|
||||||
|
**Key Design Choices:**
|
||||||
|
- Weight tying between embedding and final linear layers (standard for small models to save parameters)
|
||||||
|
- Embedding layer optimization: Without weight tying, a 10,000-word vocabulary would consume ~102M parameters (0.1B)
|
||||||
|
|
||||||
|
**Limitations:**
|
||||||
|
- May struggle with complex language phenomena due to smaller parameter size
|
||||||
|
- Prone to overfitting on specialized datasets
|
||||||
|
- Limited multilingual capabilities
|
||||||
|
|
||||||
|
**Advantages:**
|
||||||
|
- Runs efficiently on lower-spec hardware
|
||||||
|
- Shorter training time compared to larger models
|
||||||
|
|
||||||
|
**Training Pipeline:**
|
||||||
|
The model has completed pre-training + SFT (Supervised Fine-Tuning) + DPO (Direct Preference Optimization) workflows. All corresponding training code is included in the repository.
|
||||||
|
|
||||||
|
|
||||||
|
<h2 id="chinese">中文版本</h2>
|
||||||
|
这是一个支持中英文双语的 Transformer 模型,能够处理两种语言。模型包含配置文件和训练流程,通过加载 `param_path/config.json` 中定义的参数完成训练。训练脚本 `train.py` 支持命令行参数解析,包括数据集根目录、训练轮数(epochs)、批量大小(batch size)、检查点保存间隔、检查点目录等。
|
||||||
|
|
||||||
|
**模型下载选项(任选其一):**
|
||||||
|
|
||||||
|
1. 访问 [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) 查看 **Files and versions**
|
||||||
|
2. 运行 `scripts/download.py` 下载模型参数
|
||||||
|
|
||||||
|
**演示视频:** [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd)
|
||||||
|
|
||||||
|
训练数据来源请参见 HuggingFace 下载页面中的 **Model Card** 部分。
|
||||||
|
|
||||||
|
**许可证:** 代码遵循 Apache-2.0 协议,使用时请注明出处。
|
||||||
|
|
||||||
|
- **📊 设备选择:** 默认使用 CUDA 进行训练
|
||||||
|
- **🌐 性能优化:** 启用 `dtype=torch.bfloat16` 以加速训练并减少内存占用,请确保硬件支持该特性
|
||||||
|
- **🤖 语言支持:** 模型支持中文和英文训练。由于 BBPE 分词器未使用多语言文本训练,因此中英文的 OOV(未登录词)问题较少,其他语言可能存在 OOV 问题
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### 📌 训练指南
|
||||||
|
|
||||||
|
要训练该 Transformer 模型,请按照以下步骤操作:
|
||||||
|
|
||||||
|
#### **(1). 准备数据集:**
|
||||||
|
|
||||||
|
将数据集放置在指定的根目录下。文件应为包含中文、英文或混合文本的文本文档。格式应符合模型输入要求——建议使用预分词后的 `token_ids` 并以 `torch.Tensor` 格式保存(使用 `torch.Tensor` 相比 Python 列表更节省内存,列表默认为 64 位精度)。
|
||||||
|
|
||||||
|
#### **(2). 安装依赖:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
pip install .
|
||||||
|
```
|
||||||
|
|
||||||
|
#### **(3). 运行训练脚本:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python train.py \
|
||||||
|
--train_type=train_type[seq, sft, dpo] \
|
||||||
|
--data_root_path=/path/to/dataset \
|
||||||
|
--param_path=/path/to/param_path \
|
||||||
|
--n_epoch=5 \
|
||||||
|
--batch_size=8 \
|
||||||
|
--max_lr=2e-4 \
|
||||||
|
--n_iter_ckpt=10000 \
|
||||||
|
--ckpt_dir=checkpoints
|
||||||
|
```
|
||||||
|
|
||||||
|
**参数说明:**
|
||||||
|
- `--train_type`: 训练类型(seq, sft, dpo)
|
||||||
|
- `--data_root_path`: 数据集根目录
|
||||||
|
- `--param_path`: 模型训练参数路径
|
||||||
|
- `--n_epoch`: 总训练轮数
|
||||||
|
- `--batch_size`: 批量大小
|
||||||
|
- `--n_iter_step`: 每个训练步骤的 batch 数量
|
||||||
|
- `--warning_step`: 预热步数(warmup steps)
|
||||||
|
- `--max_lr`: 最大学习率(使用预热 + 余弦衰减)
|
||||||
|
- `--n_iter_ckpt`: 检查点保存间隔
|
||||||
|
- `--ckpt_dir`: 检查点保存目录
|
||||||
|
- `--resume_dir`: 从指定路径恢复训练
|
||||||
|
|
||||||
|
训练日志将保存在 `train_log.txt` 中。检查点将保存在指定目录,用于恢复训练或评估。
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### 👉 使用指南
|
||||||
|
|
||||||
|
#### **(1). 与模型对话:**
|
||||||
|
|
||||||
|
打开 `chat.py` 或使用流式/非流式接口:
|
||||||
|
|
||||||
|
**流式输出:**
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from khaosz import Khaosz
|
||||||
|
|
||||||
|
model_dir = "your_model_parameter_dir"
|
||||||
|
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||||
|
history = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
query = input(">> ")
|
||||||
|
if query == "!exit":
|
||||||
|
break
|
||||||
|
|
||||||
|
response_size = 0
|
||||||
|
for response, history in model.stream_generate(
|
||||||
|
query=query,
|
||||||
|
history=history,
|
||||||
|
temperature=0.85,
|
||||||
|
top_p=0.95,
|
||||||
|
top_k=50
|
||||||
|
):
|
||||||
|
print(response[response_size:], end="")
|
||||||
|
response_size = len(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
**非流式输出:**
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from khaosz import Khaosz
|
||||||
|
|
||||||
|
model_dir = "your_model_parameter_dir"
|
||||||
|
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||||
|
history = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
query = input(">> ")
|
||||||
|
if query == "!exit":
|
||||||
|
break
|
||||||
|
|
||||||
|
response = model.generate(
|
||||||
|
query=query,
|
||||||
|
history=history,
|
||||||
|
temperature=0.85,
|
||||||
|
top_p=0.95,
|
||||||
|
top_k=50
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### **(2). 基于检索的生成(RAG):**
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from khaosz import Khaosz
|
||||||
|
|
||||||
|
model_dir = "your_model_parameter_dir"
|
||||||
|
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
retrieved_content = model.retrieve_generate(
|
||||||
|
query=query,
|
||||||
|
retrieve_top_k=5,
|
||||||
|
temperature=0.6,
|
||||||
|
top_k=30,
|
||||||
|
top_p=0.95
|
||||||
|
)
|
||||||
|
print(retrieved_content)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### 📌 模型规格说明(重复部分)
|
||||||
|
|
||||||
|
该模型基于一个 24 层的 Transformer 架构,参数配置定义在 `config.json` 中,总参数量约为 10 亿(1.0B)。
|
||||||
|
|
||||||
|
**关键设计选择:**
|
||||||
|
- 在嵌入层(embedding)与最终线性层之间进行权重绑定(weight tying),这是小型模型中常见的节省参数量的做法
|
||||||
|
- 嵌入层优化:若不进行权重绑定,一个包含 10,000 个词的词汇表将消耗约 1.02 亿(0.1B)参数
|
||||||
|
|
||||||
|
**局限性:**
|
||||||
|
- 由于参数规模较小,可能在处理复杂语言现象时表现受限
|
||||||
|
- 在特定领域的数据集上容易出现过拟合
|
||||||
|
- 多语言能力有限
|
||||||
|
|
||||||
|
**优势:**
|
||||||
|
- 可在低配置硬件上高效运行
|
||||||
|
- 相较于大型模型,训练时间更短
|
||||||
|
|
||||||
|
**训练流程:**
|
||||||
|
该模型已完成预训练(pre-training)+ 监督微调(SFT, Supervised Fine-Tuning)+ 直接偏好优化(DPO, Direct Preference Optimization)的全流程。所有相关的训练代码均已包含在代码库中。
|
||||||
|
|
@ -0,0 +1,89 @@
|
||||||
|
## 模型介绍
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### 1. 模型搭建
|
||||||
|
|
||||||
|
本模型采用Transformer架构, 使用GQA(q_head=24, kv_head=4) 机制,相较于传统的MHA可以节省KV cache 的显存占用(但是目前没有做KV cache),通过堆叠24层Transformer实现模型的搭建, 参数量为1.0b。Transformer 是自回归模型, 是通过计算前面所有的token的关系得到下一个token的概率分布
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
什么是自回归模型呢, 在把句子拆分成token之后, 模型会预测下一个token的概率分布。这意味着模型会根据给定的上下文(即已经出现的tokens序列),计算出下一个可能的token及其对应的概率。
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#### 1. 自回归
|
||||||
|
|
||||||
|
假设我们有一个句子被拆分成如下tokens列表:
|
||||||
|
|
||||||
|
```
|
||||||
|
["你好", "," "今天", "天气"]
|
||||||
|
```
|
||||||
|
|
||||||
|
接下来,模型会基于这个序列预测下一个可能出现的token。这通常以概率分布的形式给出,比如:
|
||||||
|
|
||||||
|
```
|
||||||
|
-> {"token": "不错", "probability": 0.4}
|
||||||
|
-> {"token": "晴朗", "probability": 0.2}
|
||||||
|
-> ......
|
||||||
|
```
|
||||||
|
|
||||||
|
这里,“不错”和“晴朗”是两个可能跟随在“天气”之后的tokens,并且给出了每个token成为下一个token的可能性大小。
|
||||||
|
|
||||||
|
之后,我们通过采样(通过top_k, top_p, temperature参数调整采样后的结果)得到下一个token并且将下一个token加入序列作为输入
|
||||||
|
|
||||||
|
```
|
||||||
|
["你好", "," "今天", "天气", "不错"]
|
||||||
|
```
|
||||||
|
|
||||||
|
之后都是在重复这个流程, 直到遇到控制流程结束的token(<|end_of_seqence|>)模型停止处理(一般模型都会设置控制token, 不然模型会一直输出到显存爆炸)。
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#### 2. 因果掩码
|
||||||
|
|
||||||
|
transformer 中采用注意力机制,输入的形状一般为[bsz, seq_len], 输出为[bsz, seq_len,n_dim], 为了实现预测下一个token, 模型的输入和输出必须错开来一个位置。模型预测的target必须错开一个位置, 在训练的时候我们也采用错开一个位置的方法
|
||||||
|
|
||||||
|
```
|
||||||
|
sequence : [[1, 2, 3, 4, 5, 6]]
|
||||||
|
input_ids: [[1, 2, 3, 4, 5]]
|
||||||
|
target_ids: [[2, 3, 4, 5, 6]]
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
注意力得分计算的公式为
|
||||||
|
|
||||||
|
|
||||||
|
$$ s_{ij} = softmax(\frac{q_i^Tk_j}{\sqrt{d_k}}) $$
|
||||||
|
$$ s_{ij} := s_{ij} + mask_{ij} $$
|
||||||
|
|
||||||
|
|
||||||
|
其中注意力得分代表了模型对两个token之间相似程度的关注程度
|
||||||
|
|
||||||
|
对于decoder only结构的模型, 为了防止模型从未来的位置偷到信息, 在注意力的计算过程中需要增加掩码,我们需要在注意力得分计算之前应用一个掩码。这个掩码通常是一个下三角矩阵,对于长度为n的序列,它的形状是[n, n]。下面以一个长度为5的序列为例,展示如何创建这样的因果掩码矩阵:
|
||||||
|
|
||||||
|
```
|
||||||
|
[[0, -inf, -inf, -inf, -inf],
|
||||||
|
[0, 0, -inf, -inf, -inf],
|
||||||
|
[0, 0, 0, -inf, -inf],
|
||||||
|
[0, 0, 0, 0, -inf],
|
||||||
|
[0, 0, 0, 0, 0]]
|
||||||
|
```
|
||||||
|
|
||||||
|
在这个矩阵中,0表示可以注意到的位置,而-inf表示应该被掩盖(即不应注意到)的位置。因为这个句子保证了注意力得分中 $j > i$ 的部分通过softmax 之后由`inf` 变成0, 也就是模型不能看到未来的信息
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#### 3. 旋转位置编码
|
||||||
|
|
||||||
|
旋转位置编码(Rotary Position Embedding, RoPE)是一种为了解决Transformer模型中缺乏对序列位置信息直接建模的问题而设计的位置编码方法。与传统的位置编码(如正弦和余弦函数的位置编码)不同,RoPE通过将位置信息直接嵌入到查询(Query, Q)和键(Key, K)向量中来实现,使得模型能够更自然地处理序列中的相对位置关系。
|
||||||
|
|
||||||
|
|
||||||
|
$$ q_i = R_i W_q x_i $$
|
||||||
|
$$ k_j = R_j W_k x_j $$
|
||||||
|
$$ q_i^T k_j = (R_i W_q x_i)^T( R_j W_k x_j) = x_i^T W_q^T R_{i-j} W_k x_j $$
|
||||||
|
|
||||||
|
其中的 $R_{i-j}$ 控制了模型的不同token 在不同相对距离上注意力的衰减,在 $i - j$ 绝对值越大的时候, 衰减的程度越强, 通过这种方式能让模型学习到相对位置关系, 从而使得模型可以扩展和适应长序列
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
## kv_cache 实现
|
||||||
|
|
||||||
|
根据注意力的计算公式
|
||||||
|
|
||||||
|
$$
|
||||||
|
\begin{align*}
|
||||||
|
o_i &= \sum_j s_{ij} v_{j} \\
|
||||||
|
s_{ij} &= \text{softmax}\left( \sum_n \frac{q_{i,n} k_{j,n}}{\sqrt{d_k}} \right)
|
||||||
|
\end{align*}
|
||||||
|
$$
|
||||||
|
|
||||||
|
由于模型是自回归模型, 我们只用求序列最后一个部分,也就是说 $ i $ 的下标是确定的, 是序列最后一个元素, 我们求的是 $o_{n} $
|
||||||
|
|
||||||
|
$$
|
||||||
|
\begin{align*}
|
||||||
|
o_n &= \sum_j s_{j}v_{j,n} \\
|
||||||
|
s_j &= \text{softmax}\left(\sum_n\frac{q_n k_{j,n}}{\sqrt{d_k}} \right)
|
||||||
|
\end{align*}
|
||||||
|
$$
|
||||||
|
|
||||||
|
如果我们把式子展开
|
||||||
|
|
||||||
|
$$
|
||||||
|
o_n = \sum_j \sum_n \text{softmax}\left(\frac{q_n k_{j,n}}{\sqrt{d_k}}\right)v_{j,n}
|
||||||
|
$$
|
||||||
|
|
||||||
|
以上表达式只有k和v存在长度下标, 而 $q$ 没有, 所以计算过程中 $q$ 的输入是确定的上次输入的最后一个token, 而 $k, v$ 是需要对不同长度的部分进行缓存的,同时缓存的时候应该注意位置编码的计算应该在kvcache的计算之前进行,否则会存在位置编码的计算错误
|
||||||
Binary file not shown.
|
After Width: | Height: | Size: 21 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 11 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 590 KiB |
|
|
@ -0,0 +1,101 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from khaosz import Khaosz
|
||||||
|
from typing import List
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
def batch_generate(
|
||||||
|
model: Khaosz,
|
||||||
|
queries: List[str],
|
||||||
|
temperature: float,
|
||||||
|
top_k: int,
|
||||||
|
top_p: float,
|
||||||
|
batch_size: int,
|
||||||
|
) -> List:
|
||||||
|
assert batch_size > 0
|
||||||
|
sorted_queries = sorted(queries, key=lambda x: len(x), reverse=True)
|
||||||
|
original_indices = {query: idx for idx, query in enumerate(queries)}
|
||||||
|
|
||||||
|
responses = [None] * len(queries)
|
||||||
|
total_batches = (len(sorted_queries) + batch_size - 1) // batch_size
|
||||||
|
|
||||||
|
for i in tqdm(range(0, total_batches * batch_size, batch_size), desc="Generating responses"):
|
||||||
|
batch_queries = sorted_queries[i: min(i + batch_size, len(queries))]
|
||||||
|
if not isinstance(batch_queries, list):
|
||||||
|
batch_queries = [batch_queries]
|
||||||
|
|
||||||
|
batch_responses = model.batch_generate(
|
||||||
|
queries=batch_queries,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p
|
||||||
|
)
|
||||||
|
|
||||||
|
for batch_query, batch_response in zip(batch_queries, batch_responses):
|
||||||
|
print(f"Q: {batch_query[:50]} \nR: {batch_response[:50]})")
|
||||||
|
|
||||||
|
for query, response in zip(batch_queries, batch_responses):
|
||||||
|
original_idx = original_indices[query]
|
||||||
|
responses[original_idx] = response
|
||||||
|
|
||||||
|
return responses
|
||||||
|
|
||||||
|
|
||||||
|
def processor(
|
||||||
|
model: Khaosz,
|
||||||
|
input_json_file: str,
|
||||||
|
output_json_file: str,
|
||||||
|
batch_size: int,
|
||||||
|
temperature: float,
|
||||||
|
top_p: float,
|
||||||
|
top_k: int,
|
||||||
|
question_key: str="question",
|
||||||
|
):
|
||||||
|
with open(input_json_file, "r", encoding='utf-8') as f:
|
||||||
|
input_dict = [json.loads(line) for line in f]
|
||||||
|
queries = [item[question_key] for item in input_dict]
|
||||||
|
|
||||||
|
output_dict = batch_generate(
|
||||||
|
model=model,
|
||||||
|
queries=queries,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
batch_size=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(output_json_file, "w", encoding='utf-8') as f:
|
||||||
|
json.dump(output_dict, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.")
|
||||||
|
|
||||||
|
parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory.")
|
||||||
|
parser.add_argument("--input_json_file", type=str, required=True, help="Path to the input JSONL file.")
|
||||||
|
parser.add_argument("--output_json_file", type=str, required=True, help="Path to the output JSONL file.")
|
||||||
|
parser.add_argument("--question_key", type=str, default="question", help="Key for the question in the input JSON.")
|
||||||
|
parser.add_argument("--temperature", type=float, default=0.60, help="Temperature for generating responses.")
|
||||||
|
parser.add_argument("--top_p", type=float, default=0.95, help="Top-p value for generating responses.")
|
||||||
|
parser.add_argument("--top_k", type=int, default=30, help="Top-k value for generating responses.")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for generating responses.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
model = Khaosz(args.model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
processor(
|
||||||
|
model,
|
||||||
|
input_json_file=args.input_json_file,
|
||||||
|
output_json_file=args.output_json_file,
|
||||||
|
question_key=args.question_key,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
temperature=args.temperature,
|
||||||
|
top_k=args.top_k,
|
||||||
|
top_p=args.top_p
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
__version__ = "1.2.1"
|
||||||
|
__author__ = "ViperEkura"
|
||||||
|
|
||||||
|
from khaosz.model import Khaosz
|
||||||
|
from khaosz.core.transformer import Transformer, TransformerConfig
|
||||||
|
from khaosz.utils.retriever import Retriever
|
||||||
|
from khaosz.utils.splitter import (
|
||||||
|
SemanticTextSplitter,
|
||||||
|
PriorityTextSplitter
|
||||||
|
)
|
||||||
|
from khaosz.core.tokenizer import BpeTokenizer
|
||||||
|
from khaosz.core.parameter import ParameterLoader
|
||||||
|
from khaosz.core.generator import (
|
||||||
|
TextGenerator,
|
||||||
|
ChatGenerator,
|
||||||
|
StreamGenerator,
|
||||||
|
BatchGenerator,
|
||||||
|
RetrievalGenerator,
|
||||||
|
EmbeddingEncoder
|
||||||
|
)
|
||||||
|
from khaosz.trainer.trainer import Trainer
|
||||||
|
from khaosz.trainer.dataset import SeqDataset, SftDataset, DpoDataset, BaseDataset
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# model
|
||||||
|
"Khaosz",
|
||||||
|
|
||||||
|
# module
|
||||||
|
"Transformer",
|
||||||
|
"TransformerConfig",
|
||||||
|
"BpeTokenizer",
|
||||||
|
"ParameterLoader",
|
||||||
|
"TextGenerator",
|
||||||
|
"ChatGenerator",
|
||||||
|
"StreamGenerator",
|
||||||
|
"BatchGenerator",
|
||||||
|
"RetrievalGenerator",
|
||||||
|
"EmbeddingEncoder",
|
||||||
|
|
||||||
|
# trainer
|
||||||
|
"Trainer",
|
||||||
|
"SeqDataset",
|
||||||
|
"SftDataset",
|
||||||
|
"DpoDataset",
|
||||||
|
"BaseDataset",
|
||||||
|
|
||||||
|
# utils
|
||||||
|
"Retriever",
|
||||||
|
"SemanticTextSplitter",
|
||||||
|
"PriorityTextSplitter",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
from khaosz.core.tokenizer import BpeTokenizer
|
||||||
|
from khaosz.core.transformer import Transformer, TransformerConfig
|
||||||
|
from khaosz.core.parameter import ParameterLoader, ModelParameter, Checkpoint
|
||||||
|
from khaosz.core.generator import (
|
||||||
|
TextGenerator,
|
||||||
|
ChatGenerator,
|
||||||
|
StreamGenerator,
|
||||||
|
BatchGenerator,
|
||||||
|
RetrievalGenerator,
|
||||||
|
EmbeddingEncoder
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Transformer",
|
||||||
|
"TransformerConfig",
|
||||||
|
"BpeTokenizer",
|
||||||
|
"ParameterLoader",
|
||||||
|
"ModelParameter",
|
||||||
|
"Checkpoint",
|
||||||
|
"TextGenerator",
|
||||||
|
"ChatGenerator",
|
||||||
|
"StreamGenerator",
|
||||||
|
"BatchGenerator",
|
||||||
|
"RetrievalGenerator",
|
||||||
|
"EmbeddingEncoder"
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,568 @@
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from typing import List, Tuple, Union, Optional, Generator, Self
|
||||||
|
from khaosz.core.parameter import ModelParameter
|
||||||
|
|
||||||
|
|
||||||
|
def build_prompt(query: str, history: Optional[List[Tuple[str, str]]] = None) -> str:
|
||||||
|
"""
|
||||||
|
Build prompt for query and history
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query(str): query string
|
||||||
|
history(Optional[List[Tuple[str, str]]]): history list of query and response
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: prompt string
|
||||||
|
|
||||||
|
"""
|
||||||
|
prompt_parts = []
|
||||||
|
|
||||||
|
if history is None:
|
||||||
|
history = []
|
||||||
|
|
||||||
|
for his_query, his_response in history:
|
||||||
|
prompt_parts.append(f"<|user|> {his_query} <|system|> <bos>{his_response}<eos>")
|
||||||
|
|
||||||
|
if query is not None:
|
||||||
|
prompt_parts.append(f"<|user|> {query} <|system|> <bos>")
|
||||||
|
|
||||||
|
return "\n".join(prompt_parts)
|
||||||
|
|
||||||
|
def pad_sequence(ids_list: List[List[int]], max_ids_len: int, pad_id: int) -> List[List[int]]:
|
||||||
|
"""
|
||||||
|
Pad a list of sequences to a fixed length.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids_list (List[List[int]]): A list of sequences.
|
||||||
|
max_ids_len (int): The maximum length of sequences.
|
||||||
|
pad_id (int): The id to pad sequences.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[List[int]]: A list of padded sequences.
|
||||||
|
|
||||||
|
"""
|
||||||
|
new_ids_list = []
|
||||||
|
for ids in ids_list:
|
||||||
|
pad_len = max_ids_len - len(ids)
|
||||||
|
padded_seq = [pad_id] * pad_len + ids
|
||||||
|
new_ids_list.append(padded_seq)
|
||||||
|
|
||||||
|
return new_ids_list
|
||||||
|
|
||||||
|
def apply_sampling_strategies(
|
||||||
|
logits: Tensor,
|
||||||
|
temperature: float,
|
||||||
|
top_k: int,
|
||||||
|
top_p: float,
|
||||||
|
filter_value: float = -float("inf")
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
Apply sampling strategies to the logits tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits (Tensor): The logits tensor.
|
||||||
|
temperature (float): The temperature parameter.
|
||||||
|
top_k (int): The top-k parameter.
|
||||||
|
top_p (float): The top-p parameter.
|
||||||
|
filter_value (float, optional): The filter value. Defaults to -float("inf").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: The sampled logits tensor.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
if temperature != 1.0:
|
||||||
|
logits = logits / temperature
|
||||||
|
|
||||||
|
if top_k > 0:
|
||||||
|
top_k = min(top_k, logits.size(-1))
|
||||||
|
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
|
||||||
|
logits[indices_to_remove] = filter_value
|
||||||
|
|
||||||
|
if top_p < 1.0:
|
||||||
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||||
|
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||||
|
|
||||||
|
sorted_indices_to_remove = cumulative_probs > top_p
|
||||||
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||||
|
sorted_indices_to_remove[..., 0] = 0
|
||||||
|
|
||||||
|
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
|
||||||
|
indices_to_remove.scatter_(
|
||||||
|
dim=1,
|
||||||
|
index=sorted_indices,
|
||||||
|
src=sorted_indices_to_remove
|
||||||
|
)
|
||||||
|
|
||||||
|
logits[indices_to_remove] = filter_value
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class KVCacheManager:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_layers: int,
|
||||||
|
batch_size: int,
|
||||||
|
max_len: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
device: torch.device = "cuda",
|
||||||
|
dtype: torch.dtype = torch.bfloat16
|
||||||
|
):
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.max_len = max_len
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.device = device
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
self._kv_cache: List[Tuple[Tensor, Tensor]] = None
|
||||||
|
self._seq_mask: Tensor = None
|
||||||
|
self._initialize()
|
||||||
|
|
||||||
|
def _initialize(self):
|
||||||
|
self._kv_cache = []
|
||||||
|
for _ in range(self.num_layers):
|
||||||
|
k_cache = torch.zeros(
|
||||||
|
(self.batch_size, self.max_len, self.num_heads, self.head_dim),
|
||||||
|
device=self.device, dtype=self.dtype
|
||||||
|
)
|
||||||
|
v_cache = torch.zeros(
|
||||||
|
(self.batch_size, self.max_len, self.num_heads, self.head_dim),
|
||||||
|
device=self.device, dtype=self.dtype
|
||||||
|
)
|
||||||
|
self._kv_cache.append((k_cache, v_cache))
|
||||||
|
|
||||||
|
self._seq_mask = torch.ones(
|
||||||
|
(self.batch_size, self.max_len),
|
||||||
|
device=self.device, dtype=torch.bool
|
||||||
|
)
|
||||||
|
|
||||||
|
def update(self, active_mask: Tensor):
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
k_cache, v_cache = self._kv_cache[i]
|
||||||
|
new_k_cache, new_v_cache = k_cache[active_mask], v_cache[active_mask]
|
||||||
|
self._kv_cache[i] = (new_k_cache, new_v_cache)
|
||||||
|
|
||||||
|
self._seq_mask = self._seq_mask[active_mask]
|
||||||
|
|
||||||
|
def reset(self, full_reset=False):
|
||||||
|
if full_reset:
|
||||||
|
self._kv_cache = None
|
||||||
|
self._seq_mask = None
|
||||||
|
else:
|
||||||
|
self._initialize()
|
||||||
|
|
||||||
|
def set_seq_mask(self, input_ids: Tensor, pad_id: int):
|
||||||
|
batch_size, seq_len = input_ids.shape
|
||||||
|
bool_mask = (input_ids != pad_id)
|
||||||
|
self._seq_mask[: batch_size, : seq_len] = bool_mask
|
||||||
|
|
||||||
|
def get_kvcache(self) -> List[Tuple[Tensor, Tensor]]:
|
||||||
|
return self._kv_cache
|
||||||
|
|
||||||
|
def get_seq_mask(self) -> Tensor:
|
||||||
|
return self._seq_mask
|
||||||
|
|
||||||
|
|
||||||
|
class GeneratorCore:
|
||||||
|
def __init__(self, parameter: ModelParameter):
|
||||||
|
self.model = parameter.model
|
||||||
|
self.tokenizer = parameter.tokenizer
|
||||||
|
self.config = parameter.config
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
input_ids: Tensor,
|
||||||
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
|
||||||
|
start_pos: int = 0
|
||||||
|
) -> Tuple[Tensor, int]:
|
||||||
|
with torch.inference_mode():
|
||||||
|
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
|
||||||
|
logits = outputs["logits"][:, -1, :]
|
||||||
|
cache_increase = input_ids.size(-1)
|
||||||
|
|
||||||
|
return logits, cache_increase
|
||||||
|
|
||||||
|
def to(self, *args, **kargs) -> Self:
|
||||||
|
self.model.to(*args, **kargs)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingEncoderCore:
|
||||||
|
def __init__(self, parameter: ModelParameter):
|
||||||
|
self.model = parameter.model
|
||||||
|
self.tokenizer = parameter.tokenizer
|
||||||
|
self.config = parameter.config
|
||||||
|
|
||||||
|
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
|
||||||
|
with_batch = isinstance(sentence, list)
|
||||||
|
ids = self.tokenizer.encode(sentence)
|
||||||
|
batch_ids = ids if with_batch else [ids]
|
||||||
|
max_model_len = self.config.m_len
|
||||||
|
|
||||||
|
all_fragments = []
|
||||||
|
fragment_origin_idx = []
|
||||||
|
|
||||||
|
for i, seq in enumerate(batch_ids):
|
||||||
|
if len(seq) > max_model_len:
|
||||||
|
fragments = [seq[j:j+max_model_len] for j in range(0, len(seq), max_model_len)]
|
||||||
|
all_fragments.extend(fragments)
|
||||||
|
fragment_origin_idx.extend([i] * len(fragments))
|
||||||
|
else:
|
||||||
|
all_fragments.append(seq)
|
||||||
|
fragment_origin_idx.append(i)
|
||||||
|
|
||||||
|
#if empty fragments
|
||||||
|
if not all_fragments or not ids:
|
||||||
|
return [] if with_batch else torch.tensor([])
|
||||||
|
|
||||||
|
device = next(self.model.parameters()).device
|
||||||
|
max_len = min(max(len(seq) for seq in all_fragments), max_model_len)
|
||||||
|
|
||||||
|
padded_ids = []
|
||||||
|
masks = []
|
||||||
|
for seq in all_fragments:
|
||||||
|
pad_len = max_len - len(seq)
|
||||||
|
padded_seq = seq + [self.tokenizer.pad_id] * pad_len
|
||||||
|
mask = [token_id != self.tokenizer.pad_id for token_id in padded_seq]
|
||||||
|
padded_ids.append(padded_seq)
|
||||||
|
masks.append(mask)
|
||||||
|
|
||||||
|
input_tensor = torch.tensor(padded_ids, device=device, dtype=torch.long)
|
||||||
|
seq_mask = torch.tensor(masks, device=device, dtype=torch.bool)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
outputs = self.model(input_tensor, seq_mask)["hidden_states"]
|
||||||
|
# [num_fragments, seq_len, hidden_size]
|
||||||
|
fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1))
|
||||||
|
|
||||||
|
sentence_embs: List[Tensor] = []
|
||||||
|
for i in range(len(batch_ids)):
|
||||||
|
indices = [idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i]
|
||||||
|
if indices is not None:
|
||||||
|
sum_frags = torch.sum(fragment_embs[indices, :, :], dim=1) # [frags, hidden_size]
|
||||||
|
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(1) # [frags, 1]
|
||||||
|
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
|
||||||
|
sentence_embs.append(emb.flatten())
|
||||||
|
|
||||||
|
if with_batch:
|
||||||
|
return [emb.flatten() for emb in sentence_embs]
|
||||||
|
else:
|
||||||
|
return sentence_embs[0].flatten()
|
||||||
|
|
||||||
|
def to(self, *args, **kargs) -> Self:
|
||||||
|
self.model.to(*args, **kargs)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class TextGenerator(GeneratorCore):
|
||||||
|
def __init__(self, parameter: ModelParameter):
|
||||||
|
super().__init__(parameter)
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
temperature: float,
|
||||||
|
top_k: int,
|
||||||
|
top_p: float,
|
||||||
|
) -> str:
|
||||||
|
assert temperature >= 0.0
|
||||||
|
assert top_k >= 0
|
||||||
|
assert top_p >= 0.0 and top_p <= 1.0
|
||||||
|
|
||||||
|
device = next(self.model.parameters()).device
|
||||||
|
cache_manager = KVCacheManager(
|
||||||
|
num_layers=self.config.n_layer,
|
||||||
|
batch_size=1,
|
||||||
|
max_len=self.config.m_len,
|
||||||
|
num_heads=self.config.n_kvhead,
|
||||||
|
head_dim=self.config.n_dim // self.config.n_head,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
ids = self.tokenizer.encode(query)
|
||||||
|
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
||||||
|
|
||||||
|
start_cache_pos = len(ids)
|
||||||
|
cur_cache_pos = 0
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
while len(ids) < self.config.m_len:
|
||||||
|
kv_caches = cache_manager.get_kvcache()
|
||||||
|
logits, cache_increase = self.compute_logits(
|
||||||
|
input_ids,
|
||||||
|
kv_caches=kv_caches,
|
||||||
|
start_pos=cur_cache_pos
|
||||||
|
)
|
||||||
|
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
|
||||||
|
probs = torch.softmax(logits, dim=-1)
|
||||||
|
next_token_id = torch.multinomial(probs, num_samples=1)
|
||||||
|
|
||||||
|
input_ids = next_token_id
|
||||||
|
ids.append(next_token_id.item())
|
||||||
|
cur_cache_pos += cache_increase
|
||||||
|
|
||||||
|
if next_token_id.item() in self.tokenizer.stop_ids:
|
||||||
|
break
|
||||||
|
|
||||||
|
response = self.tokenizer.decode(ids[start_cache_pos:])
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGenerator(GeneratorCore):
|
||||||
|
def __init__(self, parameter: ModelParameter):
|
||||||
|
super().__init__(parameter)
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
history: List[Tuple[str, str]],
|
||||||
|
temperature: float,
|
||||||
|
top_k: int,
|
||||||
|
top_p: float,
|
||||||
|
) -> str:
|
||||||
|
|
||||||
|
assert temperature >= 0.0
|
||||||
|
assert top_k >= 0
|
||||||
|
assert top_p >= 0.0 and top_p <= 1.0
|
||||||
|
|
||||||
|
if history is None:
|
||||||
|
history = []
|
||||||
|
|
||||||
|
device = next(self.model.parameters()).device
|
||||||
|
cache_manager = KVCacheManager(
|
||||||
|
num_layers=self.config.n_layer,
|
||||||
|
batch_size=1,
|
||||||
|
max_len=self.config.m_len,
|
||||||
|
num_heads=self.config.n_kvhead,
|
||||||
|
head_dim=self.config.n_dim // self.config.n_head,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
ids = self.tokenizer.encode(build_prompt(query, history))
|
||||||
|
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
||||||
|
cpy_history = history.copy()
|
||||||
|
|
||||||
|
start_cache_pos = len(ids)
|
||||||
|
cur_cache_pos = 0
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
while len(ids) < self.config.m_len:
|
||||||
|
kv_caches = cache_manager.get_kvcache()
|
||||||
|
logits, cache_increase = self.compute_logits(
|
||||||
|
input_ids,
|
||||||
|
kv_caches=kv_caches,
|
||||||
|
start_pos=cur_cache_pos
|
||||||
|
)
|
||||||
|
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
|
||||||
|
probs = torch.softmax(logits, dim=-1)
|
||||||
|
next_token_id = torch.multinomial(probs, num_samples=1)
|
||||||
|
|
||||||
|
input_ids = next_token_id
|
||||||
|
ids.append(next_token_id.item())
|
||||||
|
cur_cache_pos += cache_increase
|
||||||
|
|
||||||
|
if next_token_id.item() in self.tokenizer.stop_ids:
|
||||||
|
break
|
||||||
|
|
||||||
|
response = self.tokenizer.decode(ids[start_cache_pos:])
|
||||||
|
cpy_history.append((query, response))
|
||||||
|
|
||||||
|
return response, cpy_history
|
||||||
|
|
||||||
|
|
||||||
|
class StreamGenerator(GeneratorCore):
|
||||||
|
def __init__(self, parameter: ModelParameter):
|
||||||
|
super().__init__(parameter)
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
history: List[Tuple[str, str]],
|
||||||
|
temperature: float,
|
||||||
|
top_k: int,
|
||||||
|
top_p: float,
|
||||||
|
) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
|
||||||
|
|
||||||
|
assert temperature >= 0.0
|
||||||
|
assert top_k >= 0
|
||||||
|
assert top_p >= 0.0 and top_p <= 1.0
|
||||||
|
|
||||||
|
if history is None:
|
||||||
|
history = []
|
||||||
|
|
||||||
|
device = next(self.model.parameters()).device
|
||||||
|
cache_manager = KVCacheManager(
|
||||||
|
num_layers=self.config.n_layer,
|
||||||
|
batch_size=1,
|
||||||
|
max_len=self.config.m_len,
|
||||||
|
num_heads=self.config.n_kvhead,
|
||||||
|
head_dim=self.config.n_dim // self.config.n_head,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
ids = self.tokenizer.encode(build_prompt(query, history))
|
||||||
|
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
||||||
|
cpy_history = history.copy()
|
||||||
|
|
||||||
|
start_cache_pos = len(ids)
|
||||||
|
cur_cache_pos = 0
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
while len(ids) < self.config.m_len:
|
||||||
|
kv_caches = cache_manager.get_kvcache()
|
||||||
|
logits, cache_increase = self.compute_logits(
|
||||||
|
input_ids,
|
||||||
|
kv_caches=kv_caches,
|
||||||
|
start_pos=cur_cache_pos
|
||||||
|
)
|
||||||
|
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
|
||||||
|
probs = torch.softmax(logits, dim=-1)
|
||||||
|
next_token_id = torch.multinomial(probs, num_samples=1)
|
||||||
|
|
||||||
|
input_ids = next_token_id
|
||||||
|
ids.append(next_token_id.item())
|
||||||
|
cur_cache_pos += cache_increase
|
||||||
|
|
||||||
|
response = self.tokenizer.decode(ids[start_cache_pos:])
|
||||||
|
yield response, cpy_history + [(query, response)]
|
||||||
|
|
||||||
|
if next_token_id.item() in self.tokenizer.stop_ids:
|
||||||
|
yield response + "\n", cpy_history + [(query, response)]
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
class BatchGenerator(GeneratorCore):
|
||||||
|
def __init__(self, parameter: ModelParameter):
|
||||||
|
super().__init__(parameter)
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
queries: List[str],
|
||||||
|
histories: List[List[Tuple[str, str]]],
|
||||||
|
temperature: float,
|
||||||
|
top_k: int,
|
||||||
|
top_p: float
|
||||||
|
) -> List[str]:
|
||||||
|
|
||||||
|
assert temperature >= 0.0
|
||||||
|
assert top_k >= 0
|
||||||
|
assert top_p >= 0.0 and top_p <= 1.0
|
||||||
|
|
||||||
|
batch_size = len(queries)
|
||||||
|
if histories is None:
|
||||||
|
histories = [[] for _ in range(batch_size)]
|
||||||
|
|
||||||
|
prompts = [build_prompt(query, history) for query, history in zip(queries, histories)]
|
||||||
|
ids_list = [self.tokenizer.encode(prompt) for prompt in prompts]
|
||||||
|
max_ids_len = max(len(ids) for ids in ids_list)
|
||||||
|
ids_list = pad_sequence(ids_list, max_ids_len, self.tokenizer.pad_id)
|
||||||
|
|
||||||
|
device = next(self.model.parameters()).device
|
||||||
|
cache_manager = KVCacheManager(
|
||||||
|
num_layers=self.config.n_layer,
|
||||||
|
batch_size=batch_size,
|
||||||
|
max_len=self.config.m_len,
|
||||||
|
num_heads=self.config.n_kvhead,
|
||||||
|
head_dim=self.config.n_dim // self.config.n_head,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long)
|
||||||
|
cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id)
|
||||||
|
activate_task_mask = [True] * batch_size
|
||||||
|
|
||||||
|
start_cache_pos = max_ids_len
|
||||||
|
cur_cache_pos = 0
|
||||||
|
|
||||||
|
while max_ids_len < self.config.m_len and sum(activate_task_mask) != 0:
|
||||||
|
kv_caches = cache_manager.get_kvcache()
|
||||||
|
attn_mask =cache_manager.get_seq_mask()
|
||||||
|
|
||||||
|
logits, cache_increase = self.compute_logits(
|
||||||
|
input_tensor,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
kv_caches=kv_caches,
|
||||||
|
start_pos=cur_cache_pos
|
||||||
|
)
|
||||||
|
|
||||||
|
cur_cache_pos += cache_increase
|
||||||
|
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
|
||||||
|
probs = torch.softmax(logits, dim=-1)
|
||||||
|
next_token_id = torch.multinomial(probs, num_samples=1)
|
||||||
|
|
||||||
|
active_mask = []
|
||||||
|
c_ids = 0
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
if activate_task_mask[i]:
|
||||||
|
token = next_token_id[c_ids, :].item()
|
||||||
|
ids_list[i].append(token)
|
||||||
|
c_ids += 1
|
||||||
|
|
||||||
|
is_active = not token in self.tokenizer.stop_ids
|
||||||
|
activate_task_mask[i] = is_active
|
||||||
|
active_mask.append(is_active)
|
||||||
|
|
||||||
|
active_mask = torch.tensor(active_mask, device=device, dtype=torch.bool)
|
||||||
|
cache_manager.update(active_mask)
|
||||||
|
input_tensor = next_token_id[active_mask, :]
|
||||||
|
|
||||||
|
max_ids_len += 1
|
||||||
|
|
||||||
|
|
||||||
|
responses = [str()] * batch_size
|
||||||
|
for i in range(batch_size):
|
||||||
|
responses[i] = self.tokenizer.decode(ids_list[i][start_cache_pos:])
|
||||||
|
histories[i].append((queries[i], responses[i]))
|
||||||
|
|
||||||
|
return responses
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalGenerator(GeneratorCore):
|
||||||
|
def __init__(self, retriever_parameter: ModelParameter):
|
||||||
|
super().__init__(retriever_parameter)
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
retrieved: List[str],
|
||||||
|
query: str,
|
||||||
|
history: List[Tuple[str, str]],
|
||||||
|
temperature: float,
|
||||||
|
top_k: int,
|
||||||
|
top_p: float,
|
||||||
|
) -> str:
|
||||||
|
assert temperature >= 0.0
|
||||||
|
assert top_k >= 0
|
||||||
|
assert top_p >= 0.0 and top_p <= 1.0
|
||||||
|
|
||||||
|
if history is None:
|
||||||
|
history = []
|
||||||
|
|
||||||
|
retrieved = "\n".join([f"{idx + 1}. {key}" for idx, key in enumerate(retrieved)]) if retrieved else ""
|
||||||
|
retrieved_query = f"{retrieved}<eos>\n\n根据以上内容回答: {query}" if retrieved else query
|
||||||
|
parameter = ModelParameter(self.model, self.tokenizer, self.config)
|
||||||
|
|
||||||
|
return ChatGenerator(parameter).generate(
|
||||||
|
retrieved_query,
|
||||||
|
history,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
class EmbeddingEncoder(EmbeddingEncoderCore):
|
||||||
|
def __init__(self, parameter: ModelParameter):
|
||||||
|
super().__init__(parameter)
|
||||||
|
|
||||||
|
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
|
||||||
|
return super().encode(sentence)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,238 @@
|
||||||
|
import pickle as pkl
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import safetensors.torch as st
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional, Self, Union
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from khaosz.core.tokenizer import BpeTokenizer
|
||||||
|
from khaosz.core.transformer import TransformerConfig, Transformer
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModelIO:
|
||||||
|
"""Base class for model I/O operations."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Optional[nn.Module] = None,
|
||||||
|
tokenizer: Optional[BpeTokenizer] = None,
|
||||||
|
config: Optional[TransformerConfig] = None
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer or BpeTokenizer()
|
||||||
|
self.config = config or TransformerConfig()
|
||||||
|
|
||||||
|
def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
|
||||||
|
"""Get standardized file paths for model components."""
|
||||||
|
dir_path = Path(directory)
|
||||||
|
return {
|
||||||
|
"model": dir_path / "model.safetensors",
|
||||||
|
"config": dir_path / "config.json",
|
||||||
|
"tokenizer": dir_path / "tokenizer.json"
|
||||||
|
}
|
||||||
|
|
||||||
|
def save_components(self, save_dir: Union[str, Path]):
|
||||||
|
"""Save core model components."""
|
||||||
|
paths = self._get_file_paths(save_dir)
|
||||||
|
paths["model"].parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if self.model is not None:
|
||||||
|
st.save_file(self.model.state_dict(), str(paths["model"]))
|
||||||
|
self.config.save(str(paths["config"]))
|
||||||
|
self.tokenizer.save(str(paths["tokenizer"]))
|
||||||
|
|
||||||
|
def load_components(self, load_dir: Union[str, Path]) -> Self:
|
||||||
|
"""Load core model components."""
|
||||||
|
paths = self._get_file_paths(load_dir)
|
||||||
|
|
||||||
|
self.config.load(str(paths["config"]))
|
||||||
|
self.tokenizer.load(str(paths["tokenizer"]))
|
||||||
|
|
||||||
|
if paths["model"].exists():
|
||||||
|
state_dict = st.load_file(str(paths["model"]))
|
||||||
|
if self.model is None:
|
||||||
|
self.model = Transformer(self.config)
|
||||||
|
self.model.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs) -> Self:
|
||||||
|
"""Move model to device."""
|
||||||
|
if self.model is not None:
|
||||||
|
self.model.to(*args, **kwargs)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelParameter(BaseModelIO):
|
||||||
|
"""Container for model parameters with serialization capabilities."""
|
||||||
|
|
||||||
|
model: Optional[nn.Module] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Transformer model."}
|
||||||
|
)
|
||||||
|
tokenizer: BpeTokenizer = field(
|
||||||
|
default_factory=BpeTokenizer,
|
||||||
|
metadata={"help": "Tokenizer for the model."}
|
||||||
|
)
|
||||||
|
config: TransformerConfig = field(
|
||||||
|
default_factory=TransformerConfig,
|
||||||
|
metadata={"help": "Transformer model configuration."}
|
||||||
|
)
|
||||||
|
|
||||||
|
def save(self, save_dir: Union[str, Path]):
|
||||||
|
"""Save model parameters."""
|
||||||
|
self.save_components(save_dir)
|
||||||
|
|
||||||
|
def load(self, load_dir: Union[str, Path]) -> Self:
|
||||||
|
"""Load model parameters."""
|
||||||
|
return self.load_components(load_dir)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Checkpoint(BaseModelIO):
|
||||||
|
"""Extended model parameters with training state."""
|
||||||
|
|
||||||
|
model: Optional[nn.Module] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Transformer model."}
|
||||||
|
)
|
||||||
|
tokenizer: BpeTokenizer = field(
|
||||||
|
default_factory=BpeTokenizer,
|
||||||
|
metadata={"help": "Tokenizer for the model."}
|
||||||
|
)
|
||||||
|
config: TransformerConfig = field(
|
||||||
|
default_factory=TransformerConfig,
|
||||||
|
metadata={"help": "Transformer model configuration."}
|
||||||
|
)
|
||||||
|
loss_list: list[float] = field(
|
||||||
|
default_factory=list,
|
||||||
|
metadata={"help": "List of training losses."}
|
||||||
|
)
|
||||||
|
current_iter: int = field(
|
||||||
|
default=0,
|
||||||
|
metadata={"help": "Current training iteration."}
|
||||||
|
)
|
||||||
|
optimizer: Optional[optim.Optimizer] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Optimizer state."}
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# Ensure current_iter matches loss list length if not explicitly set
|
||||||
|
if self.current_iter == 0 and self.loss_list:
|
||||||
|
self.current_iter = len(self.loss_list)
|
||||||
|
|
||||||
|
def _get_training_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
|
||||||
|
"""Get file paths for training-specific files."""
|
||||||
|
paths = self._get_file_paths(directory)
|
||||||
|
paths.update({
|
||||||
|
"loss_list": paths["model"].parent / "loss.pkl",
|
||||||
|
"loss_plot": paths["model"].parent / "loss.png",
|
||||||
|
"optimizer": paths["model"].parent / "optimizer.pkl"
|
||||||
|
})
|
||||||
|
return paths
|
||||||
|
|
||||||
|
def save_training_state(self, save_dir: Union[str, Path]):
|
||||||
|
"""Save training-specific state."""
|
||||||
|
paths = self._get_training_paths(save_dir)
|
||||||
|
|
||||||
|
# Save loss plot
|
||||||
|
self._plot_loss(str(paths["loss_plot"]))
|
||||||
|
|
||||||
|
# Save loss list
|
||||||
|
with open(str(paths["loss_list"]), "wb") as f:
|
||||||
|
pkl.dump(self.loss_list, f)
|
||||||
|
|
||||||
|
# Save optimizer state
|
||||||
|
if self.optimizer is not None:
|
||||||
|
with open(str(paths["optimizer"]), "wb") as f:
|
||||||
|
pkl.dump(self.optimizer.state_dict(), f)
|
||||||
|
|
||||||
|
def load_training_state(self, load_dir: Union[str, Path]) -> Self:
|
||||||
|
"""Load training-specific state."""
|
||||||
|
paths = self._get_training_paths(load_dir)
|
||||||
|
|
||||||
|
# Load loss list
|
||||||
|
if paths["loss_list"].exists():
|
||||||
|
with open(str(paths["loss_list"]), "rb") as f:
|
||||||
|
self.loss_list = pkl.load(f)
|
||||||
|
self.current_iter = len(self.loss_list)
|
||||||
|
|
||||||
|
# Load optimizer state
|
||||||
|
if paths["optimizer"].exists() and self.optimizer is not None:
|
||||||
|
with open(str(paths["optimizer"]), "rb") as f:
|
||||||
|
optim_state = pkl.load(f)
|
||||||
|
self.optimizer.load_state_dict(optim_state)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _plot_loss(self, save_path: str):
|
||||||
|
"""Plot and save loss curve."""
|
||||||
|
if not self.loss_list:
|
||||||
|
return
|
||||||
|
|
||||||
|
plt.figure(figsize=(10, 6))
|
||||||
|
plt.plot(self.loss_list)
|
||||||
|
plt.title(f"Training Loss - Iteration {self.current_iter}")
|
||||||
|
plt.xlabel("Batch")
|
||||||
|
plt.ylabel("Loss")
|
||||||
|
plt.grid(True)
|
||||||
|
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
def save(self, save_dir: Union[str, Path]):
|
||||||
|
"""Save complete checkpoint."""
|
||||||
|
self.save_components(save_dir)
|
||||||
|
self.save_training_state(save_dir)
|
||||||
|
|
||||||
|
def load(self, load_dir: Union[str, Path]) -> Self:
|
||||||
|
"""Load complete checkpoint."""
|
||||||
|
self.load_components(load_dir)
|
||||||
|
self.load_training_state(load_dir)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class ParameterLoader:
|
||||||
|
"""Factory class for loading model parameters or checkpoints."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(load_dir: Union[str, Path]) -> Union[ModelParameter, Checkpoint]:
|
||||||
|
"""Load either ModelParameter or Checkpoint based on directory contents."""
|
||||||
|
load_dir = Path(load_dir)
|
||||||
|
|
||||||
|
# Check for training-specific files
|
||||||
|
loss_file = load_dir / "loss.pkl"
|
||||||
|
has_training_data = loss_file.exists()
|
||||||
|
|
||||||
|
# Create appropriate instance
|
||||||
|
if has_training_data:
|
||||||
|
checkpoint = Checkpoint()
|
||||||
|
checkpoint.load(str(load_dir))
|
||||||
|
return checkpoint
|
||||||
|
else:
|
||||||
|
params = ModelParameter()
|
||||||
|
params.load(str(load_dir))
|
||||||
|
return params
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_checkpoint(
|
||||||
|
model: nn.Module,
|
||||||
|
tokenizer: BpeTokenizer,
|
||||||
|
config: TransformerConfig,
|
||||||
|
loss_list: Optional[list[float]] = None,
|
||||||
|
optimizer: Optional[optim.Optimizer] = None
|
||||||
|
) -> Checkpoint:
|
||||||
|
"""Convenience method to create a training checkpoint."""
|
||||||
|
return Checkpoint(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
config=config,
|
||||||
|
loss_list=loss_list or [],
|
||||||
|
optimizer=optimizer
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,111 @@
|
||||||
|
from tokenizers import Tokenizer, Encoding
|
||||||
|
from tokenizers import decoders, processors, normalizers, pre_tokenizers
|
||||||
|
from tokenizers.models import BPE
|
||||||
|
from tokenizers.trainers import BpeTrainer
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
|
||||||
|
class BpeTokenizer:
|
||||||
|
def __init__(self, path=None):
|
||||||
|
self._control_tokens = ["<bos>", "<eos>", "<pad>"]
|
||||||
|
self._special_tokens = ["<|user|>", "<|system|>"]
|
||||||
|
model = BPE()
|
||||||
|
tokenizer = Tokenizer(model)
|
||||||
|
tokenizer.normalizer = normalizers.Sequence([
|
||||||
|
normalizers.NFC()
|
||||||
|
])
|
||||||
|
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
|
||||||
|
pre_tokenizers.Punctuation(behavior="isolated"),
|
||||||
|
pre_tokenizers.Metaspace(prepend_scheme="never"),
|
||||||
|
pre_tokenizers.Split(pattern=r"(\d+|[a-zA-Z]+|(?:'s|'t|'re|'ve|'m|'ll|'d))", behavior="isolated"),
|
||||||
|
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
|
||||||
|
])
|
||||||
|
tokenizer.decoder = decoders.Sequence([
|
||||||
|
decoders.ByteLevel(),
|
||||||
|
decoders.Metaspace(prepend_scheme="never")
|
||||||
|
])
|
||||||
|
tokenizer.post_processor = processors.Sequence([
|
||||||
|
processors.ByteLevel(trim_offsets=False)
|
||||||
|
])
|
||||||
|
self._tokenizer = tokenizer
|
||||||
|
|
||||||
|
if path is not None:
|
||||||
|
self._tokenizer = Tokenizer.from_file(path)
|
||||||
|
|
||||||
|
def _prepare_trainer(self, vocab_size: int, min_freq: int, reserved_token_size: int) -> tuple:
|
||||||
|
assert reserved_token_size > len(self._special_tokens)
|
||||||
|
reserved_tokens = [f"<|rsv{i:02d}|>" for i in range(reserved_token_size - len(self._special_tokens))]
|
||||||
|
detail_vocab_size = vocab_size - (len(reserved_tokens) + len(self._special_tokens))
|
||||||
|
|
||||||
|
alphabet = pre_tokenizers.ByteLevel.alphabet()
|
||||||
|
min_size = len(alphabet) + len(self._control_tokens)
|
||||||
|
assert detail_vocab_size > min_size
|
||||||
|
|
||||||
|
trainer = BpeTrainer(
|
||||||
|
vocab_size=detail_vocab_size,
|
||||||
|
min_frequency=min_freq,
|
||||||
|
limit_alphabet=detail_vocab_size // 4,
|
||||||
|
max_token_length=18,
|
||||||
|
special_tokens=self._control_tokens,
|
||||||
|
show_progress=True,
|
||||||
|
initial_alphabet=alphabet,
|
||||||
|
)
|
||||||
|
|
||||||
|
return trainer, detail_vocab_size, reserved_tokens
|
||||||
|
|
||||||
|
def train(self, files, vocab_size, min_freq, reserved_token_size=100):
|
||||||
|
trainer, _, reserved_tokens = self._prepare_trainer(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
min_freq=min_freq,
|
||||||
|
reserved_token_size=reserved_token_size
|
||||||
|
)
|
||||||
|
self._tokenizer.train(files=files, trainer=trainer)
|
||||||
|
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
|
||||||
|
|
||||||
|
def train_from_iterator(self, iterator, vocab_size, min_freq, reserved_token_size=100):
|
||||||
|
trainer, _, reserved_tokens = self._prepare_trainer(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
min_freq=min_freq,
|
||||||
|
reserved_token_size=reserved_token_size
|
||||||
|
)
|
||||||
|
self._tokenizer.train_from_iterator(iterator=iterator, trainer=trainer)
|
||||||
|
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
|
||||||
|
|
||||||
|
def save(self, path):
|
||||||
|
self._tokenizer.save(path)
|
||||||
|
|
||||||
|
def load(self, path):
|
||||||
|
self._tokenizer = Tokenizer.from_file(path)
|
||||||
|
|
||||||
|
def encode(self, tokens: Union[str, List[str]], out_ids: bool=True, add_special_tokens: bool=False) -> List:
|
||||||
|
if isinstance(tokens, str):
|
||||||
|
encoded: Encoding = self._tokenizer.encode(tokens, add_special_tokens=add_special_tokens)
|
||||||
|
return encoded.ids if out_ids else encoded.tokens
|
||||||
|
elif isinstance(tokens, list):
|
||||||
|
encoded_list: List[Encoding] = self._tokenizer.encode_batch(tokens, add_special_tokens=add_special_tokens)
|
||||||
|
return [encoded.ids if out_ids else encoded.tokens for encoded in encoded_list]
|
||||||
|
|
||||||
|
def decode(self, tokens: List[int], skip_special_tokens: bool=True) -> str:
|
||||||
|
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self._tokenizer.get_vocab_size()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stop_ids(self) -> List[int]:
|
||||||
|
stop_ids = []
|
||||||
|
for token in self._control_tokens:
|
||||||
|
stop_ids.append(self._tokenizer.token_to_id(token))
|
||||||
|
return stop_ids
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bos_id(self) -> int:
|
||||||
|
return self._tokenizer.token_to_id("<bos>")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eos_id(self) -> int:
|
||||||
|
return self._tokenizer.token_to_id("<eos>")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pad_id(self) -> int:
|
||||||
|
return self._tokenizer.token_to_id("<pad>")
|
||||||
|
|
@ -0,0 +1,341 @@
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn import init
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from typing import List, Optional, Self, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||||
|
"""
|
||||||
|
Repeat k times along the dimension for attention heads.
|
||||||
|
Args:
|
||||||
|
x (Tensor): The input tensor.
|
||||||
|
n_rep (int): The number of repetitions.
|
||||||
|
Returns:
|
||||||
|
Tensor: The repeated tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
bs, slen, n_heads, head_dim = x.shape
|
||||||
|
if n_rep == 1:
|
||||||
|
return x
|
||||||
|
return (
|
||||||
|
x[:, :, :, None, :]
|
||||||
|
.expand(bs, slen, n_heads, n_rep, head_dim)
|
||||||
|
.reshape(bs, slen, n_heads * n_rep, head_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_rotary_emb(
|
||||||
|
dim: int,
|
||||||
|
max_len: int,
|
||||||
|
base: float = 10000,
|
||||||
|
device: torch.device = "cuda",
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Get the rotary embedding for the given dimension and maximum length.
|
||||||
|
Args:
|
||||||
|
dim (int): The dimension of the input.
|
||||||
|
max_len (int): The maximum length of the input.
|
||||||
|
base (float, optional): The base for the frequency. Defaults to 10000.
|
||||||
|
device (torch.device, optional): The device to use. Defaults to "cuda".
|
||||||
|
Returns:
|
||||||
|
Tensor: The rotary embedding tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
theta = base ** (-torch.arange(0, dim, 2, device=device).float() / dim)
|
||||||
|
t = torch.arange(0, max_len, device=device).float()
|
||||||
|
freqs = torch.outer(t, theta)
|
||||||
|
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||||
|
|
||||||
|
return freqs_cis
|
||||||
|
|
||||||
|
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Apply rotary embedding to the input tensor.
|
||||||
|
Args:
|
||||||
|
x (Tensor): The input tensor.
|
||||||
|
freqs_cis (Tensor): The rotary embedding tensor.
|
||||||
|
Returns:
|
||||||
|
Tensor: The output tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dtype = x.dtype
|
||||||
|
seq_len = x.size(1)
|
||||||
|
|
||||||
|
x_complex = torch.view_as_complex(x.view(*x.shape[:-1], -1, 2).float())
|
||||||
|
freqs_cis = freqs_cis.reshape(1, seq_len, 1, -1)
|
||||||
|
x_out = torch.view_as_real(x_complex * freqs_cis).flatten(3)
|
||||||
|
|
||||||
|
return x_out.to(dtype)
|
||||||
|
|
||||||
|
def create_attention_mask(
|
||||||
|
seq_mask: Tensor,
|
||||||
|
start_pos: int = 0,
|
||||||
|
seq_len: int = 0,
|
||||||
|
is_causal: bool = False,
|
||||||
|
device: torch.device = "cuda",
|
||||||
|
dtype: torch.dtype = torch.float32
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
Create attention mask for GQA
|
||||||
|
Args:
|
||||||
|
seq_mask (Tensor): A tensor indicating whether each position is valid or not.
|
||||||
|
start_pos (int): The starting position of the sequence.
|
||||||
|
seq_len (int): The length of the sequence.
|
||||||
|
is_causal (bool): Whether the attention is causal or not.
|
||||||
|
device (torch.device): The device to use.
|
||||||
|
Returns:
|
||||||
|
Tensor: The attention mask tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if start_pos != 0 and seq_mask is None:
|
||||||
|
# for single prompt chat
|
||||||
|
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
|
||||||
|
|
||||||
|
if seq_mask is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
batch_size = seq_mask.size(0)
|
||||||
|
seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool)
|
||||||
|
# (bsz, start_pos + seq_len)
|
||||||
|
expanded_mask = seq_mask.unsqueeze(1).expand(batch_size, seq_len, start_pos + seq_len)
|
||||||
|
# (bsz, seq_len, start_pos + seq_len)
|
||||||
|
|
||||||
|
if is_causal:
|
||||||
|
causal_mask = torch.tril(
|
||||||
|
torch.ones((seq_len, start_pos + seq_len), dtype=torch.bool, device=device),
|
||||||
|
diagonal=start_pos
|
||||||
|
)
|
||||||
|
causal_mask = causal_mask.unsqueeze(0).expand(batch_size, seq_len, start_pos + seq_len)
|
||||||
|
expanded_mask = expanded_mask & causal_mask
|
||||||
|
|
||||||
|
attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
|
||||||
|
attention_mask = attention_mask.masked_fill_(~expanded_mask, -torch.finfo(dtype).max / 2).unsqueeze(1)
|
||||||
|
# (bsz, 1, seq_len, seq_len + start_pos)
|
||||||
|
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TransformerConfig:
|
||||||
|
# basic config
|
||||||
|
vocab_size: Optional[int] = None
|
||||||
|
n_dim: Optional[int] = None
|
||||||
|
n_head: Optional[int] = None
|
||||||
|
n_layer: Optional[int] = None
|
||||||
|
m_len: Optional[int] = None
|
||||||
|
norm_eps: Optional[float] = None
|
||||||
|
d_ffn: Optional[int] = None
|
||||||
|
|
||||||
|
# GQA
|
||||||
|
n_kvhead: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
def load(self, config_path: str) -> Self:
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
config: dict = json.load(f)
|
||||||
|
for key, value in config.items():
|
||||||
|
if hasattr(self, key):
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def save(self, config_path: str) -> None:
|
||||||
|
config_dict = asdict(self)
|
||||||
|
config_dict = {k: v for k, v in config_dict.items() if v is not None}
|
||||||
|
with open(config_path, 'w') as f:
|
||||||
|
json.dump(config_dict, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
class Linear(nn.Module):
|
||||||
|
def __init__(self, in_dim: int, out_dim: int, bias: bool=False):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
|
||||||
|
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
|
||||||
|
init.normal_(self.weight, mean=0, std=0.006)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return F.linear(x, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, n_dim, norm_eps):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(n_dim))
|
||||||
|
self.norm_eps = norm_eps
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
dtype = x.dtype
|
||||||
|
x = x.float()
|
||||||
|
mean_square = torch.mean(torch.pow(x, 2), dim=-1, keepdim=True)
|
||||||
|
norm = x * torch.rsqrt(mean_square + self.norm_eps)
|
||||||
|
norm = norm.to(dtype)
|
||||||
|
out = norm * self.weight
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, n_dim: int, d_ffn: int):
|
||||||
|
super().__init__()
|
||||||
|
self.up = Linear(n_dim, d_ffn)
|
||||||
|
self.gate = Linear(n_dim, d_ffn)
|
||||||
|
self.down = Linear(d_ffn, n_dim)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
gated = self.up(x) * F.silu(self.gate(x))
|
||||||
|
out = self.down(gated)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class GQA(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_dim: int,
|
||||||
|
n_head: int,
|
||||||
|
n_kvhead: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert n_dim % n_head == 0
|
||||||
|
assert n_head % n_kvhead == 0
|
||||||
|
|
||||||
|
self.head_dim = n_dim // n_head
|
||||||
|
self.n_dim = n_dim
|
||||||
|
self.n_heads = n_head
|
||||||
|
self.n_kvheads = n_kvhead
|
||||||
|
self.n_rep = n_head // n_kvhead
|
||||||
|
|
||||||
|
self.q_proj = Linear(n_dim, n_head * self.head_dim)
|
||||||
|
self.k_proj = Linear(n_dim, n_kvhead * self.head_dim)
|
||||||
|
self.v_proj = Linear(n_dim, n_kvhead * self.head_dim)
|
||||||
|
self.o_proj = Linear(n_dim, n_dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
freqs_cis: Tensor,
|
||||||
|
mask: Tensor = None,
|
||||||
|
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
||||||
|
start_pos: int = 0
|
||||||
|
) -> Tensor:
|
||||||
|
bsz, seq_len, _ = x.size()
|
||||||
|
# x(bsz, seq_len, n_heads * head_dim) -> (bsz, seq_len, n_heads, head_dim)
|
||||||
|
q = self._split_heads(self.q_proj(x), self.n_heads)
|
||||||
|
k = self._split_heads(self.k_proj(x), self.n_kvheads)
|
||||||
|
v = self._split_heads(self.v_proj(x), self.n_kvheads)
|
||||||
|
q, k = apply_rotary_emb(q, freqs_cis), apply_rotary_emb(k, freqs_cis)
|
||||||
|
|
||||||
|
if kv_cache is not None:
|
||||||
|
k_cache, v_cache = kv_cache
|
||||||
|
|
||||||
|
# copy to cache
|
||||||
|
k_cache[:bsz, start_pos:start_pos + seq_len] = k
|
||||||
|
v_cache[:bsz, start_pos:start_pos + seq_len] = v
|
||||||
|
|
||||||
|
# get cache
|
||||||
|
k = k_cache[:bsz, :start_pos + seq_len]
|
||||||
|
v = v_cache[:bsz, :start_pos + seq_len]
|
||||||
|
|
||||||
|
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
||||||
|
|
||||||
|
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
|
||||||
|
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
||||||
|
sdqa_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=(mask == None)).permute(0, 2, 1, 3)
|
||||||
|
out = self.o_proj(sdqa_out.contiguous().view(bsz, seq_len, -1))
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _split_heads(self, x: Tensor, n_heads) -> Tensor:
|
||||||
|
batch_size, seq_len, _ = x.shape
|
||||||
|
x = x.reshape(batch_size, seq_len, n_heads, self.head_dim)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderBlock(nn.Module):
|
||||||
|
def __init__(self, n_dim, n_head, d_ffn, n_kvhead, norm_eps):
|
||||||
|
super().__init__()
|
||||||
|
self.attention = GQA(n_dim, n_head, n_kvhead)
|
||||||
|
self.norm_attn = RMSNorm(n_dim, norm_eps)
|
||||||
|
self.ffn = MLP(n_dim, d_ffn)
|
||||||
|
self.norm_ffn = RMSNorm(n_dim, norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
freqs_cis: Tensor,
|
||||||
|
attention_mask: Optional[Tensor] = None,
|
||||||
|
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
||||||
|
start_pos: int = 0
|
||||||
|
) -> Tensor:
|
||||||
|
# attention
|
||||||
|
attn_output = self.attention(
|
||||||
|
self.norm_attn(x),
|
||||||
|
freqs_cis,
|
||||||
|
attention_mask,
|
||||||
|
kv_cache,
|
||||||
|
start_pos
|
||||||
|
)
|
||||||
|
x = attn_output + x
|
||||||
|
|
||||||
|
# feed forward
|
||||||
|
x = self.ffn(self.norm_ffn(x)) + x
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
def __init__(self, config: TransformerConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.embedding = nn.Parameter(torch.empty(config.vocab_size, config.n_dim))
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
DecoderBlock(
|
||||||
|
config.n_dim,
|
||||||
|
config.n_head,
|
||||||
|
config.d_ffn,
|
||||||
|
config.n_kvhead,
|
||||||
|
config.norm_eps
|
||||||
|
)
|
||||||
|
for _ in range(config.n_layer)
|
||||||
|
])
|
||||||
|
self.norm = RMSNorm(config.n_dim, config.norm_eps)
|
||||||
|
self.freq_cis = get_rotary_emb(config.n_dim // config.n_head, config.m_len)
|
||||||
|
init.normal_(self.embedding, mean=0, std=0.02)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Tensor,
|
||||||
|
seq_mask: Optional[Tensor]=None,
|
||||||
|
persistent_key_values: Optional[List[Tuple[Tensor, Tensor]]]=None,
|
||||||
|
start_pos: int = 0
|
||||||
|
) -> Tensor:
|
||||||
|
assert input_ids.ndim == 2
|
||||||
|
seq_len = input_ids.size(-1)
|
||||||
|
x = F.embedding(input_ids, self.embedding)
|
||||||
|
|
||||||
|
self.freq_cis = self.freq_cis.to(x.device)
|
||||||
|
freqs_cis = self.freq_cis[start_pos:start_pos+seq_len]
|
||||||
|
has_kvcache = persistent_key_values is not None
|
||||||
|
|
||||||
|
attn_mask = create_attention_mask(
|
||||||
|
seq_mask,
|
||||||
|
start_pos=start_pos,
|
||||||
|
seq_len=seq_len,
|
||||||
|
is_causal=has_kvcache,
|
||||||
|
device=x.device,
|
||||||
|
dtype=x.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
kv_cache = persistent_key_values[i] if persistent_key_values else None
|
||||||
|
x = layer(x, freqs_cis, attn_mask, kv_cache, start_pos)
|
||||||
|
|
||||||
|
hidden_states = self.norm(x)
|
||||||
|
logits = F.linear(hidden_states, self.embedding)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"logits": logits,
|
||||||
|
"hidden_states": hidden_states
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -0,0 +1,112 @@
|
||||||
|
from torch import Tensor
|
||||||
|
from typing import List, Tuple, Generator, Union
|
||||||
|
|
||||||
|
from khaosz.core.generator import (
|
||||||
|
TextGenerator,
|
||||||
|
ChatGenerator,
|
||||||
|
StreamGenerator,
|
||||||
|
BatchGenerator,
|
||||||
|
RetrievalGenerator,
|
||||||
|
EmbeddingEncoder
|
||||||
|
)
|
||||||
|
from khaosz.core.parameter import ParameterLoader
|
||||||
|
|
||||||
|
|
||||||
|
class Khaosz:
|
||||||
|
def __init__(self, model_dir: str):
|
||||||
|
self.parameter = ParameterLoader.load(model_dir)
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
self.parameter.to(*args, **kwargs)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
history: List[Tuple[str, str]]=None,
|
||||||
|
temperature: float=0.8,
|
||||||
|
top_k: int=50,
|
||||||
|
top_p: float=0.95,
|
||||||
|
) -> str:
|
||||||
|
generator = ChatGenerator(self.parameter)
|
||||||
|
return generator.generate(
|
||||||
|
query,
|
||||||
|
history=history,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
def batch_generate(
|
||||||
|
self,
|
||||||
|
queries: List[str],
|
||||||
|
histories: List[Tuple[str, str]]=None,
|
||||||
|
temperature: float=0.8,
|
||||||
|
top_k: int=50,
|
||||||
|
top_p: float=0.95,
|
||||||
|
) -> List[str]:
|
||||||
|
generator = BatchGenerator(self.parameter)
|
||||||
|
return generator.generate(
|
||||||
|
queries,
|
||||||
|
histories=histories,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def stream_generate(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
history: List[Tuple[str, str]]=None,
|
||||||
|
temperature: float=0.8,
|
||||||
|
top_k: int=50,
|
||||||
|
top_p: float=0.95,
|
||||||
|
) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
|
||||||
|
stream_generator = StreamGenerator(self.parameter)
|
||||||
|
return stream_generator.generate(
|
||||||
|
query,
|
||||||
|
history=history,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
def retrieve_generate(
|
||||||
|
self,
|
||||||
|
retrieved,
|
||||||
|
query: str,
|
||||||
|
history: List[Tuple[str, str]] = None,
|
||||||
|
temperature: float=0.8,
|
||||||
|
top_k: int=50,
|
||||||
|
top_p: float=0.95,
|
||||||
|
) -> str:
|
||||||
|
generator = RetrievalGenerator(self.parameter)
|
||||||
|
return generator.generate(
|
||||||
|
retrieved,
|
||||||
|
query,
|
||||||
|
history=history,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
def text_generate(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
temperature: float=0.8,
|
||||||
|
top_k: int=50,
|
||||||
|
top_p: float=0.95,
|
||||||
|
) -> str:
|
||||||
|
generator = TextGenerator(self.parameter)
|
||||||
|
|
||||||
|
return generator.generate(
|
||||||
|
query,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
|
||||||
|
encoder = EmbeddingEncoder(self.parameter)
|
||||||
|
return encoder.encode(sentence)
|
||||||
|
|
@ -0,0 +1,11 @@
|
||||||
|
from khaosz.trainer.dataset import DatasetLoader
|
||||||
|
from khaosz.trainer.trainer import Trainer
|
||||||
|
from khaosz.trainer.strategy import TrainConfig, CosineScheduleConfig, SgdrScheduleConfig
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DatasetLoader",
|
||||||
|
"Trainer",
|
||||||
|
"TrainConfig",
|
||||||
|
"CosineScheduleConfig",
|
||||||
|
"SgdrScheduleConfig",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,210 @@
|
||||||
|
import torch
|
||||||
|
import bisect
|
||||||
|
import pickle as pkl
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from typing import Callable, List, Dict, Literal, Union
|
||||||
|
|
||||||
|
MutiSeg = Dict[str, List[Tensor]]
|
||||||
|
Seg = Dict[str, Tensor]
|
||||||
|
|
||||||
|
def load_pkl_files(paths: List[str]):
|
||||||
|
segments: MutiSeg = {}
|
||||||
|
total_samples = 0
|
||||||
|
|
||||||
|
for path in paths:
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
pkl_file: Seg = pkl.load(f)
|
||||||
|
for key, value in pkl_file.items():
|
||||||
|
if key not in segments:
|
||||||
|
segments[key] = []
|
||||||
|
segments[key].append(value)
|
||||||
|
first_key = list(pkl_file.keys())[0]
|
||||||
|
total_samples += pkl_file[first_key].numel()
|
||||||
|
|
||||||
|
return segments, total_samples
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSegmentFetcher:
|
||||||
|
def __init__(self, segments: List[Tensor]):
|
||||||
|
self.segments = segments
|
||||||
|
self.cum_lengths = []
|
||||||
|
total = 0
|
||||||
|
for seg in segments:
|
||||||
|
total += len(seg)
|
||||||
|
self.cum_lengths.append(total)
|
||||||
|
self.total_length = total if segments else 0
|
||||||
|
|
||||||
|
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||||
|
if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length):
|
||||||
|
raise ValueError("begin_idx or end_idx out of bounds")
|
||||||
|
if begin_idx >= end_idx:
|
||||||
|
return torch.tensor([], dtype=torch.long)
|
||||||
|
|
||||||
|
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx - 1)
|
||||||
|
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx - 1)
|
||||||
|
|
||||||
|
result_segments = []
|
||||||
|
|
||||||
|
for i in range(seg_start_idx, seg_end_idx + 1):
|
||||||
|
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
|
||||||
|
start = max(begin_idx - prev_cum, 0)
|
||||||
|
end = min(end_idx - prev_cum, len(self.segments[i]))
|
||||||
|
result_segments.append(self.segments[i][start:end])
|
||||||
|
|
||||||
|
return torch.cat(result_segments, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
class MutiSegmentFetcher:
|
||||||
|
def __init__(self, muti_segments: MutiSeg):
|
||||||
|
self.muti_keys = list(muti_segments.keys())
|
||||||
|
self.muti_fetchers = {
|
||||||
|
key: BaseSegmentFetcher(segments)
|
||||||
|
for key, segments in muti_segments.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Union[Tensor, Seg]:
|
||||||
|
fetch_dict = {}
|
||||||
|
keys = [keys] if isinstance(keys, str) else keys
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
fetcher = self.muti_fetchers[key]
|
||||||
|
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
|
||||||
|
fetch_dict[key] = fetch_tensor
|
||||||
|
|
||||||
|
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
|
||||||
|
|
||||||
|
def fetch_data(self, begin_idx: int, end_idx: int) -> Union[Tensor, Seg]:
|
||||||
|
return self.key_fetch(begin_idx, end_idx, self.muti_keys)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDataset(Dataset, ABC):
|
||||||
|
def __init__(self, chunk_size: int, device: str):
|
||||||
|
super().__init__()
|
||||||
|
self.segments: MutiSeg = {}
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
self.total_samples = 0
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def save(self, save_path: str):
|
||||||
|
first_item = self.segments[keys[0]]
|
||||||
|
segment_size = len(first_item)
|
||||||
|
keys = list(self.segments.keys())
|
||||||
|
|
||||||
|
for i in range(segment_size):
|
||||||
|
formated_segment = {key: self.segments[key][i] for key in keys}
|
||||||
|
pkl.dump(formated_segment, open(f"{save_path}_{i}.pkl", "wb"))
|
||||||
|
|
||||||
|
|
||||||
|
def load(self, load_path: Union[str, List[str]]):
|
||||||
|
paths = [load_path] if isinstance(load_path, str) else load_path
|
||||||
|
self.segments, self.total_samples = load_pkl_files(paths)
|
||||||
|
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __getitem__(self, index: int):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
assert self.total_samples // self.chunk_size > 0
|
||||||
|
return self.total_samples // self.chunk_size
|
||||||
|
|
||||||
|
|
||||||
|
class SeqDataset(BaseDataset):
|
||||||
|
def __init__(self, chunk_size , device='cuda'):
|
||||||
|
super().__init__(chunk_size, device)
|
||||||
|
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
|
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||||
|
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
begin_idx = index * self.chunk_size
|
||||||
|
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
|
||||||
|
|
||||||
|
x = self._fetch_data(begin_idx, end_idx).to(device=self.device, dtype=torch.long)
|
||||||
|
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(device=self.device, dtype=torch.long)
|
||||||
|
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
|
class SftDataset(BaseDataset):
|
||||||
|
def __init__(self, chunk_size, device='cuda'):
|
||||||
|
super().__init__(chunk_size, device)
|
||||||
|
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
|
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
begin_idx = index * self.chunk_size
|
||||||
|
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
|
||||||
|
|
||||||
|
x = self._fetch_data(begin_idx, end_idx, "sequence").to(device=self.device, dtype=torch.long)
|
||||||
|
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(device=self.device, dtype=torch.long)
|
||||||
|
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "mask").to(device=self.device, dtype=torch.bool)
|
||||||
|
|
||||||
|
return x, y, loss_mask
|
||||||
|
|
||||||
|
|
||||||
|
class DpoDataset(BaseDataset):
|
||||||
|
def __init__(self, chunk_size: int, device="cuda"):
|
||||||
|
super().__init__(chunk_size, device)
|
||||||
|
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
|
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||||
|
|
||||||
|
def __getitem__(self, index: int):
|
||||||
|
start_idx = index * self.chunk_size
|
||||||
|
end_idx = min(start_idx + self.chunk_size, self.total_samples - 1)
|
||||||
|
|
||||||
|
chosen = self._fetch_data(start_idx, end_idx, "chosen").to(device=self.device, dtype=torch.long)
|
||||||
|
rejected = self._fetch_data(start_idx, end_idx, "rejected").to(device=self.device, dtype=torch.long)
|
||||||
|
chosen_mask = self._fetch_data(start_idx, end_idx, "chosen_mask").to(device=self.device, dtype=torch.bool)
|
||||||
|
rejected_mask = self._fetch_data(start_idx, end_idx, "rejected_mask").to(device=self.device, dtype=torch.bool)
|
||||||
|
|
||||||
|
return chosen, rejected, chosen_mask, rejected_mask
|
||||||
|
|
||||||
|
|
||||||
|
class PpoDataset(BaseDataset):
|
||||||
|
def __init__(self, chunk_size: int, device="cuda"):
|
||||||
|
super().__init__(chunk_size, device)
|
||||||
|
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
|
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||||
|
|
||||||
|
begin_idx = index * self.chunk_size
|
||||||
|
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
|
||||||
|
|
||||||
|
|
||||||
|
input_ids = self._fetch_data(begin_idx, end_idx, "input_ids").to(self.device),
|
||||||
|
actions = self._fetch_data(begin_idx, end_idx, "actions").to(self.device),
|
||||||
|
logprobs = self._fetch_data(begin_idx, end_idx, "logprobs").to(self.device),
|
||||||
|
rewards = self._fetch_data(begin_idx, end_idx, "rewards").to(self.device)
|
||||||
|
|
||||||
|
return input_ids, actions, logprobs, rewards
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetLoader:
|
||||||
|
@staticmethod
|
||||||
|
def load(
|
||||||
|
train_type: Literal["seq", "sft", "dpo"],
|
||||||
|
load_path: Union[str, List[str]],
|
||||||
|
max_len: int,
|
||||||
|
device: str
|
||||||
|
) -> BaseDataset:
|
||||||
|
|
||||||
|
dataset_router: Dict[str, Callable[[int, torch.device], BaseDataset]] = {
|
||||||
|
"seq": lambda m_len, device: SeqDataset(m_len, device=device),
|
||||||
|
"sft": lambda m_len, device: SftDataset(m_len, device=device),
|
||||||
|
"dpo": lambda m_len, device: DpoDataset(m_len, device=device),
|
||||||
|
}
|
||||||
|
dataset = dataset_router[train_type](max_len, device)
|
||||||
|
dataset.load(load_path)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
@ -0,0 +1,55 @@
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from abc import abstractmethod
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class MaskBuilder:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
bos_token_id: int,
|
||||||
|
eos_token_id: int,
|
||||||
|
user_token_id: int,
|
||||||
|
system_token_id: int,
|
||||||
|
|
||||||
|
):
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.user_token_id = user_token_id
|
||||||
|
self.system_token_id = system_token_id
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def build(input_ids: Tensor) -> Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LossMaskBuilder(MaskBuilder):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def build(self, input_ids: Tensor) -> Tensor:
|
||||||
|
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
|
||||||
|
|
||||||
|
is_user_token = input_ids.eq(self.user_token_id)
|
||||||
|
is_system_token = input_ids.eq(self.system_token_id)
|
||||||
|
|
||||||
|
token_markers[is_user_token] = 1
|
||||||
|
token_markers[is_system_token] = -1
|
||||||
|
|
||||||
|
cumulative_markers = torch.cumsum(token_markers, dim=-1)
|
||||||
|
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
|
||||||
|
loss_mask = cumulative_markers - min_cumulative
|
||||||
|
|
||||||
|
return loss_mask
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionMaskBuilder:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def build(input_ids: Tensor):
|
||||||
|
bsz = input_ids.size(0)
|
||||||
|
|
@ -0,0 +1,388 @@
|
||||||
|
import copy
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from typing import Any, Literal, Tuple, Callable, Dict
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import asdict, dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id):
|
||||||
|
input_mask = input_ids.ne(pad_token_id)
|
||||||
|
logits = model(input_ids, input_mask)["logits"]
|
||||||
|
log_probs = torch.log_softmax(logits, dim=-1)
|
||||||
|
|
||||||
|
shifted_log_probs = log_probs[:, :-1, :]
|
||||||
|
shifted_input_ids = input_ids[:, 1:]
|
||||||
|
shifted_response_mask = mask[:, 1:]
|
||||||
|
|
||||||
|
token_logprobs = torch.gather(
|
||||||
|
shifted_log_probs,
|
||||||
|
dim=-1,
|
||||||
|
index=shifted_input_ids.unsqueeze(-1)
|
||||||
|
).squeeze(-1)
|
||||||
|
|
||||||
|
prompt_mask = input_mask[:, 1:]
|
||||||
|
valid_mask = (prompt_mask & shifted_response_mask).float()
|
||||||
|
|
||||||
|
return (token_logprobs * valid_mask).sum(dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
class MaskBuilder:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
bos_token_id: int,
|
||||||
|
eos_token_id: int,
|
||||||
|
user_token_id: int,
|
||||||
|
system_token_id: int,
|
||||||
|
|
||||||
|
):
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.user_token_id = user_token_id
|
||||||
|
self.system_token_id = system_token_id
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def build(input_ids: Tensor) -> Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LossMaskBuilder(MaskBuilder):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def build(self, input_ids: Tensor) -> Tensor:
|
||||||
|
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
|
||||||
|
|
||||||
|
is_user_token = input_ids.eq(self.user_token_id)
|
||||||
|
is_system_token = input_ids.eq(self.system_token_id)
|
||||||
|
|
||||||
|
token_markers[is_user_token] = 1
|
||||||
|
token_markers[is_system_token] = -1
|
||||||
|
|
||||||
|
cumulative_markers = torch.cumsum(token_markers, dim=-1)
|
||||||
|
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
|
||||||
|
loss_mask = cumulative_markers - min_cumulative
|
||||||
|
|
||||||
|
return loss_mask.to(dtype=torch.bool)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionMaskBuilder(MaskBuilder):
|
||||||
|
def __init__(self, multi_turn=False, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.multi_turn = multi_turn
|
||||||
|
|
||||||
|
def build(self, input_ids: Tensor):
|
||||||
|
bsz = input_ids.size(0)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_batch(self, input_ids: Tensor):
|
||||||
|
is_user_token = input_ids.eq(self.user_token_id)
|
||||||
|
|
||||||
|
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
|
||||||
|
token_markers[is_user_token] = 1
|
||||||
|
cumulative_markers = torch.cumsum(token_markers, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class BaseStrategy(ABC):
|
||||||
|
def __init__(self, model: nn.Module):
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __call__(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
||||||
|
return self.compute_loss(batch)
|
||||||
|
|
||||||
|
|
||||||
|
class SeqStrategy(BaseStrategy):
|
||||||
|
def __init__(self, model):
|
||||||
|
super().__init__(model)
|
||||||
|
|
||||||
|
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
||||||
|
x, y = batch
|
||||||
|
B, L = x.size()
|
||||||
|
logits: Tensor = self.model(x)["logits"]
|
||||||
|
|
||||||
|
loss = F.cross_entropy(
|
||||||
|
logits.view(B * L, -1), y.flatten()
|
||||||
|
)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class SftStrategy(BaseStrategy):
|
||||||
|
def __init__(self, model):
|
||||||
|
super().__init__(model)
|
||||||
|
|
||||||
|
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
||||||
|
x, y, loss_mask = batch
|
||||||
|
B, L = x.size()
|
||||||
|
ignore_idx = -1
|
||||||
|
|
||||||
|
logits: Tensor = self.model(x)["logits"]
|
||||||
|
masked_y = y.masked_fill(loss_mask == 0, ignore_idx)
|
||||||
|
|
||||||
|
loss = F.cross_entropy(
|
||||||
|
logits.view(B * L, -1),
|
||||||
|
masked_y.flatten(),
|
||||||
|
ignore_index=ignore_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
class DpoStrategy(BaseStrategy):
|
||||||
|
def __init__(self, model, pad_token_id, beta):
|
||||||
|
super().__init__(model)
|
||||||
|
ref_model = copy.deepcopy(self.model)
|
||||||
|
ref_model.requires_grad_(False)
|
||||||
|
ref_model.eval()
|
||||||
|
|
||||||
|
self.ref_model = ref_model
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.beta = beta
|
||||||
|
|
||||||
|
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
||||||
|
good_ids, bad_ids, good_mask, bad_mask = batch
|
||||||
|
|
||||||
|
log_pi_good = get_logprobs(self.model, good_ids, good_mask, self.pad_token_id)
|
||||||
|
log_pi_bad = get_logprobs(self.model, bad_ids, bad_mask, self.pad_token_id)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
log_ref_good = get_logprobs(self.ref_model, good_ids, good_mask, self.pad_token_id)
|
||||||
|
log_ref_bad = get_logprobs(self.ref_model, bad_ids, bad_mask, self.pad_token_id)
|
||||||
|
|
||||||
|
pi_log_ratio = log_pi_good - log_pi_bad
|
||||||
|
ref_log_ratio = log_ref_good - log_ref_bad
|
||||||
|
|
||||||
|
ratio_diff = pi_log_ratio - ref_log_ratio
|
||||||
|
|
||||||
|
dpo_loss = -F.logsigmoid(self.beta * ratio_diff).mean()
|
||||||
|
return dpo_loss
|
||||||
|
|
||||||
|
|
||||||
|
class PpoStrategy(BaseStrategy):
|
||||||
|
def __init__(self, model, pad_token_id, epsilon):
|
||||||
|
super().__init__(model)
|
||||||
|
ref_model = copy.deepcopy(self.model)
|
||||||
|
ref_model.requires_grad_(False)
|
||||||
|
ref_model.eval()
|
||||||
|
|
||||||
|
self.ref_model = ref_model
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.epsilon = epsilon
|
||||||
|
|
||||||
|
def ppo_clip_loss_masked(
|
||||||
|
self,
|
||||||
|
log_probs: Tensor,
|
||||||
|
old_log_probs: Tensor,
|
||||||
|
advantages: Tensor,
|
||||||
|
values: Tensor,
|
||||||
|
returns: Tensor,
|
||||||
|
mask: Tensor,
|
||||||
|
clip_eps: float=0.2,
|
||||||
|
):
|
||||||
|
ratio = torch.exp(log_probs - old_log_probs)
|
||||||
|
surr1 = ratio * advantages
|
||||||
|
surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * advantages
|
||||||
|
policy_loss = -torch.min(surr1, surr2).masked_select(mask).mean()
|
||||||
|
|
||||||
|
value_loss = F.mse_loss(values.masked_select(mask),
|
||||||
|
returns.masked_select(mask))
|
||||||
|
|
||||||
|
entropy = -(log_probs.exp() * log_probs).masked_select(mask).mean()
|
||||||
|
entropy_loss = -entropy
|
||||||
|
return policy_loss, value_loss, entropy_loss
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class StrategyFactory:
|
||||||
|
|
||||||
|
def load(model, train_type, pad_token_id, dpo_beta):
|
||||||
|
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
|
||||||
|
"seq": lambda: SeqStrategy(model),
|
||||||
|
"sft": lambda: SftStrategy(model),
|
||||||
|
"dpo": lambda: DpoStrategy(model, pad_token_id, dpo_beta)
|
||||||
|
}
|
||||||
|
strategy = train_strategy[train_type]()
|
||||||
|
return strategy
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainConfig:
|
||||||
|
train_type: str = field(
|
||||||
|
default_factory=["seq", "sft", "dpo"],
|
||||||
|
metadata={"help": "Type of training."}
|
||||||
|
)
|
||||||
|
dataset: Dataset = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Dataset for training."}
|
||||||
|
)
|
||||||
|
optimizer: Optimizer = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Optimizer for training."}
|
||||||
|
)
|
||||||
|
ckpt_dir: str = field(
|
||||||
|
default="./checkpoint",
|
||||||
|
metadata={"help": "Checkpoint directory."}
|
||||||
|
)
|
||||||
|
n_epoch: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "Number of epochs for training."}
|
||||||
|
)
|
||||||
|
batch_size: int = field(
|
||||||
|
default=4,
|
||||||
|
metadata={"help": "Batch size for training."}
|
||||||
|
)
|
||||||
|
n_iter_ckpt: int = field(
|
||||||
|
default=5000,
|
||||||
|
metadata={"help": "Number of iterations between checkpoints."}
|
||||||
|
)
|
||||||
|
n_iter_step: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "Number of iterations between steps."}
|
||||||
|
)
|
||||||
|
max_grad_norm: float = field(
|
||||||
|
default=1.0,
|
||||||
|
metadata={"help": "Maximum gradient norm."}
|
||||||
|
)
|
||||||
|
random_seed: int = field(
|
||||||
|
default=3407,
|
||||||
|
metadata={"help": "Random seed."}
|
||||||
|
)
|
||||||
|
dpo_beta: float = field(
|
||||||
|
default=0.1,
|
||||||
|
metadata={"help": "DPO beta."}
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_kwargs(self)-> Dict[str, Any]:
|
||||||
|
config_dict = asdict(self)
|
||||||
|
return {k: v for k, v in config_dict.items() if v is not None}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ScheduleConfig:
|
||||||
|
schedule_type: str = field(
|
||||||
|
default_factory=["cosine", "sgdr"],
|
||||||
|
metadata={"help": "Type of learning rate schedule."}
|
||||||
|
)
|
||||||
|
warning_step: int = field(
|
||||||
|
default=1000,
|
||||||
|
metadata= {"help": "Warning up step."}
|
||||||
|
)
|
||||||
|
@abstractmethod
|
||||||
|
def get_kwargs(self)-> Dict[str, Any]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CosineScheduleConfig(ScheduleConfig):
|
||||||
|
total_iters: int = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Total iterations for cosine schedule."}
|
||||||
|
)
|
||||||
|
min_rate: float = field(
|
||||||
|
default=0.05,
|
||||||
|
metadata={"help": "Minimum rate for cosine schedule."}
|
||||||
|
)
|
||||||
|
schedule_type: Literal["cosine"] = "cosine"
|
||||||
|
|
||||||
|
def get_kwargs(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"schedule_type": self.schedule_type,
|
||||||
|
"warning_step": self.warning_step,
|
||||||
|
"lr_decay_iters": self.total_iters - self.warning_step,
|
||||||
|
"min_rate": self.min_rate
|
||||||
|
}
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SgdrScheduleConfig(ScheduleConfig):
|
||||||
|
cycle_length: int = field(
|
||||||
|
default=1000,
|
||||||
|
metadata={"help": "Cycle length for sgdr schedule."}
|
||||||
|
)
|
||||||
|
min_rate: float = field(
|
||||||
|
default=0.05,
|
||||||
|
metadata={"help": "Minimum rate for sgdr schedule."}
|
||||||
|
)
|
||||||
|
T_mult: int = field(
|
||||||
|
default=2,
|
||||||
|
metadata={"help": "T_mult for sgdr schedule."}
|
||||||
|
)
|
||||||
|
schedule_type: Literal["sgdr"] = "sgdr"
|
||||||
|
|
||||||
|
def get_kwargs(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"schedule_type": self.schedule_type,
|
||||||
|
"warning_step": self.warning_step,
|
||||||
|
"cycle_length": self.cycle_length,
|
||||||
|
"min_rate": self.min_rate,
|
||||||
|
"T_mult": self.T_mult
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerFactory:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_sgdr_schedule(
|
||||||
|
warning_step: int,
|
||||||
|
cycle_length: int,
|
||||||
|
min_rate: float = 0.1,
|
||||||
|
T_mult: int = 2
|
||||||
|
) -> Callable[[int], float]:
|
||||||
|
|
||||||
|
def sgdr_schedule(now_iter: int) -> float:
|
||||||
|
if now_iter < warning_step:
|
||||||
|
return max(min_rate, now_iter / warning_step)
|
||||||
|
|
||||||
|
adjusted_iter = now_iter - warning_step
|
||||||
|
total_cycles, current_cycle = 0, 0
|
||||||
|
while adjusted_iter >= cycle_length * (T_mult ** total_cycles):
|
||||||
|
current_cycle += 1
|
||||||
|
total_cycles += 1
|
||||||
|
|
||||||
|
cycle_start = sum(cycle_length * (T_mult ** i) for i in range(current_cycle))
|
||||||
|
cycle_pos = adjusted_iter - cycle_start
|
||||||
|
|
||||||
|
cycle_length_current = cycle_length * (T_mult ** current_cycle)
|
||||||
|
return max(min_rate, 0.5 * (1 + math.cos(math.pi * cycle_pos / cycle_length_current)))
|
||||||
|
|
||||||
|
return sgdr_schedule
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_cosine_warmup_schedule(
|
||||||
|
warning_step: int,
|
||||||
|
lr_decay_iters: int,
|
||||||
|
min_rate: float = 0.1
|
||||||
|
) -> Callable[[int], float]:
|
||||||
|
|
||||||
|
def cosine_warmup_schedule(now_iter: int) -> float:
|
||||||
|
if now_iter <= warning_step:
|
||||||
|
return max(min_rate, now_iter / warning_step)
|
||||||
|
else:
|
||||||
|
rate = (now_iter - warning_step) / (lr_decay_iters - warning_step)
|
||||||
|
return max(min_rate, 0.5 * (1.0 + math.cos(math.pi * rate)))
|
||||||
|
|
||||||
|
return cosine_warmup_schedule
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_schedule_fn(**kwargs):
|
||||||
|
strategy = kwargs.pop("schedule_type")
|
||||||
|
if strategy == "cosine":
|
||||||
|
return SchedulerFactory.get_cosine_warmup_schedule(**kwargs)
|
||||||
|
elif strategy == "sgdr":
|
||||||
|
return SchedulerFactory.get_sgdr_schedule(**kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid schedule type: {strategy}")
|
||||||
|
|
||||||
|
|
@ -0,0 +1,167 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
from torch.nn.utils import clip_grad_norm_
|
||||||
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
|
from torch.utils.data import DataLoader, RandomSampler
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from khaosz.core import ModelParameter, Checkpoint
|
||||||
|
from khaosz.trainer.strategy import SchedulerFactory, StrategyFactory, TrainConfig, ScheduleConfig
|
||||||
|
|
||||||
|
|
||||||
|
class Trainer:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parameter: ModelParameter,
|
||||||
|
log_path: str="./train_log.log"
|
||||||
|
):
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logger.setLevel(level = logging.INFO)
|
||||||
|
handler = logging.FileHandler(log_path)
|
||||||
|
handler.setLevel(logging.INFO)
|
||||||
|
handler.setFormatter(logging.Formatter('%(asctime)s: %(message)s'))
|
||||||
|
logger.addHandler(handler)
|
||||||
|
logger.info("initializing trainer ...")
|
||||||
|
|
||||||
|
self.logger = logger
|
||||||
|
self.model = parameter.model
|
||||||
|
self.tokenizer = parameter.tokenizer
|
||||||
|
self.config = parameter.config
|
||||||
|
|
||||||
|
def save_checkpoint(
|
||||||
|
self,
|
||||||
|
loss_list: list,
|
||||||
|
ckpt_dir: str,
|
||||||
|
current_iter: int,
|
||||||
|
last_ckpt_iter: int
|
||||||
|
):
|
||||||
|
save_path = os.path.join(ckpt_dir, f"iter_{current_iter}")
|
||||||
|
Checkpoint(
|
||||||
|
self.model,
|
||||||
|
self.tokenizer,
|
||||||
|
self.config,
|
||||||
|
loss_list,
|
||||||
|
current_iter
|
||||||
|
).save(save_path)
|
||||||
|
|
||||||
|
diff_iter = current_iter - last_ckpt_iter
|
||||||
|
avg_loss = sum(loss_list[last_ckpt_iter:current_iter]) / diff_iter
|
||||||
|
self.logger.info(f"iter: {current_iter} loss: {avg_loss}")
|
||||||
|
|
||||||
|
return current_iter
|
||||||
|
|
||||||
|
def load_checkpoint(self, train_checkpoint: Checkpoint) -> Tuple[list, int]:
|
||||||
|
self.model = train_checkpoint.model
|
||||||
|
self.tokenizer = train_checkpoint.tokenizer
|
||||||
|
self.config = train_checkpoint.config
|
||||||
|
loss_list = train_checkpoint.loss_list
|
||||||
|
last_ckpt_iter = train_checkpoint.current_iter
|
||||||
|
|
||||||
|
return loss_list, last_ckpt_iter
|
||||||
|
|
||||||
|
def train(
|
||||||
|
self,
|
||||||
|
train_config: TrainConfig,
|
||||||
|
schedule_config: ScheduleConfig,
|
||||||
|
train_checkpoint: Checkpoint = None
|
||||||
|
):
|
||||||
|
assert schedule_config.schedule_type in ["cosine", "sgdr"]
|
||||||
|
assert train_config.train_type in ["seq", "sft", "dpo"]
|
||||||
|
|
||||||
|
if train_checkpoint:
|
||||||
|
loss_list, last_ckpt_iter = self.load_checkpoint(train_checkpoint)
|
||||||
|
current_iter = train_checkpoint.current_iter + 1
|
||||||
|
self.logger.info(f"Resuming training from checkpoint: iter {current_iter}")
|
||||||
|
else:
|
||||||
|
current_iter = 0
|
||||||
|
last_ckpt_iter = 0
|
||||||
|
loss_list = []
|
||||||
|
|
||||||
|
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
|
||||||
|
**schedule_config.get_kwargs()
|
||||||
|
)
|
||||||
|
|
||||||
|
strategy = StrategyFactory.load(
|
||||||
|
self.model,
|
||||||
|
train_config.train_type,
|
||||||
|
self.tokenizer.pad_id,
|
||||||
|
train_config.dpo_beta
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler = LambdaLR(
|
||||||
|
train_config.optimizer,
|
||||||
|
lambda_scheduler_fn,
|
||||||
|
last_epoch=current_iter - 1 if train_checkpoint else -1
|
||||||
|
)
|
||||||
|
|
||||||
|
seed = train_config.random_seed
|
||||||
|
generator = torch.Generator().manual_seed(seed)
|
||||||
|
sampler = RandomSampler(train_config.dataset, generator=generator)
|
||||||
|
remaining_epochs = train_config.n_epoch - current_iter // (len(train_config.dataset) // train_config.batch_size)
|
||||||
|
|
||||||
|
self.logger.info(f"Starting {train_config.train_type.upper()} training for {train_config.n_epoch} epochs")
|
||||||
|
self.logger.info(f"Checkpoint interval: {train_config.n_iter_ckpt} iterations")
|
||||||
|
|
||||||
|
for epoch in range(remaining_epochs):
|
||||||
|
self.model.train()
|
||||||
|
dataloader = DataLoader(
|
||||||
|
train_config.dataset,
|
||||||
|
batch_size=train_config.batch_size,
|
||||||
|
sampler=sampler
|
||||||
|
)
|
||||||
|
progress_bar = tqdm(
|
||||||
|
dataloader,
|
||||||
|
desc=f"Epoch {epoch+1}/{train_config.n_epoch}",
|
||||||
|
dynamic_ncols=True
|
||||||
|
)
|
||||||
|
for batch in progress_bar:
|
||||||
|
#forward
|
||||||
|
loss = strategy(batch)
|
||||||
|
loss_list.append(loss.item())
|
||||||
|
#backward
|
||||||
|
loss.backward()
|
||||||
|
#step
|
||||||
|
if current_iter % train_config.n_iter_step == 0:
|
||||||
|
clip_grad_norm_(
|
||||||
|
self.model.parameters(),
|
||||||
|
train_config.max_grad_norm
|
||||||
|
)
|
||||||
|
train_config.optimizer.step()
|
||||||
|
train_config.optimizer.zero_grad()
|
||||||
|
|
||||||
|
current_iter += 1
|
||||||
|
scheduler.step()
|
||||||
|
progress_bar.set_postfix({
|
||||||
|
"loss": f"{loss.item():.4f}",
|
||||||
|
"lr": f"{train_config.optimizer.param_groups[0]['lr']:.2e}"
|
||||||
|
})
|
||||||
|
#save checkpotint
|
||||||
|
if current_iter - last_ckpt_iter >= train_config.n_iter_ckpt:
|
||||||
|
last_ckpt_iter = self.save_checkpoint(
|
||||||
|
loss_list,
|
||||||
|
train_config.ckpt_dir,
|
||||||
|
current_iter,
|
||||||
|
last_ckpt_iter
|
||||||
|
)
|
||||||
|
|
||||||
|
if current_iter != last_ckpt_iter:
|
||||||
|
last_ckpt_iter = self.save_checkpoint(
|
||||||
|
loss_list,
|
||||||
|
train_config.ckpt_dir,
|
||||||
|
current_iter,
|
||||||
|
last_ckpt_iter
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.info("Training completed")
|
||||||
|
|
||||||
|
return Checkpoint(
|
||||||
|
self.model,
|
||||||
|
self.tokenizer,
|
||||||
|
self.config,
|
||||||
|
loss_list,
|
||||||
|
current_iter,
|
||||||
|
train_config.optimizer
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,88 @@
|
||||||
|
import torch
|
||||||
|
import sqlite3
|
||||||
|
import numpy as np
|
||||||
|
from torch import Tensor
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class Retriever:
|
||||||
|
def __init__(self, db_path=None):
|
||||||
|
self.data: Dict[str, Tensor] = {}
|
||||||
|
self.embedding_cache: Tensor = None
|
||||||
|
self.is_caculated: bool = False
|
||||||
|
|
||||||
|
if db_path is not None:
|
||||||
|
self.load(db_path)
|
||||||
|
|
||||||
|
def retrieve(self, query: Tensor, top_k: int) -> List[Tuple[str, float]]:
|
||||||
|
if not self.data:
|
||||||
|
return []
|
||||||
|
|
||||||
|
query = query.flatten().unsqueeze(1) # [dim, 1]
|
||||||
|
norm_embeddings = self._embeddings.to(
|
||||||
|
device=query.device,
|
||||||
|
dtype=query.dtype
|
||||||
|
) # [n_vectors, dim]
|
||||||
|
sim_scores = torch.matmul(norm_embeddings, query).squeeze() # [n_vectors]
|
||||||
|
|
||||||
|
top_k = min(top_k, len(self.data))
|
||||||
|
indices = sim_scores.topk(top_k).indices
|
||||||
|
keys = list(self.data.keys())
|
||||||
|
|
||||||
|
return [(keys[i], sim_scores[i].item()) for i in indices]
|
||||||
|
|
||||||
|
def add_vector(self, key: str, vector_data: Tensor):
|
||||||
|
self.is_caculated = False
|
||||||
|
self.data[key] = vector_data.flatten().float().cpu()
|
||||||
|
|
||||||
|
def delete_vector(self, key: str):
|
||||||
|
self.is_caculated = False
|
||||||
|
self.data.pop(key, None)
|
||||||
|
|
||||||
|
def save(self, db_path):
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
self._init_db(cursor)
|
||||||
|
cursor.execute('DELETE FROM vectors')
|
||||||
|
|
||||||
|
for item, vec in self.data.items():
|
||||||
|
vec_bytes = vec.numpy().tobytes()
|
||||||
|
cursor.execute('INSERT OR REPLACE INTO vectors (key, vector) VALUES (?, ?)',
|
||||||
|
(item, vec_bytes))
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def load(self, db_path):
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
self._init_db(cursor)
|
||||||
|
cursor.execute('SELECT key, vector FROM vectors')
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
self.data = {}
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
key, vec_bytes = row
|
||||||
|
vec_numpy = np.frombuffer(vec_bytes, dtype=np.float32).copy()
|
||||||
|
vec = torch.from_numpy(vec_numpy)
|
||||||
|
self.data[key] = vec
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def _init_db(self,cursor: sqlite3.Cursor):
|
||||||
|
# Create table if not exists (in case loading from a new database)
|
||||||
|
cursor.execute('''
|
||||||
|
CREATE TABLE IF NOT EXISTS vectors (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
key TEXT UNIQUE NOT NULL,
|
||||||
|
vector BLOB NOT NULL
|
||||||
|
)''')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _embeddings(self) -> Tensor:
|
||||||
|
if not self.is_caculated:
|
||||||
|
embeddings = torch.stack(list(self.data.values()))
|
||||||
|
norm_embeddings = embeddings / torch.norm(embeddings, dim=-1, keepdim=True)
|
||||||
|
self.embedding_cache = norm_embeddings
|
||||||
|
|
||||||
|
return self.embedding_cache
|
||||||
|
|
@ -0,0 +1,127 @@
|
||||||
|
import re
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from torch import Tensor
|
||||||
|
from typing import List, Callable, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTextSplitter(ABC):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_len: int = 512,
|
||||||
|
chunk_overlap: int = 0,
|
||||||
|
):
|
||||||
|
if max_len <= 0:
|
||||||
|
raise ValueError("max_len must be > 0")
|
||||||
|
if chunk_overlap < 0:
|
||||||
|
raise ValueError("chunk_overlap must be >= 0")
|
||||||
|
|
||||||
|
self.max_len = max_len
|
||||||
|
self.chunk_overlap = chunk_overlap
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def split(self, text: str, **kwargs) -> List[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def preprocess(self, text: str) -> str:
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
def postprocess(self, chunks: List[str]) -> List[str]:
|
||||||
|
return [chunk.strip() for chunk in chunks if chunk.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
class PriorityTextSplitter(BaseTextSplitter):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
separators: List[str],
|
||||||
|
max_len: int = 512,
|
||||||
|
chunk_overlap: int = 0,
|
||||||
|
):
|
||||||
|
super().__init__(max_len=max_len, chunk_overlap=chunk_overlap)
|
||||||
|
if not separators:
|
||||||
|
raise ValueError("separators must be a non-empty list")
|
||||||
|
self.separators = separators
|
||||||
|
|
||||||
|
def split(self, text: str) -> List[str]:
|
||||||
|
text = self.preprocess(text)
|
||||||
|
for sep in self.separators:
|
||||||
|
parts = text.split(sep)
|
||||||
|
|
||||||
|
valid_parts = [p.strip() for p in parts if p.strip()]
|
||||||
|
if len(valid_parts) > 1:
|
||||||
|
return self.postprocess(valid_parts)
|
||||||
|
return [text]
|
||||||
|
|
||||||
|
|
||||||
|
class SemanticTextSplitter(BaseTextSplitter):
|
||||||
|
|
||||||
|
DEFAULT_PATTERN = r'(?<=[。!?!?])(?=(?:[^"\'‘’“”]*["\'‘’“”][^"\'‘’“”]*["\'‘’“”])*[^"\'‘’“”]*$)'
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_func: Callable[[List[str]], List[Tensor]],
|
||||||
|
pattern: Optional[str] = None,
|
||||||
|
max_len: int = 512,
|
||||||
|
chunk_overlap: int = 0,
|
||||||
|
):
|
||||||
|
super().__init__(max_len=max_len, chunk_overlap=chunk_overlap)
|
||||||
|
if not callable(embedding_func):
|
||||||
|
raise TypeError("embedding_func must be callable")
|
||||||
|
self.embedding_func = embedding_func
|
||||||
|
self.pattern = pattern or SemanticTextSplitter.DEFAULT_PATTERN
|
||||||
|
|
||||||
|
def split(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
threshold: float = 0.5,
|
||||||
|
window_size: int = 1,
|
||||||
|
) -> List[str]:
|
||||||
|
text = self.preprocess(text)
|
||||||
|
sentences = [s.strip() for s in re.split(self.pattern, text) if s.strip()]
|
||||||
|
|
||||||
|
if len(sentences) <= 1:
|
||||||
|
return self.postprocess(sentences)
|
||||||
|
|
||||||
|
try:
|
||||||
|
sentence_embs = self.embedding_func(sentences)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Embedding generation failed: {e}")
|
||||||
|
|
||||||
|
if len(sentence_embs) != len(sentences):
|
||||||
|
raise ValueError("Embedding function must return one vector per sentence")
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
emb_tensor = torch.stack(sentence_embs) # shape: [N, D]
|
||||||
|
current_chunk: List[str] = [sentences[0]]
|
||||||
|
|
||||||
|
for i in range(1, len(sentences)):
|
||||||
|
start_prev = max(0, i - window_size)
|
||||||
|
end_prev = i
|
||||||
|
start_next = i
|
||||||
|
end_next = min(len(sentences), i + window_size)
|
||||||
|
|
||||||
|
prev_window_emb = emb_tensor[start_prev:end_prev].mean(dim=0)
|
||||||
|
next_window_emb = emb_tensor[start_next:end_next].mean(dim=0)
|
||||||
|
|
||||||
|
similarity = F.cosine_similarity(
|
||||||
|
prev_window_emb.unsqueeze(0),
|
||||||
|
next_window_emb.unsqueeze(0),
|
||||||
|
dim=1
|
||||||
|
).item()
|
||||||
|
|
||||||
|
dynamic_threshold = max(threshold * (1 - 0.03 * (end_next - start_prev)), 0.2)
|
||||||
|
|
||||||
|
if similarity < dynamic_threshold:
|
||||||
|
chunks.append(" ".join(current_chunk))
|
||||||
|
overlap_start = max(0, len(current_chunk) - self.chunk_overlap)
|
||||||
|
current_chunk = current_chunk[overlap_start:]
|
||||||
|
current_chunk.append(sentences[i])
|
||||||
|
else:
|
||||||
|
current_chunk.append(sentences[i])
|
||||||
|
|
||||||
|
if current_chunk:
|
||||||
|
chunks.append(" ".join(current_chunk))
|
||||||
|
|
||||||
|
return self.postprocess(chunks)
|
||||||
|
|
@ -0,0 +1,36 @@
|
||||||
|
# python=3.12
|
||||||
|
--extra-index-url https://download.pytorch.org/whl/cu126
|
||||||
|
|
||||||
|
certifi==2025.8.3
|
||||||
|
charset-normalizer==3.4.2
|
||||||
|
colorama==0.4.6
|
||||||
|
contourpy==1.3.3
|
||||||
|
cycler==0.12.1
|
||||||
|
filelock==3.13.1
|
||||||
|
fonttools==4.59.0
|
||||||
|
fsspec==2024.6.1
|
||||||
|
huggingface-hub==0.34.3
|
||||||
|
idna==3.10
|
||||||
|
Jinja2==3.1.6
|
||||||
|
kiwisolver==1.4.8
|
||||||
|
MarkupSafe==2.1.5
|
||||||
|
matplotlib==3.10.5
|
||||||
|
mpmath==1.3.0
|
||||||
|
networkx==3.3
|
||||||
|
numpy==2.3.2
|
||||||
|
packaging==25.0
|
||||||
|
pillow==11.3.0
|
||||||
|
pyparsing==3.2.3
|
||||||
|
python-dateutil==2.9.0.post0
|
||||||
|
PyYAML==6.0.2
|
||||||
|
requests==2.32.4
|
||||||
|
safetensors==0.5.3
|
||||||
|
setuptools==78.1.1
|
||||||
|
six==1.17.0
|
||||||
|
sympy==1.13.3
|
||||||
|
tokenizers==0.21.4
|
||||||
|
torch==2.7.1+cu126
|
||||||
|
tqdm==4.67.1
|
||||||
|
typing_extensions==4.12.2
|
||||||
|
urllib3==2.5.0
|
||||||
|
wheel==0.45.1
|
||||||
|
|
@ -0,0 +1,32 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from khaosz import Khaosz
|
||||||
|
|
||||||
|
|
||||||
|
PROJECT_ROOT = os.path.dirname(
|
||||||
|
os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
def chat():
|
||||||
|
model_dir = os.path.join(PROJECT_ROOT, "params")
|
||||||
|
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
histroy = []
|
||||||
|
while True:
|
||||||
|
query = input(">> ")
|
||||||
|
if query == "!exit":
|
||||||
|
break
|
||||||
|
|
||||||
|
response_size = 0
|
||||||
|
for response, histroy in model.stream_generate(
|
||||||
|
query=query,
|
||||||
|
history=histroy,
|
||||||
|
temperature=0.7,
|
||||||
|
top_p=0.95,
|
||||||
|
top_k=30
|
||||||
|
):
|
||||||
|
print(response[response_size:], end="", flush=True)
|
||||||
|
response_size = len(response)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
chat()
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
import os
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
|
||||||
|
PROJECT_ROOT = os.path.dirname(
|
||||||
|
os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="ViperEk/KHAOSZ",
|
||||||
|
local_dir=os.path.join(PROJECT_ROOT, "params"),
|
||||||
|
force_download=True
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from khaosz import Khaosz
|
||||||
|
|
||||||
|
|
||||||
|
PROJECT_ROOT = os.path.dirname(
|
||||||
|
os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
def generate_text():
|
||||||
|
model_dir = os.path.join(PROJECT_ROOT, "params")
|
||||||
|
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
query = input(">> ")
|
||||||
|
|
||||||
|
response = model.text_generate(
|
||||||
|
query=query,
|
||||||
|
temperature=0.6,
|
||||||
|
top_p=0.95,
|
||||||
|
top_k=30
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
generate_text()
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from khaosz import Khaosz
|
||||||
|
|
||||||
|
|
||||||
|
PROJECT_ROOT = os.path.dirname(
|
||||||
|
os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
def batch_generate():
|
||||||
|
model_dir = os.path.join(PROJECT_ROOT, "params")
|
||||||
|
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||||
|
inputs = ["你好", "请问什么是人工智能", "今天天气如何", "我感到焦虑, 请问我应该怎么办", "请问什么是显卡"]
|
||||||
|
|
||||||
|
responses = model.batch_generate(
|
||||||
|
queries=inputs,
|
||||||
|
temperature=0.7,
|
||||||
|
top_p=0.95,
|
||||||
|
top_k=30
|
||||||
|
)
|
||||||
|
|
||||||
|
for q, r in zip(inputs, responses):
|
||||||
|
print((q, r))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
batch_generate()
|
||||||
|
|
@ -0,0 +1,42 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from khaosz import Khaosz, SemanticTextSplitter, Retriever
|
||||||
|
|
||||||
|
|
||||||
|
PROJECT_ROOT = os.path.dirname(
|
||||||
|
os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
model_dir = os.path.join(PROJECT_ROOT, "params")
|
||||||
|
context_path = os.path.join(PROJECT_ROOT, "README.md")
|
||||||
|
|
||||||
|
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||||
|
spliter = SemanticTextSplitter(model.encode)
|
||||||
|
retriever = Retriever()
|
||||||
|
text = open(context_path, "r", encoding="utf-8").read()
|
||||||
|
|
||||||
|
res = spliter.split(text, threshold=0.8, window_size=1)
|
||||||
|
# print(("\n" + "+"*100 + "\n").join(res))
|
||||||
|
|
||||||
|
res_embs = model.encode(res)
|
||||||
|
for sentence, emb in zip(res, res_embs):
|
||||||
|
retriever.add_vector(sentence, emb)
|
||||||
|
|
||||||
|
retrive_top_k = 5
|
||||||
|
query = "作者设计了一个怎样的模型"
|
||||||
|
emb_query = model.encode(query)
|
||||||
|
retrieved = retriever.retrieve(emb_query, retrive_top_k)
|
||||||
|
|
||||||
|
retrive_response = model.retrieve_generate(
|
||||||
|
retrieved=retrieved,
|
||||||
|
query=query,
|
||||||
|
temperature=0.7,
|
||||||
|
top_k=30,
|
||||||
|
top_p=0.95,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("retrive content:")
|
||||||
|
print("\n".join([f"{idx + 1}. " + text for idx, (text, _) in enumerate(retrieved)]))
|
||||||
|
|
||||||
|
print("\n\nretrive generate:")
|
||||||
|
print(retrive_response)
|
||||||
|
|
@ -0,0 +1,18 @@
|
||||||
|
import re
|
||||||
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
|
||||||
|
with open("requirements.txt") as f:
|
||||||
|
required = [line for line in f.read().splitlines()
|
||||||
|
if line and re.match(r'^[^=]+==[^=]+$', line.strip())]
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="khaosz",
|
||||||
|
version="1.2.0",
|
||||||
|
packages=find_packages(),
|
||||||
|
install_requires=required,
|
||||||
|
dependency_links=[
|
||||||
|
"https://download.pytorch.org/whl/cu126",
|
||||||
|
],
|
||||||
|
python_requires="==3.12.*",
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,103 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import shutil
|
||||||
|
import pytest
|
||||||
|
import tempfile
|
||||||
|
import safetensors.torch as st
|
||||||
|
from khaosz.core import *
|
||||||
|
from khaosz.core.generator import EmbeddingEncoderCore, GeneratorCore
|
||||||
|
from tokenizers import pre_tokenizers
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_env():
|
||||||
|
test_dir = tempfile.mkdtemp()
|
||||||
|
config_path = os.path.join(test_dir, "config.json")
|
||||||
|
tokenizer_path = os.path.join(test_dir, "tokenizer.json")
|
||||||
|
model_path = os.path.join(test_dir, "model.safetensors")
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"vocab_size": 1000,
|
||||||
|
"n_dim": 128,
|
||||||
|
"n_head": 4,
|
||||||
|
"n_kvhead": 2,
|
||||||
|
"d_ffn": 256,
|
||||||
|
"m_len": 64,
|
||||||
|
"n_layer": 2,
|
||||||
|
"norm_eps": 1e-5
|
||||||
|
}
|
||||||
|
with open(config_path, 'w') as f:
|
||||||
|
json.dump(config, f)
|
||||||
|
|
||||||
|
tokenizer = BpeTokenizer()
|
||||||
|
sp_token_iter = iter(pre_tokenizers.ByteLevel.alphabet())
|
||||||
|
tokenizer.train_from_iterator(sp_token_iter, config["vocab_size"], 1)
|
||||||
|
tokenizer.save(tokenizer_path)
|
||||||
|
|
||||||
|
transformer_config = TransformerConfig().load(config_path)
|
||||||
|
model = Transformer(transformer_config)
|
||||||
|
st.save_file(model.state_dict(), model_path)
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"test_dir": test_dir,
|
||||||
|
"model": model,
|
||||||
|
"tokenizer": tokenizer,
|
||||||
|
"transformer_config": transformer_config,
|
||||||
|
}
|
||||||
|
|
||||||
|
shutil.rmtree(test_dir)
|
||||||
|
|
||||||
|
# parameter loader
|
||||||
|
def test_parameter_loader(test_env):
|
||||||
|
loaded_param = ParameterLoader.load(test_env["test_dir"])
|
||||||
|
assert loaded_param.model is not None
|
||||||
|
assert loaded_param.tokenizer is not None
|
||||||
|
assert loaded_param.config == test_env["transformer_config"]
|
||||||
|
|
||||||
|
def test_model_parameter(test_env):
|
||||||
|
save_dir = os.path.join(test_env["test_dir"], "save")
|
||||||
|
model_param = ModelParameter(test_env["model"],test_env["tokenizer"] , test_env["transformer_config"])
|
||||||
|
model_param.save(save_dir)
|
||||||
|
|
||||||
|
assert os.path.exists(os.path.join(save_dir, "model.safetensors"))
|
||||||
|
assert os.path.exists(os.path.join(save_dir, "tokenizer.json"))
|
||||||
|
assert os.path.exists(os.path.join(save_dir, "config.json"))
|
||||||
|
|
||||||
|
# transformer
|
||||||
|
def test_transformer(test_env):
|
||||||
|
model = test_env["model"]
|
||||||
|
input_ids = torch.randint(0, test_env["transformer_config"].vocab_size,
|
||||||
|
(4, test_env["transformer_config"].m_len))
|
||||||
|
output_logits = model(input_ids)["logits"]
|
||||||
|
target_shape = (4, test_env["transformer_config"].m_len, test_env["transformer_config"].vocab_size)
|
||||||
|
assert output_logits.shape == target_shape
|
||||||
|
|
||||||
|
# generator
|
||||||
|
def test_embedding_encoder_core(test_env):
|
||||||
|
parameter = ModelParameter(
|
||||||
|
test_env["model"],
|
||||||
|
test_env["tokenizer"],
|
||||||
|
test_env["transformer_config"]
|
||||||
|
)
|
||||||
|
encoder = EmbeddingEncoderCore(parameter)
|
||||||
|
|
||||||
|
single_emb = encoder.encode("测试文本")
|
||||||
|
assert isinstance(single_emb, torch.Tensor)
|
||||||
|
assert single_emb.shape[-1] == test_env["transformer_config"].n_dim
|
||||||
|
|
||||||
|
|
||||||
|
batch_emb = encoder.encode(["测试1", "测试2"])
|
||||||
|
assert isinstance(batch_emb, list)
|
||||||
|
assert len(batch_emb) == 2
|
||||||
|
|
||||||
|
def test_generator_core(test_env):
|
||||||
|
parameter = ModelParameter(
|
||||||
|
test_env["model"],
|
||||||
|
test_env["tokenizer"],
|
||||||
|
test_env["transformer_config"]
|
||||||
|
)
|
||||||
|
generator = GeneratorCore(parameter)
|
||||||
|
logits, incr = generator.compute_logits(torch.randint(0, test_env["transformer_config"].vocab_size, (4, 10)))
|
||||||
|
|
||||||
|
assert logits.shape == (4, test_env["transformer_config"].vocab_size)
|
||||||
|
assert incr == 10
|
||||||
|
|
@ -0,0 +1,203 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import shutil
|
||||||
|
import pytest
|
||||||
|
import pickle
|
||||||
|
import tempfile
|
||||||
|
import matplotlib
|
||||||
|
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from khaosz.core import *
|
||||||
|
from khaosz.trainer import *
|
||||||
|
|
||||||
|
# to avoid _tkinter.TclError
|
||||||
|
matplotlib.use('Agg')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_env():
|
||||||
|
test_dir = tempfile.mkdtemp()
|
||||||
|
config_path = os.path.join(test_dir, "config.json")
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"vocab_size": 1000,
|
||||||
|
"n_dim": 128,
|
||||||
|
"n_head": 4,
|
||||||
|
"n_kvhead": 2,
|
||||||
|
"d_ffn": 256,
|
||||||
|
"m_len": 64,
|
||||||
|
"n_layer": 2,
|
||||||
|
"norm_eps": 1e-5
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(config_path, 'w') as f:
|
||||||
|
json.dump(config, f)
|
||||||
|
|
||||||
|
transformer_config = TransformerConfig().load(config_path)
|
||||||
|
model = Transformer(transformer_config)
|
||||||
|
tokenizer = BpeTokenizer()
|
||||||
|
|
||||||
|
class DummyDataset(Dataset):
|
||||||
|
def __init__(self, length=10):
|
||||||
|
self.length = length
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return (
|
||||||
|
torch.randint(0, 1000, (64,)),
|
||||||
|
torch.randint(0, 1000, (64,))
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = DummyDataset()
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"test_dir": test_dir,
|
||||||
|
"config_path": config_path,
|
||||||
|
"transformer_config": transformer_config,
|
||||||
|
"model": model,
|
||||||
|
"tokenizer": tokenizer,
|
||||||
|
"dataset": dataset
|
||||||
|
}
|
||||||
|
|
||||||
|
shutil.rmtree(test_dir)
|
||||||
|
|
||||||
|
def test_dataset_loader(test_env):
|
||||||
|
test_dir = test_env["test_dir"]
|
||||||
|
pkl_path = os.path.join(test_dir, "test_data.pkl")
|
||||||
|
|
||||||
|
dummy_data = {"sequence": torch.randint(0, 1000, (64,))}
|
||||||
|
with open(pkl_path, "wb") as f:
|
||||||
|
pickle.dump(dummy_data, f)
|
||||||
|
|
||||||
|
loaded_dataset = DatasetLoader.load(train_type="seq", load_path=pkl_path, max_len=64, device="cpu")
|
||||||
|
assert loaded_dataset is not None
|
||||||
|
|
||||||
|
def test_training_config(test_env):
|
||||||
|
optimizer = torch.optim.AdamW(test_env["model"].parameters())
|
||||||
|
train_config = TrainConfig(
|
||||||
|
train_type="seq",
|
||||||
|
dataset=test_env["dataset"],
|
||||||
|
optimizer=optimizer,
|
||||||
|
ckpt_dir=test_env["test_dir"],
|
||||||
|
n_epoch=1,
|
||||||
|
batch_size=2,
|
||||||
|
n_iter_ckpt=5,
|
||||||
|
n_iter_step=1,
|
||||||
|
max_grad_norm=1.0,
|
||||||
|
random_seed=42
|
||||||
|
)
|
||||||
|
assert train_config.get_kwargs()["batch_size"] == 2
|
||||||
|
|
||||||
|
def test_cosine_schedule(test_env):
|
||||||
|
assert test_env is not None
|
||||||
|
schedule_config = CosineScheduleConfig(
|
||||||
|
warning_step=100,
|
||||||
|
total_iters=1000
|
||||||
|
)
|
||||||
|
kwargs = schedule_config.get_kwargs()
|
||||||
|
assert kwargs["warning_step"] == 100
|
||||||
|
assert kwargs["lr_decay_iters"] == 900
|
||||||
|
|
||||||
|
|
||||||
|
def test_sgdr_schedule(test_env):
|
||||||
|
assert test_env is not None
|
||||||
|
schedule_config = SgdrScheduleConfig(
|
||||||
|
warning_step=100,
|
||||||
|
cycle_length=200,
|
||||||
|
T_mult=2
|
||||||
|
)
|
||||||
|
kwargs = schedule_config.get_kwargs()
|
||||||
|
assert kwargs["warning_step"] == 100
|
||||||
|
assert kwargs["cycle_length"] == 200
|
||||||
|
assert kwargs["T_mult"] == 2
|
||||||
|
|
||||||
|
def test_trainer_train(test_env):
|
||||||
|
optimizer = torch.optim.AdamW(test_env["model"].parameters())
|
||||||
|
train_config = TrainConfig(
|
||||||
|
train_type="seq",
|
||||||
|
dataset=test_env["dataset"],
|
||||||
|
optimizer=optimizer,
|
||||||
|
ckpt_dir=test_env["test_dir"],
|
||||||
|
n_epoch=1,
|
||||||
|
batch_size=2,
|
||||||
|
n_iter_ckpt=5,
|
||||||
|
n_iter_step=1,
|
||||||
|
max_grad_norm=1.0,
|
||||||
|
random_seed=42
|
||||||
|
)
|
||||||
|
schedule_config = CosineScheduleConfig(
|
||||||
|
warning_step=100,
|
||||||
|
total_iters=1000
|
||||||
|
)
|
||||||
|
model_parameter = ModelParameter(
|
||||||
|
test_env["model"],
|
||||||
|
test_env["tokenizer"],
|
||||||
|
test_env["transformer_config"]
|
||||||
|
)
|
||||||
|
trainer = Trainer(model_parameter)
|
||||||
|
trainer.train(train_config, schedule_config)
|
||||||
|
|
||||||
|
def test_checkpoint(test_env):
|
||||||
|
temp_dir = test_env["test_dir"]
|
||||||
|
config = test_env["transformer_config"]
|
||||||
|
model = test_env["model"]
|
||||||
|
tokenizer = test_env["tokenizer"]
|
||||||
|
|
||||||
|
param = ModelParameter(model, tokenizer, config)
|
||||||
|
checkpoint = Checkpoint(
|
||||||
|
model=param.model,
|
||||||
|
tokenizer=param.tokenizer,
|
||||||
|
config=param.config,
|
||||||
|
loss_list=[1.0, 2.0, 3.0],
|
||||||
|
current_iter=3
|
||||||
|
)
|
||||||
|
ckpt_dir = os.path.join(temp_dir, "ckpt")
|
||||||
|
checkpoint.save(ckpt_dir)
|
||||||
|
|
||||||
|
loaded_ckpt = Checkpoint()
|
||||||
|
loaded_ckpt.load(ckpt_dir)
|
||||||
|
|
||||||
|
assert loaded_ckpt.current_iter == 3
|
||||||
|
assert loaded_ckpt.loss_list == [1.0, 2.0, 3.0]
|
||||||
|
|
||||||
|
for p1, p2 in zip(model.parameters(), loaded_ckpt.model.parameters()):
|
||||||
|
assert torch.allclose(p1, p2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_checkpoint_train(test_env):
|
||||||
|
temp_dir = test_env["test_dir"]
|
||||||
|
config = test_env["transformer_config"]
|
||||||
|
model = test_env["model"]
|
||||||
|
tokenizer = test_env["tokenizer"]
|
||||||
|
dataset = test_env["dataset"]
|
||||||
|
|
||||||
|
param = ModelParameter(model, tokenizer, config)
|
||||||
|
trainer = Trainer(param)
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(test_env["model"].parameters())
|
||||||
|
train_config = TrainConfig(
|
||||||
|
train_type="seq",
|
||||||
|
dataset=dataset,
|
||||||
|
optimizer=optimizer,
|
||||||
|
ckpt_dir=test_env["test_dir"],
|
||||||
|
n_epoch=1,
|
||||||
|
batch_size=2,
|
||||||
|
n_iter_ckpt=5,
|
||||||
|
n_iter_step=1,
|
||||||
|
max_grad_norm=1.0,
|
||||||
|
random_seed=42
|
||||||
|
)
|
||||||
|
schedule_config = CosineScheduleConfig(
|
||||||
|
warning_step=100,
|
||||||
|
total_iters=1000
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.train(
|
||||||
|
train_config=train_config,
|
||||||
|
schedule_config=schedule_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,128 @@
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch.optim import AdamW
|
||||||
|
from khaosz.core import ParameterLoader
|
||||||
|
from khaosz.trainer import Trainer, DatasetLoader, TrainConfig, CosineScheduleConfig
|
||||||
|
|
||||||
|
|
||||||
|
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
def get_files(root_path: str) -> list[str]:
|
||||||
|
paths = []
|
||||||
|
for root, _, files in os.walk(root_path):
|
||||||
|
paths.extend([os.path.join(root, file) for file in files])
|
||||||
|
|
||||||
|
return paths
|
||||||
|
|
||||||
|
def train(
|
||||||
|
train_type: str,
|
||||||
|
param_path: str,
|
||||||
|
data_root_path: str,
|
||||||
|
n_epoch: int,
|
||||||
|
batch_size: int,
|
||||||
|
n_iter_step: int,
|
||||||
|
warning_step: int,
|
||||||
|
max_lr: int,
|
||||||
|
n_iter_ckpt: int,
|
||||||
|
ckpt_dir: str,
|
||||||
|
dpo_beta: float,
|
||||||
|
adamw_betas: tuple,
|
||||||
|
adamw_weight_decay: float,
|
||||||
|
max_grad_norm: float,
|
||||||
|
embdeding_lr_rate: int,
|
||||||
|
random_seed: int,
|
||||||
|
):
|
||||||
|
assert train_type in ["seq", "sft", "dpo"]
|
||||||
|
assert os.path.exists(param_path)
|
||||||
|
|
||||||
|
parameter = ParameterLoader.load(param_path)
|
||||||
|
model = parameter.model
|
||||||
|
|
||||||
|
device = torch.device("cuda")
|
||||||
|
model = model.to(device=device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
cache_files = get_files(data_root_path)
|
||||||
|
dataset = DatasetLoader.load(
|
||||||
|
train_type=train_type,
|
||||||
|
load_path=cache_files,
|
||||||
|
max_len=parameter.config.m_len,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
param_groups = [
|
||||||
|
{"params": [p for n, p in model.named_parameters() if "embedding" in n], "lr": max_lr * embdeding_lr_rate},
|
||||||
|
{"params": [p for n, p in model.named_parameters() if "embedding" not in n], "lr": max_lr}
|
||||||
|
]
|
||||||
|
|
||||||
|
optim = AdamW(
|
||||||
|
param_groups,
|
||||||
|
betas=adamw_betas,
|
||||||
|
weight_decay=adamw_weight_decay
|
||||||
|
)
|
||||||
|
|
||||||
|
train_config = TrainConfig(
|
||||||
|
train_type=train_type,
|
||||||
|
dataset=dataset,
|
||||||
|
optimizer=optim,
|
||||||
|
ckpt_dir=ckpt_dir,
|
||||||
|
n_epoch=n_epoch,
|
||||||
|
batch_size=batch_size,
|
||||||
|
n_iter_ckpt=n_iter_ckpt,
|
||||||
|
n_iter_step=n_iter_step,
|
||||||
|
max_grad_norm=max_grad_norm,
|
||||||
|
random_seed=random_seed,
|
||||||
|
dpo_beta=dpo_beta
|
||||||
|
)
|
||||||
|
|
||||||
|
schedule_config = CosineScheduleConfig(
|
||||||
|
warning_step=warning_step,
|
||||||
|
total_iters=len(dataset) * n_epoch // batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = Trainer(parameter)
|
||||||
|
trainer.train(
|
||||||
|
train_config=train_config,
|
||||||
|
schedule_config=schedule_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Train the Transformer model.")
|
||||||
|
parser.add_argument("--train_type",choices=["seq", "sft", "dpo"], help="Train type.")
|
||||||
|
parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.")
|
||||||
|
parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.")
|
||||||
|
parser.add_argument("--n_epoch", type=int, default=1, help="Number of epochs to train.")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training.")
|
||||||
|
parser.add_argument("--n_iter_step", type=int, default=1, help="Number of iterations between each optimizer step.")
|
||||||
|
parser.add_argument("--warning_step", type=int, default=1000, help="Number of iters between warnings.")
|
||||||
|
parser.add_argument("--max_lr", type=float, default=3e-4, help="Max learning rate for training.")
|
||||||
|
parser.add_argument("--n_iter_ckpt", type=int, default=5000, help="Number of iters between checkpoints.")
|
||||||
|
parser.add_argument("--ckpt_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")
|
||||||
|
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
||||||
|
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping.")
|
||||||
|
parser.add_argument("--adamw_betas", type=tuple, default=(0.9, 0.95), help="Beta values for AdamW optimizer.")
|
||||||
|
parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.")
|
||||||
|
parser.add_argument("--embdeding_lr_rate", type=float, default=1.0, help="The rate between the embedding layers lr rate and the max lr rate.")
|
||||||
|
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
train(
|
||||||
|
param_path=args.param_path,
|
||||||
|
data_root_path=args.data_root_path,
|
||||||
|
n_epoch=args.n_epoch,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
n_iter_step=args.n_iter_step,
|
||||||
|
warning_step=args.warning_step,
|
||||||
|
max_lr=args.max_lr,
|
||||||
|
dpo_beta=args.dpo_beta,
|
||||||
|
adamw_betas=args.adamw_betas,
|
||||||
|
adamw_weight_decay=args.adamw_weight_decay,
|
||||||
|
max_grad_norm=args.max_grad_norm,
|
||||||
|
embdeding_lr_rate=args.embdeding_lr_rate,
|
||||||
|
n_iter_ckpt=args.n_iter_ckpt,
|
||||||
|
ckpt_dir=args.ckpt_dir,
|
||||||
|
train_type=args.train_type,
|
||||||
|
random_seed=args.random_seed
|
||||||
|
)
|
||||||
Loading…
Reference in New Issue