299 lines
11 KiB
Markdown
299 lines
11 KiB
Markdown
## Model Introduction
|
||
|
||
### 1. Model Architecture
|
||
|
||
This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA. The model is built by stacking 32 layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token.
|
||
|
||
The model now uses the **AutoModel** base class for flexible loading and saving:
|
||
|
||
```python
|
||
from astrai.model import AutoModel
|
||
|
||
# Load model from checkpoint
|
||
model = AutoModel.from_pretrained("path/to/model")
|
||
|
||
# Save model to new directory
|
||
model.save_pretrained("path/to/save")
|
||
```
|
||
|
||
The Transformer model is registered via `@AutoModel.register('transformer')` decorator, allowing easy extension for new model types.
|
||
|
||
```mermaid
|
||
flowchart TB
|
||
subgraph Layers["Transformer Layers"]
|
||
direction TB
|
||
A[Input Embedding] --> B[Transformer Block\nLayer 1]
|
||
B --> C[Transformer Block\nLayer ...]
|
||
C --> D[Transformer Block\nLayer 32]
|
||
D --> E[RMSNorm]
|
||
E --> F[Linear]
|
||
F --> G[SoftMax]
|
||
end
|
||
|
||
subgraph TransformerBlock["Transformer Block"]
|
||
direction TB
|
||
H[x] --> I[RMSNorm]
|
||
I --> J[Linear → Q/K/V]
|
||
J --> K[Q]
|
||
J --> L[K]
|
||
J --> M[V]
|
||
K --> N[RoPE]
|
||
L --> O[RoPE]
|
||
N --> P["Q @ K^T / sqrt(d)"]
|
||
O --> P
|
||
P --> Q[Masked SoftMax]
|
||
Q --> R[S @ V]
|
||
M --> R
|
||
R --> S[Linear]
|
||
S --> T[+]
|
||
H --> T
|
||
T --> U[RMSNorm]
|
||
U --> V[Linear]
|
||
V --> W[SiLU]
|
||
V --> X[×]
|
||
W --> X
|
||
X --> Y[Linear]
|
||
Y --> Z[+]
|
||
T --> Z
|
||
Z --> AA[x']
|
||
end
|
||
|
||
classDef main fill:#e6f3ff,stroke:#0066cc;
|
||
classDef block fill:#fff2e6,stroke:#cc6600;
|
||
class Layers main;
|
||
class TransformerBlock block;
|
||
```
|
||
|
||
What is an autoregressive model? After splitting a sentence into tokens, the model predicts the probability distribution of the next token. This means the model calculates the probability of the next possible token and its corresponding probability based on the given context (the sequence of tokens that have already appeared).
|
||
|
||
#### 1. Autoregression
|
||
|
||
In autoregressive modeling, when a sentence is tokenized into a sequence of tokens, the model learns to predict what comes next. Given a sequence of tokens as input, the model calculates a probability distribution over all possible next tokens. This distribution tells us how likely each potential next token is, given the current context.
|
||
|
||
For instance, if the input sequence contains tokens representing a question, the model might predict that certain response tokens have higher probabilities than others. The sampling process then selects one token from this distribution—controlled by parameters like top_k, top_p, and temperature—to serve as the next token in the sequence.
|
||
|
||
Once a token is selected, it is appended to the input sequence, and the model repeats this process. The updated sequence is then fed back into the model to predict the next token. This iterative process continues until either a special end-of-sequence token is generated, or the maximum sequence length is reached. These control tokens are essential because without them, the model would continue generating tokens indefinitely, eventually exhausting available memory.
|
||
|
||
#### 2. Causal Mask
|
||
|
||
Transformers use attention mechanism. The input shape is generally [bsz, seq_len], and the output is [bsz, seq_len, n_dim]. To predict the next token, the model's input and output must be offset by one position. The target predicted by the model must be offset by one position, and during training we also use the offset-by-one method:
|
||
|
||
```
|
||
sequence : [[1, 2, 3, 4, 5, 6]]
|
||
input_ids: [[1, 2, 3, 4, 5]]
|
||
target_ids: [[2, 3, 4, 5, 6]]
|
||
```
|
||
|
||
The attention score calculation formula is:
|
||
|
||
$$ s_{ij} = softmax(\frac{q_i^Tk_j}{\sqrt{d_k}}) $$
|
||
$$ s_{ij} := s_{ij} + mask_{ij} $$
|
||
|
||
Here, the attention score represents the degree to which the model attends to the similarity between two tokens.
|
||
|
||
For decoder-only structure models, to prevent the model from "stealing" information from future positions, a mask needs to be added during attention calculation. We need to apply a mask before attention score calculation. This mask is typically a lower triangular matrix, and for a sequence of length n, its shape is [n, n]. Below is an example of how to create such a causal mask matrix for a sequence of length 5:
|
||
|
||
```
|
||
[[0, -inf, -inf, -inf, -inf],
|
||
[0, 0, -inf, -inf, -inf],
|
||
[0, 0, 0, -inf, -inf],
|
||
[0, 0, 0, 0, -inf],
|
||
[0, 0, 0, 0, 0]]
|
||
```
|
||
|
||
In this matrix, 0 represents positions that can be attended to, while -inf represents positions that should be masked (i.e., should not be attended to). Because this matrix ensures that after the softmax, the parts of the attention scores where $j > i$ change from `inf` to 0, meaning the model cannot see future information.
|
||
|
||
#### 3. Rotary Position Embedding
|
||
|
||
Rotary Position Embedding (RoPE) is a position encoding method designed to solve the problem of lacking direct modeling of sequence position information in Transformer models. Unlike traditional position encodings (such as sine and cosine function position encodings), RoPE embeds position information directly into the Query (Q) and Key (K) vectors, allowing the model to more naturally handle relative position relationships in sequences.
|
||
|
||
$$ q_i = R_i W_q x_i $$
|
||
$$ k_j = R_j W_k x_j $$
|
||
$$ q_i^T k_j = (R_i W_q x_i)^T( R_j W_k x_j) = x_i^T W_q^T R_{i-j} W_k x_j $$
|
||
|
||
The $R_{i-j}$ controls the attenuation of attention for different tokens at different relative distances. When the absolute value of $i - j$ is larger, the degree of attenuation is stronger. This approach allows the model to learn relative position relationships, enabling the model to scale and adapt to longer sequences.
|
||
|
||
## KV Cache Implementation
|
||
|
||
According to the attention calculation formula:
|
||
|
||
$$
|
||
\begin{align*}
|
||
o_i &= \sum_j s_{ij} v_{j} \newline
|
||
s_{ij} &= \text{softmax}\left( \frac{q_{i} k_{j}}{\sqrt{d_k}} \right)
|
||
\end{align*}
|
||
$$
|
||
|
||
Since the model is an autoregressive model, we only need to calculate for the last part of the sequence, meaning the index $i$ is fixed as the last element of the sequence, and we compute $o_{n}$:
|
||
|
||
$$
|
||
\begin{align*}
|
||
o_n &= \sum_j s_{j}v_{j} \newline
|
||
s_j &= \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}} \right)
|
||
\end{align*}
|
||
$$
|
||
|
||
If we expand the expression:
|
||
|
||
$$
|
||
o_n = \sum_j \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}}\right)v_{j}
|
||
$$
|
||
|
||
In the above expression, only k and v have length indices, while $q$ does not. Therefore, during the calculation process, the input of $q$ is fixed as the last token from the previous input, while $k$ and $v$ need to be cached for parts of different lengths. Also, when caching, note that position encoding calculation should be performed before KV cache computation, otherwise there will be position encoding calculation errors.
|
||
|
||
### 4. AutoModel Loading
|
||
|
||
The project now uses the **AutoModel** base class for flexible model loading and saving:
|
||
|
||
```python
|
||
from astrai.model import AutoModel
|
||
|
||
# Load model from checkpoint
|
||
model = AutoModel.from_pretrained("path/to/model")
|
||
|
||
# Save model to new directory
|
||
model.save_pretrained("path/to/save")
|
||
```
|
||
|
||
The Transformer model is registered via `@AutoModel.register('transformer')` decorator, allowing easy extension for new model types. The `from_pretrained` method automatically loads the `config.json` to determine the model type and uses safetensors format for weights.
|
||
|
||
### 5. Continuous Batching Inference
|
||
|
||
The inference engine supports **continuous batching** for efficient batch processing:
|
||
|
||
```python
|
||
from astrai.inference import InferenceEngine, GenerationRequest
|
||
|
||
# Create inference engine with continuous batching
|
||
engine = InferenceEngine(
|
||
model=model,
|
||
tokenizer=tokenizer,
|
||
max_batch_size=8,
|
||
max_seq_len=4096,
|
||
)
|
||
|
||
# Use GenerationRequest with messages format
|
||
request = GenerationRequest(
|
||
messages=[
|
||
{"role": "system", "content": "You are a helpful assistant."},
|
||
{"role": "user", "content": "Hello"},
|
||
],
|
||
temperature=0.8,
|
||
top_p=0.95,
|
||
top_k=50,
|
||
max_len=1024,
|
||
stream=True,
|
||
)
|
||
|
||
# Generate with streaming
|
||
for token in engine.generate_with_request(request):
|
||
print(token, end="", flush=True)
|
||
```
|
||
|
||
The continuous batching feature allows dynamic batch composition where new requests can join at any time and completed requests are released immediately.
|
||
|
||
## HTTP API Usage
|
||
|
||
The inference server provides HTTP endpoints for remote inference. Start the server first:
|
||
|
||
```bash
|
||
python -m scripts.tools.server --port 8000
|
||
```
|
||
|
||
### OpenAI-Compatible Endpoint
|
||
|
||
The server provides an OpenAI-compatible chat completion endpoint at `/v1/chat/completions`:
|
||
|
||
```bash
|
||
curl -X POST http://localhost:8000/v1/chat/completions \
|
||
-H "Content-Type: application/json" \
|
||
-d '{
|
||
"messages": [
|
||
{"role": "system", "content": "You are a helpful assistant."},
|
||
{"role": "user", "content": "Hello, how are you?"}
|
||
],
|
||
"temperature": 0.8,
|
||
"max_tokens": 2048,
|
||
"stream": false
|
||
}'
|
||
```
|
||
|
||
**Request Parameters:**
|
||
| Parameter | Type | Default | Description |
|
||
|-----------|------|---------|-------------|
|
||
| `messages` | List[dict] | Required | Chat messages with role and content |
|
||
| `temperature` | float | 0.8 | Sampling temperature (0.0-2.0) |
|
||
| `top_p` | float | 0.95 | Nucleus sampling threshold |
|
||
| `top_k` | int | 50 | Top-k sampling parameter |
|
||
| `max_tokens` | int | 2048 | Maximum tokens to generate |
|
||
| `stream` | bool | false | Enable streaming response |
|
||
| `system_prompt` | str | None | System prompt override |
|
||
|
||
**Response (non-streaming):**
|
||
```json
|
||
{
|
||
"id": "chatcmpl-1234567890",
|
||
"object": "chat.completion",
|
||
"created": 1234567890,
|
||
"model": "astrai",
|
||
"choices": [
|
||
{
|
||
"index": 0,
|
||
"message": {"role": "assistant", "content": "Hello! I'm doing well..."},
|
||
"finish_reason": "stop"
|
||
}
|
||
]
|
||
}
|
||
```
|
||
|
||
### Streaming Response
|
||
|
||
Enable streaming for real-time token-by-token output:
|
||
|
||
```bash
|
||
curl -X POST http://localhost:8000/v1/chat/completions \
|
||
-H "Content-Type: application/json" \
|
||
-d '{
|
||
"messages": [{"role": "user", "content": "Write a story"}],
|
||
"stream": true,
|
||
"max_tokens": 500
|
||
}'
|
||
```
|
||
|
||
The server uses Server-Sent Events (SSE) with content type `text/event-stream`.
|
||
|
||
### Simple Generation Endpoint
|
||
|
||
For basic text generation without chat format:
|
||
|
||
```bash
|
||
curl -X POST "http://localhost:8000/generate?query=Hello&max_len=1000" \
|
||
-H "Content-Type: application/json"
|
||
```
|
||
|
||
Or with conversation history:
|
||
|
||
```bash
|
||
curl -X POST "http://localhost:8000/generate" \
|
||
-H "Content-Type: application/json" \
|
||
-d '{
|
||
"query": "What is AI?",
|
||
"history": [["Hello", "Hi there!"], ["How are you?", "I'm doing well"]],
|
||
"temperature": 0.8,
|
||
"max_len": 2048
|
||
}'
|
||
```
|
||
|
||
### Health Check
|
||
|
||
Monitor server and model status:
|
||
|
||
```bash
|
||
curl http://localhost:8000/health
|
||
# {"status": "ok", "model_loaded": true, "engine_ready": true}
|
||
|
||
curl http://localhost:8000/stats
|
||
# {"requests_total": 10, "tokens_generated": 5000, ...}
|
||
```
|
||
|
||
> Document Update Time: 2026-04-09 |