Compare commits
10 Commits
a57a16430d
...
296db909aa
| Author | SHA1 | Date |
|---|---|---|
|
|
296db909aa | |
|
|
a2ae742988 | |
|
|
29beb174a5 | |
|
|
bbeaff4c60 | |
|
|
ab5e207f42 | |
|
|
b0eff02446 | |
|
|
408f0cb513 | |
|
|
64b78ecce3 | |
|
|
f2ffdf60d0 | |
|
|
ace8f6ee68 |
32
README.md
32
README.md
|
|
@ -84,6 +84,38 @@ python scripts/tools/train.py \
|
|||
python scripts/tools/generate.py --param_path=/path/to/param_path
|
||||
```
|
||||
|
||||
#### Start HTTP Server
|
||||
|
||||
Start the inference server with OpenAI-compatible HTTP API:
|
||||
|
||||
```bash
|
||||
python -m scripts.tools.server --port 8000 --device cuda
|
||||
```
|
||||
|
||||
Make requests:
|
||||
|
||||
```bash
|
||||
# Chat API (OpenAI compatible)
|
||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"max_tokens": 512
|
||||
}'
|
||||
|
||||
# Streaming response
|
||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [{"role": "user", "content": "Tell a story"}],
|
||||
"stream": true,
|
||||
"max_tokens": 500
|
||||
}'
|
||||
|
||||
# Health check
|
||||
curl http://localhost:8000/health
|
||||
```
|
||||
|
||||
#### Demo
|
||||
|
||||
Check out the demos in the `scripts/demo/` folder:
|
||||
|
|
|
|||
|
|
@ -85,6 +85,38 @@ python scripts/tools/train.py \
|
|||
python scripts/tools/generate.py --param_path=/path/to/param_path
|
||||
```
|
||||
|
||||
#### 启动 HTTP 服务
|
||||
|
||||
启动推理服务器,支持 OpenAI 兼容的 HTTP API:
|
||||
|
||||
```bash
|
||||
python -m scripts.tools.server --port 8000 --device cuda
|
||||
```
|
||||
|
||||
发起请求:
|
||||
|
||||
```bash
|
||||
# Chat API(OpenAI 兼容)
|
||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [{"role": "user", "content": "你好"}],
|
||||
"max_tokens": 512
|
||||
}'
|
||||
|
||||
# 流式响应
|
||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [{"role": "user", "content": "讲个故事"}],
|
||||
"stream": true,
|
||||
"max_tokens": 500
|
||||
}'
|
||||
|
||||
# 健康检查
|
||||
curl http://localhost:8000/health
|
||||
```
|
||||
|
||||
#### 演示
|
||||
|
||||
查看 `scripts/demo/` 文件夹中的演示:
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ AstrAI adopts a modular design with the following main components:
|
|||
- **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations
|
||||
- **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration
|
||||
- **Parallel Module** (`astrai/parallel/`): Distributed training support
|
||||
- **Serialization Module** (`astrai/serialization/`): HDF5 data loading, checkpoint management
|
||||
|
||||
The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**.
|
||||
|
||||
|
|
@ -21,11 +22,11 @@ The data flow can generally be divided into two main lines: **Training Data Flow
|
|||
flowchart LR
|
||||
subgraph A[Data Preparation]
|
||||
direction TB
|
||||
A1[Raw Text] --> A2[BpeTokenizer]
|
||||
A1[Raw Text] --> A2[AutoTokenizer]
|
||||
A2 --> A3[Serialize to .h5 files]
|
||||
A3 --> A4[BaseDataset]
|
||||
A4 --> A5[ResumableDistributedSampler]
|
||||
A5 --> A6[DataLoader]
|
||||
A5 --> A6[PyTorch DataLoader]
|
||||
end
|
||||
|
||||
subgraph B[Training]
|
||||
|
|
@ -50,7 +51,7 @@ flowchart LR
|
|||
C5 --> C6[InferenceScheduler]
|
||||
C6 --> C7[apply_sampling_strategies]
|
||||
C7 --> C8[Transformer Forward]
|
||||
C8 --> C9[KV Cache]
|
||||
C8 --> C9[KV Cache + Prefix Cache]
|
||||
C9 --> C10{End Condition?}
|
||||
C10 -->|No| C8
|
||||
C10 -->|Yes| C11[Output Text]
|
||||
|
|
@ -64,25 +65,18 @@ flowchart LR
|
|||
|
||||
### 1. Dataset Module
|
||||
|
||||
#### 1.1 Tokenizer (`tokenizer.py`)
|
||||
- Implemented based on Byte-Level BPE (BPE)
|
||||
- Supports special tokens: `<|begin▁of▁sentence|>`, `<|end▁of▁sentence|>`, `<|▁pad▁|>`, `<|im▁start|>`, `<|im▁end|>`
|
||||
- Provides `encode`/`decode` methods for mutual conversion between text and token IDs
|
||||
- Learns vocabulary from corpus during training, saved as `.json` files
|
||||
- `BpeTrainer` class handles vocabulary training from corpus
|
||||
|
||||
#### 1.2 Serialization (`serialization.py`)
|
||||
#### 1.1 Serialization (`serialization.py`)
|
||||
- **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors
|
||||
- **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory (`share_memory=True`)
|
||||
- **`Checkpoint` class**: Encapsulates model state dict, training epoch, iteration count; supports safetensors format for saving and loading
|
||||
|
||||
#### 1.3 Dataset (`dataset.py`)
|
||||
#### 1.2 Dataset (`dataset.py`)
|
||||
- **`BaseDataset`**: Abstract base class, defines common logic for window sampling, stride, etc.
|
||||
- **`BaseSegmentFetcher`** and **`MultiSegmentFetcher`**: Efficiently fetch data from specified index ranges in multiple segments
|
||||
- **`DatasetFactory`**: Factory pattern, supports dynamic registration of dataset types (`seq`, `sft`, `dpo`, `grpo`)
|
||||
- After dataset loading, multiple data keys (such as `"sequence"`, `"mask"`) are managed through `MultiSegmentFetcher`
|
||||
|
||||
#### 1.4 Sampler (`sampler.py`)
|
||||
#### 1.3 Sampler (`sampler.py`)
|
||||
- **`ResumableDistributedSampler`**: Resumable sampler supporting distributed training
|
||||
- Records current epoch and iteration position, enabling training resume from breakpoints
|
||||
- Supports shuffle and drop_last options
|
||||
|
|
@ -99,7 +93,10 @@ flowchart LR
|
|||
|
||||
#### 2.2 Submodules (`module.py`)
|
||||
- **`RotaryEmbedding`**: Generates RoPE cos/sin cache
|
||||
- **`DecoderBlock`**: Contains multi-head attention (supports GQA), feedforward network (FFN), residual connections
|
||||
- **`DecoderBlock`**: Contains multi-head attention (supports GQA and MLA), feedforward network (FFN), residual connections
|
||||
- **`GQA`**: Grouped Query Attention implementation
|
||||
- **`MLA`**: Multi-Latent Attention implementation (like Qwen2-VL)
|
||||
- **`MLP`**: Feed-forward network with SiLU activation and gated mechanism
|
||||
- **`RMSNorm`**: Layer normalization variant
|
||||
- **`Linear`**, **`Embedding`**: Custom linear layer and embedding layer, supporting parallelism wrappers
|
||||
|
||||
|
|
@ -116,15 +113,29 @@ flowchart LR
|
|||
1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end`
|
||||
|
||||
#### 3.3 Strategy (`strategy.py`)
|
||||
- **`BaseStrategy`**: Defines training strategy interface (such as `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`)
|
||||
- **`BaseStrategy`**: Defines training strategy interface
|
||||
- **`SEQStrategy`**: Standard next-token prediction training
|
||||
- **`SFTStrategy`**: Supervised Fine-tuning with loss masking
|
||||
- **`DPOStrategy`**: Direct Preference Optimization
|
||||
- **`GRPOStrategy`**: Group Relative Policy Optimization
|
||||
- Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor
|
||||
- Created dynamically by `StrategyFactory` according to configuration
|
||||
|
||||
#### 3.4 Scheduler (`schedule.py`)
|
||||
- **`BaseScheduler`**: Abstract base class defining learning rate scheduling interface
|
||||
- **`SchedulerFactory`**: Factory pattern, supports registration of various schedulers (such as `cosine`, `sgdr`)
|
||||
- **`CosineScheduler`**: Cosine decay scheduler with warmup
|
||||
- **`SGDRScheduler`**: Stochastic Gradient Descent with Warm Restarts
|
||||
- **`SchedulerFactory`**: Factory pattern, supports registration of various schedulers
|
||||
- Scheduler is automatically created according to configuration and bound to optimizer
|
||||
|
||||
#### 3.5 Callbacks (`train_callback.py`)
|
||||
- **`TrainCallback`**: Protocol interface for trainer callbacks
|
||||
- **`CheckpointCallback`**: Saves model checkpoints at configurable intervals
|
||||
- **`ProgressBarCallback`**: Displays training progress
|
||||
- **`MetricLoggerCallback`**: Logs training metrics to JSON files
|
||||
- **`GradientClippingCallback`**: Clips gradient norms
|
||||
- **`SchedulerCallback`**: Steps learning rate scheduler
|
||||
|
||||
### 4. Factory Module
|
||||
|
||||
#### 4.1 Registry and BaseFactory (`factory.py`)
|
||||
|
|
@ -133,9 +144,24 @@ flowchart LR
|
|||
- Supports decorator-based registration pattern for extensible components
|
||||
- Provides methods for registration, retrieval, and listing with filtering
|
||||
|
||||
### 5. Inference Module
|
||||
### 5. Parallel Module
|
||||
|
||||
#### 5.1 Inference Engine (`engine.py`)
|
||||
#### 5.1 Setup (`setup.py`)
|
||||
- **`spawn_parallel_fn`**: Spawns multiple processes for distributed training using PyTorch multiprocessing
|
||||
- **`setup_parallel`**: Context manager for initializing distributed process group (NCCL/CCL backend)
|
||||
- **`only_on_rank`**: Decorator to execute functions only on specific ranks
|
||||
- **`get_rank`**: Returns current process rank in distributed group
|
||||
- **`get_world_size`**: Returns total number of processes in distributed group
|
||||
- **`get_current_device`**: Returns current device from environment
|
||||
|
||||
#### 5.2 Parallel Layers (`module.py`)
|
||||
- **`ParallelModel`**: Base class for parallel models with process group
|
||||
- **`ColumnParallelLinear`**: Column-parallel linear layer with input splitting and output gathering
|
||||
- **`RowParallelLinear`**: Row-parallel linear layer with output reduction
|
||||
|
||||
### 6. Inference Module
|
||||
|
||||
#### 6.1 Inference Engine (`engine.py`)
|
||||
- **`InferenceEngine`**: Unified inference interface, supports streaming and non-streaming generation
|
||||
- **`InferenceScheduler`**: Continuous batching scheduler with dynamic batch composition
|
||||
- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.)
|
||||
|
|
@ -145,22 +171,38 @@ flowchart LR
|
|||
- Supports continuous batching with `max_batch_size` and `max_seq_len` parameters
|
||||
- Uses separate model and tokenizer initialization for flexibility
|
||||
|
||||
#### 5.2 Scheduler (`scheduler.py`)
|
||||
#### 6.2 Scheduler (`scheduler.py`)
|
||||
- **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED)
|
||||
- **`TaskStatus`**: Task state enumeration
|
||||
- **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits
|
||||
- **`PrefixCacheManager`**: Radix tree-based prefix cache with LRU eviction for efficient KV cache reuse
|
||||
- **`RadixNode`**: Tree node structure for prefix caching
|
||||
- Continuous batching: new requests can join at any time, completed requests are released immediately
|
||||
|
||||
#### 5.3 Request (`engine.py`)
|
||||
- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.)
|
||||
- **`messages` format**: List of message dictionaries with `role` (system/user/assistant) and `content`
|
||||
- **`apply_chat_template`** (from `tokenizer.py`): Converts messages into prompt string using ChatML format
|
||||
- Provides streaming (`stream=True`) and non-streaming (`stream=False`) generation interfaces
|
||||
#### 6.3 Server (`server.py`)
|
||||
- FastAPI-based HTTP inference server
|
||||
- OpenAI-compatible `/v1/chat/completions` endpoint
|
||||
- Health check and statistics endpoints
|
||||
- Supports both streaming and non-streaming responses
|
||||
|
||||
### 7. Tokenizer Module
|
||||
|
||||
#### 7.1 Tokenizer (`tokenizer.py`)
|
||||
- Implemented based on HuggingFace tokenizers library (Byte-Level BPE)
|
||||
- **`AutoTokenizer`**: Auto-loading tokenizer class
|
||||
- Supports special tokens: `<|begin▁of▁sentence|>`, `<|end▁of▁sentence|>`, `<|▁pad▁|>`, `<|im▁start|>`, `<|im▁end|>`
|
||||
- Provides `encode`/`decode` methods for mutual conversion between text and token IDs
|
||||
- Uses `AutoTokenizer` for loading pre-trained tokenizers
|
||||
|
||||
#### 7.2 Chat Template (`chat_template.py`)
|
||||
- **`ChatTemplate`**: Jinja2-based chat template with rendering support
|
||||
- Handles multi-role message formatting (system, user, assistant)
|
||||
- Supports dynamic prompts and generation prompts
|
||||
|
||||
## Training Data Flow - Detailed Steps
|
||||
|
||||
1. **Data Preparation**
|
||||
- Raw text is converted to token ID sequences through BPE tokenizer
|
||||
- Raw text is converted to token ID sequences through AutoTokenizer
|
||||
- Token ID sequences (possibly with masks, labels, etc.) are saved by groups as `.h5` files
|
||||
- Files can contain multiple segments, each segment corresponds to a tensor
|
||||
|
||||
|
|
@ -171,7 +213,7 @@ flowchart LR
|
|||
|
||||
3. **Sampling and Batch Loading**
|
||||
- `ResumableDistributedSampler` generates index sequence based on current epoch and iteration position
|
||||
- `DataLoader` uses sampler to get indices, calls dataset's `__getitem__` to get actual data
|
||||
- PyTorch `DataLoader` uses sampler to get indices, calls dataset's `__getitem__` to get actual data
|
||||
- Batch data shape is `[batch_size, window_size]` (or varies according to specific dataset type)
|
||||
|
||||
4. **Strategy Forward and Loss Calculation**
|
||||
|
|
@ -202,7 +244,7 @@ flowchart LR
|
|||
- For batch generation, use `pad_sequence` for padding
|
||||
|
||||
3. **Autoregressive Generation Loop**
|
||||
- Initialize KV cache (optional)
|
||||
- Initialize KV cache (optional) and prefix cache
|
||||
- Loop until generating `max_len` tokens or encountering stop token:
|
||||
- Input current `input_ids` (or cached new token) to model, obtain `logits`
|
||||
- Apply `apply_sampling_strategies` (temperature, top-k, top-p) to `logits`
|
||||
|
|
@ -222,7 +264,6 @@ flowchart LR
|
|||
|
||||
## Summary
|
||||
|
||||
The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using KV cache and sampling strategies. Clear interfaces between modules facilitate customization and extension.
|
||||
The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using KV cache, prefix caching, and sampling strategies. Clear interfaces between modules facilitate customization and extension.
|
||||
|
||||
> Document Update Time: 2026-04-05
|
||||
> Corresponding Code Version: Refer to version number defined in `pyproject.toml`
|
||||
> Document Update Time: 2026-04-09
|
||||
|
|
@ -8,7 +8,7 @@ Thus, the AstrAI project was born - 1B parameters, Chinese-English bilingual, su
|
|||
|
||||
```mermaid
|
||||
classDiagram
|
||||
namespace astrai.config {
|
||||
namespace config {
|
||||
class ModelConfig {
|
||||
+int vocab_size
|
||||
+int dim
|
||||
|
|
@ -56,17 +56,9 @@ classDiagram
|
|||
+validate()
|
||||
}
|
||||
|
||||
class ModelParameter {
|
||||
+nn.Module model
|
||||
+BpeTokenizer tokenizer
|
||||
+ModelConfig config
|
||||
+save(instance, save_dir)
|
||||
+load(load_dir, disable_init) ModelParameter
|
||||
+to(*args, **kwargs)
|
||||
}
|
||||
}
|
||||
|
||||
namespace astrai.dataset {
|
||||
namespace dataset {
|
||||
class BaseDataset {
|
||||
+int window_size
|
||||
+int stride
|
||||
|
|
@ -100,8 +92,8 @@ classDiagram
|
|||
}
|
||||
|
||||
class MultiSegmentFetcher {
|
||||
+Dict muti_fetchers
|
||||
+List muti_keys
|
||||
+Dict multi_fetchers
|
||||
+List multi_keys
|
||||
+key_fetch(begin_idx, end_idx, keys) Dict
|
||||
+fetch_data(begin_idx, end_idx) Dict
|
||||
}
|
||||
|
|
@ -125,17 +117,9 @@ classDiagram
|
|||
+save(save_dir)
|
||||
+load(save_dir) Checkpoint
|
||||
}
|
||||
|
||||
class DataLoader {
|
||||
+Dataset dataset
|
||||
+int batch_size
|
||||
+Sampler sampler
|
||||
+__iter__()
|
||||
+__len__()
|
||||
}
|
||||
}
|
||||
|
||||
namespace astrai.model {
|
||||
namespace model {
|
||||
class AutoModel {
|
||||
+ModelConfig config
|
||||
+Dict _registry
|
||||
|
|
@ -148,7 +132,7 @@ classDiagram
|
|||
|
||||
class Transformer {
|
||||
+ModelConfig config
|
||||
+RotaryEmbedding rotary_embeding
|
||||
+RotaryEmbedding rotary_embedding
|
||||
+Embedding embed_tokens
|
||||
+ModuleList layers
|
||||
+RMSNorm norm
|
||||
|
|
@ -216,24 +200,46 @@ classDiagram
|
|||
}
|
||||
}
|
||||
|
||||
namespace astrai.tokenize {
|
||||
class Tokenizer {
|
||||
+encode(tokens, out_ids, add_special_tokens) List~int~
|
||||
+decode(tokens, skip_special_tokens) str
|
||||
+__len__() int
|
||||
}
|
||||
|
||||
class BpeTokenizer {
|
||||
namespace tokenize {
|
||||
class AutoTokenizer {
|
||||
+List~str~ stop_ids
|
||||
+int bos_id
|
||||
+int eos_id
|
||||
+int pad_id
|
||||
+vocab_size int
|
||||
+encode(tokens, out_ids, add_special_tokens) List~int~
|
||||
+decode(tokens, skip_special_tokens) str
|
||||
+apply_chat_template(messages, tokenize) Union~str, List[int]~
|
||||
+set_chat_template(template)
|
||||
+load(path)
|
||||
+from_pretrained(path) AutoTokenizer
|
||||
+save_pretrained(save_path)
|
||||
}
|
||||
|
||||
class ChatTemplate {
|
||||
+String template_str
|
||||
+render(messages, add_generation_prompt) str
|
||||
+from_string(template) ChatTemplate
|
||||
}
|
||||
}
|
||||
|
||||
namespace astrai.trainer {
|
||||
namespace factory {
|
||||
class Registry {
|
||||
+Dict _entries
|
||||
+register(name, component_cls, category, priority)
|
||||
+get(name) Type
|
||||
+list_names() List~str~
|
||||
}
|
||||
|
||||
class BaseFactory {
|
||||
+Registry _registry
|
||||
+register(name, category, priority) decorator
|
||||
+create(name, *args, **kwargs) T
|
||||
+list_registered() list
|
||||
}
|
||||
}
|
||||
|
||||
namespace trainer {
|
||||
class Trainer {
|
||||
+TrainConfig train_config
|
||||
+List~TrainCallback~ callbacks
|
||||
|
|
@ -337,6 +343,39 @@ classDiagram
|
|||
+on_error(context)
|
||||
}
|
||||
|
||||
class GradientClippingCallback {
|
||||
+float max_grad_norm
|
||||
+on_step_begin(context)
|
||||
}
|
||||
|
||||
class SchedulerCallback {
|
||||
+on_train_begin(context)
|
||||
+on_batch_end(context)
|
||||
}
|
||||
|
||||
class CheckpointCallback {
|
||||
+str save_dir
|
||||
+int interval
|
||||
+_save_checkpoint(context)
|
||||
+on_batch_end(context)
|
||||
+on_train_end(context)
|
||||
+on_error(context)
|
||||
}
|
||||
|
||||
class ProgressBarCallback {
|
||||
+int num_epoch
|
||||
+on_epoch_begin(context)
|
||||
+on_batch_end(context)
|
||||
+on_epoch_end(context)
|
||||
}
|
||||
|
||||
class MetricLoggerCallback {
|
||||
+str log_dir
|
||||
+int save_interval
|
||||
+on_batch_end(context)
|
||||
+on_train_end(context)
|
||||
}
|
||||
|
||||
class CallbackFactory {
|
||||
+Registry _registry
|
||||
+register(name) decorator
|
||||
|
|
@ -344,10 +383,17 @@ classDiagram
|
|||
}
|
||||
}
|
||||
|
||||
namespace astrai.inference {
|
||||
namespace inference {
|
||||
class InferenceEngine {
|
||||
+ModelParameter parameter
|
||||
+nn.Module model
|
||||
+AutoTokenizer tokenizer
|
||||
+InferenceScheduler scheduler
|
||||
+int max_batch_size
|
||||
+Optional int max_seq_len
|
||||
+int max_prefix_len
|
||||
+int cache_capacity
|
||||
+Tensor kv_cache
|
||||
+Tensor seq_mask
|
||||
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
|
||||
+generate_with_request(request) Union[Generator, str, List[str]]
|
||||
+get_stats() Dict
|
||||
|
|
@ -356,10 +402,11 @@ classDiagram
|
|||
|
||||
class InferenceScheduler {
|
||||
+nn.Module model
|
||||
+Tokenizer tokenizer
|
||||
+AutoTokenizer tokenizer
|
||||
+ModelConfig config
|
||||
+Tuple kv_cache
|
||||
+Tensor seq_mask
|
||||
+PrefixCacheManager prefix_cache
|
||||
+List waiting_queue
|
||||
+List active_tasks
|
||||
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
||||
|
|
@ -369,6 +416,24 @@ classDiagram
|
|||
+get_stats() Dict
|
||||
}
|
||||
|
||||
class PrefixCacheManager {
|
||||
+RadixNode root
|
||||
+int max_capacity
|
||||
+List lru
|
||||
+insert(token_ids, slot)
|
||||
+find_longest_prefix(token_ids) Tuple[int, int]
|
||||
+release(token_ids)
|
||||
}
|
||||
|
||||
class RadixNode {
|
||||
+Dict children
|
||||
+int hash
|
||||
+int slot
|
||||
+int ref_count
|
||||
+float last_access
|
||||
+List token_sequence
|
||||
}
|
||||
|
||||
class Task {
|
||||
+str task_id
|
||||
+List prompt_ids
|
||||
|
|
@ -392,14 +457,6 @@ classDiagram
|
|||
+str ABORTED
|
||||
}
|
||||
|
||||
class apply_sampling_strategies {
|
||||
+Tensor logits
|
||||
+float temperature
|
||||
+int top_k
|
||||
+float top_p
|
||||
+forward() Tensor
|
||||
}
|
||||
|
||||
class Server {
|
||||
+start()
|
||||
+predict(request)
|
||||
|
|
@ -410,19 +467,54 @@ classDiagram
|
|||
+float top_p
|
||||
+float temperature
|
||||
+int max_len
|
||||
+Union~str, List~str~~ query
|
||||
+history Optional
|
||||
+system_prompt Optional~str~
|
||||
+List~Dict~ messages
|
||||
+stream bool
|
||||
}
|
||||
|
||||
class _Result {
|
||||
+List~str~ tokens
|
||||
+List~str~ results
|
||||
+List~bool~ done_flags
|
||||
+append(token, idx)
|
||||
+get_results() List~str~
|
||||
}
|
||||
|
||||
class ChatMessage {
|
||||
+str role
|
||||
+str content
|
||||
}
|
||||
|
||||
class ChatCompletionRequest {
|
||||
+List~ChatMessage~ messages
|
||||
+float temperature
|
||||
+float top_p
|
||||
+int top_k
|
||||
+int max_tokens
|
||||
+bool stream
|
||||
+Optional~str~ system_prompt
|
||||
}
|
||||
|
||||
class CompletionResponse {
|
||||
+str id
|
||||
+str object
|
||||
+int created
|
||||
+str model
|
||||
+List~Dict~ choices
|
||||
}
|
||||
}
|
||||
|
||||
namespace astrai.parallel {
|
||||
namespace parallel {
|
||||
class ParallelSetup {
|
||||
+spawn_parallel_fn(fn, nprocs)
|
||||
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type, device_ids)
|
||||
}
|
||||
|
||||
class ParallelModel {
|
||||
+dist.ProcessGroup process_group
|
||||
+int rank
|
||||
+int world_size
|
||||
}
|
||||
|
||||
class ColumnParallelLinear {
|
||||
+forward(x) Tensor
|
||||
}
|
||||
|
|
@ -433,9 +525,16 @@ classDiagram
|
|||
}
|
||||
|
||||
%% Relationships
|
||||
TrainConfig --> ModelConfig : contains
|
||||
TrainConfig --> ModelConfig : uses
|
||||
TrainConfig --> BaseDataset : uses
|
||||
TrainConfig --> Transformer : uses
|
||||
TrainConfig --> StrategyFactory : selects
|
||||
StrategyFactory ..> BaseStrategy : creates
|
||||
BaseStrategy <|-- SEQStrategy
|
||||
BaseStrategy <|-- SFTStrategy
|
||||
BaseStrategy <|-- DPOStrategy
|
||||
BaseStrategy <|-- GRPOStrategy
|
||||
DPOStrategy --> Transformer : uses
|
||||
GRPOStrategy --> Transformer : uses
|
||||
Trainer --> TrainConfig : configures
|
||||
Trainer --> TrainContextBuilder : builds
|
||||
Trainer --> TrainCallback : manages
|
||||
|
|
@ -443,30 +542,27 @@ classDiagram
|
|||
TrainContext --> Checkpoint : manages
|
||||
TrainContext --> BaseStrategy : uses
|
||||
TrainContext --> BaseScheduler : uses
|
||||
StrategyFactory ..> BaseStrategy : creates
|
||||
BaseStrategy <|-- SEQStrategy
|
||||
BaseStrategy <|-- SFTStrategy
|
||||
BaseStrategy <|-- DPOStrategy
|
||||
BaseStrategy <|-- GRPOStrategy
|
||||
DPOStrategy --> Transformer : creates ref_model
|
||||
GRPOStrategy --> Transformer : creates ref_model
|
||||
AutoModel --> ModelConfig : contains
|
||||
SchedulerFactory ..> BaseScheduler : creates
|
||||
BaseScheduler <|-- CosineScheduler
|
||||
BaseScheduler <|-- SGDRScheduler
|
||||
CallbackFactory ..> TrainCallback : creates
|
||||
TrainCallback <|-- GradientClippingCallback
|
||||
TrainCallback <|-- SchedulerCallback
|
||||
TrainCallback <|-- CheckpointCallback
|
||||
TrainCallback <|-- ProgressBarCallback
|
||||
TrainCallback <|-- MetricLoggerCallback
|
||||
InferenceEngine --> InferenceScheduler : uses
|
||||
InferenceScheduler --> Task : manages
|
||||
InferenceScheduler --> TaskStatus : uses
|
||||
InferenceScheduler --> apply_sampling_strategies : uses
|
||||
InferenceScheduler --> Transformer : uses
|
||||
InferenceEngine --> Transformer : uses
|
||||
InferenceEngine --> GenerationRequest : uses
|
||||
Server --> InferenceEngine : uses
|
||||
Server --> ChatMessage : uses
|
||||
Server --> ChatCompletionRequest : uses
|
||||
Server --> CompletionResponse : uses
|
||||
ParallelSetup --> Trainer : enables
|
||||
TrainConfig --> StrategyFactory : selects
|
||||
ModelParameter --> Transformer : contains
|
||||
ModelParameter --> BpeTokenizer : contains
|
||||
ModelParameter --> ModelConfig : contains
|
||||
BaseDataset <|-- SEQDataset
|
||||
BaseDataset <|-- SFTDataset
|
||||
BaseDataset <|-- DPODataset
|
||||
|
|
@ -483,22 +579,34 @@ classDiagram
|
|||
DecoderBlock --> MLA : uses
|
||||
DecoderBlock --> MLP : uses
|
||||
DecoderBlock --> RMSNorm : uses
|
||||
BpeTokenizer --> Tokenizer : inherits
|
||||
TrainContextBuilder --> ResumableDistributedSampler : creates
|
||||
DataLoader --> BaseDataset : uses
|
||||
ResumableDistributedSampler --> BaseDataset : samples
|
||||
ParallelModel <|-- RowParallelLinear
|
||||
ParallelModel <|-- ColumnParallelLinear
|
||||
AutoTokenizer --> ChatTemplate : uses
|
||||
InferenceScheduler --> PrefixCacheManager : uses
|
||||
InferenceScheduler --> RadixNode : uses
|
||||
Checkpoint ..> Checkpoint : saves/loads
|
||||
TrainConfig --> DatasetFactory : selects
|
||||
TrainConfig --> SchedulerFactory : selects
|
||||
TrainConfig --> CallbackFactory : selects
|
||||
AutoModel ..> AutoTokenizer : loads with
|
||||
BaseFactory <|-- DatasetFactory
|
||||
BaseFactory <|-- StrategyFactory
|
||||
BaseFactory <|-- SchedulerFactory
|
||||
BaseFactory <|-- CallbackFactory
|
||||
```
|
||||
|
||||
### Module Overview
|
||||
|
||||
| Module | Components | Description |
|
||||
|--------|------------|-------------|
|
||||
| **astrai.config** | ModelConfig, TrainConfig, ModelParameter | Configuration management |
|
||||
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint, DataLoader | Dataset loading and management |
|
||||
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
|
||||
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint | Dataset loading and management |
|
||||
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||
| **astrai.tokenize** | AutoTokenizer, BpeTokenizer, ChatTemplate, BpeTrainer | Tokenizer |
|
||||
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
|
||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest | Inference service with continuous batching |
|
||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest, PrefixCacheManager, ChatMessage, ChatCompletionRequest, CompletionResponse | Inference service with continuous batching |
|
||||
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
||||
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
|
||||
|
||||
|
|
@ -515,7 +623,7 @@ classDiagram
|
|||
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
|
||||
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
|
||||
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
|
||||
| **Generator Pattern** | `_StreamingResult`, `_NonStreamingResult` | Event-based result notification for streaming/non-streaming generation |
|
||||
| **Generator Pattern** | `_Result`, `GenerationRequest` | Event-based result notification for streaming/non-streaming generation |
|
||||
|
||||
### Core Relationships
|
||||
|
||||
|
|
@ -582,3 +690,5 @@ $$
|
|||
The final loss is the sum of both: $L = L_{\text{policy}} + L_{KL}$
|
||||
|
||||
Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence.
|
||||
|
||||
> Document Update Time: 2026-04-09
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
### 1. Model Architecture
|
||||
|
||||
This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA (although KV cache is not currently implemented). The model is built by stacking 32 layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token.
|
||||
This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA. The model is built by stacking 32 layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token.
|
||||
|
||||
The model now uses the **AutoModel** base class for flexible loading and saving:
|
||||
|
||||
|
|
@ -191,3 +191,109 @@ for token in engine.generate_with_request(request):
|
|||
```
|
||||
|
||||
The continuous batching feature allows dynamic batch composition where new requests can join at any time and completed requests are released immediately.
|
||||
|
||||
## HTTP API Usage
|
||||
|
||||
The inference server provides HTTP endpoints for remote inference. Start the server first:
|
||||
|
||||
```bash
|
||||
python -m scripts.tools.server --port 8000
|
||||
```
|
||||
|
||||
### OpenAI-Compatible Endpoint
|
||||
|
||||
The server provides an OpenAI-compatible chat completion endpoint at `/v1/chat/completions`:
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"}
|
||||
],
|
||||
"temperature": 0.8,
|
||||
"max_tokens": 2048,
|
||||
"stream": false
|
||||
}'
|
||||
```
|
||||
|
||||
**Request Parameters:**
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `messages` | List[dict] | Required | Chat messages with role and content |
|
||||
| `temperature` | float | 0.8 | Sampling temperature (0.0-2.0) |
|
||||
| `top_p` | float | 0.95 | Nucleus sampling threshold |
|
||||
| `top_k` | int | 50 | Top-k sampling parameter |
|
||||
| `max_tokens` | int | 2048 | Maximum tokens to generate |
|
||||
| `stream` | bool | false | Enable streaming response |
|
||||
| `system_prompt` | str | None | System prompt override |
|
||||
|
||||
**Response (non-streaming):**
|
||||
```json
|
||||
{
|
||||
"id": "chatcmpl-1234567890",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "astrai",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": "Hello! I'm doing well..."},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Streaming Response
|
||||
|
||||
Enable streaming for real-time token-by-token output:
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [{"role": "user", "content": "Write a story"}],
|
||||
"stream": true,
|
||||
"max_tokens": 500
|
||||
}'
|
||||
```
|
||||
|
||||
The server uses Server-Sent Events (SSE) with content type `text/event-stream`.
|
||||
|
||||
### Simple Generation Endpoint
|
||||
|
||||
For basic text generation without chat format:
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:8000/generate?query=Hello&max_len=1000" \
|
||||
-H "Content-Type: application/json"
|
||||
```
|
||||
|
||||
Or with conversation history:
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:8000/generate" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"query": "What is AI?",
|
||||
"history": [["Hello", "Hi there!"], ["How are you?", "I'm doing well"]],
|
||||
"temperature": 0.8,
|
||||
"max_len": 2048
|
||||
}'
|
||||
```
|
||||
|
||||
### Health Check
|
||||
|
||||
Monitor server and model status:
|
||||
|
||||
```bash
|
||||
curl http://localhost:8000/health
|
||||
# {"status": "ok", "model_loaded": true, "engine_ready": true}
|
||||
|
||||
curl http://localhost:8000/stats
|
||||
# {"requests_total": 10, "tokens_generated": 5000, ...}
|
||||
```
|
||||
|
||||
> Document Update Time: 2026-04-09
|
||||
|
|
@ -137,3 +137,5 @@ result = engine.generate(
|
|||
|------|-------------|
|
||||
| `stream=True` | Streaming output, yields token by token |
|
||||
| `stream=False` | Non-streaming output, returns complete result |
|
||||
|
||||
> Document Update Time: 2026-04-09
|
||||
|
|
@ -12,18 +12,19 @@ from astrai.inference import (
|
|||
InferenceEngine,
|
||||
)
|
||||
from astrai.model import AutoModel, Transformer
|
||||
from astrai.tokenize import BpeTokenizer
|
||||
from astrai.trainer import SchedulerFactory, StrategyFactory, Trainer
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer
|
||||
|
||||
__all__ = [
|
||||
"Transformer",
|
||||
"ModelConfig",
|
||||
"TrainConfig",
|
||||
"DatasetFactory",
|
||||
"BpeTokenizer",
|
||||
"AutoTokenizer",
|
||||
"GenerationRequest",
|
||||
"InferenceEngine",
|
||||
"Trainer",
|
||||
"CallbackFactory",
|
||||
"StrategyFactory",
|
||||
"SchedulerFactory",
|
||||
"BaseFactory",
|
||||
|
|
|
|||
|
|
@ -72,15 +72,16 @@ class MultiSegmentFetcher:
|
|||
Each key corresponds to a different type of data (e.g., "sequence", "mask").
|
||||
"""
|
||||
|
||||
def __init__(self, muti_segments: Dict):
|
||||
self.muti_keys = list(muti_segments.keys())
|
||||
self.muti_fetchers = {
|
||||
key: BaseSegmentFetcher(segments) for key, segments in muti_segments.items()
|
||||
def __init__(self, multi_segments: Dict):
|
||||
self.multi_keys = list(multi_segments.keys())
|
||||
self.multi_fetchers = {
|
||||
key: BaseSegmentFetcher(segments)
|
||||
for key, segments in multi_segments.items()
|
||||
}
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the minimum length across all fetchers."""
|
||||
len_list = [len(seg) for seg in self.muti_fetchers.values()]
|
||||
len_list = [len(seg) for seg in self.multi_fetchers.values()]
|
||||
return min(len_list)
|
||||
|
||||
def key_fetch(
|
||||
|
|
@ -100,7 +101,7 @@ class MultiSegmentFetcher:
|
|||
keys = [keys] if isinstance(keys, str) else keys
|
||||
|
||||
for key in keys:
|
||||
fetcher = self.muti_fetchers[key]
|
||||
fetcher = self.multi_fetchers[key]
|
||||
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
|
||||
fetch_dict[key] = fetch_tensor
|
||||
|
||||
|
|
@ -108,7 +109,7 @@ class MultiSegmentFetcher:
|
|||
|
||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
|
||||
"""Fetch all keys."""
|
||||
return self.key_fetch(begin_idx, end_idx, self.muti_keys)
|
||||
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
|
||||
|
||||
|
||||
class BaseDataset(Dataset, ABC):
|
||||
|
|
|
|||
|
|
@ -45,17 +45,31 @@ class GenerationRequest:
|
|||
raise ValueError("temperature must be a non-negative number")
|
||||
|
||||
|
||||
class _StreamingResult:
|
||||
"""Streaming result holder with event-based notification."""
|
||||
class _Result:
|
||||
"""Unified result holder for streaming/non-streaming modes."""
|
||||
|
||||
def __init__(self):
|
||||
self.tokens: List[str] = []
|
||||
self._event = threading.Event()
|
||||
def __init__(self, count: int = 1, stream: bool = False):
|
||||
self._stream = stream
|
||||
self._lock = threading.Lock()
|
||||
self._event = threading.Event()
|
||||
self.tokens: List[str] = []
|
||||
self.results: List[str] = [""] * count if count > 1 else [""]
|
||||
self.done_flags: List[bool] = [False] * count
|
||||
self._completed_count = 0
|
||||
|
||||
def append(self, token: str):
|
||||
def append(self, token: str, idx: int = 0):
|
||||
with self._lock:
|
||||
if self._stream:
|
||||
self.tokens.append(token)
|
||||
else:
|
||||
if token == "[DONE]":
|
||||
if not self.done_flags[idx]:
|
||||
self.done_flags[idx] = True
|
||||
self._completed_count += 1
|
||||
if self._completed_count == len(self.results):
|
||||
self._event.set()
|
||||
else:
|
||||
self.results[idx] += token
|
||||
self._event.set()
|
||||
|
||||
def pop_all(self) -> List[str]:
|
||||
|
|
@ -69,35 +83,6 @@ class _StreamingResult:
|
|||
def wait(self, timeout: float = None) -> bool:
|
||||
return self._event.wait(timeout=timeout)
|
||||
|
||||
|
||||
class _NonStreamingResult:
|
||||
"""Non-streaming result holder with event-based completion notification."""
|
||||
|
||||
def __init__(self, count: int):
|
||||
self.results: List[str] = [""] * count
|
||||
self.done_flags: List[bool] = [False] * count
|
||||
self._completed_count = 0
|
||||
self._event = threading.Event()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def append(self, idx: int, token: str):
|
||||
with self._lock:
|
||||
if token == "[DONE]":
|
||||
if not self.done_flags[idx]:
|
||||
self.done_flags[idx] = True
|
||||
self._completed_count += 1
|
||||
if self._completed_count == len(self.results):
|
||||
self._event.set()
|
||||
else:
|
||||
self.results[idx] += token
|
||||
|
||||
def is_all_done(self) -> bool:
|
||||
with self._lock:
|
||||
return all(self.done_flags)
|
||||
|
||||
def wait(self, timeout: float = None) -> bool:
|
||||
return self._event.wait(timeout=timeout)
|
||||
|
||||
def get_results(self) -> List[str]:
|
||||
with self._lock:
|
||||
return self.results.copy()
|
||||
|
|
@ -112,6 +97,8 @@ class InferenceEngine:
|
|||
tokenizer: AutoTokenizer,
|
||||
max_batch_size: int = 1,
|
||||
max_seq_len: Optional[int] = None,
|
||||
max_prefix_len: int = 512,
|
||||
cache_capacity: int = 1000,
|
||||
):
|
||||
"""
|
||||
Initialize inference engine with separate model and tokenizer.
|
||||
|
|
@ -122,6 +109,8 @@ class InferenceEngine:
|
|||
config: Model configuration
|
||||
max_batch_size: Maximum batch size for continuous batching
|
||||
max_seq_len: Maximum sequence length (defaults to config.max_len)
|
||||
max_prefix_len: Maximum prefix length for cache (default: 512)
|
||||
cache_capacity: Maximum number of cached prefixes (default: 1000)
|
||||
"""
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
|
@ -141,6 +130,8 @@ class InferenceEngine:
|
|||
tokenizer=self.tokenizer,
|
||||
max_batch_size=max_batch_size,
|
||||
max_seq_len=max_seq_len,
|
||||
max_prefix_len=max_prefix_len,
|
||||
cache_capacity=cache_capacity,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
|
@ -227,7 +218,7 @@ class InferenceEngine:
|
|||
if is_batch:
|
||||
raise NotImplementedError("Batch streaming is not implemented yet")
|
||||
|
||||
result = _StreamingResult()
|
||||
result = _Result(stream=True)
|
||||
|
||||
task_id = self.scheduler.add_task(
|
||||
prompt=prompts[0],
|
||||
|
|
@ -266,7 +257,7 @@ class InferenceEngine:
|
|||
top_k: int,
|
||||
) -> Union[str, List[str]]:
|
||||
"""Generate without streaming."""
|
||||
result = _NonStreamingResult(len(prompts))
|
||||
result = _Result(count=len(prompts))
|
||||
|
||||
for i, p in enumerate(prompts):
|
||||
# Create closure to capture current index value using factory function
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
@ -12,6 +12,135 @@ from astrai.model.automodel import AutoModel
|
|||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
|
||||
class RadixNode:
|
||||
"""Radix tree node for prefix cache."""
|
||||
|
||||
def __init__(self):
|
||||
self.children: Dict[int, "RadixNode"] = {} # token_id -> child node
|
||||
self.hash: Optional[int] = None # 64-bit hash of the prefix
|
||||
self.slot: int = -1 # KV Cache slot, valid only for leaf nodes
|
||||
self.ref_count: int = 0 # number of tasks referencing this prefix
|
||||
self.last_access: float = 0.0 # timestamp for LRU
|
||||
self.token_sequence: list = [] # full token sequence from root to this node
|
||||
|
||||
|
||||
class PrefixCacheManager:
|
||||
"""Prefix cache manager using Radix tree with LRU eviction."""
|
||||
|
||||
def __init__(self, max_capacity: int = 1000, base: int = 131, mod: int = 10**9 + 7):
|
||||
self.root = RadixNode()
|
||||
self.base = base
|
||||
self.mod = mod
|
||||
self.max_capacity = max_capacity
|
||||
self.lru: List[Tuple[float, RadixNode]] = [] # (timestamp, node) for LRU
|
||||
|
||||
def insert(self, token_ids: Tuple[int, ...], slot: int) -> None:
|
||||
"""Insert a prefix, increase ref_count if already exists, otherwise create new node."""
|
||||
node = self.root
|
||||
path = []
|
||||
h = 0
|
||||
for i, token_id in enumerate(token_ids):
|
||||
if token_id not in node.children:
|
||||
node.children[token_id] = RadixNode()
|
||||
node = node.children[token_id]
|
||||
h = (h * self.base + token_id) % self.mod
|
||||
node.hash = h
|
||||
path.append(token_id)
|
||||
node.token_sequence = list(
|
||||
path
|
||||
) # store full sequence for exact verification
|
||||
|
||||
# Leaf node: set slot and increase ref_count
|
||||
if node.slot == -1:
|
||||
node.slot = slot
|
||||
node.ref_count += 1
|
||||
node.last_access = time.time()
|
||||
self._update_lru(node)
|
||||
self._evict_if_needed()
|
||||
|
||||
def find_longest_prefix(self, token_ids: List[int]) -> Optional[Tuple[int, int]]:
|
||||
"""Find longest matching prefix, return (prefix_len, slot).
|
||||
|
||||
During traversal, compute hash per token and compare with node hash.
|
||||
If hash matches, perform full token sequence verification to avoid
|
||||
hash collision errors.
|
||||
"""
|
||||
node = self.root
|
||||
best_len = 0
|
||||
best_slot = -1
|
||||
h = 0
|
||||
|
||||
for i, token_id in enumerate(token_ids):
|
||||
if token_id not in node.children:
|
||||
break
|
||||
node = node.children[token_id]
|
||||
h = (h * self.base + token_id) % self.mod
|
||||
if node.hash == h: # hash matches
|
||||
# Exact verification: compare full token sequence
|
||||
if node.token_sequence == token_ids[: i + 1]:
|
||||
best_len = i + 1
|
||||
best_slot = node.slot
|
||||
node.last_access = time.time()
|
||||
self._update_lru(node)
|
||||
|
||||
if best_len > 0:
|
||||
return (best_len, best_slot)
|
||||
return None
|
||||
|
||||
def release(self, token_ids: Tuple[int, ...]) -> None:
|
||||
"""Release reference to a prefix, decrease ref_count. If zero, mark as evictable."""
|
||||
node = self.root
|
||||
for token_id in token_ids:
|
||||
if token_id not in node.children:
|
||||
return
|
||||
node = node.children[token_id]
|
||||
if node.ref_count > 0:
|
||||
node.ref_count -= 1
|
||||
if node.ref_count == 0:
|
||||
node.slot = -1 # slot can be reused
|
||||
|
||||
def _update_lru(self, node: RadixNode) -> None:
|
||||
"""Update LRU list, move node to most recently used position."""
|
||||
self.lru = [(ts, n) for (ts, n) in self.lru if n is not node]
|
||||
self.lru.append((node.last_access, node))
|
||||
|
||||
def _evict_if_needed(self) -> None:
|
||||
"""If cache entries exceed capacity, evict least recently used leaf nodes (ref_count must be 0)."""
|
||||
if len(self.lru) <= self.max_capacity:
|
||||
return
|
||||
# Sort by timestamp
|
||||
self.lru.sort(key=lambda x: x[0])
|
||||
for ts, node in self.lru:
|
||||
if node.ref_count == 0:
|
||||
# Remove leaf node from tree (need to recursively delete empty branches)
|
||||
self._remove_node(node)
|
||||
self.lru.remove((ts, node))
|
||||
if len(self.lru) <= self.max_capacity:
|
||||
break
|
||||
|
||||
def _remove_node(
|
||||
self,
|
||||
node: RadixNode,
|
||||
parent: Optional[RadixNode] = None,
|
||||
child_key: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Remove node from tree, including empty parent nodes."""
|
||||
# First, recursively remove all children
|
||||
for child_key, child_node in list(node.children.items()):
|
||||
self._remove_node(child_node, node, child_key)
|
||||
|
||||
# Clear the node's leaf properties
|
||||
node.slot = -1
|
||||
node.hash = None
|
||||
node.token_sequence = []
|
||||
node.children.clear()
|
||||
|
||||
# If this node has no children and has a parent, remove the reference from parent
|
||||
if parent is not None and child_key is not None and len(node.children) == 0:
|
||||
if child_key in parent.children:
|
||||
del parent.children[child_key]
|
||||
|
||||
|
||||
class TaskStatus:
|
||||
"""Task state for continuous batching."""
|
||||
|
||||
|
|
@ -46,6 +175,7 @@ class Task:
|
|||
self.input_tokens: int = 0
|
||||
self.output_tokens: int = 0
|
||||
self.slot: int = -1
|
||||
self.prefix_len: int = 0 # prefix cache matched length
|
||||
self.arrival_time = time.time()
|
||||
self.finish_time: Optional[float] = None
|
||||
|
||||
|
|
@ -53,9 +183,10 @@ class Task:
|
|||
|
||||
def is_finished(self, stop_ids: List[int]) -> bool:
|
||||
"""Check if task is finished."""
|
||||
if self.output_ids and self.output_ids[-1] in stop_ids:
|
||||
return True
|
||||
return self.output_tokens >= self.max_tokens
|
||||
return (
|
||||
bool(self.output_ids and self.output_ids[-1] in stop_ids)
|
||||
or self.output_tokens >= self.max_tokens
|
||||
)
|
||||
|
||||
|
||||
def apply_sampling_strategies(
|
||||
|
|
@ -104,6 +235,8 @@ class InferenceScheduler:
|
|||
tokenizer: AutoTokenizer,
|
||||
max_batch_size: int = 16,
|
||||
max_seq_len: Optional[int] = None,
|
||||
max_prefix_len: int = 512,
|
||||
cache_capacity: int = 1000,
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
):
|
||||
|
|
@ -113,9 +246,13 @@ class InferenceScheduler:
|
|||
self.tokenizer = tokenizer
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_seq_len = max_seq_len or config.max_len
|
||||
self.max_prefix_len = max_prefix_len
|
||||
self.device = device or next(model.parameters()).device
|
||||
self.dtype = dtype or next(model.parameters()).dtype
|
||||
|
||||
# Initialize prefix cache
|
||||
self.prefix_cache = PrefixCacheManager(max_capacity=cache_capacity)
|
||||
|
||||
num_kv_heads = config.n_kv_heads
|
||||
head_dim = config.dim // config.n_heads
|
||||
n_layers = config.n_layers
|
||||
|
|
@ -170,6 +307,10 @@ class InferenceScheduler:
|
|||
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||
prompt_ids = self.tokenizer.encode(prompt)
|
||||
|
||||
# Truncate if exceeds max_prefix_len
|
||||
if len(prompt_ids) > self.max_prefix_len:
|
||||
prompt_ids = prompt_ids[: self.max_prefix_len]
|
||||
|
||||
task = Task(
|
||||
task_id=task_id,
|
||||
prompt_ids=prompt_ids,
|
||||
|
|
@ -180,6 +321,16 @@ class InferenceScheduler:
|
|||
stream_callback=stream_callback,
|
||||
)
|
||||
|
||||
# Find longest matching prefix from cache
|
||||
match = self.prefix_cache.find_longest_prefix(prompt_ids)
|
||||
if match:
|
||||
prefix_len, slot = match
|
||||
task.prefix_len = prefix_len
|
||||
task.slot = slot
|
||||
else:
|
||||
task.prefix_len = 0
|
||||
task.slot = -1
|
||||
|
||||
with self._lock:
|
||||
self.waiting_queue.append(task)
|
||||
self._total_tasks += 1
|
||||
|
|
@ -207,6 +358,11 @@ class InferenceScheduler:
|
|||
slot = task.slot
|
||||
if slot >= 0 and slot < len(self.active_tasks):
|
||||
self.seq_mask[slot, :] = False
|
||||
|
||||
# Release prefix cache reference
|
||||
if task.prefix_len > 0:
|
||||
self.prefix_cache.release(tuple(task.prompt_ids[: task.prefix_len]))
|
||||
|
||||
task.slot = -1
|
||||
|
||||
self.active_tasks = [
|
||||
|
|
@ -220,22 +376,51 @@ class InferenceScheduler:
|
|||
return
|
||||
|
||||
with self._lock:
|
||||
to_add = []
|
||||
for _ in range(min(available_slots, len(self.waiting_queue))):
|
||||
if self.waiting_queue:
|
||||
task = self.waiting_queue.pop(0)
|
||||
task.status = TaskStatus.RUNNING
|
||||
to_add.append(task)
|
||||
|
||||
to_add = [
|
||||
self.waiting_queue.pop(0)
|
||||
for _ in range(min(available_slots, len(self.waiting_queue)))
|
||||
]
|
||||
for task in to_add:
|
||||
for i in range(self.max_batch_size):
|
||||
if all(t.slot != i for t in self.active_tasks):
|
||||
task.slot = i
|
||||
break
|
||||
task.slot = self._allocate_slot()
|
||||
task.status = TaskStatus.RUNNING
|
||||
self.active_tasks.append(task)
|
||||
|
||||
def _allocate_slot(self) -> int:
|
||||
"""Allocate an available slot for a task."""
|
||||
for i in range(self.max_batch_size):
|
||||
if not any(t.slot == i for t in self.active_tasks):
|
||||
return i
|
||||
return -1
|
||||
|
||||
def _execute_prefill(self, tasks: List[Task]) -> None:
|
||||
"""Execute Prefill phase."""
|
||||
"""Execute Prefill phase with incremental prefill support."""
|
||||
if not tasks:
|
||||
return
|
||||
|
||||
# Group tasks by prefix cache status
|
||||
fully_cached, partial, full = [], [], []
|
||||
for task in tasks:
|
||||
total_len, prefix_len = len(task.prompt_ids), task.prefix_len
|
||||
if prefix_len == total_len:
|
||||
fully_cached.append(task)
|
||||
elif prefix_len > 0:
|
||||
partial.append(task)
|
||||
else:
|
||||
full.append(task)
|
||||
|
||||
# Handle fully cached tasks
|
||||
for t in fully_cached:
|
||||
t.input_tokens, t.output_tokens = len(t.prompt_ids), 0
|
||||
if t.slot >= 0:
|
||||
self.seq_mask[t.slot, : t.input_tokens] = True
|
||||
|
||||
if full:
|
||||
self._execute_full_prefill(full)
|
||||
if partial:
|
||||
self._execute_partial_prefill(partial)
|
||||
|
||||
def _execute_full_prefill(self, tasks: List[Task]) -> None:
|
||||
"""Execute full prefill for tasks without prefix cache."""
|
||||
if not tasks:
|
||||
return
|
||||
|
||||
|
|
@ -271,11 +456,59 @@ class InferenceScheduler:
|
|||
for i, task in enumerate(tasks):
|
||||
task.input_tokens = prompt_lens[i]
|
||||
task.output_tokens = 0
|
||||
# Insert new prefix into cache
|
||||
self.prefix_cache.insert(tuple(task.prompt_ids), task.slot)
|
||||
|
||||
for task in tasks:
|
||||
if task.slot >= 0:
|
||||
self.seq_mask[task.slot, : task.input_tokens] = True
|
||||
|
||||
def _execute_partial_prefill(self, tasks: List[Task]) -> None:
|
||||
"""Execute incremental prefill for tasks with partial prefix cache match."""
|
||||
for task in tasks:
|
||||
total_len = len(task.prompt_ids)
|
||||
prefix_len = task.prefix_len
|
||||
|
||||
if prefix_len >= total_len:
|
||||
task.input_tokens = total_len
|
||||
task.output_tokens = 0
|
||||
continue
|
||||
|
||||
# Get new tokens that need prefill
|
||||
new_ids = task.prompt_ids[prefix_len:]
|
||||
new_len = len(new_ids)
|
||||
|
||||
if new_len == 0:
|
||||
task.input_tokens = total_len
|
||||
task.output_tokens = 0
|
||||
continue
|
||||
|
||||
# Build input for incremental prefill
|
||||
input_ids = torch.tensor([new_ids], dtype=torch.long, device=self.device)
|
||||
|
||||
# Input mask should cover from position 0 to prefix_len + new_len
|
||||
# The prefix part uses cached KV, new part needs computation
|
||||
input_mask = torch.ones(
|
||||
(1, prefix_len + new_len), dtype=torch.bool, device=self.device
|
||||
)
|
||||
|
||||
with torch.inference_mode():
|
||||
self.model(
|
||||
input_ids,
|
||||
input_mask=input_mask,
|
||||
start_pos=prefix_len,
|
||||
persistent_key_values=self.kv_cache,
|
||||
)
|
||||
|
||||
task.input_tokens = total_len
|
||||
task.output_tokens = 0
|
||||
|
||||
# Insert full prefix into cache (ref_count already increased in add_task)
|
||||
self.prefix_cache.insert(tuple(task.prompt_ids), task.slot)
|
||||
|
||||
if task.slot >= 0:
|
||||
self.seq_mask[task.slot, : task.input_tokens] = True
|
||||
|
||||
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
|
||||
"""Execute Decode phase."""
|
||||
if not tasks:
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import json
|
|||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import uvicorn
|
||||
|
|
@ -134,78 +134,6 @@ class CompletionResponse(BaseModel):
|
|||
choices: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class StreamCompletionResponse(BaseModel):
|
||||
id: str = "chatcmpl-default"
|
||||
object: str = "chat.completion.chunk"
|
||||
created: int = 0
|
||||
model: str = "astrai"
|
||||
choices: List[Dict[str, Any]]
|
||||
|
||||
|
||||
def convert_messages_to_history(
|
||||
messages: List[ChatMessage],
|
||||
) -> tuple[Optional[str], Optional[List[Tuple[str, str]]]]:
|
||||
"""Convert OpenAI-style messages to system_prompt and history."""
|
||||
system_prompt = None
|
||||
history: List[Tuple[str, str]] = []
|
||||
user_buffer = []
|
||||
assistant_buffer = []
|
||||
for msg in messages:
|
||||
if msg.role == "system":
|
||||
system_prompt = msg.content
|
||||
elif msg.role == "user":
|
||||
if assistant_buffer:
|
||||
# Flush previous pair
|
||||
history.append(("".join(user_buffer), "".join(assistant_buffer)))
|
||||
user_buffer = []
|
||||
assistant_buffer = []
|
||||
user_buffer.append(msg.content)
|
||||
elif msg.role == "assistant":
|
||||
assistant_buffer.append(msg.content)
|
||||
else:
|
||||
logger.warning(f"Unknown role {msg.role}")
|
||||
return system_prompt, history if history else None
|
||||
|
||||
|
||||
def convert_messages_to_prompt(
|
||||
messages: List[ChatMessage], engine: InferenceEngine = None
|
||||
) -> str:
|
||||
"""Convert messages to prompt string.
|
||||
|
||||
Args:
|
||||
messages: List of ChatMessage objects
|
||||
engine: InferenceEngine instance for accessing tokenizer
|
||||
|
||||
Returns:
|
||||
str: Formatted prompt string
|
||||
"""
|
||||
# Convert to dict format for chat template
|
||||
msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
|
||||
|
||||
# Extract system prompt if present
|
||||
system_prompt = None
|
||||
filtered_messages = []
|
||||
for msg in msg_dicts:
|
||||
if msg["role"] == "system":
|
||||
system_prompt = msg["content"]
|
||||
else:
|
||||
filtered_messages.append(msg)
|
||||
|
||||
# Use engine's tokenizer chat template if available
|
||||
if engine is not None and engine.tokenizer is not None:
|
||||
return engine.tokenizer.apply_chat_template(
|
||||
filtered_messages, system_prompt=system_prompt, tokenize=False
|
||||
)
|
||||
|
||||
# Fallback: simple concatenation (deprecated)
|
||||
prompt_parts = []
|
||||
for msg in filtered_messages:
|
||||
prompt_parts.append(
|
||||
f"<|im▁start|>{msg['role']}\n{msg['content']}<|im▁end|>"
|
||||
)
|
||||
return "\n".join(prompt_parts) + "\n<|im▁start|>assistant\n"
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {
|
||||
|
|
@ -233,7 +161,12 @@ async def chat_completion(request: ChatCompletionRequest):
|
|||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
|
||||
# Convert messages to prompt using engine's tokenizer
|
||||
prompt = convert_messages_to_prompt(request.messages, engine=_engine)
|
||||
# Extract system prompt if present, then apply chat template
|
||||
# Apply chat template directly with messages
|
||||
prompt = _engine.tokenizer.apply_chat_template(
|
||||
[{"role": m.role, "content": m.content} for m in request.messages],
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
# Streaming response (use synchronous generator)
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ def get_rotary_emb(
|
|||
dim: int,
|
||||
max_len: int,
|
||||
base: float = 10000,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get the rotary embedding for the given dimension and maximum length.
|
||||
|
|
@ -37,12 +38,13 @@ def get_rotary_emb(
|
|||
dim (int): The dimension of the input.
|
||||
max_len (int): The maximum length of the input.
|
||||
base (float, optional): The base for the frequency. Defaults to 10000.
|
||||
device (optional): The device to create tensors on. Defaults to None.
|
||||
Returns:
|
||||
Tensor: The rotary embedding tensor.
|
||||
"""
|
||||
|
||||
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64) / dim)
|
||||
t = torch.arange(0, max_len, dtype=torch.float64)
|
||||
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
||||
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
||||
freqs = torch.outer(t, theta)
|
||||
|
||||
return torch.cos(freqs).float(), torch.sin(freqs).float()
|
||||
|
|
@ -83,10 +85,10 @@ class RotaryEmbedding(nn.Module):
|
|||
self.max_len = max_len
|
||||
self.base = base
|
||||
self.max_len_cached = None
|
||||
self._set_rotary_buffer(self.max_len)
|
||||
self._set_rotary_buffer(self.max_len, None)
|
||||
|
||||
def _set_rotary_buffer(self, max_len: int):
|
||||
cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base)
|
||||
def _set_rotary_buffer(self, max_len: int, device: Optional[torch.device] = None):
|
||||
cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base, device)
|
||||
self.register_buffer("cos_cached", cos_cached, persistent=False)
|
||||
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
||||
self.max_len_cached = max_len
|
||||
|
|
@ -95,7 +97,7 @@ class RotaryEmbedding(nn.Module):
|
|||
seq_len = x.size(1)
|
||||
|
||||
if self.max_len_cached < seq_len + start_pos:
|
||||
self._set_rotary_buffer(seq_len + start_pos)
|
||||
self._set_rotary_buffer(self.max_len_cached * 2, x.device)
|
||||
|
||||
cos = self.cos_cached[start_pos : start_pos + seq_len]
|
||||
sin = self.sin_cached[start_pos : start_pos + seq_len]
|
||||
|
|
@ -121,8 +123,7 @@ class RMSNorm(nn.Module):
|
|||
self.norm_eps = norm_eps
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
rms = F.rms_norm(x.float(), self.normalized_shape, self.weight, self.norm_eps)
|
||||
return rms.to(x.dtype)
|
||||
return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
|
|
@ -257,7 +258,7 @@ class MLA(nn.Module):
|
|||
|
||||
self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False)
|
||||
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
|
||||
self.kv_norm = RMSNorm(kv_lora_rank, eps=norm_eps)
|
||||
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
|
||||
|
||||
# KV (k_nope, k_rope, v)
|
||||
self.kv_b_proj = Linear(
|
||||
|
|
|
|||
|
|
@ -1,15 +1,8 @@
|
|||
from astrai.tokenize.chat_template import ChatTemplate, MessageType
|
||||
from astrai.tokenize.tokenizer import (
|
||||
AutoTokenizer,
|
||||
BpeTokenizer,
|
||||
)
|
||||
from astrai.tokenize.trainer import BpeTrainer
|
||||
from astrai.tokenize.tokenizer import AutoTokenizer
|
||||
|
||||
__all__ = [
|
||||
"AutoTokenizer",
|
||||
"BpeTokenizer",
|
||||
"BpeTrainer",
|
||||
"ChatTemplate",
|
||||
"MessageType",
|
||||
"HistoryType",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -6,8 +6,7 @@ import json
|
|||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers import Tokenizer
|
||||
|
||||
from astrai.tokenize.chat_template import ChatTemplate
|
||||
|
||||
|
|
@ -210,9 +209,9 @@ class AutoTokenizer:
|
|||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
system_prompt: Optional system prompt string.
|
||||
system_prompt: Optional system prompt string (auto-converted to first message).
|
||||
tokenize: Whether to return token IDs (True) or raw string (False).
|
||||
add_generation_prompt: Whether to add the generation prompt (default: False).
|
||||
add_generation_prompt: Whether to add the generation prompt (default: True).
|
||||
**kwargs: Additional variables to pass to the template.
|
||||
|
||||
Returns:
|
||||
|
|
@ -226,10 +225,13 @@ class AutoTokenizer:
|
|||
"Chat template not set. Use set_chat_template() to set a template first."
|
||||
)
|
||||
|
||||
# Auto-convert system_prompt to first message if provided
|
||||
if system_prompt:
|
||||
messages = [{"role": "system", "content": system_prompt}] + list(messages)
|
||||
|
||||
# Render the template
|
||||
rendered = self._chat_template.render(
|
||||
messages=messages,
|
||||
system_prompt=system_prompt,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
@ -238,42 +240,3 @@ class AutoTokenizer:
|
|||
return self.encode(rendered)
|
||||
|
||||
return rendered
|
||||
|
||||
|
||||
class BpeTokenizer(AutoTokenizer):
|
||||
"""BPE tokenizer implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
special_token_map: Dict[str, str] = None,
|
||||
path: Optional[str] = None,
|
||||
chat_template: Optional[str] = None,
|
||||
):
|
||||
special_token_map = special_token_map or {
|
||||
"bos": "<|begin▁of▁sentence|>",
|
||||
"eos": "<|end▁of▁sentence|>",
|
||||
"pad": "<|▁pad▁|>",
|
||||
"im_start": "<|im▁start|>",
|
||||
"im_end": "<|im▁end|>",
|
||||
}
|
||||
self._tokenizer = None
|
||||
self._init_tokenizer()
|
||||
super().__init__(
|
||||
path, special_token_map=special_token_map, chat_template=chat_template
|
||||
)
|
||||
|
||||
def _init_tokenizer(self):
|
||||
"""Initialize a new BPE tokenizer with default settings."""
|
||||
model = BPE()
|
||||
self._tokenizer = Tokenizer(model)
|
||||
self._tokenizer.normalizer = normalizers.Sequence(
|
||||
[normalizers.NFC(), normalizers.Strip()]
|
||||
)
|
||||
self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
|
||||
[
|
||||
pre_tokenizers.UnicodeScripts(),
|
||||
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True),
|
||||
]
|
||||
)
|
||||
self._tokenizer.decoder = decoders.ByteLevel()
|
||||
self._tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)
|
||||
|
|
|
|||
|
|
@ -1,108 +0,0 @@
|
|||
"""
|
||||
BPE Tokenizer Trainer module.
|
||||
|
||||
Provides training functionality for BPE tokenizers.
|
||||
"""
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
from tokenizers import pre_tokenizers
|
||||
from tokenizers.trainers import BpeTrainer as BpeTrainerImpl
|
||||
|
||||
|
||||
class BpeTrainer:
|
||||
"""BPE tokenizer trainer."""
|
||||
|
||||
def __init__(self, tokenizer):
|
||||
"""Initialize trainer with a tokenizer instance.
|
||||
|
||||
Args:
|
||||
tokenizer: A BpeTokenizer instance
|
||||
"""
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def _prepare_trainer(
|
||||
self,
|
||||
vocab_size: int,
|
||||
min_freq: int,
|
||||
reserved_token_size: int,
|
||||
max_token_length: int = 18,
|
||||
):
|
||||
"""Prepare the BPE trainer with proper configuration."""
|
||||
assert reserved_token_size > len(self.tokenizer._special_tokens)
|
||||
reserved_tokens = [
|
||||
f"<|reserve{i:02d}|>"
|
||||
for i in range(reserved_token_size - len(self.tokenizer._special_tokens))
|
||||
]
|
||||
detail_vocab_size = vocab_size - (
|
||||
len(reserved_tokens) + len(self.tokenizer._special_tokens)
|
||||
)
|
||||
alphabet = pre_tokenizers.ByteLevel.alphabet()
|
||||
min_size = len(alphabet) + len(self.tokenizer._control_tokens)
|
||||
assert detail_vocab_size > min_size
|
||||
|
||||
trainer = BpeTrainerImpl(
|
||||
vocab_size=detail_vocab_size,
|
||||
min_frequency=min_freq,
|
||||
limit_alphabet=detail_vocab_size // 6,
|
||||
max_token_length=max_token_length,
|
||||
special_tokens=self.tokenizer._control_tokens,
|
||||
initial_alphabet=alphabet,
|
||||
show_progress=True,
|
||||
)
|
||||
return trainer, reserved_tokens
|
||||
|
||||
def train(
|
||||
self,
|
||||
files: Union[str, List[str]],
|
||||
vocab_size: int,
|
||||
min_freq: int,
|
||||
reserved_token_size: int = 100,
|
||||
**kwargs,
|
||||
):
|
||||
"""Train tokenizer from files.
|
||||
|
||||
Args:
|
||||
files: Path or list of paths to training files
|
||||
vocab_size: Target vocabulary size
|
||||
min_freq: Minimum frequency for tokens
|
||||
reserved_token_size: Number of reserved tokens
|
||||
**kwargs: Additional arguments
|
||||
"""
|
||||
trainer, reserved_tokens = self._prepare_trainer(
|
||||
vocab_size, min_freq, reserved_token_size, **kwargs
|
||||
)
|
||||
self.tokenizer._tokenizer.train(files=files, trainer=trainer)
|
||||
self.tokenizer._tokenizer.add_special_tokens(
|
||||
self.tokenizer._special_tokens + reserved_tokens
|
||||
)
|
||||
|
||||
def train_from_iterator(
|
||||
self,
|
||||
iterator,
|
||||
vocab_size: int,
|
||||
min_freq: int,
|
||||
reserved_token_size: int = 100,
|
||||
**kwargs,
|
||||
):
|
||||
"""Train tokenizer from iterator.
|
||||
|
||||
Args:
|
||||
iterator: Iterator yielding training strings
|
||||
vocab_size: Target vocabulary size
|
||||
min_freq: Minimum frequency for tokens
|
||||
reserved_token_size: Number of reserved tokens
|
||||
**kwargs: Additional arguments
|
||||
"""
|
||||
trainer, reserved_tokens = self._prepare_trainer(
|
||||
vocab_size, min_freq, reserved_token_size, **kwargs
|
||||
)
|
||||
self.tokenizer._tokenizer.train_from_iterator(
|
||||
iterator=iterator, trainer=trainer
|
||||
)
|
||||
self.tokenizer._tokenizer.add_special_tokens(
|
||||
self.tokenizer._special_tokens + reserved_tokens
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["BpeTrainer"]
|
||||
|
|
@ -234,13 +234,13 @@ def train(
|
|||
},
|
||||
)
|
||||
|
||||
toltal_steps = len(dataset) * n_epoch // (batch_size * nprocs)
|
||||
total_steps = len(dataset) * n_epoch // (batch_size * nprocs)
|
||||
scheduler_fn = partial(
|
||||
create_scheduler,
|
||||
**{
|
||||
"schedule_type": "cosine",
|
||||
"warmup_steps": warmup_steps,
|
||||
"lr_decay_steps": toltal_steps - warmup_steps,
|
||||
"lr_decay_steps": total_steps - warmup_steps,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,12 +7,27 @@ import numpy as np
|
|||
import pytest
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
from tokenizers import pre_tokenizers
|
||||
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.model.transformer import Transformer
|
||||
from astrai.tokenize import BpeTokenizer, BpeTrainer
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
|
||||
def create_test_tokenizer(vocab_size: int = 1000) -> AutoTokenizer:
|
||||
"""Create a simple tokenizer for testing purposes."""
|
||||
tokenizer = Tokenizer(models.BPE())
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel()
|
||||
trainer = trainers.BpeTrainer(
|
||||
vocab_size=vocab_size, min_frequency=1, special_tokens=["<unk>", "<pad>"]
|
||||
)
|
||||
# Train on empty iterator with single character
|
||||
tokenizer.train_from_iterator([chr(i) for i in range(256)], trainer)
|
||||
auto_tokenizer = AutoTokenizer()
|
||||
auto_tokenizer._tokenizer = tokenizer
|
||||
auto_tokenizer._special_token_map = {"unk_token": "<unk>", "pad_token": "<pad>"}
|
||||
return auto_tokenizer
|
||||
|
||||
|
||||
class RandomDataset(Dataset):
|
||||
|
|
@ -109,7 +124,7 @@ def base_test_env(request: pytest.FixtureRequest):
|
|||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
transformer_config = ModelConfig().load(config_path)
|
||||
model = Transformer(transformer_config).to(device=device)
|
||||
tokenizer = BpeTokenizer()
|
||||
tokenizer = create_test_tokenizer()
|
||||
|
||||
yield {
|
||||
"device": device,
|
||||
|
|
@ -164,10 +179,7 @@ def test_env(request: pytest.FixtureRequest):
|
|||
with open(config_path, "w") as f:
|
||||
json.dump(config, f)
|
||||
|
||||
tokenizer = BpeTokenizer()
|
||||
trainer = BpeTrainer(tokenizer)
|
||||
sp_token_iter = iter(pre_tokenizers.ByteLevel.alphabet())
|
||||
trainer.train_from_iterator(sp_token_iter, config["vocab_size"], 1)
|
||||
tokenizer = create_test_tokenizer(vocab_size=config["vocab_size"])
|
||||
tokenizer.save(tokenizer_path)
|
||||
|
||||
transformer_config = ModelConfig().load(config_path)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,320 @@
|
|||
"""Tests for scheduler concurrency."""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from astrai.inference.scheduler import (
|
||||
InferenceScheduler,
|
||||
PrefixCacheManager,
|
||||
)
|
||||
|
||||
|
||||
def test_prefix_cache_concurrent_insert_find():
|
||||
"""Test concurrent insert and find operations."""
|
||||
cache = PrefixCacheManager(max_capacity=100)
|
||||
|
||||
results = {"errors": [], "inserts": 0, "finds": 0}
|
||||
|
||||
def insert_worker():
|
||||
try:
|
||||
for i in range(50):
|
||||
cache.insert((i,), slot=i % 10)
|
||||
results["inserts"] += 1
|
||||
except Exception as e:
|
||||
results["errors"].append(str(e))
|
||||
|
||||
def find_worker():
|
||||
try:
|
||||
for i in range(50):
|
||||
cache.find_longest_prefix([i])
|
||||
results["finds"] += 1
|
||||
except Exception as e:
|
||||
results["errors"].append(str(e))
|
||||
|
||||
threads = [threading.Thread(target=insert_worker) for _ in range(3)]
|
||||
threads += [threading.Thread(target=find_worker) for _ in range(3)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
||||
assert results["inserts"] == 150
|
||||
assert results["finds"] == 150
|
||||
|
||||
|
||||
def test_prefix_cache_concurrent_release():
|
||||
"""Test concurrent release operations."""
|
||||
cache = PrefixCacheManager(max_capacity=100)
|
||||
|
||||
# Insert some prefixes
|
||||
for i in range(10):
|
||||
cache.insert((i,), slot=i)
|
||||
|
||||
results = {"errors": []}
|
||||
|
||||
def release_worker():
|
||||
try:
|
||||
for i in range(10):
|
||||
cache.release((i,))
|
||||
except Exception as e:
|
||||
results["errors"].append(str(e))
|
||||
|
||||
threads = [threading.Thread(target=release_worker) for _ in range(3)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
||||
|
||||
|
||||
def test_prefix_cache_concurrent_insert_release_find():
|
||||
"""Test mixed concurrent operations."""
|
||||
cache = PrefixCacheManager(max_capacity=50)
|
||||
|
||||
results = {"errors": []}
|
||||
|
||||
def worker(worker_id):
|
||||
try:
|
||||
for i in range(20):
|
||||
token_ids = (worker_id * 100 + i,)
|
||||
cache.insert(token_ids, slot=worker_id)
|
||||
|
||||
# Find after insert
|
||||
cache.find_longest_prefix(list(token_ids))
|
||||
|
||||
# Release
|
||||
cache.release(token_ids)
|
||||
except Exception as e:
|
||||
results["errors"].append(f"Worker {worker_id}: {str(e)}")
|
||||
|
||||
threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_and_tokenizer():
|
||||
"""Create mock model and tokenizer."""
|
||||
mock_model = MagicMock()
|
||||
mock_model.config = MagicMock()
|
||||
mock_model.config.n_kv_heads = 8
|
||||
mock_model.config.n_heads = 8
|
||||
mock_model.config.dim = 128
|
||||
mock_model.config.n_layers = 2
|
||||
mock_model.config.max_len = 100
|
||||
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_tokenizer.encode.return_value = [1, 2, 3, 4, 5]
|
||||
mock_tokenizer.decode.return_value = "token"
|
||||
mock_tokenizer.stop_ids = [0]
|
||||
mock_tokenizer.pad_id = None
|
||||
|
||||
return mock_model, mock_tokenizer
|
||||
|
||||
|
||||
def test_scheduler_concurrent_add_task(mock_model_and_tokenizer):
|
||||
"""Test concurrent add_task operations."""
|
||||
mock_model, mock_tokenizer = mock_model_and_tokenizer
|
||||
|
||||
with patch("astrai.inference.scheduler.AutoModel"):
|
||||
with patch("astrai.inference.scheduler.AutoTokenizer"):
|
||||
scheduler = InferenceScheduler(
|
||||
model=mock_model,
|
||||
tokenizer=mock_tokenizer,
|
||||
max_batch_size=4,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
results = {"task_ids": [], "errors": []}
|
||||
lock = threading.Lock()
|
||||
|
||||
def add_task_worker(worker_id):
|
||||
try:
|
||||
for i in range(10):
|
||||
task_id = scheduler.add_task(f"prompt from worker {worker_id}-{i}")
|
||||
with lock:
|
||||
results["task_ids"].append(task_id)
|
||||
except Exception as e:
|
||||
results["errors"].append(str(e))
|
||||
|
||||
threads = [threading.Thread(target=add_task_worker, args=(i,)) for i in range(5)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
# Let some tasks be processed
|
||||
time.sleep(0.1)
|
||||
|
||||
scheduler.stop()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
||||
assert len(results["task_ids"]) == 50
|
||||
|
||||
|
||||
def test_scheduler_concurrent_add_remove_task(mock_model_and_tokenizer):
|
||||
"""Test concurrent add and remove task operations."""
|
||||
mock_model, mock_tokenizer = mock_model_and_tokenizer
|
||||
|
||||
with patch("astrai.inference.scheduler.AutoModel"):
|
||||
with patch("astrai.inference.scheduler.AutoTokenizer"):
|
||||
scheduler = InferenceScheduler(
|
||||
model=mock_model,
|
||||
tokenizer=mock_tokenizer,
|
||||
max_batch_size=4,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
results = {"added": [], "removed": [], "errors": []}
|
||||
|
||||
def add_worker():
|
||||
try:
|
||||
for i in range(20):
|
||||
task_id = scheduler.add_task(f"prompt {i}")
|
||||
results["added"].append(task_id)
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
results["errors"].append(f"Add: {str(e)}")
|
||||
|
||||
def remove_worker():
|
||||
try:
|
||||
time.sleep(0.05) # Wait for some tasks to be added
|
||||
for task_id in results["added"][:10]:
|
||||
scheduler.remove_task(task_id)
|
||||
results["removed"].append(task_id)
|
||||
except Exception as e:
|
||||
results["errors"].append(f"Remove: {str(e)}")
|
||||
|
||||
add_thread = threading.Thread(target=add_worker)
|
||||
remove_thread = threading.Thread(target=remove_worker)
|
||||
|
||||
add_thread.start()
|
||||
remove_thread.start()
|
||||
|
||||
time.sleep(0.2)
|
||||
scheduler.stop()
|
||||
|
||||
add_thread.join()
|
||||
remove_thread.join()
|
||||
|
||||
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
||||
assert len(results["added"]) == 20
|
||||
|
||||
|
||||
def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
|
||||
"""Test concurrent get_stats operations."""
|
||||
mock_model, mock_tokenizer = mock_model_and_tokenizer
|
||||
|
||||
with patch("astrai.inference.scheduler.AutoModel"):
|
||||
with patch("astrai.inference.scheduler.AutoTokenizer"):
|
||||
scheduler = InferenceScheduler(
|
||||
model=mock_model,
|
||||
tokenizer=mock_tokenizer,
|
||||
max_batch_size=4,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
results = {"stats": [], "errors": []}
|
||||
|
||||
def add_tasks():
|
||||
try:
|
||||
for i in range(20):
|
||||
scheduler.add_task(f"prompt {i}")
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
results["errors"].append(f"Add: {str(e)}")
|
||||
|
||||
def get_stats():
|
||||
try:
|
||||
for _ in range(50):
|
||||
stats = scheduler.get_stats()
|
||||
results["stats"].append(stats)
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
results["errors"].append(f"Get stats: {str(e)}")
|
||||
|
||||
add_thread = threading.Thread(target=add_tasks)
|
||||
stats_thread = threading.Thread(target=get_stats)
|
||||
|
||||
add_thread.start()
|
||||
stats_thread.start()
|
||||
|
||||
time.sleep(0.3)
|
||||
scheduler.stop()
|
||||
|
||||
add_thread.join()
|
||||
stats_thread.join()
|
||||
|
||||
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
||||
assert len(results["stats"]) == 50
|
||||
|
||||
# Verify stats are consistent
|
||||
for stats in results["stats"]:
|
||||
assert "total_tasks" in stats
|
||||
assert stats["total_tasks"] >= 0
|
||||
|
||||
|
||||
def test_prefix_cache_insert_same_prefix_concurrently():
|
||||
"""Test inserting the same prefix concurrently."""
|
||||
cache = PrefixCacheManager(max_capacity=100)
|
||||
|
||||
results = {"slot_values": [], "errors": []}
|
||||
|
||||
def insert_worker():
|
||||
try:
|
||||
# All workers try to insert the same prefix
|
||||
cache.insert((1, 2, 3), slot=threading.current_thread().name)
|
||||
node = cache.root.children.get(1)
|
||||
if node:
|
||||
node = node.children.get(2)
|
||||
if node:
|
||||
node = node.children.get(3)
|
||||
if node:
|
||||
results["slot_values"].append(node.slot)
|
||||
except Exception as e:
|
||||
results["errors"].append(str(e))
|
||||
|
||||
threads = [threading.Thread(target=insert_worker) for _ in range(10)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# All inserts should succeed, final slot should be one of the values
|
||||
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
||||
# Check ref_count is correct (should be 10)
|
||||
node = cache.root.children.get(1).children.get(2).children.get(3)
|
||||
assert node.ref_count == 10, f"Expected ref_count=10, got {node.ref_count}"
|
||||
|
||||
|
||||
def test_prefix_cache_ref_count_underflow_prevention():
|
||||
"""Test that ref_count doesn't go negative."""
|
||||
cache = PrefixCacheManager(max_capacity=100)
|
||||
|
||||
# Insert a prefix
|
||||
cache.insert((1, 2, 3), slot=0)
|
||||
|
||||
# Release multiple times
|
||||
for _ in range(5):
|
||||
cache.release((1, 2, 3))
|
||||
|
||||
# Try to find it - should return None since ref_count would be negative
|
||||
# or handle it gracefully
|
||||
node = cache.root.children.get(1).children.get(2).children.get(3)
|
||||
# The ref_count should be 0, not negative
|
||||
assert node.ref_count >= 0, f"ref_count went negative: {node.ref_count}"
|
||||
Loading…
Reference in New Issue