docs: 更新文档
This commit is contained in:
parent
feaa3fca36
commit
bf7adb35b3
|
|
@ -95,6 +95,7 @@ flowchart LR
|
|||
- Contains embedding layer, multi-layer `DecoderBlock`, RMSNorm, and linear output head
|
||||
- Supports weight tying (`tie_weight=True`) to reduce parameter count
|
||||
- Uses Rotary Position Embedding (RoPE) to inject position information
|
||||
- Supports loading from safetensors format with automatic model type detection from `config.json`
|
||||
|
||||
#### 2.2 Submodules (`module.py`)
|
||||
- **`RotaryEmbedding`**: Generates RoPE cos/sin cache
|
||||
|
|
@ -137,7 +138,12 @@ flowchart LR
|
|||
#### 5.1 Inference Engine (`engine.py`)
|
||||
- **`InferenceEngine`**: Unified inference interface, supports streaming and non-streaming generation
|
||||
- **`InferenceScheduler`**: Continuous batching scheduler with dynamic batch composition
|
||||
- Manages task queue (`waiting_queue`, `active_tasks`) and KV cache allocation
|
||||
- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.)
|
||||
- **`messages` format**: List of message dictionaries with `role` (system/user/assistant) and `content`
|
||||
- **`apply_chat_template`** (from `tokenizer.py`): Converts messages into prompt string using ChatML format
|
||||
- Provides streaming (`stream=True`) and non-streaming (`stream=False`) generation interfaces
|
||||
- Supports continuous batching with `max_batch_size` and `max_seq_len` parameters
|
||||
- Uses separate model and tokenizer initialization for flexibility
|
||||
|
||||
#### 5.2 Scheduler (`scheduler.py`)
|
||||
- **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED)
|
||||
|
|
|
|||
|
|
@ -175,6 +175,17 @@ classDiagram
|
|||
+forward(x, rotary_emb, mask, kv_cache, start_pos) Tensor
|
||||
}
|
||||
|
||||
class MLA {
|
||||
+int n_heads
|
||||
+int n_kv_heads
|
||||
+int head_dim
|
||||
+Linear q_a_proj, q_b_proj, q_c_proj
|
||||
+Linear kv_a_proj, kv_b_proj, kv_c_proj
|
||||
+Linear o_proj
|
||||
+RMSNorm q_norm, k_norm
|
||||
+forward(x, rotary_emb, mask, kv_cache, start_pos) Tensor
|
||||
}
|
||||
|
||||
class MLP {
|
||||
+Linear up, gate, down
|
||||
+forward(x) Tensor
|
||||
|
|
@ -469,6 +480,7 @@ classDiagram
|
|||
Transformer --> RotaryEmbedding : uses
|
||||
Transformer --> Embedding : uses
|
||||
DecoderBlock --> GQA : uses
|
||||
DecoderBlock --> MLA : uses
|
||||
DecoderBlock --> MLP : uses
|
||||
DecoderBlock --> RMSNorm : uses
|
||||
BpeTokenizer --> Tokenizer : inherits
|
||||
|
|
@ -483,8 +495,8 @@ classDiagram
|
|||
|--------|------------|-------------|
|
||||
| **astrai.config** | ModelConfig, TrainConfig, ModelParameter | Configuration management |
|
||||
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint, DataLoader | Dataset loading and management |
|
||||
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||
| **astrai.tokenize** | Tokenizer, BpeTokenizer | Tokenizer |
|
||||
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||
| **astrai.tokenize** | AutoTokenizer, BpeTokenizer, ChatTemplate, BpeTrainer | Tokenizer |
|
||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
|
||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest | Inference service with continuous batching |
|
||||
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
||||
|
|
@ -503,6 +515,7 @@ classDiagram
|
|||
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
|
||||
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
|
||||
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
|
||||
| **Generator Pattern** | `_StreamingResult`, `_NonStreamingResult` | Event-based result notification for streaming/non-streaming generation |
|
||||
|
||||
### Core Relationships
|
||||
|
||||
|
|
@ -540,4 +553,32 @@ $$
|
|||
L_{\text{DPO}} = -\mathbb{E}_{(x, y_w, y_l) \sim D} \left[ \log \sigma\left( \beta \log \frac{\pi_\theta(y_w \mid x)}{\pi_{\text{ref}}(y_w \mid x)} - \beta \log \frac{\pi_\theta(y_l \mid x)}{\pi_{\text{ref}}(y_l \mid x)} \right) \right]
|
||||
$$
|
||||
|
||||
**GRPO:**
|
||||
|
||||
GRPO (Group Relative Policy Optimization) computes advantages from multiple responses to the same prompt, then optimizes using a PPO-style clipped objective:
|
||||
|
||||
$$
|
||||
\text{Advantage}_i = \frac{r_i - \mu}{\sigma + \epsilon}
|
||||
$$
|
||||
|
||||
Where $r_i$ is the reward for the $i$-th response, $\mu$ and $\sigma$ are the mean and standard deviation of group rewards.
|
||||
|
||||
$$
|
||||
L_{\text{GRPO}} = -\mathbb{E} \left[ \min\left( \frac{\pi_\theta(a|s)}{\pi_{\text{ref}}(a|s)} \cdot A, \text{clip}\left(\frac{\pi_\theta(a|s)}{\pi_{\text{ref}}(a|s)}, 1-\epsilon, 1+\epsilon\right) \cdot A \right) \right] + \lambda \cdot D_{KL}
|
||||
$$
|
||||
|
||||
In this implementation, an off-policy approach is used ($\pi_\theta = \pi_{\text{ref}}$), and the policy loss simplifies to:
|
||||
|
||||
$$
|
||||
L_{\text{policy}} = -\mathbb{E}[A]
|
||||
$$
|
||||
|
||||
The KL divergence term uses mean squared error approximation:
|
||||
|
||||
$$
|
||||
L_{KL} = \lambda \cdot \mathbb{E} \left[ (\log \pi_\theta - \log \pi_{\text{ref}})^2 \right]
|
||||
$$
|
||||
|
||||
The final loss is the sum of both: $L = L_{\text{policy}} + L_{KL}$
|
||||
|
||||
Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence.
|
||||
|
|
|
|||
|
|
@ -139,4 +139,55 @@ $$
|
|||
o_n = \sum_j \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}}\right)v_{j}
|
||||
$$
|
||||
|
||||
In the above expression, only k and v have length indices, while $q$ does not. Therefore, during the calculation process, the input of $q$ is fixed as the last token from the previous input, while $k$ and $v$ need to be cached for parts of different lengths. Also, when caching, note that position encoding calculation should be performed before KV cache computation, otherwise there will be position encoding calculation errors.
|
||||
In the above expression, only k and v have length indices, while $q$ does not. Therefore, during the calculation process, the input of $q$ is fixed as the last token from the previous input, while $k$ and $v$ need to be cached for parts of different lengths. Also, when caching, note that position encoding calculation should be performed before KV cache computation, otherwise there will be position encoding calculation errors.
|
||||
|
||||
### 4. AutoModel Loading
|
||||
|
||||
The project now uses the **AutoModel** base class for flexible model loading and saving:
|
||||
|
||||
```python
|
||||
from astrai.model import AutoModel
|
||||
|
||||
# Load model from checkpoint
|
||||
model = AutoModel.from_pretrained("path/to/model")
|
||||
|
||||
# Save model to new directory
|
||||
model.save_pretrained("path/to/save")
|
||||
```
|
||||
|
||||
The Transformer model is registered via `@AutoModel.register('transformer')` decorator, allowing easy extension for new model types. The `from_pretrained` method automatically loads the `config.json` to determine the model type and uses safetensors format for weights.
|
||||
|
||||
### 5. Continuous Batching Inference
|
||||
|
||||
The inference engine supports **continuous batching** for efficient batch processing:
|
||||
|
||||
```python
|
||||
from astrai.inference import InferenceEngine, GenerationRequest
|
||||
|
||||
# Create inference engine with continuous batching
|
||||
engine = InferenceEngine(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
max_batch_size=8,
|
||||
max_seq_len=4096,
|
||||
)
|
||||
|
||||
# Use GenerationRequest with messages format
|
||||
request = GenerationRequest(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
],
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
top_k=50,
|
||||
max_len=1024,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Generate with streaming
|
||||
for token in engine.generate_with_request(request):
|
||||
print(token, end="", flush=True)
|
||||
```
|
||||
|
||||
The continuous batching feature allows dynamic batch composition where new requests can join at any time and completed requests are released immediately.
|
||||
|
|
@ -62,6 +62,9 @@
|
|||
| `--window_size` | Maximum input sequence length | model config max_len |
|
||||
| `--stride` | Input sequence stride | - |
|
||||
| `--dpo_beta` | DPO beta value | 0.1 |
|
||||
| `--grpo_clip_eps` | GRPO clip epsilon | 0.2 |
|
||||
| `--grpo_kl_coef` | GRPO KL coefficient | 0.01 |
|
||||
| `--grpo_group_size` | GRPO group size | 4 |
|
||||
| `--label_smoothing` | Label smoothing parameter | 0.1 |
|
||||
| `--start_epoch` | Starting epoch | 0 |
|
||||
| `--start_batch` | Starting batch | 0 |
|
||||
|
|
|
|||
Loading…
Reference in New Issue