docs: 更新设计文档

This commit is contained in:
ViperEkura 2026-04-05 00:17:35 +08:00
parent 2b26f03bd3
commit ff43a2fab8
3 changed files with 89 additions and 73 deletions

View File

@ -8,7 +8,7 @@ AstrAI adopts a modular design with the following main components:
- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools
- **Model Module** (`astrai/model/`): Transformer model and its submodules
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers
- **Inference Module** (`astrai/inference/`): Generation core, KV cache management, streaming generation
- **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation
- **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations
- **Parallel Module** (`astrai/parallel/`): Distributed training support
@ -45,11 +45,11 @@ flowchart LR
C1[Checkpoint] --> C2[ModelParameter]
C2 --> C3[Transformer + BpeTokenizer]
C3 --> C4[GenerationRequest + build_prompt]
C4 --> C5[GeneratorFactory]
C5 --> C6[GeneratorCore]
C4 --> C5[InferenceEngine]
C5 --> C6[InferenceScheduler]
C6 --> C7[apply_sampling_strategies]
C7 --> C8[Transformer Forward]
C8 --> C9[KVCacheManager]
C8 --> C9[KV Cache]
C9 --> C10{End Condition?}
C10 -->|No| C8
C10 -->|Yes| C11[Output Text]
@ -124,22 +124,21 @@ flowchart LR
### 4. Inference Module
#### 4.1 Generation Core (`core.py`)
- **`GeneratorCore`**: Provides `generate_iterator` method, executes single-step generation
- Applies sampling strategies (temperature, top-k, top-p) to filter logits
- Supports KV cache to accelerate autoregressive generation
#### 4.1 Inference Engine (`engine.py`)
- **`InferenceEngine`**: Unified inference interface, supports streaming and non-streaming generation
- **`InferenceScheduler`**: Continuous batching scheduler with dynamic batch composition
- Manages task queue (`waiting_queue`, `active_tasks`) and KV cache allocation
#### 4.2 KV Cache Management (`core.py`)
- **`KVCacheManager`**: Manages K and V cache for each layer, supports batch generation and length extension
- Cache shape is `[batch_size, n_kv_heads, seq_len, head_dim]`
#### 4.2 Scheduler (`scheduler.py`)
- **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED)
- **`TaskStatus`**: Task state enumeration
- **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits
- Continuous batching: new requests can join at any time, completed requests are released immediately
#### 4.3 Generator (`generator.py`)
- **`GenerationRequest`**: Encapsulates generation request parameters (top_k, top_p, temperature, max_len, query, history, etc.)
#### 4.3 Request (`engine.py`)
- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, query, history, etc.)
- **`build_prompt`** (from `chat_template.py`): Converts query and history into ChatML format prompt string
- **`GeneratorCore`**: Base generator with generate_iterator and generate_loop methods
- **`LoopGenerator`**, **`StreamGenerator`**, **`BatchGenerator`**: Different generation modes
- **`pad_sequence`**: Pads input IDs to consistent length
- Provides streaming and non-streaming generation interfaces
- Provides streaming (`stream=True`) and non-streaming (`stream=False`) generation interfaces
## Training Data Flow - Detailed Steps

View File

@ -311,41 +311,59 @@ classDiagram
}
namespace astrai.inference {
class GeneratorCore {
class InferenceEngine {
+ModelParameter parameter
+generate_iterator(input_ids, temperature, top_k, top_p, attn_mask, kv_caches, start_pos) Tuple~Tensor, int~
+generate_loop(input_ids, ids, temperature, top_k, top_p, attn_mask, kv_caches, start_pos, callback) List~int~
+InferenceScheduler scheduler
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
+generate_with_request(request) Union[Generator, str, List[str]]
+get_stats() Dict
+shutdown()
}
class LoopGenerator {
+generate(request) str
}
class StreamGenerator {
+generate(request) Generator
}
class BatchGenerator {
+generate(request) List~str~
}
class EmbeddingEncoder {
+encode(sentence) Tensor
}
class KVCacheManager {
+int batch_size
class InferenceScheduler {
+nn.Module model
+Tokenizer tokenizer
+ModelConfig config
+Tuple kv_cache
+Tensor seq_mask
+get_kvcache() Tuple
+get_seq_mask() Tensor
+update(active_mask)
+reset(full_reset)
+List waiting_queue
+List active_tasks
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
+remove_task(task_id)
+start()
+stop()
+get_stats() Dict
}
class GeneratorFactory {
+create(parameter, request) GeneratorCore
+create_encoder(parameter) EmbeddingEncoderCore
class Task {
+str task_id
+List prompt_ids
+int max_tokens
+float temperature
+float top_p
+int top_k
+TaskStatus status
+List output_ids
+int input_tokens
+int output_tokens
+int slot
+Callable stream_callback
+is_finished(stop_ids) bool
}
class TaskStatus {
+str PENDING
+str RUNNING
+str FINISHED
+str ABORTED
}
class apply_sampling_strategies {
+Tensor logits
+float temperature
+int top_k
+float top_p
+forward() Tensor
}
class Server {
@ -396,28 +414,25 @@ classDiagram
BaseStrategy <|-- SFTStrategy
BaseStrategy <|-- DPOStrategy
BaseStrategy <|-- GRPOStrategy
DPOStrategy --> Transformer : creates ref_model
GRPOStrategy --> Transformer : creates ref_model
SchedulerFactory ..> BaseScheduler : creates
BaseScheduler <|-- CosineScheduler
BaseScheduler <|-- SGDRScheduler
CallbackFactory ..> TrainCallback : creates
GeneratorFactory ..> GeneratorCore : creates
GeneratorCore <|-- LoopGenerator
GeneratorCore <|-- StreamGenerator
GeneratorCore <|-- BatchGenerator
GeneratorCore <|-- EmbeddingEncoder
LoopGenerator --> KVCacheManager : uses
StreamGenerator --> KVCacheManager : uses
BatchGenerator --> KVCacheManager : uses
GeneratorCore --> Transformer : uses
DPOStrategy --> Transformer : creates ref_model
GRPOStrategy --> Transformer : creates ref_model
Server --> GeneratorFactory : uses
InferenceEngine --> InferenceScheduler : uses
InferenceScheduler --> Task : manages
InferenceScheduler --> TaskStatus : uses
InferenceScheduler --> apply_sampling_strategies : uses
InferenceScheduler --> Transformer : uses
InferenceEngine --> Transformer : uses
InferenceEngine --> GenerationRequest : uses
Server --> InferenceEngine : uses
ParallelSetup --> Trainer : enables
TrainConfig --> StrategyFactory : selects
ModelParameter --> Transformer : contains
ModelParameter --> BpeTokenizer : contains
ModelParameter --> ModelConfig : contains
GeneratorFactory --> GenerationRequest : uses
BaseDataset <|-- SEQDataset
BaseDataset <|-- SFTDataset
BaseDataset <|-- DPODataset
@ -446,7 +461,7 @@ classDiagram
| **astrai.model** | Transformer, DecoderBlock, GQA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | Tokenizer, BpeTokenizer | Tokenizer |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
| **astrai.inference** | GeneratorCore, LoopGenerator, StreamGenerator, BatchGenerator, EmbeddingEncoder, KVCacheManager, GeneratorFactory, Server, GenerationRequest | Inference service |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest | Inference service with continuous batching |
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
### Design Patterns
@ -455,17 +470,19 @@ classDiagram
|---------|---------|---------|
| **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO |
| **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components |
| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `GeneratorFactory`, `CallbackFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, generators, and callbacks |
| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `CallbackFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, and callbacks |
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
| **Singleton** | `TrainContext` | Training process global state management |
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
### Core Relationships
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` to compute loss
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
4. **Inference Flow**: `Server``GeneratorFactory` → `GeneratorCore``Transformer`, supports multiple generators (LoopGenerator, StreamGenerator, BatchGenerator)
4. **Inference Flow**: `Server``InferenceEngine` → `InferenceScheduler``Transformer`, supports continuous batching with streaming/non-streaming
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors

View File

@ -83,15 +83,15 @@
### Usage Example
```python
from astrai.config.param_config import ModelParameter
from astrai.inference.generator import StreamGenerator, GenerationRequest
from astrai.config import ModelParameter
from astrai.inference import InferenceEngine, GenerationRequest
# Load model
param = ModelParameter.load("your_model_dir")
param.to(device="cuda", dtype=torch.bfloat16)
# Create generator
generator = StreamGenerator(param)
# Create engine
engine = InferenceEngine(param)
# Build request
request = GenerationRequest(
@ -102,14 +102,14 @@ request = GenerationRequest(
top_k=50,
)
# Generate
response = generator.generate(request)
# Generate (streaming)
for token in engine.generate_with_request(request):
print(token, end="", flush=True)
```
### Three Types of Generators
### Generation Modes
| Generator | Usage |
|-----------|-------|
| `StreamGenerator` | Streaming output, returns word by word |
| `LoopGenerator` | Non-streaming output, returns at once |
| `BatchGenerator` | Batch generation, processes multiple queries simultaneously |
| Mode | Description |
|------|-------------|
| `stream=True` | Streaming output, yields token by token |
| `stream=False` | Non-streaming output, returns complete result |