From ff43a2fab8d00266335e7552bddfb9a6c9b3b481 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 5 Apr 2026 00:17:35 +0800 Subject: [PATCH] =?UTF-8?q?docs:=20=E6=9B=B4=E6=96=B0=E8=AE=BE=E8=AE=A1?= =?UTF-8?q?=E6=96=87=E6=A1=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- assets/docs/dataflow.md | 33 ++++++------- assets/docs/design.md | 105 +++++++++++++++++++++++----------------- assets/docs/params.md | 24 ++++----- 3 files changed, 89 insertions(+), 73 deletions(-) diff --git a/assets/docs/dataflow.md b/assets/docs/dataflow.md index fc84bbf..3f2ce2c 100644 --- a/assets/docs/dataflow.md +++ b/assets/docs/dataflow.md @@ -8,7 +8,7 @@ AstrAI adopts a modular design with the following main components: - **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools - **Model Module** (`astrai/model/`): Transformer model and its submodules - **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers -- **Inference Module** (`astrai/inference/`): Generation core, KV cache management, streaming generation +- **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation - **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations - **Parallel Module** (`astrai/parallel/`): Distributed training support @@ -45,11 +45,11 @@ flowchart LR C1[Checkpoint] --> C2[ModelParameter] C2 --> C3[Transformer + BpeTokenizer] C3 --> C4[GenerationRequest + build_prompt] - C4 --> C5[GeneratorFactory] - C5 --> C6[GeneratorCore] + C4 --> C5[InferenceEngine] + C5 --> C6[InferenceScheduler] C6 --> C7[apply_sampling_strategies] C7 --> C8[Transformer Forward] - C8 --> C9[KVCacheManager] + C8 --> C9[KV Cache] C9 --> C10{End Condition?} C10 -->|No| C8 C10 -->|Yes| C11[Output Text] @@ -124,22 +124,21 @@ flowchart LR ### 4. Inference Module -#### 4.1 Generation Core (`core.py`) -- **`GeneratorCore`**: Provides `generate_iterator` method, executes single-step generation -- Applies sampling strategies (temperature, top-k, top-p) to filter logits -- Supports KV cache to accelerate autoregressive generation +#### 4.1 Inference Engine (`engine.py`) +- **`InferenceEngine`**: Unified inference interface, supports streaming and non-streaming generation +- **`InferenceScheduler`**: Continuous batching scheduler with dynamic batch composition +- Manages task queue (`waiting_queue`, `active_tasks`) and KV cache allocation -#### 4.2 KV Cache Management (`core.py`) -- **`KVCacheManager`**: Manages K and V cache for each layer, supports batch generation and length extension -- Cache shape is `[batch_size, n_kv_heads, seq_len, head_dim]` +#### 4.2 Scheduler (`scheduler.py`) +- **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED) +- **`TaskStatus`**: Task state enumeration +- **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits +- Continuous batching: new requests can join at any time, completed requests are released immediately -#### 4.3 Generator (`generator.py`) -- **`GenerationRequest`**: Encapsulates generation request parameters (top_k, top_p, temperature, max_len, query, history, etc.) +#### 4.3 Request (`engine.py`) +- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, query, history, etc.) - **`build_prompt`** (from `chat_template.py`): Converts query and history into ChatML format prompt string -- **`GeneratorCore`**: Base generator with generate_iterator and generate_loop methods -- **`LoopGenerator`**, **`StreamGenerator`**, **`BatchGenerator`**: Different generation modes -- **`pad_sequence`**: Pads input IDs to consistent length -- Provides streaming and non-streaming generation interfaces +- Provides streaming (`stream=True`) and non-streaming (`stream=False`) generation interfaces ## Training Data Flow - Detailed Steps diff --git a/assets/docs/design.md b/assets/docs/design.md index 5903c40..0260c7d 100644 --- a/assets/docs/design.md +++ b/assets/docs/design.md @@ -311,41 +311,59 @@ classDiagram } namespace astrai.inference { - class GeneratorCore { + class InferenceEngine { +ModelParameter parameter - +generate_iterator(input_ids, temperature, top_k, top_p, attn_mask, kv_caches, start_pos) Tuple~Tensor, int~ - +generate_loop(input_ids, ids, temperature, top_k, top_p, attn_mask, kv_caches, start_pos, callback) List~int~ + +InferenceScheduler scheduler + +generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]] + +generate_with_request(request) Union[Generator, str, List[str]] + +get_stats() Dict + +shutdown() } - class LoopGenerator { - +generate(request) str - } - - class StreamGenerator { - +generate(request) Generator - } - - class BatchGenerator { - +generate(request) List~str~ - } - - class EmbeddingEncoder { - +encode(sentence) Tensor - } - - class KVCacheManager { - +int batch_size + class InferenceScheduler { + +nn.Module model + +Tokenizer tokenizer + +ModelConfig config +Tuple kv_cache +Tensor seq_mask - +get_kvcache() Tuple - +get_seq_mask() Tensor - +update(active_mask) - +reset(full_reset) + +List waiting_queue + +List active_tasks + +add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str + +remove_task(task_id) + +start() + +stop() + +get_stats() Dict } - class GeneratorFactory { - +create(parameter, request) GeneratorCore - +create_encoder(parameter) EmbeddingEncoderCore + class Task { + +str task_id + +List prompt_ids + +int max_tokens + +float temperature + +float top_p + +int top_k + +TaskStatus status + +List output_ids + +int input_tokens + +int output_tokens + +int slot + +Callable stream_callback + +is_finished(stop_ids) bool + } + + class TaskStatus { + +str PENDING + +str RUNNING + +str FINISHED + +str ABORTED + } + + class apply_sampling_strategies { + +Tensor logits + +float temperature + +int top_k + +float top_p + +forward() Tensor } class Server { @@ -396,28 +414,25 @@ classDiagram BaseStrategy <|-- SFTStrategy BaseStrategy <|-- DPOStrategy BaseStrategy <|-- GRPOStrategy + DPOStrategy --> Transformer : creates ref_model + GRPOStrategy --> Transformer : creates ref_model SchedulerFactory ..> BaseScheduler : creates BaseScheduler <|-- CosineScheduler BaseScheduler <|-- SGDRScheduler CallbackFactory ..> TrainCallback : creates - GeneratorFactory ..> GeneratorCore : creates - GeneratorCore <|-- LoopGenerator - GeneratorCore <|-- StreamGenerator - GeneratorCore <|-- BatchGenerator - GeneratorCore <|-- EmbeddingEncoder - LoopGenerator --> KVCacheManager : uses - StreamGenerator --> KVCacheManager : uses - BatchGenerator --> KVCacheManager : uses - GeneratorCore --> Transformer : uses - DPOStrategy --> Transformer : creates ref_model - GRPOStrategy --> Transformer : creates ref_model - Server --> GeneratorFactory : uses + InferenceEngine --> InferenceScheduler : uses + InferenceScheduler --> Task : manages + InferenceScheduler --> TaskStatus : uses + InferenceScheduler --> apply_sampling_strategies : uses + InferenceScheduler --> Transformer : uses + InferenceEngine --> Transformer : uses + InferenceEngine --> GenerationRequest : uses + Server --> InferenceEngine : uses ParallelSetup --> Trainer : enables TrainConfig --> StrategyFactory : selects ModelParameter --> Transformer : contains ModelParameter --> BpeTokenizer : contains ModelParameter --> ModelConfig : contains - GeneratorFactory --> GenerationRequest : uses BaseDataset <|-- SEQDataset BaseDataset <|-- SFTDataset BaseDataset <|-- DPODataset @@ -446,7 +461,7 @@ classDiagram | **astrai.model** | Transformer, DecoderBlock, GQA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.tokenize** | Tokenizer, BpeTokenizer | Tokenizer | | **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management | -| **astrai.inference** | GeneratorCore, LoopGenerator, StreamGenerator, BatchGenerator, EmbeddingEncoder, KVCacheManager, GeneratorFactory, Server, GenerationRequest | Inference service | +| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest | Inference service with continuous batching | | **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel | ### Design Patterns @@ -455,17 +470,19 @@ classDiagram |---------|---------|---------| | **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO | | **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components | -| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `GeneratorFactory`, `CallbackFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, generators, and callbacks | +| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `CallbackFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, and callbacks | | **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) | | **Singleton** | `TrainContext` | Training process global state management | | **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support | +| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management | +| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module | ### Core Relationships 1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references 2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss 3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type` -4. **Inference Flow**: `Server` → `GeneratorFactory` → `GeneratorCore` → `Transformer`, supports multiple generators (LoopGenerator, StreamGenerator, BatchGenerator) +4. **Inference Flow**: `Server` → `InferenceEngine` → `InferenceScheduler` → `Transformer`, supports continuous batching with streaming/non-streaming 5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer` 6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher` 7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors diff --git a/assets/docs/params.md b/assets/docs/params.md index 256bfa4..9d4bf4b 100644 --- a/assets/docs/params.md +++ b/assets/docs/params.md @@ -83,15 +83,15 @@ ### Usage Example ```python -from astrai.config.param_config import ModelParameter -from astrai.inference.generator import StreamGenerator, GenerationRequest +from astrai.config import ModelParameter +from astrai.inference import InferenceEngine, GenerationRequest # Load model param = ModelParameter.load("your_model_dir") param.to(device="cuda", dtype=torch.bfloat16) -# Create generator -generator = StreamGenerator(param) +# Create engine +engine = InferenceEngine(param) # Build request request = GenerationRequest( @@ -102,14 +102,14 @@ request = GenerationRequest( top_k=50, ) -# Generate -response = generator.generate(request) +# Generate (streaming) +for token in engine.generate_with_request(request): + print(token, end="", flush=True) ``` -### Three Types of Generators +### Generation Modes -| Generator | Usage | -|-----------|-------| -| `StreamGenerator` | Streaming output, returns word by word | -| `LoopGenerator` | Non-streaming output, returns at once | -| `BatchGenerator` | Batch generation, processes multiple queries simultaneously | \ No newline at end of file +| Mode | Description | +|------|-------------| +| `stream=True` | Streaming output, yields token by token | +| `stream=False` | Non-streaming output, returns complete result | \ No newline at end of file