diff --git a/assets/docs/dataflow.md b/assets/docs/dataflow.md index 18e7416..158cc48 100644 --- a/assets/docs/dataflow.md +++ b/assets/docs/dataflow.md @@ -95,6 +95,7 @@ flowchart LR - Contains embedding layer, multi-layer `DecoderBlock`, RMSNorm, and linear output head - Supports weight tying (`tie_weight=True`) to reduce parameter count - Uses Rotary Position Embedding (RoPE) to inject position information +- Supports loading from safetensors format with automatic model type detection from `config.json` #### 2.2 Submodules (`module.py`) - **`RotaryEmbedding`**: Generates RoPE cos/sin cache @@ -137,7 +138,12 @@ flowchart LR #### 5.1 Inference Engine (`engine.py`) - **`InferenceEngine`**: Unified inference interface, supports streaming and non-streaming generation - **`InferenceScheduler`**: Continuous batching scheduler with dynamic batch composition -- Manages task queue (`waiting_queue`, `active_tasks`) and KV cache allocation +- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.) +- **`messages` format**: List of message dictionaries with `role` (system/user/assistant) and `content` +- **`apply_chat_template`** (from `tokenizer.py`): Converts messages into prompt string using ChatML format +- Provides streaming (`stream=True`) and non-streaming (`stream=False`) generation interfaces +- Supports continuous batching with `max_batch_size` and `max_seq_len` parameters +- Uses separate model and tokenizer initialization for flexibility #### 5.2 Scheduler (`scheduler.py`) - **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED) diff --git a/assets/docs/design.md b/assets/docs/design.md index 0ca3b93..c7d929a 100644 --- a/assets/docs/design.md +++ b/assets/docs/design.md @@ -175,6 +175,17 @@ classDiagram +forward(x, rotary_emb, mask, kv_cache, start_pos) Tensor } + class MLA { + +int n_heads + +int n_kv_heads + +int head_dim + +Linear q_a_proj, q_b_proj, q_c_proj + +Linear kv_a_proj, kv_b_proj, kv_c_proj + +Linear o_proj + +RMSNorm q_norm, k_norm + +forward(x, rotary_emb, mask, kv_cache, start_pos) Tensor + } + class MLP { +Linear up, gate, down +forward(x) Tensor @@ -469,6 +480,7 @@ classDiagram Transformer --> RotaryEmbedding : uses Transformer --> Embedding : uses DecoderBlock --> GQA : uses + DecoderBlock --> MLA : uses DecoderBlock --> MLP : uses DecoderBlock --> RMSNorm : uses BpeTokenizer --> Tokenizer : inherits @@ -483,8 +495,8 @@ classDiagram |--------|------------|-------------| | **astrai.config** | ModelConfig, TrainConfig, ModelParameter | Configuration management | | **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint, DataLoader | Dataset loading and management | -| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | -| **astrai.tokenize** | Tokenizer, BpeTokenizer | Tokenizer | +| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | +| **astrai.tokenize** | AutoTokenizer, BpeTokenizer, ChatTemplate, BpeTrainer | Tokenizer | | **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management | | **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest | Inference service with continuous batching | | **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel | @@ -503,6 +515,7 @@ classDiagram | **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management | | **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module | | **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern | +| **Generator Pattern** | `_StreamingResult`, `_NonStreamingResult` | Event-based result notification for streaming/non-streaming generation | ### Core Relationships @@ -540,4 +553,32 @@ $$ L_{\text{DPO}} = -\mathbb{E}_{(x, y_w, y_l) \sim D} \left[ \log \sigma\left( \beta \log \frac{\pi_\theta(y_w \mid x)}{\pi_{\text{ref}}(y_w \mid x)} - \beta \log \frac{\pi_\theta(y_l \mid x)}{\pi_{\text{ref}}(y_l \mid x)} \right) \right] $$ +**GRPO:** + +GRPO (Group Relative Policy Optimization) computes advantages from multiple responses to the same prompt, then optimizes using a PPO-style clipped objective: + +$$ +\text{Advantage}_i = \frac{r_i - \mu}{\sigma + \epsilon} +$$ + +Where $r_i$ is the reward for the $i$-th response, $\mu$ and $\sigma$ are the mean and standard deviation of group rewards. + +$$ +L_{\text{GRPO}} = -\mathbb{E} \left[ \min\left( \frac{\pi_\theta(a|s)}{\pi_{\text{ref}}(a|s)} \cdot A, \text{clip}\left(\frac{\pi_\theta(a|s)}{\pi_{\text{ref}}(a|s)}, 1-\epsilon, 1+\epsilon\right) \cdot A \right) \right] + \lambda \cdot D_{KL} +$$ + +In this implementation, an off-policy approach is used ($\pi_\theta = \pi_{\text{ref}}$), and the policy loss simplifies to: + +$$ +L_{\text{policy}} = -\mathbb{E}[A] +$$ + +The KL divergence term uses mean squared error approximation: + +$$ +L_{KL} = \lambda \cdot \mathbb{E} \left[ (\log \pi_\theta - \log \pi_{\text{ref}})^2 \right] +$$ + +The final loss is the sum of both: $L = L_{\text{policy}} + L_{KL}$ + Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence. diff --git a/assets/docs/introduction.md b/assets/docs/introduction.md index 4f8604d..1f75e78 100644 --- a/assets/docs/introduction.md +++ b/assets/docs/introduction.md @@ -139,4 +139,55 @@ $$ o_n = \sum_j \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}}\right)v_{j} $$ -In the above expression, only k and v have length indices, while $q$ does not. Therefore, during the calculation process, the input of $q$ is fixed as the last token from the previous input, while $k$ and $v$ need to be cached for parts of different lengths. Also, when caching, note that position encoding calculation should be performed before KV cache computation, otherwise there will be position encoding calculation errors. \ No newline at end of file +In the above expression, only k and v have length indices, while $q$ does not. Therefore, during the calculation process, the input of $q$ is fixed as the last token from the previous input, while $k$ and $v$ need to be cached for parts of different lengths. Also, when caching, note that position encoding calculation should be performed before KV cache computation, otherwise there will be position encoding calculation errors. + +### 4. AutoModel Loading + +The project now uses the **AutoModel** base class for flexible model loading and saving: + +```python +from astrai.model import AutoModel + +# Load model from checkpoint +model = AutoModel.from_pretrained("path/to/model") + +# Save model to new directory +model.save_pretrained("path/to/save") +``` + +The Transformer model is registered via `@AutoModel.register('transformer')` decorator, allowing easy extension for new model types. The `from_pretrained` method automatically loads the `config.json` to determine the model type and uses safetensors format for weights. + +### 5. Continuous Batching Inference + +The inference engine supports **continuous batching** for efficient batch processing: + +```python +from astrai.inference import InferenceEngine, GenerationRequest + +# Create inference engine with continuous batching +engine = InferenceEngine( + model=model, + tokenizer=tokenizer, + max_batch_size=8, + max_seq_len=4096, +) + +# Use GenerationRequest with messages format +request = GenerationRequest( + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}, + ], + temperature=0.8, + top_p=0.95, + top_k=50, + max_len=1024, + stream=True, +) + +# Generate with streaming +for token in engine.generate_with_request(request): + print(token, end="", flush=True) +``` + +The continuous batching feature allows dynamic batch composition where new requests can join at any time and completed requests are released immediately. \ No newline at end of file diff --git a/assets/docs/params.md b/assets/docs/params.md index 72a246b..e1f8dbc 100644 --- a/assets/docs/params.md +++ b/assets/docs/params.md @@ -62,6 +62,9 @@ | `--window_size` | Maximum input sequence length | model config max_len | | `--stride` | Input sequence stride | - | | `--dpo_beta` | DPO beta value | 0.1 | +| `--grpo_clip_eps` | GRPO clip epsilon | 0.2 | +| `--grpo_kl_coef` | GRPO KL coefficient | 0.01 | +| `--grpo_group_size` | GRPO group size | 4 | | `--label_smoothing` | Label smoothing parameter | 0.1 | | `--start_epoch` | Starting epoch | 0 | | `--start_batch` | Starting batch | 0 |