docs: 更新文档类图等

This commit is contained in:
ViperEkura 2026-04-04 18:11:36 +08:00
parent c94a246c71
commit 99b821ebf5
2 changed files with 453 additions and 184 deletions

View File

@ -5,7 +5,7 @@ This document describes the data flow of the AstrAI project (a training and infe
## Overview
AstrAI adopts a modular design with the following main components:
- **Data Module** (`astrai/data/`): Dataset, sampler, tokenizer, serialization tools
- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools
- **Model Module** (`astrai/model/`): Transformer model and its submodules
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers
- **Inference Module** (`astrai/inference/`): Generation core, KV cache management, streaming generation
@ -20,35 +20,39 @@ The data flow can generally be divided into two main lines: **Training Data Flow
flowchart LR
subgraph A[Data Preparation]
direction TB
A1[Raw Text] --> A2[BBPE Tokenizer]
A1[Raw Text] --> A2[BpeTokenizer]
A2 --> A3[Serialize to .h5 files]
A3 --> A4[Dataset Loading<br/>BaseDataset]
A4 --> A5[Resumable Distributed Sampler<br/>ResumableDistributedSampler]
A5 --> A6[DataLoader Batch Loading]
A3 --> A4[BaseDataset]
A4 --> A5[ResumableDistributedSampler]
A5 --> A6[DataLoader]
end
subgraph B[Training Loop]
subgraph B[Training]
direction TB
B1[Batch Data] --> B2[Training Strategy<br/>BaseStrategy]
B2 --> B3[Transformer Model]
B3 --> B4[Output logits]
B4 --> B5[Loss Calculation]
B5 --> B6[Backpropagation]
B6 --> B7[Optimizer Update]
B7 --> B8[Learning Rate Scheduler]
B8 --> B9[Checkpoint Save]
B1[Batch Data] --> B2[TrainContextBuilder]
B2 --> B3[TrainContext]
B3 --> B4[BaseStrategy]
B4 --> B5[Transformer]
B5 --> B6[Compute Loss]
B6 --> B7[Backward]
B7 --> B8[Optimizer]
B8 --> B9[LRScheduler]
B9 --> B10[CheckpointCallback]
end
subgraph C[Inference Generation]
subgraph C[Inference]
direction TB
C1[Checkpoint Loading] --> C2[Inference Model Loading]
C2 --> C3[Generation Core<br/>GeneratorCore]
C3 --> C4[Sampling Strategy<br/>Temperature/top-k/top-p]
C4 --> C5[Generate Next Token]
C5 --> C6[KV Cache Update]
C6 --> C7{Max Length Reached?}
C7 -->|No| C5
C7 -->|Yes| C8[Output Generated Text]
C1[Checkpoint] --> C2[ModelParameter]
C2 --> C3[Transformer + BpeTokenizer]
C3 --> C4[GenerationRequest + build_prompt]
C4 --> C5[GeneratorFactory]
C5 --> C6[GeneratorCore]
C6 --> C7[apply_sampling_strategies]
C7 --> C8[Transformer Forward]
C8 --> C9[KVCacheManager]
C9 --> C10{End Condition?}
C10 -->|No| C8
C10 -->|Yes| C11[Output Text]
end
A --> B
@ -57,13 +61,14 @@ flowchart LR
## Detailed Module Descriptions
### 1. Data Module
### 1. Dataset Module
#### 1.1 Tokenizer (`tokenizer.py`)
- Implemented based on Byte-Level BPE (BBPE)
- Implemented based on Byte-Level BPE (BPE)
- Supports special tokens: `<begin▁of▁sentence>`, `<end▁of▁sentence>`, `<▁pad▁>`, `<im▁start>`, `<im▁end>`
- Provides `encode`/`decode` methods for mutual conversion between text and token IDs
- Learns vocabulary from corpus during training, saved as `.json` files
- `BpeTrainer` class handles vocabulary training from corpus
#### 1.2 Serialization (`serialization.py`)
- **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors
@ -108,7 +113,7 @@ flowchart LR
1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end`
#### 3.3 Strategy (`strategy.py`)
- **`BaseStrategy`**: Defines training strategy interface (such as `SeqStrategy`, `SFTStrategy`, `DPOStrategy`)
- **`BaseStrategy`**: Defines training strategy interface (such as `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`)
- Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor
- Created dynamically by `StrategyFactory` according to configuration
@ -130,14 +135,16 @@ flowchart LR
#### 4.3 Generator (`generator.py`)
- **`GenerationRequest`**: Encapsulates generation request parameters (top_k, top_p, temperature, max_len, query, history, etc.)
- **`build_prompt`**: Converts query and history into ChatML format prompt string
- **`build_prompt`** (from `chat_template.py`): Converts query and history into ChatML format prompt string
- **`GeneratorCore`**: Base generator with generate_iterator and generate_loop methods
- **`LoopGenerator`**, **`StreamGenerator`**, **`BatchGenerator`**: Different generation modes
- **`pad_sequence`**: Pads input IDs to consistent length
- Provides streaming and non-streaming generation interfaces
## Training Data Flow - Detailed Steps
1. **Data Preparation**
- Raw text is converted to token ID sequences through BBPE tokenizer
- Raw text is converted to token ID sequences through BPE tokenizer
- Token ID sequences (possibly with masks, labels, etc.) are saved by groups as `.h5` files
- Files can contain multiple segments, each segment corresponds to a tensor
@ -152,7 +159,7 @@ flowchart LR
- Batch data shape is `[batch_size, window_size]` (or varies according to specific dataset type)
4. **Strategy Forward and Loss Calculation**
- Batch data is passed to strategy (such as `SeqStrategy`)
- Batch data is passed to strategy (such as `SEQStrategy`)
- Strategy internally calls `Transformer` model, obtaining logits
- Calculate cross-entropy loss (or DPO loss, etc.) according to task type
- Return loss tensor
@ -174,7 +181,7 @@ flowchart LR
- Set model to evaluation mode (`model.eval()`), enable inference mode (`torch.inference_mode`)
2. **Prompt Construction and Encoding**
- User query and history are converted to ChatML format string through `build_prompt`
- User query and history are converted to ChatML format string through `build_prompt` function in chat_template module
- Tokenizer encodes prompt string to token ID sequence `input_ids`
- For batch generation, use `pad_sequence` for padding

View File

@ -8,206 +8,468 @@ Thus, the AstrAI project was born - 1B parameters, Chinese-English bilingual, su
```mermaid
classDiagram
%% Configuration Classes
class ModelConfig {
+int vocab_size
+int dim
+int n_layers
+float norm_eps
+int dim_ffn
+int max_len
+float rope_theta
+int n_heads
+int n_kv_heads
+bool use_qk_norm
+bool use_gated_attention
+load(config_path) ModelConfig
+save(config_path)
namespace astrai.config {
class ModelConfig {
+int vocab_size
+int dim
+int n_layers
+float norm_eps
+int dim_ffn
+bool tie_weight
+int max_len
+float rope_theta
+int n_heads
+int n_kv_heads
+bool use_qk_norm
+bool use_gated_attention
+load(config_path) ModelConfig
+save(config_path)
}
class TrainConfig {
+nn.Module model
+str strategy
+Dataset dataset
+Callable optimizer_fn
+Callable scheduler_fn
+int n_epoch
+int batch_size
+int accumulation_steps
+float max_grad_norm
+str ckpt_dir
+int ckpt_interval
+int nprocs
+str backend
+validate()
}
class ModelParameter {
+nn.Module model
+BpeTokenizer tokenizer
+ModelConfig config
+save(instance, save_dir)
+load(load_dir, disable_init) ModelParameter
+to(*args, **kwargs)
}
}
class TrainConfig {
+nn.Module model
+str strategy
+Dataset dataset
+Callable optimizer_fn
+Callable scheduler_fn
+int n_epoch
+int batch_size
+int accumulation_steps
+float max_grad_norm
+str ckpt_dir
+int ckpt_interval
+int nprocs
+str backend
+validate()
namespace astrai.dataset {
class BaseDataset {
+int window_size
+int stride
+MultiSegmentFetcher fetcher
+load(load_path)
+__getitem__(index)
+__len__()
}
class SEQDataset {
+__getitem__(index) Dict
}
class SFTDataset {
+__getitem__(index) Dict
}
class DPODataset {
+__getitem__(index) Dict
}
class GRPODataset {
+__getitem__(index) Dict
}
class BaseSegmentFetcher {
+List~Tensor~ segments
+List~int~ cum_lengths
+int total_length
+fetch_data(begin_idx, end_idx) Tensor
}
class MultiSegmentFetcher {
+Dict muti_fetchers
+List muti_keys
+key_fetch(begin_idx, end_idx, keys) Dict
+fetch_data(begin_idx, end_idx) Dict
}
class ResumableDistributedSampler {
+int start_epoch
+int start_iter
}
class DatasetFactory {
+Registry _registry
+register(name) decorator
+create(train_type, window_size, stride) BaseDataset
+load(train_type, load_path, window_size, stride) BaseDataset
}
class Checkpoint {
+dict state_dict
+int epoch
+int iteration
+save(save_dir)
+load(save_dir) Checkpoint
}
class DataLoader {
+Dataset dataset
+int batch_size
+Sampler sampler
+__iter__()
+__len__()
}
}
%% Data Classes
class Dataset {
+__len__()
+__getitem__()
namespace astrai.model {
class Transformer {
+ModelConfig config
+RotaryEmbedding rotary_embeding
+Embedding embed_tokens
+ModuleList layers
+RMSNorm norm
+Linear lm_head
+forward(input_ids, input_mask, persistent_key_values, start_pos) Dict
+load_state_dict(state_dict)
+state_dict()
}
class DecoderBlock {
+GQA attention
+RMSNorm input_norm
+MLP mlp
+RMSNorm post_attention_norm
+forward(x, rotary_emb, attention_mask, kv_cache, start_pos) Tensor
}
class GQA {
+int n_heads
+int n_kv_heads
+int head_dim
+Linear q_proj, k_proj, v_proj, o_proj
+RMSNorm q_norm, k_norm
+forward(x, rotary_emb, mask, kv_cache, start_pos) Tensor
}
class MLP {
+Linear up, gate, down
+forward(x) Tensor
}
class RMSNorm {
+Parameter weight
+float norm_eps
+forward(x) Tensor
}
class Linear {
+Parameter weight
+Parameter bias
+forward(x) Tensor
}
class RotaryEmbedding {
+int dim
+int max_len
+float base
+forward(x, start_pos) Tuple~Tensor, Tensor~
}
class Embedding {
+Parameter weight
+forward(x) Tensor
}
}
class Checkpoint {
+dict state_dict
+int epoch
+int iteration
namespace astrai.tokenize {
class Tokenizer {
+encode(tokens, out_ids, add_special_tokens) List~int~
+decode(tokens, skip_special_tokens) str
+__len__() int
}
class BpeTokenizer {
+List~str~ stop_ids
+int bos_id
+int eos_id
+int pad_id
+encode(tokens, out_ids, add_special_tokens) List~int~
+decode(tokens, skip_special_tokens) str
}
}
class Tokenizer {
+encode(text) List[int]
+decode(ids) str
namespace astrai.trainer {
class Trainer {
+TrainConfig train_config
+List~TrainCallback~ callbacks
+train(checkpoint)
+_build_context(checkpoint) TrainContext
+_get_default_callbacks() List~TrainCallback~
}
class TrainContext {
+nn.Module model
+BaseStrategy strategy
+DataLoader dataloader
+Optimizer optimizer
+LRScheduler scheduler
+Checkpoint checkpoint
+int epoch
+int iteration
+float loss
+int world_size
+int rank
}
class TrainContextBuilder {
+TrainConfig config
+with_checkpoint(checkpoint) TrainContextBuilder
+with_dataloader() TrainContextBuilder
+with_strategy() TrainContextBuilder
+build() TrainContext
}
class BaseStrategy {
+nn.Module model
+str device
+compute_loss(batch) Tensor
}
class StrategyFactory {
+Registry _registry
+register(name) decorator
+create(model, train_type, device, **kwargs) BaseStrategy
}
class SEQStrategy {
+float label_smoothing
+compute_loss(batch) Tensor
}
class SFTStrategy {
+float label_smoothing
+compute_loss(batch) Tensor
}
class DPOStrategy {
+nn.Module ref_model
+float beta
+str reduction
+compute_loss(batch) Tensor
}
class GRPOStrategy {
+nn.Module ref_model
+float clip_eps
+float kl_coef
+int group_size
+compute_loss(batch) Tensor
}
class BaseScheduler {
+get_lr() List~float~
+step()
}
class SchedulerFactory {
+Registry _registry
+register(name) decorator
+create(optimizer, schedule_type, **kwargs) BaseScheduler
}
class CosineScheduler {
+int warmup_steps
+int lr_decay_steps
+float min_rate
}
class SGDRScheduler {
+int warmup_steps
+int cycle_length
+float min_rate
+int t_mult
}
class TrainCallback {
+on_train_begin(context)
+on_train_end(context)
+on_epoch_begin(context)
+on_epoch_end(context)
+on_step_begin(context)
+on_step_end(context)
+on_batch_begin(context)
+on_batch_end(context)
+on_error(context)
}
class CallbackFactory {
+Registry _registry
+register(name) decorator
+create(name, **kwargs) TrainCallback
}
}
%% Model Classes
class Transformer {
+forward(input_ids, mask) Dict
namespace astrai.inference {
class GeneratorCore {
+ModelParameter parameter
+generate_iterator(input_ids, temperature, top_k, top_p, attn_mask, kv_caches, start_pos) Tuple~Tensor, int~
+generate_loop(input_ids, ids, temperature, top_k, top_p, attn_mask, kv_caches, start_pos, callback) List~int~
}
class LoopGenerator {
+generate(request) str
}
class StreamGenerator {
+generate(request) Generator
}
class BatchGenerator {
+generate(request) List~str~
}
class EmbeddingEncoder {
+encode(sentence) Tensor
}
class KVCacheManager {
+int batch_size
+Tuple kv_cache
+Tensor seq_mask
+get_kvcache() Tuple
+get_seq_mask() Tensor
+update(active_mask)
+reset(full_reset)
}
class GeneratorFactory {
+create(parameter, request) GeneratorCore
+create_encoder(parameter) EmbeddingEncoderCore
}
class Server {
+start()
+predict(request)
}
class GenerationRequest {
+int top_k
+float top_p
+float temperature
+int max_len
+Union~str, List~str~~ query
+history Optional
+system_prompt Optional~str~
+stream bool
}
}
%% Trainer Classes
class Trainer {
+TrainConfig train_config
+List~TrainCallback~ callbacks
+train()
+_build_context() TrainContext
}
namespace astrai.parallel {
class ParallelSetup {
+spawn_parallel_fn(fn, nprocs)
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type, device_ids)
}
class TrainContext {
+nn.Module model
+BaseStrategy strategy
+DataLoader dataloader
+Optimizer optimizer
+LRScheduler scheduler
+Checkpoint checkpoint
+int epoch
+int iteration
}
class ColumnParallelLinear {
+forward(x) Tensor
}
class TrainContextBuilder {
+TrainConfig config
+with_checkpoint(Checkpoint) TrainContextBuilder
+with_dataloader() TrainContextBuilder
+with_strategy() TrainContextBuilder
+build() TrainContext
}
class BaseStrategy {
+nn.Module model
+str device
+compute_loss(batch) Tensor
}
class StrategyFactory {
+frozenset SUPPORTED_STRATEGIES
+Dict STRATEGY_MAP
+register(name) decorator
+create(model, train_type, device) BaseStrategy
+available_strategies() list
}
class SEQStrategy {
+float label_smoothing
+compute_loss(batch) Tensor
}
class SFTStrategy {
+float label_smoothing
+compute_loss(batch) Tensor
}
class DPOStrategy {
+nn.Module ref_model
+float beta
+str reduction
+compute_loss(batch) Tensor
}
class GRPOStrategy {
+nn.Module ref_model
+float clip_eps
+float kl_coef
+int group_size
+compute_loss(batch) Tensor
}
class TrainCallback {
+on_train_begin(trainer)
+on_train_end(trainer)
+on_epoch_begin(epoch, trainer)
+on_epoch_end(epoch, trainer)
+on_batch_begin(batch, trainer)
+on_batch_end(batch, trainer)
}
class Schedule {
+step()
}
%% Inference Classes
class Generator {
+generate(prompt, config) str
+generate_batch(prompts, config) List[str]
+stream_generate(prompt, config) Generator
}
class InferenceCore {
+forward(input_ids) Dict
+apply_sampling_strategies()
}
class Server {
+start()
+predict(request)
}
%% Parallel Classes
class ParallelSetup {
+spawn_parallel_fn(fn, nprocs)
class RowParallelLinear {
+forward(x) Tensor
}
}
%% Relationships
TrainConfig --> ModelConfig : contains
TrainConfig --> Dataset : uses
TrainConfig --> BaseDataset : uses
TrainConfig --> Transformer : uses
Trainer --> TrainConfig : configures
Trainer --> TrainContextBuilder : builds
Trainer --> TrainCallback : manages
TrainContextBuilder --> TrainContext : creates
TrainContext --> Checkpoint : manages
TrainContext --> BaseStrategy : uses
TrainContext --> BaseScheduler : uses
StrategyFactory ..> BaseStrategy : creates
BaseStrategy <|-- SEQStrategy
BaseStrategy <|-- SFTStrategy
BaseStrategy <|-- DPOStrategy
BaseStrategy <|-- GRPOStrategy
TrainContext --> BaseStrategy : uses
Generator --> InferenceCore : uses
Generator --> Transformer : uses
Server --> Generator : uses
SchedulerFactory ..> BaseScheduler : creates
BaseScheduler <|-- CosineScheduler
BaseScheduler <|-- SGDRScheduler
CallbackFactory ..> TrainCallback : creates
GeneratorFactory ..> GeneratorCore : creates
GeneratorCore <|-- LoopGenerator
GeneratorCore <|-- StreamGenerator
GeneratorCore <|-- BatchGenerator
GeneratorCore <|-- EmbeddingEncoder
LoopGenerator --> KVCacheManager : uses
StreamGenerator --> KVCacheManager : uses
BatchGenerator --> KVCacheManager : uses
GeneratorCore --> Transformer : uses
DPOStrategy --> Transformer : creates ref_model
GRPOStrategy --> Transformer : creates ref_model
Server --> GeneratorFactory : uses
ParallelSetup --> Trainer : enables
TrainConfig --> StrategyFactory : selects
TrainCallback <|-- CheckpointCallback
TrainCallback <|-- MetricLoggerCallback
TrainCallback <|-- SchedulerCallback
TrainContext --> Schedule : uses
ModelParameter --> Transformer : contains
ModelParameter --> BpeTokenizer : contains
ModelParameter --> ModelConfig : contains
GeneratorFactory --> GenerationRequest : uses
BaseDataset <|-- SEQDataset
BaseDataset <|-- SFTDataset
BaseDataset <|-- DPODataset
BaseDataset <|-- GRPODataset
DatasetFactory ..> BaseDataset : creates
BaseSegmentFetcher --> MultiSegmentFetcher : used by
MultiSegmentFetcher --> BaseDataset : used by
Transformer --> DecoderBlock : uses
Transformer --> RotaryEmbedding : uses
Transformer --> Embedding : uses
DecoderBlock --> GQA : uses
DecoderBlock --> MLP : uses
DecoderBlock --> RMSNorm : uses
BpeTokenizer --> Tokenizer : inherits
TrainContextBuilder --> ResumableDistributedSampler : creates
DataLoader --> BaseDataset : uses
ResumableDistributedSampler --> BaseDataset : samples
```
### Design Pattern Summary
### Module Overview
| Module | Components | Description |
|--------|------------|-------------|
| **astrai.config** | ModelConfig, TrainConfig, ModelParameter | Configuration management |
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint, DataLoader | Dataset loading and management |
| **astrai.model** | Transformer, DecoderBlock, GQA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | Tokenizer, BpeTokenizer | Tokenizer |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
| **astrai.inference** | GeneratorCore, LoopGenerator, StreamGenerator, BatchGenerator, EmbeddingEncoder, KVCacheManager, GeneratorFactory, Server, GenerationRequest | Inference service |
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
### Design Patterns
| Pattern | Classes | Purpose |
|---------|---------|---------|
| **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO |
| **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components |
| **Factory** | `StrategyFactory` | Decorator registration mechanism, dynamically create training strategies |
| **Observer** | `TrainCallback` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `GeneratorFactory`, `CallbackFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, generators, and callbacks |
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
| **Singleton** | `TrainContext` | Training process global state management |
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
### Core Relationships
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` to compute loss
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
4. **Inference Flow**: `Server``Generator` → `InferenceCore``Transformer`
4. **Inference Flow**: `Server``GeneratorFactory` → `GeneratorCore``Transformer`, supports multiple generators (LoopGenerator, StreamGenerator, BatchGenerator)
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
8. **Scheduler Support**: `SchedulerFactory` creates learning rate schedulers (CosineScheduler, SGDRScheduler)
## 3. Training Process