From 8b6509b305e4dd0c70579c8d2e49f03250990b53 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 2 Apr 2026 20:11:19 +0800 Subject: [PATCH] =?UTF-8?q?docs:=20=E6=9B=B4=E6=96=B0=20design.md=20?= =?UTF-8?q?=E9=A1=B9=E7=9B=AE=E7=BB=93=E6=9E=84=E5=92=8C=E6=A8=A1=E5=9D=97?= =?UTF-8?q?=E6=96=87=E6=A1=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- assets/docs/design.md | 201 ++++++++++++++++++++++-------------------- 1 file changed, 105 insertions(+), 96 deletions(-) diff --git a/assets/docs/design.md b/assets/docs/design.md index e99328d..170731e 100644 --- a/assets/docs/design.md +++ b/assets/docs/design.md @@ -9,117 +9,126 @@ Thus, the AstrAI project was born - 1B parameters, Chinese-English bilingual, su The system is divided into the following modules: ```mermaid -graph LR +flowchart TB %% Style definitions - classDef config fill:#e1f5fe,stroke:#01579b; - classDef trainer fill:#f3e5f5,stroke:#4a148c; - classDef data fill:#e8f5e8,stroke:#1b5e20; - classDef model fill:#fff3e0,stroke:#e65100; - classDef inference fill:#fce4ec,stroke:#880e4f; - classDef parallel fill:#e0f2f1,stroke:#004d40; + classDef config fill:#e1f5fe,stroke:#01579b,stroke-width:2px; + classDef data fill:#e8f5e8,stroke:#1b5e20,stroke-width:2px; + classDef model fill:#fff3e0,stroke:#e65100,stroke-width:2px; + classDef trainer fill:#f3e5f5,stroke:#4a148c,stroke-width:2px; + classDef inference fill:#fce4ec,stroke:#880e4f,stroke-width:2px; + classDef parallel fill:#e0f2f1,stroke:#004d40,stroke-width:2px; + classDef scripts fill:#fffbe6,stroke:#f57f17,stroke-width:2px; + + subgraph Config["Config Module (config/)"] + direction LR + C1[model_config.py
Model Architecture] + C2[train_config.py
Training Params] + C3[param_config.py
Hyperparameters] + C4[schedule_config.py
Scheduler Config] + end - %% Config module - subgraph Config["Config"] - C1[model_config.py] - C2[train_config.py] - C3[scheduler_config.py] + subgraph Data["Data Module (data/)"] + direction LR + D1[dataset.py
Dataset] + D2[sampler.py
Sampler] + D3[serialization.py
Serialization] + D4[tokenizer.py
Tokenizer] end - class Config config; - - %% Trainer module - subgraph Trainer["Trainer"] - T1[trainer.py] - T2[train_content.py] - T3[schedule.py] - T4[strategy.py] - T5[train_callback.py] + + subgraph Model["Model Module (model/)"] + direction LR + M1[transformer.py
Transformer Architecture] + M2[module.py
Model Components] end - class Trainer trainer; - - %% Data module - subgraph Data["Data"] - D1[dataset.py] - D2[sampler.py] - D3[mmap.py] - D4[tokenizer.py] - D5[checkpoint.py] + + subgraph Trainer["Trainer Module (trainer/)"] + direction TB + T1[trainer.py
Trainer Entry] + T2[train_context.py
Training Context] + T3[strategy.py
Training Strategy] + T4[schedule.py
LR Scheduler] + T5[train_callback.py
Callbacks] + T6[metric_util.py
Metrics] end - class Data data; - - %% Model module - subgraph Model["Model"] - M1[transformer.py] - M2[module.py] + + subgraph Inference["Inference Module (inference/)"] + direction LR + I1[generator.py
Text Generation] + I2[core.py
Inference Core] + I3[server.py
API Service] end - class Model model; - - %% Inference module - subgraph Inference["Inference"] - I1[generator.py] - I2[core.py] + + subgraph Parallel["Parallel Module (parallel/)"] + direction LR + P1[setup.py
Parallel Init] + P2[module.py
Parallel Components] end - class Inference inference; - - %% Parallel module - subgraph Parallel["Parallel"] - P1[setup.py] - P2[module.py] + + subgraph Scripts["Scripts (scripts/)"] + direction LR + S1[tools/
Train & Inference] + S2[demo/
Demos] end - class Parallel parallel; - %% Config dependencies - C2 -.-> T1 - C1 -.-> M1 - C3 -.-> T3 - - %% Trainer internal dependencies - T1 --> T5 - T1 --> T2 - T2 --> T3 - T2 --> T4 - - %% Data flow - D1 --> D2 - D1 --> D3 - D1 --> D4 - D1 --> D5 - - %% Model dependencies - M1 --> M2 - - %% Inference dependencies - I1 --> I2 - - %% Cross-module dependencies - T2 -.-> M1 - I1 -.-> M1 - T2 -.-> D1 - T1 -.-> P1 + %% External config input + Config --> Trainer + + %% Training flow + Trainer -->|Load Model| Model + Trainer -->|Load Data| Data + Trainer -->|Setup| Parallel + + %% Inference flow + Inference -->|Use Model| Model + Inference -->|Use| generator + + %% Data dependency + Data -->|Data Pipeline| Model + + %% Parallel dependency + Parallel -->|Distributed| Trainer + + %% Scripts + Scripts -->|Execute| Trainer + Scripts -->|Execute| Inference ``` -### 1. Configuration Management (/config/) -- **Model Configuration**: Defines model structure parameters (such as layers, heads, dimensions, etc.), managed uniformly through `ModelConfig`. -- **Training Configuration**: Sets training parameters (such as batch size, training stages PT/SFT/DPO, optimizers, etc.), loaded by `TrainConfig`. -- **Scheduler Configuration**: Controls learning rate strategies (such as cosine annealing) and training progress. +### 1. Configuration Module (config/) +- **model_config.py**: Defines model structure parameters (layers, heads, dimensions, etc.), managed through `ModelConfig`. +- **train_config.py**: Sets training parameters (batch size, training stages PT/SFT/DPO, optimizers, etc.), loaded by `TrainConfig`. +- **param_config.py**: Manages hyperparameters for training and inference. +- **schedule_config.py**: Controls learning rate strategies (cosine annealing) and training progress. -### 2. Hardware and Parallelism (/parallel/) -- **Distributed Initialization**: Initializes multi-GPU/multi-machine training environments through the `setup_parallel` function according to configuration. +### 2. Data Module (data/) +- **dataset.py**: Dataset handling and loading. +- **sampler.py**: Data sampling for different training stages. +- **serialization.py**: Data serialization and deserialization. +- **tokenizer.py**: Text tokenization and encoding. -### 3. Data Processing (/data/) -- **Efficient Loading**: Uses memory mapping (mmap) technology to load massive corpora, avoiding memory overflow and achieving zero-copy reading. +### 3. Model Module (model/) +- **transformer.py**: Transformer architecture implementation. +- **module.py**: Model components and layers. -### 4. Model and Training (/model/, /trainer/) -- **Unified Model Architecture**: Based on Transformer, supporting flexible configuration of different scales (such as 7B, 13B). -- **Strategy-based Trainer**: `Trainer` automatically switches training strategies according to training stages (PT/SFT/DPO), reusing the same training loop. -- **Training Context Management**: Unifies management of model, optimizer, scheduler, and metrics, supporting seamless multi-stage transitions. +### 4. Trainer Module (trainer/) +- **trainer.py**: Main training entry point. +- **train_context.py**: Training context management (model, optimizer, scheduler, metrics). +- **strategy.py**: Training strategies for PT/SFT/DPO stages. +- **schedule.py**: Learning rate scheduler. +- **train_callback.py**: Training callbacks (checkpoint, early stopping, etc.). +- **metric_util.py**: Metrics calculation and tracking. -### 5. Inference Service (/inference/, /utils/) -- **Unified Generation Interface**: Provides synchronous, batch, and streaming generation methods, adapting to all training stages. -- **KV Cache Optimization**: Caches Key/Value during autoregressive generation, utilizing high-speed on-chip memory acceleration on NVIDIA GPU. -- **RAG Support**: Combines retriever and embedding models to inject relevant information from external knowledge bases, improving answer quality. -- **Intelligent Text Segmentation**: - - **Structure-first Segmentation**: Splits by titles, paragraphs, etc.; - - **Semantic Segmentation**: Based on sentence embedding similarity, ensuring fragment semantic completeness and improving fine-tuning effects. +### 5. Inference Module (inference/) +- **generator.py**: Text generation with various methods (sync, batch, streaming). +- **core.py**: Inference core with KV cache optimization. +- **server.py**: API service for inference. + +### 6. Parallel Module (parallel/) +- **setup.py**: Distributed initialization for multi-GPU/multi-machine training. +- **module.py**: Parallel communication components. + +### 7. Scripts (scripts/) +- **tools/**: Main scripts for training and inference (train.py, generate.py, etc.). +- **demo/**: Demo scripts for interactive chat, batch generation, etc. ## 3. Training Process