docs: 更新设计文档

This commit is contained in:
ViperEkura 2026-04-09 20:05:54 +08:00
parent a2ae742988
commit 296db909aa
4 changed files with 252 additions and 97 deletions

View File

@ -12,6 +12,7 @@ AstrAI adopts a modular design with the following main components:
- **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations - **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations
- **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration - **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration
- **Parallel Module** (`astrai/parallel/`): Distributed training support - **Parallel Module** (`astrai/parallel/`): Distributed training support
- **Serialization Module** (`astrai/serialization/`): HDF5 data loading, checkpoint management
The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**. The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**.
@ -21,11 +22,11 @@ The data flow can generally be divided into two main lines: **Training Data Flow
flowchart LR flowchart LR
subgraph A[Data Preparation] subgraph A[Data Preparation]
direction TB direction TB
A1[Raw Text] --> A2[BpeTokenizer] A1[Raw Text] --> A2[AutoTokenizer]
A2 --> A3[Serialize to .h5 files] A2 --> A3[Serialize to .h5 files]
A3 --> A4[BaseDataset] A3 --> A4[BaseDataset]
A4 --> A5[ResumableDistributedSampler] A4 --> A5[ResumableDistributedSampler]
A5 --> A6[DataLoader] A5 --> A6[PyTorch DataLoader]
end end
subgraph B[Training] subgraph B[Training]
@ -50,7 +51,7 @@ flowchart LR
C5 --> C6[InferenceScheduler] C5 --> C6[InferenceScheduler]
C6 --> C7[apply_sampling_strategies] C6 --> C7[apply_sampling_strategies]
C7 --> C8[Transformer Forward] C7 --> C8[Transformer Forward]
C8 --> C9[KV Cache] C8 --> C9[KV Cache + Prefix Cache]
C9 --> C10{End Condition?} C9 --> C10{End Condition?}
C10 -->|No| C8 C10 -->|No| C8
C10 -->|Yes| C11[Output Text] C10 -->|Yes| C11[Output Text]
@ -64,25 +65,18 @@ flowchart LR
### 1. Dataset Module ### 1. Dataset Module
#### 1.1 Tokenizer (`tokenizer.py`) #### 1.1 Serialization (`serialization.py`)
- Implemented based on Byte-Level BPE (BPE)
- Supports special tokens: `<begin▁of▁sentence>`, `<end▁of▁sentence>`, `<▁pad▁>`, `<im▁start>`, `<im▁end>`
- Provides `encode`/`decode` methods for mutual conversion between text and token IDs
- Learns vocabulary from corpus during training, saved as `.json` files
- `BpeTrainer` class handles vocabulary training from corpus
#### 1.2 Serialization (`serialization.py`)
- **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors - **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors
- **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory (`share_memory=True`) - **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory (`share_memory=True`)
- **`Checkpoint` class**: Encapsulates model state dict, training epoch, iteration count; supports safetensors format for saving and loading - **`Checkpoint` class**: Encapsulates model state dict, training epoch, iteration count; supports safetensors format for saving and loading
#### 1.3 Dataset (`dataset.py`) #### 1.2 Dataset (`dataset.py`)
- **`BaseDataset`**: Abstract base class, defines common logic for window sampling, stride, etc. - **`BaseDataset`**: Abstract base class, defines common logic for window sampling, stride, etc.
- **`BaseSegmentFetcher`** and **`MultiSegmentFetcher`**: Efficiently fetch data from specified index ranges in multiple segments - **`BaseSegmentFetcher`** and **`MultiSegmentFetcher`**: Efficiently fetch data from specified index ranges in multiple segments
- **`DatasetFactory`**: Factory pattern, supports dynamic registration of dataset types (`seq`, `sft`, `dpo`, `grpo`) - **`DatasetFactory`**: Factory pattern, supports dynamic registration of dataset types (`seq`, `sft`, `dpo`, `grpo`)
- After dataset loading, multiple data keys (such as `"sequence"`, `"mask"`) are managed through `MultiSegmentFetcher` - After dataset loading, multiple data keys (such as `"sequence"`, `"mask"`) are managed through `MultiSegmentFetcher`
#### 1.4 Sampler (`sampler.py`) #### 1.3 Sampler (`sampler.py`)
- **`ResumableDistributedSampler`**: Resumable sampler supporting distributed training - **`ResumableDistributedSampler`**: Resumable sampler supporting distributed training
- Records current epoch and iteration position, enabling training resume from breakpoints - Records current epoch and iteration position, enabling training resume from breakpoints
- Supports shuffle and drop_last options - Supports shuffle and drop_last options
@ -99,7 +93,10 @@ flowchart LR
#### 2.2 Submodules (`module.py`) #### 2.2 Submodules (`module.py`)
- **`RotaryEmbedding`**: Generates RoPE cos/sin cache - **`RotaryEmbedding`**: Generates RoPE cos/sin cache
- **`DecoderBlock`**: Contains multi-head attention (supports GQA), feedforward network (FFN), residual connections - **`DecoderBlock`**: Contains multi-head attention (supports GQA and MLA), feedforward network (FFN), residual connections
- **`GQA`**: Grouped Query Attention implementation
- **`MLA`**: Multi-Latent Attention implementation (like Qwen2-VL)
- **`MLP`**: Feed-forward network with SiLU activation and gated mechanism
- **`RMSNorm`**: Layer normalization variant - **`RMSNorm`**: Layer normalization variant
- **`Linear`**, **`Embedding`**: Custom linear layer and embedding layer, supporting parallelism wrappers - **`Linear`**, **`Embedding`**: Custom linear layer and embedding layer, supporting parallelism wrappers
@ -116,15 +113,29 @@ flowchart LR
1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end` 1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end`
#### 3.3 Strategy (`strategy.py`) #### 3.3 Strategy (`strategy.py`)
- **`BaseStrategy`**: Defines training strategy interface (such as `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`) - **`BaseStrategy`**: Defines training strategy interface
- **`SEQStrategy`**: Standard next-token prediction training
- **`SFTStrategy`**: Supervised Fine-tuning with loss masking
- **`DPOStrategy`**: Direct Preference Optimization
- **`GRPOStrategy`**: Group Relative Policy Optimization
- Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor - Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor
- Created dynamically by `StrategyFactory` according to configuration - Created dynamically by `StrategyFactory` according to configuration
#### 3.4 Scheduler (`schedule.py`) #### 3.4 Scheduler (`schedule.py`)
- **`BaseScheduler`**: Abstract base class defining learning rate scheduling interface - **`BaseScheduler`**: Abstract base class defining learning rate scheduling interface
- **`SchedulerFactory`**: Factory pattern, supports registration of various schedulers (such as `cosine`, `sgdr`) - **`CosineScheduler`**: Cosine decay scheduler with warmup
- **`SGDRScheduler`**: Stochastic Gradient Descent with Warm Restarts
- **`SchedulerFactory`**: Factory pattern, supports registration of various schedulers
- Scheduler is automatically created according to configuration and bound to optimizer - Scheduler is automatically created according to configuration and bound to optimizer
#### 3.5 Callbacks (`train_callback.py`)
- **`TrainCallback`**: Protocol interface for trainer callbacks
- **`CheckpointCallback`**: Saves model checkpoints at configurable intervals
- **`ProgressBarCallback`**: Displays training progress
- **`MetricLoggerCallback`**: Logs training metrics to JSON files
- **`GradientClippingCallback`**: Clips gradient norms
- **`SchedulerCallback`**: Steps learning rate scheduler
### 4. Factory Module ### 4. Factory Module
#### 4.1 Registry and BaseFactory (`factory.py`) #### 4.1 Registry and BaseFactory (`factory.py`)
@ -133,9 +144,24 @@ flowchart LR
- Supports decorator-based registration pattern for extensible components - Supports decorator-based registration pattern for extensible components
- Provides methods for registration, retrieval, and listing with filtering - Provides methods for registration, retrieval, and listing with filtering
### 5. Inference Module ### 5. Parallel Module
#### 5.1 Inference Engine (`engine.py`) #### 5.1 Setup (`setup.py`)
- **`spawn_parallel_fn`**: Spawns multiple processes for distributed training using PyTorch multiprocessing
- **`setup_parallel`**: Context manager for initializing distributed process group (NCCL/CCL backend)
- **`only_on_rank`**: Decorator to execute functions only on specific ranks
- **`get_rank`**: Returns current process rank in distributed group
- **`get_world_size`**: Returns total number of processes in distributed group
- **`get_current_device`**: Returns current device from environment
#### 5.2 Parallel Layers (`module.py`)
- **`ParallelModel`**: Base class for parallel models with process group
- **`ColumnParallelLinear`**: Column-parallel linear layer with input splitting and output gathering
- **`RowParallelLinear`**: Row-parallel linear layer with output reduction
### 6. Inference Module
#### 6.1 Inference Engine (`engine.py`)
- **`InferenceEngine`**: Unified inference interface, supports streaming and non-streaming generation - **`InferenceEngine`**: Unified inference interface, supports streaming and non-streaming generation
- **`InferenceScheduler`**: Continuous batching scheduler with dynamic batch composition - **`InferenceScheduler`**: Continuous batching scheduler with dynamic batch composition
- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.) - **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.)
@ -145,22 +171,38 @@ flowchart LR
- Supports continuous batching with `max_batch_size` and `max_seq_len` parameters - Supports continuous batching with `max_batch_size` and `max_seq_len` parameters
- Uses separate model and tokenizer initialization for flexibility - Uses separate model and tokenizer initialization for flexibility
#### 5.2 Scheduler (`scheduler.py`) #### 6.2 Scheduler (`scheduler.py`)
- **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED) - **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED)
- **`TaskStatus`**: Task state enumeration - **`TaskStatus`**: Task state enumeration
- **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits - **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits
- **`PrefixCacheManager`**: Radix tree-based prefix cache with LRU eviction for efficient KV cache reuse
- **`RadixNode`**: Tree node structure for prefix caching
- Continuous batching: new requests can join at any time, completed requests are released immediately - Continuous batching: new requests can join at any time, completed requests are released immediately
#### 5.3 Request (`engine.py`) #### 6.3 Server (`server.py`)
- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.) - FastAPI-based HTTP inference server
- **`messages` format**: List of message dictionaries with `role` (system/user/assistant) and `content` - OpenAI-compatible `/v1/chat/completions` endpoint
- **`apply_chat_template`** (from `tokenizer.py`): Converts messages into prompt string using ChatML format - Health check and statistics endpoints
- Provides streaming (`stream=True`) and non-streaming (`stream=False`) generation interfaces - Supports both streaming and non-streaming responses
### 7. Tokenizer Module
#### 7.1 Tokenizer (`tokenizer.py`)
- Implemented based on HuggingFace tokenizers library (Byte-Level BPE)
- **`AutoTokenizer`**: Auto-loading tokenizer class
- Supports special tokens: `<begin▁of▁sentence>`, `<end▁of▁sentence>`, `<▁pad▁>`, `<im▁start>`, `<im▁end>`
- Provides `encode`/`decode` methods for mutual conversion between text and token IDs
- Uses `AutoTokenizer` for loading pre-trained tokenizers
#### 7.2 Chat Template (`chat_template.py`)
- **`ChatTemplate`**: Jinja2-based chat template with rendering support
- Handles multi-role message formatting (system, user, assistant)
- Supports dynamic prompts and generation prompts
## Training Data Flow - Detailed Steps ## Training Data Flow - Detailed Steps
1. **Data Preparation** 1. **Data Preparation**
- Raw text is converted to token ID sequences through BPE tokenizer - Raw text is converted to token ID sequences through AutoTokenizer
- Token ID sequences (possibly with masks, labels, etc.) are saved by groups as `.h5` files - Token ID sequences (possibly with masks, labels, etc.) are saved by groups as `.h5` files
- Files can contain multiple segments, each segment corresponds to a tensor - Files can contain multiple segments, each segment corresponds to a tensor
@ -171,7 +213,7 @@ flowchart LR
3. **Sampling and Batch Loading** 3. **Sampling and Batch Loading**
- `ResumableDistributedSampler` generates index sequence based on current epoch and iteration position - `ResumableDistributedSampler` generates index sequence based on current epoch and iteration position
- `DataLoader` uses sampler to get indices, calls dataset's `__getitem__` to get actual data - PyTorch `DataLoader` uses sampler to get indices, calls dataset's `__getitem__` to get actual data
- Batch data shape is `[batch_size, window_size]` (or varies according to specific dataset type) - Batch data shape is `[batch_size, window_size]` (or varies according to specific dataset type)
4. **Strategy Forward and Loss Calculation** 4. **Strategy Forward and Loss Calculation**
@ -202,7 +244,7 @@ flowchart LR
- For batch generation, use `pad_sequence` for padding - For batch generation, use `pad_sequence` for padding
3. **Autoregressive Generation Loop** 3. **Autoregressive Generation Loop**
- Initialize KV cache (optional) - Initialize KV cache (optional) and prefix cache
- Loop until generating `max_len` tokens or encountering stop token: - Loop until generating `max_len` tokens or encountering stop token:
- Input current `input_ids` (or cached new token) to model, obtain `logits` - Input current `input_ids` (or cached new token) to model, obtain `logits`
- Apply `apply_sampling_strategies` (temperature, top-k, top-p) to `logits` - Apply `apply_sampling_strategies` (temperature, top-k, top-p) to `logits`
@ -222,7 +264,6 @@ flowchart LR
## Summary ## Summary
The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using KV cache and sampling strategies. Clear interfaces between modules facilitate customization and extension. The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using KV cache, prefix caching, and sampling strategies. Clear interfaces between modules facilitate customization and extension.
> Document Update Time: 2026-04-05 > Document Update Time: 2026-04-09
> Corresponding Code Version: Refer to version number defined in `pyproject.toml`

View File

@ -8,7 +8,7 @@ Thus, the AstrAI project was born - 1B parameters, Chinese-English bilingual, su
```mermaid ```mermaid
classDiagram classDiagram
namespace astrai.config { namespace config {
class ModelConfig { class ModelConfig {
+int vocab_size +int vocab_size
+int dim +int dim
@ -56,17 +56,9 @@ classDiagram
+validate() +validate()
} }
class ModelParameter {
+nn.Module model
+BpeTokenizer tokenizer
+ModelConfig config
+save(instance, save_dir)
+load(load_dir, disable_init) ModelParameter
+to(*args, **kwargs)
}
} }
namespace astrai.dataset { namespace dataset {
class BaseDataset { class BaseDataset {
+int window_size +int window_size
+int stride +int stride
@ -125,17 +117,9 @@ classDiagram
+save(save_dir) +save(save_dir)
+load(save_dir) Checkpoint +load(save_dir) Checkpoint
} }
class DataLoader {
+Dataset dataset
+int batch_size
+Sampler sampler
+__iter__()
+__len__()
}
} }
namespace astrai.model { namespace model {
class AutoModel { class AutoModel {
+ModelConfig config +ModelConfig config
+Dict _registry +Dict _registry
@ -216,24 +200,46 @@ classDiagram
} }
} }
namespace astrai.tokenize { namespace tokenize {
class Tokenizer { class AutoTokenizer {
+encode(tokens, out_ids, add_special_tokens) List~int~
+decode(tokens, skip_special_tokens) str
+__len__() int
}
class BpeTokenizer {
+List~str~ stop_ids +List~str~ stop_ids
+int bos_id +int bos_id
+int eos_id +int eos_id
+int pad_id +int pad_id
+vocab_size int
+encode(tokens, out_ids, add_special_tokens) List~int~ +encode(tokens, out_ids, add_special_tokens) List~int~
+decode(tokens, skip_special_tokens) str +decode(tokens, skip_special_tokens) str
+apply_chat_template(messages, tokenize) Union~str, List[int]~
+set_chat_template(template)
+load(path)
+from_pretrained(path) AutoTokenizer
+save_pretrained(save_path)
}
class ChatTemplate {
+String template_str
+render(messages, add_generation_prompt) str
+from_string(template) ChatTemplate
} }
} }
namespace astrai.trainer { namespace factory {
class Registry {
+Dict _entries
+register(name, component_cls, category, priority)
+get(name) Type
+list_names() List~str~
}
class BaseFactory {
+Registry _registry
+register(name, category, priority) decorator
+create(name, *args, **kwargs) T
+list_registered() list
}
}
namespace trainer {
class Trainer { class Trainer {
+TrainConfig train_config +TrainConfig train_config
+List~TrainCallback~ callbacks +List~TrainCallback~ callbacks
@ -337,6 +343,39 @@ classDiagram
+on_error(context) +on_error(context)
} }
class GradientClippingCallback {
+float max_grad_norm
+on_step_begin(context)
}
class SchedulerCallback {
+on_train_begin(context)
+on_batch_end(context)
}
class CheckpointCallback {
+str save_dir
+int interval
+_save_checkpoint(context)
+on_batch_end(context)
+on_train_end(context)
+on_error(context)
}
class ProgressBarCallback {
+int num_epoch
+on_epoch_begin(context)
+on_batch_end(context)
+on_epoch_end(context)
}
class MetricLoggerCallback {
+str log_dir
+int save_interval
+on_batch_end(context)
+on_train_end(context)
}
class CallbackFactory { class CallbackFactory {
+Registry _registry +Registry _registry
+register(name) decorator +register(name) decorator
@ -344,10 +383,17 @@ classDiagram
} }
} }
namespace astrai.inference { namespace inference {
class InferenceEngine { class InferenceEngine {
+ModelParameter parameter +nn.Module model
+AutoTokenizer tokenizer
+InferenceScheduler scheduler +InferenceScheduler scheduler
+int max_batch_size
+Optional int max_seq_len
+int max_prefix_len
+int cache_capacity
+Tensor kv_cache
+Tensor seq_mask
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]] +generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
+generate_with_request(request) Union[Generator, str, List[str]] +generate_with_request(request) Union[Generator, str, List[str]]
+get_stats() Dict +get_stats() Dict
@ -356,10 +402,11 @@ classDiagram
class InferenceScheduler { class InferenceScheduler {
+nn.Module model +nn.Module model
+Tokenizer tokenizer +AutoTokenizer tokenizer
+ModelConfig config +ModelConfig config
+Tuple kv_cache +Tuple kv_cache
+Tensor seq_mask +Tensor seq_mask
+PrefixCacheManager prefix_cache
+List waiting_queue +List waiting_queue
+List active_tasks +List active_tasks
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str +add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
@ -369,6 +416,24 @@ classDiagram
+get_stats() Dict +get_stats() Dict
} }
class PrefixCacheManager {
+RadixNode root
+int max_capacity
+List lru
+insert(token_ids, slot)
+find_longest_prefix(token_ids) Tuple[int, int]
+release(token_ids)
}
class RadixNode {
+Dict children
+int hash
+int slot
+int ref_count
+float last_access
+List token_sequence
}
class Task { class Task {
+str task_id +str task_id
+List prompt_ids +List prompt_ids
@ -392,14 +457,6 @@ classDiagram
+str ABORTED +str ABORTED
} }
class apply_sampling_strategies {
+Tensor logits
+float temperature
+int top_k
+float top_p
+forward() Tensor
}
class Server { class Server {
+start() +start()
+predict(request) +predict(request)
@ -410,19 +467,54 @@ classDiagram
+float top_p +float top_p
+float temperature +float temperature
+int max_len +int max_len
+Union~str, List~str~~ query +List~Dict~ messages
+history Optional
+system_prompt Optional~str~
+stream bool +stream bool
} }
class _Result {
+List~str~ tokens
+List~str~ results
+List~bool~ done_flags
+append(token, idx)
+get_results() List~str~
}
class ChatMessage {
+str role
+str content
}
class ChatCompletionRequest {
+List~ChatMessage~ messages
+float temperature
+float top_p
+int top_k
+int max_tokens
+bool stream
+Optional~str~ system_prompt
}
class CompletionResponse {
+str id
+str object
+int created
+str model
+List~Dict~ choices
}
} }
namespace astrai.parallel { namespace parallel {
class ParallelSetup { class ParallelSetup {
+spawn_parallel_fn(fn, nprocs) +spawn_parallel_fn(fn, nprocs)
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type, device_ids) +setup_parallel(rank, world_size, backend, master_addr, master_port, device_type, device_ids)
} }
class ParallelModel {
+dist.ProcessGroup process_group
+int rank
+int world_size
}
class ColumnParallelLinear { class ColumnParallelLinear {
+forward(x) Tensor +forward(x) Tensor
} }
@ -433,9 +525,16 @@ classDiagram
} }
%% Relationships %% Relationships
TrainConfig --> ModelConfig : contains TrainConfig --> ModelConfig : uses
TrainConfig --> BaseDataset : uses TrainConfig --> BaseDataset : uses
TrainConfig --> Transformer : uses TrainConfig --> StrategyFactory : selects
StrategyFactory ..> BaseStrategy : creates
BaseStrategy <|-- SEQStrategy
BaseStrategy <|-- SFTStrategy
BaseStrategy <|-- DPOStrategy
BaseStrategy <|-- GRPOStrategy
DPOStrategy --> Transformer : uses
GRPOStrategy --> Transformer : uses
Trainer --> TrainConfig : configures Trainer --> TrainConfig : configures
Trainer --> TrainContextBuilder : builds Trainer --> TrainContextBuilder : builds
Trainer --> TrainCallback : manages Trainer --> TrainCallback : manages
@ -443,30 +542,27 @@ classDiagram
TrainContext --> Checkpoint : manages TrainContext --> Checkpoint : manages
TrainContext --> BaseStrategy : uses TrainContext --> BaseStrategy : uses
TrainContext --> BaseScheduler : uses TrainContext --> BaseScheduler : uses
StrategyFactory ..> BaseStrategy : creates AutoModel --> ModelConfig : contains
BaseStrategy <|-- SEQStrategy
BaseStrategy <|-- SFTStrategy
BaseStrategy <|-- DPOStrategy
BaseStrategy <|-- GRPOStrategy
DPOStrategy --> Transformer : creates ref_model
GRPOStrategy --> Transformer : creates ref_model
SchedulerFactory ..> BaseScheduler : creates SchedulerFactory ..> BaseScheduler : creates
BaseScheduler <|-- CosineScheduler BaseScheduler <|-- CosineScheduler
BaseScheduler <|-- SGDRScheduler BaseScheduler <|-- SGDRScheduler
CallbackFactory ..> TrainCallback : creates CallbackFactory ..> TrainCallback : creates
TrainCallback <|-- GradientClippingCallback
TrainCallback <|-- SchedulerCallback
TrainCallback <|-- CheckpointCallback
TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback
InferenceEngine --> InferenceScheduler : uses InferenceEngine --> InferenceScheduler : uses
InferenceScheduler --> Task : manages InferenceScheduler --> Task : manages
InferenceScheduler --> TaskStatus : uses InferenceScheduler --> TaskStatus : uses
InferenceScheduler --> apply_sampling_strategies : uses
InferenceScheduler --> Transformer : uses InferenceScheduler --> Transformer : uses
InferenceEngine --> Transformer : uses InferenceEngine --> Transformer : uses
InferenceEngine --> GenerationRequest : uses InferenceEngine --> GenerationRequest : uses
Server --> InferenceEngine : uses Server --> InferenceEngine : uses
Server --> ChatMessage : uses
Server --> ChatCompletionRequest : uses
Server --> CompletionResponse : uses
ParallelSetup --> Trainer : enables ParallelSetup --> Trainer : enables
TrainConfig --> StrategyFactory : selects
ModelParameter --> Transformer : contains
ModelParameter --> BpeTokenizer : contains
ModelParameter --> ModelConfig : contains
BaseDataset <|-- SEQDataset BaseDataset <|-- SEQDataset
BaseDataset <|-- SFTDataset BaseDataset <|-- SFTDataset
BaseDataset <|-- DPODataset BaseDataset <|-- DPODataset
@ -483,22 +579,34 @@ classDiagram
DecoderBlock --> MLA : uses DecoderBlock --> MLA : uses
DecoderBlock --> MLP : uses DecoderBlock --> MLP : uses
DecoderBlock --> RMSNorm : uses DecoderBlock --> RMSNorm : uses
BpeTokenizer --> Tokenizer : inherits
TrainContextBuilder --> ResumableDistributedSampler : creates TrainContextBuilder --> ResumableDistributedSampler : creates
DataLoader --> BaseDataset : uses
ResumableDistributedSampler --> BaseDataset : samples ResumableDistributedSampler --> BaseDataset : samples
ParallelModel <|-- RowParallelLinear
ParallelModel <|-- ColumnParallelLinear
AutoTokenizer --> ChatTemplate : uses
InferenceScheduler --> PrefixCacheManager : uses
InferenceScheduler --> RadixNode : uses
Checkpoint ..> Checkpoint : saves/loads
TrainConfig --> DatasetFactory : selects
TrainConfig --> SchedulerFactory : selects
TrainConfig --> CallbackFactory : selects
AutoModel ..> AutoTokenizer : loads with
BaseFactory <|-- DatasetFactory
BaseFactory <|-- StrategyFactory
BaseFactory <|-- SchedulerFactory
BaseFactory <|-- CallbackFactory
``` ```
### Module Overview ### Module Overview
| Module | Components | Description | | Module | Components | Description |
|--------|------------|-------------| |--------|------------|-------------|
| **astrai.config** | ModelConfig, TrainConfig, ModelParameter | Configuration management | | **astrai.config** | ModelConfig, TrainConfig | Configuration management |
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint, DataLoader | Dataset loading and management | | **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint | Dataset loading and management |
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, BpeTokenizer, ChatTemplate, BpeTrainer | Tokenizer | | **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management | | **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest | Inference service with continuous batching | | **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest, PrefixCacheManager, ChatMessage, ChatCompletionRequest, CompletionResponse | Inference service with continuous batching |
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel | | **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
| **astrai.factory** | Registry, BaseFactory | Generic component registration | | **astrai.factory** | Registry, BaseFactory | Generic component registration |
@ -515,7 +623,7 @@ classDiagram
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management | | **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module | | **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern | | **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
| **Generator Pattern** | `_StreamingResult`, `_NonStreamingResult` | Event-based result notification for streaming/non-streaming generation | | **Generator Pattern** | `_Result`, `GenerationRequest` | Event-based result notification for streaming/non-streaming generation |
### Core Relationships ### Core Relationships
@ -582,3 +690,5 @@ $$
The final loss is the sum of both: $L = L_{\text{policy}} + L_{KL}$ The final loss is the sum of both: $L = L_{\text{policy}} + L_{KL}$
Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence. Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence.
> Document Update Time: 2026-04-09

View File

@ -2,7 +2,7 @@
### 1. Model Architecture ### 1. Model Architecture
This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA (although KV cache is not currently implemented). The model is built by stacking 32 layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token. This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA. The model is built by stacking 32 layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token.
The model now uses the **AutoModel** base class for flexible loading and saving: The model now uses the **AutoModel** base class for flexible loading and saving:
@ -295,3 +295,5 @@ curl http://localhost:8000/health
curl http://localhost:8000/stats curl http://localhost:8000/stats
# {"requests_total": 10, "tokens_generated": 5000, ...} # {"requests_total": 10, "tokens_generated": 5000, ...}
``` ```
> Document Update Time: 2026-04-09

View File

@ -137,3 +137,5 @@ result = engine.generate(
|------|-------------| |------|-------------|
| `stream=True` | Streaming output, yields token by token | | `stream=True` | Streaming output, yields token by token |
| `stream=False` | Non-streaming output, returns complete result | | `stream=False` | Non-streaming output, returns complete result |
> Document Update Time: 2026-04-09