docs: 更新文档类图等
This commit is contained in:
parent
c94a246c71
commit
99b821ebf5
|
|
@ -5,7 +5,7 @@ This document describes the data flow of the AstrAI project (a training and infe
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
AstrAI adopts a modular design with the following main components:
|
AstrAI adopts a modular design with the following main components:
|
||||||
- **Data Module** (`astrai/data/`): Dataset, sampler, tokenizer, serialization tools
|
- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools
|
||||||
- **Model Module** (`astrai/model/`): Transformer model and its submodules
|
- **Model Module** (`astrai/model/`): Transformer model and its submodules
|
||||||
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers
|
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers
|
||||||
- **Inference Module** (`astrai/inference/`): Generation core, KV cache management, streaming generation
|
- **Inference Module** (`astrai/inference/`): Generation core, KV cache management, streaming generation
|
||||||
|
|
@ -20,35 +20,39 @@ The data flow can generally be divided into two main lines: **Training Data Flow
|
||||||
flowchart LR
|
flowchart LR
|
||||||
subgraph A[Data Preparation]
|
subgraph A[Data Preparation]
|
||||||
direction TB
|
direction TB
|
||||||
A1[Raw Text] --> A2[BBPE Tokenizer]
|
A1[Raw Text] --> A2[BpeTokenizer]
|
||||||
A2 --> A3[Serialize to .h5 files]
|
A2 --> A3[Serialize to .h5 files]
|
||||||
A3 --> A4[Dataset Loading<br/>BaseDataset]
|
A3 --> A4[BaseDataset]
|
||||||
A4 --> A5[Resumable Distributed Sampler<br/>ResumableDistributedSampler]
|
A4 --> A5[ResumableDistributedSampler]
|
||||||
A5 --> A6[DataLoader Batch Loading]
|
A5 --> A6[DataLoader]
|
||||||
end
|
end
|
||||||
|
|
||||||
subgraph B[Training Loop]
|
subgraph B[Training]
|
||||||
direction TB
|
direction TB
|
||||||
B1[Batch Data] --> B2[Training Strategy<br/>BaseStrategy]
|
B1[Batch Data] --> B2[TrainContextBuilder]
|
||||||
B2 --> B3[Transformer Model]
|
B2 --> B3[TrainContext]
|
||||||
B3 --> B4[Output logits]
|
B3 --> B4[BaseStrategy]
|
||||||
B4 --> B5[Loss Calculation]
|
B4 --> B5[Transformer]
|
||||||
B5 --> B6[Backpropagation]
|
B5 --> B6[Compute Loss]
|
||||||
B6 --> B7[Optimizer Update]
|
B6 --> B7[Backward]
|
||||||
B7 --> B8[Learning Rate Scheduler]
|
B7 --> B8[Optimizer]
|
||||||
B8 --> B9[Checkpoint Save]
|
B8 --> B9[LRScheduler]
|
||||||
|
B9 --> B10[CheckpointCallback]
|
||||||
end
|
end
|
||||||
|
|
||||||
subgraph C[Inference Generation]
|
subgraph C[Inference]
|
||||||
direction TB
|
direction TB
|
||||||
C1[Checkpoint Loading] --> C2[Inference Model Loading]
|
C1[Checkpoint] --> C2[ModelParameter]
|
||||||
C2 --> C3[Generation Core<br/>GeneratorCore]
|
C2 --> C3[Transformer + BpeTokenizer]
|
||||||
C3 --> C4[Sampling Strategy<br/>Temperature/top-k/top-p]
|
C3 --> C4[GenerationRequest + build_prompt]
|
||||||
C4 --> C5[Generate Next Token]
|
C4 --> C5[GeneratorFactory]
|
||||||
C5 --> C6[KV Cache Update]
|
C5 --> C6[GeneratorCore]
|
||||||
C6 --> C7{Max Length Reached?}
|
C6 --> C7[apply_sampling_strategies]
|
||||||
C7 -->|No| C5
|
C7 --> C8[Transformer Forward]
|
||||||
C7 -->|Yes| C8[Output Generated Text]
|
C8 --> C9[KVCacheManager]
|
||||||
|
C9 --> C10{End Condition?}
|
||||||
|
C10 -->|No| C8
|
||||||
|
C10 -->|Yes| C11[Output Text]
|
||||||
end
|
end
|
||||||
|
|
||||||
A --> B
|
A --> B
|
||||||
|
|
@ -57,13 +61,14 @@ flowchart LR
|
||||||
|
|
||||||
## Detailed Module Descriptions
|
## Detailed Module Descriptions
|
||||||
|
|
||||||
### 1. Data Module
|
### 1. Dataset Module
|
||||||
|
|
||||||
#### 1.1 Tokenizer (`tokenizer.py`)
|
#### 1.1 Tokenizer (`tokenizer.py`)
|
||||||
- Implemented based on Byte-Level BPE (BBPE)
|
- Implemented based on Byte-Level BPE (BPE)
|
||||||
- Supports special tokens: `<|begin▁of▁sentence|>`, `<|end▁of▁sentence|>`, `<|▁pad▁|>`, `<|im▁start|>`, `<|im▁end|>`
|
- Supports special tokens: `<|begin▁of▁sentence|>`, `<|end▁of▁sentence|>`, `<|▁pad▁|>`, `<|im▁start|>`, `<|im▁end|>`
|
||||||
- Provides `encode`/`decode` methods for mutual conversion between text and token IDs
|
- Provides `encode`/`decode` methods for mutual conversion between text and token IDs
|
||||||
- Learns vocabulary from corpus during training, saved as `.json` files
|
- Learns vocabulary from corpus during training, saved as `.json` files
|
||||||
|
- `BpeTrainer` class handles vocabulary training from corpus
|
||||||
|
|
||||||
#### 1.2 Serialization (`serialization.py`)
|
#### 1.2 Serialization (`serialization.py`)
|
||||||
- **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors
|
- **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors
|
||||||
|
|
@ -108,7 +113,7 @@ flowchart LR
|
||||||
1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end`
|
1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end`
|
||||||
|
|
||||||
#### 3.3 Strategy (`strategy.py`)
|
#### 3.3 Strategy (`strategy.py`)
|
||||||
- **`BaseStrategy`**: Defines training strategy interface (such as `SeqStrategy`, `SFTStrategy`, `DPOStrategy`)
|
- **`BaseStrategy`**: Defines training strategy interface (such as `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`)
|
||||||
- Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor
|
- Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor
|
||||||
- Created dynamically by `StrategyFactory` according to configuration
|
- Created dynamically by `StrategyFactory` according to configuration
|
||||||
|
|
||||||
|
|
@ -130,14 +135,16 @@ flowchart LR
|
||||||
|
|
||||||
#### 4.3 Generator (`generator.py`)
|
#### 4.3 Generator (`generator.py`)
|
||||||
- **`GenerationRequest`**: Encapsulates generation request parameters (top_k, top_p, temperature, max_len, query, history, etc.)
|
- **`GenerationRequest`**: Encapsulates generation request parameters (top_k, top_p, temperature, max_len, query, history, etc.)
|
||||||
- **`build_prompt`**: Converts query and history into ChatML format prompt string
|
- **`build_prompt`** (from `chat_template.py`): Converts query and history into ChatML format prompt string
|
||||||
|
- **`GeneratorCore`**: Base generator with generate_iterator and generate_loop methods
|
||||||
|
- **`LoopGenerator`**, **`StreamGenerator`**, **`BatchGenerator`**: Different generation modes
|
||||||
- **`pad_sequence`**: Pads input IDs to consistent length
|
- **`pad_sequence`**: Pads input IDs to consistent length
|
||||||
- Provides streaming and non-streaming generation interfaces
|
- Provides streaming and non-streaming generation interfaces
|
||||||
|
|
||||||
## Training Data Flow - Detailed Steps
|
## Training Data Flow - Detailed Steps
|
||||||
|
|
||||||
1. **Data Preparation**
|
1. **Data Preparation**
|
||||||
- Raw text is converted to token ID sequences through BBPE tokenizer
|
- Raw text is converted to token ID sequences through BPE tokenizer
|
||||||
- Token ID sequences (possibly with masks, labels, etc.) are saved by groups as `.h5` files
|
- Token ID sequences (possibly with masks, labels, etc.) are saved by groups as `.h5` files
|
||||||
- Files can contain multiple segments, each segment corresponds to a tensor
|
- Files can contain multiple segments, each segment corresponds to a tensor
|
||||||
|
|
||||||
|
|
@ -152,7 +159,7 @@ flowchart LR
|
||||||
- Batch data shape is `[batch_size, window_size]` (or varies according to specific dataset type)
|
- Batch data shape is `[batch_size, window_size]` (or varies according to specific dataset type)
|
||||||
|
|
||||||
4. **Strategy Forward and Loss Calculation**
|
4. **Strategy Forward and Loss Calculation**
|
||||||
- Batch data is passed to strategy (such as `SeqStrategy`)
|
- Batch data is passed to strategy (such as `SEQStrategy`)
|
||||||
- Strategy internally calls `Transformer` model, obtaining logits
|
- Strategy internally calls `Transformer` model, obtaining logits
|
||||||
- Calculate cross-entropy loss (or DPO loss, etc.) according to task type
|
- Calculate cross-entropy loss (or DPO loss, etc.) according to task type
|
||||||
- Return loss tensor
|
- Return loss tensor
|
||||||
|
|
@ -174,7 +181,7 @@ flowchart LR
|
||||||
- Set model to evaluation mode (`model.eval()`), enable inference mode (`torch.inference_mode`)
|
- Set model to evaluation mode (`model.eval()`), enable inference mode (`torch.inference_mode`)
|
||||||
|
|
||||||
2. **Prompt Construction and Encoding**
|
2. **Prompt Construction and Encoding**
|
||||||
- User query and history are converted to ChatML format string through `build_prompt`
|
- User query and history are converted to ChatML format string through `build_prompt` function in chat_template module
|
||||||
- Tokenizer encodes prompt string to token ID sequence `input_ids`
|
- Tokenizer encodes prompt string to token ID sequence `input_ids`
|
||||||
- For batch generation, use `pad_sequence` for padding
|
- For batch generation, use `pad_sequence` for padding
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,13 +8,14 @@ Thus, the AstrAI project was born - 1B parameters, Chinese-English bilingual, su
|
||||||
|
|
||||||
```mermaid
|
```mermaid
|
||||||
classDiagram
|
classDiagram
|
||||||
%% Configuration Classes
|
namespace astrai.config {
|
||||||
class ModelConfig {
|
class ModelConfig {
|
||||||
+int vocab_size
|
+int vocab_size
|
||||||
+int dim
|
+int dim
|
||||||
+int n_layers
|
+int n_layers
|
||||||
+float norm_eps
|
+float norm_eps
|
||||||
+int dim_ffn
|
+int dim_ffn
|
||||||
|
+bool tie_weight
|
||||||
+int max_len
|
+int max_len
|
||||||
+float rope_theta
|
+float rope_theta
|
||||||
+int n_heads
|
+int n_heads
|
||||||
|
|
@ -42,34 +43,169 @@ classDiagram
|
||||||
+validate()
|
+validate()
|
||||||
}
|
}
|
||||||
|
|
||||||
%% Data Classes
|
class ModelParameter {
|
||||||
class Dataset {
|
+nn.Module model
|
||||||
|
+BpeTokenizer tokenizer
|
||||||
|
+ModelConfig config
|
||||||
|
+save(instance, save_dir)
|
||||||
|
+load(load_dir, disable_init) ModelParameter
|
||||||
|
+to(*args, **kwargs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace astrai.dataset {
|
||||||
|
class BaseDataset {
|
||||||
|
+int window_size
|
||||||
|
+int stride
|
||||||
|
+MultiSegmentFetcher fetcher
|
||||||
|
+load(load_path)
|
||||||
|
+__getitem__(index)
|
||||||
+__len__()
|
+__len__()
|
||||||
+__getitem__()
|
}
|
||||||
|
|
||||||
|
class SEQDataset {
|
||||||
|
+__getitem__(index) Dict
|
||||||
|
}
|
||||||
|
|
||||||
|
class SFTDataset {
|
||||||
|
+__getitem__(index) Dict
|
||||||
|
}
|
||||||
|
|
||||||
|
class DPODataset {
|
||||||
|
+__getitem__(index) Dict
|
||||||
|
}
|
||||||
|
|
||||||
|
class GRPODataset {
|
||||||
|
+__getitem__(index) Dict
|
||||||
|
}
|
||||||
|
|
||||||
|
class BaseSegmentFetcher {
|
||||||
|
+List~Tensor~ segments
|
||||||
|
+List~int~ cum_lengths
|
||||||
|
+int total_length
|
||||||
|
+fetch_data(begin_idx, end_idx) Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
class MultiSegmentFetcher {
|
||||||
|
+Dict muti_fetchers
|
||||||
|
+List muti_keys
|
||||||
|
+key_fetch(begin_idx, end_idx, keys) Dict
|
||||||
|
+fetch_data(begin_idx, end_idx) Dict
|
||||||
|
}
|
||||||
|
|
||||||
|
class ResumableDistributedSampler {
|
||||||
|
+int start_epoch
|
||||||
|
+int start_iter
|
||||||
|
}
|
||||||
|
|
||||||
|
class DatasetFactory {
|
||||||
|
+Registry _registry
|
||||||
|
+register(name) decorator
|
||||||
|
+create(train_type, window_size, stride) BaseDataset
|
||||||
|
+load(train_type, load_path, window_size, stride) BaseDataset
|
||||||
}
|
}
|
||||||
|
|
||||||
class Checkpoint {
|
class Checkpoint {
|
||||||
+dict state_dict
|
+dict state_dict
|
||||||
+int epoch
|
+int epoch
|
||||||
+int iteration
|
+int iteration
|
||||||
|
+save(save_dir)
|
||||||
|
+load(save_dir) Checkpoint
|
||||||
}
|
}
|
||||||
|
|
||||||
class Tokenizer {
|
class DataLoader {
|
||||||
+encode(text) List[int]
|
+Dataset dataset
|
||||||
+decode(ids) str
|
+int batch_size
|
||||||
|
+Sampler sampler
|
||||||
|
+__iter__()
|
||||||
|
+__len__()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
%% Model Classes
|
namespace astrai.model {
|
||||||
class Transformer {
|
class Transformer {
|
||||||
+forward(input_ids, mask) Dict
|
+ModelConfig config
|
||||||
|
+RotaryEmbedding rotary_embeding
|
||||||
|
+Embedding embed_tokens
|
||||||
|
+ModuleList layers
|
||||||
|
+RMSNorm norm
|
||||||
|
+Linear lm_head
|
||||||
|
+forward(input_ids, input_mask, persistent_key_values, start_pos) Dict
|
||||||
|
+load_state_dict(state_dict)
|
||||||
|
+state_dict()
|
||||||
}
|
}
|
||||||
|
|
||||||
%% Trainer Classes
|
class DecoderBlock {
|
||||||
|
+GQA attention
|
||||||
|
+RMSNorm input_norm
|
||||||
|
+MLP mlp
|
||||||
|
+RMSNorm post_attention_norm
|
||||||
|
+forward(x, rotary_emb, attention_mask, kv_cache, start_pos) Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
class GQA {
|
||||||
|
+int n_heads
|
||||||
|
+int n_kv_heads
|
||||||
|
+int head_dim
|
||||||
|
+Linear q_proj, k_proj, v_proj, o_proj
|
||||||
|
+RMSNorm q_norm, k_norm
|
||||||
|
+forward(x, rotary_emb, mask, kv_cache, start_pos) Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
class MLP {
|
||||||
|
+Linear up, gate, down
|
||||||
|
+forward(x) Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
class RMSNorm {
|
||||||
|
+Parameter weight
|
||||||
|
+float norm_eps
|
||||||
|
+forward(x) Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
class Linear {
|
||||||
|
+Parameter weight
|
||||||
|
+Parameter bias
|
||||||
|
+forward(x) Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
class RotaryEmbedding {
|
||||||
|
+int dim
|
||||||
|
+int max_len
|
||||||
|
+float base
|
||||||
|
+forward(x, start_pos) Tuple~Tensor, Tensor~
|
||||||
|
}
|
||||||
|
|
||||||
|
class Embedding {
|
||||||
|
+Parameter weight
|
||||||
|
+forward(x) Tensor
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace astrai.tokenize {
|
||||||
|
class Tokenizer {
|
||||||
|
+encode(tokens, out_ids, add_special_tokens) List~int~
|
||||||
|
+decode(tokens, skip_special_tokens) str
|
||||||
|
+__len__() int
|
||||||
|
}
|
||||||
|
|
||||||
|
class BpeTokenizer {
|
||||||
|
+List~str~ stop_ids
|
||||||
|
+int bos_id
|
||||||
|
+int eos_id
|
||||||
|
+int pad_id
|
||||||
|
+encode(tokens, out_ids, add_special_tokens) List~int~
|
||||||
|
+decode(tokens, skip_special_tokens) str
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace astrai.trainer {
|
||||||
class Trainer {
|
class Trainer {
|
||||||
+TrainConfig train_config
|
+TrainConfig train_config
|
||||||
+List~TrainCallback~ callbacks
|
+List~TrainCallback~ callbacks
|
||||||
+train()
|
+train(checkpoint)
|
||||||
+_build_context() TrainContext
|
+_build_context(checkpoint) TrainContext
|
||||||
|
+_get_default_callbacks() List~TrainCallback~
|
||||||
}
|
}
|
||||||
|
|
||||||
class TrainContext {
|
class TrainContext {
|
||||||
|
|
@ -81,11 +217,14 @@ classDiagram
|
||||||
+Checkpoint checkpoint
|
+Checkpoint checkpoint
|
||||||
+int epoch
|
+int epoch
|
||||||
+int iteration
|
+int iteration
|
||||||
|
+float loss
|
||||||
|
+int world_size
|
||||||
|
+int rank
|
||||||
}
|
}
|
||||||
|
|
||||||
class TrainContextBuilder {
|
class TrainContextBuilder {
|
||||||
+TrainConfig config
|
+TrainConfig config
|
||||||
+with_checkpoint(Checkpoint) TrainContextBuilder
|
+with_checkpoint(checkpoint) TrainContextBuilder
|
||||||
+with_dataloader() TrainContextBuilder
|
+with_dataloader() TrainContextBuilder
|
||||||
+with_strategy() TrainContextBuilder
|
+with_strategy() TrainContextBuilder
|
||||||
+build() TrainContext
|
+build() TrainContext
|
||||||
|
|
@ -98,11 +237,9 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class StrategyFactory {
|
class StrategyFactory {
|
||||||
+frozenset SUPPORTED_STRATEGIES
|
+Registry _registry
|
||||||
+Dict STRATEGY_MAP
|
|
||||||
+register(name) decorator
|
+register(name) decorator
|
||||||
+create(model, train_type, device) BaseStrategy
|
+create(model, train_type, device, **kwargs) BaseStrategy
|
||||||
+available_strategies() list
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class SEQStrategy {
|
class SEQStrategy {
|
||||||
|
|
@ -130,29 +267,85 @@ classDiagram
|
||||||
+compute_loss(batch) Tensor
|
+compute_loss(batch) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
class TrainCallback {
|
class BaseScheduler {
|
||||||
+on_train_begin(trainer)
|
+get_lr() List~float~
|
||||||
+on_train_end(trainer)
|
|
||||||
+on_epoch_begin(epoch, trainer)
|
|
||||||
+on_epoch_end(epoch, trainer)
|
|
||||||
+on_batch_begin(batch, trainer)
|
|
||||||
+on_batch_end(batch, trainer)
|
|
||||||
}
|
|
||||||
|
|
||||||
class Schedule {
|
|
||||||
+step()
|
+step()
|
||||||
}
|
}
|
||||||
|
|
||||||
%% Inference Classes
|
class SchedulerFactory {
|
||||||
class Generator {
|
+Registry _registry
|
||||||
+generate(prompt, config) str
|
+register(name) decorator
|
||||||
+generate_batch(prompts, config) List[str]
|
+create(optimizer, schedule_type, **kwargs) BaseScheduler
|
||||||
+stream_generate(prompt, config) Generator
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class InferenceCore {
|
class CosineScheduler {
|
||||||
+forward(input_ids) Dict
|
+int warmup_steps
|
||||||
+apply_sampling_strategies()
|
+int lr_decay_steps
|
||||||
|
+float min_rate
|
||||||
|
}
|
||||||
|
|
||||||
|
class SGDRScheduler {
|
||||||
|
+int warmup_steps
|
||||||
|
+int cycle_length
|
||||||
|
+float min_rate
|
||||||
|
+int t_mult
|
||||||
|
}
|
||||||
|
|
||||||
|
class TrainCallback {
|
||||||
|
+on_train_begin(context)
|
||||||
|
+on_train_end(context)
|
||||||
|
+on_epoch_begin(context)
|
||||||
|
+on_epoch_end(context)
|
||||||
|
+on_step_begin(context)
|
||||||
|
+on_step_end(context)
|
||||||
|
+on_batch_begin(context)
|
||||||
|
+on_batch_end(context)
|
||||||
|
+on_error(context)
|
||||||
|
}
|
||||||
|
|
||||||
|
class CallbackFactory {
|
||||||
|
+Registry _registry
|
||||||
|
+register(name) decorator
|
||||||
|
+create(name, **kwargs) TrainCallback
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace astrai.inference {
|
||||||
|
class GeneratorCore {
|
||||||
|
+ModelParameter parameter
|
||||||
|
+generate_iterator(input_ids, temperature, top_k, top_p, attn_mask, kv_caches, start_pos) Tuple~Tensor, int~
|
||||||
|
+generate_loop(input_ids, ids, temperature, top_k, top_p, attn_mask, kv_caches, start_pos, callback) List~int~
|
||||||
|
}
|
||||||
|
|
||||||
|
class LoopGenerator {
|
||||||
|
+generate(request) str
|
||||||
|
}
|
||||||
|
|
||||||
|
class StreamGenerator {
|
||||||
|
+generate(request) Generator
|
||||||
|
}
|
||||||
|
|
||||||
|
class BatchGenerator {
|
||||||
|
+generate(request) List~str~
|
||||||
|
}
|
||||||
|
|
||||||
|
class EmbeddingEncoder {
|
||||||
|
+encode(sentence) Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
class KVCacheManager {
|
||||||
|
+int batch_size
|
||||||
|
+Tuple kv_cache
|
||||||
|
+Tensor seq_mask
|
||||||
|
+get_kvcache() Tuple
|
||||||
|
+get_seq_mask() Tensor
|
||||||
|
+update(active_mask)
|
||||||
|
+reset(full_reset)
|
||||||
|
}
|
||||||
|
|
||||||
|
class GeneratorFactory {
|
||||||
|
+create(parameter, request) GeneratorCore
|
||||||
|
+create_encoder(parameter) EmbeddingEncoderCore
|
||||||
}
|
}
|
||||||
|
|
||||||
class Server {
|
class Server {
|
||||||
|
|
@ -160,54 +353,123 @@ classDiagram
|
||||||
+predict(request)
|
+predict(request)
|
||||||
}
|
}
|
||||||
|
|
||||||
%% Parallel Classes
|
class GenerationRequest {
|
||||||
|
+int top_k
|
||||||
|
+float top_p
|
||||||
|
+float temperature
|
||||||
|
+int max_len
|
||||||
|
+Union~str, List~str~~ query
|
||||||
|
+history Optional
|
||||||
|
+system_prompt Optional~str~
|
||||||
|
+stream bool
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace astrai.parallel {
|
||||||
class ParallelSetup {
|
class ParallelSetup {
|
||||||
+spawn_parallel_fn(fn, nprocs)
|
+spawn_parallel_fn(fn, nprocs)
|
||||||
|
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type, device_ids)
|
||||||
|
}
|
||||||
|
|
||||||
|
class ColumnParallelLinear {
|
||||||
|
+forward(x) Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
class RowParallelLinear {
|
||||||
|
+forward(x) Tensor
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
%% Relationships
|
%% Relationships
|
||||||
TrainConfig --> ModelConfig : contains
|
TrainConfig --> ModelConfig : contains
|
||||||
TrainConfig --> Dataset : uses
|
TrainConfig --> BaseDataset : uses
|
||||||
TrainConfig --> Transformer : uses
|
TrainConfig --> Transformer : uses
|
||||||
Trainer --> TrainConfig : configures
|
Trainer --> TrainConfig : configures
|
||||||
Trainer --> TrainContextBuilder : builds
|
Trainer --> TrainContextBuilder : builds
|
||||||
Trainer --> TrainCallback : manages
|
Trainer --> TrainCallback : manages
|
||||||
TrainContextBuilder --> TrainContext : creates
|
TrainContextBuilder --> TrainContext : creates
|
||||||
TrainContext --> Checkpoint : manages
|
TrainContext --> Checkpoint : manages
|
||||||
|
TrainContext --> BaseStrategy : uses
|
||||||
|
TrainContext --> BaseScheduler : uses
|
||||||
StrategyFactory ..> BaseStrategy : creates
|
StrategyFactory ..> BaseStrategy : creates
|
||||||
BaseStrategy <|-- SEQStrategy
|
BaseStrategy <|-- SEQStrategy
|
||||||
BaseStrategy <|-- SFTStrategy
|
BaseStrategy <|-- SFTStrategy
|
||||||
BaseStrategy <|-- DPOStrategy
|
BaseStrategy <|-- DPOStrategy
|
||||||
BaseStrategy <|-- GRPOStrategy
|
BaseStrategy <|-- GRPOStrategy
|
||||||
TrainContext --> BaseStrategy : uses
|
SchedulerFactory ..> BaseScheduler : creates
|
||||||
Generator --> InferenceCore : uses
|
BaseScheduler <|-- CosineScheduler
|
||||||
Generator --> Transformer : uses
|
BaseScheduler <|-- SGDRScheduler
|
||||||
Server --> Generator : uses
|
CallbackFactory ..> TrainCallback : creates
|
||||||
|
GeneratorFactory ..> GeneratorCore : creates
|
||||||
|
GeneratorCore <|-- LoopGenerator
|
||||||
|
GeneratorCore <|-- StreamGenerator
|
||||||
|
GeneratorCore <|-- BatchGenerator
|
||||||
|
GeneratorCore <|-- EmbeddingEncoder
|
||||||
|
LoopGenerator --> KVCacheManager : uses
|
||||||
|
StreamGenerator --> KVCacheManager : uses
|
||||||
|
BatchGenerator --> KVCacheManager : uses
|
||||||
|
GeneratorCore --> Transformer : uses
|
||||||
|
DPOStrategy --> Transformer : creates ref_model
|
||||||
|
GRPOStrategy --> Transformer : creates ref_model
|
||||||
|
Server --> GeneratorFactory : uses
|
||||||
ParallelSetup --> Trainer : enables
|
ParallelSetup --> Trainer : enables
|
||||||
TrainConfig --> StrategyFactory : selects
|
TrainConfig --> StrategyFactory : selects
|
||||||
TrainCallback <|-- CheckpointCallback
|
ModelParameter --> Transformer : contains
|
||||||
TrainCallback <|-- MetricLoggerCallback
|
ModelParameter --> BpeTokenizer : contains
|
||||||
TrainCallback <|-- SchedulerCallback
|
ModelParameter --> ModelConfig : contains
|
||||||
TrainContext --> Schedule : uses
|
GeneratorFactory --> GenerationRequest : uses
|
||||||
|
BaseDataset <|-- SEQDataset
|
||||||
|
BaseDataset <|-- SFTDataset
|
||||||
|
BaseDataset <|-- DPODataset
|
||||||
|
BaseDataset <|-- GRPODataset
|
||||||
|
DatasetFactory ..> BaseDataset : creates
|
||||||
|
BaseSegmentFetcher --> MultiSegmentFetcher : used by
|
||||||
|
MultiSegmentFetcher --> BaseDataset : used by
|
||||||
|
Transformer --> DecoderBlock : uses
|
||||||
|
Transformer --> RotaryEmbedding : uses
|
||||||
|
Transformer --> Embedding : uses
|
||||||
|
DecoderBlock --> GQA : uses
|
||||||
|
DecoderBlock --> MLP : uses
|
||||||
|
DecoderBlock --> RMSNorm : uses
|
||||||
|
BpeTokenizer --> Tokenizer : inherits
|
||||||
|
TrainContextBuilder --> ResumableDistributedSampler : creates
|
||||||
|
DataLoader --> BaseDataset : uses
|
||||||
|
ResumableDistributedSampler --> BaseDataset : samples
|
||||||
```
|
```
|
||||||
|
|
||||||
### Design Pattern Summary
|
### Module Overview
|
||||||
|
|
||||||
|
| Module | Components | Description |
|
||||||
|
|--------|------------|-------------|
|
||||||
|
| **astrai.config** | ModelConfig, TrainConfig, ModelParameter | Configuration management |
|
||||||
|
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint, DataLoader | Dataset loading and management |
|
||||||
|
| **astrai.model** | Transformer, DecoderBlock, GQA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||||
|
| **astrai.tokenize** | Tokenizer, BpeTokenizer | Tokenizer |
|
||||||
|
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
|
||||||
|
| **astrai.inference** | GeneratorCore, LoopGenerator, StreamGenerator, BatchGenerator, EmbeddingEncoder, KVCacheManager, GeneratorFactory, Server, GenerationRequest | Inference service |
|
||||||
|
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
||||||
|
|
||||||
|
### Design Patterns
|
||||||
|
|
||||||
| Pattern | Classes | Purpose |
|
| Pattern | Classes | Purpose |
|
||||||
|---------|---------|---------|
|
|---------|---------|---------|
|
||||||
| **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO |
|
| **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO |
|
||||||
| **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components |
|
| **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components |
|
||||||
| **Factory** | `StrategyFactory` | Decorator registration mechanism, dynamically create training strategies |
|
| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `GeneratorFactory`, `CallbackFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, generators, and callbacks |
|
||||||
| **Observer** | `TrainCallback` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
|
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
|
||||||
| **Singleton** | `TrainContext` | Training process global state management |
|
| **Singleton** | `TrainContext` | Training process global state management |
|
||||||
|
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
|
||||||
|
|
||||||
### Core Relationships
|
### Core Relationships
|
||||||
|
|
||||||
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
|
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
|
||||||
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss
|
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss
|
||||||
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
|
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
|
||||||
4. **Inference Flow**: `Server` → `Generator` → `InferenceCore` → `Transformer`
|
4. **Inference Flow**: `Server` → `GeneratorFactory` → `GeneratorCore` → `Transformer`, supports multiple generators (LoopGenerator, StreamGenerator, BatchGenerator)
|
||||||
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
|
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
|
||||||
|
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
|
||||||
|
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
|
||||||
|
8. **Scheduler Support**: `SchedulerFactory` creates learning rate schedulers (CosineScheduler, SGDRScheduler)
|
||||||
|
|
||||||
## 3. Training Process
|
## 3. Training Process
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue