docs: 更新设计文档
This commit is contained in:
parent
2b26f03bd3
commit
ff43a2fab8
|
|
@ -8,7 +8,7 @@ AstrAI adopts a modular design with the following main components:
|
|||
- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools
|
||||
- **Model Module** (`astrai/model/`): Transformer model and its submodules
|
||||
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers
|
||||
- **Inference Module** (`astrai/inference/`): Generation core, KV cache management, streaming generation
|
||||
- **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation
|
||||
- **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations
|
||||
- **Parallel Module** (`astrai/parallel/`): Distributed training support
|
||||
|
||||
|
|
@ -45,11 +45,11 @@ flowchart LR
|
|||
C1[Checkpoint] --> C2[ModelParameter]
|
||||
C2 --> C3[Transformer + BpeTokenizer]
|
||||
C3 --> C4[GenerationRequest + build_prompt]
|
||||
C4 --> C5[GeneratorFactory]
|
||||
C5 --> C6[GeneratorCore]
|
||||
C4 --> C5[InferenceEngine]
|
||||
C5 --> C6[InferenceScheduler]
|
||||
C6 --> C7[apply_sampling_strategies]
|
||||
C7 --> C8[Transformer Forward]
|
||||
C8 --> C9[KVCacheManager]
|
||||
C8 --> C9[KV Cache]
|
||||
C9 --> C10{End Condition?}
|
||||
C10 -->|No| C8
|
||||
C10 -->|Yes| C11[Output Text]
|
||||
|
|
@ -124,22 +124,21 @@ flowchart LR
|
|||
|
||||
### 4. Inference Module
|
||||
|
||||
#### 4.1 Generation Core (`core.py`)
|
||||
- **`GeneratorCore`**: Provides `generate_iterator` method, executes single-step generation
|
||||
- Applies sampling strategies (temperature, top-k, top-p) to filter logits
|
||||
- Supports KV cache to accelerate autoregressive generation
|
||||
#### 4.1 Inference Engine (`engine.py`)
|
||||
- **`InferenceEngine`**: Unified inference interface, supports streaming and non-streaming generation
|
||||
- **`InferenceScheduler`**: Continuous batching scheduler with dynamic batch composition
|
||||
- Manages task queue (`waiting_queue`, `active_tasks`) and KV cache allocation
|
||||
|
||||
#### 4.2 KV Cache Management (`core.py`)
|
||||
- **`KVCacheManager`**: Manages K and V cache for each layer, supports batch generation and length extension
|
||||
- Cache shape is `[batch_size, n_kv_heads, seq_len, head_dim]`
|
||||
#### 4.2 Scheduler (`scheduler.py`)
|
||||
- **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED)
|
||||
- **`TaskStatus`**: Task state enumeration
|
||||
- **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits
|
||||
- Continuous batching: new requests can join at any time, completed requests are released immediately
|
||||
|
||||
#### 4.3 Generator (`generator.py`)
|
||||
- **`GenerationRequest`**: Encapsulates generation request parameters (top_k, top_p, temperature, max_len, query, history, etc.)
|
||||
#### 4.3 Request (`engine.py`)
|
||||
- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, query, history, etc.)
|
||||
- **`build_prompt`** (from `chat_template.py`): Converts query and history into ChatML format prompt string
|
||||
- **`GeneratorCore`**: Base generator with generate_iterator and generate_loop methods
|
||||
- **`LoopGenerator`**, **`StreamGenerator`**, **`BatchGenerator`**: Different generation modes
|
||||
- **`pad_sequence`**: Pads input IDs to consistent length
|
||||
- Provides streaming and non-streaming generation interfaces
|
||||
- Provides streaming (`stream=True`) and non-streaming (`stream=False`) generation interfaces
|
||||
|
||||
## Training Data Flow - Detailed Steps
|
||||
|
||||
|
|
|
|||
|
|
@ -311,41 +311,59 @@ classDiagram
|
|||
}
|
||||
|
||||
namespace astrai.inference {
|
||||
class GeneratorCore {
|
||||
class InferenceEngine {
|
||||
+ModelParameter parameter
|
||||
+generate_iterator(input_ids, temperature, top_k, top_p, attn_mask, kv_caches, start_pos) Tuple~Tensor, int~
|
||||
+generate_loop(input_ids, ids, temperature, top_k, top_p, attn_mask, kv_caches, start_pos, callback) List~int~
|
||||
+InferenceScheduler scheduler
|
||||
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
|
||||
+generate_with_request(request) Union[Generator, str, List[str]]
|
||||
+get_stats() Dict
|
||||
+shutdown()
|
||||
}
|
||||
|
||||
class LoopGenerator {
|
||||
+generate(request) str
|
||||
}
|
||||
|
||||
class StreamGenerator {
|
||||
+generate(request) Generator
|
||||
}
|
||||
|
||||
class BatchGenerator {
|
||||
+generate(request) List~str~
|
||||
}
|
||||
|
||||
class EmbeddingEncoder {
|
||||
+encode(sentence) Tensor
|
||||
}
|
||||
|
||||
class KVCacheManager {
|
||||
+int batch_size
|
||||
class InferenceScheduler {
|
||||
+nn.Module model
|
||||
+Tokenizer tokenizer
|
||||
+ModelConfig config
|
||||
+Tuple kv_cache
|
||||
+Tensor seq_mask
|
||||
+get_kvcache() Tuple
|
||||
+get_seq_mask() Tensor
|
||||
+update(active_mask)
|
||||
+reset(full_reset)
|
||||
+List waiting_queue
|
||||
+List active_tasks
|
||||
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
||||
+remove_task(task_id)
|
||||
+start()
|
||||
+stop()
|
||||
+get_stats() Dict
|
||||
}
|
||||
|
||||
class GeneratorFactory {
|
||||
+create(parameter, request) GeneratorCore
|
||||
+create_encoder(parameter) EmbeddingEncoderCore
|
||||
class Task {
|
||||
+str task_id
|
||||
+List prompt_ids
|
||||
+int max_tokens
|
||||
+float temperature
|
||||
+float top_p
|
||||
+int top_k
|
||||
+TaskStatus status
|
||||
+List output_ids
|
||||
+int input_tokens
|
||||
+int output_tokens
|
||||
+int slot
|
||||
+Callable stream_callback
|
||||
+is_finished(stop_ids) bool
|
||||
}
|
||||
|
||||
class TaskStatus {
|
||||
+str PENDING
|
||||
+str RUNNING
|
||||
+str FINISHED
|
||||
+str ABORTED
|
||||
}
|
||||
|
||||
class apply_sampling_strategies {
|
||||
+Tensor logits
|
||||
+float temperature
|
||||
+int top_k
|
||||
+float top_p
|
||||
+forward() Tensor
|
||||
}
|
||||
|
||||
class Server {
|
||||
|
|
@ -396,28 +414,25 @@ classDiagram
|
|||
BaseStrategy <|-- SFTStrategy
|
||||
BaseStrategy <|-- DPOStrategy
|
||||
BaseStrategy <|-- GRPOStrategy
|
||||
DPOStrategy --> Transformer : creates ref_model
|
||||
GRPOStrategy --> Transformer : creates ref_model
|
||||
SchedulerFactory ..> BaseScheduler : creates
|
||||
BaseScheduler <|-- CosineScheduler
|
||||
BaseScheduler <|-- SGDRScheduler
|
||||
CallbackFactory ..> TrainCallback : creates
|
||||
GeneratorFactory ..> GeneratorCore : creates
|
||||
GeneratorCore <|-- LoopGenerator
|
||||
GeneratorCore <|-- StreamGenerator
|
||||
GeneratorCore <|-- BatchGenerator
|
||||
GeneratorCore <|-- EmbeddingEncoder
|
||||
LoopGenerator --> KVCacheManager : uses
|
||||
StreamGenerator --> KVCacheManager : uses
|
||||
BatchGenerator --> KVCacheManager : uses
|
||||
GeneratorCore --> Transformer : uses
|
||||
DPOStrategy --> Transformer : creates ref_model
|
||||
GRPOStrategy --> Transformer : creates ref_model
|
||||
Server --> GeneratorFactory : uses
|
||||
InferenceEngine --> InferenceScheduler : uses
|
||||
InferenceScheduler --> Task : manages
|
||||
InferenceScheduler --> TaskStatus : uses
|
||||
InferenceScheduler --> apply_sampling_strategies : uses
|
||||
InferenceScheduler --> Transformer : uses
|
||||
InferenceEngine --> Transformer : uses
|
||||
InferenceEngine --> GenerationRequest : uses
|
||||
Server --> InferenceEngine : uses
|
||||
ParallelSetup --> Trainer : enables
|
||||
TrainConfig --> StrategyFactory : selects
|
||||
ModelParameter --> Transformer : contains
|
||||
ModelParameter --> BpeTokenizer : contains
|
||||
ModelParameter --> ModelConfig : contains
|
||||
GeneratorFactory --> GenerationRequest : uses
|
||||
BaseDataset <|-- SEQDataset
|
||||
BaseDataset <|-- SFTDataset
|
||||
BaseDataset <|-- DPODataset
|
||||
|
|
@ -446,7 +461,7 @@ classDiagram
|
|||
| **astrai.model** | Transformer, DecoderBlock, GQA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||
| **astrai.tokenize** | Tokenizer, BpeTokenizer | Tokenizer |
|
||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
|
||||
| **astrai.inference** | GeneratorCore, LoopGenerator, StreamGenerator, BatchGenerator, EmbeddingEncoder, KVCacheManager, GeneratorFactory, Server, GenerationRequest | Inference service |
|
||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest | Inference service with continuous batching |
|
||||
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
||||
|
||||
### Design Patterns
|
||||
|
|
@ -455,17 +470,19 @@ classDiagram
|
|||
|---------|---------|---------|
|
||||
| **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO |
|
||||
| **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components |
|
||||
| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `GeneratorFactory`, `CallbackFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, generators, and callbacks |
|
||||
| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `CallbackFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, and callbacks |
|
||||
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
|
||||
| **Singleton** | `TrainContext` | Training process global state management |
|
||||
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
|
||||
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
|
||||
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
|
||||
|
||||
### Core Relationships
|
||||
|
||||
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
|
||||
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss
|
||||
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
|
||||
4. **Inference Flow**: `Server` → `GeneratorFactory` → `GeneratorCore` → `Transformer`, supports multiple generators (LoopGenerator, StreamGenerator, BatchGenerator)
|
||||
4. **Inference Flow**: `Server` → `InferenceEngine` → `InferenceScheduler` → `Transformer`, supports continuous batching with streaming/non-streaming
|
||||
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
|
||||
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
|
||||
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
|
||||
|
|
|
|||
|
|
@ -83,15 +83,15 @@
|
|||
### Usage Example
|
||||
|
||||
```python
|
||||
from astrai.config.param_config import ModelParameter
|
||||
from astrai.inference.generator import StreamGenerator, GenerationRequest
|
||||
from astrai.config import ModelParameter
|
||||
from astrai.inference import InferenceEngine, GenerationRequest
|
||||
|
||||
# Load model
|
||||
param = ModelParameter.load("your_model_dir")
|
||||
param.to(device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Create generator
|
||||
generator = StreamGenerator(param)
|
||||
# Create engine
|
||||
engine = InferenceEngine(param)
|
||||
|
||||
# Build request
|
||||
request = GenerationRequest(
|
||||
|
|
@ -102,14 +102,14 @@ request = GenerationRequest(
|
|||
top_k=50,
|
||||
)
|
||||
|
||||
# Generate
|
||||
response = generator.generate(request)
|
||||
# Generate (streaming)
|
||||
for token in engine.generate_with_request(request):
|
||||
print(token, end="", flush=True)
|
||||
```
|
||||
|
||||
### Three Types of Generators
|
||||
### Generation Modes
|
||||
|
||||
| Generator | Usage |
|
||||
|-----------|-------|
|
||||
| `StreamGenerator` | Streaming output, returns word by word |
|
||||
| `LoopGenerator` | Non-streaming output, returns at once |
|
||||
| `BatchGenerator` | Batch generation, processes multiple queries simultaneously |
|
||||
| Mode | Description |
|
||||
|------|-------------|
|
||||
| `stream=True` | Streaming output, yields token by token |
|
||||
| `stream=False` | Non-streaming output, returns complete result |
|
||||
Loading…
Reference in New Issue