docs: 更新文档
This commit is contained in:
parent
feaa3fca36
commit
bf7adb35b3
|
|
@ -95,6 +95,7 @@ flowchart LR
|
||||||
- Contains embedding layer, multi-layer `DecoderBlock`, RMSNorm, and linear output head
|
- Contains embedding layer, multi-layer `DecoderBlock`, RMSNorm, and linear output head
|
||||||
- Supports weight tying (`tie_weight=True`) to reduce parameter count
|
- Supports weight tying (`tie_weight=True`) to reduce parameter count
|
||||||
- Uses Rotary Position Embedding (RoPE) to inject position information
|
- Uses Rotary Position Embedding (RoPE) to inject position information
|
||||||
|
- Supports loading from safetensors format with automatic model type detection from `config.json`
|
||||||
|
|
||||||
#### 2.2 Submodules (`module.py`)
|
#### 2.2 Submodules (`module.py`)
|
||||||
- **`RotaryEmbedding`**: Generates RoPE cos/sin cache
|
- **`RotaryEmbedding`**: Generates RoPE cos/sin cache
|
||||||
|
|
@ -137,7 +138,12 @@ flowchart LR
|
||||||
#### 5.1 Inference Engine (`engine.py`)
|
#### 5.1 Inference Engine (`engine.py`)
|
||||||
- **`InferenceEngine`**: Unified inference interface, supports streaming and non-streaming generation
|
- **`InferenceEngine`**: Unified inference interface, supports streaming and non-streaming generation
|
||||||
- **`InferenceScheduler`**: Continuous batching scheduler with dynamic batch composition
|
- **`InferenceScheduler`**: Continuous batching scheduler with dynamic batch composition
|
||||||
- Manages task queue (`waiting_queue`, `active_tasks`) and KV cache allocation
|
- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.)
|
||||||
|
- **`messages` format**: List of message dictionaries with `role` (system/user/assistant) and `content`
|
||||||
|
- **`apply_chat_template`** (from `tokenizer.py`): Converts messages into prompt string using ChatML format
|
||||||
|
- Provides streaming (`stream=True`) and non-streaming (`stream=False`) generation interfaces
|
||||||
|
- Supports continuous batching with `max_batch_size` and `max_seq_len` parameters
|
||||||
|
- Uses separate model and tokenizer initialization for flexibility
|
||||||
|
|
||||||
#### 5.2 Scheduler (`scheduler.py`)
|
#### 5.2 Scheduler (`scheduler.py`)
|
||||||
- **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED)
|
- **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED)
|
||||||
|
|
|
||||||
|
|
@ -175,6 +175,17 @@ classDiagram
|
||||||
+forward(x, rotary_emb, mask, kv_cache, start_pos) Tensor
|
+forward(x, rotary_emb, mask, kv_cache, start_pos) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class MLA {
|
||||||
|
+int n_heads
|
||||||
|
+int n_kv_heads
|
||||||
|
+int head_dim
|
||||||
|
+Linear q_a_proj, q_b_proj, q_c_proj
|
||||||
|
+Linear kv_a_proj, kv_b_proj, kv_c_proj
|
||||||
|
+Linear o_proj
|
||||||
|
+RMSNorm q_norm, k_norm
|
||||||
|
+forward(x, rotary_emb, mask, kv_cache, start_pos) Tensor
|
||||||
|
}
|
||||||
|
|
||||||
class MLP {
|
class MLP {
|
||||||
+Linear up, gate, down
|
+Linear up, gate, down
|
||||||
+forward(x) Tensor
|
+forward(x) Tensor
|
||||||
|
|
@ -469,6 +480,7 @@ classDiagram
|
||||||
Transformer --> RotaryEmbedding : uses
|
Transformer --> RotaryEmbedding : uses
|
||||||
Transformer --> Embedding : uses
|
Transformer --> Embedding : uses
|
||||||
DecoderBlock --> GQA : uses
|
DecoderBlock --> GQA : uses
|
||||||
|
DecoderBlock --> MLA : uses
|
||||||
DecoderBlock --> MLP : uses
|
DecoderBlock --> MLP : uses
|
||||||
DecoderBlock --> RMSNorm : uses
|
DecoderBlock --> RMSNorm : uses
|
||||||
BpeTokenizer --> Tokenizer : inherits
|
BpeTokenizer --> Tokenizer : inherits
|
||||||
|
|
@ -483,8 +495,8 @@ classDiagram
|
||||||
|--------|------------|-------------|
|
|--------|------------|-------------|
|
||||||
| **astrai.config** | ModelConfig, TrainConfig, ModelParameter | Configuration management |
|
| **astrai.config** | ModelConfig, TrainConfig, ModelParameter | Configuration management |
|
||||||
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint, DataLoader | Dataset loading and management |
|
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint, DataLoader | Dataset loading and management |
|
||||||
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||||
| **astrai.tokenize** | Tokenizer, BpeTokenizer | Tokenizer |
|
| **astrai.tokenize** | AutoTokenizer, BpeTokenizer, ChatTemplate, BpeTrainer | Tokenizer |
|
||||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
|
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
|
||||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest | Inference service with continuous batching |
|
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest | Inference service with continuous batching |
|
||||||
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
||||||
|
|
@ -503,6 +515,7 @@ classDiagram
|
||||||
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
|
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
|
||||||
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
|
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
|
||||||
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
|
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
|
||||||
|
| **Generator Pattern** | `_StreamingResult`, `_NonStreamingResult` | Event-based result notification for streaming/non-streaming generation |
|
||||||
|
|
||||||
### Core Relationships
|
### Core Relationships
|
||||||
|
|
||||||
|
|
@ -540,4 +553,32 @@ $$
|
||||||
L_{\text{DPO}} = -\mathbb{E}_{(x, y_w, y_l) \sim D} \left[ \log \sigma\left( \beta \log \frac{\pi_\theta(y_w \mid x)}{\pi_{\text{ref}}(y_w \mid x)} - \beta \log \frac{\pi_\theta(y_l \mid x)}{\pi_{\text{ref}}(y_l \mid x)} \right) \right]
|
L_{\text{DPO}} = -\mathbb{E}_{(x, y_w, y_l) \sim D} \left[ \log \sigma\left( \beta \log \frac{\pi_\theta(y_w \mid x)}{\pi_{\text{ref}}(y_w \mid x)} - \beta \log \frac{\pi_\theta(y_l \mid x)}{\pi_{\text{ref}}(y_l \mid x)} \right) \right]
|
||||||
$$
|
$$
|
||||||
|
|
||||||
|
**GRPO:**
|
||||||
|
|
||||||
|
GRPO (Group Relative Policy Optimization) computes advantages from multiple responses to the same prompt, then optimizes using a PPO-style clipped objective:
|
||||||
|
|
||||||
|
$$
|
||||||
|
\text{Advantage}_i = \frac{r_i - \mu}{\sigma + \epsilon}
|
||||||
|
$$
|
||||||
|
|
||||||
|
Where $r_i$ is the reward for the $i$-th response, $\mu$ and $\sigma$ are the mean and standard deviation of group rewards.
|
||||||
|
|
||||||
|
$$
|
||||||
|
L_{\text{GRPO}} = -\mathbb{E} \left[ \min\left( \frac{\pi_\theta(a|s)}{\pi_{\text{ref}}(a|s)} \cdot A, \text{clip}\left(\frac{\pi_\theta(a|s)}{\pi_{\text{ref}}(a|s)}, 1-\epsilon, 1+\epsilon\right) \cdot A \right) \right] + \lambda \cdot D_{KL}
|
||||||
|
$$
|
||||||
|
|
||||||
|
In this implementation, an off-policy approach is used ($\pi_\theta = \pi_{\text{ref}}$), and the policy loss simplifies to:
|
||||||
|
|
||||||
|
$$
|
||||||
|
L_{\text{policy}} = -\mathbb{E}[A]
|
||||||
|
$$
|
||||||
|
|
||||||
|
The KL divergence term uses mean squared error approximation:
|
||||||
|
|
||||||
|
$$
|
||||||
|
L_{KL} = \lambda \cdot \mathbb{E} \left[ (\log \pi_\theta - \log \pi_{\text{ref}})^2 \right]
|
||||||
|
$$
|
||||||
|
|
||||||
|
The final loss is the sum of both: $L = L_{\text{policy}} + L_{KL}$
|
||||||
|
|
||||||
Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence.
|
Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence.
|
||||||
|
|
|
||||||
|
|
@ -140,3 +140,54 @@ o_n = \sum_j \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}}\right)v_{j}
|
||||||
$$
|
$$
|
||||||
|
|
||||||
In the above expression, only k and v have length indices, while $q$ does not. Therefore, during the calculation process, the input of $q$ is fixed as the last token from the previous input, while $k$ and $v$ need to be cached for parts of different lengths. Also, when caching, note that position encoding calculation should be performed before KV cache computation, otherwise there will be position encoding calculation errors.
|
In the above expression, only k and v have length indices, while $q$ does not. Therefore, during the calculation process, the input of $q$ is fixed as the last token from the previous input, while $k$ and $v$ need to be cached for parts of different lengths. Also, when caching, note that position encoding calculation should be performed before KV cache computation, otherwise there will be position encoding calculation errors.
|
||||||
|
|
||||||
|
### 4. AutoModel Loading
|
||||||
|
|
||||||
|
The project now uses the **AutoModel** base class for flexible model loading and saving:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from astrai.model import AutoModel
|
||||||
|
|
||||||
|
# Load model from checkpoint
|
||||||
|
model = AutoModel.from_pretrained("path/to/model")
|
||||||
|
|
||||||
|
# Save model to new directory
|
||||||
|
model.save_pretrained("path/to/save")
|
||||||
|
```
|
||||||
|
|
||||||
|
The Transformer model is registered via `@AutoModel.register('transformer')` decorator, allowing easy extension for new model types. The `from_pretrained` method automatically loads the `config.json` to determine the model type and uses safetensors format for weights.
|
||||||
|
|
||||||
|
### 5. Continuous Batching Inference
|
||||||
|
|
||||||
|
The inference engine supports **continuous batching** for efficient batch processing:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from astrai.inference import InferenceEngine, GenerationRequest
|
||||||
|
|
||||||
|
# Create inference engine with continuous batching
|
||||||
|
engine = InferenceEngine(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_batch_size=8,
|
||||||
|
max_seq_len=4096,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use GenerationRequest with messages format
|
||||||
|
request = GenerationRequest(
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
],
|
||||||
|
temperature=0.8,
|
||||||
|
top_p=0.95,
|
||||||
|
top_k=50,
|
||||||
|
max_len=1024,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate with streaming
|
||||||
|
for token in engine.generate_with_request(request):
|
||||||
|
print(token, end="", flush=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
The continuous batching feature allows dynamic batch composition where new requests can join at any time and completed requests are released immediately.
|
||||||
|
|
@ -62,6 +62,9 @@
|
||||||
| `--window_size` | Maximum input sequence length | model config max_len |
|
| `--window_size` | Maximum input sequence length | model config max_len |
|
||||||
| `--stride` | Input sequence stride | - |
|
| `--stride` | Input sequence stride | - |
|
||||||
| `--dpo_beta` | DPO beta value | 0.1 |
|
| `--dpo_beta` | DPO beta value | 0.1 |
|
||||||
|
| `--grpo_clip_eps` | GRPO clip epsilon | 0.2 |
|
||||||
|
| `--grpo_kl_coef` | GRPO KL coefficient | 0.01 |
|
||||||
|
| `--grpo_group_size` | GRPO group size | 4 |
|
||||||
| `--label_smoothing` | Label smoothing parameter | 0.1 |
|
| `--label_smoothing` | Label smoothing parameter | 0.1 |
|
||||||
| `--start_epoch` | Starting epoch | 0 |
|
| `--start_epoch` | Starting epoch | 0 |
|
||||||
| `--start_batch` | Starting batch | 0 |
|
| `--start_batch` | Starting batch | 0 |
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue