docs: 更新设计文档
This commit is contained in:
parent
2b26f03bd3
commit
ff43a2fab8
|
|
@ -8,7 +8,7 @@ AstrAI adopts a modular design with the following main components:
|
||||||
- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools
|
- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools
|
||||||
- **Model Module** (`astrai/model/`): Transformer model and its submodules
|
- **Model Module** (`astrai/model/`): Transformer model and its submodules
|
||||||
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers
|
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers
|
||||||
- **Inference Module** (`astrai/inference/`): Generation core, KV cache management, streaming generation
|
- **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation
|
||||||
- **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations
|
- **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations
|
||||||
- **Parallel Module** (`astrai/parallel/`): Distributed training support
|
- **Parallel Module** (`astrai/parallel/`): Distributed training support
|
||||||
|
|
||||||
|
|
@ -45,11 +45,11 @@ flowchart LR
|
||||||
C1[Checkpoint] --> C2[ModelParameter]
|
C1[Checkpoint] --> C2[ModelParameter]
|
||||||
C2 --> C3[Transformer + BpeTokenizer]
|
C2 --> C3[Transformer + BpeTokenizer]
|
||||||
C3 --> C4[GenerationRequest + build_prompt]
|
C3 --> C4[GenerationRequest + build_prompt]
|
||||||
C4 --> C5[GeneratorFactory]
|
C4 --> C5[InferenceEngine]
|
||||||
C5 --> C6[GeneratorCore]
|
C5 --> C6[InferenceScheduler]
|
||||||
C6 --> C7[apply_sampling_strategies]
|
C6 --> C7[apply_sampling_strategies]
|
||||||
C7 --> C8[Transformer Forward]
|
C7 --> C8[Transformer Forward]
|
||||||
C8 --> C9[KVCacheManager]
|
C8 --> C9[KV Cache]
|
||||||
C9 --> C10{End Condition?}
|
C9 --> C10{End Condition?}
|
||||||
C10 -->|No| C8
|
C10 -->|No| C8
|
||||||
C10 -->|Yes| C11[Output Text]
|
C10 -->|Yes| C11[Output Text]
|
||||||
|
|
@ -124,22 +124,21 @@ flowchart LR
|
||||||
|
|
||||||
### 4. Inference Module
|
### 4. Inference Module
|
||||||
|
|
||||||
#### 4.1 Generation Core (`core.py`)
|
#### 4.1 Inference Engine (`engine.py`)
|
||||||
- **`GeneratorCore`**: Provides `generate_iterator` method, executes single-step generation
|
- **`InferenceEngine`**: Unified inference interface, supports streaming and non-streaming generation
|
||||||
- Applies sampling strategies (temperature, top-k, top-p) to filter logits
|
- **`InferenceScheduler`**: Continuous batching scheduler with dynamic batch composition
|
||||||
- Supports KV cache to accelerate autoregressive generation
|
- Manages task queue (`waiting_queue`, `active_tasks`) and KV cache allocation
|
||||||
|
|
||||||
#### 4.2 KV Cache Management (`core.py`)
|
#### 4.2 Scheduler (`scheduler.py`)
|
||||||
- **`KVCacheManager`**: Manages K and V cache for each layer, supports batch generation and length extension
|
- **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED)
|
||||||
- Cache shape is `[batch_size, n_kv_heads, seq_len, head_dim]`
|
- **`TaskStatus`**: Task state enumeration
|
||||||
|
- **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits
|
||||||
|
- Continuous batching: new requests can join at any time, completed requests are released immediately
|
||||||
|
|
||||||
#### 4.3 Generator (`generator.py`)
|
#### 4.3 Request (`engine.py`)
|
||||||
- **`GenerationRequest`**: Encapsulates generation request parameters (top_k, top_p, temperature, max_len, query, history, etc.)
|
- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, query, history, etc.)
|
||||||
- **`build_prompt`** (from `chat_template.py`): Converts query and history into ChatML format prompt string
|
- **`build_prompt`** (from `chat_template.py`): Converts query and history into ChatML format prompt string
|
||||||
- **`GeneratorCore`**: Base generator with generate_iterator and generate_loop methods
|
- Provides streaming (`stream=True`) and non-streaming (`stream=False`) generation interfaces
|
||||||
- **`LoopGenerator`**, **`StreamGenerator`**, **`BatchGenerator`**: Different generation modes
|
|
||||||
- **`pad_sequence`**: Pads input IDs to consistent length
|
|
||||||
- Provides streaming and non-streaming generation interfaces
|
|
||||||
|
|
||||||
## Training Data Flow - Detailed Steps
|
## Training Data Flow - Detailed Steps
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -311,41 +311,59 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace astrai.inference {
|
namespace astrai.inference {
|
||||||
class GeneratorCore {
|
class InferenceEngine {
|
||||||
+ModelParameter parameter
|
+ModelParameter parameter
|
||||||
+generate_iterator(input_ids, temperature, top_k, top_p, attn_mask, kv_caches, start_pos) Tuple~Tensor, int~
|
+InferenceScheduler scheduler
|
||||||
+generate_loop(input_ids, ids, temperature, top_k, top_p, attn_mask, kv_caches, start_pos, callback) List~int~
|
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
|
||||||
|
+generate_with_request(request) Union[Generator, str, List[str]]
|
||||||
|
+get_stats() Dict
|
||||||
|
+shutdown()
|
||||||
}
|
}
|
||||||
|
|
||||||
class LoopGenerator {
|
class InferenceScheduler {
|
||||||
+generate(request) str
|
+nn.Module model
|
||||||
}
|
+Tokenizer tokenizer
|
||||||
|
+ModelConfig config
|
||||||
class StreamGenerator {
|
|
||||||
+generate(request) Generator
|
|
||||||
}
|
|
||||||
|
|
||||||
class BatchGenerator {
|
|
||||||
+generate(request) List~str~
|
|
||||||
}
|
|
||||||
|
|
||||||
class EmbeddingEncoder {
|
|
||||||
+encode(sentence) Tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
class KVCacheManager {
|
|
||||||
+int batch_size
|
|
||||||
+Tuple kv_cache
|
+Tuple kv_cache
|
||||||
+Tensor seq_mask
|
+Tensor seq_mask
|
||||||
+get_kvcache() Tuple
|
+List waiting_queue
|
||||||
+get_seq_mask() Tensor
|
+List active_tasks
|
||||||
+update(active_mask)
|
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
||||||
+reset(full_reset)
|
+remove_task(task_id)
|
||||||
|
+start()
|
||||||
|
+stop()
|
||||||
|
+get_stats() Dict
|
||||||
}
|
}
|
||||||
|
|
||||||
class GeneratorFactory {
|
class Task {
|
||||||
+create(parameter, request) GeneratorCore
|
+str task_id
|
||||||
+create_encoder(parameter) EmbeddingEncoderCore
|
+List prompt_ids
|
||||||
|
+int max_tokens
|
||||||
|
+float temperature
|
||||||
|
+float top_p
|
||||||
|
+int top_k
|
||||||
|
+TaskStatus status
|
||||||
|
+List output_ids
|
||||||
|
+int input_tokens
|
||||||
|
+int output_tokens
|
||||||
|
+int slot
|
||||||
|
+Callable stream_callback
|
||||||
|
+is_finished(stop_ids) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
class TaskStatus {
|
||||||
|
+str PENDING
|
||||||
|
+str RUNNING
|
||||||
|
+str FINISHED
|
||||||
|
+str ABORTED
|
||||||
|
}
|
||||||
|
|
||||||
|
class apply_sampling_strategies {
|
||||||
|
+Tensor logits
|
||||||
|
+float temperature
|
||||||
|
+int top_k
|
||||||
|
+float top_p
|
||||||
|
+forward() Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
class Server {
|
class Server {
|
||||||
|
|
@ -396,28 +414,25 @@ classDiagram
|
||||||
BaseStrategy <|-- SFTStrategy
|
BaseStrategy <|-- SFTStrategy
|
||||||
BaseStrategy <|-- DPOStrategy
|
BaseStrategy <|-- DPOStrategy
|
||||||
BaseStrategy <|-- GRPOStrategy
|
BaseStrategy <|-- GRPOStrategy
|
||||||
|
DPOStrategy --> Transformer : creates ref_model
|
||||||
|
GRPOStrategy --> Transformer : creates ref_model
|
||||||
SchedulerFactory ..> BaseScheduler : creates
|
SchedulerFactory ..> BaseScheduler : creates
|
||||||
BaseScheduler <|-- CosineScheduler
|
BaseScheduler <|-- CosineScheduler
|
||||||
BaseScheduler <|-- SGDRScheduler
|
BaseScheduler <|-- SGDRScheduler
|
||||||
CallbackFactory ..> TrainCallback : creates
|
CallbackFactory ..> TrainCallback : creates
|
||||||
GeneratorFactory ..> GeneratorCore : creates
|
InferenceEngine --> InferenceScheduler : uses
|
||||||
GeneratorCore <|-- LoopGenerator
|
InferenceScheduler --> Task : manages
|
||||||
GeneratorCore <|-- StreamGenerator
|
InferenceScheduler --> TaskStatus : uses
|
||||||
GeneratorCore <|-- BatchGenerator
|
InferenceScheduler --> apply_sampling_strategies : uses
|
||||||
GeneratorCore <|-- EmbeddingEncoder
|
InferenceScheduler --> Transformer : uses
|
||||||
LoopGenerator --> KVCacheManager : uses
|
InferenceEngine --> Transformer : uses
|
||||||
StreamGenerator --> KVCacheManager : uses
|
InferenceEngine --> GenerationRequest : uses
|
||||||
BatchGenerator --> KVCacheManager : uses
|
Server --> InferenceEngine : uses
|
||||||
GeneratorCore --> Transformer : uses
|
|
||||||
DPOStrategy --> Transformer : creates ref_model
|
|
||||||
GRPOStrategy --> Transformer : creates ref_model
|
|
||||||
Server --> GeneratorFactory : uses
|
|
||||||
ParallelSetup --> Trainer : enables
|
ParallelSetup --> Trainer : enables
|
||||||
TrainConfig --> StrategyFactory : selects
|
TrainConfig --> StrategyFactory : selects
|
||||||
ModelParameter --> Transformer : contains
|
ModelParameter --> Transformer : contains
|
||||||
ModelParameter --> BpeTokenizer : contains
|
ModelParameter --> BpeTokenizer : contains
|
||||||
ModelParameter --> ModelConfig : contains
|
ModelParameter --> ModelConfig : contains
|
||||||
GeneratorFactory --> GenerationRequest : uses
|
|
||||||
BaseDataset <|-- SEQDataset
|
BaseDataset <|-- SEQDataset
|
||||||
BaseDataset <|-- SFTDataset
|
BaseDataset <|-- SFTDataset
|
||||||
BaseDataset <|-- DPODataset
|
BaseDataset <|-- DPODataset
|
||||||
|
|
@ -446,7 +461,7 @@ classDiagram
|
||||||
| **astrai.model** | Transformer, DecoderBlock, GQA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
| **astrai.model** | Transformer, DecoderBlock, GQA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||||
| **astrai.tokenize** | Tokenizer, BpeTokenizer | Tokenizer |
|
| **astrai.tokenize** | Tokenizer, BpeTokenizer | Tokenizer |
|
||||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
|
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
|
||||||
| **astrai.inference** | GeneratorCore, LoopGenerator, StreamGenerator, BatchGenerator, EmbeddingEncoder, KVCacheManager, GeneratorFactory, Server, GenerationRequest | Inference service |
|
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest | Inference service with continuous batching |
|
||||||
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
||||||
|
|
||||||
### Design Patterns
|
### Design Patterns
|
||||||
|
|
@ -455,17 +470,19 @@ classDiagram
|
||||||
|---------|---------|---------|
|
|---------|---------|---------|
|
||||||
| **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO |
|
| **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO |
|
||||||
| **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components |
|
| **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components |
|
||||||
| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `GeneratorFactory`, `CallbackFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, generators, and callbacks |
|
| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `CallbackFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, and callbacks |
|
||||||
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
|
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
|
||||||
| **Singleton** | `TrainContext` | Training process global state management |
|
| **Singleton** | `TrainContext` | Training process global state management |
|
||||||
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
|
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
|
||||||
|
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
|
||||||
|
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
|
||||||
|
|
||||||
### Core Relationships
|
### Core Relationships
|
||||||
|
|
||||||
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
|
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
|
||||||
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss
|
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss
|
||||||
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
|
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
|
||||||
4. **Inference Flow**: `Server` → `GeneratorFactory` → `GeneratorCore` → `Transformer`, supports multiple generators (LoopGenerator, StreamGenerator, BatchGenerator)
|
4. **Inference Flow**: `Server` → `InferenceEngine` → `InferenceScheduler` → `Transformer`, supports continuous batching with streaming/non-streaming
|
||||||
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
|
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
|
||||||
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
|
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
|
||||||
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
|
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
|
||||||
|
|
|
||||||
|
|
@ -83,15 +83,15 @@
|
||||||
### Usage Example
|
### Usage Example
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from astrai.config.param_config import ModelParameter
|
from astrai.config import ModelParameter
|
||||||
from astrai.inference.generator import StreamGenerator, GenerationRequest
|
from astrai.inference import InferenceEngine, GenerationRequest
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
param = ModelParameter.load("your_model_dir")
|
param = ModelParameter.load("your_model_dir")
|
||||||
param.to(device="cuda", dtype=torch.bfloat16)
|
param.to(device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
# Create generator
|
# Create engine
|
||||||
generator = StreamGenerator(param)
|
engine = InferenceEngine(param)
|
||||||
|
|
||||||
# Build request
|
# Build request
|
||||||
request = GenerationRequest(
|
request = GenerationRequest(
|
||||||
|
|
@ -102,14 +102,14 @@ request = GenerationRequest(
|
||||||
top_k=50,
|
top_k=50,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate
|
# Generate (streaming)
|
||||||
response = generator.generate(request)
|
for token in engine.generate_with_request(request):
|
||||||
|
print(token, end="", flush=True)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Three Types of Generators
|
### Generation Modes
|
||||||
|
|
||||||
| Generator | Usage |
|
| Mode | Description |
|
||||||
|-----------|-------|
|
|------|-------------|
|
||||||
| `StreamGenerator` | Streaming output, returns word by word |
|
| `stream=True` | Streaming output, yields token by token |
|
||||||
| `LoopGenerator` | Non-streaming output, returns at once |
|
| `stream=False` | Non-streaming output, returns complete result |
|
||||||
| `BatchGenerator` | Batch generation, processes multiple queries simultaneously |
|
|
||||||
Loading…
Reference in New Issue