From c0e0e6afd9d5a0c4e1a5798768f1c5765593cdcd Mon Sep 17 00:00:00 2001
From: ViperEkura <3081035982@qq.com>
Date: Fri, 3 Apr 2026 22:11:19 +0800
Subject: [PATCH] =?UTF-8?q?docs:=20=E6=9B=B4=E6=96=B0=E6=96=87=E6=A1=A3?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
assets/docs/design.md | 379 ++++++++++++++++++++++--------------------
1 file changed, 202 insertions(+), 177 deletions(-)
diff --git a/assets/docs/design.md b/assets/docs/design.md
index b9fa3d3..c4f0952 100644
--- a/assets/docs/design.md
+++ b/assets/docs/design.md
@@ -2,210 +2,235 @@
There are many large language models on the market today, such as GPT, LLaMA, and others, with tens of billions or even hundreds of billions of parameters. But honestly, these models have extremely high hardware requirements, making them inaccessible for ordinary developers. I thought: **Can we create a model that is both useful and can run on ordinary computers?** This is also what most people currently hope for - a locally deployable AI project that achieves complete privatization while maintaining some level of intelligence.
-Thus, the AstrAI project was born - 1B parameters, Chinese-English bilingual, supporting dialogue, text generation, RAG retrieval, and the training code is open source!
+Thus, the AstrAI project was born - 1B parameters, Chinese-English bilingual, supporting dialogue, text generation, and the training code is open source!
## 2. System Architecture
-The system is divided into the following modules:
-
```mermaid
-flowchart TB
- %% Style definitions
- classDef config fill:#e1f5fe,stroke:#01579b,stroke-width:2px;
- classDef data fill:#e8f5e8,stroke:#1b5e20,stroke-width:2px;
- classDef model fill:#fff3e0,stroke:#e65100,stroke-width:2px;
- classDef trainer fill:#f3e5f5,stroke:#4a148c,stroke-width:2px;
- classDef inference fill:#fce4ec,stroke:#880e4f,stroke-width:2px;
- classDef parallel fill:#e0f2f1,stroke:#004d40,stroke-width:2px;
- classDef scripts fill:#fffbe6,stroke:#f57f17,stroke-width:2px;
+classDiagram
+ %% Configuration Classes
+ class ModelConfig {
+ +int vocab_size
+ +int dim
+ +int n_layers
+ +float norm_eps
+ +int dim_ffn
+ +int max_len
+ +float rope_theta
+ +int n_heads
+ +int n_kv_heads
+ +bool use_qk_norm
+ +bool use_gated_attention
+ +load(config_path) ModelConfig
+ +save(config_path)
+ }
- subgraph Config["Config Module (config/)"]
- direction LR
- C1[model_config.py
Model Architecture]
- C2[train_config.py
Training Params]
- C3[param_config.py
Hyperparameters]
- end
-
- subgraph Data["Data Module (data/)"]
- direction LR
- D1[dataset.py
Dataset]
- D2[sampler.py
Sampler]
- D3[serialization.py
Serialization]
- D4[tokenizer.py
Tokenizer]
- end
-
- subgraph Model["Model Module (model/)"]
- direction LR
- M1[transformer.py
Transformer Architecture]
- M2[module.py
Model Components]
- end
-
- subgraph Trainer["Trainer Module (trainer/)"]
- direction TB
- T1[trainer.py
Trainer Entry]
- T2[train_context.py
Training Context]
- T3[strategy.py
Training Strategy]
- T4[schedule.py
LR Scheduler]
- T5[train_callback.py
Callbacks]
- T6[metric_util.py
Metrics]
- end
-
- subgraph Inference["Inference Module (inference/)"]
- direction LR
- I1[generator.py
Text Generation]
- I2[core.py
Inference Core]
- I3[server.py
API Service]
- end
-
- subgraph Parallel["Parallel Module (parallel/)"]
- direction LR
- P1[setup.py
Parallel Init]
- P2[module.py
Parallel Components]
- end
-
- subgraph Scripts["Scripts (scripts/)"]
- direction LR
- S1[tools/
Train & Inference]
- S2[demo/
Demos]
- end
+ class TrainConfig {
+ +nn.Module model
+ +str strategy
+ +Dataset dataset
+ +Callable optimizer_fn
+ +Callable scheduler_fn
+ +int n_epoch
+ +int batch_size
+ +int accumulation_steps
+ +float max_grad_norm
+ +str ckpt_dir
+ +int ckpt_interval
+ +int nprocs
+ +str backend
+ +validate()
+ }
- %% External config input
- Config --> Trainer
-
- %% Training flow
- Trainer -->|Load Model| Model
- Trainer -->|Load Data| Data
- Trainer -->|Setup| Parallel
-
- %% Inference flow
- Inference -->|Use Model| Model
- Inference -->|Use| generator
-
- %% Data dependency
- Data -->|Data Pipeline| Model
-
- %% Parallel dependency
- Parallel -->|Distributed| Trainer
-
- %% Scripts
- Scripts -->|Execute| Trainer
- Scripts -->|Execute| Inference
+ %% Data Classes
+ class Dataset {
+ +__len__()
+ +__getitem__()
+ }
+
+ class Checkpoint {
+ +dict state_dict
+ +int epoch
+ +int iteration
+ }
+
+ class Tokenizer {
+ +encode(text) List[int]
+ +decode(ids) str
+ }
+
+ %% Model Classes
+ class Transformer {
+ +forward(input_ids, mask) Dict
+ }
+
+ %% Trainer Classes
+ class Trainer {
+ +TrainConfig train_config
+ +List~TrainCallback~ callbacks
+ +train()
+ +_build_context() TrainContext
+ }
+
+ class TrainContext {
+ +nn.Module model
+ +BaseStrategy strategy
+ +DataLoader dataloader
+ +Optimizer optimizer
+ +LRScheduler scheduler
+ +Checkpoint checkpoint
+ +int epoch
+ +int iteration
+ }
+
+ class TrainContextBuilder {
+ +TrainConfig config
+ +with_checkpoint(Checkpoint) TrainContextBuilder
+ +with_dataloader() TrainContextBuilder
+ +with_strategy() TrainContextBuilder
+ +build() TrainContext
+ }
+
+ class BaseStrategy {
+ +nn.Module model
+ +str device
+ +compute_loss(batch) Tensor
+ }
+
+ class StrategyFactory {
+ +frozenset SUPPORTED_STRATEGIES
+ +Dict STRATEGY_MAP
+ +register(name) decorator
+ +create(model, train_type, device) BaseStrategy
+ +available_strategies() list
+ }
+
+ class SEQStrategy {
+ +float label_smoothing
+ +compute_loss(batch) Tensor
+ }
+
+ class SFTStrategy {
+ +float label_smoothing
+ +compute_loss(batch) Tensor
+ }
+
+ class DPOStrategy {
+ +nn.Module ref_model
+ +float beta
+ +str reduction
+ +compute_loss(batch) Tensor
+ }
+
+ class GRPOStrategy {
+ +nn.Module ref_model
+ +float clip_eps
+ +float kl_coef
+ +int group_size
+ +compute_loss(batch) Tensor
+ }
+
+ class TrainCallback {
+ +on_train_begin(trainer)
+ +on_train_end(trainer)
+ +on_epoch_begin(epoch, trainer)
+ +on_epoch_end(epoch, trainer)
+ +on_batch_begin(batch, trainer)
+ +on_batch_end(batch, trainer)
+ }
+
+ class Schedule {
+ +step()
+ }
+
+ %% Inference Classes
+ class Generator {
+ +generate(prompt, config) str
+ +generate_batch(prompts, config) List[str]
+ +stream_generate(prompt, config) Generator
+ }
+
+ class InferenceCore {
+ +forward(input_ids) Dict
+ +apply_sampling_strategies()
+ }
+
+ class Server {
+ +start()
+ +predict(request)
+ }
+
+ %% Parallel Classes
+ class ParallelSetup {
+ +spawn_parallel_fn(fn, nprocs)
+ }
+
+ %% Relationships
+ TrainConfig --> ModelConfig : contains
+ TrainConfig --> Dataset : uses
+ TrainConfig --> Transformer : uses
+ Trainer --> TrainConfig : configures
+ Trainer --> TrainContextBuilder : builds
+ Trainer --> TrainCallback : manages
+ TrainContextBuilder --> TrainContext : creates
+ TrainContext --> Checkpoint : manages
+ StrategyFactory ..> BaseStrategy : creates
+ BaseStrategy <|-- SEQStrategy
+ BaseStrategy <|-- SFTStrategy
+ BaseStrategy <|-- DPOStrategy
+ BaseStrategy <|-- GRPOStrategy
+ TrainContext --> BaseStrategy : uses
+ Generator --> InferenceCore : uses
+ Generator --> Transformer : uses
+ Server --> Generator : uses
+ ParallelSetup --> Trainer : enables
+ TrainConfig --> StrategyFactory : selects
+ TrainCallback <|-- CheckpointCallback
+ TrainCallback <|-- MetricLoggerCallback
+ TrainCallback <|-- SchedulerCallback
+ TrainContext --> Schedule : uses
```
-### 1. Configuration Module (config/)
-- **model_config.py**: Defines model structure parameters (layers, heads, dimensions, etc.), managed through `ModelConfig`.
-- **train_config.py**: Sets training parameters (batch size, training stages SEQ/SFT/GRPO/DPO, optimizers, etc.), loaded by `TrainConfig`.
-- **param_config.py**: Manages hyperparameters for training and inference.
+### Design Pattern Summary
-### 2. Data Module (data/)
-- **dataset.py**: Dataset handling and loading.
-- **sampler.py**: Data sampling for different training stages.
-- **serialization.py**: Data serialization and deserialization, checkpoint management.
-- **tokenizer.py**: Text tokenization and encoding.
+| Pattern | Classes | Purpose |
+|---------|---------|---------|
+| **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO |
+| **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components |
+| **Factory** | `StrategyFactory` | Decorator registration mechanism, dynamically create training strategies |
+| **Observer** | `TrainCallback` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
+| **Singleton** | `TrainContext` | Training process global state management |
-### 3. Model Module (model/)
-- **transformer.py**: Transformer architecture implementation.
-- **module.py**: Model components and layers.
+### Core Relationships
-### 4. Trainer Module (trainer/)
-- **trainer.py**: Main training entry point.
-- **train_context.py**: Training context management (model, optimizer, scheduler, metrics).
-- **strategy.py**: Training strategies for SEQ/SFT/GRPO/DPO stages via `StrategyFactory`.
-- **schedule.py**: Learning rate scheduler implementation (cosine, SGDR, etc.).
-- **train_callback.py**: Training callbacks (checkpoint, early stopping, etc.).
-- **metric_util.py**: Metrics calculation and tracking.
-
-### 5. Inference Module (inference/)
-- **generator.py**: Text generation with various methods (sync, batch, streaming).
-- **core.py**: Inference core with KV cache optimization.
-- **server.py**: API service for inference (FastAPI + Uvicorn).
-
-### 6. Parallel Module (parallel/)
-- **setup.py**: Distributed initialization for multi-GPU/multi-machine training.
-- **module.py**: Parallel communication components.
-
-### 7. Scripts (scripts/)
-- **tools/**: Main scripts for training and inference (train.py, generate.py, etc.).
-- **demo/**: Demo scripts for interactive chat, batch generation, etc.
+1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
+2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss
+3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
+4. **Inference Flow**: `Server` → `Generator` → `InferenceCore` → `Transformer`
+5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
## 3. Training Process
-The common training process for large language models (LLM) typically includes three stages: **Pre-training (PT)**, **Supervised Fine-Tuning (SFT)**, and **Reinforcement Learning from Human Feedback (RLHF)**. This system is designed to support seamless end-to-end flow, achieving efficient switching and state management of different training stages through modular strategies, ensuring the model's capabilities gradually evolve from general language understanding to human-preference-aligned dialogue and instruction execution.
+The common training process for large language models (LLM) typically includes three stages: **Pre-training (SEQ)**, **Supervised Fine-Tuning (SFT)**, and **Reinforcement Learning from Human Feedback (DPO/GRPO)**. This system is designed to support seamless end-to-end flow, achieving efficient switching and state management of different training stages through modular strategies.
-### **2.1 Pre-training Stage (SEQ/PT)**
+### Core Formulas
-The pre-training stage aims to build the model's foundational language capabilities and general knowledge representation. This stage performs self-supervised learning on large-scale, unlabeled corpora (typically covering hundreds of GB to TB of text data). The model architecture is based on the standard Transformer Decoder, trained through masked language modeling objectives (such as causal language modeling), enabling the model to learn vocabulary, grammar, semantics, and world knowledge embedded in text.
-
-**Core Formula: Causal Language Modeling**
+**Pre-training (SEQ):**
$$
L_{\text{PT}} = - \sum_{t=1}^{T} \log P(x_t \mid x_{\lt t}; \theta)
$$
-**Symbol Description:**
-
-- $T$: Sequence length
-- $x_t$: The $t$-th token in the sequence
-- $x_{