Compare commits

...

10 Commits

18 changed files with 1079 additions and 416 deletions

View File

@ -84,6 +84,38 @@ python scripts/tools/train.py \
python scripts/tools/generate.py --param_path=/path/to/param_path
```
#### Start HTTP Server
Start the inference server with OpenAI-compatible HTTP API:
```bash
python -m scripts.tools.server --port 8000 --device cuda
```
Make requests:
```bash
# Chat API (OpenAI compatible)
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 512
}'
# Streaming response
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "Tell a story"}],
"stream": true,
"max_tokens": 500
}'
# Health check
curl http://localhost:8000/health
```
#### Demo
Check out the demos in the `scripts/demo/` folder:

View File

@ -85,6 +85,38 @@ python scripts/tools/train.py \
python scripts/tools/generate.py --param_path=/path/to/param_path
```
#### 启动 HTTP 服务
启动推理服务器,支持 OpenAI 兼容的 HTTP API
```bash
python -m scripts.tools.server --port 8000 --device cuda
```
发起请求:
```bash
# Chat APIOpenAI 兼容)
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "你好"}],
"max_tokens": 512
}'
# 流式响应
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "讲个故事"}],
"stream": true,
"max_tokens": 500
}'
# 健康检查
curl http://localhost:8000/health
```
#### 演示
查看 `scripts/demo/` 文件夹中的演示:

View File

@ -12,6 +12,7 @@ AstrAI adopts a modular design with the following main components:
- **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations
- **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration
- **Parallel Module** (`astrai/parallel/`): Distributed training support
- **Serialization Module** (`astrai/serialization/`): HDF5 data loading, checkpoint management
The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**.
@ -21,11 +22,11 @@ The data flow can generally be divided into two main lines: **Training Data Flow
flowchart LR
subgraph A[Data Preparation]
direction TB
A1[Raw Text] --> A2[BpeTokenizer]
A1[Raw Text] --> A2[AutoTokenizer]
A2 --> A3[Serialize to .h5 files]
A3 --> A4[BaseDataset]
A4 --> A5[ResumableDistributedSampler]
A5 --> A6[DataLoader]
A5 --> A6[PyTorch DataLoader]
end
subgraph B[Training]
@ -50,7 +51,7 @@ flowchart LR
C5 --> C6[InferenceScheduler]
C6 --> C7[apply_sampling_strategies]
C7 --> C8[Transformer Forward]
C8 --> C9[KV Cache]
C8 --> C9[KV Cache + Prefix Cache]
C9 --> C10{End Condition?}
C10 -->|No| C8
C10 -->|Yes| C11[Output Text]
@ -64,25 +65,18 @@ flowchart LR
### 1. Dataset Module
#### 1.1 Tokenizer (`tokenizer.py`)
- Implemented based on Byte-Level BPE (BPE)
- Supports special tokens: `<begin▁of▁sentence>`, `<end▁of▁sentence>`, `<▁pad▁>`, `<im▁start>`, `<im▁end>`
- Provides `encode`/`decode` methods for mutual conversion between text and token IDs
- Learns vocabulary from corpus during training, saved as `.json` files
- `BpeTrainer` class handles vocabulary training from corpus
#### 1.2 Serialization (`serialization.py`)
#### 1.1 Serialization (`serialization.py`)
- **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors
- **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory (`share_memory=True`)
- **`Checkpoint` class**: Encapsulates model state dict, training epoch, iteration count; supports safetensors format for saving and loading
#### 1.3 Dataset (`dataset.py`)
#### 1.2 Dataset (`dataset.py`)
- **`BaseDataset`**: Abstract base class, defines common logic for window sampling, stride, etc.
- **`BaseSegmentFetcher`** and **`MultiSegmentFetcher`**: Efficiently fetch data from specified index ranges in multiple segments
- **`DatasetFactory`**: Factory pattern, supports dynamic registration of dataset types (`seq`, `sft`, `dpo`, `grpo`)
- After dataset loading, multiple data keys (such as `"sequence"`, `"mask"`) are managed through `MultiSegmentFetcher`
#### 1.4 Sampler (`sampler.py`)
#### 1.3 Sampler (`sampler.py`)
- **`ResumableDistributedSampler`**: Resumable sampler supporting distributed training
- Records current epoch and iteration position, enabling training resume from breakpoints
- Supports shuffle and drop_last options
@ -99,7 +93,10 @@ flowchart LR
#### 2.2 Submodules (`module.py`)
- **`RotaryEmbedding`**: Generates RoPE cos/sin cache
- **`DecoderBlock`**: Contains multi-head attention (supports GQA), feedforward network (FFN), residual connections
- **`DecoderBlock`**: Contains multi-head attention (supports GQA and MLA), feedforward network (FFN), residual connections
- **`GQA`**: Grouped Query Attention implementation
- **`MLA`**: Multi-Latent Attention implementation (like Qwen2-VL)
- **`MLP`**: Feed-forward network with SiLU activation and gated mechanism
- **`RMSNorm`**: Layer normalization variant
- **`Linear`**, **`Embedding`**: Custom linear layer and embedding layer, supporting parallelism wrappers
@ -116,15 +113,29 @@ flowchart LR
1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end`
#### 3.3 Strategy (`strategy.py`)
- **`BaseStrategy`**: Defines training strategy interface (such as `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`)
- **`BaseStrategy`**: Defines training strategy interface
- **`SEQStrategy`**: Standard next-token prediction training
- **`SFTStrategy`**: Supervised Fine-tuning with loss masking
- **`DPOStrategy`**: Direct Preference Optimization
- **`GRPOStrategy`**: Group Relative Policy Optimization
- Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor
- Created dynamically by `StrategyFactory` according to configuration
#### 3.4 Scheduler (`schedule.py`)
- **`BaseScheduler`**: Abstract base class defining learning rate scheduling interface
- **`SchedulerFactory`**: Factory pattern, supports registration of various schedulers (such as `cosine`, `sgdr`)
- **`CosineScheduler`**: Cosine decay scheduler with warmup
- **`SGDRScheduler`**: Stochastic Gradient Descent with Warm Restarts
- **`SchedulerFactory`**: Factory pattern, supports registration of various schedulers
- Scheduler is automatically created according to configuration and bound to optimizer
#### 3.5 Callbacks (`train_callback.py`)
- **`TrainCallback`**: Protocol interface for trainer callbacks
- **`CheckpointCallback`**: Saves model checkpoints at configurable intervals
- **`ProgressBarCallback`**: Displays training progress
- **`MetricLoggerCallback`**: Logs training metrics to JSON files
- **`GradientClippingCallback`**: Clips gradient norms
- **`SchedulerCallback`**: Steps learning rate scheduler
### 4. Factory Module
#### 4.1 Registry and BaseFactory (`factory.py`)
@ -133,9 +144,24 @@ flowchart LR
- Supports decorator-based registration pattern for extensible components
- Provides methods for registration, retrieval, and listing with filtering
### 5. Inference Module
### 5. Parallel Module
#### 5.1 Inference Engine (`engine.py`)
#### 5.1 Setup (`setup.py`)
- **`spawn_parallel_fn`**: Spawns multiple processes for distributed training using PyTorch multiprocessing
- **`setup_parallel`**: Context manager for initializing distributed process group (NCCL/CCL backend)
- **`only_on_rank`**: Decorator to execute functions only on specific ranks
- **`get_rank`**: Returns current process rank in distributed group
- **`get_world_size`**: Returns total number of processes in distributed group
- **`get_current_device`**: Returns current device from environment
#### 5.2 Parallel Layers (`module.py`)
- **`ParallelModel`**: Base class for parallel models with process group
- **`ColumnParallelLinear`**: Column-parallel linear layer with input splitting and output gathering
- **`RowParallelLinear`**: Row-parallel linear layer with output reduction
### 6. Inference Module
#### 6.1 Inference Engine (`engine.py`)
- **`InferenceEngine`**: Unified inference interface, supports streaming and non-streaming generation
- **`InferenceScheduler`**: Continuous batching scheduler with dynamic batch composition
- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.)
@ -145,22 +171,38 @@ flowchart LR
- Supports continuous batching with `max_batch_size` and `max_seq_len` parameters
- Uses separate model and tokenizer initialization for flexibility
#### 5.2 Scheduler (`scheduler.py`)
#### 6.2 Scheduler (`scheduler.py`)
- **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED)
- **`TaskStatus`**: Task state enumeration
- **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits
- **`PrefixCacheManager`**: Radix tree-based prefix cache with LRU eviction for efficient KV cache reuse
- **`RadixNode`**: Tree node structure for prefix caching
- Continuous batching: new requests can join at any time, completed requests are released immediately
#### 5.3 Request (`engine.py`)
- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.)
- **`messages` format**: List of message dictionaries with `role` (system/user/assistant) and `content`
- **`apply_chat_template`** (from `tokenizer.py`): Converts messages into prompt string using ChatML format
- Provides streaming (`stream=True`) and non-streaming (`stream=False`) generation interfaces
#### 6.3 Server (`server.py`)
- FastAPI-based HTTP inference server
- OpenAI-compatible `/v1/chat/completions` endpoint
- Health check and statistics endpoints
- Supports both streaming and non-streaming responses
### 7. Tokenizer Module
#### 7.1 Tokenizer (`tokenizer.py`)
- Implemented based on HuggingFace tokenizers library (Byte-Level BPE)
- **`AutoTokenizer`**: Auto-loading tokenizer class
- Supports special tokens: `<begin▁of▁sentence>`, `<end▁of▁sentence>`, `<▁pad▁>`, `<im▁start>`, `<im▁end>`
- Provides `encode`/`decode` methods for mutual conversion between text and token IDs
- Uses `AutoTokenizer` for loading pre-trained tokenizers
#### 7.2 Chat Template (`chat_template.py`)
- **`ChatTemplate`**: Jinja2-based chat template with rendering support
- Handles multi-role message formatting (system, user, assistant)
- Supports dynamic prompts and generation prompts
## Training Data Flow - Detailed Steps
1. **Data Preparation**
- Raw text is converted to token ID sequences through BPE tokenizer
- Raw text is converted to token ID sequences through AutoTokenizer
- Token ID sequences (possibly with masks, labels, etc.) are saved by groups as `.h5` files
- Files can contain multiple segments, each segment corresponds to a tensor
@ -171,7 +213,7 @@ flowchart LR
3. **Sampling and Batch Loading**
- `ResumableDistributedSampler` generates index sequence based on current epoch and iteration position
- `DataLoader` uses sampler to get indices, calls dataset's `__getitem__` to get actual data
- PyTorch `DataLoader` uses sampler to get indices, calls dataset's `__getitem__` to get actual data
- Batch data shape is `[batch_size, window_size]` (or varies according to specific dataset type)
4. **Strategy Forward and Loss Calculation**
@ -202,7 +244,7 @@ flowchart LR
- For batch generation, use `pad_sequence` for padding
3. **Autoregressive Generation Loop**
- Initialize KV cache (optional)
- Initialize KV cache (optional) and prefix cache
- Loop until generating `max_len` tokens or encountering stop token:
- Input current `input_ids` (or cached new token) to model, obtain `logits`
- Apply `apply_sampling_strategies` (temperature, top-k, top-p) to `logits`
@ -222,7 +264,6 @@ flowchart LR
## Summary
The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using KV cache and sampling strategies. Clear interfaces between modules facilitate customization and extension.
The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using KV cache, prefix caching, and sampling strategies. Clear interfaces between modules facilitate customization and extension.
> Document Update Time: 2026-04-05
> Corresponding Code Version: Refer to version number defined in `pyproject.toml`
> Document Update Time: 2026-04-09

View File

@ -8,7 +8,7 @@ Thus, the AstrAI project was born - 1B parameters, Chinese-English bilingual, su
```mermaid
classDiagram
namespace astrai.config {
namespace config {
class ModelConfig {
+int vocab_size
+int dim
@ -56,17 +56,9 @@ classDiagram
+validate()
}
class ModelParameter {
+nn.Module model
+BpeTokenizer tokenizer
+ModelConfig config
+save(instance, save_dir)
+load(load_dir, disable_init) ModelParameter
+to(*args, **kwargs)
}
}
namespace astrai.dataset {
namespace dataset {
class BaseDataset {
+int window_size
+int stride
@ -100,8 +92,8 @@ classDiagram
}
class MultiSegmentFetcher {
+Dict muti_fetchers
+List muti_keys
+Dict multi_fetchers
+List multi_keys
+key_fetch(begin_idx, end_idx, keys) Dict
+fetch_data(begin_idx, end_idx) Dict
}
@ -125,17 +117,9 @@ classDiagram
+save(save_dir)
+load(save_dir) Checkpoint
}
class DataLoader {
+Dataset dataset
+int batch_size
+Sampler sampler
+__iter__()
+__len__()
}
}
namespace astrai.model {
namespace model {
class AutoModel {
+ModelConfig config
+Dict _registry
@ -148,7 +132,7 @@ classDiagram
class Transformer {
+ModelConfig config
+RotaryEmbedding rotary_embeding
+RotaryEmbedding rotary_embedding
+Embedding embed_tokens
+ModuleList layers
+RMSNorm norm
@ -216,24 +200,46 @@ classDiagram
}
}
namespace astrai.tokenize {
class Tokenizer {
+encode(tokens, out_ids, add_special_tokens) List~int~
+decode(tokens, skip_special_tokens) str
+__len__() int
}
class BpeTokenizer {
namespace tokenize {
class AutoTokenizer {
+List~str~ stop_ids
+int bos_id
+int eos_id
+int pad_id
+vocab_size int
+encode(tokens, out_ids, add_special_tokens) List~int~
+decode(tokens, skip_special_tokens) str
+apply_chat_template(messages, tokenize) Union~str, List[int]~
+set_chat_template(template)
+load(path)
+from_pretrained(path) AutoTokenizer
+save_pretrained(save_path)
}
class ChatTemplate {
+String template_str
+render(messages, add_generation_prompt) str
+from_string(template) ChatTemplate
}
}
namespace astrai.trainer {
namespace factory {
class Registry {
+Dict _entries
+register(name, component_cls, category, priority)
+get(name) Type
+list_names() List~str~
}
class BaseFactory {
+Registry _registry
+register(name, category, priority) decorator
+create(name, *args, **kwargs) T
+list_registered() list
}
}
namespace trainer {
class Trainer {
+TrainConfig train_config
+List~TrainCallback~ callbacks
@ -337,6 +343,39 @@ classDiagram
+on_error(context)
}
class GradientClippingCallback {
+float max_grad_norm
+on_step_begin(context)
}
class SchedulerCallback {
+on_train_begin(context)
+on_batch_end(context)
}
class CheckpointCallback {
+str save_dir
+int interval
+_save_checkpoint(context)
+on_batch_end(context)
+on_train_end(context)
+on_error(context)
}
class ProgressBarCallback {
+int num_epoch
+on_epoch_begin(context)
+on_batch_end(context)
+on_epoch_end(context)
}
class MetricLoggerCallback {
+str log_dir
+int save_interval
+on_batch_end(context)
+on_train_end(context)
}
class CallbackFactory {
+Registry _registry
+register(name) decorator
@ -344,10 +383,17 @@ classDiagram
}
}
namespace astrai.inference {
namespace inference {
class InferenceEngine {
+ModelParameter parameter
+nn.Module model
+AutoTokenizer tokenizer
+InferenceScheduler scheduler
+int max_batch_size
+Optional int max_seq_len
+int max_prefix_len
+int cache_capacity
+Tensor kv_cache
+Tensor seq_mask
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
+generate_with_request(request) Union[Generator, str, List[str]]
+get_stats() Dict
@ -356,10 +402,11 @@ classDiagram
class InferenceScheduler {
+nn.Module model
+Tokenizer tokenizer
+AutoTokenizer tokenizer
+ModelConfig config
+Tuple kv_cache
+Tensor seq_mask
+PrefixCacheManager prefix_cache
+List waiting_queue
+List active_tasks
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
@ -369,6 +416,24 @@ classDiagram
+get_stats() Dict
}
class PrefixCacheManager {
+RadixNode root
+int max_capacity
+List lru
+insert(token_ids, slot)
+find_longest_prefix(token_ids) Tuple[int, int]
+release(token_ids)
}
class RadixNode {
+Dict children
+int hash
+int slot
+int ref_count
+float last_access
+List token_sequence
}
class Task {
+str task_id
+List prompt_ids
@ -392,14 +457,6 @@ classDiagram
+str ABORTED
}
class apply_sampling_strategies {
+Tensor logits
+float temperature
+int top_k
+float top_p
+forward() Tensor
}
class Server {
+start()
+predict(request)
@ -410,19 +467,54 @@ classDiagram
+float top_p
+float temperature
+int max_len
+Union~str, List~str~~ query
+history Optional
+system_prompt Optional~str~
+List~Dict~ messages
+stream bool
}
class _Result {
+List~str~ tokens
+List~str~ results
+List~bool~ done_flags
+append(token, idx)
+get_results() List~str~
}
class ChatMessage {
+str role
+str content
}
class ChatCompletionRequest {
+List~ChatMessage~ messages
+float temperature
+float top_p
+int top_k
+int max_tokens
+bool stream
+Optional~str~ system_prompt
}
class CompletionResponse {
+str id
+str object
+int created
+str model
+List~Dict~ choices
}
}
namespace astrai.parallel {
namespace parallel {
class ParallelSetup {
+spawn_parallel_fn(fn, nprocs)
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type, device_ids)
}
class ParallelModel {
+dist.ProcessGroup process_group
+int rank
+int world_size
}
class ColumnParallelLinear {
+forward(x) Tensor
}
@ -433,9 +525,16 @@ classDiagram
}
%% Relationships
TrainConfig --> ModelConfig : contains
TrainConfig --> ModelConfig : uses
TrainConfig --> BaseDataset : uses
TrainConfig --> Transformer : uses
TrainConfig --> StrategyFactory : selects
StrategyFactory ..> BaseStrategy : creates
BaseStrategy <|-- SEQStrategy
BaseStrategy <|-- SFTStrategy
BaseStrategy <|-- DPOStrategy
BaseStrategy <|-- GRPOStrategy
DPOStrategy --> Transformer : uses
GRPOStrategy --> Transformer : uses
Trainer --> TrainConfig : configures
Trainer --> TrainContextBuilder : builds
Trainer --> TrainCallback : manages
@ -443,30 +542,27 @@ classDiagram
TrainContext --> Checkpoint : manages
TrainContext --> BaseStrategy : uses
TrainContext --> BaseScheduler : uses
StrategyFactory ..> BaseStrategy : creates
BaseStrategy <|-- SEQStrategy
BaseStrategy <|-- SFTStrategy
BaseStrategy <|-- DPOStrategy
BaseStrategy <|-- GRPOStrategy
DPOStrategy --> Transformer : creates ref_model
GRPOStrategy --> Transformer : creates ref_model
AutoModel --> ModelConfig : contains
SchedulerFactory ..> BaseScheduler : creates
BaseScheduler <|-- CosineScheduler
BaseScheduler <|-- SGDRScheduler
CallbackFactory ..> TrainCallback : creates
TrainCallback <|-- GradientClippingCallback
TrainCallback <|-- SchedulerCallback
TrainCallback <|-- CheckpointCallback
TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback
InferenceEngine --> InferenceScheduler : uses
InferenceScheduler --> Task : manages
InferenceScheduler --> TaskStatus : uses
InferenceScheduler --> apply_sampling_strategies : uses
InferenceScheduler --> Transformer : uses
InferenceEngine --> Transformer : uses
InferenceEngine --> GenerationRequest : uses
Server --> InferenceEngine : uses
Server --> ChatMessage : uses
Server --> ChatCompletionRequest : uses
Server --> CompletionResponse : uses
ParallelSetup --> Trainer : enables
TrainConfig --> StrategyFactory : selects
ModelParameter --> Transformer : contains
ModelParameter --> BpeTokenizer : contains
ModelParameter --> ModelConfig : contains
BaseDataset <|-- SEQDataset
BaseDataset <|-- SFTDataset
BaseDataset <|-- DPODataset
@ -483,22 +579,34 @@ classDiagram
DecoderBlock --> MLA : uses
DecoderBlock --> MLP : uses
DecoderBlock --> RMSNorm : uses
BpeTokenizer --> Tokenizer : inherits
TrainContextBuilder --> ResumableDistributedSampler : creates
DataLoader --> BaseDataset : uses
ResumableDistributedSampler --> BaseDataset : samples
ParallelModel <|-- RowParallelLinear
ParallelModel <|-- ColumnParallelLinear
AutoTokenizer --> ChatTemplate : uses
InferenceScheduler --> PrefixCacheManager : uses
InferenceScheduler --> RadixNode : uses
Checkpoint ..> Checkpoint : saves/loads
TrainConfig --> DatasetFactory : selects
TrainConfig --> SchedulerFactory : selects
TrainConfig --> CallbackFactory : selects
AutoModel ..> AutoTokenizer : loads with
BaseFactory <|-- DatasetFactory
BaseFactory <|-- StrategyFactory
BaseFactory <|-- SchedulerFactory
BaseFactory <|-- CallbackFactory
```
### Module Overview
| Module | Components | Description |
|--------|------------|-------------|
| **astrai.config** | ModelConfig, TrainConfig, ModelParameter | Configuration management |
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint, DataLoader | Dataset loading and management |
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint | Dataset loading and management |
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, BpeTokenizer, ChatTemplate, BpeTrainer | Tokenizer |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest | Inference service with continuous batching |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest, PrefixCacheManager, ChatMessage, ChatCompletionRequest, CompletionResponse | Inference service with continuous batching |
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
@ -515,7 +623,7 @@ classDiagram
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
| **Generator Pattern** | `_StreamingResult`, `_NonStreamingResult` | Event-based result notification for streaming/non-streaming generation |
| **Generator Pattern** | `_Result`, `GenerationRequest` | Event-based result notification for streaming/non-streaming generation |
### Core Relationships
@ -582,3 +690,5 @@ $$
The final loss is the sum of both: $L = L_{\text{policy}} + L_{KL}$
Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence.
> Document Update Time: 2026-04-09

View File

@ -2,7 +2,7 @@
### 1. Model Architecture
This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA (although KV cache is not currently implemented). The model is built by stacking 32 layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token.
This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA. The model is built by stacking 32 layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token.
The model now uses the **AutoModel** base class for flexible loading and saving:
@ -190,4 +190,110 @@ for token in engine.generate_with_request(request):
print(token, end="", flush=True)
```
The continuous batching feature allows dynamic batch composition where new requests can join at any time and completed requests are released immediately.
The continuous batching feature allows dynamic batch composition where new requests can join at any time and completed requests are released immediately.
## HTTP API Usage
The inference server provides HTTP endpoints for remote inference. Start the server first:
```bash
python -m scripts.tools.server --port 8000
```
### OpenAI-Compatible Endpoint
The server provides an OpenAI-compatible chat completion endpoint at `/v1/chat/completions`:
```bash
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"}
],
"temperature": 0.8,
"max_tokens": 2048,
"stream": false
}'
```
**Request Parameters:**
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `messages` | List[dict] | Required | Chat messages with role and content |
| `temperature` | float | 0.8 | Sampling temperature (0.0-2.0) |
| `top_p` | float | 0.95 | Nucleus sampling threshold |
| `top_k` | int | 50 | Top-k sampling parameter |
| `max_tokens` | int | 2048 | Maximum tokens to generate |
| `stream` | bool | false | Enable streaming response |
| `system_prompt` | str | None | System prompt override |
**Response (non-streaming):**
```json
{
"id": "chatcmpl-1234567890",
"object": "chat.completion",
"created": 1234567890,
"model": "astrai",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": "Hello! I'm doing well..."},
"finish_reason": "stop"
}
]
}
```
### Streaming Response
Enable streaming for real-time token-by-token output:
```bash
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "Write a story"}],
"stream": true,
"max_tokens": 500
}'
```
The server uses Server-Sent Events (SSE) with content type `text/event-stream`.
### Simple Generation Endpoint
For basic text generation without chat format:
```bash
curl -X POST "http://localhost:8000/generate?query=Hello&max_len=1000" \
-H "Content-Type: application/json"
```
Or with conversation history:
```bash
curl -X POST "http://localhost:8000/generate" \
-H "Content-Type: application/json" \
-d '{
"query": "What is AI?",
"history": [["Hello", "Hi there!"], ["How are you?", "I'm doing well"]],
"temperature": 0.8,
"max_len": 2048
}'
```
### Health Check
Monitor server and model status:
```bash
curl http://localhost:8000/health
# {"status": "ok", "model_loaded": true, "engine_ready": true}
curl http://localhost:8000/stats
# {"requests_total": 10, "tokens_generated": 5000, ...}
```
> Document Update Time: 2026-04-09

View File

@ -136,4 +136,6 @@ result = engine.generate(
| Mode | Description |
|------|-------------|
| `stream=True` | Streaming output, yields token by token |
| `stream=False` | Non-streaming output, returns complete result |
| `stream=False` | Non-streaming output, returns complete result |
> Document Update Time: 2026-04-09

View File

@ -12,18 +12,19 @@ from astrai.inference import (
InferenceEngine,
)
from astrai.model import AutoModel, Transformer
from astrai.tokenize import BpeTokenizer
from astrai.trainer import SchedulerFactory, StrategyFactory, Trainer
from astrai.tokenize import AutoTokenizer
from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer
__all__ = [
"Transformer",
"ModelConfig",
"TrainConfig",
"DatasetFactory",
"BpeTokenizer",
"AutoTokenizer",
"GenerationRequest",
"InferenceEngine",
"Trainer",
"CallbackFactory",
"StrategyFactory",
"SchedulerFactory",
"BaseFactory",

View File

@ -72,15 +72,16 @@ class MultiSegmentFetcher:
Each key corresponds to a different type of data (e.g., "sequence", "mask").
"""
def __init__(self, muti_segments: Dict):
self.muti_keys = list(muti_segments.keys())
self.muti_fetchers = {
key: BaseSegmentFetcher(segments) for key, segments in muti_segments.items()
def __init__(self, multi_segments: Dict):
self.multi_keys = list(multi_segments.keys())
self.multi_fetchers = {
key: BaseSegmentFetcher(segments)
for key, segments in multi_segments.items()
}
def __len__(self) -> int:
"""Returns the minimum length across all fetchers."""
len_list = [len(seg) for seg in self.muti_fetchers.values()]
len_list = [len(seg) for seg in self.multi_fetchers.values()]
return min(len_list)
def key_fetch(
@ -100,7 +101,7 @@ class MultiSegmentFetcher:
keys = [keys] if isinstance(keys, str) else keys
for key in keys:
fetcher = self.muti_fetchers[key]
fetcher = self.multi_fetchers[key]
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
fetch_dict[key] = fetch_tensor
@ -108,7 +109,7 @@ class MultiSegmentFetcher:
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
"""Fetch all keys."""
return self.key_fetch(begin_idx, end_idx, self.muti_keys)
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
class BaseDataset(Dataset, ABC):

View File

@ -45,17 +45,31 @@ class GenerationRequest:
raise ValueError("temperature must be a non-negative number")
class _StreamingResult:
"""Streaming result holder with event-based notification."""
class _Result:
"""Unified result holder for streaming/non-streaming modes."""
def __init__(self):
self.tokens: List[str] = []
self._event = threading.Event()
def __init__(self, count: int = 1, stream: bool = False):
self._stream = stream
self._lock = threading.Lock()
self._event = threading.Event()
self.tokens: List[str] = []
self.results: List[str] = [""] * count if count > 1 else [""]
self.done_flags: List[bool] = [False] * count
self._completed_count = 0
def append(self, token: str):
def append(self, token: str, idx: int = 0):
with self._lock:
self.tokens.append(token)
if self._stream:
self.tokens.append(token)
else:
if token == "[DONE]":
if not self.done_flags[idx]:
self.done_flags[idx] = True
self._completed_count += 1
if self._completed_count == len(self.results):
self._event.set()
else:
self.results[idx] += token
self._event.set()
def pop_all(self) -> List[str]:
@ -69,35 +83,6 @@ class _StreamingResult:
def wait(self, timeout: float = None) -> bool:
return self._event.wait(timeout=timeout)
class _NonStreamingResult:
"""Non-streaming result holder with event-based completion notification."""
def __init__(self, count: int):
self.results: List[str] = [""] * count
self.done_flags: List[bool] = [False] * count
self._completed_count = 0
self._event = threading.Event()
self._lock = threading.Lock()
def append(self, idx: int, token: str):
with self._lock:
if token == "[DONE]":
if not self.done_flags[idx]:
self.done_flags[idx] = True
self._completed_count += 1
if self._completed_count == len(self.results):
self._event.set()
else:
self.results[idx] += token
def is_all_done(self) -> bool:
with self._lock:
return all(self.done_flags)
def wait(self, timeout: float = None) -> bool:
return self._event.wait(timeout=timeout)
def get_results(self) -> List[str]:
with self._lock:
return self.results.copy()
@ -112,6 +97,8 @@ class InferenceEngine:
tokenizer: AutoTokenizer,
max_batch_size: int = 1,
max_seq_len: Optional[int] = None,
max_prefix_len: int = 512,
cache_capacity: int = 1000,
):
"""
Initialize inference engine with separate model and tokenizer.
@ -122,6 +109,8 @@ class InferenceEngine:
config: Model configuration
max_batch_size: Maximum batch size for continuous batching
max_seq_len: Maximum sequence length (defaults to config.max_len)
max_prefix_len: Maximum prefix length for cache (default: 512)
cache_capacity: Maximum number of cached prefixes (default: 1000)
"""
self.model = model
self.tokenizer = tokenizer
@ -141,6 +130,8 @@ class InferenceEngine:
tokenizer=self.tokenizer,
max_batch_size=max_batch_size,
max_seq_len=max_seq_len,
max_prefix_len=max_prefix_len,
cache_capacity=cache_capacity,
device=device,
dtype=dtype,
)
@ -227,7 +218,7 @@ class InferenceEngine:
if is_batch:
raise NotImplementedError("Batch streaming is not implemented yet")
result = _StreamingResult()
result = _Result(stream=True)
task_id = self.scheduler.add_task(
prompt=prompts[0],
@ -266,7 +257,7 @@ class InferenceEngine:
top_k: int,
) -> Union[str, List[str]]:
"""Generate without streaming."""
result = _NonStreamingResult(len(prompts))
result = _Result(count=len(prompts))
for i, p in enumerate(prompts):
# Create closure to capture current index value using factory function

View File

@ -3,7 +3,7 @@
import threading
import time
import uuid
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from torch import Tensor
@ -12,6 +12,135 @@ from astrai.model.automodel import AutoModel
from astrai.tokenize import AutoTokenizer
class RadixNode:
"""Radix tree node for prefix cache."""
def __init__(self):
self.children: Dict[int, "RadixNode"] = {} # token_id -> child node
self.hash: Optional[int] = None # 64-bit hash of the prefix
self.slot: int = -1 # KV Cache slot, valid only for leaf nodes
self.ref_count: int = 0 # number of tasks referencing this prefix
self.last_access: float = 0.0 # timestamp for LRU
self.token_sequence: list = [] # full token sequence from root to this node
class PrefixCacheManager:
"""Prefix cache manager using Radix tree with LRU eviction."""
def __init__(self, max_capacity: int = 1000, base: int = 131, mod: int = 10**9 + 7):
self.root = RadixNode()
self.base = base
self.mod = mod
self.max_capacity = max_capacity
self.lru: List[Tuple[float, RadixNode]] = [] # (timestamp, node) for LRU
def insert(self, token_ids: Tuple[int, ...], slot: int) -> None:
"""Insert a prefix, increase ref_count if already exists, otherwise create new node."""
node = self.root
path = []
h = 0
for i, token_id in enumerate(token_ids):
if token_id not in node.children:
node.children[token_id] = RadixNode()
node = node.children[token_id]
h = (h * self.base + token_id) % self.mod
node.hash = h
path.append(token_id)
node.token_sequence = list(
path
) # store full sequence for exact verification
# Leaf node: set slot and increase ref_count
if node.slot == -1:
node.slot = slot
node.ref_count += 1
node.last_access = time.time()
self._update_lru(node)
self._evict_if_needed()
def find_longest_prefix(self, token_ids: List[int]) -> Optional[Tuple[int, int]]:
"""Find longest matching prefix, return (prefix_len, slot).
During traversal, compute hash per token and compare with node hash.
If hash matches, perform full token sequence verification to avoid
hash collision errors.
"""
node = self.root
best_len = 0
best_slot = -1
h = 0
for i, token_id in enumerate(token_ids):
if token_id not in node.children:
break
node = node.children[token_id]
h = (h * self.base + token_id) % self.mod
if node.hash == h: # hash matches
# Exact verification: compare full token sequence
if node.token_sequence == token_ids[: i + 1]:
best_len = i + 1
best_slot = node.slot
node.last_access = time.time()
self._update_lru(node)
if best_len > 0:
return (best_len, best_slot)
return None
def release(self, token_ids: Tuple[int, ...]) -> None:
"""Release reference to a prefix, decrease ref_count. If zero, mark as evictable."""
node = self.root
for token_id in token_ids:
if token_id not in node.children:
return
node = node.children[token_id]
if node.ref_count > 0:
node.ref_count -= 1
if node.ref_count == 0:
node.slot = -1 # slot can be reused
def _update_lru(self, node: RadixNode) -> None:
"""Update LRU list, move node to most recently used position."""
self.lru = [(ts, n) for (ts, n) in self.lru if n is not node]
self.lru.append((node.last_access, node))
def _evict_if_needed(self) -> None:
"""If cache entries exceed capacity, evict least recently used leaf nodes (ref_count must be 0)."""
if len(self.lru) <= self.max_capacity:
return
# Sort by timestamp
self.lru.sort(key=lambda x: x[0])
for ts, node in self.lru:
if node.ref_count == 0:
# Remove leaf node from tree (need to recursively delete empty branches)
self._remove_node(node)
self.lru.remove((ts, node))
if len(self.lru) <= self.max_capacity:
break
def _remove_node(
self,
node: RadixNode,
parent: Optional[RadixNode] = None,
child_key: Optional[int] = None,
) -> None:
"""Remove node from tree, including empty parent nodes."""
# First, recursively remove all children
for child_key, child_node in list(node.children.items()):
self._remove_node(child_node, node, child_key)
# Clear the node's leaf properties
node.slot = -1
node.hash = None
node.token_sequence = []
node.children.clear()
# If this node has no children and has a parent, remove the reference from parent
if parent is not None and child_key is not None and len(node.children) == 0:
if child_key in parent.children:
del parent.children[child_key]
class TaskStatus:
"""Task state for continuous batching."""
@ -46,6 +175,7 @@ class Task:
self.input_tokens: int = 0
self.output_tokens: int = 0
self.slot: int = -1
self.prefix_len: int = 0 # prefix cache matched length
self.arrival_time = time.time()
self.finish_time: Optional[float] = None
@ -53,9 +183,10 @@ class Task:
def is_finished(self, stop_ids: List[int]) -> bool:
"""Check if task is finished."""
if self.output_ids and self.output_ids[-1] in stop_ids:
return True
return self.output_tokens >= self.max_tokens
return (
bool(self.output_ids and self.output_ids[-1] in stop_ids)
or self.output_tokens >= self.max_tokens
)
def apply_sampling_strategies(
@ -104,6 +235,8 @@ class InferenceScheduler:
tokenizer: AutoTokenizer,
max_batch_size: int = 16,
max_seq_len: Optional[int] = None,
max_prefix_len: int = 512,
cache_capacity: int = 1000,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
@ -113,9 +246,13 @@ class InferenceScheduler:
self.tokenizer = tokenizer
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len or config.max_len
self.max_prefix_len = max_prefix_len
self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype
# Initialize prefix cache
self.prefix_cache = PrefixCacheManager(max_capacity=cache_capacity)
num_kv_heads = config.n_kv_heads
head_dim = config.dim // config.n_heads
n_layers = config.n_layers
@ -170,6 +307,10 @@ class InferenceScheduler:
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
prompt_ids = self.tokenizer.encode(prompt)
# Truncate if exceeds max_prefix_len
if len(prompt_ids) > self.max_prefix_len:
prompt_ids = prompt_ids[: self.max_prefix_len]
task = Task(
task_id=task_id,
prompt_ids=prompt_ids,
@ -180,6 +321,16 @@ class InferenceScheduler:
stream_callback=stream_callback,
)
# Find longest matching prefix from cache
match = self.prefix_cache.find_longest_prefix(prompt_ids)
if match:
prefix_len, slot = match
task.prefix_len = prefix_len
task.slot = slot
else:
task.prefix_len = 0
task.slot = -1
with self._lock:
self.waiting_queue.append(task)
self._total_tasks += 1
@ -207,6 +358,11 @@ class InferenceScheduler:
slot = task.slot
if slot >= 0 and slot < len(self.active_tasks):
self.seq_mask[slot, :] = False
# Release prefix cache reference
if task.prefix_len > 0:
self.prefix_cache.release(tuple(task.prompt_ids[: task.prefix_len]))
task.slot = -1
self.active_tasks = [
@ -220,22 +376,51 @@ class InferenceScheduler:
return
with self._lock:
to_add = []
for _ in range(min(available_slots, len(self.waiting_queue))):
if self.waiting_queue:
task = self.waiting_queue.pop(0)
task.status = TaskStatus.RUNNING
to_add.append(task)
to_add = [
self.waiting_queue.pop(0)
for _ in range(min(available_slots, len(self.waiting_queue)))
]
for task in to_add:
for i in range(self.max_batch_size):
if all(t.slot != i for t in self.active_tasks):
task.slot = i
break
task.slot = self._allocate_slot()
task.status = TaskStatus.RUNNING
self.active_tasks.append(task)
def _allocate_slot(self) -> int:
"""Allocate an available slot for a task."""
for i in range(self.max_batch_size):
if not any(t.slot == i for t in self.active_tasks):
return i
return -1
def _execute_prefill(self, tasks: List[Task]) -> None:
"""Execute Prefill phase."""
"""Execute Prefill phase with incremental prefill support."""
if not tasks:
return
# Group tasks by prefix cache status
fully_cached, partial, full = [], [], []
for task in tasks:
total_len, prefix_len = len(task.prompt_ids), task.prefix_len
if prefix_len == total_len:
fully_cached.append(task)
elif prefix_len > 0:
partial.append(task)
else:
full.append(task)
# Handle fully cached tasks
for t in fully_cached:
t.input_tokens, t.output_tokens = len(t.prompt_ids), 0
if t.slot >= 0:
self.seq_mask[t.slot, : t.input_tokens] = True
if full:
self._execute_full_prefill(full)
if partial:
self._execute_partial_prefill(partial)
def _execute_full_prefill(self, tasks: List[Task]) -> None:
"""Execute full prefill for tasks without prefix cache."""
if not tasks:
return
@ -271,11 +456,59 @@ class InferenceScheduler:
for i, task in enumerate(tasks):
task.input_tokens = prompt_lens[i]
task.output_tokens = 0
# Insert new prefix into cache
self.prefix_cache.insert(tuple(task.prompt_ids), task.slot)
for task in tasks:
if task.slot >= 0:
self.seq_mask[task.slot, : task.input_tokens] = True
def _execute_partial_prefill(self, tasks: List[Task]) -> None:
"""Execute incremental prefill for tasks with partial prefix cache match."""
for task in tasks:
total_len = len(task.prompt_ids)
prefix_len = task.prefix_len
if prefix_len >= total_len:
task.input_tokens = total_len
task.output_tokens = 0
continue
# Get new tokens that need prefill
new_ids = task.prompt_ids[prefix_len:]
new_len = len(new_ids)
if new_len == 0:
task.input_tokens = total_len
task.output_tokens = 0
continue
# Build input for incremental prefill
input_ids = torch.tensor([new_ids], dtype=torch.long, device=self.device)
# Input mask should cover from position 0 to prefix_len + new_len
# The prefix part uses cached KV, new part needs computation
input_mask = torch.ones(
(1, prefix_len + new_len), dtype=torch.bool, device=self.device
)
with torch.inference_mode():
self.model(
input_ids,
input_mask=input_mask,
start_pos=prefix_len,
persistent_key_values=self.kv_cache,
)
task.input_tokens = total_len
task.output_tokens = 0
# Insert full prefix into cache (ref_count already increased in add_task)
self.prefix_cache.insert(tuple(task.prompt_ids), task.slot)
if task.slot >= 0:
self.seq_mask[task.slot, : task.input_tokens] = True
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
"""Execute Decode phase."""
if not tasks:

View File

@ -9,7 +9,7 @@ import json
import logging
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional
import torch
import uvicorn
@ -134,78 +134,6 @@ class CompletionResponse(BaseModel):
choices: List[Dict[str, Any]]
class StreamCompletionResponse(BaseModel):
id: str = "chatcmpl-default"
object: str = "chat.completion.chunk"
created: int = 0
model: str = "astrai"
choices: List[Dict[str, Any]]
def convert_messages_to_history(
messages: List[ChatMessage],
) -> tuple[Optional[str], Optional[List[Tuple[str, str]]]]:
"""Convert OpenAI-style messages to system_prompt and history."""
system_prompt = None
history: List[Tuple[str, str]] = []
user_buffer = []
assistant_buffer = []
for msg in messages:
if msg.role == "system":
system_prompt = msg.content
elif msg.role == "user":
if assistant_buffer:
# Flush previous pair
history.append(("".join(user_buffer), "".join(assistant_buffer)))
user_buffer = []
assistant_buffer = []
user_buffer.append(msg.content)
elif msg.role == "assistant":
assistant_buffer.append(msg.content)
else:
logger.warning(f"Unknown role {msg.role}")
return system_prompt, history if history else None
def convert_messages_to_prompt(
messages: List[ChatMessage], engine: InferenceEngine = None
) -> str:
"""Convert messages to prompt string.
Args:
messages: List of ChatMessage objects
engine: InferenceEngine instance for accessing tokenizer
Returns:
str: Formatted prompt string
"""
# Convert to dict format for chat template
msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
# Extract system prompt if present
system_prompt = None
filtered_messages = []
for msg in msg_dicts:
if msg["role"] == "system":
system_prompt = msg["content"]
else:
filtered_messages.append(msg)
# Use engine's tokenizer chat template if available
if engine is not None and engine.tokenizer is not None:
return engine.tokenizer.apply_chat_template(
filtered_messages, system_prompt=system_prompt, tokenize=False
)
# Fallback: simple concatenation (deprecated)
prompt_parts = []
for msg in filtered_messages:
prompt_parts.append(
f"<im▁start>{msg['role']}\n{msg['content']}<im▁end>"
)
return "\n".join(prompt_parts) + "\n<im▁start>assistant\n"
@app.get("/health")
async def health():
return {
@ -233,7 +161,12 @@ async def chat_completion(request: ChatCompletionRequest):
raise HTTPException(status_code=503, detail="Engine not initialized")
# Convert messages to prompt using engine's tokenizer
prompt = convert_messages_to_prompt(request.messages, engine=_engine)
# Extract system prompt if present, then apply chat template
# Apply chat template directly with messages
prompt = _engine.tokenizer.apply_chat_template(
[{"role": m.role, "content": m.content} for m in request.messages],
tokenize=False,
)
if request.stream:
# Streaming response (use synchronous generator)

View File

@ -30,6 +30,7 @@ def get_rotary_emb(
dim: int,
max_len: int,
base: float = 10000,
device: Optional[torch.device] = None,
) -> Tuple[Tensor, Tensor]:
"""
Get the rotary embedding for the given dimension and maximum length.
@ -37,12 +38,13 @@ def get_rotary_emb(
dim (int): The dimension of the input.
max_len (int): The maximum length of the input.
base (float, optional): The base for the frequency. Defaults to 10000.
device (optional): The device to create tensors on. Defaults to None.
Returns:
Tensor: The rotary embedding tensor.
"""
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64) / dim)
t = torch.arange(0, max_len, dtype=torch.float64)
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
freqs = torch.outer(t, theta)
return torch.cos(freqs).float(), torch.sin(freqs).float()
@ -83,10 +85,10 @@ class RotaryEmbedding(nn.Module):
self.max_len = max_len
self.base = base
self.max_len_cached = None
self._set_rotary_buffer(self.max_len)
self._set_rotary_buffer(self.max_len, None)
def _set_rotary_buffer(self, max_len: int):
cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base)
def _set_rotary_buffer(self, max_len: int, device: Optional[torch.device] = None):
cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base, device)
self.register_buffer("cos_cached", cos_cached, persistent=False)
self.register_buffer("sin_cached", sin_cached, persistent=False)
self.max_len_cached = max_len
@ -95,7 +97,7 @@ class RotaryEmbedding(nn.Module):
seq_len = x.size(1)
if self.max_len_cached < seq_len + start_pos:
self._set_rotary_buffer(seq_len + start_pos)
self._set_rotary_buffer(self.max_len_cached * 2, x.device)
cos = self.cos_cached[start_pos : start_pos + seq_len]
sin = self.sin_cached[start_pos : start_pos + seq_len]
@ -121,8 +123,7 @@ class RMSNorm(nn.Module):
self.norm_eps = norm_eps
def forward(self, x: Tensor) -> Tensor:
rms = F.rms_norm(x.float(), self.normalized_shape, self.weight, self.norm_eps)
return rms.to(x.dtype)
return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps)
class MLP(nn.Module):
@ -257,7 +258,7 @@ class MLA(nn.Module):
self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False)
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
self.kv_norm = RMSNorm(kv_lora_rank, eps=norm_eps)
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
# KV (k_nope, k_rope, v)
self.kv_b_proj = Linear(

View File

@ -1,15 +1,8 @@
from astrai.tokenize.chat_template import ChatTemplate, MessageType
from astrai.tokenize.tokenizer import (
AutoTokenizer,
BpeTokenizer,
)
from astrai.tokenize.trainer import BpeTrainer
from astrai.tokenize.tokenizer import AutoTokenizer
__all__ = [
"AutoTokenizer",
"BpeTokenizer",
"BpeTrainer",
"ChatTemplate",
"MessageType",
"HistoryType",
]

View File

@ -6,8 +6,7 @@ import json
from pathlib import Path
from typing import Dict, List, Optional, Union
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers.models import BPE
from tokenizers import Tokenizer
from astrai.tokenize.chat_template import ChatTemplate
@ -210,9 +209,9 @@ class AutoTokenizer:
Args:
messages: List of message dicts with 'role' and 'content'.
system_prompt: Optional system prompt string.
system_prompt: Optional system prompt string (auto-converted to first message).
tokenize: Whether to return token IDs (True) or raw string (False).
add_generation_prompt: Whether to add the generation prompt (default: False).
add_generation_prompt: Whether to add the generation prompt (default: True).
**kwargs: Additional variables to pass to the template.
Returns:
@ -226,10 +225,13 @@ class AutoTokenizer:
"Chat template not set. Use set_chat_template() to set a template first."
)
# Auto-convert system_prompt to first message if provided
if system_prompt:
messages = [{"role": "system", "content": system_prompt}] + list(messages)
# Render the template
rendered = self._chat_template.render(
messages=messages,
system_prompt=system_prompt,
add_generation_prompt=add_generation_prompt,
**kwargs,
)
@ -238,42 +240,3 @@ class AutoTokenizer:
return self.encode(rendered)
return rendered
class BpeTokenizer(AutoTokenizer):
"""BPE tokenizer implementation."""
def __init__(
self,
special_token_map: Dict[str, str] = None,
path: Optional[str] = None,
chat_template: Optional[str] = None,
):
special_token_map = special_token_map or {
"bos": "<begin▁of▁sentence>",
"eos": "<end▁of▁sentence>",
"pad": "<▁pad▁>",
"im_start": "<im▁start>",
"im_end": "<im▁end>",
}
self._tokenizer = None
self._init_tokenizer()
super().__init__(
path, special_token_map=special_token_map, chat_template=chat_template
)
def _init_tokenizer(self):
"""Initialize a new BPE tokenizer with default settings."""
model = BPE()
self._tokenizer = Tokenizer(model)
self._tokenizer.normalizer = normalizers.Sequence(
[normalizers.NFC(), normalizers.Strip()]
)
self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.UnicodeScripts(),
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True),
]
)
self._tokenizer.decoder = decoders.ByteLevel()
self._tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)

View File

@ -1,108 +0,0 @@
"""
BPE Tokenizer Trainer module.
Provides training functionality for BPE tokenizers.
"""
from typing import List, Union
from tokenizers import pre_tokenizers
from tokenizers.trainers import BpeTrainer as BpeTrainerImpl
class BpeTrainer:
"""BPE tokenizer trainer."""
def __init__(self, tokenizer):
"""Initialize trainer with a tokenizer instance.
Args:
tokenizer: A BpeTokenizer instance
"""
self.tokenizer = tokenizer
def _prepare_trainer(
self,
vocab_size: int,
min_freq: int,
reserved_token_size: int,
max_token_length: int = 18,
):
"""Prepare the BPE trainer with proper configuration."""
assert reserved_token_size > len(self.tokenizer._special_tokens)
reserved_tokens = [
f"<|reserve{i:02d}|>"
for i in range(reserved_token_size - len(self.tokenizer._special_tokens))
]
detail_vocab_size = vocab_size - (
len(reserved_tokens) + len(self.tokenizer._special_tokens)
)
alphabet = pre_tokenizers.ByteLevel.alphabet()
min_size = len(alphabet) + len(self.tokenizer._control_tokens)
assert detail_vocab_size > min_size
trainer = BpeTrainerImpl(
vocab_size=detail_vocab_size,
min_frequency=min_freq,
limit_alphabet=detail_vocab_size // 6,
max_token_length=max_token_length,
special_tokens=self.tokenizer._control_tokens,
initial_alphabet=alphabet,
show_progress=True,
)
return trainer, reserved_tokens
def train(
self,
files: Union[str, List[str]],
vocab_size: int,
min_freq: int,
reserved_token_size: int = 100,
**kwargs,
):
"""Train tokenizer from files.
Args:
files: Path or list of paths to training files
vocab_size: Target vocabulary size
min_freq: Minimum frequency for tokens
reserved_token_size: Number of reserved tokens
**kwargs: Additional arguments
"""
trainer, reserved_tokens = self._prepare_trainer(
vocab_size, min_freq, reserved_token_size, **kwargs
)
self.tokenizer._tokenizer.train(files=files, trainer=trainer)
self.tokenizer._tokenizer.add_special_tokens(
self.tokenizer._special_tokens + reserved_tokens
)
def train_from_iterator(
self,
iterator,
vocab_size: int,
min_freq: int,
reserved_token_size: int = 100,
**kwargs,
):
"""Train tokenizer from iterator.
Args:
iterator: Iterator yielding training strings
vocab_size: Target vocabulary size
min_freq: Minimum frequency for tokens
reserved_token_size: Number of reserved tokens
**kwargs: Additional arguments
"""
trainer, reserved_tokens = self._prepare_trainer(
vocab_size, min_freq, reserved_token_size, **kwargs
)
self.tokenizer._tokenizer.train_from_iterator(
iterator=iterator, trainer=trainer
)
self.tokenizer._tokenizer.add_special_tokens(
self.tokenizer._special_tokens + reserved_tokens
)
__all__ = ["BpeTrainer"]

View File

@ -234,13 +234,13 @@ def train(
},
)
toltal_steps = len(dataset) * n_epoch // (batch_size * nprocs)
total_steps = len(dataset) * n_epoch // (batch_size * nprocs)
scheduler_fn = partial(
create_scheduler,
**{
"schedule_type": "cosine",
"warmup_steps": warmup_steps,
"lr_decay_steps": toltal_steps - warmup_steps,
"lr_decay_steps": total_steps - warmup_steps,
},
)

View File

@ -7,12 +7,27 @@ import numpy as np
import pytest
import safetensors.torch as st
import torch
from tokenizers import pre_tokenizers
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
from torch.utils.data import Dataset
from astrai.config.model_config import ModelConfig
from astrai.model.transformer import Transformer
from astrai.tokenize import BpeTokenizer, BpeTrainer
from astrai.tokenize import AutoTokenizer
def create_test_tokenizer(vocab_size: int = 1000) -> AutoTokenizer:
"""Create a simple tokenizer for testing purposes."""
tokenizer = Tokenizer(models.BPE())
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel()
trainer = trainers.BpeTrainer(
vocab_size=vocab_size, min_frequency=1, special_tokens=["<unk>", "<pad>"]
)
# Train on empty iterator with single character
tokenizer.train_from_iterator([chr(i) for i in range(256)], trainer)
auto_tokenizer = AutoTokenizer()
auto_tokenizer._tokenizer = tokenizer
auto_tokenizer._special_token_map = {"unk_token": "<unk>", "pad_token": "<pad>"}
return auto_tokenizer
class RandomDataset(Dataset):
@ -109,7 +124,7 @@ def base_test_env(request: pytest.FixtureRequest):
device = "cuda" if torch.cuda.is_available() else "cpu"
transformer_config = ModelConfig().load(config_path)
model = Transformer(transformer_config).to(device=device)
tokenizer = BpeTokenizer()
tokenizer = create_test_tokenizer()
yield {
"device": device,
@ -164,10 +179,7 @@ def test_env(request: pytest.FixtureRequest):
with open(config_path, "w") as f:
json.dump(config, f)
tokenizer = BpeTokenizer()
trainer = BpeTrainer(tokenizer)
sp_token_iter = iter(pre_tokenizers.ByteLevel.alphabet())
trainer.train_from_iterator(sp_token_iter, config["vocab_size"], 1)
tokenizer = create_test_tokenizer(vocab_size=config["vocab_size"])
tokenizer.save(tokenizer_path)
transformer_config = ModelConfig().load(config_path)

View File

@ -0,0 +1,320 @@
"""Tests for scheduler concurrency."""
import threading
import time
from unittest.mock import MagicMock, patch
import pytest
from astrai.inference.scheduler import (
InferenceScheduler,
PrefixCacheManager,
)
def test_prefix_cache_concurrent_insert_find():
"""Test concurrent insert and find operations."""
cache = PrefixCacheManager(max_capacity=100)
results = {"errors": [], "inserts": 0, "finds": 0}
def insert_worker():
try:
for i in range(50):
cache.insert((i,), slot=i % 10)
results["inserts"] += 1
except Exception as e:
results["errors"].append(str(e))
def find_worker():
try:
for i in range(50):
cache.find_longest_prefix([i])
results["finds"] += 1
except Exception as e:
results["errors"].append(str(e))
threads = [threading.Thread(target=insert_worker) for _ in range(3)]
threads += [threading.Thread(target=find_worker) for _ in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert results["inserts"] == 150
assert results["finds"] == 150
def test_prefix_cache_concurrent_release():
"""Test concurrent release operations."""
cache = PrefixCacheManager(max_capacity=100)
# Insert some prefixes
for i in range(10):
cache.insert((i,), slot=i)
results = {"errors": []}
def release_worker():
try:
for i in range(10):
cache.release((i,))
except Exception as e:
results["errors"].append(str(e))
threads = [threading.Thread(target=release_worker) for _ in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
def test_prefix_cache_concurrent_insert_release_find():
"""Test mixed concurrent operations."""
cache = PrefixCacheManager(max_capacity=50)
results = {"errors": []}
def worker(worker_id):
try:
for i in range(20):
token_ids = (worker_id * 100 + i,)
cache.insert(token_ids, slot=worker_id)
# Find after insert
cache.find_longest_prefix(list(token_ids))
# Release
cache.release(token_ids)
except Exception as e:
results["errors"].append(f"Worker {worker_id}: {str(e)}")
threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
@pytest.fixture
def mock_model_and_tokenizer():
"""Create mock model and tokenizer."""
mock_model = MagicMock()
mock_model.config = MagicMock()
mock_model.config.n_kv_heads = 8
mock_model.config.n_heads = 8
mock_model.config.dim = 128
mock_model.config.n_layers = 2
mock_model.config.max_len = 100
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1, 2, 3, 4, 5]
mock_tokenizer.decode.return_value = "token"
mock_tokenizer.stop_ids = [0]
mock_tokenizer.pad_id = None
return mock_model, mock_tokenizer
def test_scheduler_concurrent_add_task(mock_model_and_tokenizer):
"""Test concurrent add_task operations."""
mock_model, mock_tokenizer = mock_model_and_tokenizer
with patch("astrai.inference.scheduler.AutoModel"):
with patch("astrai.inference.scheduler.AutoTokenizer"):
scheduler = InferenceScheduler(
model=mock_model,
tokenizer=mock_tokenizer,
max_batch_size=4,
device="cpu",
)
results = {"task_ids": [], "errors": []}
lock = threading.Lock()
def add_task_worker(worker_id):
try:
for i in range(10):
task_id = scheduler.add_task(f"prompt from worker {worker_id}-{i}")
with lock:
results["task_ids"].append(task_id)
except Exception as e:
results["errors"].append(str(e))
threads = [threading.Thread(target=add_task_worker, args=(i,)) for i in range(5)]
for t in threads:
t.start()
# Let some tasks be processed
time.sleep(0.1)
scheduler.stop()
for t in threads:
t.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert len(results["task_ids"]) == 50
def test_scheduler_concurrent_add_remove_task(mock_model_and_tokenizer):
"""Test concurrent add and remove task operations."""
mock_model, mock_tokenizer = mock_model_and_tokenizer
with patch("astrai.inference.scheduler.AutoModel"):
with patch("astrai.inference.scheduler.AutoTokenizer"):
scheduler = InferenceScheduler(
model=mock_model,
tokenizer=mock_tokenizer,
max_batch_size=4,
device="cpu",
)
results = {"added": [], "removed": [], "errors": []}
def add_worker():
try:
for i in range(20):
task_id = scheduler.add_task(f"prompt {i}")
results["added"].append(task_id)
time.sleep(0.001)
except Exception as e:
results["errors"].append(f"Add: {str(e)}")
def remove_worker():
try:
time.sleep(0.05) # Wait for some tasks to be added
for task_id in results["added"][:10]:
scheduler.remove_task(task_id)
results["removed"].append(task_id)
except Exception as e:
results["errors"].append(f"Remove: {str(e)}")
add_thread = threading.Thread(target=add_worker)
remove_thread = threading.Thread(target=remove_worker)
add_thread.start()
remove_thread.start()
time.sleep(0.2)
scheduler.stop()
add_thread.join()
remove_thread.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert len(results["added"]) == 20
def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
"""Test concurrent get_stats operations."""
mock_model, mock_tokenizer = mock_model_and_tokenizer
with patch("astrai.inference.scheduler.AutoModel"):
with patch("astrai.inference.scheduler.AutoTokenizer"):
scheduler = InferenceScheduler(
model=mock_model,
tokenizer=mock_tokenizer,
max_batch_size=4,
device="cpu",
)
results = {"stats": [], "errors": []}
def add_tasks():
try:
for i in range(20):
scheduler.add_task(f"prompt {i}")
time.sleep(0.001)
except Exception as e:
results["errors"].append(f"Add: {str(e)}")
def get_stats():
try:
for _ in range(50):
stats = scheduler.get_stats()
results["stats"].append(stats)
time.sleep(0.001)
except Exception as e:
results["errors"].append(f"Get stats: {str(e)}")
add_thread = threading.Thread(target=add_tasks)
stats_thread = threading.Thread(target=get_stats)
add_thread.start()
stats_thread.start()
time.sleep(0.3)
scheduler.stop()
add_thread.join()
stats_thread.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert len(results["stats"]) == 50
# Verify stats are consistent
for stats in results["stats"]:
assert "total_tasks" in stats
assert stats["total_tasks"] >= 0
def test_prefix_cache_insert_same_prefix_concurrently():
"""Test inserting the same prefix concurrently."""
cache = PrefixCacheManager(max_capacity=100)
results = {"slot_values": [], "errors": []}
def insert_worker():
try:
# All workers try to insert the same prefix
cache.insert((1, 2, 3), slot=threading.current_thread().name)
node = cache.root.children.get(1)
if node:
node = node.children.get(2)
if node:
node = node.children.get(3)
if node:
results["slot_values"].append(node.slot)
except Exception as e:
results["errors"].append(str(e))
threads = [threading.Thread(target=insert_worker) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
# All inserts should succeed, final slot should be one of the values
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
# Check ref_count is correct (should be 10)
node = cache.root.children.get(1).children.get(2).children.get(3)
assert node.ref_count == 10, f"Expected ref_count=10, got {node.ref_count}"
def test_prefix_cache_ref_count_underflow_prevention():
"""Test that ref_count doesn't go negative."""
cache = PrefixCacheManager(max_capacity=100)
# Insert a prefix
cache.insert((1, 2, 3), slot=0)
# Release multiple times
for _ in range(5):
cache.release((1, 2, 3))
# Try to find it - should return None since ref_count would be negative
# or handle it gracefully
node = cache.root.children.get(1).children.get(2).children.get(3)
# The ref_count should be 0, not negative
assert node.ref_count >= 0, f"ref_count went negative: {node.ref_count}"