## 1. Why I Created This Project
There are many large language models on the market today, such as GPT, LLaMA, and others, with tens of billions or even hundreds of billions of parameters. But honestly, these models have extremely high hardware requirements, making them inaccessible for ordinary developers. I thought: **Can we create a model that is both useful and can run on ordinary computers?** This is also what most people currently hope for - a locally deployable AI project that achieves complete privatization while maintaining some level of intelligence.
Thus, the AstrAI project was born - 1B parameters, Chinese-English bilingual, supporting dialogue, text generation, RAG retrieval, and the training code is open source!
## 2. System Architecture
The system is divided into the following modules:
```mermaid
flowchart TB
%% Style definitions
classDef config fill:#e1f5fe,stroke:#01579b,stroke-width:2px;
classDef data fill:#e8f5e8,stroke:#1b5e20,stroke-width:2px;
classDef model fill:#fff3e0,stroke:#e65100,stroke-width:2px;
classDef trainer fill:#f3e5f5,stroke:#4a148c,stroke-width:2px;
classDef inference fill:#fce4ec,stroke:#880e4f,stroke-width:2px;
classDef parallel fill:#e0f2f1,stroke:#004d40,stroke-width:2px;
classDef scripts fill:#fffbe6,stroke:#f57f17,stroke-width:2px;
subgraph Config["Config Module (config/)"]
direction LR
C1[model_config.py
Model Architecture]
C2[train_config.py
Training Params]
C3[param_config.py
Hyperparameters]
C4[schedule_config.py
Scheduler Config]
end
subgraph Data["Data Module (data/)"]
direction LR
D1[dataset.py
Dataset]
D2[sampler.py
Sampler]
D3[serialization.py
Serialization]
D4[tokenizer.py
Tokenizer]
end
subgraph Model["Model Module (model/)"]
direction LR
M1[transformer.py
Transformer Architecture]
M2[module.py
Model Components]
end
subgraph Trainer["Trainer Module (trainer/)"]
direction TB
T1[trainer.py
Trainer Entry]
T2[train_context.py
Training Context]
T3[strategy.py
Training Strategy]
T4[schedule.py
LR Scheduler]
T5[train_callback.py
Callbacks]
T6[metric_util.py
Metrics]
end
subgraph Inference["Inference Module (inference/)"]
direction LR
I1[generator.py
Text Generation]
I2[core.py
Inference Core]
I3[server.py
API Service]
end
subgraph Parallel["Parallel Module (parallel/)"]
direction LR
P1[setup.py
Parallel Init]
P2[module.py
Parallel Components]
end
subgraph Scripts["Scripts (scripts/)"]
direction LR
S1[tools/
Train & Inference]
S2[demo/
Demos]
end
%% External config input
Config --> Trainer
%% Training flow
Trainer -->|Load Model| Model
Trainer -->|Load Data| Data
Trainer -->|Setup| Parallel
%% Inference flow
Inference -->|Use Model| Model
Inference -->|Use| generator
%% Data dependency
Data -->|Data Pipeline| Model
%% Parallel dependency
Parallel -->|Distributed| Trainer
%% Scripts
Scripts -->|Execute| Trainer
Scripts -->|Execute| Inference
```
### 1. Configuration Module (config/)
- **model_config.py**: Defines model structure parameters (layers, heads, dimensions, etc.), managed through `ModelConfig`.
- **train_config.py**: Sets training parameters (batch size, training stages PT/SFT/DPO, optimizers, etc.), loaded by `TrainConfig`.
- **param_config.py**: Manages hyperparameters for training and inference.
- **schedule_config.py**: Controls learning rate strategies (cosine annealing) and training progress.
### 2. Data Module (data/)
- **dataset.py**: Dataset handling and loading.
- **sampler.py**: Data sampling for different training stages.
- **serialization.py**: Data serialization and deserialization.
- **tokenizer.py**: Text tokenization and encoding.
### 3. Model Module (model/)
- **transformer.py**: Transformer architecture implementation.
- **module.py**: Model components and layers.
### 4. Trainer Module (trainer/)
- **trainer.py**: Main training entry point.
- **train_context.py**: Training context management (model, optimizer, scheduler, metrics).
- **strategy.py**: Training strategies for PT/SFT/DPO stages.
- **schedule.py**: Learning rate scheduler.
- **train_callback.py**: Training callbacks (checkpoint, early stopping, etc.).
- **metric_util.py**: Metrics calculation and tracking.
### 5. Inference Module (inference/)
- **generator.py**: Text generation with various methods (sync, batch, streaming).
- **core.py**: Inference core with KV cache optimization.
- **server.py**: API service for inference.
### 6. Parallel Module (parallel/)
- **setup.py**: Distributed initialization for multi-GPU/multi-machine training.
- **module.py**: Parallel communication components.
### 7. Scripts (scripts/)
- **tools/**: Main scripts for training and inference (train.py, generate.py, etc.).
- **demo/**: Demo scripts for interactive chat, batch generation, etc.
## 3. Training Process
The common training process for large language models (LLM) typically includes three stages: **Pre-training (PT)**, **Supervised Fine-Tuning (SFT)**, and **Reinforcement Learning from Human Feedback (RLHF)**. This system is designed to support seamless end-to-end flow, achieving efficient switching and state management of different training stages through modular strategies, ensuring the model's capabilities gradually evolve from general language understanding to human-preference-aligned dialogue and instruction execution.
### **2.1 Pre-training Stage**
The pre-training stage aims to build the model's foundational language capabilities and general knowledge representation. This stage performs self-supervised learning on large-scale, unlabeled corpora (typically covering hundreds of GB to TB of text data). The model architecture is based on the standard Transformer Decoder, trained through masked language modeling objectives (such as causal language modeling), enabling the model to learn vocabulary, grammar, semantics, and world knowledge embedded in text.
**Core Formula: Causal Language Modeling**
$$
L_{\text{PT}} = - \sum_{t=1}^{T} \log P(x_t \mid x_{\lt t}; \theta)
$$
**Symbol Description:**
- $T$: Sequence length
- $x_t$: The $t$-th token in the sequence
- $x_{