docs: 更新设计文档

This commit is contained in:
ViperEkura 2026-04-05 00:17:35 +08:00
parent 2b26f03bd3
commit ff43a2fab8
3 changed files with 89 additions and 73 deletions

View File

@ -8,7 +8,7 @@ AstrAI adopts a modular design with the following main components:
- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools - **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools
- **Model Module** (`astrai/model/`): Transformer model and its submodules - **Model Module** (`astrai/model/`): Transformer model and its submodules
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers - **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers
- **Inference Module** (`astrai/inference/`): Generation core, KV cache management, streaming generation - **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation
- **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations - **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations
- **Parallel Module** (`astrai/parallel/`): Distributed training support - **Parallel Module** (`astrai/parallel/`): Distributed training support
@ -45,11 +45,11 @@ flowchart LR
C1[Checkpoint] --> C2[ModelParameter] C1[Checkpoint] --> C2[ModelParameter]
C2 --> C3[Transformer + BpeTokenizer] C2 --> C3[Transformer + BpeTokenizer]
C3 --> C4[GenerationRequest + build_prompt] C3 --> C4[GenerationRequest + build_prompt]
C4 --> C5[GeneratorFactory] C4 --> C5[InferenceEngine]
C5 --> C6[GeneratorCore] C5 --> C6[InferenceScheduler]
C6 --> C7[apply_sampling_strategies] C6 --> C7[apply_sampling_strategies]
C7 --> C8[Transformer Forward] C7 --> C8[Transformer Forward]
C8 --> C9[KVCacheManager] C8 --> C9[KV Cache]
C9 --> C10{End Condition?} C9 --> C10{End Condition?}
C10 -->|No| C8 C10 -->|No| C8
C10 -->|Yes| C11[Output Text] C10 -->|Yes| C11[Output Text]
@ -124,22 +124,21 @@ flowchart LR
### 4. Inference Module ### 4. Inference Module
#### 4.1 Generation Core (`core.py`) #### 4.1 Inference Engine (`engine.py`)
- **`GeneratorCore`**: Provides `generate_iterator` method, executes single-step generation - **`InferenceEngine`**: Unified inference interface, supports streaming and non-streaming generation
- Applies sampling strategies (temperature, top-k, top-p) to filter logits - **`InferenceScheduler`**: Continuous batching scheduler with dynamic batch composition
- Supports KV cache to accelerate autoregressive generation - Manages task queue (`waiting_queue`, `active_tasks`) and KV cache allocation
#### 4.2 KV Cache Management (`core.py`) #### 4.2 Scheduler (`scheduler.py`)
- **`KVCacheManager`**: Manages K and V cache for each layer, supports batch generation and length extension - **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED)
- Cache shape is `[batch_size, n_kv_heads, seq_len, head_dim]` - **`TaskStatus`**: Task state enumeration
- **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits
- Continuous batching: new requests can join at any time, completed requests are released immediately
#### 4.3 Generator (`generator.py`) #### 4.3 Request (`engine.py`)
- **`GenerationRequest`**: Encapsulates generation request parameters (top_k, top_p, temperature, max_len, query, history, etc.) - **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, query, history, etc.)
- **`build_prompt`** (from `chat_template.py`): Converts query and history into ChatML format prompt string - **`build_prompt`** (from `chat_template.py`): Converts query and history into ChatML format prompt string
- **`GeneratorCore`**: Base generator with generate_iterator and generate_loop methods - Provides streaming (`stream=True`) and non-streaming (`stream=False`) generation interfaces
- **`LoopGenerator`**, **`StreamGenerator`**, **`BatchGenerator`**: Different generation modes
- **`pad_sequence`**: Pads input IDs to consistent length
- Provides streaming and non-streaming generation interfaces
## Training Data Flow - Detailed Steps ## Training Data Flow - Detailed Steps

View File

@ -311,41 +311,59 @@ classDiagram
} }
namespace astrai.inference { namespace astrai.inference {
class GeneratorCore { class InferenceEngine {
+ModelParameter parameter +ModelParameter parameter
+generate_iterator(input_ids, temperature, top_k, top_p, attn_mask, kv_caches, start_pos) Tuple~Tensor, int~ +InferenceScheduler scheduler
+generate_loop(input_ids, ids, temperature, top_k, top_p, attn_mask, kv_caches, start_pos, callback) List~int~ +generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
+generate_with_request(request) Union[Generator, str, List[str]]
+get_stats() Dict
+shutdown()
} }
class LoopGenerator { class InferenceScheduler {
+generate(request) str +nn.Module model
} +Tokenizer tokenizer
+ModelConfig config
class StreamGenerator {
+generate(request) Generator
}
class BatchGenerator {
+generate(request) List~str~
}
class EmbeddingEncoder {
+encode(sentence) Tensor
}
class KVCacheManager {
+int batch_size
+Tuple kv_cache +Tuple kv_cache
+Tensor seq_mask +Tensor seq_mask
+get_kvcache() Tuple +List waiting_queue
+get_seq_mask() Tensor +List active_tasks
+update(active_mask) +add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
+reset(full_reset) +remove_task(task_id)
+start()
+stop()
+get_stats() Dict
} }
class GeneratorFactory { class Task {
+create(parameter, request) GeneratorCore +str task_id
+create_encoder(parameter) EmbeddingEncoderCore +List prompt_ids
+int max_tokens
+float temperature
+float top_p
+int top_k
+TaskStatus status
+List output_ids
+int input_tokens
+int output_tokens
+int slot
+Callable stream_callback
+is_finished(stop_ids) bool
}
class TaskStatus {
+str PENDING
+str RUNNING
+str FINISHED
+str ABORTED
}
class apply_sampling_strategies {
+Tensor logits
+float temperature
+int top_k
+float top_p
+forward() Tensor
} }
class Server { class Server {
@ -396,28 +414,25 @@ classDiagram
BaseStrategy <|-- SFTStrategy BaseStrategy <|-- SFTStrategy
BaseStrategy <|-- DPOStrategy BaseStrategy <|-- DPOStrategy
BaseStrategy <|-- GRPOStrategy BaseStrategy <|-- GRPOStrategy
DPOStrategy --> Transformer : creates ref_model
GRPOStrategy --> Transformer : creates ref_model
SchedulerFactory ..> BaseScheduler : creates SchedulerFactory ..> BaseScheduler : creates
BaseScheduler <|-- CosineScheduler BaseScheduler <|-- CosineScheduler
BaseScheduler <|-- SGDRScheduler BaseScheduler <|-- SGDRScheduler
CallbackFactory ..> TrainCallback : creates CallbackFactory ..> TrainCallback : creates
GeneratorFactory ..> GeneratorCore : creates InferenceEngine --> InferenceScheduler : uses
GeneratorCore <|-- LoopGenerator InferenceScheduler --> Task : manages
GeneratorCore <|-- StreamGenerator InferenceScheduler --> TaskStatus : uses
GeneratorCore <|-- BatchGenerator InferenceScheduler --> apply_sampling_strategies : uses
GeneratorCore <|-- EmbeddingEncoder InferenceScheduler --> Transformer : uses
LoopGenerator --> KVCacheManager : uses InferenceEngine --> Transformer : uses
StreamGenerator --> KVCacheManager : uses InferenceEngine --> GenerationRequest : uses
BatchGenerator --> KVCacheManager : uses Server --> InferenceEngine : uses
GeneratorCore --> Transformer : uses
DPOStrategy --> Transformer : creates ref_model
GRPOStrategy --> Transformer : creates ref_model
Server --> GeneratorFactory : uses
ParallelSetup --> Trainer : enables ParallelSetup --> Trainer : enables
TrainConfig --> StrategyFactory : selects TrainConfig --> StrategyFactory : selects
ModelParameter --> Transformer : contains ModelParameter --> Transformer : contains
ModelParameter --> BpeTokenizer : contains ModelParameter --> BpeTokenizer : contains
ModelParameter --> ModelConfig : contains ModelParameter --> ModelConfig : contains
GeneratorFactory --> GenerationRequest : uses
BaseDataset <|-- SEQDataset BaseDataset <|-- SEQDataset
BaseDataset <|-- SFTDataset BaseDataset <|-- SFTDataset
BaseDataset <|-- DPODataset BaseDataset <|-- DPODataset
@ -446,7 +461,7 @@ classDiagram
| **astrai.model** | Transformer, DecoderBlock, GQA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.model** | Transformer, DecoderBlock, GQA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | Tokenizer, BpeTokenizer | Tokenizer | | **astrai.tokenize** | Tokenizer, BpeTokenizer | Tokenizer |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management | | **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
| **astrai.inference** | GeneratorCore, LoopGenerator, StreamGenerator, BatchGenerator, EmbeddingEncoder, KVCacheManager, GeneratorFactory, Server, GenerationRequest | Inference service | | **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest | Inference service with continuous batching |
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel | | **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
### Design Patterns ### Design Patterns
@ -455,17 +470,19 @@ classDiagram
|---------|---------|---------| |---------|---------|---------|
| **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO | | **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO |
| **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components | | **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components |
| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `GeneratorFactory`, `CallbackFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, generators, and callbacks | | **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `CallbackFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, and callbacks |
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) | | **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
| **Singleton** | `TrainContext` | Training process global state management | | **Singleton** | `TrainContext` | Training process global state management |
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support | | **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
### Core Relationships ### Core Relationships
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references 1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` to compute loss 2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` to compute loss
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type` 3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
4. **Inference Flow**: `Server``GeneratorFactory` → `GeneratorCore``Transformer`, supports multiple generators (LoopGenerator, StreamGenerator, BatchGenerator) 4. **Inference Flow**: `Server``InferenceEngine` → `InferenceScheduler``Transformer`, supports continuous batching with streaming/non-streaming
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer` 5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher` 6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors 7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors

View File

@ -83,15 +83,15 @@
### Usage Example ### Usage Example
```python ```python
from astrai.config.param_config import ModelParameter from astrai.config import ModelParameter
from astrai.inference.generator import StreamGenerator, GenerationRequest from astrai.inference import InferenceEngine, GenerationRequest
# Load model # Load model
param = ModelParameter.load("your_model_dir") param = ModelParameter.load("your_model_dir")
param.to(device="cuda", dtype=torch.bfloat16) param.to(device="cuda", dtype=torch.bfloat16)
# Create generator # Create engine
generator = StreamGenerator(param) engine = InferenceEngine(param)
# Build request # Build request
request = GenerationRequest( request = GenerationRequest(
@ -102,14 +102,14 @@ request = GenerationRequest(
top_k=50, top_k=50,
) )
# Generate # Generate (streaming)
response = generator.generate(request) for token in engine.generate_with_request(request):
print(token, end="", flush=True)
``` ```
### Three Types of Generators ### Generation Modes
| Generator | Usage | | Mode | Description |
|-----------|-------| |------|-------------|
| `StreamGenerator` | Streaming output, returns word by word | | `stream=True` | Streaming output, yields token by token |
| `LoopGenerator` | Non-streaming output, returns at once | | `stream=False` | Non-streaming output, returns complete result |
| `BatchGenerator` | Batch generation, processes multiple queries simultaneously |