docs: 更新 design.md 项目结构和模块文档

This commit is contained in:
ViperEkura 2026-04-02 20:11:19 +08:00
parent 912d7c7f54
commit 8b6509b305
1 changed files with 105 additions and 96 deletions

View File

@ -9,117 +9,126 @@ Thus, the AstrAI project was born - 1B parameters, Chinese-English bilingual, su
The system is divided into the following modules:
```mermaid
graph LR
flowchart TB
%% Style definitions
classDef config fill:#e1f5fe,stroke:#01579b;
classDef trainer fill:#f3e5f5,stroke:#4a148c;
classDef data fill:#e8f5e8,stroke:#1b5e20;
classDef model fill:#fff3e0,stroke:#e65100;
classDef inference fill:#fce4ec,stroke:#880e4f;
classDef parallel fill:#e0f2f1,stroke:#004d40;
classDef config fill:#e1f5fe,stroke:#01579b,stroke-width:2px;
classDef data fill:#e8f5e8,stroke:#1b5e20,stroke-width:2px;
classDef model fill:#fff3e0,stroke:#e65100,stroke-width:2px;
classDef trainer fill:#f3e5f5,stroke:#4a148c,stroke-width:2px;
classDef inference fill:#fce4ec,stroke:#880e4f,stroke-width:2px;
classDef parallel fill:#e0f2f1,stroke:#004d40,stroke-width:2px;
classDef scripts fill:#fffbe6,stroke:#f57f17,stroke-width:2px;
%% Config module
subgraph Config["Config"]
C1[model_config.py]
C2[train_config.py]
C3[scheduler_config.py]
subgraph Config["Config Module (config/)"]
direction LR
C1[model_config.py<br/>Model Architecture]
C2[train_config.py<br/>Training Params]
C3[param_config.py<br/>Hyperparameters]
C4[schedule_config.py<br/>Scheduler Config]
end
class Config config;
%% Trainer module
subgraph Trainer["Trainer"]
T1[trainer.py]
T2[train_content.py]
T3[schedule.py]
T4[strategy.py]
T5[train_callback.py]
subgraph Data["Data Module (data/)"]
direction LR
D1[dataset.py<br/>Dataset]
D2[sampler.py<br/>Sampler]
D3[serialization.py<br/>Serialization]
D4[tokenizer.py<br/>Tokenizer]
end
class Trainer trainer;
%% Data module
subgraph Data["Data"]
D1[dataset.py]
D2[sampler.py]
D3[mmap.py]
D4[tokenizer.py]
D5[checkpoint.py]
subgraph Model["Model Module (model/)"]
direction LR
M1[transformer.py<br/>Transformer Architecture]
M2[module.py<br/>Model Components]
end
class Data data;
%% Model module
subgraph Model["Model"]
M1[transformer.py]
M2[module.py]
subgraph Trainer["Trainer Module (trainer/)"]
direction TB
T1[trainer.py<br/>Trainer Entry]
T2[train_context.py<br/>Training Context]
T3[strategy.py<br/>Training Strategy]
T4[schedule.py<br/>LR Scheduler]
T5[train_callback.py<br/>Callbacks]
T6[metric_util.py<br/>Metrics]
end
class Model model;
%% Inference module
subgraph Inference["Inference"]
I1[generator.py]
I2[core.py]
subgraph Inference["Inference Module (inference/)"]
direction LR
I1[generator.py<br/>Text Generation]
I2[core.py<br/>Inference Core]
I3[server.py<br/>API Service]
end
class Inference inference;
%% Parallel module
subgraph Parallel["Parallel"]
P1[setup.py]
P2[module.py]
subgraph Parallel["Parallel Module (parallel/)"]
direction LR
P1[setup.py<br/>Parallel Init]
P2[module.py<br/>Parallel Components]
end
class Parallel parallel;
%% Config dependencies
C2 -.-> T1
C1 -.-> M1
C3 -.-> T3
subgraph Scripts["Scripts (scripts/)"]
direction LR
S1[tools/<br/>Train & Inference]
S2[demo/<br/>Demos]
end
%% Trainer internal dependencies
T1 --> T5
T1 --> T2
T2 --> T3
T2 --> T4
%% External config input
Config --> Trainer
%% Data flow
D1 --> D2
D1 --> D3
D1 --> D4
D1 --> D5
%% Training flow
Trainer -->|Load Model| Model
Trainer -->|Load Data| Data
Trainer -->|Setup| Parallel
%% Model dependencies
M1 --> M2
%% Inference flow
Inference -->|Use Model| Model
Inference -->|Use| generator
%% Inference dependencies
I1 --> I2
%% Data dependency
Data -->|Data Pipeline| Model
%% Cross-module dependencies
T2 -.-> M1
I1 -.-> M1
T2 -.-> D1
T1 -.-> P1
%% Parallel dependency
Parallel -->|Distributed| Trainer
%% Scripts
Scripts -->|Execute| Trainer
Scripts -->|Execute| Inference
```
### 1. Configuration Management (/config/)
- **Model Configuration**: Defines model structure parameters (such as layers, heads, dimensions, etc.), managed uniformly through `ModelConfig`.
- **Training Configuration**: Sets training parameters (such as batch size, training stages PT/SFT/DPO, optimizers, etc.), loaded by `TrainConfig`.
- **Scheduler Configuration**: Controls learning rate strategies (such as cosine annealing) and training progress.
### 1. Configuration Module (config/)
- **model_config.py**: Defines model structure parameters (layers, heads, dimensions, etc.), managed through `ModelConfig`.
- **train_config.py**: Sets training parameters (batch size, training stages PT/SFT/DPO, optimizers, etc.), loaded by `TrainConfig`.
- **param_config.py**: Manages hyperparameters for training and inference.
- **schedule_config.py**: Controls learning rate strategies (cosine annealing) and training progress.
### 2. Hardware and Parallelism (/parallel/)
- **Distributed Initialization**: Initializes multi-GPU/multi-machine training environments through the `setup_parallel` function according to configuration.
### 2. Data Module (data/)
- **dataset.py**: Dataset handling and loading.
- **sampler.py**: Data sampling for different training stages.
- **serialization.py**: Data serialization and deserialization.
- **tokenizer.py**: Text tokenization and encoding.
### 3. Data Processing (/data/)
- **Efficient Loading**: Uses memory mapping (mmap) technology to load massive corpora, avoiding memory overflow and achieving zero-copy reading.
### 3. Model Module (model/)
- **transformer.py**: Transformer architecture implementation.
- **module.py**: Model components and layers.
### 4. Model and Training (/model/, /trainer/)
- **Unified Model Architecture**: Based on Transformer, supporting flexible configuration of different scales (such as 7B, 13B).
- **Strategy-based Trainer**: `Trainer` automatically switches training strategies according to training stages (PT/SFT/DPO), reusing the same training loop.
- **Training Context Management**: Unifies management of model, optimizer, scheduler, and metrics, supporting seamless multi-stage transitions.
### 4. Trainer Module (trainer/)
- **trainer.py**: Main training entry point.
- **train_context.py**: Training context management (model, optimizer, scheduler, metrics).
- **strategy.py**: Training strategies for PT/SFT/DPO stages.
- **schedule.py**: Learning rate scheduler.
- **train_callback.py**: Training callbacks (checkpoint, early stopping, etc.).
- **metric_util.py**: Metrics calculation and tracking.
### 5. Inference Service (/inference/, /utils/)
- **Unified Generation Interface**: Provides synchronous, batch, and streaming generation methods, adapting to all training stages.
- **KV Cache Optimization**: Caches Key/Value during autoregressive generation, utilizing high-speed on-chip memory acceleration on NVIDIA GPU.
- **RAG Support**: Combines retriever and embedding models to inject relevant information from external knowledge bases, improving answer quality.
- **Intelligent Text Segmentation**:
- **Structure-first Segmentation**: Splits by titles, paragraphs, etc.;
- **Semantic Segmentation**: Based on sentence embedding similarity, ensuring fragment semantic completeness and improving fine-tuning effects.
### 5. Inference Module (inference/)
- **generator.py**: Text generation with various methods (sync, batch, streaming).
- **core.py**: Inference core with KV cache optimization.
- **server.py**: API service for inference.
### 6. Parallel Module (parallel/)
- **setup.py**: Distributed initialization for multi-GPU/multi-machine training.
- **module.py**: Parallel communication components.
### 7. Scripts (scripts/)
- **tools/**: Main scripts for training and inference (train.py, generate.py, etc.).
- **demo/**: Demo scripts for interactive chat, batch generation, etc.
## 3. Training Process