chore: 更新文档, 修正代码格式

This commit is contained in:
ViperEkura 2026-04-05 20:59:52 +08:00
parent 23ce4bc3ae
commit d2fe8afbd1
9 changed files with 141 additions and 66 deletions

1
.gitattributes vendored
View File

@ -4,6 +4,7 @@
# Files that MUST use LF (Unix/Linux execution)
*.sh text eol=lf
*.py text eol=lf
*.md text eol=lf
Dockerfile text eol=lf
.dockerignore text eol=lf

View File

@ -47,7 +47,7 @@
### Features
- 🚀 **High Performance**: Optimized for both training and inference with efficient parallelization.
- 🔧 **Flexible**: Support for seq/sft/dpo training, customizable model architectures.
- 🔧 **Flexible**: Support for seq/sft/dpo/grpo training, customizable model architectures.
- 💡 **Easy to Use**: Simple API with comprehensive examples and demos.
- 📦 **Lightweight**: Minimal dependencies, easy to deploy.
- 🔬 **ResearchFriendly**: Modular design, easy to experiment with new ideas.

View File

@ -48,7 +48,7 @@
### 特性
- 🚀 **高性能**: 训练与推理双向优化,高效并行。
- 🔧 **灵活**: 支持 seq/sft/dpo 多种训练方式,可定制模型架构。
- 🔧 **灵活**: 支持 seq/sft/dpo/grpo 多种训练方式,可定制模型架构。
- 💡 **易用**: 简洁的 API 与丰富的示例、演示。
- 📦 **轻量**: 依赖少,部署简单。
- 🔬 **研究友好**: 模块化设计,便于实验新想法。

View File

@ -6,10 +6,11 @@ This document describes the data flow of the AstrAI project (a training and infe
AstrAI adopts a modular design with the following main components:
- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools
- **Model Module** (`astrai/model/`): Transformer model and its submodules
- **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers
- **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation
- **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations
- **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration
- **Parallel Module** (`astrai/parallel/`): Distributed training support
The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**.
@ -42,9 +43,9 @@ flowchart LR
subgraph C[Inference]
direction TB
C1[Checkpoint] --> C2[ModelParameter]
C2 --> C3[Transformer + BpeTokenizer]
C3 --> C4[GenerationRequest + build_prompt]
C1[Checkpoint] --> C2[AutoModel]
C2 --> C3[Transformer + Tokenizer]
C3 --> C4[GenerationRequest + apply_chat_template]
C4 --> C5[InferenceEngine]
C5 --> C6[InferenceScheduler]
C6 --> C7[apply_sampling_strategies]
@ -88,8 +89,9 @@ flowchart LR
### 2. Model Module
#### 2.1 Transformer (`transformer.py`)
- Core autoregressive decoder architecture
#### 2.1 Transformer / AutoModel (`transformer.py`, `automodel.py`)
- **`AutoModel`**: Base class for autoregressive language models with `from_pretrained()` and `save_pretrained()` methods
- **`Transformer`**: Core autoregressive decoder architecture (registered via `@AutoModel.register('transformer')`)
- Contains embedding layer, multi-layer `DecoderBlock`, RMSNorm, and linear output head
- Supports weight tying (`tie_weight=True`) to reduce parameter count
- Uses Rotary Position Embedding (RoPE) to inject position information
@ -122,22 +124,31 @@ flowchart LR
- **`SchedulerFactory`**: Factory pattern, supports registration of various schedulers (such as `cosine`, `sgdr`)
- Scheduler is automatically created according to configuration and bound to optimizer
### 4. Inference Module
### 4. Factory Module
#### 4.1 Inference Engine (`engine.py`)
#### 4.1 Registry and BaseFactory (`factory.py`)
- **`Registry`**: Flexible registry for component classes with category and priority support
- **`BaseFactory`**: Generic factory class for component registration and creation
- Supports decorator-based registration pattern for extensible components
- Provides methods for registration, retrieval, and listing with filtering
### 5. Inference Module
#### 5.1 Inference Engine (`engine.py`)
- **`InferenceEngine`**: Unified inference interface, supports streaming and non-streaming generation
- **`InferenceScheduler`**: Continuous batching scheduler with dynamic batch composition
- Manages task queue (`waiting_queue`, `active_tasks`) and KV cache allocation
#### 4.2 Scheduler (`scheduler.py`)
#### 5.2 Scheduler (`scheduler.py`)
- **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED)
- **`TaskStatus`**: Task state enumeration
- **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits
- Continuous batching: new requests can join at any time, completed requests are released immediately
#### 4.3 Request (`engine.py`)
- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, query, history, etc.)
- **`build_prompt`** (from `chat_template.py`): Converts query and history into ChatML format prompt string
#### 5.3 Request (`engine.py`)
- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.)
- **`messages` format**: List of message dictionaries with `role` (system/user/assistant) and `content`
- **`apply_chat_template`** (from `tokenizer.py`): Converts messages into prompt string using ChatML format
- Provides streaming (`stream=True`) and non-streaming (`stream=False`) generation interfaces
## Training Data Flow - Detailed Steps
@ -176,11 +187,11 @@ flowchart LR
## Inference Data Flow - Detailed Steps
1. **Model Loading**
- Load `Transformer` model and tokenizer from checkpoint
- Load `Transformer` model from checkpoint via `AutoModel.from_pretrained()`
- Set model to evaluation mode (`model.eval()`), enable inference mode (`torch.inference_mode`)
2. **Prompt Construction and Encoding**
- User query and history are converted to ChatML format string through `build_prompt` function in chat_template module
- User messages (list of dict with role and content) are converted to ChatML format string through `apply_chat_template` method in tokenizer
- Tokenizer encodes prompt string to token ID sequence `input_ids`
- For batch generation, use `pad_sequence` for padding
@ -207,5 +218,5 @@ flowchart LR
The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using KV cache and sampling strategies. Clear interfaces between modules facilitate customization and extension.
> Document Update Time: 2026-03-30
> Document Update Time: 2026-04-05
> Corresponding Code Version: Refer to version number defined in `pyproject.toml`

View File

@ -36,10 +36,23 @@ classDiagram
+int batch_size
+int accumulation_steps
+float max_grad_norm
+int start_epoch
+int start_batch
+str ckpt_dir
+int ckpt_interval
+int random_seed
+int num_workers
+int prefetch_factor
+bool pin_memory
+int nprocs
+str backend
+str master_addr
+str master_port
+Callable parallel_wrapper
+Callable state_dict_fn
+List[int] device_ids
+str device_type
+dict extra_kwargs
+validate()
}
@ -123,6 +136,16 @@ classDiagram
}
namespace astrai.model {
class AutoModel {
+ModelConfig config
+Dict _registry
+register(model_type) decorator
+get_model_class(model_type) Type
+from_pretrained(path, disable_random_init) nn.Module
+save_pretrained(save_directory)
+to(*args, **kwargs) Self
}
class Transformer {
+ModelConfig config
+RotaryEmbedding rotary_embeding
@ -440,6 +463,8 @@ classDiagram
DatasetFactory ..> BaseDataset : creates
BaseSegmentFetcher --> MultiSegmentFetcher : used by
MultiSegmentFetcher --> BaseDataset : used by
AutoModel <|-- Transformer
AutoModel --> ModelConfig : contains
Transformer --> DecoderBlock : uses
Transformer --> RotaryEmbedding : uses
Transformer --> Embedding : uses
@ -458,11 +483,12 @@ classDiagram
|--------|------------|-------------|
| **astrai.config** | ModelConfig, TrainConfig, ModelParameter | Configuration management |
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint, DataLoader | Dataset loading and management |
| **astrai.model** | Transformer, DecoderBlock, GQA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | Tokenizer, BpeTokenizer | Tokenizer |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest | Inference service with continuous batching |
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
### Design Patterns
@ -470,12 +496,13 @@ classDiagram
|---------|---------|---------|
| **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO |
| **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components |
| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `CallbackFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, and callbacks |
| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `CallbackFactory`, `BaseFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, and callbacks |
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
| **Singleton** | `TrainContext` | Training process global state management |
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
### Core Relationships
@ -487,6 +514,7 @@ classDiagram
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
8. **Scheduler Support**: `SchedulerFactory` creates learning rate schedulers (CosineScheduler, SGDRScheduler)
9. **AutoModel Loading**: `AutoModel.from_pretrained()` dynamically loads model based on `config.json` model_type, uses `Registry` pattern for model type registration
## 3. Training Process

View File

@ -2,7 +2,21 @@
### 1. Model Architecture
This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA (although KV cache is not currently implemented). The model is built by stacking 24 layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token.
This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA (although KV cache is not currently implemented). The model is built by stacking 32 layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token.
The model now uses the **AutoModel** base class for flexible loading and saving:
```python
from astrai.model import AutoModel
# Load model from checkpoint
model = AutoModel.from_pretrained("path/to/model")
# Save model to new directory
model.save_pretrained("path/to/save")
```
The Transformer model is registered via `@AutoModel.register('transformer')` decorator, allowing easy extension for new model types.
```mermaid
flowchart TB

View File

@ -6,11 +6,12 @@
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `--train_type` | Training type (seq, sft, dpo) | required |
| `--train_type` | Training type (seq, sft, dpo, grpo) | required |
| `--model_type` | Model type for AutoModel loading (e.g., transformer) | transformer |
| `--data_root_path` | Dataset root directory | required |
| `--param_path` | Model parameters or checkpoint path | required |
| `--n_epoch` | Total training epochs | 1 |
| `--batch_size` | Batch size | 1 |
| `--batch_size` | Batch size | 4 |
| `--accumulation_steps` | Gradient accumulation steps | 1 |
### Learning Rate Scheduling
@ -42,7 +43,9 @@
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `--random_seed` | Random seed | 3407 |
| `--num_workers` | DataLoader workers | 4 |
| `--num_workers` | DataLoader workers | 0 |
| `--prefetch_factor` | Prefetch factor for dataloader | None |
| `--pin_memory` | Enable pin_memory | False |
| `--no_pin_memory` | Disable pin_memory | - |
### Distributed Training
@ -71,44 +74,58 @@
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `query` | Input text or text list | required |
| `history` | Conversation history | None |
| `system_prompt` | System prompt | None |
| `temperature` | Sampling temperature (higher = more random) | required |
| `top_p` | Nucleus sampling threshold | required |
| `top_k` | Top-k sampling count | required |
| `max_len` | Maximum generation length | model config max_len |
| `messages` | List of message dictionaries (role, content) | required |
| `temperature` | Sampling temperature (higher = more random) | 1.0 |
| `top_p` | Nucleus sampling threshold | 1.0 |
| `top_k` | Top-k sampling count | 50 |
| `max_len` | Maximum generation length | 1024 |
| `stream` | Whether to stream output | False |
### Usage Example
```python
from astrai.config import ModelParameter
import torch
from astrai.model import AutoModel
from astrai.tokenize import Tokenizer
from astrai.inference import InferenceEngine, GenerationRequest
# Load model
param = ModelParameter.load("your_model_dir")
param.to(device="cuda", dtype=torch.bfloat16)
# Load model using AutoModel
model = AutoModel.from_pretrained("your_model_dir")
# Load tokenizer
tokenizer = Tokenizer("your_model_dir")
# Create engine with separate model and tokenizer
engine = InferenceEngine(
model=param.model,
tokenizer=param.tokenizer,
config=param.config,
model=model,
tokenizer=tokenizer,
)
# Build request
# Build request with messages format
request = GenerationRequest(
query="Hello",
history=[],
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello"},
],
temperature=0.8,
top_p=0.95,
top_k=50,
max_len=1024,
)
# Generate (streaming)
for token in engine.generate_with_request(request):
print(token, end="", flush=True)
# Or use simple generate interface
result = engine.generate(
prompt="Hello",
stream=False,
max_tokens=1024,
temperature=0.8,
top_p=0.95,
top_k=50,
)
```
### Generation Modes

View File

@ -188,8 +188,13 @@ class InferenceEngine:
if stream:
return self._generate_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k,
abort_on_exception
prompts,
is_batch,
max_tokens,
temperature,
top_p,
top_k,
abort_on_exception,
)
else:
return self._generate_non_streaming(
@ -307,16 +312,15 @@ class InferenceEngine:
Use this for emergency shutdown when graceful shutdown is not possible.
"""
import os
# Stop watching threads if any
if hasattr(self, 'stop_watching'):
if hasattr(self, "stop_watching"):
self.stop_watching()
# Unregister signal handlers
if hasattr(self, '_original_sigint'):
if hasattr(self, "_original_sigint"):
signal.signal(signal.SIGINT, self._original_sigint)
if hasattr(self, '_original_sigterm'):
if hasattr(self, "_original_sigterm"):
signal.signal(signal.SIGTERM, self._original_sigterm)
# Force stop scheduler