From 8b6509b305e4dd0c70579c8d2e49f03250990b53 Mon Sep 17 00:00:00 2001
From: ViperEkura <3081035982@qq.com>
Date: Thu, 2 Apr 2026 20:11:19 +0800
Subject: [PATCH] =?UTF-8?q?docs:=20=E6=9B=B4=E6=96=B0=20design.md=20?=
=?UTF-8?q?=E9=A1=B9=E7=9B=AE=E7=BB=93=E6=9E=84=E5=92=8C=E6=A8=A1=E5=9D=97?=
=?UTF-8?q?=E6=96=87=E6=A1=A3?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
assets/docs/design.md | 201 ++++++++++++++++++++++--------------------
1 file changed, 105 insertions(+), 96 deletions(-)
diff --git a/assets/docs/design.md b/assets/docs/design.md
index e99328d..170731e 100644
--- a/assets/docs/design.md
+++ b/assets/docs/design.md
@@ -9,117 +9,126 @@ Thus, the AstrAI project was born - 1B parameters, Chinese-English bilingual, su
The system is divided into the following modules:
```mermaid
-graph LR
+flowchart TB
%% Style definitions
- classDef config fill:#e1f5fe,stroke:#01579b;
- classDef trainer fill:#f3e5f5,stroke:#4a148c;
- classDef data fill:#e8f5e8,stroke:#1b5e20;
- classDef model fill:#fff3e0,stroke:#e65100;
- classDef inference fill:#fce4ec,stroke:#880e4f;
- classDef parallel fill:#e0f2f1,stroke:#004d40;
+ classDef config fill:#e1f5fe,stroke:#01579b,stroke-width:2px;
+ classDef data fill:#e8f5e8,stroke:#1b5e20,stroke-width:2px;
+ classDef model fill:#fff3e0,stroke:#e65100,stroke-width:2px;
+ classDef trainer fill:#f3e5f5,stroke:#4a148c,stroke-width:2px;
+ classDef inference fill:#fce4ec,stroke:#880e4f,stroke-width:2px;
+ classDef parallel fill:#e0f2f1,stroke:#004d40,stroke-width:2px;
+ classDef scripts fill:#fffbe6,stroke:#f57f17,stroke-width:2px;
+
+ subgraph Config["Config Module (config/)"]
+ direction LR
+ C1[model_config.py
Model Architecture]
+ C2[train_config.py
Training Params]
+ C3[param_config.py
Hyperparameters]
+ C4[schedule_config.py
Scheduler Config]
+ end
- %% Config module
- subgraph Config["Config"]
- C1[model_config.py]
- C2[train_config.py]
- C3[scheduler_config.py]
+ subgraph Data["Data Module (data/)"]
+ direction LR
+ D1[dataset.py
Dataset]
+ D2[sampler.py
Sampler]
+ D3[serialization.py
Serialization]
+ D4[tokenizer.py
Tokenizer]
end
- class Config config;
-
- %% Trainer module
- subgraph Trainer["Trainer"]
- T1[trainer.py]
- T2[train_content.py]
- T3[schedule.py]
- T4[strategy.py]
- T5[train_callback.py]
+
+ subgraph Model["Model Module (model/)"]
+ direction LR
+ M1[transformer.py
Transformer Architecture]
+ M2[module.py
Model Components]
end
- class Trainer trainer;
-
- %% Data module
- subgraph Data["Data"]
- D1[dataset.py]
- D2[sampler.py]
- D3[mmap.py]
- D4[tokenizer.py]
- D5[checkpoint.py]
+
+ subgraph Trainer["Trainer Module (trainer/)"]
+ direction TB
+ T1[trainer.py
Trainer Entry]
+ T2[train_context.py
Training Context]
+ T3[strategy.py
Training Strategy]
+ T4[schedule.py
LR Scheduler]
+ T5[train_callback.py
Callbacks]
+ T6[metric_util.py
Metrics]
end
- class Data data;
-
- %% Model module
- subgraph Model["Model"]
- M1[transformer.py]
- M2[module.py]
+
+ subgraph Inference["Inference Module (inference/)"]
+ direction LR
+ I1[generator.py
Text Generation]
+ I2[core.py
Inference Core]
+ I3[server.py
API Service]
end
- class Model model;
-
- %% Inference module
- subgraph Inference["Inference"]
- I1[generator.py]
- I2[core.py]
+
+ subgraph Parallel["Parallel Module (parallel/)"]
+ direction LR
+ P1[setup.py
Parallel Init]
+ P2[module.py
Parallel Components]
end
- class Inference inference;
-
- %% Parallel module
- subgraph Parallel["Parallel"]
- P1[setup.py]
- P2[module.py]
+
+ subgraph Scripts["Scripts (scripts/)"]
+ direction LR
+ S1[tools/
Train & Inference]
+ S2[demo/
Demos]
end
- class Parallel parallel;
- %% Config dependencies
- C2 -.-> T1
- C1 -.-> M1
- C3 -.-> T3
-
- %% Trainer internal dependencies
- T1 --> T5
- T1 --> T2
- T2 --> T3
- T2 --> T4
-
- %% Data flow
- D1 --> D2
- D1 --> D3
- D1 --> D4
- D1 --> D5
-
- %% Model dependencies
- M1 --> M2
-
- %% Inference dependencies
- I1 --> I2
-
- %% Cross-module dependencies
- T2 -.-> M1
- I1 -.-> M1
- T2 -.-> D1
- T1 -.-> P1
+ %% External config input
+ Config --> Trainer
+
+ %% Training flow
+ Trainer -->|Load Model| Model
+ Trainer -->|Load Data| Data
+ Trainer -->|Setup| Parallel
+
+ %% Inference flow
+ Inference -->|Use Model| Model
+ Inference -->|Use| generator
+
+ %% Data dependency
+ Data -->|Data Pipeline| Model
+
+ %% Parallel dependency
+ Parallel -->|Distributed| Trainer
+
+ %% Scripts
+ Scripts -->|Execute| Trainer
+ Scripts -->|Execute| Inference
```
-### 1. Configuration Management (/config/)
-- **Model Configuration**: Defines model structure parameters (such as layers, heads, dimensions, etc.), managed uniformly through `ModelConfig`.
-- **Training Configuration**: Sets training parameters (such as batch size, training stages PT/SFT/DPO, optimizers, etc.), loaded by `TrainConfig`.
-- **Scheduler Configuration**: Controls learning rate strategies (such as cosine annealing) and training progress.
+### 1. Configuration Module (config/)
+- **model_config.py**: Defines model structure parameters (layers, heads, dimensions, etc.), managed through `ModelConfig`.
+- **train_config.py**: Sets training parameters (batch size, training stages PT/SFT/DPO, optimizers, etc.), loaded by `TrainConfig`.
+- **param_config.py**: Manages hyperparameters for training and inference.
+- **schedule_config.py**: Controls learning rate strategies (cosine annealing) and training progress.
-### 2. Hardware and Parallelism (/parallel/)
-- **Distributed Initialization**: Initializes multi-GPU/multi-machine training environments through the `setup_parallel` function according to configuration.
+### 2. Data Module (data/)
+- **dataset.py**: Dataset handling and loading.
+- **sampler.py**: Data sampling for different training stages.
+- **serialization.py**: Data serialization and deserialization.
+- **tokenizer.py**: Text tokenization and encoding.
-### 3. Data Processing (/data/)
-- **Efficient Loading**: Uses memory mapping (mmap) technology to load massive corpora, avoiding memory overflow and achieving zero-copy reading.
+### 3. Model Module (model/)
+- **transformer.py**: Transformer architecture implementation.
+- **module.py**: Model components and layers.
-### 4. Model and Training (/model/, /trainer/)
-- **Unified Model Architecture**: Based on Transformer, supporting flexible configuration of different scales (such as 7B, 13B).
-- **Strategy-based Trainer**: `Trainer` automatically switches training strategies according to training stages (PT/SFT/DPO), reusing the same training loop.
-- **Training Context Management**: Unifies management of model, optimizer, scheduler, and metrics, supporting seamless multi-stage transitions.
+### 4. Trainer Module (trainer/)
+- **trainer.py**: Main training entry point.
+- **train_context.py**: Training context management (model, optimizer, scheduler, metrics).
+- **strategy.py**: Training strategies for PT/SFT/DPO stages.
+- **schedule.py**: Learning rate scheduler.
+- **train_callback.py**: Training callbacks (checkpoint, early stopping, etc.).
+- **metric_util.py**: Metrics calculation and tracking.
-### 5. Inference Service (/inference/, /utils/)
-- **Unified Generation Interface**: Provides synchronous, batch, and streaming generation methods, adapting to all training stages.
-- **KV Cache Optimization**: Caches Key/Value during autoregressive generation, utilizing high-speed on-chip memory acceleration on NVIDIA GPU.
-- **RAG Support**: Combines retriever and embedding models to inject relevant information from external knowledge bases, improving answer quality.
-- **Intelligent Text Segmentation**:
- - **Structure-first Segmentation**: Splits by titles, paragraphs, etc.;
- - **Semantic Segmentation**: Based on sentence embedding similarity, ensuring fragment semantic completeness and improving fine-tuning effects.
+### 5. Inference Module (inference/)
+- **generator.py**: Text generation with various methods (sync, batch, streaming).
+- **core.py**: Inference core with KV cache optimization.
+- **server.py**: API service for inference.
+
+### 6. Parallel Module (parallel/)
+- **setup.py**: Distributed initialization for multi-GPU/multi-machine training.
+- **module.py**: Parallel communication components.
+
+### 7. Scripts (scripts/)
+- **tools/**: Main scripts for training and inference (train.py, generate.py, etc.).
+- **demo/**: Demo scripts for interactive chat, batch generation, etc.
## 3. Training Process