From 99b821ebf51d287f24095928caa148a5c0b3a3ee Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 4 Apr 2026 18:11:36 +0800 Subject: [PATCH] =?UTF-8?q?docs:=20=20=E6=9B=B4=E6=96=B0=E6=96=87=E6=A1=A3?= =?UTF-8?q?=E7=B1=BB=E5=9B=BE=E7=AD=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- assets/docs/dataflow.md | 67 ++--- assets/docs/design.md | 570 +++++++++++++++++++++++++++++----------- 2 files changed, 453 insertions(+), 184 deletions(-) diff --git a/assets/docs/dataflow.md b/assets/docs/dataflow.md index 29ef4ea..fc84bbf 100644 --- a/assets/docs/dataflow.md +++ b/assets/docs/dataflow.md @@ -5,7 +5,7 @@ This document describes the data flow of the AstrAI project (a training and infe ## Overview AstrAI adopts a modular design with the following main components: -- **Data Module** (`astrai/data/`): Dataset, sampler, tokenizer, serialization tools +- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools - **Model Module** (`astrai/model/`): Transformer model and its submodules - **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers - **Inference Module** (`astrai/inference/`): Generation core, KV cache management, streaming generation @@ -20,35 +20,39 @@ The data flow can generally be divided into two main lines: **Training Data Flow flowchart LR subgraph A[Data Preparation] direction TB - A1[Raw Text] --> A2[BBPE Tokenizer] + A1[Raw Text] --> A2[BpeTokenizer] A2 --> A3[Serialize to .h5 files] - A3 --> A4[Dataset Loading
BaseDataset] - A4 --> A5[Resumable Distributed Sampler
ResumableDistributedSampler] - A5 --> A6[DataLoader Batch Loading] + A3 --> A4[BaseDataset] + A4 --> A5[ResumableDistributedSampler] + A5 --> A6[DataLoader] end - subgraph B[Training Loop] + subgraph B[Training] direction TB - B1[Batch Data] --> B2[Training Strategy
BaseStrategy] - B2 --> B3[Transformer Model] - B3 --> B4[Output logits] - B4 --> B5[Loss Calculation] - B5 --> B6[Backpropagation] - B6 --> B7[Optimizer Update] - B7 --> B8[Learning Rate Scheduler] - B8 --> B9[Checkpoint Save] + B1[Batch Data] --> B2[TrainContextBuilder] + B2 --> B3[TrainContext] + B3 --> B4[BaseStrategy] + B4 --> B5[Transformer] + B5 --> B6[Compute Loss] + B6 --> B7[Backward] + B7 --> B8[Optimizer] + B8 --> B9[LRScheduler] + B9 --> B10[CheckpointCallback] end - subgraph C[Inference Generation] + subgraph C[Inference] direction TB - C1[Checkpoint Loading] --> C2[Inference Model Loading] - C2 --> C3[Generation Core
GeneratorCore] - C3 --> C4[Sampling Strategy
Temperature/top-k/top-p] - C4 --> C5[Generate Next Token] - C5 --> C6[KV Cache Update] - C6 --> C7{Max Length Reached?} - C7 -->|No| C5 - C7 -->|Yes| C8[Output Generated Text] + C1[Checkpoint] --> C2[ModelParameter] + C2 --> C3[Transformer + BpeTokenizer] + C3 --> C4[GenerationRequest + build_prompt] + C4 --> C5[GeneratorFactory] + C5 --> C6[GeneratorCore] + C6 --> C7[apply_sampling_strategies] + C7 --> C8[Transformer Forward] + C8 --> C9[KVCacheManager] + C9 --> C10{End Condition?} + C10 -->|No| C8 + C10 -->|Yes| C11[Output Text] end A --> B @@ -57,13 +61,14 @@ flowchart LR ## Detailed Module Descriptions -### 1. Data Module +### 1. Dataset Module #### 1.1 Tokenizer (`tokenizer.py`) -- Implemented based on Byte-Level BPE (BBPE) +- Implemented based on Byte-Level BPE (BPE) - Supports special tokens: `<|begin▁of▁sentence|>`, `<|end▁of▁sentence|>`, `<|▁pad▁|>`, `<|im▁start|>`, `<|im▁end|>` - Provides `encode`/`decode` methods for mutual conversion between text and token IDs - Learns vocabulary from corpus during training, saved as `.json` files +- `BpeTrainer` class handles vocabulary training from corpus #### 1.2 Serialization (`serialization.py`) - **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors @@ -108,7 +113,7 @@ flowchart LR 1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end` #### 3.3 Strategy (`strategy.py`) -- **`BaseStrategy`**: Defines training strategy interface (such as `SeqStrategy`, `SFTStrategy`, `DPOStrategy`) +- **`BaseStrategy`**: Defines training strategy interface (such as `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`) - Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor - Created dynamically by `StrategyFactory` according to configuration @@ -130,14 +135,16 @@ flowchart LR #### 4.3 Generator (`generator.py`) - **`GenerationRequest`**: Encapsulates generation request parameters (top_k, top_p, temperature, max_len, query, history, etc.) -- **`build_prompt`**: Converts query and history into ChatML format prompt string +- **`build_prompt`** (from `chat_template.py`): Converts query and history into ChatML format prompt string +- **`GeneratorCore`**: Base generator with generate_iterator and generate_loop methods +- **`LoopGenerator`**, **`StreamGenerator`**, **`BatchGenerator`**: Different generation modes - **`pad_sequence`**: Pads input IDs to consistent length - Provides streaming and non-streaming generation interfaces ## Training Data Flow - Detailed Steps 1. **Data Preparation** - - Raw text is converted to token ID sequences through BBPE tokenizer + - Raw text is converted to token ID sequences through BPE tokenizer - Token ID sequences (possibly with masks, labels, etc.) are saved by groups as `.h5` files - Files can contain multiple segments, each segment corresponds to a tensor @@ -152,7 +159,7 @@ flowchart LR - Batch data shape is `[batch_size, window_size]` (or varies according to specific dataset type) 4. **Strategy Forward and Loss Calculation** - - Batch data is passed to strategy (such as `SeqStrategy`) + - Batch data is passed to strategy (such as `SEQStrategy`) - Strategy internally calls `Transformer` model, obtaining logits - Calculate cross-entropy loss (or DPO loss, etc.) according to task type - Return loss tensor @@ -174,7 +181,7 @@ flowchart LR - Set model to evaluation mode (`model.eval()`), enable inference mode (`torch.inference_mode`) 2. **Prompt Construction and Encoding** - - User query and history are converted to ChatML format string through `build_prompt` + - User query and history are converted to ChatML format string through `build_prompt` function in chat_template module - Tokenizer encodes prompt string to token ID sequence `input_ids` - For batch generation, use `pad_sequence` for padding diff --git a/assets/docs/design.md b/assets/docs/design.md index c4f0952..5903c40 100644 --- a/assets/docs/design.md +++ b/assets/docs/design.md @@ -8,206 +8,468 @@ Thus, the AstrAI project was born - 1B parameters, Chinese-English bilingual, su ```mermaid classDiagram - %% Configuration Classes - class ModelConfig { - +int vocab_size - +int dim - +int n_layers - +float norm_eps - +int dim_ffn - +int max_len - +float rope_theta - +int n_heads - +int n_kv_heads - +bool use_qk_norm - +bool use_gated_attention - +load(config_path) ModelConfig - +save(config_path) + namespace astrai.config { + class ModelConfig { + +int vocab_size + +int dim + +int n_layers + +float norm_eps + +int dim_ffn + +bool tie_weight + +int max_len + +float rope_theta + +int n_heads + +int n_kv_heads + +bool use_qk_norm + +bool use_gated_attention + +load(config_path) ModelConfig + +save(config_path) + } + + class TrainConfig { + +nn.Module model + +str strategy + +Dataset dataset + +Callable optimizer_fn + +Callable scheduler_fn + +int n_epoch + +int batch_size + +int accumulation_steps + +float max_grad_norm + +str ckpt_dir + +int ckpt_interval + +int nprocs + +str backend + +validate() + } + + class ModelParameter { + +nn.Module model + +BpeTokenizer tokenizer + +ModelConfig config + +save(instance, save_dir) + +load(load_dir, disable_init) ModelParameter + +to(*args, **kwargs) + } } - class TrainConfig { - +nn.Module model - +str strategy - +Dataset dataset - +Callable optimizer_fn - +Callable scheduler_fn - +int n_epoch - +int batch_size - +int accumulation_steps - +float max_grad_norm - +str ckpt_dir - +int ckpt_interval - +int nprocs - +str backend - +validate() + namespace astrai.dataset { + class BaseDataset { + +int window_size + +int stride + +MultiSegmentFetcher fetcher + +load(load_path) + +__getitem__(index) + +__len__() + } + + class SEQDataset { + +__getitem__(index) Dict + } + + class SFTDataset { + +__getitem__(index) Dict + } + + class DPODataset { + +__getitem__(index) Dict + } + + class GRPODataset { + +__getitem__(index) Dict + } + + class BaseSegmentFetcher { + +List~Tensor~ segments + +List~int~ cum_lengths + +int total_length + +fetch_data(begin_idx, end_idx) Tensor + } + + class MultiSegmentFetcher { + +Dict muti_fetchers + +List muti_keys + +key_fetch(begin_idx, end_idx, keys) Dict + +fetch_data(begin_idx, end_idx) Dict + } + + class ResumableDistributedSampler { + +int start_epoch + +int start_iter + } + + class DatasetFactory { + +Registry _registry + +register(name) decorator + +create(train_type, window_size, stride) BaseDataset + +load(train_type, load_path, window_size, stride) BaseDataset + } + + class Checkpoint { + +dict state_dict + +int epoch + +int iteration + +save(save_dir) + +load(save_dir) Checkpoint + } + + class DataLoader { + +Dataset dataset + +int batch_size + +Sampler sampler + +__iter__() + +__len__() + } } - %% Data Classes - class Dataset { - +__len__() - +__getitem__() + namespace astrai.model { + class Transformer { + +ModelConfig config + +RotaryEmbedding rotary_embeding + +Embedding embed_tokens + +ModuleList layers + +RMSNorm norm + +Linear lm_head + +forward(input_ids, input_mask, persistent_key_values, start_pos) Dict + +load_state_dict(state_dict) + +state_dict() + } + + class DecoderBlock { + +GQA attention + +RMSNorm input_norm + +MLP mlp + +RMSNorm post_attention_norm + +forward(x, rotary_emb, attention_mask, kv_cache, start_pos) Tensor + } + + class GQA { + +int n_heads + +int n_kv_heads + +int head_dim + +Linear q_proj, k_proj, v_proj, o_proj + +RMSNorm q_norm, k_norm + +forward(x, rotary_emb, mask, kv_cache, start_pos) Tensor + } + + class MLP { + +Linear up, gate, down + +forward(x) Tensor + } + + class RMSNorm { + +Parameter weight + +float norm_eps + +forward(x) Tensor + } + + class Linear { + +Parameter weight + +Parameter bias + +forward(x) Tensor + } + + class RotaryEmbedding { + +int dim + +int max_len + +float base + +forward(x, start_pos) Tuple~Tensor, Tensor~ + } + + class Embedding { + +Parameter weight + +forward(x) Tensor + } } - class Checkpoint { - +dict state_dict - +int epoch - +int iteration + namespace astrai.tokenize { + class Tokenizer { + +encode(tokens, out_ids, add_special_tokens) List~int~ + +decode(tokens, skip_special_tokens) str + +__len__() int + } + + class BpeTokenizer { + +List~str~ stop_ids + +int bos_id + +int eos_id + +int pad_id + +encode(tokens, out_ids, add_special_tokens) List~int~ + +decode(tokens, skip_special_tokens) str + } } - class Tokenizer { - +encode(text) List[int] - +decode(ids) str + namespace astrai.trainer { + class Trainer { + +TrainConfig train_config + +List~TrainCallback~ callbacks + +train(checkpoint) + +_build_context(checkpoint) TrainContext + +_get_default_callbacks() List~TrainCallback~ + } + + class TrainContext { + +nn.Module model + +BaseStrategy strategy + +DataLoader dataloader + +Optimizer optimizer + +LRScheduler scheduler + +Checkpoint checkpoint + +int epoch + +int iteration + +float loss + +int world_size + +int rank + } + + class TrainContextBuilder { + +TrainConfig config + +with_checkpoint(checkpoint) TrainContextBuilder + +with_dataloader() TrainContextBuilder + +with_strategy() TrainContextBuilder + +build() TrainContext + } + + class BaseStrategy { + +nn.Module model + +str device + +compute_loss(batch) Tensor + } + + class StrategyFactory { + +Registry _registry + +register(name) decorator + +create(model, train_type, device, **kwargs) BaseStrategy + } + + class SEQStrategy { + +float label_smoothing + +compute_loss(batch) Tensor + } + + class SFTStrategy { + +float label_smoothing + +compute_loss(batch) Tensor + } + + class DPOStrategy { + +nn.Module ref_model + +float beta + +str reduction + +compute_loss(batch) Tensor + } + + class GRPOStrategy { + +nn.Module ref_model + +float clip_eps + +float kl_coef + +int group_size + +compute_loss(batch) Tensor + } + + class BaseScheduler { + +get_lr() List~float~ + +step() + } + + class SchedulerFactory { + +Registry _registry + +register(name) decorator + +create(optimizer, schedule_type, **kwargs) BaseScheduler + } + + class CosineScheduler { + +int warmup_steps + +int lr_decay_steps + +float min_rate + } + + class SGDRScheduler { + +int warmup_steps + +int cycle_length + +float min_rate + +int t_mult + } + + class TrainCallback { + +on_train_begin(context) + +on_train_end(context) + +on_epoch_begin(context) + +on_epoch_end(context) + +on_step_begin(context) + +on_step_end(context) + +on_batch_begin(context) + +on_batch_end(context) + +on_error(context) + } + + class CallbackFactory { + +Registry _registry + +register(name) decorator + +create(name, **kwargs) TrainCallback + } } - %% Model Classes - class Transformer { - +forward(input_ids, mask) Dict + namespace astrai.inference { + class GeneratorCore { + +ModelParameter parameter + +generate_iterator(input_ids, temperature, top_k, top_p, attn_mask, kv_caches, start_pos) Tuple~Tensor, int~ + +generate_loop(input_ids, ids, temperature, top_k, top_p, attn_mask, kv_caches, start_pos, callback) List~int~ + } + + class LoopGenerator { + +generate(request) str + } + + class StreamGenerator { + +generate(request) Generator + } + + class BatchGenerator { + +generate(request) List~str~ + } + + class EmbeddingEncoder { + +encode(sentence) Tensor + } + + class KVCacheManager { + +int batch_size + +Tuple kv_cache + +Tensor seq_mask + +get_kvcache() Tuple + +get_seq_mask() Tensor + +update(active_mask) + +reset(full_reset) + } + + class GeneratorFactory { + +create(parameter, request) GeneratorCore + +create_encoder(parameter) EmbeddingEncoderCore + } + + class Server { + +start() + +predict(request) + } + + class GenerationRequest { + +int top_k + +float top_p + +float temperature + +int max_len + +Union~str, List~str~~ query + +history Optional + +system_prompt Optional~str~ + +stream bool + } } - %% Trainer Classes - class Trainer { - +TrainConfig train_config - +List~TrainCallback~ callbacks - +train() - +_build_context() TrainContext - } + namespace astrai.parallel { + class ParallelSetup { + +spawn_parallel_fn(fn, nprocs) + +setup_parallel(rank, world_size, backend, master_addr, master_port, device_type, device_ids) + } - class TrainContext { - +nn.Module model - +BaseStrategy strategy - +DataLoader dataloader - +Optimizer optimizer - +LRScheduler scheduler - +Checkpoint checkpoint - +int epoch - +int iteration - } + class ColumnParallelLinear { + +forward(x) Tensor + } - class TrainContextBuilder { - +TrainConfig config - +with_checkpoint(Checkpoint) TrainContextBuilder - +with_dataloader() TrainContextBuilder - +with_strategy() TrainContextBuilder - +build() TrainContext - } - - class BaseStrategy { - +nn.Module model - +str device - +compute_loss(batch) Tensor - } - - class StrategyFactory { - +frozenset SUPPORTED_STRATEGIES - +Dict STRATEGY_MAP - +register(name) decorator - +create(model, train_type, device) BaseStrategy - +available_strategies() list - } - - class SEQStrategy { - +float label_smoothing - +compute_loss(batch) Tensor - } - - class SFTStrategy { - +float label_smoothing - +compute_loss(batch) Tensor - } - - class DPOStrategy { - +nn.Module ref_model - +float beta - +str reduction - +compute_loss(batch) Tensor - } - - class GRPOStrategy { - +nn.Module ref_model - +float clip_eps - +float kl_coef - +int group_size - +compute_loss(batch) Tensor - } - - class TrainCallback { - +on_train_begin(trainer) - +on_train_end(trainer) - +on_epoch_begin(epoch, trainer) - +on_epoch_end(epoch, trainer) - +on_batch_begin(batch, trainer) - +on_batch_end(batch, trainer) - } - - class Schedule { - +step() - } - - %% Inference Classes - class Generator { - +generate(prompt, config) str - +generate_batch(prompts, config) List[str] - +stream_generate(prompt, config) Generator - } - - class InferenceCore { - +forward(input_ids) Dict - +apply_sampling_strategies() - } - - class Server { - +start() - +predict(request) - } - - %% Parallel Classes - class ParallelSetup { - +spawn_parallel_fn(fn, nprocs) + class RowParallelLinear { + +forward(x) Tensor + } } %% Relationships TrainConfig --> ModelConfig : contains - TrainConfig --> Dataset : uses + TrainConfig --> BaseDataset : uses TrainConfig --> Transformer : uses Trainer --> TrainConfig : configures Trainer --> TrainContextBuilder : builds Trainer --> TrainCallback : manages TrainContextBuilder --> TrainContext : creates TrainContext --> Checkpoint : manages + TrainContext --> BaseStrategy : uses + TrainContext --> BaseScheduler : uses StrategyFactory ..> BaseStrategy : creates BaseStrategy <|-- SEQStrategy BaseStrategy <|-- SFTStrategy BaseStrategy <|-- DPOStrategy BaseStrategy <|-- GRPOStrategy - TrainContext --> BaseStrategy : uses - Generator --> InferenceCore : uses - Generator --> Transformer : uses - Server --> Generator : uses + SchedulerFactory ..> BaseScheduler : creates + BaseScheduler <|-- CosineScheduler + BaseScheduler <|-- SGDRScheduler + CallbackFactory ..> TrainCallback : creates + GeneratorFactory ..> GeneratorCore : creates + GeneratorCore <|-- LoopGenerator + GeneratorCore <|-- StreamGenerator + GeneratorCore <|-- BatchGenerator + GeneratorCore <|-- EmbeddingEncoder + LoopGenerator --> KVCacheManager : uses + StreamGenerator --> KVCacheManager : uses + BatchGenerator --> KVCacheManager : uses + GeneratorCore --> Transformer : uses + DPOStrategy --> Transformer : creates ref_model + GRPOStrategy --> Transformer : creates ref_model + Server --> GeneratorFactory : uses ParallelSetup --> Trainer : enables TrainConfig --> StrategyFactory : selects - TrainCallback <|-- CheckpointCallback - TrainCallback <|-- MetricLoggerCallback - TrainCallback <|-- SchedulerCallback - TrainContext --> Schedule : uses + ModelParameter --> Transformer : contains + ModelParameter --> BpeTokenizer : contains + ModelParameter --> ModelConfig : contains + GeneratorFactory --> GenerationRequest : uses + BaseDataset <|-- SEQDataset + BaseDataset <|-- SFTDataset + BaseDataset <|-- DPODataset + BaseDataset <|-- GRPODataset + DatasetFactory ..> BaseDataset : creates + BaseSegmentFetcher --> MultiSegmentFetcher : used by + MultiSegmentFetcher --> BaseDataset : used by + Transformer --> DecoderBlock : uses + Transformer --> RotaryEmbedding : uses + Transformer --> Embedding : uses + DecoderBlock --> GQA : uses + DecoderBlock --> MLP : uses + DecoderBlock --> RMSNorm : uses + BpeTokenizer --> Tokenizer : inherits + TrainContextBuilder --> ResumableDistributedSampler : creates + DataLoader --> BaseDataset : uses + ResumableDistributedSampler --> BaseDataset : samples ``` -### Design Pattern Summary +### Module Overview + +| Module | Components | Description | +|--------|------------|-------------| +| **astrai.config** | ModelConfig, TrainConfig, ModelParameter | Configuration management | +| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint, DataLoader | Dataset loading and management | +| **astrai.model** | Transformer, DecoderBlock, GQA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | +| **astrai.tokenize** | Tokenizer, BpeTokenizer | Tokenizer | +| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management | +| **astrai.inference** | GeneratorCore, LoopGenerator, StreamGenerator, BatchGenerator, EmbeddingEncoder, KVCacheManager, GeneratorFactory, Server, GenerationRequest | Inference service | +| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel | + +### Design Patterns | Pattern | Classes | Purpose | |---------|---------|---------| | **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO | | **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components | -| **Factory** | `StrategyFactory` | Decorator registration mechanism, dynamically create training strategies | -| **Observer** | `TrainCallback` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) | +| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `GeneratorFactory`, `CallbackFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, generators, and callbacks | +| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) | | **Singleton** | `TrainContext` | Training process global state management | +| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support | ### Core Relationships 1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references 2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss 3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type` -4. **Inference Flow**: `Server` → `Generator` → `InferenceCore` → `Transformer` +4. **Inference Flow**: `Server` → `GeneratorFactory` → `GeneratorCore` → `Transformer`, supports multiple generators (LoopGenerator, StreamGenerator, BatchGenerator) 5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer` +6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher` +7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors +8. **Scheduler Support**: `SchedulerFactory` creates learning rate schedulers (CosineScheduler, SGDRScheduler) ## 3. Training Process