docs: 更新文档类图等

This commit is contained in:
ViperEkura 2026-04-04 18:11:36 +08:00
parent c94a246c71
commit 99b821ebf5
2 changed files with 453 additions and 184 deletions

View File

@ -5,7 +5,7 @@ This document describes the data flow of the AstrAI project (a training and infe
## Overview ## Overview
AstrAI adopts a modular design with the following main components: AstrAI adopts a modular design with the following main components:
- **Data Module** (`astrai/data/`): Dataset, sampler, tokenizer, serialization tools - **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools
- **Model Module** (`astrai/model/`): Transformer model and its submodules - **Model Module** (`astrai/model/`): Transformer model and its submodules
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers - **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers
- **Inference Module** (`astrai/inference/`): Generation core, KV cache management, streaming generation - **Inference Module** (`astrai/inference/`): Generation core, KV cache management, streaming generation
@ -20,35 +20,39 @@ The data flow can generally be divided into two main lines: **Training Data Flow
flowchart LR flowchart LR
subgraph A[Data Preparation] subgraph A[Data Preparation]
direction TB direction TB
A1[Raw Text] --> A2[BBPE Tokenizer] A1[Raw Text] --> A2[BpeTokenizer]
A2 --> A3[Serialize to .h5 files] A2 --> A3[Serialize to .h5 files]
A3 --> A4[Dataset Loading<br/>BaseDataset] A3 --> A4[BaseDataset]
A4 --> A5[Resumable Distributed Sampler<br/>ResumableDistributedSampler] A4 --> A5[ResumableDistributedSampler]
A5 --> A6[DataLoader Batch Loading] A5 --> A6[DataLoader]
end end
subgraph B[Training Loop] subgraph B[Training]
direction TB direction TB
B1[Batch Data] --> B2[Training Strategy<br/>BaseStrategy] B1[Batch Data] --> B2[TrainContextBuilder]
B2 --> B3[Transformer Model] B2 --> B3[TrainContext]
B3 --> B4[Output logits] B3 --> B4[BaseStrategy]
B4 --> B5[Loss Calculation] B4 --> B5[Transformer]
B5 --> B6[Backpropagation] B5 --> B6[Compute Loss]
B6 --> B7[Optimizer Update] B6 --> B7[Backward]
B7 --> B8[Learning Rate Scheduler] B7 --> B8[Optimizer]
B8 --> B9[Checkpoint Save] B8 --> B9[LRScheduler]
B9 --> B10[CheckpointCallback]
end end
subgraph C[Inference Generation] subgraph C[Inference]
direction TB direction TB
C1[Checkpoint Loading] --> C2[Inference Model Loading] C1[Checkpoint] --> C2[ModelParameter]
C2 --> C3[Generation Core<br/>GeneratorCore] C2 --> C3[Transformer + BpeTokenizer]
C3 --> C4[Sampling Strategy<br/>Temperature/top-k/top-p] C3 --> C4[GenerationRequest + build_prompt]
C4 --> C5[Generate Next Token] C4 --> C5[GeneratorFactory]
C5 --> C6[KV Cache Update] C5 --> C6[GeneratorCore]
C6 --> C7{Max Length Reached?} C6 --> C7[apply_sampling_strategies]
C7 -->|No| C5 C7 --> C8[Transformer Forward]
C7 -->|Yes| C8[Output Generated Text] C8 --> C9[KVCacheManager]
C9 --> C10{End Condition?}
C10 -->|No| C8
C10 -->|Yes| C11[Output Text]
end end
A --> B A --> B
@ -57,13 +61,14 @@ flowchart LR
## Detailed Module Descriptions ## Detailed Module Descriptions
### 1. Data Module ### 1. Dataset Module
#### 1.1 Tokenizer (`tokenizer.py`) #### 1.1 Tokenizer (`tokenizer.py`)
- Implemented based on Byte-Level BPE (BBPE) - Implemented based on Byte-Level BPE (BPE)
- Supports special tokens: `<begin▁of▁sentence>`, `<end▁of▁sentence>`, `<▁pad▁>`, `<im▁start>`, `<im▁end>` - Supports special tokens: `<begin▁of▁sentence>`, `<end▁of▁sentence>`, `<▁pad▁>`, `<im▁start>`, `<im▁end>`
- Provides `encode`/`decode` methods for mutual conversion between text and token IDs - Provides `encode`/`decode` methods for mutual conversion between text and token IDs
- Learns vocabulary from corpus during training, saved as `.json` files - Learns vocabulary from corpus during training, saved as `.json` files
- `BpeTrainer` class handles vocabulary training from corpus
#### 1.2 Serialization (`serialization.py`) #### 1.2 Serialization (`serialization.py`)
- **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors - **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors
@ -108,7 +113,7 @@ flowchart LR
1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end` 1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end`
#### 3.3 Strategy (`strategy.py`) #### 3.3 Strategy (`strategy.py`)
- **`BaseStrategy`**: Defines training strategy interface (such as `SeqStrategy`, `SFTStrategy`, `DPOStrategy`) - **`BaseStrategy`**: Defines training strategy interface (such as `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`)
- Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor - Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor
- Created dynamically by `StrategyFactory` according to configuration - Created dynamically by `StrategyFactory` according to configuration
@ -130,14 +135,16 @@ flowchart LR
#### 4.3 Generator (`generator.py`) #### 4.3 Generator (`generator.py`)
- **`GenerationRequest`**: Encapsulates generation request parameters (top_k, top_p, temperature, max_len, query, history, etc.) - **`GenerationRequest`**: Encapsulates generation request parameters (top_k, top_p, temperature, max_len, query, history, etc.)
- **`build_prompt`**: Converts query and history into ChatML format prompt string - **`build_prompt`** (from `chat_template.py`): Converts query and history into ChatML format prompt string
- **`GeneratorCore`**: Base generator with generate_iterator and generate_loop methods
- **`LoopGenerator`**, **`StreamGenerator`**, **`BatchGenerator`**: Different generation modes
- **`pad_sequence`**: Pads input IDs to consistent length - **`pad_sequence`**: Pads input IDs to consistent length
- Provides streaming and non-streaming generation interfaces - Provides streaming and non-streaming generation interfaces
## Training Data Flow - Detailed Steps ## Training Data Flow - Detailed Steps
1. **Data Preparation** 1. **Data Preparation**
- Raw text is converted to token ID sequences through BBPE tokenizer - Raw text is converted to token ID sequences through BPE tokenizer
- Token ID sequences (possibly with masks, labels, etc.) are saved by groups as `.h5` files - Token ID sequences (possibly with masks, labels, etc.) are saved by groups as `.h5` files
- Files can contain multiple segments, each segment corresponds to a tensor - Files can contain multiple segments, each segment corresponds to a tensor
@ -152,7 +159,7 @@ flowchart LR
- Batch data shape is `[batch_size, window_size]` (or varies according to specific dataset type) - Batch data shape is `[batch_size, window_size]` (or varies according to specific dataset type)
4. **Strategy Forward and Loss Calculation** 4. **Strategy Forward and Loss Calculation**
- Batch data is passed to strategy (such as `SeqStrategy`) - Batch data is passed to strategy (such as `SEQStrategy`)
- Strategy internally calls `Transformer` model, obtaining logits - Strategy internally calls `Transformer` model, obtaining logits
- Calculate cross-entropy loss (or DPO loss, etc.) according to task type - Calculate cross-entropy loss (or DPO loss, etc.) according to task type
- Return loss tensor - Return loss tensor
@ -174,7 +181,7 @@ flowchart LR
- Set model to evaluation mode (`model.eval()`), enable inference mode (`torch.inference_mode`) - Set model to evaluation mode (`model.eval()`), enable inference mode (`torch.inference_mode`)
2. **Prompt Construction and Encoding** 2. **Prompt Construction and Encoding**
- User query and history are converted to ChatML format string through `build_prompt` - User query and history are converted to ChatML format string through `build_prompt` function in chat_template module
- Tokenizer encodes prompt string to token ID sequence `input_ids` - Tokenizer encodes prompt string to token ID sequence `input_ids`
- For batch generation, use `pad_sequence` for padding - For batch generation, use `pad_sequence` for padding

View File

@ -8,13 +8,14 @@ Thus, the AstrAI project was born - 1B parameters, Chinese-English bilingual, su
```mermaid ```mermaid
classDiagram classDiagram
%% Configuration Classes namespace astrai.config {
class ModelConfig { class ModelConfig {
+int vocab_size +int vocab_size
+int dim +int dim
+int n_layers +int n_layers
+float norm_eps +float norm_eps
+int dim_ffn +int dim_ffn
+bool tie_weight
+int max_len +int max_len
+float rope_theta +float rope_theta
+int n_heads +int n_heads
@ -42,34 +43,169 @@ classDiagram
+validate() +validate()
} }
%% Data Classes class ModelParameter {
class Dataset { +nn.Module model
+BpeTokenizer tokenizer
+ModelConfig config
+save(instance, save_dir)
+load(load_dir, disable_init) ModelParameter
+to(*args, **kwargs)
}
}
namespace astrai.dataset {
class BaseDataset {
+int window_size
+int stride
+MultiSegmentFetcher fetcher
+load(load_path)
+__getitem__(index)
+__len__() +__len__()
+__getitem__() }
class SEQDataset {
+__getitem__(index) Dict
}
class SFTDataset {
+__getitem__(index) Dict
}
class DPODataset {
+__getitem__(index) Dict
}
class GRPODataset {
+__getitem__(index) Dict
}
class BaseSegmentFetcher {
+List~Tensor~ segments
+List~int~ cum_lengths
+int total_length
+fetch_data(begin_idx, end_idx) Tensor
}
class MultiSegmentFetcher {
+Dict muti_fetchers
+List muti_keys
+key_fetch(begin_idx, end_idx, keys) Dict
+fetch_data(begin_idx, end_idx) Dict
}
class ResumableDistributedSampler {
+int start_epoch
+int start_iter
}
class DatasetFactory {
+Registry _registry
+register(name) decorator
+create(train_type, window_size, stride) BaseDataset
+load(train_type, load_path, window_size, stride) BaseDataset
} }
class Checkpoint { class Checkpoint {
+dict state_dict +dict state_dict
+int epoch +int epoch
+int iteration +int iteration
+save(save_dir)
+load(save_dir) Checkpoint
} }
class Tokenizer { class DataLoader {
+encode(text) List[int] +Dataset dataset
+decode(ids) str +int batch_size
+Sampler sampler
+__iter__()
+__len__()
}
} }
%% Model Classes namespace astrai.model {
class Transformer { class Transformer {
+forward(input_ids, mask) Dict +ModelConfig config
+RotaryEmbedding rotary_embeding
+Embedding embed_tokens
+ModuleList layers
+RMSNorm norm
+Linear lm_head
+forward(input_ids, input_mask, persistent_key_values, start_pos) Dict
+load_state_dict(state_dict)
+state_dict()
} }
%% Trainer Classes class DecoderBlock {
+GQA attention
+RMSNorm input_norm
+MLP mlp
+RMSNorm post_attention_norm
+forward(x, rotary_emb, attention_mask, kv_cache, start_pos) Tensor
}
class GQA {
+int n_heads
+int n_kv_heads
+int head_dim
+Linear q_proj, k_proj, v_proj, o_proj
+RMSNorm q_norm, k_norm
+forward(x, rotary_emb, mask, kv_cache, start_pos) Tensor
}
class MLP {
+Linear up, gate, down
+forward(x) Tensor
}
class RMSNorm {
+Parameter weight
+float norm_eps
+forward(x) Tensor
}
class Linear {
+Parameter weight
+Parameter bias
+forward(x) Tensor
}
class RotaryEmbedding {
+int dim
+int max_len
+float base
+forward(x, start_pos) Tuple~Tensor, Tensor~
}
class Embedding {
+Parameter weight
+forward(x) Tensor
}
}
namespace astrai.tokenize {
class Tokenizer {
+encode(tokens, out_ids, add_special_tokens) List~int~
+decode(tokens, skip_special_tokens) str
+__len__() int
}
class BpeTokenizer {
+List~str~ stop_ids
+int bos_id
+int eos_id
+int pad_id
+encode(tokens, out_ids, add_special_tokens) List~int~
+decode(tokens, skip_special_tokens) str
}
}
namespace astrai.trainer {
class Trainer { class Trainer {
+TrainConfig train_config +TrainConfig train_config
+List~TrainCallback~ callbacks +List~TrainCallback~ callbacks
+train() +train(checkpoint)
+_build_context() TrainContext +_build_context(checkpoint) TrainContext
+_get_default_callbacks() List~TrainCallback~
} }
class TrainContext { class TrainContext {
@ -81,11 +217,14 @@ classDiagram
+Checkpoint checkpoint +Checkpoint checkpoint
+int epoch +int epoch
+int iteration +int iteration
+float loss
+int world_size
+int rank
} }
class TrainContextBuilder { class TrainContextBuilder {
+TrainConfig config +TrainConfig config
+with_checkpoint(Checkpoint) TrainContextBuilder +with_checkpoint(checkpoint) TrainContextBuilder
+with_dataloader() TrainContextBuilder +with_dataloader() TrainContextBuilder
+with_strategy() TrainContextBuilder +with_strategy() TrainContextBuilder
+build() TrainContext +build() TrainContext
@ -98,11 +237,9 @@ classDiagram
} }
class StrategyFactory { class StrategyFactory {
+frozenset SUPPORTED_STRATEGIES +Registry _registry
+Dict STRATEGY_MAP
+register(name) decorator +register(name) decorator
+create(model, train_type, device) BaseStrategy +create(model, train_type, device, **kwargs) BaseStrategy
+available_strategies() list
} }
class SEQStrategy { class SEQStrategy {
@ -130,29 +267,85 @@ classDiagram
+compute_loss(batch) Tensor +compute_loss(batch) Tensor
} }
class TrainCallback { class BaseScheduler {
+on_train_begin(trainer) +get_lr() List~float~
+on_train_end(trainer)
+on_epoch_begin(epoch, trainer)
+on_epoch_end(epoch, trainer)
+on_batch_begin(batch, trainer)
+on_batch_end(batch, trainer)
}
class Schedule {
+step() +step()
} }
%% Inference Classes class SchedulerFactory {
class Generator { +Registry _registry
+generate(prompt, config) str +register(name) decorator
+generate_batch(prompts, config) List[str] +create(optimizer, schedule_type, **kwargs) BaseScheduler
+stream_generate(prompt, config) Generator
} }
class InferenceCore { class CosineScheduler {
+forward(input_ids) Dict +int warmup_steps
+apply_sampling_strategies() +int lr_decay_steps
+float min_rate
}
class SGDRScheduler {
+int warmup_steps
+int cycle_length
+float min_rate
+int t_mult
}
class TrainCallback {
+on_train_begin(context)
+on_train_end(context)
+on_epoch_begin(context)
+on_epoch_end(context)
+on_step_begin(context)
+on_step_end(context)
+on_batch_begin(context)
+on_batch_end(context)
+on_error(context)
}
class CallbackFactory {
+Registry _registry
+register(name) decorator
+create(name, **kwargs) TrainCallback
}
}
namespace astrai.inference {
class GeneratorCore {
+ModelParameter parameter
+generate_iterator(input_ids, temperature, top_k, top_p, attn_mask, kv_caches, start_pos) Tuple~Tensor, int~
+generate_loop(input_ids, ids, temperature, top_k, top_p, attn_mask, kv_caches, start_pos, callback) List~int~
}
class LoopGenerator {
+generate(request) str
}
class StreamGenerator {
+generate(request) Generator
}
class BatchGenerator {
+generate(request) List~str~
}
class EmbeddingEncoder {
+encode(sentence) Tensor
}
class KVCacheManager {
+int batch_size
+Tuple kv_cache
+Tensor seq_mask
+get_kvcache() Tuple
+get_seq_mask() Tensor
+update(active_mask)
+reset(full_reset)
}
class GeneratorFactory {
+create(parameter, request) GeneratorCore
+create_encoder(parameter) EmbeddingEncoderCore
} }
class Server { class Server {
@ -160,54 +353,123 @@ classDiagram
+predict(request) +predict(request)
} }
%% Parallel Classes class GenerationRequest {
+int top_k
+float top_p
+float temperature
+int max_len
+Union~str, List~str~~ query
+history Optional
+system_prompt Optional~str~
+stream bool
}
}
namespace astrai.parallel {
class ParallelSetup { class ParallelSetup {
+spawn_parallel_fn(fn, nprocs) +spawn_parallel_fn(fn, nprocs)
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type, device_ids)
}
class ColumnParallelLinear {
+forward(x) Tensor
}
class RowParallelLinear {
+forward(x) Tensor
}
} }
%% Relationships %% Relationships
TrainConfig --> ModelConfig : contains TrainConfig --> ModelConfig : contains
TrainConfig --> Dataset : uses TrainConfig --> BaseDataset : uses
TrainConfig --> Transformer : uses TrainConfig --> Transformer : uses
Trainer --> TrainConfig : configures Trainer --> TrainConfig : configures
Trainer --> TrainContextBuilder : builds Trainer --> TrainContextBuilder : builds
Trainer --> TrainCallback : manages Trainer --> TrainCallback : manages
TrainContextBuilder --> TrainContext : creates TrainContextBuilder --> TrainContext : creates
TrainContext --> Checkpoint : manages TrainContext --> Checkpoint : manages
TrainContext --> BaseStrategy : uses
TrainContext --> BaseScheduler : uses
StrategyFactory ..> BaseStrategy : creates StrategyFactory ..> BaseStrategy : creates
BaseStrategy <|-- SEQStrategy BaseStrategy <|-- SEQStrategy
BaseStrategy <|-- SFTStrategy BaseStrategy <|-- SFTStrategy
BaseStrategy <|-- DPOStrategy BaseStrategy <|-- DPOStrategy
BaseStrategy <|-- GRPOStrategy BaseStrategy <|-- GRPOStrategy
TrainContext --> BaseStrategy : uses SchedulerFactory ..> BaseScheduler : creates
Generator --> InferenceCore : uses BaseScheduler <|-- CosineScheduler
Generator --> Transformer : uses BaseScheduler <|-- SGDRScheduler
Server --> Generator : uses CallbackFactory ..> TrainCallback : creates
GeneratorFactory ..> GeneratorCore : creates
GeneratorCore <|-- LoopGenerator
GeneratorCore <|-- StreamGenerator
GeneratorCore <|-- BatchGenerator
GeneratorCore <|-- EmbeddingEncoder
LoopGenerator --> KVCacheManager : uses
StreamGenerator --> KVCacheManager : uses
BatchGenerator --> KVCacheManager : uses
GeneratorCore --> Transformer : uses
DPOStrategy --> Transformer : creates ref_model
GRPOStrategy --> Transformer : creates ref_model
Server --> GeneratorFactory : uses
ParallelSetup --> Trainer : enables ParallelSetup --> Trainer : enables
TrainConfig --> StrategyFactory : selects TrainConfig --> StrategyFactory : selects
TrainCallback <|-- CheckpointCallback ModelParameter --> Transformer : contains
TrainCallback <|-- MetricLoggerCallback ModelParameter --> BpeTokenizer : contains
TrainCallback <|-- SchedulerCallback ModelParameter --> ModelConfig : contains
TrainContext --> Schedule : uses GeneratorFactory --> GenerationRequest : uses
BaseDataset <|-- SEQDataset
BaseDataset <|-- SFTDataset
BaseDataset <|-- DPODataset
BaseDataset <|-- GRPODataset
DatasetFactory ..> BaseDataset : creates
BaseSegmentFetcher --> MultiSegmentFetcher : used by
MultiSegmentFetcher --> BaseDataset : used by
Transformer --> DecoderBlock : uses
Transformer --> RotaryEmbedding : uses
Transformer --> Embedding : uses
DecoderBlock --> GQA : uses
DecoderBlock --> MLP : uses
DecoderBlock --> RMSNorm : uses
BpeTokenizer --> Tokenizer : inherits
TrainContextBuilder --> ResumableDistributedSampler : creates
DataLoader --> BaseDataset : uses
ResumableDistributedSampler --> BaseDataset : samples
``` ```
### Design Pattern Summary ### Module Overview
| Module | Components | Description |
|--------|------------|-------------|
| **astrai.config** | ModelConfig, TrainConfig, ModelParameter | Configuration management |
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint, DataLoader | Dataset loading and management |
| **astrai.model** | Transformer, DecoderBlock, GQA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | Tokenizer, BpeTokenizer | Tokenizer |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
| **astrai.inference** | GeneratorCore, LoopGenerator, StreamGenerator, BatchGenerator, EmbeddingEncoder, KVCacheManager, GeneratorFactory, Server, GenerationRequest | Inference service |
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
### Design Patterns
| Pattern | Classes | Purpose | | Pattern | Classes | Purpose |
|---------|---------|---------| |---------|---------|---------|
| **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO | | **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO |
| **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components | | **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components |
| **Factory** | `StrategyFactory` | Decorator registration mechanism, dynamically create training strategies | | **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `GeneratorFactory`, `CallbackFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, generators, and callbacks |
| **Observer** | `TrainCallback` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) | | **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
| **Singleton** | `TrainContext` | Training process global state management | | **Singleton** | `TrainContext` | Training process global state management |
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
### Core Relationships ### Core Relationships
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references 1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` to compute loss 2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` to compute loss
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type` 3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
4. **Inference Flow**: `Server``Generator` → `InferenceCore``Transformer` 4. **Inference Flow**: `Server``GeneratorFactory` → `GeneratorCore``Transformer`, supports multiple generators (LoopGenerator, StreamGenerator, BatchGenerator)
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer` 5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
8. **Scheduler Support**: `SchedulerFactory` creates learning rate schedulers (CosineScheduler, SGDRScheduler)
## 3. Training Process ## 3. Training Process