docs: 更新文档类图等
This commit is contained in:
parent
c94a246c71
commit
99b821ebf5
|
|
@ -5,7 +5,7 @@ This document describes the data flow of the AstrAI project (a training and infe
|
|||
## Overview
|
||||
|
||||
AstrAI adopts a modular design with the following main components:
|
||||
- **Data Module** (`astrai/data/`): Dataset, sampler, tokenizer, serialization tools
|
||||
- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools
|
||||
- **Model Module** (`astrai/model/`): Transformer model and its submodules
|
||||
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers
|
||||
- **Inference Module** (`astrai/inference/`): Generation core, KV cache management, streaming generation
|
||||
|
|
@ -20,35 +20,39 @@ The data flow can generally be divided into two main lines: **Training Data Flow
|
|||
flowchart LR
|
||||
subgraph A[Data Preparation]
|
||||
direction TB
|
||||
A1[Raw Text] --> A2[BBPE Tokenizer]
|
||||
A1[Raw Text] --> A2[BpeTokenizer]
|
||||
A2 --> A3[Serialize to .h5 files]
|
||||
A3 --> A4[Dataset Loading<br/>BaseDataset]
|
||||
A4 --> A5[Resumable Distributed Sampler<br/>ResumableDistributedSampler]
|
||||
A5 --> A6[DataLoader Batch Loading]
|
||||
A3 --> A4[BaseDataset]
|
||||
A4 --> A5[ResumableDistributedSampler]
|
||||
A5 --> A6[DataLoader]
|
||||
end
|
||||
|
||||
subgraph B[Training Loop]
|
||||
subgraph B[Training]
|
||||
direction TB
|
||||
B1[Batch Data] --> B2[Training Strategy<br/>BaseStrategy]
|
||||
B2 --> B3[Transformer Model]
|
||||
B3 --> B4[Output logits]
|
||||
B4 --> B5[Loss Calculation]
|
||||
B5 --> B6[Backpropagation]
|
||||
B6 --> B7[Optimizer Update]
|
||||
B7 --> B8[Learning Rate Scheduler]
|
||||
B8 --> B9[Checkpoint Save]
|
||||
B1[Batch Data] --> B2[TrainContextBuilder]
|
||||
B2 --> B3[TrainContext]
|
||||
B3 --> B4[BaseStrategy]
|
||||
B4 --> B5[Transformer]
|
||||
B5 --> B6[Compute Loss]
|
||||
B6 --> B7[Backward]
|
||||
B7 --> B8[Optimizer]
|
||||
B8 --> B9[LRScheduler]
|
||||
B9 --> B10[CheckpointCallback]
|
||||
end
|
||||
|
||||
subgraph C[Inference Generation]
|
||||
subgraph C[Inference]
|
||||
direction TB
|
||||
C1[Checkpoint Loading] --> C2[Inference Model Loading]
|
||||
C2 --> C3[Generation Core<br/>GeneratorCore]
|
||||
C3 --> C4[Sampling Strategy<br/>Temperature/top-k/top-p]
|
||||
C4 --> C5[Generate Next Token]
|
||||
C5 --> C6[KV Cache Update]
|
||||
C6 --> C7{Max Length Reached?}
|
||||
C7 -->|No| C5
|
||||
C7 -->|Yes| C8[Output Generated Text]
|
||||
C1[Checkpoint] --> C2[ModelParameter]
|
||||
C2 --> C3[Transformer + BpeTokenizer]
|
||||
C3 --> C4[GenerationRequest + build_prompt]
|
||||
C4 --> C5[GeneratorFactory]
|
||||
C5 --> C6[GeneratorCore]
|
||||
C6 --> C7[apply_sampling_strategies]
|
||||
C7 --> C8[Transformer Forward]
|
||||
C8 --> C9[KVCacheManager]
|
||||
C9 --> C10{End Condition?}
|
||||
C10 -->|No| C8
|
||||
C10 -->|Yes| C11[Output Text]
|
||||
end
|
||||
|
||||
A --> B
|
||||
|
|
@ -57,13 +61,14 @@ flowchart LR
|
|||
|
||||
## Detailed Module Descriptions
|
||||
|
||||
### 1. Data Module
|
||||
### 1. Dataset Module
|
||||
|
||||
#### 1.1 Tokenizer (`tokenizer.py`)
|
||||
- Implemented based on Byte-Level BPE (BBPE)
|
||||
- Implemented based on Byte-Level BPE (BPE)
|
||||
- Supports special tokens: `<|begin▁of▁sentence|>`, `<|end▁of▁sentence|>`, `<|▁pad▁|>`, `<|im▁start|>`, `<|im▁end|>`
|
||||
- Provides `encode`/`decode` methods for mutual conversion between text and token IDs
|
||||
- Learns vocabulary from corpus during training, saved as `.json` files
|
||||
- `BpeTrainer` class handles vocabulary training from corpus
|
||||
|
||||
#### 1.2 Serialization (`serialization.py`)
|
||||
- **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors
|
||||
|
|
@ -108,7 +113,7 @@ flowchart LR
|
|||
1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end`
|
||||
|
||||
#### 3.3 Strategy (`strategy.py`)
|
||||
- **`BaseStrategy`**: Defines training strategy interface (such as `SeqStrategy`, `SFTStrategy`, `DPOStrategy`)
|
||||
- **`BaseStrategy`**: Defines training strategy interface (such as `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`)
|
||||
- Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor
|
||||
- Created dynamically by `StrategyFactory` according to configuration
|
||||
|
||||
|
|
@ -130,14 +135,16 @@ flowchart LR
|
|||
|
||||
#### 4.3 Generator (`generator.py`)
|
||||
- **`GenerationRequest`**: Encapsulates generation request parameters (top_k, top_p, temperature, max_len, query, history, etc.)
|
||||
- **`build_prompt`**: Converts query and history into ChatML format prompt string
|
||||
- **`build_prompt`** (from `chat_template.py`): Converts query and history into ChatML format prompt string
|
||||
- **`GeneratorCore`**: Base generator with generate_iterator and generate_loop methods
|
||||
- **`LoopGenerator`**, **`StreamGenerator`**, **`BatchGenerator`**: Different generation modes
|
||||
- **`pad_sequence`**: Pads input IDs to consistent length
|
||||
- Provides streaming and non-streaming generation interfaces
|
||||
|
||||
## Training Data Flow - Detailed Steps
|
||||
|
||||
1. **Data Preparation**
|
||||
- Raw text is converted to token ID sequences through BBPE tokenizer
|
||||
- Raw text is converted to token ID sequences through BPE tokenizer
|
||||
- Token ID sequences (possibly with masks, labels, etc.) are saved by groups as `.h5` files
|
||||
- Files can contain multiple segments, each segment corresponds to a tensor
|
||||
|
||||
|
|
@ -152,7 +159,7 @@ flowchart LR
|
|||
- Batch data shape is `[batch_size, window_size]` (or varies according to specific dataset type)
|
||||
|
||||
4. **Strategy Forward and Loss Calculation**
|
||||
- Batch data is passed to strategy (such as `SeqStrategy`)
|
||||
- Batch data is passed to strategy (such as `SEQStrategy`)
|
||||
- Strategy internally calls `Transformer` model, obtaining logits
|
||||
- Calculate cross-entropy loss (or DPO loss, etc.) according to task type
|
||||
- Return loss tensor
|
||||
|
|
@ -174,7 +181,7 @@ flowchart LR
|
|||
- Set model to evaluation mode (`model.eval()`), enable inference mode (`torch.inference_mode`)
|
||||
|
||||
2. **Prompt Construction and Encoding**
|
||||
- User query and history are converted to ChatML format string through `build_prompt`
|
||||
- User query and history are converted to ChatML format string through `build_prompt` function in chat_template module
|
||||
- Tokenizer encodes prompt string to token ID sequence `input_ids`
|
||||
- For batch generation, use `pad_sequence` for padding
|
||||
|
||||
|
|
|
|||
|
|
@ -8,13 +8,14 @@ Thus, the AstrAI project was born - 1B parameters, Chinese-English bilingual, su
|
|||
|
||||
```mermaid
|
||||
classDiagram
|
||||
%% Configuration Classes
|
||||
namespace astrai.config {
|
||||
class ModelConfig {
|
||||
+int vocab_size
|
||||
+int dim
|
||||
+int n_layers
|
||||
+float norm_eps
|
||||
+int dim_ffn
|
||||
+bool tie_weight
|
||||
+int max_len
|
||||
+float rope_theta
|
||||
+int n_heads
|
||||
|
|
@ -42,34 +43,169 @@ classDiagram
|
|||
+validate()
|
||||
}
|
||||
|
||||
%% Data Classes
|
||||
class Dataset {
|
||||
class ModelParameter {
|
||||
+nn.Module model
|
||||
+BpeTokenizer tokenizer
|
||||
+ModelConfig config
|
||||
+save(instance, save_dir)
|
||||
+load(load_dir, disable_init) ModelParameter
|
||||
+to(*args, **kwargs)
|
||||
}
|
||||
}
|
||||
|
||||
namespace astrai.dataset {
|
||||
class BaseDataset {
|
||||
+int window_size
|
||||
+int stride
|
||||
+MultiSegmentFetcher fetcher
|
||||
+load(load_path)
|
||||
+__getitem__(index)
|
||||
+__len__()
|
||||
+__getitem__()
|
||||
}
|
||||
|
||||
class SEQDataset {
|
||||
+__getitem__(index) Dict
|
||||
}
|
||||
|
||||
class SFTDataset {
|
||||
+__getitem__(index) Dict
|
||||
}
|
||||
|
||||
class DPODataset {
|
||||
+__getitem__(index) Dict
|
||||
}
|
||||
|
||||
class GRPODataset {
|
||||
+__getitem__(index) Dict
|
||||
}
|
||||
|
||||
class BaseSegmentFetcher {
|
||||
+List~Tensor~ segments
|
||||
+List~int~ cum_lengths
|
||||
+int total_length
|
||||
+fetch_data(begin_idx, end_idx) Tensor
|
||||
}
|
||||
|
||||
class MultiSegmentFetcher {
|
||||
+Dict muti_fetchers
|
||||
+List muti_keys
|
||||
+key_fetch(begin_idx, end_idx, keys) Dict
|
||||
+fetch_data(begin_idx, end_idx) Dict
|
||||
}
|
||||
|
||||
class ResumableDistributedSampler {
|
||||
+int start_epoch
|
||||
+int start_iter
|
||||
}
|
||||
|
||||
class DatasetFactory {
|
||||
+Registry _registry
|
||||
+register(name) decorator
|
||||
+create(train_type, window_size, stride) BaseDataset
|
||||
+load(train_type, load_path, window_size, stride) BaseDataset
|
||||
}
|
||||
|
||||
class Checkpoint {
|
||||
+dict state_dict
|
||||
+int epoch
|
||||
+int iteration
|
||||
+save(save_dir)
|
||||
+load(save_dir) Checkpoint
|
||||
}
|
||||
|
||||
class Tokenizer {
|
||||
+encode(text) List[int]
|
||||
+decode(ids) str
|
||||
class DataLoader {
|
||||
+Dataset dataset
|
||||
+int batch_size
|
||||
+Sampler sampler
|
||||
+__iter__()
|
||||
+__len__()
|
||||
}
|
||||
}
|
||||
|
||||
%% Model Classes
|
||||
namespace astrai.model {
|
||||
class Transformer {
|
||||
+forward(input_ids, mask) Dict
|
||||
+ModelConfig config
|
||||
+RotaryEmbedding rotary_embeding
|
||||
+Embedding embed_tokens
|
||||
+ModuleList layers
|
||||
+RMSNorm norm
|
||||
+Linear lm_head
|
||||
+forward(input_ids, input_mask, persistent_key_values, start_pos) Dict
|
||||
+load_state_dict(state_dict)
|
||||
+state_dict()
|
||||
}
|
||||
|
||||
%% Trainer Classes
|
||||
class DecoderBlock {
|
||||
+GQA attention
|
||||
+RMSNorm input_norm
|
||||
+MLP mlp
|
||||
+RMSNorm post_attention_norm
|
||||
+forward(x, rotary_emb, attention_mask, kv_cache, start_pos) Tensor
|
||||
}
|
||||
|
||||
class GQA {
|
||||
+int n_heads
|
||||
+int n_kv_heads
|
||||
+int head_dim
|
||||
+Linear q_proj, k_proj, v_proj, o_proj
|
||||
+RMSNorm q_norm, k_norm
|
||||
+forward(x, rotary_emb, mask, kv_cache, start_pos) Tensor
|
||||
}
|
||||
|
||||
class MLP {
|
||||
+Linear up, gate, down
|
||||
+forward(x) Tensor
|
||||
}
|
||||
|
||||
class RMSNorm {
|
||||
+Parameter weight
|
||||
+float norm_eps
|
||||
+forward(x) Tensor
|
||||
}
|
||||
|
||||
class Linear {
|
||||
+Parameter weight
|
||||
+Parameter bias
|
||||
+forward(x) Tensor
|
||||
}
|
||||
|
||||
class RotaryEmbedding {
|
||||
+int dim
|
||||
+int max_len
|
||||
+float base
|
||||
+forward(x, start_pos) Tuple~Tensor, Tensor~
|
||||
}
|
||||
|
||||
class Embedding {
|
||||
+Parameter weight
|
||||
+forward(x) Tensor
|
||||
}
|
||||
}
|
||||
|
||||
namespace astrai.tokenize {
|
||||
class Tokenizer {
|
||||
+encode(tokens, out_ids, add_special_tokens) List~int~
|
||||
+decode(tokens, skip_special_tokens) str
|
||||
+__len__() int
|
||||
}
|
||||
|
||||
class BpeTokenizer {
|
||||
+List~str~ stop_ids
|
||||
+int bos_id
|
||||
+int eos_id
|
||||
+int pad_id
|
||||
+encode(tokens, out_ids, add_special_tokens) List~int~
|
||||
+decode(tokens, skip_special_tokens) str
|
||||
}
|
||||
}
|
||||
|
||||
namespace astrai.trainer {
|
||||
class Trainer {
|
||||
+TrainConfig train_config
|
||||
+List~TrainCallback~ callbacks
|
||||
+train()
|
||||
+_build_context() TrainContext
|
||||
+train(checkpoint)
|
||||
+_build_context(checkpoint) TrainContext
|
||||
+_get_default_callbacks() List~TrainCallback~
|
||||
}
|
||||
|
||||
class TrainContext {
|
||||
|
|
@ -81,11 +217,14 @@ classDiagram
|
|||
+Checkpoint checkpoint
|
||||
+int epoch
|
||||
+int iteration
|
||||
+float loss
|
||||
+int world_size
|
||||
+int rank
|
||||
}
|
||||
|
||||
class TrainContextBuilder {
|
||||
+TrainConfig config
|
||||
+with_checkpoint(Checkpoint) TrainContextBuilder
|
||||
+with_checkpoint(checkpoint) TrainContextBuilder
|
||||
+with_dataloader() TrainContextBuilder
|
||||
+with_strategy() TrainContextBuilder
|
||||
+build() TrainContext
|
||||
|
|
@ -98,11 +237,9 @@ classDiagram
|
|||
}
|
||||
|
||||
class StrategyFactory {
|
||||
+frozenset SUPPORTED_STRATEGIES
|
||||
+Dict STRATEGY_MAP
|
||||
+Registry _registry
|
||||
+register(name) decorator
|
||||
+create(model, train_type, device) BaseStrategy
|
||||
+available_strategies() list
|
||||
+create(model, train_type, device, **kwargs) BaseStrategy
|
||||
}
|
||||
|
||||
class SEQStrategy {
|
||||
|
|
@ -130,29 +267,85 @@ classDiagram
|
|||
+compute_loss(batch) Tensor
|
||||
}
|
||||
|
||||
class TrainCallback {
|
||||
+on_train_begin(trainer)
|
||||
+on_train_end(trainer)
|
||||
+on_epoch_begin(epoch, trainer)
|
||||
+on_epoch_end(epoch, trainer)
|
||||
+on_batch_begin(batch, trainer)
|
||||
+on_batch_end(batch, trainer)
|
||||
}
|
||||
|
||||
class Schedule {
|
||||
class BaseScheduler {
|
||||
+get_lr() List~float~
|
||||
+step()
|
||||
}
|
||||
|
||||
%% Inference Classes
|
||||
class Generator {
|
||||
+generate(prompt, config) str
|
||||
+generate_batch(prompts, config) List[str]
|
||||
+stream_generate(prompt, config) Generator
|
||||
class SchedulerFactory {
|
||||
+Registry _registry
|
||||
+register(name) decorator
|
||||
+create(optimizer, schedule_type, **kwargs) BaseScheduler
|
||||
}
|
||||
|
||||
class InferenceCore {
|
||||
+forward(input_ids) Dict
|
||||
+apply_sampling_strategies()
|
||||
class CosineScheduler {
|
||||
+int warmup_steps
|
||||
+int lr_decay_steps
|
||||
+float min_rate
|
||||
}
|
||||
|
||||
class SGDRScheduler {
|
||||
+int warmup_steps
|
||||
+int cycle_length
|
||||
+float min_rate
|
||||
+int t_mult
|
||||
}
|
||||
|
||||
class TrainCallback {
|
||||
+on_train_begin(context)
|
||||
+on_train_end(context)
|
||||
+on_epoch_begin(context)
|
||||
+on_epoch_end(context)
|
||||
+on_step_begin(context)
|
||||
+on_step_end(context)
|
||||
+on_batch_begin(context)
|
||||
+on_batch_end(context)
|
||||
+on_error(context)
|
||||
}
|
||||
|
||||
class CallbackFactory {
|
||||
+Registry _registry
|
||||
+register(name) decorator
|
||||
+create(name, **kwargs) TrainCallback
|
||||
}
|
||||
}
|
||||
|
||||
namespace astrai.inference {
|
||||
class GeneratorCore {
|
||||
+ModelParameter parameter
|
||||
+generate_iterator(input_ids, temperature, top_k, top_p, attn_mask, kv_caches, start_pos) Tuple~Tensor, int~
|
||||
+generate_loop(input_ids, ids, temperature, top_k, top_p, attn_mask, kv_caches, start_pos, callback) List~int~
|
||||
}
|
||||
|
||||
class LoopGenerator {
|
||||
+generate(request) str
|
||||
}
|
||||
|
||||
class StreamGenerator {
|
||||
+generate(request) Generator
|
||||
}
|
||||
|
||||
class BatchGenerator {
|
||||
+generate(request) List~str~
|
||||
}
|
||||
|
||||
class EmbeddingEncoder {
|
||||
+encode(sentence) Tensor
|
||||
}
|
||||
|
||||
class KVCacheManager {
|
||||
+int batch_size
|
||||
+Tuple kv_cache
|
||||
+Tensor seq_mask
|
||||
+get_kvcache() Tuple
|
||||
+get_seq_mask() Tensor
|
||||
+update(active_mask)
|
||||
+reset(full_reset)
|
||||
}
|
||||
|
||||
class GeneratorFactory {
|
||||
+create(parameter, request) GeneratorCore
|
||||
+create_encoder(parameter) EmbeddingEncoderCore
|
||||
}
|
||||
|
||||
class Server {
|
||||
|
|
@ -160,54 +353,123 @@ classDiagram
|
|||
+predict(request)
|
||||
}
|
||||
|
||||
%% Parallel Classes
|
||||
class GenerationRequest {
|
||||
+int top_k
|
||||
+float top_p
|
||||
+float temperature
|
||||
+int max_len
|
||||
+Union~str, List~str~~ query
|
||||
+history Optional
|
||||
+system_prompt Optional~str~
|
||||
+stream bool
|
||||
}
|
||||
}
|
||||
|
||||
namespace astrai.parallel {
|
||||
class ParallelSetup {
|
||||
+spawn_parallel_fn(fn, nprocs)
|
||||
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type, device_ids)
|
||||
}
|
||||
|
||||
class ColumnParallelLinear {
|
||||
+forward(x) Tensor
|
||||
}
|
||||
|
||||
class RowParallelLinear {
|
||||
+forward(x) Tensor
|
||||
}
|
||||
}
|
||||
|
||||
%% Relationships
|
||||
TrainConfig --> ModelConfig : contains
|
||||
TrainConfig --> Dataset : uses
|
||||
TrainConfig --> BaseDataset : uses
|
||||
TrainConfig --> Transformer : uses
|
||||
Trainer --> TrainConfig : configures
|
||||
Trainer --> TrainContextBuilder : builds
|
||||
Trainer --> TrainCallback : manages
|
||||
TrainContextBuilder --> TrainContext : creates
|
||||
TrainContext --> Checkpoint : manages
|
||||
TrainContext --> BaseStrategy : uses
|
||||
TrainContext --> BaseScheduler : uses
|
||||
StrategyFactory ..> BaseStrategy : creates
|
||||
BaseStrategy <|-- SEQStrategy
|
||||
BaseStrategy <|-- SFTStrategy
|
||||
BaseStrategy <|-- DPOStrategy
|
||||
BaseStrategy <|-- GRPOStrategy
|
||||
TrainContext --> BaseStrategy : uses
|
||||
Generator --> InferenceCore : uses
|
||||
Generator --> Transformer : uses
|
||||
Server --> Generator : uses
|
||||
SchedulerFactory ..> BaseScheduler : creates
|
||||
BaseScheduler <|-- CosineScheduler
|
||||
BaseScheduler <|-- SGDRScheduler
|
||||
CallbackFactory ..> TrainCallback : creates
|
||||
GeneratorFactory ..> GeneratorCore : creates
|
||||
GeneratorCore <|-- LoopGenerator
|
||||
GeneratorCore <|-- StreamGenerator
|
||||
GeneratorCore <|-- BatchGenerator
|
||||
GeneratorCore <|-- EmbeddingEncoder
|
||||
LoopGenerator --> KVCacheManager : uses
|
||||
StreamGenerator --> KVCacheManager : uses
|
||||
BatchGenerator --> KVCacheManager : uses
|
||||
GeneratorCore --> Transformer : uses
|
||||
DPOStrategy --> Transformer : creates ref_model
|
||||
GRPOStrategy --> Transformer : creates ref_model
|
||||
Server --> GeneratorFactory : uses
|
||||
ParallelSetup --> Trainer : enables
|
||||
TrainConfig --> StrategyFactory : selects
|
||||
TrainCallback <|-- CheckpointCallback
|
||||
TrainCallback <|-- MetricLoggerCallback
|
||||
TrainCallback <|-- SchedulerCallback
|
||||
TrainContext --> Schedule : uses
|
||||
ModelParameter --> Transformer : contains
|
||||
ModelParameter --> BpeTokenizer : contains
|
||||
ModelParameter --> ModelConfig : contains
|
||||
GeneratorFactory --> GenerationRequest : uses
|
||||
BaseDataset <|-- SEQDataset
|
||||
BaseDataset <|-- SFTDataset
|
||||
BaseDataset <|-- DPODataset
|
||||
BaseDataset <|-- GRPODataset
|
||||
DatasetFactory ..> BaseDataset : creates
|
||||
BaseSegmentFetcher --> MultiSegmentFetcher : used by
|
||||
MultiSegmentFetcher --> BaseDataset : used by
|
||||
Transformer --> DecoderBlock : uses
|
||||
Transformer --> RotaryEmbedding : uses
|
||||
Transformer --> Embedding : uses
|
||||
DecoderBlock --> GQA : uses
|
||||
DecoderBlock --> MLP : uses
|
||||
DecoderBlock --> RMSNorm : uses
|
||||
BpeTokenizer --> Tokenizer : inherits
|
||||
TrainContextBuilder --> ResumableDistributedSampler : creates
|
||||
DataLoader --> BaseDataset : uses
|
||||
ResumableDistributedSampler --> BaseDataset : samples
|
||||
```
|
||||
|
||||
### Design Pattern Summary
|
||||
### Module Overview
|
||||
|
||||
| Module | Components | Description |
|
||||
|--------|------------|-------------|
|
||||
| **astrai.config** | ModelConfig, TrainConfig, ModelParameter | Configuration management |
|
||||
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint, DataLoader | Dataset loading and management |
|
||||
| **astrai.model** | Transformer, DecoderBlock, GQA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||
| **astrai.tokenize** | Tokenizer, BpeTokenizer | Tokenizer |
|
||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
|
||||
| **astrai.inference** | GeneratorCore, LoopGenerator, StreamGenerator, BatchGenerator, EmbeddingEncoder, KVCacheManager, GeneratorFactory, Server, GenerationRequest | Inference service |
|
||||
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
||||
|
||||
### Design Patterns
|
||||
|
||||
| Pattern | Classes | Purpose |
|
||||
|---------|---------|---------|
|
||||
| **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO |
|
||||
| **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components |
|
||||
| **Factory** | `StrategyFactory` | Decorator registration mechanism, dynamically create training strategies |
|
||||
| **Observer** | `TrainCallback` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
|
||||
| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `GeneratorFactory`, `CallbackFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, generators, and callbacks |
|
||||
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
|
||||
| **Singleton** | `TrainContext` | Training process global state management |
|
||||
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
|
||||
|
||||
### Core Relationships
|
||||
|
||||
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
|
||||
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss
|
||||
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
|
||||
4. **Inference Flow**: `Server` → `Generator` → `InferenceCore` → `Transformer`
|
||||
4. **Inference Flow**: `Server` → `GeneratorFactory` → `GeneratorCore` → `Transformer`, supports multiple generators (LoopGenerator, StreamGenerator, BatchGenerator)
|
||||
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
|
||||
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
|
||||
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
|
||||
8. **Scheduler Support**: `SchedulerFactory` creates learning rate schedulers (CosineScheduler, SGDRScheduler)
|
||||
|
||||
## 3. Training Process
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue