refactor: 优化参数传递,清理导入样式

This commit is contained in:
ViperEkura 2026-04-03 22:06:32 +08:00
parent 3a7d98a950
commit 0852b852f8
51 changed files with 300 additions and 435 deletions

View File

@ -24,7 +24,6 @@ flowchart TB
C1[model_config.py<br/>Model Architecture]
C2[train_config.py<br/>Training Params]
C3[param_config.py<br/>Hyperparameters]
C4[schedule_config.py<br/>Scheduler Config]
end
subgraph Data["Data Module (data/)"]
@ -95,14 +94,13 @@ flowchart TB
### 1. Configuration Module (config/)
- **model_config.py**: Defines model structure parameters (layers, heads, dimensions, etc.), managed through `ModelConfig`.
- **train_config.py**: Sets training parameters (batch size, training stages PT/SFT/DPO, optimizers, etc.), loaded by `TrainConfig`.
- **train_config.py**: Sets training parameters (batch size, training stages SEQ/SFT/GRPO/DPO, optimizers, etc.), loaded by `TrainConfig`.
- **param_config.py**: Manages hyperparameters for training and inference.
- **schedule_config.py**: Controls learning rate strategies (cosine annealing) and training progress.
### 2. Data Module (data/)
- **dataset.py**: Dataset handling and loading.
- **sampler.py**: Data sampling for different training stages.
- **serialization.py**: Data serialization and deserialization.
- **serialization.py**: Data serialization and deserialization, checkpoint management.
- **tokenizer.py**: Text tokenization and encoding.
### 3. Model Module (model/)
@ -112,15 +110,15 @@ flowchart TB
### 4. Trainer Module (trainer/)
- **trainer.py**: Main training entry point.
- **train_context.py**: Training context management (model, optimizer, scheduler, metrics).
- **strategy.py**: Training strategies for PT/SFT/DPO stages.
- **schedule.py**: Learning rate scheduler.
- **strategy.py**: Training strategies for SEQ/SFT/GRPO/DPO stages via `StrategyFactory`.
- **schedule.py**: Learning rate scheduler implementation (cosine, SGDR, etc.).
- **train_callback.py**: Training callbacks (checkpoint, early stopping, etc.).
- **metric_util.py**: Metrics calculation and tracking.
### 5. Inference Module (inference/)
- **generator.py**: Text generation with various methods (sync, batch, streaming).
- **core.py**: Inference core with KV cache optimization.
- **server.py**: API service for inference.
- **server.py**: API service for inference (FastAPI + Uvicorn).
### 6. Parallel Module (parallel/)
- **setup.py**: Distributed initialization for multi-GPU/multi-machine training.
@ -134,7 +132,7 @@ flowchart TB
The common training process for large language models (LLM) typically includes three stages: **Pre-training (PT)**, **Supervised Fine-Tuning (SFT)**, and **Reinforcement Learning from Human Feedback (RLHF)**. This system is designed to support seamless end-to-end flow, achieving efficient switching and state management of different training stages through modular strategies, ensuring the model's capabilities gradually evolve from general language understanding to human-preference-aligned dialogue and instruction execution.
### **2.1 Pre-training Stage**
### **2.1 Pre-training Stage (SEQ/PT)**
The pre-training stage aims to build the model's foundational language capabilities and general knowledge representation. This stage performs self-supervised learning on large-scale, unlabeled corpora (typically covering hundreds of GB to TB of text data). The model architecture is based on the standard Transformer Decoder, trained through masked language modeling objectives (such as causal language modeling), enabling the model to learn vocabulary, grammar, semantics, and world knowledge embedded in text.
@ -152,7 +150,7 @@ $$
- $\theta$: Model parameters
- $P(x_t \mid x_{<t}; \theta)$: The probability of the model predicting the next token given the preceding context
The core of this stage lies in utilizing distributed parallel computing resources to achieve stable optimization of model parameters. The `PTStrategy` in the trainer module is specifically responsible for managing pre-training-specific data sampling, long sequence segmentation, and gradient accumulation logic. At the same time, the hardware adaptation module automatically selects the optimal parallel communication backend (such as NCCL) based on the runtime environment (such as NVIDIA GPU cluster) and performs computation graph optimization to maximize hardware utilization and training throughput.
The core of this stage lies in utilizing distributed parallel computing resources to achieve stable optimization of model parameters. The `SEQStrategy` (Pre-training) in the trainer module is specifically responsible for managing pre-training-specific data sampling, long sequence segmentation, and gradient accumulation logic. At the same time, the hardware adaptation module automatically selects the optimal parallel communication backend (such as NCCL) based on the runtime environment (such as NVIDIA GPU cluster) and performs computation graph optimization to maximize hardware utilization and training throughput.
Additionally, the system achieves zero-copy reading of massive data through the efficient memory-mapped loader (`MmapFileHandler`) in the data module, overcoming traditional IO bottlenecks.
@ -205,9 +203,9 @@ $$
- $\beta$: Temperature parameter (typically set to 0.1-0.5)
- Note: Implicitly learning reward function $r(x, y) = \beta \log \frac{\pi_\theta(y \mid x)}{\pi_{\text{ref}}(y \mid x)}$
In this stage, the trainer module enables the `RLHFStrategy` (or similar `DPOStrategy` direct preference optimization strategy). This strategy manages a complex training loop containing the policy model (LLM to be optimized), reference model (usually an SFT model snapshot), and reward model. The system flow is as follows:
In this stage, the trainer module enables the `DPOStrategy` (Direct Preference Optimization) or `GRPOStrategy` (Group Relative Policy Optimization). This strategy manages a complex training loop containing the policy model (LLM to be optimized), reference model (usually an SFT model snapshot), and reward model. The system flow is as follows:
1. **Preference Data Collection and Reward Modeling**: First, by collecting human annotators' ranking preferences for multiple model-generated results for the same prompt, a separate reward model (RM) is trained. This model learns to output a scalar reward score for generated text to quantify the degree of alignment with human preferences.
2. **Policy Optimization**: Then, using the reward model as the optimization signal, the SFT model (as the policy) is fine-tuned through reinforcement learning algorithms. The goal of policy optimization is to maximize the expected cumulative reward obtained from the reward model, while constraining the output distribution of the policy model and reference model from deviating too much through a KL divergence penalty term, preventing mode collapse and maintaining generation diversity. The training context manager maintains the states of the policy model, reference model, and reward model (or value function model) simultaneously at this stage, and coordinates complex multi-stage gradient computations.
1. **Preference Data Collection and Reward Modeling**: First, by collecting human annotators' ranking preferences for multiple model-generated results for the same prompt, a separate reward model (RM) can be trained. This model learns to output a scalar reward score for generated text to quantify the degree of alignment with human preferences.
2. **Policy Optimization**: Then, using the reward model as the optimization signal, the SFT model (as the policy) is fine-tuned through reinforcement learning algorithms (DPO/GRPO). The goal of policy optimization is to maximize the expected cumulative reward obtained from the reward model, while constraining the output distribution of the policy model and reference model from deviating too much through a KL divergence penalty term, preventing mode collapse and maintaining generation diversity. The training context manager maintains the states of the policy model, reference model, and reward model (or value function model) simultaneously at this stage, and coordinates complex multi-stage gradient computations.
Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence. The system, through unified `Trainer` interface and strategy pattern design, makes each stage of training highly reusable at the code level, clearly decoupled at the process level, providing an efficient, flexible, and scalable engineering foundation for large-scale language model research and iteration.

View File

@ -5,17 +5,17 @@ from astrai.config import (
ModelConfig,
TrainConfig,
)
from astrai.model.transformer import Transformer
from astrai.data import DatasetLoader, BpeTokenizer
from astrai.data import BpeTokenizer, DatasetLoader
from astrai.inference.generator import (
GenerationRequest,
LoopGenerator,
StreamGenerator,
BatchGenerator,
EmbeddingEncoder,
GenerationRequest,
GeneratorFactory,
LoopGenerator,
StreamGenerator,
)
from astrai.trainer import Trainer, StrategyFactory, SchedulerFactory
from astrai.model.transformer import Transformer
from astrai.trainer import SchedulerFactory, StrategyFactory, Trainer
__all__ = [
"Transformer",

View File

@ -1,14 +1,7 @@
from astrai.config.model_config import ModelConfig
from astrai.config.param_config import BaseModelIO, ModelParameter
from astrai.config.schedule_config import (
ScheduleConfig,
CosineScheduleConfig,
SGDRScheduleConfig,
ScheduleConfigFactory,
)
from astrai.config.train_config import TrainConfig
__all__ = [
# Base I/O
"BaseModelIO",
@ -16,9 +9,4 @@ __all__ = [
# Model configuration
"ModelConfig",
"TrainConfig",
# Schedule configuration
"ScheduleConfig",
"CosineScheduleConfig",
"SGDRScheduleConfig",
"ScheduleConfigFactory",
]

View File

@ -1,5 +1,4 @@
import json
from dataclasses import asdict, dataclass
from typing import Optional, Self

View File

@ -1,13 +1,13 @@
import torch.nn as nn
import safetensors.torch as st
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Optional, Self, Union
from pathlib import Path
from typing import Self, Union
import safetensors.torch as st
import torch.nn as nn
from astrai.data.tokenizer import BpeTokenizer
from astrai.config.model_config import ModelConfig
from astrai.data.tokenizer import BpeTokenizer
from astrai.model.transformer import Transformer

View File

@ -1,149 +0,0 @@
from typing import Any, Dict, Type
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
@dataclass
class ScheduleConfig(ABC):
"""Base configuration class for learning rate schedulers.
Provides common validation and interface for all schedule types.
"""
schedule_type: str = field(
default="cosine",
metadata={
"help": "Type of learning rate schedule.",
"choices": ["cosine", "sgdr"],
},
)
warmup_steps: int = field(
default=1000, metadata={"help": "Number of warmup steps."}
)
min_rate: float = field(
default=0.05, metadata={"help": "Minimum learning rate multiplier."}
)
@abstractmethod
def get_kwargs(self) -> Dict[str, Any]:
"""Get configuration kwargs for scheduler creation."""
raise NotImplementedError
def validate(self) -> None:
"""Validate configuration parameters."""
if self.warmup_steps < 0:
raise ValueError(
f"warmup_steps must be non-negative, got {self.warmup_steps}"
)
if not 0 <= self.min_rate <= 1:
raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}")
@dataclass
class CosineScheduleConfig(ScheduleConfig):
"""Cosine annealing learning rate schedule configuration."""
total_steps: int = field(
default=None, metadata={"help": "Total training steps for cosine schedule."}
)
def __post_init__(self) -> None:
self.schedule_type = "cosine"
self.validate()
def get_kwargs(self) -> Dict[str, Any]:
if self.total_steps is None:
raise ValueError("total_steps must be specified for cosine schedule")
return {
"schedule_type": self.schedule_type,
"warmup_steps": self.warmup_steps,
"lr_decay_steps": self.total_steps - self.warmup_steps,
"min_rate": self.min_rate,
}
def validate(self) -> None:
super().validate()
if self.total_steps is not None and self.total_steps <= self.warmup_steps:
raise ValueError(
f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})"
)
@dataclass
class SGDRScheduleConfig(ScheduleConfig):
"""Stochastic Gradient Descent with Warm Restarts schedule configuration."""
cycle_length: int = field(
default=1000, metadata={"help": "Length of the first cycle in steps."}
)
t_mult: int = field(
default=2, metadata={"help": "Multiplier for cycle length growth."}
)
def __post_init__(self) -> None:
self.schedule_type = "sgdr"
self.validate()
def get_kwargs(self) -> Dict[str, Any]:
return {
"schedule_type": self.schedule_type,
"warmup_steps": self.warmup_steps,
"cycle_length": self.cycle_length,
"min_rate": self.min_rate,
"t_mult": self.t_mult,
}
def validate(self) -> None:
super().validate()
if self.cycle_length <= 0:
raise ValueError(f"cycle_length must be positive, got {self.cycle_length}")
if self.t_mult < 1:
raise ValueError(f"t_mult must be >= 1, got {self.t_mult}")
class ScheduleConfigFactory:
"""Factory class for creating ScheduleConfig instances.
Supports both direct instantiation and factory creation methods.
Example usage:
# Direct creation
config = CosineScheduleConfig(total_steps=10000)
# Factory method
config = ScheduleConfigFactory.create("cosine", total_steps=10000)
"""
CONFIG_MAP: Dict[str, Type[ScheduleConfig]] = {
"cosine": CosineScheduleConfig,
"sgdr": SGDRScheduleConfig,
}
@classmethod
def create(cls, schedule_type: str, **kwargs) -> ScheduleConfig:
"""Create a schedule config instance.
Args:
schedule_type: Type of schedule ("cosine", "sgdr")
**kwargs: Arguments passed to the config constructor
Returns:
ScheduleConfig instance
Raises:
ValueError: If schedule_type is not supported
"""
if schedule_type not in cls.CONFIG_MAP:
raise ValueError(
f"Unknown schedule type: '{schedule_type}'. "
f"Supported types: {sorted(cls.CONFIG_MAP.keys())}"
)
config_cls = cls.CONFIG_MAP[schedule_type]
return config_cls(**kwargs)
@classmethod
def available_types(cls) -> list:
"""Return list of available schedule type names."""
return list(cls.CONFIG_MAP.keys())

View File

@ -1,11 +1,11 @@
import torch.nn as nn
from torch.utils.data import Dataset
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from dataclasses import dataclass, field
from typing import Callable, List, Optional
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import Dataset
@dataclass
class TrainConfig:

View File

@ -1,16 +1,15 @@
from astrai.data.dataset import (
BaseDataset,
SEQDataset,
DatasetFactory,
DatasetLoader,
DPODataset,
SFTDataset,
GRPODataset,
MultiSegmentFetcher,
DatasetLoader,
DatasetFactory,
SEQDataset,
SFTDataset,
)
from astrai.data.tokenizer import BpeTokenizer
from astrai.data.sampler import ResumableDistributedSampler
from astrai.data.tokenizer import BpeTokenizer
__all__ = [
# Base classes

View File

@ -1,13 +1,14 @@
"""Dataset implementations with factory pattern for training."""
import torch
import bisect
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Union
import torch
from torch import Tensor
from torch.utils.data import Dataset
from astrai.data.serialization import load_h5
from typing import List, Dict, Optional, Union
class BaseSegmentFetcher:

View File

@ -1,8 +1,8 @@
from typing import Optional
import torch
import torch.distributed as dist
from torch.utils.data import Dataset, Sampler
from typing import Optional
class ResumableDistributedSampler(Sampler[int]):

View File

@ -1,13 +1,14 @@
import os
import h5py
import torch
import json
import safetensors.torch as st
import torch.distributed as dist
import os
from pathlib import Path
from torch import Tensor
from typing import Any, Dict, List
import h5py
import safetensors.torch as st
import torch
import torch.distributed as dist
from torch import Tensor
from astrai.parallel.setup import get_rank

View File

@ -1,8 +1,9 @@
from abc import ABC, abstractmethod
from tokenizers import Tokenizer, decoders, processors, normalizers, pre_tokenizers
from typing import List, Union
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer as BpeTrainerImpl
from typing import List, Union
class BaseTokenizer(ABC):

View File

@ -1,16 +1,15 @@
from astrai.inference.core import (
GeneratorCore,
EmbeddingEncoderCore,
GeneratorCore,
KVCacheManager,
)
from astrai.inference.generator import (
GenerationRequest,
LoopGenerator,
StreamGenerator,
BatchGenerator,
EmbeddingEncoder,
GenerationRequest,
GeneratorFactory,
LoopGenerator,
StreamGenerator,
)
__all__ = [

View File

@ -1,8 +1,9 @@
import torch
from typing import Any, Callable, List, Optional, Self, Tuple, Union
import torch
from torch import Tensor
from typing import Any, Callable, List, Tuple, Union, Optional, Self
from astrai.config import ModelParameter, ModelConfig
from astrai.config import ModelConfig, ModelParameter
def apply_sampling_strategies(

View File

@ -1,10 +1,11 @@
import torch
from dataclasses import dataclass
from torch import Tensor
from typing import List, Tuple, Union, Optional, Generator
from astrai.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheManager
from astrai.config.param_config import ModelParameter
from typing import Generator, List, Optional, Tuple, Union
import torch
from torch import Tensor
from astrai.config.param_config import ModelParameter
from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager
HistoryType = List[Tuple[str, str]]

View File

@ -1,13 +1,15 @@
import torch
import uvicorn
import logging
from pathlib import Path
from typing import List, Optional, Dict, Any, Tuple
from typing import Any, Dict, List, Optional, Tuple
import torch
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from astrai.config.param_config import ModelParameter
from astrai.inference.generator import GeneratorFactory, GenerationRequest
from astrai.inference.generator import GenerationRequest, GeneratorFactory
logger = logging.getLogger(__name__)

View File

@ -1,9 +1,9 @@
from astrai.model.module import (
GQA,
MLP,
DecoderBlock,
Linear,
RMSNorm,
MLP,
GQA,
DecoderBlock,
)
from astrai.model.transformer import Transformer

View File

@ -1,9 +1,9 @@
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import Optional, Tuple
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:

View File

@ -1,12 +1,13 @@
from typing import Any, Mapping, Optional, Tuple
import torch
import torch.nn as nn
from torch import Tensor
from typing import Any, Mapping, Optional, Tuple
from astrai.config.model_config import ModelConfig
from astrai.model.module import (
Embedding,
DecoderBlock,
Embedding,
Linear,
RMSNorm,
RotaryEmbedding,

View File

@ -1,14 +1,13 @@
from astrai.parallel.module import ColumnParallelLinear, RowParallelLinear
from astrai.parallel.setup import (
get_world_size,
get_rank,
get_current_device,
get_rank,
get_world_size,
only_on_rank,
setup_parallel,
spawn_parallel_fn,
)
from astrai.parallel.module import RowParallelLinear, ColumnParallelLinear
__all__ = [
"get_world_size",
"get_rank",

View File

@ -1,10 +1,10 @@
from typing import Dict
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch import Tensor
from typing import Dict
class ParallelModel(nn.Module):

View File

@ -1,12 +1,12 @@
import os
from contextlib import contextmanager
from functools import wraps
from typing import Callable, List, Optional
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from functools import wraps
from contextlib import contextmanager
from typing import Callable, List, Optional
def get_current_device():
return os.environ["LOCAL_DEVICE"]

View File

@ -1,15 +1,14 @@
from astrai.trainer.trainer import Trainer
from astrai.trainer.strategy import StrategyFactory, BaseStrategy
from astrai.trainer.schedule import SchedulerFactory, BaseScheduler
from astrai.trainer.schedule import BaseScheduler, SchedulerFactory
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
from astrai.trainer.train_callback import (
TrainCallback,
GradientClippingCallback,
SchedulerCallback,
CheckpointCallback,
ProgressBarCallback,
GradientClippingCallback,
MetricLoggerCallback,
ProgressBarCallback,
SchedulerCallback,
TrainCallback,
)
from astrai.trainer.trainer import Trainer
__all__ = [
# Main trainer

View File

@ -1,6 +1,7 @@
import torch.nn as nn
from typing import Dict
import torch.nn as nn
def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]:
"""Compute gradient norm for each parameter in the model."""

View File

@ -1,10 +1,10 @@
"""Learning rate scheduler implementations with factory pattern."""
import math
from abc import abstractmethod, ABC
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Type
from torch.optim.lr_scheduler import LRScheduler
from astrai.config.schedule_config import ScheduleConfig
class BaseScheduler(LRScheduler, ABC):
@ -37,10 +37,6 @@ class SchedulerFactory:
...
scheduler = SchedulerFactory.create(optimizer, "custom", **kwargs)
# Or from config
config = CosineScheduleConfig(total_steps=10000)
scheduler = SchedulerFactory.load(optimizer, config)
"""
SCHEDULER_MAP: Dict[str, Type[BaseScheduler]] = {}
@ -67,7 +63,7 @@ class SchedulerFactory:
return decorator
@classmethod
def create(cls, optimizer, schedule_type: str, **kwargs) -> BaseScheduler:
def create(cls, optimizer, schedule_type: str = "none", **kwargs) -> BaseScheduler:
"""Create a scheduler instance by type name.
Args:
@ -90,29 +86,13 @@ class SchedulerFactory:
scheduler_cls = cls.SCHEDULER_MAP[schedule_type]
return scheduler_cls(optimizer, **kwargs)
@staticmethod
def load(optimizer, schedule_config: ScheduleConfig) -> BaseScheduler:
"""Create a scheduler from a ScheduleConfig object.
Args:
optimizer: PyTorch optimizer
schedule_config: ScheduleConfig instance
Returns:
Scheduler instance
"""
kwargs = schedule_config.get_kwargs()
schedule_type = kwargs.pop("schedule_type")
return SchedulerFactory.create(optimizer, schedule_type, **kwargs)
@classmethod
def available_types(cls) -> list:
"""Return list of registered scheduler type names."""
return list(cls.SCHEDULER_MAP.keys())
# ============== Scheduler Classes ==============
# All scheduler classes are registered at class definition time using the decorator
# ----------- Scheduler implementations -----------
@SchedulerFactory.register("cosine")

View File

@ -1,14 +1,14 @@
"""Training strategy implementations with factory pattern."""
import copy
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch import Tensor
from typing import Any, Callable, Dict, Union
from abc import ABC, abstractmethod
from torch.nn.parallel import DistributedDataParallel as DDP
def unwrap_model(model: nn.Module) -> nn.Module:

View File

@ -1,25 +1,25 @@
import os
import json
import os
import time
import torch.nn as nn
from pathlib import Path
from tqdm import tqdm
from torch.nn.utils import clip_grad_norm_
from typing import Callable, List, Optional, Protocol
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm
from astrai.data.serialization import Checkpoint
from astrai.parallel import only_on_rank
from astrai.trainer.metric_util import (
ctx_get_grad_max,
ctx_get_grad_mean,
ctx_get_grad_min,
ctx_get_grad_nan_num,
ctx_get_grad_norm,
ctx_get_grad_std,
ctx_get_loss,
ctx_get_lr,
ctx_get_grad_max,
ctx_get_grad_min,
ctx_get_grad_norm,
ctx_get_grad_mean,
ctx_get_grad_std,
ctx_get_grad_nan_num,
)
from astrai.data.serialization import Checkpoint
from astrai.trainer.train_context import TrainContext

View File

@ -1,16 +1,16 @@
from dataclasses import dataclass, field
from typing import Optional, Self
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader
from astrai.config.train_config import TrainConfig
from astrai.data import ResumableDistributedSampler
from astrai.data.serialization import Checkpoint
from astrai.trainer.strategy import StrategyFactory, BaseStrategy
from astrai.config.train_config import TrainConfig
from astrai.parallel.setup import get_current_device, get_world_size, get_rank
from dataclasses import dataclass, field
from typing import Optional, Self
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
@dataclass

View File

@ -1,17 +1,18 @@
import logging
from typing import Optional, List
from typing import List, Optional
from astrai.config import TrainConfig
from astrai.trainer.train_callback import (
TrainCallback,
ProgressBarCallback,
CheckpointCallback,
MetricLoggerCallback,
GradientClippingCallback,
SchedulerCallback,
)
from astrai.trainer.train_context import TrainContext, TrainContextBuilder
from astrai.data.serialization import Checkpoint
from astrai.parallel.setup import spawn_parallel_fn
from astrai.trainer.train_callback import (
CheckpointCallback,
GradientClippingCallback,
MetricLoggerCallback,
ProgressBarCallback,
SchedulerCallback,
TrainCallback,
)
from astrai.trainer.train_context import TrainContext, TrainContextBuilder
logger = logging.getLogger(__name__)

View File

@ -1,4 +1,5 @@
from pathlib import Path
from huggingface_hub import snapshot_download
PROJECT_ROOT = Path(__file__).resolve().parents[2]

View File

@ -1,7 +1,9 @@
import torch
from pathlib import Path
import torch
from astrai.config.param_config import ModelParameter
from astrai.inference.generator import GeneratorFactory, GenerationRequest
from astrai.inference.generator import GenerationRequest, GeneratorFactory
PROJECT_ROOT = Path(__file__).resolve().parents[2]
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")

View File

@ -1,7 +1,9 @@
import torch
from pathlib import Path
import torch
from astrai.config.param_config import ModelParameter
from astrai.inference.generator import GeneratorFactory, GenerationRequest
from astrai.inference.generator import GenerationRequest, GeneratorFactory
PROJECT_ROOT = Path(__file__).resolve().parents[2]
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")

View File

@ -1,7 +1,9 @@
import torch
from pathlib import Path
import torch
from astrai.config.param_config import ModelParameter
from astrai.inference.generator import GeneratorFactory, GenerationRequest
from astrai.inference.generator import GenerationRequest, GeneratorFactory
PROJECT_ROOT = Path(__file__).resolve().parents[2]
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")

View File

@ -1,6 +1,8 @@
import torch
from typing import Dict, Any
from dataclasses import dataclass
from typing import Any, Dict
import torch
from astrai.model.transformer import ModelConfig, Transformer

View File

@ -1,6 +1,7 @@
import torch
import json
import argparse
import json
import torch
from astrai.config.param_config import ModelParameter
from astrai.inference.generator import BatchGenerator, GenerationRequest

View File

@ -1,11 +1,12 @@
import argparse
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse
import tqdm
from torch import Tensor
from astrai.config.param_config import ModelParameter

View File

@ -1,5 +1,6 @@
import argparse
from pathlib import Path
from astrai.inference.server import run_server

View File

@ -1,15 +1,16 @@
import os
import argparse
import os
from functools import partial
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from functools import partial
from astrai.config import ModelParameter, TrainConfig
from astrai.data import DatasetLoader
from astrai.config import ModelParameter, TrainConfig, CosineScheduleConfig
from astrai.trainer import Trainer, SchedulerFactory
from astrai.parallel import get_rank
from astrai.trainer import SchedulerFactory, Trainer
def parse_args() -> argparse.Namespace:
@ -158,7 +159,7 @@ def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
def create_scheduler(
optimizer: optim.Optimizer, **kwargs
) -> optim.lr_scheduler.LRScheduler:
return SchedulerFactory.load(optimizer, **kwargs)
return SchedulerFactory.create(optimizer, **kwargs)
def prepare_checkpoint(model: nn.Module) -> dict:
@ -211,11 +212,6 @@ def train(
stride=stride,
)
schedule_config = CosineScheduleConfig(
warmup_steps=warmup_steps,
total_steps=len(dataset) * n_epoch // (batch_size * nprocs),
)
optimizer_fn = partial(
create_optimizer,
**{
@ -224,7 +220,16 @@ def train(
"weight_decay": adamw_weight_decay,
},
)
scheduler_fn = partial(create_scheduler, **{"schedule_config": schedule_config})
toltal_steps = len(dataset) * n_epoch // (batch_size * nprocs)
scheduler_fn = partial(
create_scheduler,
**{
"scheduler": "cosine",
"warmup_steps": warmup_steps,
"lr_decay_steps": toltal_steps - warmup_steps,
},
)
train_config = TrainConfig(
model=model,

View File

@ -1,11 +1,12 @@
import os
import json
import numpy as np
import tempfile
import os
import shutil
import torch
import tempfile
import numpy as np
import pytest
import safetensors.torch as st
import torch
from tokenizers import pre_tokenizers
from torch.utils.data import Dataset

View File

@ -1,9 +1,10 @@
import torch
import tempfile
import torch.distributed as dist
import torch
import torch.distributed as dist
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from astrai.data.serialization import Checkpoint
from astrai.parallel.setup import get_rank, spawn_parallel_fn

View File

@ -1,8 +1,8 @@
import torch
import numpy as np
import torch
from astrai.data.serialization import save_h5
from astrai.data.dataset import *
from astrai.data.serialization import save_h5
def test_dataset_loader_random_paths(base_test_env):

View File

@ -1,5 +1,5 @@
from astrai.trainer import *
from astrai.data import *
from astrai.trainer import *
def test_random_sampler_consistency(random_dataset):

View File

@ -1,8 +1,10 @@
"""Shared fixtures for inference tests."""
import pytest
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from astrai.inference.server import app

View File

@ -1,9 +1,6 @@
"""Unit tests for the inference HTTP server."""
import pytest
from unittest.mock import MagicMock, patch
from fastapi.testclient import TestClient
from astrai.inference.server import app
def test_health_no_model(client, monkeypatch):

View File

@ -1,10 +1,12 @@
import os
import torch
from astrai.trainer import *
from astrai.config import *
from astrai.model import *
from astrai.data import *
from astrai.inference.generator import EmbeddingEncoderCore, GeneratorCore
from astrai.model import *
from astrai.trainer import *
def test_model_parameter(test_env):

View File

@ -1,11 +1,13 @@
import os
import json
import torch
import pytest
import os
import tempfile
import pytest
import safetensors.torch as st
from astrai.model.transformer import Transformer
import torch
from astrai.config.model_config import ModelConfig
from astrai.model.transformer import Transformer
@pytest.fixture

View File

@ -1,6 +1,9 @@
import pytest
import torch
from torch.utils.data import Dataset
import pytest
from astrai.config import TrainConfig
from astrai.trainer.schedule import SchedulerFactory
class TrainerDataset(Dataset):
@ -54,13 +57,11 @@ def create_train_config(
Returns:
TrainConfig instance configured for testing
"""
from astrai.config import TrainConfig
from astrai.config.schedule_config import CosineScheduleConfig
from astrai.trainer.schedule import SchedulerFactory
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
optimizer_fn = lambda m: torch.optim.AdamW(m.parameters(), lr=0.001)
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
scheduler_fn = lambda optim: SchedulerFactory.create(
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
)
return TrainConfig(
strategy=strategy,

View File

@ -6,10 +6,10 @@ from astrai.trainer import *
def test_callback_integration(base_test_env, random_dataset):
"""Test that all callbacks are properly integrated"""
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
scheduler_fn = lambda optim: SchedulerFactory.create(
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
)
train_config = TrainConfig(
model=base_test_env["model"],

View File

@ -1,18 +1,20 @@
import os
import torch
import numpy as np
import torch
from astrai.config import *
from astrai.trainer import *
from astrai.data.serialization import Checkpoint
from astrai.trainer import *
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
"""Simulate early stopping behavior"""
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
scheduler_fn = lambda optim: SchedulerFactory.create(
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
)
train_config = TrainConfig(
strategy="seq",

View File

@ -1,10 +1,9 @@
import torch
import numpy as np
import pytest
import torch
from astrai.config import *
from astrai.trainer.schedule import *
from astrai.data.dataset import *
from astrai.trainer.schedule import *
def test_schedule_factory_random_configs():
@ -16,41 +15,57 @@ def test_schedule_factory_random_configs():
# Test multiple random configurations
for _ in range(5): # Test 5 random configurations
schedule_configs = [
CosineScheduleConfig(
warmup_steps=np.random.randint(50, 200),
total_steps=np.random.randint(1000, 5000),
min_rate=np.random.uniform(0.01, 0.1),
),
SGDRScheduleConfig(
warmup_steps=np.random.randint(50, 200),
cycle_length=np.random.randint(500, 2000),
t_mult=np.random.randint(1, 3),
min_rate=np.random.uniform(0.01, 0.1),
),
]
for config in schedule_configs:
# Validate configuration
config.validate()
# Create scheduler using factory
scheduler = SchedulerFactory.load(optimizer, config)
# Verify scheduler type
if isinstance(config, CosineScheduleConfig):
assert isinstance(scheduler, CosineScheduler)
assert scheduler.warmup_steps == config.warmup_steps
assert (
scheduler.lr_decay_steps == config.total_steps - config.warmup_steps
# Test multiple random configurations
cosine_params = {
"schedule_type": "cosine",
"warmup_steps": np.random.randint(50, 200),
"total_steps": np.random.randint(1000, 5000),
"min_rate": np.random.uniform(0.01, 0.1),
}
sgdr_params = {
"schedule_type": "sgdr",
"warmup_steps": np.random.randint(50, 200),
"cycle_length": np.random.randint(500, 2000),
"t_mult": np.random.randint(1, 3),
"min_rate": np.random.uniform(0.01, 0.1),
}
for params in [cosine_params, sgdr_params]:
schedule_type = params["schedule_type"]
# Convert parameters for scheduler constructor
if schedule_type == "cosine":
warmup_steps = params["warmup_steps"]
total_steps = params["total_steps"]
min_rate = params["min_rate"]
lr_decay_steps = total_steps - warmup_steps
scheduler = SchedulerFactory.create(
optimizer,
schedule_type,
warmup_steps=warmup_steps,
lr_decay_steps=lr_decay_steps,
min_rate=min_rate,
)
assert isinstance(scheduler, CosineScheduler)
assert scheduler.warmup_steps == warmup_steps
assert scheduler.lr_decay_steps == lr_decay_steps
assert scheduler.min_rate == min_rate
elif schedule_type == "sgdr":
warmup_steps = params["warmup_steps"]
cycle_length = params["cycle_length"]
t_mult = params["t_mult"]
min_rate = params["min_rate"]
scheduler = SchedulerFactory.create(
optimizer,
schedule_type,
warmup_steps=warmup_steps,
cycle_length=cycle_length,
t_mult=t_mult,
min_rate=min_rate,
)
assert scheduler.min_rate == config.min_rate
elif isinstance(config, SGDRScheduleConfig):
assert isinstance(scheduler, SGDRScheduler)
assert scheduler.warmup_steps == config.warmup_steps
assert scheduler.cycle_length == config.cycle_length
assert scheduler.t_mult == config.t_mult
assert scheduler.min_rate == config.min_rate
assert scheduler.warmup_steps == warmup_steps
assert scheduler.cycle_length == cycle_length
assert scheduler.t_mult == t_mult
assert scheduler.min_rate == min_rate
# Test scheduler state dict functionality
state_dict = scheduler.state_dict()
@ -76,16 +91,25 @@ def test_schedule_factory_edge_cases():
# Test edge cases for CosineScheduleConfig
edge_cases = [
# Minimal warmup and steps
CosineScheduleConfig(warmup_steps=1, total_steps=10, min_rate=0.01),
{"warmup_steps": 1, "total_steps": 10, "min_rate": 0.01},
# Large values
CosineScheduleConfig(warmup_steps=1000, total_steps=10000, min_rate=0.5),
{"warmup_steps": 1000, "total_steps": 10000, "min_rate": 0.5},
# Zero min_rate (edge case)
CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.0),
{"warmup_steps": 100, "total_steps": 1000, "min_rate": 0.0},
]
for config in edge_cases:
config.validate()
scheduler = SchedulerFactory.load(optimizer, config)
for params in edge_cases:
warmup_steps = params["warmup_steps"]
total_steps = params["total_steps"]
min_rate = params["min_rate"]
lr_decay_steps = total_steps - warmup_steps
scheduler = SchedulerFactory.create(
optimizer,
"cosine",
warmup_steps=warmup_steps,
lr_decay_steps=lr_decay_steps,
min_rate=min_rate,
)
assert scheduler is not None
# Test multiple steps
@ -93,34 +117,24 @@ def test_schedule_factory_edge_cases():
scheduler.step()
def test_schedule_factory_invalid_configs():
"""Test scheduler factory with invalid configurations"""
# Test invalid configurations that should raise errors
invalid_configs = [
# Negative warmup steps
{"warmup_steps": -10, "total_steps": 1000, "min_rate": 0.1},
# Total steps less than warmup steps
{"warmup_steps": 500, "total_steps": 400, "min_rate": 0.1},
# Invalid min_rate
{"warmup_steps": 100, "total_steps": 1000, "min_rate": -0.1},
{"warmup_steps": 100, "total_steps": 1000, "min_rate": 1.1},
]
for kwargs in invalid_configs:
with pytest.raises(ValueError):
config = CosineScheduleConfig(**kwargs)
config.validate()
def test_schedule_factory_state_persistence():
"""Test scheduler state persistence (save/load)"""
model = torch.nn.Linear(10, 2)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
config = CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.1)
scheduler = SchedulerFactory.load(optimizer, config)
# Create scheduler directly with parameters
warmup_steps = 100
total_steps = 1000
min_rate = 0.1
lr_decay_steps = total_steps - warmup_steps
scheduler = SchedulerFactory.create(
optimizer,
"cosine",
warmup_steps=warmup_steps,
lr_decay_steps=lr_decay_steps,
min_rate=min_rate,
)
# Take a few steps
for _ in range(5):
@ -129,8 +143,14 @@ def test_schedule_factory_state_persistence():
# Save state
state_dict = scheduler.state_dict()
# Create new scheduler and load state
new_scheduler = SchedulerFactory.load(optimizer, config)
# Create new scheduler with same parameters
new_scheduler = SchedulerFactory.create(
optimizer,
"cosine",
warmup_steps=warmup_steps,
lr_decay_steps=lr_decay_steps,
min_rate=min_rate,
)
new_scheduler.load_state_dict(state_dict)
# Verify states match

View File

@ -1,5 +1,3 @@
import torch
from astrai.data.dataset import *
from astrai.trainer import Trainer