From 99b821ebf51d287f24095928caa148a5c0b3a3ee Mon Sep 17 00:00:00 2001
From: ViperEkura <3081035982@qq.com>
Date: Sat, 4 Apr 2026 18:11:36 +0800
Subject: [PATCH] =?UTF-8?q?docs:=20=20=E6=9B=B4=E6=96=B0=E6=96=87=E6=A1=A3?=
=?UTF-8?q?=E7=B1=BB=E5=9B=BE=E7=AD=89?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
assets/docs/dataflow.md | 67 ++---
assets/docs/design.md | 570 +++++++++++++++++++++++++++++-----------
2 files changed, 453 insertions(+), 184 deletions(-)
diff --git a/assets/docs/dataflow.md b/assets/docs/dataflow.md
index 29ef4ea..fc84bbf 100644
--- a/assets/docs/dataflow.md
+++ b/assets/docs/dataflow.md
@@ -5,7 +5,7 @@ This document describes the data flow of the AstrAI project (a training and infe
## Overview
AstrAI adopts a modular design with the following main components:
-- **Data Module** (`astrai/data/`): Dataset, sampler, tokenizer, serialization tools
+- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools
- **Model Module** (`astrai/model/`): Transformer model and its submodules
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers
- **Inference Module** (`astrai/inference/`): Generation core, KV cache management, streaming generation
@@ -20,35 +20,39 @@ The data flow can generally be divided into two main lines: **Training Data Flow
flowchart LR
subgraph A[Data Preparation]
direction TB
- A1[Raw Text] --> A2[BBPE Tokenizer]
+ A1[Raw Text] --> A2[BpeTokenizer]
A2 --> A3[Serialize to .h5 files]
- A3 --> A4[Dataset Loading
BaseDataset]
- A4 --> A5[Resumable Distributed Sampler
ResumableDistributedSampler]
- A5 --> A6[DataLoader Batch Loading]
+ A3 --> A4[BaseDataset]
+ A4 --> A5[ResumableDistributedSampler]
+ A5 --> A6[DataLoader]
end
- subgraph B[Training Loop]
+ subgraph B[Training]
direction TB
- B1[Batch Data] --> B2[Training Strategy
BaseStrategy]
- B2 --> B3[Transformer Model]
- B3 --> B4[Output logits]
- B4 --> B5[Loss Calculation]
- B5 --> B6[Backpropagation]
- B6 --> B7[Optimizer Update]
- B7 --> B8[Learning Rate Scheduler]
- B8 --> B9[Checkpoint Save]
+ B1[Batch Data] --> B2[TrainContextBuilder]
+ B2 --> B3[TrainContext]
+ B3 --> B4[BaseStrategy]
+ B4 --> B5[Transformer]
+ B5 --> B6[Compute Loss]
+ B6 --> B7[Backward]
+ B7 --> B8[Optimizer]
+ B8 --> B9[LRScheduler]
+ B9 --> B10[CheckpointCallback]
end
- subgraph C[Inference Generation]
+ subgraph C[Inference]
direction TB
- C1[Checkpoint Loading] --> C2[Inference Model Loading]
- C2 --> C3[Generation Core
GeneratorCore]
- C3 --> C4[Sampling Strategy
Temperature/top-k/top-p]
- C4 --> C5[Generate Next Token]
- C5 --> C6[KV Cache Update]
- C6 --> C7{Max Length Reached?}
- C7 -->|No| C5
- C7 -->|Yes| C8[Output Generated Text]
+ C1[Checkpoint] --> C2[ModelParameter]
+ C2 --> C3[Transformer + BpeTokenizer]
+ C3 --> C4[GenerationRequest + build_prompt]
+ C4 --> C5[GeneratorFactory]
+ C5 --> C6[GeneratorCore]
+ C6 --> C7[apply_sampling_strategies]
+ C7 --> C8[Transformer Forward]
+ C8 --> C9[KVCacheManager]
+ C9 --> C10{End Condition?}
+ C10 -->|No| C8
+ C10 -->|Yes| C11[Output Text]
end
A --> B
@@ -57,13 +61,14 @@ flowchart LR
## Detailed Module Descriptions
-### 1. Data Module
+### 1. Dataset Module
#### 1.1 Tokenizer (`tokenizer.py`)
-- Implemented based on Byte-Level BPE (BBPE)
+- Implemented based on Byte-Level BPE (BPE)
- Supports special tokens: `<|begin▁of▁sentence|>`, `<|end▁of▁sentence|>`, `<|▁pad▁|>`, `<|im▁start|>`, `<|im▁end|>`
- Provides `encode`/`decode` methods for mutual conversion between text and token IDs
- Learns vocabulary from corpus during training, saved as `.json` files
+- `BpeTrainer` class handles vocabulary training from corpus
#### 1.2 Serialization (`serialization.py`)
- **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors
@@ -108,7 +113,7 @@ flowchart LR
1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end`
#### 3.3 Strategy (`strategy.py`)
-- **`BaseStrategy`**: Defines training strategy interface (such as `SeqStrategy`, `SFTStrategy`, `DPOStrategy`)
+- **`BaseStrategy`**: Defines training strategy interface (such as `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`)
- Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor
- Created dynamically by `StrategyFactory` according to configuration
@@ -130,14 +135,16 @@ flowchart LR
#### 4.3 Generator (`generator.py`)
- **`GenerationRequest`**: Encapsulates generation request parameters (top_k, top_p, temperature, max_len, query, history, etc.)
-- **`build_prompt`**: Converts query and history into ChatML format prompt string
+- **`build_prompt`** (from `chat_template.py`): Converts query and history into ChatML format prompt string
+- **`GeneratorCore`**: Base generator with generate_iterator and generate_loop methods
+- **`LoopGenerator`**, **`StreamGenerator`**, **`BatchGenerator`**: Different generation modes
- **`pad_sequence`**: Pads input IDs to consistent length
- Provides streaming and non-streaming generation interfaces
## Training Data Flow - Detailed Steps
1. **Data Preparation**
- - Raw text is converted to token ID sequences through BBPE tokenizer
+ - Raw text is converted to token ID sequences through BPE tokenizer
- Token ID sequences (possibly with masks, labels, etc.) are saved by groups as `.h5` files
- Files can contain multiple segments, each segment corresponds to a tensor
@@ -152,7 +159,7 @@ flowchart LR
- Batch data shape is `[batch_size, window_size]` (or varies according to specific dataset type)
4. **Strategy Forward and Loss Calculation**
- - Batch data is passed to strategy (such as `SeqStrategy`)
+ - Batch data is passed to strategy (such as `SEQStrategy`)
- Strategy internally calls `Transformer` model, obtaining logits
- Calculate cross-entropy loss (or DPO loss, etc.) according to task type
- Return loss tensor
@@ -174,7 +181,7 @@ flowchart LR
- Set model to evaluation mode (`model.eval()`), enable inference mode (`torch.inference_mode`)
2. **Prompt Construction and Encoding**
- - User query and history are converted to ChatML format string through `build_prompt`
+ - User query and history are converted to ChatML format string through `build_prompt` function in chat_template module
- Tokenizer encodes prompt string to token ID sequence `input_ids`
- For batch generation, use `pad_sequence` for padding
diff --git a/assets/docs/design.md b/assets/docs/design.md
index c4f0952..5903c40 100644
--- a/assets/docs/design.md
+++ b/assets/docs/design.md
@@ -8,206 +8,468 @@ Thus, the AstrAI project was born - 1B parameters, Chinese-English bilingual, su
```mermaid
classDiagram
- %% Configuration Classes
- class ModelConfig {
- +int vocab_size
- +int dim
- +int n_layers
- +float norm_eps
- +int dim_ffn
- +int max_len
- +float rope_theta
- +int n_heads
- +int n_kv_heads
- +bool use_qk_norm
- +bool use_gated_attention
- +load(config_path) ModelConfig
- +save(config_path)
+ namespace astrai.config {
+ class ModelConfig {
+ +int vocab_size
+ +int dim
+ +int n_layers
+ +float norm_eps
+ +int dim_ffn
+ +bool tie_weight
+ +int max_len
+ +float rope_theta
+ +int n_heads
+ +int n_kv_heads
+ +bool use_qk_norm
+ +bool use_gated_attention
+ +load(config_path) ModelConfig
+ +save(config_path)
+ }
+
+ class TrainConfig {
+ +nn.Module model
+ +str strategy
+ +Dataset dataset
+ +Callable optimizer_fn
+ +Callable scheduler_fn
+ +int n_epoch
+ +int batch_size
+ +int accumulation_steps
+ +float max_grad_norm
+ +str ckpt_dir
+ +int ckpt_interval
+ +int nprocs
+ +str backend
+ +validate()
+ }
+
+ class ModelParameter {
+ +nn.Module model
+ +BpeTokenizer tokenizer
+ +ModelConfig config
+ +save(instance, save_dir)
+ +load(load_dir, disable_init) ModelParameter
+ +to(*args, **kwargs)
+ }
}
- class TrainConfig {
- +nn.Module model
- +str strategy
- +Dataset dataset
- +Callable optimizer_fn
- +Callable scheduler_fn
- +int n_epoch
- +int batch_size
- +int accumulation_steps
- +float max_grad_norm
- +str ckpt_dir
- +int ckpt_interval
- +int nprocs
- +str backend
- +validate()
+ namespace astrai.dataset {
+ class BaseDataset {
+ +int window_size
+ +int stride
+ +MultiSegmentFetcher fetcher
+ +load(load_path)
+ +__getitem__(index)
+ +__len__()
+ }
+
+ class SEQDataset {
+ +__getitem__(index) Dict
+ }
+
+ class SFTDataset {
+ +__getitem__(index) Dict
+ }
+
+ class DPODataset {
+ +__getitem__(index) Dict
+ }
+
+ class GRPODataset {
+ +__getitem__(index) Dict
+ }
+
+ class BaseSegmentFetcher {
+ +List~Tensor~ segments
+ +List~int~ cum_lengths
+ +int total_length
+ +fetch_data(begin_idx, end_idx) Tensor
+ }
+
+ class MultiSegmentFetcher {
+ +Dict muti_fetchers
+ +List muti_keys
+ +key_fetch(begin_idx, end_idx, keys) Dict
+ +fetch_data(begin_idx, end_idx) Dict
+ }
+
+ class ResumableDistributedSampler {
+ +int start_epoch
+ +int start_iter
+ }
+
+ class DatasetFactory {
+ +Registry _registry
+ +register(name) decorator
+ +create(train_type, window_size, stride) BaseDataset
+ +load(train_type, load_path, window_size, stride) BaseDataset
+ }
+
+ class Checkpoint {
+ +dict state_dict
+ +int epoch
+ +int iteration
+ +save(save_dir)
+ +load(save_dir) Checkpoint
+ }
+
+ class DataLoader {
+ +Dataset dataset
+ +int batch_size
+ +Sampler sampler
+ +__iter__()
+ +__len__()
+ }
}
- %% Data Classes
- class Dataset {
- +__len__()
- +__getitem__()
+ namespace astrai.model {
+ class Transformer {
+ +ModelConfig config
+ +RotaryEmbedding rotary_embeding
+ +Embedding embed_tokens
+ +ModuleList layers
+ +RMSNorm norm
+ +Linear lm_head
+ +forward(input_ids, input_mask, persistent_key_values, start_pos) Dict
+ +load_state_dict(state_dict)
+ +state_dict()
+ }
+
+ class DecoderBlock {
+ +GQA attention
+ +RMSNorm input_norm
+ +MLP mlp
+ +RMSNorm post_attention_norm
+ +forward(x, rotary_emb, attention_mask, kv_cache, start_pos) Tensor
+ }
+
+ class GQA {
+ +int n_heads
+ +int n_kv_heads
+ +int head_dim
+ +Linear q_proj, k_proj, v_proj, o_proj
+ +RMSNorm q_norm, k_norm
+ +forward(x, rotary_emb, mask, kv_cache, start_pos) Tensor
+ }
+
+ class MLP {
+ +Linear up, gate, down
+ +forward(x) Tensor
+ }
+
+ class RMSNorm {
+ +Parameter weight
+ +float norm_eps
+ +forward(x) Tensor
+ }
+
+ class Linear {
+ +Parameter weight
+ +Parameter bias
+ +forward(x) Tensor
+ }
+
+ class RotaryEmbedding {
+ +int dim
+ +int max_len
+ +float base
+ +forward(x, start_pos) Tuple~Tensor, Tensor~
+ }
+
+ class Embedding {
+ +Parameter weight
+ +forward(x) Tensor
+ }
}
- class Checkpoint {
- +dict state_dict
- +int epoch
- +int iteration
+ namespace astrai.tokenize {
+ class Tokenizer {
+ +encode(tokens, out_ids, add_special_tokens) List~int~
+ +decode(tokens, skip_special_tokens) str
+ +__len__() int
+ }
+
+ class BpeTokenizer {
+ +List~str~ stop_ids
+ +int bos_id
+ +int eos_id
+ +int pad_id
+ +encode(tokens, out_ids, add_special_tokens) List~int~
+ +decode(tokens, skip_special_tokens) str
+ }
}
- class Tokenizer {
- +encode(text) List[int]
- +decode(ids) str
+ namespace astrai.trainer {
+ class Trainer {
+ +TrainConfig train_config
+ +List~TrainCallback~ callbacks
+ +train(checkpoint)
+ +_build_context(checkpoint) TrainContext
+ +_get_default_callbacks() List~TrainCallback~
+ }
+
+ class TrainContext {
+ +nn.Module model
+ +BaseStrategy strategy
+ +DataLoader dataloader
+ +Optimizer optimizer
+ +LRScheduler scheduler
+ +Checkpoint checkpoint
+ +int epoch
+ +int iteration
+ +float loss
+ +int world_size
+ +int rank
+ }
+
+ class TrainContextBuilder {
+ +TrainConfig config
+ +with_checkpoint(checkpoint) TrainContextBuilder
+ +with_dataloader() TrainContextBuilder
+ +with_strategy() TrainContextBuilder
+ +build() TrainContext
+ }
+
+ class BaseStrategy {
+ +nn.Module model
+ +str device
+ +compute_loss(batch) Tensor
+ }
+
+ class StrategyFactory {
+ +Registry _registry
+ +register(name) decorator
+ +create(model, train_type, device, **kwargs) BaseStrategy
+ }
+
+ class SEQStrategy {
+ +float label_smoothing
+ +compute_loss(batch) Tensor
+ }
+
+ class SFTStrategy {
+ +float label_smoothing
+ +compute_loss(batch) Tensor
+ }
+
+ class DPOStrategy {
+ +nn.Module ref_model
+ +float beta
+ +str reduction
+ +compute_loss(batch) Tensor
+ }
+
+ class GRPOStrategy {
+ +nn.Module ref_model
+ +float clip_eps
+ +float kl_coef
+ +int group_size
+ +compute_loss(batch) Tensor
+ }
+
+ class BaseScheduler {
+ +get_lr() List~float~
+ +step()
+ }
+
+ class SchedulerFactory {
+ +Registry _registry
+ +register(name) decorator
+ +create(optimizer, schedule_type, **kwargs) BaseScheduler
+ }
+
+ class CosineScheduler {
+ +int warmup_steps
+ +int lr_decay_steps
+ +float min_rate
+ }
+
+ class SGDRScheduler {
+ +int warmup_steps
+ +int cycle_length
+ +float min_rate
+ +int t_mult
+ }
+
+ class TrainCallback {
+ +on_train_begin(context)
+ +on_train_end(context)
+ +on_epoch_begin(context)
+ +on_epoch_end(context)
+ +on_step_begin(context)
+ +on_step_end(context)
+ +on_batch_begin(context)
+ +on_batch_end(context)
+ +on_error(context)
+ }
+
+ class CallbackFactory {
+ +Registry _registry
+ +register(name) decorator
+ +create(name, **kwargs) TrainCallback
+ }
}
- %% Model Classes
- class Transformer {
- +forward(input_ids, mask) Dict
+ namespace astrai.inference {
+ class GeneratorCore {
+ +ModelParameter parameter
+ +generate_iterator(input_ids, temperature, top_k, top_p, attn_mask, kv_caches, start_pos) Tuple~Tensor, int~
+ +generate_loop(input_ids, ids, temperature, top_k, top_p, attn_mask, kv_caches, start_pos, callback) List~int~
+ }
+
+ class LoopGenerator {
+ +generate(request) str
+ }
+
+ class StreamGenerator {
+ +generate(request) Generator
+ }
+
+ class BatchGenerator {
+ +generate(request) List~str~
+ }
+
+ class EmbeddingEncoder {
+ +encode(sentence) Tensor
+ }
+
+ class KVCacheManager {
+ +int batch_size
+ +Tuple kv_cache
+ +Tensor seq_mask
+ +get_kvcache() Tuple
+ +get_seq_mask() Tensor
+ +update(active_mask)
+ +reset(full_reset)
+ }
+
+ class GeneratorFactory {
+ +create(parameter, request) GeneratorCore
+ +create_encoder(parameter) EmbeddingEncoderCore
+ }
+
+ class Server {
+ +start()
+ +predict(request)
+ }
+
+ class GenerationRequest {
+ +int top_k
+ +float top_p
+ +float temperature
+ +int max_len
+ +Union~str, List~str~~ query
+ +history Optional
+ +system_prompt Optional~str~
+ +stream bool
+ }
}
- %% Trainer Classes
- class Trainer {
- +TrainConfig train_config
- +List~TrainCallback~ callbacks
- +train()
- +_build_context() TrainContext
- }
+ namespace astrai.parallel {
+ class ParallelSetup {
+ +spawn_parallel_fn(fn, nprocs)
+ +setup_parallel(rank, world_size, backend, master_addr, master_port, device_type, device_ids)
+ }
- class TrainContext {
- +nn.Module model
- +BaseStrategy strategy
- +DataLoader dataloader
- +Optimizer optimizer
- +LRScheduler scheduler
- +Checkpoint checkpoint
- +int epoch
- +int iteration
- }
+ class ColumnParallelLinear {
+ +forward(x) Tensor
+ }
- class TrainContextBuilder {
- +TrainConfig config
- +with_checkpoint(Checkpoint) TrainContextBuilder
- +with_dataloader() TrainContextBuilder
- +with_strategy() TrainContextBuilder
- +build() TrainContext
- }
-
- class BaseStrategy {
- +nn.Module model
- +str device
- +compute_loss(batch) Tensor
- }
-
- class StrategyFactory {
- +frozenset SUPPORTED_STRATEGIES
- +Dict STRATEGY_MAP
- +register(name) decorator
- +create(model, train_type, device) BaseStrategy
- +available_strategies() list
- }
-
- class SEQStrategy {
- +float label_smoothing
- +compute_loss(batch) Tensor
- }
-
- class SFTStrategy {
- +float label_smoothing
- +compute_loss(batch) Tensor
- }
-
- class DPOStrategy {
- +nn.Module ref_model
- +float beta
- +str reduction
- +compute_loss(batch) Tensor
- }
-
- class GRPOStrategy {
- +nn.Module ref_model
- +float clip_eps
- +float kl_coef
- +int group_size
- +compute_loss(batch) Tensor
- }
-
- class TrainCallback {
- +on_train_begin(trainer)
- +on_train_end(trainer)
- +on_epoch_begin(epoch, trainer)
- +on_epoch_end(epoch, trainer)
- +on_batch_begin(batch, trainer)
- +on_batch_end(batch, trainer)
- }
-
- class Schedule {
- +step()
- }
-
- %% Inference Classes
- class Generator {
- +generate(prompt, config) str
- +generate_batch(prompts, config) List[str]
- +stream_generate(prompt, config) Generator
- }
-
- class InferenceCore {
- +forward(input_ids) Dict
- +apply_sampling_strategies()
- }
-
- class Server {
- +start()
- +predict(request)
- }
-
- %% Parallel Classes
- class ParallelSetup {
- +spawn_parallel_fn(fn, nprocs)
+ class RowParallelLinear {
+ +forward(x) Tensor
+ }
}
%% Relationships
TrainConfig --> ModelConfig : contains
- TrainConfig --> Dataset : uses
+ TrainConfig --> BaseDataset : uses
TrainConfig --> Transformer : uses
Trainer --> TrainConfig : configures
Trainer --> TrainContextBuilder : builds
Trainer --> TrainCallback : manages
TrainContextBuilder --> TrainContext : creates
TrainContext --> Checkpoint : manages
+ TrainContext --> BaseStrategy : uses
+ TrainContext --> BaseScheduler : uses
StrategyFactory ..> BaseStrategy : creates
BaseStrategy <|-- SEQStrategy
BaseStrategy <|-- SFTStrategy
BaseStrategy <|-- DPOStrategy
BaseStrategy <|-- GRPOStrategy
- TrainContext --> BaseStrategy : uses
- Generator --> InferenceCore : uses
- Generator --> Transformer : uses
- Server --> Generator : uses
+ SchedulerFactory ..> BaseScheduler : creates
+ BaseScheduler <|-- CosineScheduler
+ BaseScheduler <|-- SGDRScheduler
+ CallbackFactory ..> TrainCallback : creates
+ GeneratorFactory ..> GeneratorCore : creates
+ GeneratorCore <|-- LoopGenerator
+ GeneratorCore <|-- StreamGenerator
+ GeneratorCore <|-- BatchGenerator
+ GeneratorCore <|-- EmbeddingEncoder
+ LoopGenerator --> KVCacheManager : uses
+ StreamGenerator --> KVCacheManager : uses
+ BatchGenerator --> KVCacheManager : uses
+ GeneratorCore --> Transformer : uses
+ DPOStrategy --> Transformer : creates ref_model
+ GRPOStrategy --> Transformer : creates ref_model
+ Server --> GeneratorFactory : uses
ParallelSetup --> Trainer : enables
TrainConfig --> StrategyFactory : selects
- TrainCallback <|-- CheckpointCallback
- TrainCallback <|-- MetricLoggerCallback
- TrainCallback <|-- SchedulerCallback
- TrainContext --> Schedule : uses
+ ModelParameter --> Transformer : contains
+ ModelParameter --> BpeTokenizer : contains
+ ModelParameter --> ModelConfig : contains
+ GeneratorFactory --> GenerationRequest : uses
+ BaseDataset <|-- SEQDataset
+ BaseDataset <|-- SFTDataset
+ BaseDataset <|-- DPODataset
+ BaseDataset <|-- GRPODataset
+ DatasetFactory ..> BaseDataset : creates
+ BaseSegmentFetcher --> MultiSegmentFetcher : used by
+ MultiSegmentFetcher --> BaseDataset : used by
+ Transformer --> DecoderBlock : uses
+ Transformer --> RotaryEmbedding : uses
+ Transformer --> Embedding : uses
+ DecoderBlock --> GQA : uses
+ DecoderBlock --> MLP : uses
+ DecoderBlock --> RMSNorm : uses
+ BpeTokenizer --> Tokenizer : inherits
+ TrainContextBuilder --> ResumableDistributedSampler : creates
+ DataLoader --> BaseDataset : uses
+ ResumableDistributedSampler --> BaseDataset : samples
```
-### Design Pattern Summary
+### Module Overview
+
+| Module | Components | Description |
+|--------|------------|-------------|
+| **astrai.config** | ModelConfig, TrainConfig, ModelParameter | Configuration management |
+| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint, DataLoader | Dataset loading and management |
+| **astrai.model** | Transformer, DecoderBlock, GQA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
+| **astrai.tokenize** | Tokenizer, BpeTokenizer | Tokenizer |
+| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
+| **astrai.inference** | GeneratorCore, LoopGenerator, StreamGenerator, BatchGenerator, EmbeddingEncoder, KVCacheManager, GeneratorFactory, Server, GenerationRequest | Inference service |
+| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
+
+### Design Patterns
| Pattern | Classes | Purpose |
|---------|---------|---------|
| **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO |
| **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components |
-| **Factory** | `StrategyFactory` | Decorator registration mechanism, dynamically create training strategies |
-| **Observer** | `TrainCallback` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
+| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `GeneratorFactory`, `CallbackFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, generators, and callbacks |
+| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
| **Singleton** | `TrainContext` | Training process global state management |
+| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
### Core Relationships
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
-4. **Inference Flow**: `Server` → `Generator` → `InferenceCore` → `Transformer`
+4. **Inference Flow**: `Server` → `GeneratorFactory` → `GeneratorCore` → `Transformer`, supports multiple generators (LoopGenerator, StreamGenerator, BatchGenerator)
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
+6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
+7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
+8. **Scheduler Support**: `SchedulerFactory` creates learning rate schedulers (CosineScheduler, SGDRScheduler)
## 3. Training Process