diff --git a/.gitattributes b/.gitattributes index 60472b8..3d150ac 100644 --- a/.gitattributes +++ b/.gitattributes @@ -4,6 +4,7 @@ # Files that MUST use LF (Unix/Linux execution) *.sh text eol=lf *.py text eol=lf +*.md text eol=lf Dockerfile text eol=lf .dockerignore text eol=lf diff --git a/README.md b/README.md index 395aa89..ce5df5d 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ ### Features - 🚀 **High Performance**: Optimized for both training and inference with efficient parallelization. -- 🔧 **Flexible**: Support for seq/sft/dpo training, customizable model architectures. +- 🔧 **Flexible**: Support for seq/sft/dpo/grpo training, customizable model architectures. - 💡 **Easy to Use**: Simple API with comprehensive examples and demos. - 📦 **Lightweight**: Minimal dependencies, easy to deploy. - 🔬 **Research‑Friendly**: Modular design, easy to experiment with new ideas. diff --git a/assets/docs/README-zh-CN.md b/assets/docs/README-zh-CN.md index 4c87f4c..313f3c9 100644 --- a/assets/docs/README-zh-CN.md +++ b/assets/docs/README-zh-CN.md @@ -48,7 +48,7 @@ ### 特性 - 🚀 **高性能**: 训练与推理双向优化,高效并行。 -- 🔧 **灵活**: 支持 seq/sft/dpo 多种训练方式,可定制模型架构。 +- 🔧 **灵活**: 支持 seq/sft/dpo/grpo 多种训练方式,可定制模型架构。 - 💡 **易用**: 简洁的 API 与丰富的示例、演示。 - 📦 **轻量**: 依赖少,部署简单。 - 🔬 **研究友好**: 模块化设计,便于实验新想法。 diff --git a/assets/docs/dataflow.md b/assets/docs/dataflow.md index 3f2ce2c..18e7416 100644 --- a/assets/docs/dataflow.md +++ b/assets/docs/dataflow.md @@ -6,10 +6,11 @@ This document describes the data flow of the AstrAI project (a training and infe AstrAI adopts a modular design with the following main components: - **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools -- **Model Module** (`astrai/model/`): Transformer model and its submodules +- **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules - **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers - **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation - **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations +- **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration - **Parallel Module** (`astrai/parallel/`): Distributed training support The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**. @@ -42,9 +43,9 @@ flowchart LR subgraph C[Inference] direction TB - C1[Checkpoint] --> C2[ModelParameter] - C2 --> C3[Transformer + BpeTokenizer] - C3 --> C4[GenerationRequest + build_prompt] + C1[Checkpoint] --> C2[AutoModel] + C2 --> C3[Transformer + Tokenizer] + C3 --> C4[GenerationRequest + apply_chat_template] C4 --> C5[InferenceEngine] C5 --> C6[InferenceScheduler] C6 --> C7[apply_sampling_strategies] @@ -88,8 +89,9 @@ flowchart LR ### 2. Model Module -#### 2.1 Transformer (`transformer.py`) -- Core autoregressive decoder architecture +#### 2.1 Transformer / AutoModel (`transformer.py`, `automodel.py`) +- **`AutoModel`**: Base class for autoregressive language models with `from_pretrained()` and `save_pretrained()` methods +- **`Transformer`**: Core autoregressive decoder architecture (registered via `@AutoModel.register('transformer')`) - Contains embedding layer, multi-layer `DecoderBlock`, RMSNorm, and linear output head - Supports weight tying (`tie_weight=True`) to reduce parameter count - Uses Rotary Position Embedding (RoPE) to inject position information @@ -122,22 +124,31 @@ flowchart LR - **`SchedulerFactory`**: Factory pattern, supports registration of various schedulers (such as `cosine`, `sgdr`) - Scheduler is automatically created according to configuration and bound to optimizer -### 4. Inference Module +### 4. Factory Module -#### 4.1 Inference Engine (`engine.py`) +#### 4.1 Registry and BaseFactory (`factory.py`) +- **`Registry`**: Flexible registry for component classes with category and priority support +- **`BaseFactory`**: Generic factory class for component registration and creation +- Supports decorator-based registration pattern for extensible components +- Provides methods for registration, retrieval, and listing with filtering + +### 5. Inference Module + +#### 5.1 Inference Engine (`engine.py`) - **`InferenceEngine`**: Unified inference interface, supports streaming and non-streaming generation - **`InferenceScheduler`**: Continuous batching scheduler with dynamic batch composition - Manages task queue (`waiting_queue`, `active_tasks`) and KV cache allocation -#### 4.2 Scheduler (`scheduler.py`) +#### 5.2 Scheduler (`scheduler.py`) - **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED) - **`TaskStatus`**: Task state enumeration - **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits - Continuous batching: new requests can join at any time, completed requests are released immediately -#### 4.3 Request (`engine.py`) -- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, query, history, etc.) -- **`build_prompt`** (from `chat_template.py`): Converts query and history into ChatML format prompt string +#### 5.3 Request (`engine.py`) +- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.) +- **`messages` format**: List of message dictionaries with `role` (system/user/assistant) and `content` +- **`apply_chat_template`** (from `tokenizer.py`): Converts messages into prompt string using ChatML format - Provides streaming (`stream=True`) and non-streaming (`stream=False`) generation interfaces ## Training Data Flow - Detailed Steps @@ -176,11 +187,11 @@ flowchart LR ## Inference Data Flow - Detailed Steps 1. **Model Loading** - - Load `Transformer` model and tokenizer from checkpoint + - Load `Transformer` model from checkpoint via `AutoModel.from_pretrained()` - Set model to evaluation mode (`model.eval()`), enable inference mode (`torch.inference_mode`) 2. **Prompt Construction and Encoding** - - User query and history are converted to ChatML format string through `build_prompt` function in chat_template module + - User messages (list of dict with role and content) are converted to ChatML format string through `apply_chat_template` method in tokenizer - Tokenizer encodes prompt string to token ID sequence `input_ids` - For batch generation, use `pad_sequence` for padding @@ -207,5 +218,5 @@ flowchart LR The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using KV cache and sampling strategies. Clear interfaces between modules facilitate customization and extension. -> Document Update Time: 2026-03-30 +> Document Update Time: 2026-04-05 > Corresponding Code Version: Refer to version number defined in `pyproject.toml` \ No newline at end of file diff --git a/assets/docs/design.md b/assets/docs/design.md index 0260c7d..0ca3b93 100644 --- a/assets/docs/design.md +++ b/assets/docs/design.md @@ -36,10 +36,23 @@ classDiagram +int batch_size +int accumulation_steps +float max_grad_norm + +int start_epoch + +int start_batch +str ckpt_dir +int ckpt_interval + +int random_seed + +int num_workers + +int prefetch_factor + +bool pin_memory +int nprocs +str backend + +str master_addr + +str master_port + +Callable parallel_wrapper + +Callable state_dict_fn + +List[int] device_ids + +str device_type + +dict extra_kwargs +validate() } @@ -123,6 +136,16 @@ classDiagram } namespace astrai.model { + class AutoModel { + +ModelConfig config + +Dict _registry + +register(model_type) decorator + +get_model_class(model_type) Type + +from_pretrained(path, disable_random_init) nn.Module + +save_pretrained(save_directory) + +to(*args, **kwargs) Self + } + class Transformer { +ModelConfig config +RotaryEmbedding rotary_embeding @@ -440,6 +463,8 @@ classDiagram DatasetFactory ..> BaseDataset : creates BaseSegmentFetcher --> MultiSegmentFetcher : used by MultiSegmentFetcher --> BaseDataset : used by + AutoModel <|-- Transformer + AutoModel --> ModelConfig : contains Transformer --> DecoderBlock : uses Transformer --> RotaryEmbedding : uses Transformer --> Embedding : uses @@ -458,11 +483,12 @@ classDiagram |--------|------------|-------------| | **astrai.config** | ModelConfig, TrainConfig, ModelParameter | Configuration management | | **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint, DataLoader | Dataset loading and management | -| **astrai.model** | Transformer, DecoderBlock, GQA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | +| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.tokenize** | Tokenizer, BpeTokenizer | Tokenizer | | **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management | | **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest | Inference service with continuous batching | | **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel | +| **astrai.factory** | Registry, BaseFactory | Generic component registration | ### Design Patterns @@ -470,12 +496,13 @@ classDiagram |---------|---------|---------| | **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO | | **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components | -| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `CallbackFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, and callbacks | +| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `CallbackFactory`, `BaseFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, and callbacks | | **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) | | **Singleton** | `TrainContext` | Training process global state management | | **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support | | **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management | | **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module | +| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern | ### Core Relationships @@ -487,6 +514,7 @@ classDiagram 6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher` 7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors 8. **Scheduler Support**: `SchedulerFactory` creates learning rate schedulers (CosineScheduler, SGDRScheduler) +9. **AutoModel Loading**: `AutoModel.from_pretrained()` dynamically loads model based on `config.json` model_type, uses `Registry` pattern for model type registration ## 3. Training Process diff --git a/assets/docs/introduction.md b/assets/docs/introduction.md index 9d5e4a6..4f8604d 100644 --- a/assets/docs/introduction.md +++ b/assets/docs/introduction.md @@ -2,7 +2,21 @@ ### 1. Model Architecture -This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA (although KV cache is not currently implemented). The model is built by stacking 24 layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token. +This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA (although KV cache is not currently implemented). The model is built by stacking 32 layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token. + +The model now uses the **AutoModel** base class for flexible loading and saving: + +```python +from astrai.model import AutoModel + +# Load model from checkpoint +model = AutoModel.from_pretrained("path/to/model") + +# Save model to new directory +model.save_pretrained("path/to/save") +``` + +The Transformer model is registered via `@AutoModel.register('transformer')` decorator, allowing easy extension for new model types. ```mermaid flowchart TB diff --git a/assets/docs/params.md b/assets/docs/params.md index 6c6f550..72a246b 100644 --- a/assets/docs/params.md +++ b/assets/docs/params.md @@ -6,11 +6,12 @@ | Parameter | Description | Default Value | |-----------|-------------|---------------| -| `--train_type` | Training type (seq, sft, dpo) | required | +| `--train_type` | Training type (seq, sft, dpo, grpo) | required | +| `--model_type` | Model type for AutoModel loading (e.g., transformer) | transformer | | `--data_root_path` | Dataset root directory | required | | `--param_path` | Model parameters or checkpoint path | required | | `--n_epoch` | Total training epochs | 1 | -| `--batch_size` | Batch size | 1 | +| `--batch_size` | Batch size | 4 | | `--accumulation_steps` | Gradient accumulation steps | 1 | ### Learning Rate Scheduling @@ -42,7 +43,9 @@ | Parameter | Description | Default Value | |-----------|-------------|---------------| | `--random_seed` | Random seed | 3407 | -| `--num_workers` | DataLoader workers | 4 | +| `--num_workers` | DataLoader workers | 0 | +| `--prefetch_factor` | Prefetch factor for dataloader | None | +| `--pin_memory` | Enable pin_memory | False | | `--no_pin_memory` | Disable pin_memory | - | ### Distributed Training @@ -71,44 +74,58 @@ | Parameter | Description | Default Value | |-----------|-------------|---------------| -| `query` | Input text or text list | required | -| `history` | Conversation history | None | -| `system_prompt` | System prompt | None | -| `temperature` | Sampling temperature (higher = more random) | required | -| `top_p` | Nucleus sampling threshold | required | -| `top_k` | Top-k sampling count | required | -| `max_len` | Maximum generation length | model config max_len | +| `messages` | List of message dictionaries (role, content) | required | +| `temperature` | Sampling temperature (higher = more random) | 1.0 | +| `top_p` | Nucleus sampling threshold | 1.0 | +| `top_k` | Top-k sampling count | 50 | +| `max_len` | Maximum generation length | 1024 | | `stream` | Whether to stream output | False | ### Usage Example ```python -from astrai.config import ModelParameter +import torch +from astrai.model import AutoModel +from astrai.tokenize import Tokenizer from astrai.inference import InferenceEngine, GenerationRequest -# Load model -param = ModelParameter.load("your_model_dir") -param.to(device="cuda", dtype=torch.bfloat16) +# Load model using AutoModel +model = AutoModel.from_pretrained("your_model_dir") + +# Load tokenizer +tokenizer = Tokenizer("your_model_dir") # Create engine with separate model and tokenizer engine = InferenceEngine( - model=param.model, - tokenizer=param.tokenizer, - config=param.config, + model=model, + tokenizer=tokenizer, ) -# Build request +# Build request with messages format request = GenerationRequest( - query="Hello", - history=[], + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}, + ], temperature=0.8, top_p=0.95, top_k=50, + max_len=1024, ) # Generate (streaming) for token in engine.generate_with_request(request): print(token, end="", flush=True) + +# Or use simple generate interface +result = engine.generate( + prompt="Hello", + stream=False, + max_tokens=1024, + temperature=0.8, + top_p=0.95, + top_k=50, +) ``` ### Generation Modes diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py index 5f06ca4..996a16c 100644 --- a/astrai/inference/engine.py +++ b/astrai/inference/engine.py @@ -163,7 +163,7 @@ class InferenceEngine: self.save_state("./inference_state") except Exception: pass - + self.shutdown() return False @@ -178,9 +178,9 @@ class InferenceEngine: abort_on_exception: bool = True, ) -> Union[Generator[str, None, None], str, List[str]]: """Unified generation interface. - + Args: - abort_on_exception: If True, abort the generation when consumer + abort_on_exception: If True, abort the generation when consumer stops iterating (GeneratorExit/StopIteration). Default: True. """ is_batch = isinstance(prompt, list) @@ -188,8 +188,13 @@ class InferenceEngine: if stream: return self._generate_streaming( - prompts, is_batch, max_tokens, temperature, top_p, top_k, - abort_on_exception + prompts, + is_batch, + max_tokens, + temperature, + top_p, + top_k, + abort_on_exception, ) else: return self._generate_non_streaming( @@ -223,9 +228,9 @@ class InferenceEngine: abort_on_exception: bool = True, ) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]: """Generate with streaming output. - + Args: - abort_on_exception: If True, abort the task when generator is + abort_on_exception: If True, abort the task when generator is stopped early by consumer (GeneratorExit/StopIteration). """ if is_batch: @@ -292,54 +297,53 @@ class InferenceEngine: def shutdown(self) -> None: """Shutdown the engine and release all resources.""" - + # Stop scheduler first self.scheduler.stop() - + if torch.cuda.is_available(): torch.cuda.empty_cache() - + gc.collect() def force_stop(self) -> None: """ Force stop the engine immediately without saving state. - + Use this for emergency shutdown when graceful shutdown is not possible. """ - import os - + # Stop watching threads if any - if hasattr(self, 'stop_watching'): + if hasattr(self, "stop_watching"): self.stop_watching() - + # Unregister signal handlers - if hasattr(self, '_original_sigint'): + if hasattr(self, "_original_sigint"): signal.signal(signal.SIGINT, self._original_sigint) - if hasattr(self, '_original_sigterm'): + if hasattr(self, "_original_sigterm"): signal.signal(signal.SIGTERM, self._original_sigterm) - + # Force stop scheduler self.scheduler._running = False - + if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() - + gc.collect() @classmethod def create_and_run(cls, model, tokenizer, **kwargs): """ Create engine, run generation, and shutdown automatically. - + This is a convenience method for simple scripts. - + Args: model: The model to use tokenizer: The tokenizer to use **kwargs: Arguments passed to generate() - + Returns: Generated text result """ diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index 0e6b2dc..b8ce5e9 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -380,7 +380,7 @@ class InferenceScheduler: self._running = False if hasattr(self, "_loop_thread"): self._loop_thread.join(timeout=1.0) - + # Clear KV cache to free GPU memory if self.kv_cache is not None: k_cache, v_cache = self.kv_cache @@ -388,10 +388,10 @@ class InferenceScheduler: k_cache.detach() if v_cache is not None: v_cache.detach() - + # Clear seq mask self.seq_mask.detach() - + # Clear task lists self.waiting_queue.clear() self.active_tasks.clear()