chore: 更新文档, 修正代码格式
This commit is contained in:
parent
23ce4bc3ae
commit
d2fe8afbd1
|
|
@ -4,6 +4,7 @@
|
||||||
# Files that MUST use LF (Unix/Linux execution)
|
# Files that MUST use LF (Unix/Linux execution)
|
||||||
*.sh text eol=lf
|
*.sh text eol=lf
|
||||||
*.py text eol=lf
|
*.py text eol=lf
|
||||||
|
*.md text eol=lf
|
||||||
Dockerfile text eol=lf
|
Dockerfile text eol=lf
|
||||||
.dockerignore text eol=lf
|
.dockerignore text eol=lf
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@
|
||||||
### Features
|
### Features
|
||||||
|
|
||||||
- 🚀 **High Performance**: Optimized for both training and inference with efficient parallelization.
|
- 🚀 **High Performance**: Optimized for both training and inference with efficient parallelization.
|
||||||
- 🔧 **Flexible**: Support for seq/sft/dpo training, customizable model architectures.
|
- 🔧 **Flexible**: Support for seq/sft/dpo/grpo training, customizable model architectures.
|
||||||
- 💡 **Easy to Use**: Simple API with comprehensive examples and demos.
|
- 💡 **Easy to Use**: Simple API with comprehensive examples and demos.
|
||||||
- 📦 **Lightweight**: Minimal dependencies, easy to deploy.
|
- 📦 **Lightweight**: Minimal dependencies, easy to deploy.
|
||||||
- 🔬 **Research‑Friendly**: Modular design, easy to experiment with new ideas.
|
- 🔬 **Research‑Friendly**: Modular design, easy to experiment with new ideas.
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@
|
||||||
### 特性
|
### 特性
|
||||||
|
|
||||||
- 🚀 **高性能**: 训练与推理双向优化,高效并行。
|
- 🚀 **高性能**: 训练与推理双向优化,高效并行。
|
||||||
- 🔧 **灵活**: 支持 seq/sft/dpo 多种训练方式,可定制模型架构。
|
- 🔧 **灵活**: 支持 seq/sft/dpo/grpo 多种训练方式,可定制模型架构。
|
||||||
- 💡 **易用**: 简洁的 API 与丰富的示例、演示。
|
- 💡 **易用**: 简洁的 API 与丰富的示例、演示。
|
||||||
- 📦 **轻量**: 依赖少,部署简单。
|
- 📦 **轻量**: 依赖少,部署简单。
|
||||||
- 🔬 **研究友好**: 模块化设计,便于实验新想法。
|
- 🔬 **研究友好**: 模块化设计,便于实验新想法。
|
||||||
|
|
|
||||||
|
|
@ -6,10 +6,11 @@ This document describes the data flow of the AstrAI project (a training and infe
|
||||||
|
|
||||||
AstrAI adopts a modular design with the following main components:
|
AstrAI adopts a modular design with the following main components:
|
||||||
- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools
|
- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools
|
||||||
- **Model Module** (`astrai/model/`): Transformer model and its submodules
|
- **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules
|
||||||
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers
|
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers
|
||||||
- **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation
|
- **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation
|
||||||
- **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations
|
- **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations
|
||||||
|
- **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration
|
||||||
- **Parallel Module** (`astrai/parallel/`): Distributed training support
|
- **Parallel Module** (`astrai/parallel/`): Distributed training support
|
||||||
|
|
||||||
The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**.
|
The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**.
|
||||||
|
|
@ -42,9 +43,9 @@ flowchart LR
|
||||||
|
|
||||||
subgraph C[Inference]
|
subgraph C[Inference]
|
||||||
direction TB
|
direction TB
|
||||||
C1[Checkpoint] --> C2[ModelParameter]
|
C1[Checkpoint] --> C2[AutoModel]
|
||||||
C2 --> C3[Transformer + BpeTokenizer]
|
C2 --> C3[Transformer + Tokenizer]
|
||||||
C3 --> C4[GenerationRequest + build_prompt]
|
C3 --> C4[GenerationRequest + apply_chat_template]
|
||||||
C4 --> C5[InferenceEngine]
|
C4 --> C5[InferenceEngine]
|
||||||
C5 --> C6[InferenceScheduler]
|
C5 --> C6[InferenceScheduler]
|
||||||
C6 --> C7[apply_sampling_strategies]
|
C6 --> C7[apply_sampling_strategies]
|
||||||
|
|
@ -88,8 +89,9 @@ flowchart LR
|
||||||
|
|
||||||
### 2. Model Module
|
### 2. Model Module
|
||||||
|
|
||||||
#### 2.1 Transformer (`transformer.py`)
|
#### 2.1 Transformer / AutoModel (`transformer.py`, `automodel.py`)
|
||||||
- Core autoregressive decoder architecture
|
- **`AutoModel`**: Base class for autoregressive language models with `from_pretrained()` and `save_pretrained()` methods
|
||||||
|
- **`Transformer`**: Core autoregressive decoder architecture (registered via `@AutoModel.register('transformer')`)
|
||||||
- Contains embedding layer, multi-layer `DecoderBlock`, RMSNorm, and linear output head
|
- Contains embedding layer, multi-layer `DecoderBlock`, RMSNorm, and linear output head
|
||||||
- Supports weight tying (`tie_weight=True`) to reduce parameter count
|
- Supports weight tying (`tie_weight=True`) to reduce parameter count
|
||||||
- Uses Rotary Position Embedding (RoPE) to inject position information
|
- Uses Rotary Position Embedding (RoPE) to inject position information
|
||||||
|
|
@ -122,22 +124,31 @@ flowchart LR
|
||||||
- **`SchedulerFactory`**: Factory pattern, supports registration of various schedulers (such as `cosine`, `sgdr`)
|
- **`SchedulerFactory`**: Factory pattern, supports registration of various schedulers (such as `cosine`, `sgdr`)
|
||||||
- Scheduler is automatically created according to configuration and bound to optimizer
|
- Scheduler is automatically created according to configuration and bound to optimizer
|
||||||
|
|
||||||
### 4. Inference Module
|
### 4. Factory Module
|
||||||
|
|
||||||
#### 4.1 Inference Engine (`engine.py`)
|
#### 4.1 Registry and BaseFactory (`factory.py`)
|
||||||
|
- **`Registry`**: Flexible registry for component classes with category and priority support
|
||||||
|
- **`BaseFactory`**: Generic factory class for component registration and creation
|
||||||
|
- Supports decorator-based registration pattern for extensible components
|
||||||
|
- Provides methods for registration, retrieval, and listing with filtering
|
||||||
|
|
||||||
|
### 5. Inference Module
|
||||||
|
|
||||||
|
#### 5.1 Inference Engine (`engine.py`)
|
||||||
- **`InferenceEngine`**: Unified inference interface, supports streaming and non-streaming generation
|
- **`InferenceEngine`**: Unified inference interface, supports streaming and non-streaming generation
|
||||||
- **`InferenceScheduler`**: Continuous batching scheduler with dynamic batch composition
|
- **`InferenceScheduler`**: Continuous batching scheduler with dynamic batch composition
|
||||||
- Manages task queue (`waiting_queue`, `active_tasks`) and KV cache allocation
|
- Manages task queue (`waiting_queue`, `active_tasks`) and KV cache allocation
|
||||||
|
|
||||||
#### 4.2 Scheduler (`scheduler.py`)
|
#### 5.2 Scheduler (`scheduler.py`)
|
||||||
- **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED)
|
- **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED)
|
||||||
- **`TaskStatus`**: Task state enumeration
|
- **`TaskStatus`**: Task state enumeration
|
||||||
- **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits
|
- **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits
|
||||||
- Continuous batching: new requests can join at any time, completed requests are released immediately
|
- Continuous batching: new requests can join at any time, completed requests are released immediately
|
||||||
|
|
||||||
#### 4.3 Request (`engine.py`)
|
#### 5.3 Request (`engine.py`)
|
||||||
- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, query, history, etc.)
|
- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.)
|
||||||
- **`build_prompt`** (from `chat_template.py`): Converts query and history into ChatML format prompt string
|
- **`messages` format**: List of message dictionaries with `role` (system/user/assistant) and `content`
|
||||||
|
- **`apply_chat_template`** (from `tokenizer.py`): Converts messages into prompt string using ChatML format
|
||||||
- Provides streaming (`stream=True`) and non-streaming (`stream=False`) generation interfaces
|
- Provides streaming (`stream=True`) and non-streaming (`stream=False`) generation interfaces
|
||||||
|
|
||||||
## Training Data Flow - Detailed Steps
|
## Training Data Flow - Detailed Steps
|
||||||
|
|
@ -176,11 +187,11 @@ flowchart LR
|
||||||
## Inference Data Flow - Detailed Steps
|
## Inference Data Flow - Detailed Steps
|
||||||
|
|
||||||
1. **Model Loading**
|
1. **Model Loading**
|
||||||
- Load `Transformer` model and tokenizer from checkpoint
|
- Load `Transformer` model from checkpoint via `AutoModel.from_pretrained()`
|
||||||
- Set model to evaluation mode (`model.eval()`), enable inference mode (`torch.inference_mode`)
|
- Set model to evaluation mode (`model.eval()`), enable inference mode (`torch.inference_mode`)
|
||||||
|
|
||||||
2. **Prompt Construction and Encoding**
|
2. **Prompt Construction and Encoding**
|
||||||
- User query and history are converted to ChatML format string through `build_prompt` function in chat_template module
|
- User messages (list of dict with role and content) are converted to ChatML format string through `apply_chat_template` method in tokenizer
|
||||||
- Tokenizer encodes prompt string to token ID sequence `input_ids`
|
- Tokenizer encodes prompt string to token ID sequence `input_ids`
|
||||||
- For batch generation, use `pad_sequence` for padding
|
- For batch generation, use `pad_sequence` for padding
|
||||||
|
|
||||||
|
|
@ -207,5 +218,5 @@ flowchart LR
|
||||||
|
|
||||||
The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using KV cache and sampling strategies. Clear interfaces between modules facilitate customization and extension.
|
The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using KV cache and sampling strategies. Clear interfaces between modules facilitate customization and extension.
|
||||||
|
|
||||||
> Document Update Time: 2026-03-30
|
> Document Update Time: 2026-04-05
|
||||||
> Corresponding Code Version: Refer to version number defined in `pyproject.toml`
|
> Corresponding Code Version: Refer to version number defined in `pyproject.toml`
|
||||||
|
|
@ -36,10 +36,23 @@ classDiagram
|
||||||
+int batch_size
|
+int batch_size
|
||||||
+int accumulation_steps
|
+int accumulation_steps
|
||||||
+float max_grad_norm
|
+float max_grad_norm
|
||||||
|
+int start_epoch
|
||||||
|
+int start_batch
|
||||||
+str ckpt_dir
|
+str ckpt_dir
|
||||||
+int ckpt_interval
|
+int ckpt_interval
|
||||||
|
+int random_seed
|
||||||
|
+int num_workers
|
||||||
|
+int prefetch_factor
|
||||||
|
+bool pin_memory
|
||||||
+int nprocs
|
+int nprocs
|
||||||
+str backend
|
+str backend
|
||||||
|
+str master_addr
|
||||||
|
+str master_port
|
||||||
|
+Callable parallel_wrapper
|
||||||
|
+Callable state_dict_fn
|
||||||
|
+List[int] device_ids
|
||||||
|
+str device_type
|
||||||
|
+dict extra_kwargs
|
||||||
+validate()
|
+validate()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -123,6 +136,16 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace astrai.model {
|
namespace astrai.model {
|
||||||
|
class AutoModel {
|
||||||
|
+ModelConfig config
|
||||||
|
+Dict _registry
|
||||||
|
+register(model_type) decorator
|
||||||
|
+get_model_class(model_type) Type
|
||||||
|
+from_pretrained(path, disable_random_init) nn.Module
|
||||||
|
+save_pretrained(save_directory)
|
||||||
|
+to(*args, **kwargs) Self
|
||||||
|
}
|
||||||
|
|
||||||
class Transformer {
|
class Transformer {
|
||||||
+ModelConfig config
|
+ModelConfig config
|
||||||
+RotaryEmbedding rotary_embeding
|
+RotaryEmbedding rotary_embeding
|
||||||
|
|
@ -440,6 +463,8 @@ classDiagram
|
||||||
DatasetFactory ..> BaseDataset : creates
|
DatasetFactory ..> BaseDataset : creates
|
||||||
BaseSegmentFetcher --> MultiSegmentFetcher : used by
|
BaseSegmentFetcher --> MultiSegmentFetcher : used by
|
||||||
MultiSegmentFetcher --> BaseDataset : used by
|
MultiSegmentFetcher --> BaseDataset : used by
|
||||||
|
AutoModel <|-- Transformer
|
||||||
|
AutoModel --> ModelConfig : contains
|
||||||
Transformer --> DecoderBlock : uses
|
Transformer --> DecoderBlock : uses
|
||||||
Transformer --> RotaryEmbedding : uses
|
Transformer --> RotaryEmbedding : uses
|
||||||
Transformer --> Embedding : uses
|
Transformer --> Embedding : uses
|
||||||
|
|
@ -458,11 +483,12 @@ classDiagram
|
||||||
|--------|------------|-------------|
|
|--------|------------|-------------|
|
||||||
| **astrai.config** | ModelConfig, TrainConfig, ModelParameter | Configuration management |
|
| **astrai.config** | ModelConfig, TrainConfig, ModelParameter | Configuration management |
|
||||||
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint, DataLoader | Dataset loading and management |
|
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint, DataLoader | Dataset loading and management |
|
||||||
| **astrai.model** | Transformer, DecoderBlock, GQA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||||
| **astrai.tokenize** | Tokenizer, BpeTokenizer | Tokenizer |
|
| **astrai.tokenize** | Tokenizer, BpeTokenizer | Tokenizer |
|
||||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
|
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
|
||||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest | Inference service with continuous batching |
|
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest | Inference service with continuous batching |
|
||||||
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
||||||
|
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
|
||||||
|
|
||||||
### Design Patterns
|
### Design Patterns
|
||||||
|
|
||||||
|
|
@ -470,12 +496,13 @@ classDiagram
|
||||||
|---------|---------|---------|
|
|---------|---------|---------|
|
||||||
| **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO |
|
| **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO |
|
||||||
| **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components |
|
| **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components |
|
||||||
| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `CallbackFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, and callbacks |
|
| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `CallbackFactory`, `BaseFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, and callbacks |
|
||||||
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
|
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
|
||||||
| **Singleton** | `TrainContext` | Training process global state management |
|
| **Singleton** | `TrainContext` | Training process global state management |
|
||||||
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
|
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
|
||||||
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
|
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
|
||||||
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
|
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
|
||||||
|
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
|
||||||
|
|
||||||
### Core Relationships
|
### Core Relationships
|
||||||
|
|
||||||
|
|
@ -487,6 +514,7 @@ classDiagram
|
||||||
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
|
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
|
||||||
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
|
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
|
||||||
8. **Scheduler Support**: `SchedulerFactory` creates learning rate schedulers (CosineScheduler, SGDRScheduler)
|
8. **Scheduler Support**: `SchedulerFactory` creates learning rate schedulers (CosineScheduler, SGDRScheduler)
|
||||||
|
9. **AutoModel Loading**: `AutoModel.from_pretrained()` dynamically loads model based on `config.json` model_type, uses `Registry` pattern for model type registration
|
||||||
|
|
||||||
## 3. Training Process
|
## 3. Training Process
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,21 @@
|
||||||
|
|
||||||
### 1. Model Architecture
|
### 1. Model Architecture
|
||||||
|
|
||||||
This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA (although KV cache is not currently implemented). The model is built by stacking 24 layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token.
|
This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA (although KV cache is not currently implemented). The model is built by stacking 32 layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token.
|
||||||
|
|
||||||
|
The model now uses the **AutoModel** base class for flexible loading and saving:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from astrai.model import AutoModel
|
||||||
|
|
||||||
|
# Load model from checkpoint
|
||||||
|
model = AutoModel.from_pretrained("path/to/model")
|
||||||
|
|
||||||
|
# Save model to new directory
|
||||||
|
model.save_pretrained("path/to/save")
|
||||||
|
```
|
||||||
|
|
||||||
|
The Transformer model is registered via `@AutoModel.register('transformer')` decorator, allowing easy extension for new model types.
|
||||||
|
|
||||||
```mermaid
|
```mermaid
|
||||||
flowchart TB
|
flowchart TB
|
||||||
|
|
|
||||||
|
|
@ -6,11 +6,12 @@
|
||||||
|
|
||||||
| Parameter | Description | Default Value |
|
| Parameter | Description | Default Value |
|
||||||
|-----------|-------------|---------------|
|
|-----------|-------------|---------------|
|
||||||
| `--train_type` | Training type (seq, sft, dpo) | required |
|
| `--train_type` | Training type (seq, sft, dpo, grpo) | required |
|
||||||
|
| `--model_type` | Model type for AutoModel loading (e.g., transformer) | transformer |
|
||||||
| `--data_root_path` | Dataset root directory | required |
|
| `--data_root_path` | Dataset root directory | required |
|
||||||
| `--param_path` | Model parameters or checkpoint path | required |
|
| `--param_path` | Model parameters or checkpoint path | required |
|
||||||
| `--n_epoch` | Total training epochs | 1 |
|
| `--n_epoch` | Total training epochs | 1 |
|
||||||
| `--batch_size` | Batch size | 1 |
|
| `--batch_size` | Batch size | 4 |
|
||||||
| `--accumulation_steps` | Gradient accumulation steps | 1 |
|
| `--accumulation_steps` | Gradient accumulation steps | 1 |
|
||||||
|
|
||||||
### Learning Rate Scheduling
|
### Learning Rate Scheduling
|
||||||
|
|
@ -42,7 +43,9 @@
|
||||||
| Parameter | Description | Default Value |
|
| Parameter | Description | Default Value |
|
||||||
|-----------|-------------|---------------|
|
|-----------|-------------|---------------|
|
||||||
| `--random_seed` | Random seed | 3407 |
|
| `--random_seed` | Random seed | 3407 |
|
||||||
| `--num_workers` | DataLoader workers | 4 |
|
| `--num_workers` | DataLoader workers | 0 |
|
||||||
|
| `--prefetch_factor` | Prefetch factor for dataloader | None |
|
||||||
|
| `--pin_memory` | Enable pin_memory | False |
|
||||||
| `--no_pin_memory` | Disable pin_memory | - |
|
| `--no_pin_memory` | Disable pin_memory | - |
|
||||||
|
|
||||||
### Distributed Training
|
### Distributed Training
|
||||||
|
|
@ -71,44 +74,58 @@
|
||||||
|
|
||||||
| Parameter | Description | Default Value |
|
| Parameter | Description | Default Value |
|
||||||
|-----------|-------------|---------------|
|
|-----------|-------------|---------------|
|
||||||
| `query` | Input text or text list | required |
|
| `messages` | List of message dictionaries (role, content) | required |
|
||||||
| `history` | Conversation history | None |
|
| `temperature` | Sampling temperature (higher = more random) | 1.0 |
|
||||||
| `system_prompt` | System prompt | None |
|
| `top_p` | Nucleus sampling threshold | 1.0 |
|
||||||
| `temperature` | Sampling temperature (higher = more random) | required |
|
| `top_k` | Top-k sampling count | 50 |
|
||||||
| `top_p` | Nucleus sampling threshold | required |
|
| `max_len` | Maximum generation length | 1024 |
|
||||||
| `top_k` | Top-k sampling count | required |
|
|
||||||
| `max_len` | Maximum generation length | model config max_len |
|
|
||||||
| `stream` | Whether to stream output | False |
|
| `stream` | Whether to stream output | False |
|
||||||
|
|
||||||
### Usage Example
|
### Usage Example
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from astrai.config import ModelParameter
|
import torch
|
||||||
|
from astrai.model import AutoModel
|
||||||
|
from astrai.tokenize import Tokenizer
|
||||||
from astrai.inference import InferenceEngine, GenerationRequest
|
from astrai.inference import InferenceEngine, GenerationRequest
|
||||||
|
|
||||||
# Load model
|
# Load model using AutoModel
|
||||||
param = ModelParameter.load("your_model_dir")
|
model = AutoModel.from_pretrained("your_model_dir")
|
||||||
param.to(device="cuda", dtype=torch.bfloat16)
|
|
||||||
|
# Load tokenizer
|
||||||
|
tokenizer = Tokenizer("your_model_dir")
|
||||||
|
|
||||||
# Create engine with separate model and tokenizer
|
# Create engine with separate model and tokenizer
|
||||||
engine = InferenceEngine(
|
engine = InferenceEngine(
|
||||||
model=param.model,
|
model=model,
|
||||||
tokenizer=param.tokenizer,
|
tokenizer=tokenizer,
|
||||||
config=param.config,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build request
|
# Build request with messages format
|
||||||
request = GenerationRequest(
|
request = GenerationRequest(
|
||||||
query="Hello",
|
messages=[
|
||||||
history=[],
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
],
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
|
max_len=1024,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate (streaming)
|
# Generate (streaming)
|
||||||
for token in engine.generate_with_request(request):
|
for token in engine.generate_with_request(request):
|
||||||
print(token, end="", flush=True)
|
print(token, end="", flush=True)
|
||||||
|
|
||||||
|
# Or use simple generate interface
|
||||||
|
result = engine.generate(
|
||||||
|
prompt="Hello",
|
||||||
|
stream=False,
|
||||||
|
max_tokens=1024,
|
||||||
|
temperature=0.8,
|
||||||
|
top_p=0.95,
|
||||||
|
top_k=50,
|
||||||
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Generation Modes
|
### Generation Modes
|
||||||
|
|
|
||||||
|
|
@ -188,8 +188,13 @@ class InferenceEngine:
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._generate_streaming(
|
return self._generate_streaming(
|
||||||
prompts, is_batch, max_tokens, temperature, top_p, top_k,
|
prompts,
|
||||||
abort_on_exception
|
is_batch,
|
||||||
|
max_tokens,
|
||||||
|
temperature,
|
||||||
|
top_p,
|
||||||
|
top_k,
|
||||||
|
abort_on_exception,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self._generate_non_streaming(
|
return self._generate_non_streaming(
|
||||||
|
|
@ -307,16 +312,15 @@ class InferenceEngine:
|
||||||
|
|
||||||
Use this for emergency shutdown when graceful shutdown is not possible.
|
Use this for emergency shutdown when graceful shutdown is not possible.
|
||||||
"""
|
"""
|
||||||
import os
|
|
||||||
|
|
||||||
# Stop watching threads if any
|
# Stop watching threads if any
|
||||||
if hasattr(self, 'stop_watching'):
|
if hasattr(self, "stop_watching"):
|
||||||
self.stop_watching()
|
self.stop_watching()
|
||||||
|
|
||||||
# Unregister signal handlers
|
# Unregister signal handlers
|
||||||
if hasattr(self, '_original_sigint'):
|
if hasattr(self, "_original_sigint"):
|
||||||
signal.signal(signal.SIGINT, self._original_sigint)
|
signal.signal(signal.SIGINT, self._original_sigint)
|
||||||
if hasattr(self, '_original_sigterm'):
|
if hasattr(self, "_original_sigterm"):
|
||||||
signal.signal(signal.SIGTERM, self._original_sigterm)
|
signal.signal(signal.SIGTERM, self._original_sigterm)
|
||||||
|
|
||||||
# Force stop scheduler
|
# Force stop scheduler
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue