refactor: 优化参数传递,清理导入样式

This commit is contained in:
ViperEkura 2026-04-03 22:06:32 +08:00
parent 3a7d98a950
commit 0852b852f8
51 changed files with 300 additions and 435 deletions

View File

@ -24,7 +24,6 @@ flowchart TB
C1[model_config.py<br/>Model Architecture] C1[model_config.py<br/>Model Architecture]
C2[train_config.py<br/>Training Params] C2[train_config.py<br/>Training Params]
C3[param_config.py<br/>Hyperparameters] C3[param_config.py<br/>Hyperparameters]
C4[schedule_config.py<br/>Scheduler Config]
end end
subgraph Data["Data Module (data/)"] subgraph Data["Data Module (data/)"]
@ -95,14 +94,13 @@ flowchart TB
### 1. Configuration Module (config/) ### 1. Configuration Module (config/)
- **model_config.py**: Defines model structure parameters (layers, heads, dimensions, etc.), managed through `ModelConfig`. - **model_config.py**: Defines model structure parameters (layers, heads, dimensions, etc.), managed through `ModelConfig`.
- **train_config.py**: Sets training parameters (batch size, training stages PT/SFT/DPO, optimizers, etc.), loaded by `TrainConfig`. - **train_config.py**: Sets training parameters (batch size, training stages SEQ/SFT/GRPO/DPO, optimizers, etc.), loaded by `TrainConfig`.
- **param_config.py**: Manages hyperparameters for training and inference. - **param_config.py**: Manages hyperparameters for training and inference.
- **schedule_config.py**: Controls learning rate strategies (cosine annealing) and training progress.
### 2. Data Module (data/) ### 2. Data Module (data/)
- **dataset.py**: Dataset handling and loading. - **dataset.py**: Dataset handling and loading.
- **sampler.py**: Data sampling for different training stages. - **sampler.py**: Data sampling for different training stages.
- **serialization.py**: Data serialization and deserialization. - **serialization.py**: Data serialization and deserialization, checkpoint management.
- **tokenizer.py**: Text tokenization and encoding. - **tokenizer.py**: Text tokenization and encoding.
### 3. Model Module (model/) ### 3. Model Module (model/)
@ -112,15 +110,15 @@ flowchart TB
### 4. Trainer Module (trainer/) ### 4. Trainer Module (trainer/)
- **trainer.py**: Main training entry point. - **trainer.py**: Main training entry point.
- **train_context.py**: Training context management (model, optimizer, scheduler, metrics). - **train_context.py**: Training context management (model, optimizer, scheduler, metrics).
- **strategy.py**: Training strategies for PT/SFT/DPO stages. - **strategy.py**: Training strategies for SEQ/SFT/GRPO/DPO stages via `StrategyFactory`.
- **schedule.py**: Learning rate scheduler. - **schedule.py**: Learning rate scheduler implementation (cosine, SGDR, etc.).
- **train_callback.py**: Training callbacks (checkpoint, early stopping, etc.). - **train_callback.py**: Training callbacks (checkpoint, early stopping, etc.).
- **metric_util.py**: Metrics calculation and tracking. - **metric_util.py**: Metrics calculation and tracking.
### 5. Inference Module (inference/) ### 5. Inference Module (inference/)
- **generator.py**: Text generation with various methods (sync, batch, streaming). - **generator.py**: Text generation with various methods (sync, batch, streaming).
- **core.py**: Inference core with KV cache optimization. - **core.py**: Inference core with KV cache optimization.
- **server.py**: API service for inference. - **server.py**: API service for inference (FastAPI + Uvicorn).
### 6. Parallel Module (parallel/) ### 6. Parallel Module (parallel/)
- **setup.py**: Distributed initialization for multi-GPU/multi-machine training. - **setup.py**: Distributed initialization for multi-GPU/multi-machine training.
@ -134,7 +132,7 @@ flowchart TB
The common training process for large language models (LLM) typically includes three stages: **Pre-training (PT)**, **Supervised Fine-Tuning (SFT)**, and **Reinforcement Learning from Human Feedback (RLHF)**. This system is designed to support seamless end-to-end flow, achieving efficient switching and state management of different training stages through modular strategies, ensuring the model's capabilities gradually evolve from general language understanding to human-preference-aligned dialogue and instruction execution. The common training process for large language models (LLM) typically includes three stages: **Pre-training (PT)**, **Supervised Fine-Tuning (SFT)**, and **Reinforcement Learning from Human Feedback (RLHF)**. This system is designed to support seamless end-to-end flow, achieving efficient switching and state management of different training stages through modular strategies, ensuring the model's capabilities gradually evolve from general language understanding to human-preference-aligned dialogue and instruction execution.
### **2.1 Pre-training Stage** ### **2.1 Pre-training Stage (SEQ/PT)**
The pre-training stage aims to build the model's foundational language capabilities and general knowledge representation. This stage performs self-supervised learning on large-scale, unlabeled corpora (typically covering hundreds of GB to TB of text data). The model architecture is based on the standard Transformer Decoder, trained through masked language modeling objectives (such as causal language modeling), enabling the model to learn vocabulary, grammar, semantics, and world knowledge embedded in text. The pre-training stage aims to build the model's foundational language capabilities and general knowledge representation. This stage performs self-supervised learning on large-scale, unlabeled corpora (typically covering hundreds of GB to TB of text data). The model architecture is based on the standard Transformer Decoder, trained through masked language modeling objectives (such as causal language modeling), enabling the model to learn vocabulary, grammar, semantics, and world knowledge embedded in text.
@ -152,7 +150,7 @@ $$
- $\theta$: Model parameters - $\theta$: Model parameters
- $P(x_t \mid x_{<t}; \theta)$: The probability of the model predicting the next token given the preceding context - $P(x_t \mid x_{<t}; \theta)$: The probability of the model predicting the next token given the preceding context
The core of this stage lies in utilizing distributed parallel computing resources to achieve stable optimization of model parameters. The `PTStrategy` in the trainer module is specifically responsible for managing pre-training-specific data sampling, long sequence segmentation, and gradient accumulation logic. At the same time, the hardware adaptation module automatically selects the optimal parallel communication backend (such as NCCL) based on the runtime environment (such as NVIDIA GPU cluster) and performs computation graph optimization to maximize hardware utilization and training throughput. The core of this stage lies in utilizing distributed parallel computing resources to achieve stable optimization of model parameters. The `SEQStrategy` (Pre-training) in the trainer module is specifically responsible for managing pre-training-specific data sampling, long sequence segmentation, and gradient accumulation logic. At the same time, the hardware adaptation module automatically selects the optimal parallel communication backend (such as NCCL) based on the runtime environment (such as NVIDIA GPU cluster) and performs computation graph optimization to maximize hardware utilization and training throughput.
Additionally, the system achieves zero-copy reading of massive data through the efficient memory-mapped loader (`MmapFileHandler`) in the data module, overcoming traditional IO bottlenecks. Additionally, the system achieves zero-copy reading of massive data through the efficient memory-mapped loader (`MmapFileHandler`) in the data module, overcoming traditional IO bottlenecks.
@ -205,9 +203,9 @@ $$
- $\beta$: Temperature parameter (typically set to 0.1-0.5) - $\beta$: Temperature parameter (typically set to 0.1-0.5)
- Note: Implicitly learning reward function $r(x, y) = \beta \log \frac{\pi_\theta(y \mid x)}{\pi_{\text{ref}}(y \mid x)}$ - Note: Implicitly learning reward function $r(x, y) = \beta \log \frac{\pi_\theta(y \mid x)}{\pi_{\text{ref}}(y \mid x)}$
In this stage, the trainer module enables the `RLHFStrategy` (or similar `DPOStrategy` direct preference optimization strategy). This strategy manages a complex training loop containing the policy model (LLM to be optimized), reference model (usually an SFT model snapshot), and reward model. The system flow is as follows: In this stage, the trainer module enables the `DPOStrategy` (Direct Preference Optimization) or `GRPOStrategy` (Group Relative Policy Optimization). This strategy manages a complex training loop containing the policy model (LLM to be optimized), reference model (usually an SFT model snapshot), and reward model. The system flow is as follows:
1. **Preference Data Collection and Reward Modeling**: First, by collecting human annotators' ranking preferences for multiple model-generated results for the same prompt, a separate reward model (RM) is trained. This model learns to output a scalar reward score for generated text to quantify the degree of alignment with human preferences. 1. **Preference Data Collection and Reward Modeling**: First, by collecting human annotators' ranking preferences for multiple model-generated results for the same prompt, a separate reward model (RM) can be trained. This model learns to output a scalar reward score for generated text to quantify the degree of alignment with human preferences.
2. **Policy Optimization**: Then, using the reward model as the optimization signal, the SFT model (as the policy) is fine-tuned through reinforcement learning algorithms. The goal of policy optimization is to maximize the expected cumulative reward obtained from the reward model, while constraining the output distribution of the policy model and reference model from deviating too much through a KL divergence penalty term, preventing mode collapse and maintaining generation diversity. The training context manager maintains the states of the policy model, reference model, and reward model (or value function model) simultaneously at this stage, and coordinates complex multi-stage gradient computations. 2. **Policy Optimization**: Then, using the reward model as the optimization signal, the SFT model (as the policy) is fine-tuned through reinforcement learning algorithms (DPO/GRPO). The goal of policy optimization is to maximize the expected cumulative reward obtained from the reward model, while constraining the output distribution of the policy model and reference model from deviating too much through a KL divergence penalty term, preventing mode collapse and maintaining generation diversity. The training context manager maintains the states of the policy model, reference model, and reward model (or value function model) simultaneously at this stage, and coordinates complex multi-stage gradient computations.
Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence. The system, through unified `Trainer` interface and strategy pattern design, makes each stage of training highly reusable at the code level, clearly decoupled at the process level, providing an efficient, flexible, and scalable engineering foundation for large-scale language model research and iteration. Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence. The system, through unified `Trainer` interface and strategy pattern design, makes each stage of training highly reusable at the code level, clearly decoupled at the process level, providing an efficient, flexible, and scalable engineering foundation for large-scale language model research and iteration.

View File

@ -5,17 +5,17 @@ from astrai.config import (
ModelConfig, ModelConfig,
TrainConfig, TrainConfig,
) )
from astrai.model.transformer import Transformer from astrai.data import BpeTokenizer, DatasetLoader
from astrai.data import DatasetLoader, BpeTokenizer
from astrai.inference.generator import ( from astrai.inference.generator import (
GenerationRequest,
LoopGenerator,
StreamGenerator,
BatchGenerator, BatchGenerator,
EmbeddingEncoder, EmbeddingEncoder,
GenerationRequest,
GeneratorFactory, GeneratorFactory,
LoopGenerator,
StreamGenerator,
) )
from astrai.trainer import Trainer, StrategyFactory, SchedulerFactory from astrai.model.transformer import Transformer
from astrai.trainer import SchedulerFactory, StrategyFactory, Trainer
__all__ = [ __all__ = [
"Transformer", "Transformer",

View File

@ -1,14 +1,7 @@
from astrai.config.model_config import ModelConfig from astrai.config.model_config import ModelConfig
from astrai.config.param_config import BaseModelIO, ModelParameter from astrai.config.param_config import BaseModelIO, ModelParameter
from astrai.config.schedule_config import (
ScheduleConfig,
CosineScheduleConfig,
SGDRScheduleConfig,
ScheduleConfigFactory,
)
from astrai.config.train_config import TrainConfig from astrai.config.train_config import TrainConfig
__all__ = [ __all__ = [
# Base I/O # Base I/O
"BaseModelIO", "BaseModelIO",
@ -16,9 +9,4 @@ __all__ = [
# Model configuration # Model configuration
"ModelConfig", "ModelConfig",
"TrainConfig", "TrainConfig",
# Schedule configuration
"ScheduleConfig",
"CosineScheduleConfig",
"SGDRScheduleConfig",
"ScheduleConfigFactory",
] ]

View File

@ -1,5 +1,4 @@
import json import json
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from typing import Optional, Self from typing import Optional, Self

View File

@ -1,13 +1,13 @@
import torch.nn as nn
import safetensors.torch as st
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional, Self, Union
from pathlib import Path from pathlib import Path
from typing import Self, Union
import safetensors.torch as st
import torch.nn as nn
from astrai.data.tokenizer import BpeTokenizer
from astrai.config.model_config import ModelConfig from astrai.config.model_config import ModelConfig
from astrai.data.tokenizer import BpeTokenizer
from astrai.model.transformer import Transformer from astrai.model.transformer import Transformer

View File

@ -1,149 +0,0 @@
from typing import Any, Dict, Type
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
@dataclass
class ScheduleConfig(ABC):
"""Base configuration class for learning rate schedulers.
Provides common validation and interface for all schedule types.
"""
schedule_type: str = field(
default="cosine",
metadata={
"help": "Type of learning rate schedule.",
"choices": ["cosine", "sgdr"],
},
)
warmup_steps: int = field(
default=1000, metadata={"help": "Number of warmup steps."}
)
min_rate: float = field(
default=0.05, metadata={"help": "Minimum learning rate multiplier."}
)
@abstractmethod
def get_kwargs(self) -> Dict[str, Any]:
"""Get configuration kwargs for scheduler creation."""
raise NotImplementedError
def validate(self) -> None:
"""Validate configuration parameters."""
if self.warmup_steps < 0:
raise ValueError(
f"warmup_steps must be non-negative, got {self.warmup_steps}"
)
if not 0 <= self.min_rate <= 1:
raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}")
@dataclass
class CosineScheduleConfig(ScheduleConfig):
"""Cosine annealing learning rate schedule configuration."""
total_steps: int = field(
default=None, metadata={"help": "Total training steps for cosine schedule."}
)
def __post_init__(self) -> None:
self.schedule_type = "cosine"
self.validate()
def get_kwargs(self) -> Dict[str, Any]:
if self.total_steps is None:
raise ValueError("total_steps must be specified for cosine schedule")
return {
"schedule_type": self.schedule_type,
"warmup_steps": self.warmup_steps,
"lr_decay_steps": self.total_steps - self.warmup_steps,
"min_rate": self.min_rate,
}
def validate(self) -> None:
super().validate()
if self.total_steps is not None and self.total_steps <= self.warmup_steps:
raise ValueError(
f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})"
)
@dataclass
class SGDRScheduleConfig(ScheduleConfig):
"""Stochastic Gradient Descent with Warm Restarts schedule configuration."""
cycle_length: int = field(
default=1000, metadata={"help": "Length of the first cycle in steps."}
)
t_mult: int = field(
default=2, metadata={"help": "Multiplier for cycle length growth."}
)
def __post_init__(self) -> None:
self.schedule_type = "sgdr"
self.validate()
def get_kwargs(self) -> Dict[str, Any]:
return {
"schedule_type": self.schedule_type,
"warmup_steps": self.warmup_steps,
"cycle_length": self.cycle_length,
"min_rate": self.min_rate,
"t_mult": self.t_mult,
}
def validate(self) -> None:
super().validate()
if self.cycle_length <= 0:
raise ValueError(f"cycle_length must be positive, got {self.cycle_length}")
if self.t_mult < 1:
raise ValueError(f"t_mult must be >= 1, got {self.t_mult}")
class ScheduleConfigFactory:
"""Factory class for creating ScheduleConfig instances.
Supports both direct instantiation and factory creation methods.
Example usage:
# Direct creation
config = CosineScheduleConfig(total_steps=10000)
# Factory method
config = ScheduleConfigFactory.create("cosine", total_steps=10000)
"""
CONFIG_MAP: Dict[str, Type[ScheduleConfig]] = {
"cosine": CosineScheduleConfig,
"sgdr": SGDRScheduleConfig,
}
@classmethod
def create(cls, schedule_type: str, **kwargs) -> ScheduleConfig:
"""Create a schedule config instance.
Args:
schedule_type: Type of schedule ("cosine", "sgdr")
**kwargs: Arguments passed to the config constructor
Returns:
ScheduleConfig instance
Raises:
ValueError: If schedule_type is not supported
"""
if schedule_type not in cls.CONFIG_MAP:
raise ValueError(
f"Unknown schedule type: '{schedule_type}'. "
f"Supported types: {sorted(cls.CONFIG_MAP.keys())}"
)
config_cls = cls.CONFIG_MAP[schedule_type]
return config_cls(**kwargs)
@classmethod
def available_types(cls) -> list:
"""Return list of available schedule type names."""
return list(cls.CONFIG_MAP.keys())

View File

@ -1,11 +1,11 @@
import torch.nn as nn
from torch.utils.data import Dataset
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Callable, List, Optional from typing import Callable, List, Optional
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import Dataset
@dataclass @dataclass
class TrainConfig: class TrainConfig:

View File

@ -1,16 +1,15 @@
from astrai.data.dataset import ( from astrai.data.dataset import (
BaseDataset, BaseDataset,
SEQDataset, DatasetFactory,
DatasetLoader,
DPODataset, DPODataset,
SFTDataset,
GRPODataset, GRPODataset,
MultiSegmentFetcher, MultiSegmentFetcher,
DatasetLoader, SEQDataset,
DatasetFactory, SFTDataset,
) )
from astrai.data.tokenizer import BpeTokenizer
from astrai.data.sampler import ResumableDistributedSampler from astrai.data.sampler import ResumableDistributedSampler
from astrai.data.tokenizer import BpeTokenizer
__all__ = [ __all__ = [
# Base classes # Base classes

View File

@ -1,13 +1,14 @@
"""Dataset implementations with factory pattern for training.""" """Dataset implementations with factory pattern for training."""
import torch
import bisect import bisect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Union
import torch
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset from torch.utils.data import Dataset
from astrai.data.serialization import load_h5 from astrai.data.serialization import load_h5
from typing import List, Dict, Optional, Union
class BaseSegmentFetcher: class BaseSegmentFetcher:

View File

@ -1,8 +1,8 @@
from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.utils.data import Dataset, Sampler from torch.utils.data import Dataset, Sampler
from typing import Optional
class ResumableDistributedSampler(Sampler[int]): class ResumableDistributedSampler(Sampler[int]):

View File

@ -1,13 +1,14 @@
import os
import h5py
import torch
import json import json
import safetensors.torch as st import os
import torch.distributed as dist
from pathlib import Path from pathlib import Path
from torch import Tensor
from typing import Any, Dict, List from typing import Any, Dict, List
import h5py
import safetensors.torch as st
import torch
import torch.distributed as dist
from torch import Tensor
from astrai.parallel.setup import get_rank from astrai.parallel.setup import get_rank

View File

@ -1,8 +1,9 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from tokenizers import Tokenizer, decoders, processors, normalizers, pre_tokenizers from typing import List, Union
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers.models import BPE from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer as BpeTrainerImpl from tokenizers.trainers import BpeTrainer as BpeTrainerImpl
from typing import List, Union
class BaseTokenizer(ABC): class BaseTokenizer(ABC):

View File

@ -1,16 +1,15 @@
from astrai.inference.core import ( from astrai.inference.core import (
GeneratorCore,
EmbeddingEncoderCore, EmbeddingEncoderCore,
GeneratorCore,
KVCacheManager, KVCacheManager,
) )
from astrai.inference.generator import ( from astrai.inference.generator import (
GenerationRequest,
LoopGenerator,
StreamGenerator,
BatchGenerator, BatchGenerator,
EmbeddingEncoder, EmbeddingEncoder,
GenerationRequest,
GeneratorFactory, GeneratorFactory,
LoopGenerator,
StreamGenerator,
) )
__all__ = [ __all__ = [

View File

@ -1,8 +1,9 @@
import torch from typing import Any, Callable, List, Optional, Self, Tuple, Union
import torch
from torch import Tensor from torch import Tensor
from typing import Any, Callable, List, Tuple, Union, Optional, Self
from astrai.config import ModelParameter, ModelConfig from astrai.config import ModelConfig, ModelParameter
def apply_sampling_strategies( def apply_sampling_strategies(

View File

@ -1,10 +1,11 @@
import torch
from dataclasses import dataclass from dataclasses import dataclass
from torch import Tensor from typing import Generator, List, Optional, Tuple, Union
from typing import List, Tuple, Union, Optional, Generator
from astrai.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheManager
from astrai.config.param_config import ModelParameter
import torch
from torch import Tensor
from astrai.config.param_config import ModelParameter
from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager
HistoryType = List[Tuple[str, str]] HistoryType = List[Tuple[str, str]]

View File

@ -1,13 +1,15 @@
import torch
import uvicorn
import logging import logging
from pathlib import Path from pathlib import Path
from typing import List, Optional, Dict, Any, Tuple from typing import Any, Dict, List, Optional, Tuple
import torch
import uvicorn
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from astrai.config.param_config import ModelParameter from astrai.config.param_config import ModelParameter
from astrai.inference.generator import GeneratorFactory, GenerationRequest from astrai.inference.generator import GenerationRequest, GeneratorFactory
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -1,9 +1,9 @@
from astrai.model.module import ( from astrai.model.module import (
GQA,
MLP,
DecoderBlock,
Linear, Linear,
RMSNorm, RMSNorm,
MLP,
GQA,
DecoderBlock,
) )
from astrai.model.transformer import Transformer from astrai.model.transformer import Transformer

View File

@ -1,9 +1,9 @@
from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from typing import Optional, Tuple
def repeat_kv(x: Tensor, n_rep: int) -> Tensor: def repeat_kv(x: Tensor, n_rep: int) -> Tensor:

View File

@ -1,12 +1,13 @@
from typing import Any, Mapping, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from typing import Any, Mapping, Optional, Tuple
from astrai.config.model_config import ModelConfig from astrai.config.model_config import ModelConfig
from astrai.model.module import ( from astrai.model.module import (
Embedding,
DecoderBlock, DecoderBlock,
Embedding,
Linear, Linear,
RMSNorm, RMSNorm,
RotaryEmbedding, RotaryEmbedding,

View File

@ -1,14 +1,13 @@
from astrai.parallel.module import ColumnParallelLinear, RowParallelLinear
from astrai.parallel.setup import ( from astrai.parallel.setup import (
get_world_size,
get_rank,
get_current_device, get_current_device,
get_rank,
get_world_size,
only_on_rank, only_on_rank,
setup_parallel, setup_parallel,
spawn_parallel_fn, spawn_parallel_fn,
) )
from astrai.parallel.module import RowParallelLinear, ColumnParallelLinear
__all__ = [ __all__ = [
"get_world_size", "get_world_size",
"get_rank", "get_rank",

View File

@ -1,10 +1,10 @@
from typing import Dict
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.distributed as dist
from torch import Tensor from torch import Tensor
from typing import Dict
class ParallelModel(nn.Module): class ParallelModel(nn.Module):

View File

@ -1,12 +1,12 @@
import os import os
from contextlib import contextmanager
from functools import wraps
from typing import Callable, List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from functools import wraps
from contextlib import contextmanager
from typing import Callable, List, Optional
def get_current_device(): def get_current_device():
return os.environ["LOCAL_DEVICE"] return os.environ["LOCAL_DEVICE"]

View File

@ -1,15 +1,14 @@
from astrai.trainer.trainer import Trainer from astrai.trainer.schedule import BaseScheduler, SchedulerFactory
from astrai.trainer.strategy import StrategyFactory, BaseStrategy from astrai.trainer.strategy import BaseStrategy, StrategyFactory
from astrai.trainer.schedule import SchedulerFactory, BaseScheduler
from astrai.trainer.train_callback import ( from astrai.trainer.train_callback import (
TrainCallback,
GradientClippingCallback,
SchedulerCallback,
CheckpointCallback, CheckpointCallback,
ProgressBarCallback, GradientClippingCallback,
MetricLoggerCallback, MetricLoggerCallback,
ProgressBarCallback,
SchedulerCallback,
TrainCallback,
) )
from astrai.trainer.trainer import Trainer
__all__ = [ __all__ = [
# Main trainer # Main trainer

View File

@ -1,6 +1,7 @@
import torch.nn as nn
from typing import Dict from typing import Dict
import torch.nn as nn
def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]: def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]:
"""Compute gradient norm for each parameter in the model.""" """Compute gradient norm for each parameter in the model."""

View File

@ -1,10 +1,10 @@
"""Learning rate scheduler implementations with factory pattern.""" """Learning rate scheduler implementations with factory pattern."""
import math import math
from abc import abstractmethod, ABC from abc import ABC, abstractmethod
from typing import Any, Dict, List, Type from typing import Any, Dict, List, Type
from torch.optim.lr_scheduler import LRScheduler from torch.optim.lr_scheduler import LRScheduler
from astrai.config.schedule_config import ScheduleConfig
class BaseScheduler(LRScheduler, ABC): class BaseScheduler(LRScheduler, ABC):
@ -37,10 +37,6 @@ class SchedulerFactory:
... ...
scheduler = SchedulerFactory.create(optimizer, "custom", **kwargs) scheduler = SchedulerFactory.create(optimizer, "custom", **kwargs)
# Or from config
config = CosineScheduleConfig(total_steps=10000)
scheduler = SchedulerFactory.load(optimizer, config)
""" """
SCHEDULER_MAP: Dict[str, Type[BaseScheduler]] = {} SCHEDULER_MAP: Dict[str, Type[BaseScheduler]] = {}
@ -67,7 +63,7 @@ class SchedulerFactory:
return decorator return decorator
@classmethod @classmethod
def create(cls, optimizer, schedule_type: str, **kwargs) -> BaseScheduler: def create(cls, optimizer, schedule_type: str = "none", **kwargs) -> BaseScheduler:
"""Create a scheduler instance by type name. """Create a scheduler instance by type name.
Args: Args:
@ -90,29 +86,13 @@ class SchedulerFactory:
scheduler_cls = cls.SCHEDULER_MAP[schedule_type] scheduler_cls = cls.SCHEDULER_MAP[schedule_type]
return scheduler_cls(optimizer, **kwargs) return scheduler_cls(optimizer, **kwargs)
@staticmethod
def load(optimizer, schedule_config: ScheduleConfig) -> BaseScheduler:
"""Create a scheduler from a ScheduleConfig object.
Args:
optimizer: PyTorch optimizer
schedule_config: ScheduleConfig instance
Returns:
Scheduler instance
"""
kwargs = schedule_config.get_kwargs()
schedule_type = kwargs.pop("schedule_type")
return SchedulerFactory.create(optimizer, schedule_type, **kwargs)
@classmethod @classmethod
def available_types(cls) -> list: def available_types(cls) -> list:
"""Return list of registered scheduler type names.""" """Return list of registered scheduler type names."""
return list(cls.SCHEDULER_MAP.keys()) return list(cls.SCHEDULER_MAP.keys())
# ============== Scheduler Classes ============== # ----------- Scheduler implementations -----------
# All scheduler classes are registered at class definition time using the decorator
@SchedulerFactory.register("cosine") @SchedulerFactory.register("cosine")

View File

@ -1,14 +1,14 @@
"""Training strategy implementations with factory pattern.""" """Training strategy implementations with factory pattern."""
import copy import copy
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch import Tensor from torch import Tensor
from typing import Any, Callable, Dict, Union from torch.nn.parallel import DistributedDataParallel as DDP
from abc import ABC, abstractmethod
def unwrap_model(model: nn.Module) -> nn.Module: def unwrap_model(model: nn.Module) -> nn.Module:

View File

@ -1,25 +1,25 @@
import os
import json import json
import os
import time import time
import torch.nn as nn
from pathlib import Path from pathlib import Path
from tqdm import tqdm
from torch.nn.utils import clip_grad_norm_
from typing import Callable, List, Optional, Protocol from typing import Callable, List, Optional, Protocol
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm
from astrai.data.serialization import Checkpoint
from astrai.parallel import only_on_rank from astrai.parallel import only_on_rank
from astrai.trainer.metric_util import ( from astrai.trainer.metric_util import (
ctx_get_grad_max,
ctx_get_grad_mean,
ctx_get_grad_min,
ctx_get_grad_nan_num,
ctx_get_grad_norm,
ctx_get_grad_std,
ctx_get_loss, ctx_get_loss,
ctx_get_lr, ctx_get_lr,
ctx_get_grad_max,
ctx_get_grad_min,
ctx_get_grad_norm,
ctx_get_grad_mean,
ctx_get_grad_std,
ctx_get_grad_nan_num,
) )
from astrai.data.serialization import Checkpoint
from astrai.trainer.train_context import TrainContext from astrai.trainer.train_context import TrainContext

View File

@ -1,16 +1,16 @@
from dataclasses import dataclass, field
from typing import Optional, Self
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from astrai.config.train_config import TrainConfig
from astrai.data import ResumableDistributedSampler from astrai.data import ResumableDistributedSampler
from astrai.data.serialization import Checkpoint from astrai.data.serialization import Checkpoint
from astrai.trainer.strategy import StrategyFactory, BaseStrategy from astrai.parallel.setup import get_current_device, get_rank, get_world_size
from astrai.config.train_config import TrainConfig from astrai.trainer.strategy import BaseStrategy, StrategyFactory
from astrai.parallel.setup import get_current_device, get_world_size, get_rank
from dataclasses import dataclass, field
from typing import Optional, Self
@dataclass @dataclass

View File

@ -1,17 +1,18 @@
import logging import logging
from typing import Optional, List from typing import List, Optional
from astrai.config import TrainConfig from astrai.config import TrainConfig
from astrai.trainer.train_callback import (
TrainCallback,
ProgressBarCallback,
CheckpointCallback,
MetricLoggerCallback,
GradientClippingCallback,
SchedulerCallback,
)
from astrai.trainer.train_context import TrainContext, TrainContextBuilder
from astrai.data.serialization import Checkpoint from astrai.data.serialization import Checkpoint
from astrai.parallel.setup import spawn_parallel_fn from astrai.parallel.setup import spawn_parallel_fn
from astrai.trainer.train_callback import (
CheckpointCallback,
GradientClippingCallback,
MetricLoggerCallback,
ProgressBarCallback,
SchedulerCallback,
TrainCallback,
)
from astrai.trainer.train_context import TrainContext, TrainContextBuilder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -1,4 +1,5 @@
from pathlib import Path from pathlib import Path
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
PROJECT_ROOT = Path(__file__).resolve().parents[2] PROJECT_ROOT = Path(__file__).resolve().parents[2]

View File

@ -1,7 +1,9 @@
import torch
from pathlib import Path from pathlib import Path
import torch
from astrai.config.param_config import ModelParameter from astrai.config.param_config import ModelParameter
from astrai.inference.generator import GeneratorFactory, GenerationRequest from astrai.inference.generator import GenerationRequest, GeneratorFactory
PROJECT_ROOT = Path(__file__).resolve().parents[2] PROJECT_ROOT = Path(__file__).resolve().parents[2]
PARAMETER_ROOT = Path(PROJECT_ROOT, "params") PARAMETER_ROOT = Path(PROJECT_ROOT, "params")

View File

@ -1,7 +1,9 @@
import torch
from pathlib import Path from pathlib import Path
import torch
from astrai.config.param_config import ModelParameter from astrai.config.param_config import ModelParameter
from astrai.inference.generator import GeneratorFactory, GenerationRequest from astrai.inference.generator import GenerationRequest, GeneratorFactory
PROJECT_ROOT = Path(__file__).resolve().parents[2] PROJECT_ROOT = Path(__file__).resolve().parents[2]
PARAMETER_ROOT = Path(PROJECT_ROOT, "params") PARAMETER_ROOT = Path(PROJECT_ROOT, "params")

View File

@ -1,7 +1,9 @@
import torch
from pathlib import Path from pathlib import Path
import torch
from astrai.config.param_config import ModelParameter from astrai.config.param_config import ModelParameter
from astrai.inference.generator import GeneratorFactory, GenerationRequest from astrai.inference.generator import GenerationRequest, GeneratorFactory
PROJECT_ROOT = Path(__file__).resolve().parents[2] PROJECT_ROOT = Path(__file__).resolve().parents[2]
PARAMETER_ROOT = Path(PROJECT_ROOT, "params") PARAMETER_ROOT = Path(PROJECT_ROOT, "params")

View File

@ -1,6 +1,8 @@
import torch
from typing import Dict, Any
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict
import torch
from astrai.model.transformer import ModelConfig, Transformer from astrai.model.transformer import ModelConfig, Transformer

View File

@ -1,6 +1,7 @@
import torch
import json
import argparse import argparse
import json
import torch
from astrai.config.param_config import ModelParameter from astrai.config.param_config import ModelParameter
from astrai.inference.generator import BatchGenerator, GenerationRequest from astrai.inference.generator import BatchGenerator, GenerationRequest

View File

@ -1,11 +1,12 @@
import argparse
import json import json
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import argparse
import tqdm import tqdm
from torch import Tensor from torch import Tensor
from astrai.config.param_config import ModelParameter from astrai.config.param_config import ModelParameter

View File

@ -1,5 +1,6 @@
import argparse import argparse
from pathlib import Path from pathlib import Path
from astrai.inference.server import run_server from astrai.inference.server import run_server

View File

@ -1,15 +1,16 @@
import os
import argparse import argparse
import os
from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from functools import partial from astrai.config import ModelParameter, TrainConfig
from astrai.data import DatasetLoader from astrai.data import DatasetLoader
from astrai.config import ModelParameter, TrainConfig, CosineScheduleConfig
from astrai.trainer import Trainer, SchedulerFactory
from astrai.parallel import get_rank from astrai.parallel import get_rank
from astrai.trainer import SchedulerFactory, Trainer
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:
@ -158,7 +159,7 @@ def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
def create_scheduler( def create_scheduler(
optimizer: optim.Optimizer, **kwargs optimizer: optim.Optimizer, **kwargs
) -> optim.lr_scheduler.LRScheduler: ) -> optim.lr_scheduler.LRScheduler:
return SchedulerFactory.load(optimizer, **kwargs) return SchedulerFactory.create(optimizer, **kwargs)
def prepare_checkpoint(model: nn.Module) -> dict: def prepare_checkpoint(model: nn.Module) -> dict:
@ -211,11 +212,6 @@ def train(
stride=stride, stride=stride,
) )
schedule_config = CosineScheduleConfig(
warmup_steps=warmup_steps,
total_steps=len(dataset) * n_epoch // (batch_size * nprocs),
)
optimizer_fn = partial( optimizer_fn = partial(
create_optimizer, create_optimizer,
**{ **{
@ -224,7 +220,16 @@ def train(
"weight_decay": adamw_weight_decay, "weight_decay": adamw_weight_decay,
}, },
) )
scheduler_fn = partial(create_scheduler, **{"schedule_config": schedule_config})
toltal_steps = len(dataset) * n_epoch // (batch_size * nprocs)
scheduler_fn = partial(
create_scheduler,
**{
"scheduler": "cosine",
"warmup_steps": warmup_steps,
"lr_decay_steps": toltal_steps - warmup_steps,
},
)
train_config = TrainConfig( train_config = TrainConfig(
model=model, model=model,

View File

@ -1,11 +1,12 @@
import os
import json import json
import numpy as np import os
import tempfile
import shutil import shutil
import torch import tempfile
import numpy as np
import pytest import pytest
import safetensors.torch as st import safetensors.torch as st
import torch
from tokenizers import pre_tokenizers from tokenizers import pre_tokenizers
from torch.utils.data import Dataset from torch.utils.data import Dataset

View File

@ -1,9 +1,10 @@
import torch
import tempfile import tempfile
import torch.distributed as dist
import torch
import torch.distributed as dist
from torch.optim import AdamW from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
from astrai.data.serialization import Checkpoint from astrai.data.serialization import Checkpoint
from astrai.parallel.setup import get_rank, spawn_parallel_fn from astrai.parallel.setup import get_rank, spawn_parallel_fn

View File

@ -1,8 +1,8 @@
import torch
import numpy as np import numpy as np
import torch
from astrai.data.serialization import save_h5
from astrai.data.dataset import * from astrai.data.dataset import *
from astrai.data.serialization import save_h5
def test_dataset_loader_random_paths(base_test_env): def test_dataset_loader_random_paths(base_test_env):

View File

@ -1,5 +1,5 @@
from astrai.trainer import *
from astrai.data import * from astrai.data import *
from astrai.trainer import *
def test_random_sampler_consistency(random_dataset): def test_random_sampler_consistency(random_dataset):

View File

@ -1,8 +1,10 @@
"""Shared fixtures for inference tests.""" """Shared fixtures for inference tests."""
import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from astrai.inference.server import app from astrai.inference.server import app

View File

@ -1,9 +1,6 @@
"""Unit tests for the inference HTTP server.""" """Unit tests for the inference HTTP server."""
import pytest import pytest
from unittest.mock import MagicMock, patch
from fastapi.testclient import TestClient
from astrai.inference.server import app
def test_health_no_model(client, monkeypatch): def test_health_no_model(client, monkeypatch):

View File

@ -1,10 +1,12 @@
import os import os
import torch import torch
from astrai.trainer import *
from astrai.config import * from astrai.config import *
from astrai.model import *
from astrai.data import * from astrai.data import *
from astrai.inference.generator import EmbeddingEncoderCore, GeneratorCore from astrai.inference.generator import EmbeddingEncoderCore, GeneratorCore
from astrai.model import *
from astrai.trainer import *
def test_model_parameter(test_env): def test_model_parameter(test_env):

View File

@ -1,11 +1,13 @@
import os
import json import json
import torch import os
import pytest
import tempfile import tempfile
import pytest
import safetensors.torch as st import safetensors.torch as st
from astrai.model.transformer import Transformer import torch
from astrai.config.model_config import ModelConfig from astrai.config.model_config import ModelConfig
from astrai.model.transformer import Transformer
@pytest.fixture @pytest.fixture

View File

@ -1,6 +1,9 @@
import pytest
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
import pytest
from astrai.config import TrainConfig
from astrai.trainer.schedule import SchedulerFactory
class TrainerDataset(Dataset): class TrainerDataset(Dataset):
@ -54,13 +57,11 @@ def create_train_config(
Returns: Returns:
TrainConfig instance configured for testing TrainConfig instance configured for testing
""" """
from astrai.config import TrainConfig
from astrai.config.schedule_config import CosineScheduleConfig
from astrai.trainer.schedule import SchedulerFactory
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
optimizer_fn = lambda m: torch.optim.AdamW(m.parameters(), lr=0.001) optimizer_fn = lambda m: torch.optim.AdamW(m.parameters(), lr=0.001)
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config) scheduler_fn = lambda optim: SchedulerFactory.create(
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
)
return TrainConfig( return TrainConfig(
strategy=strategy, strategy=strategy,

View File

@ -6,10 +6,10 @@ from astrai.trainer import *
def test_callback_integration(base_test_env, random_dataset): def test_callback_integration(base_test_env, random_dataset):
"""Test that all callbacks are properly integrated""" """Test that all callbacks are properly integrated"""
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters()) optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config) scheduler_fn = lambda optim: SchedulerFactory.create(
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
)
train_config = TrainConfig( train_config = TrainConfig(
model=base_test_env["model"], model=base_test_env["model"],

View File

@ -1,18 +1,20 @@
import os import os
import torch
import numpy as np import numpy as np
import torch
from astrai.config import * from astrai.config import *
from astrai.trainer import *
from astrai.data.serialization import Checkpoint from astrai.data.serialization import Checkpoint
from astrai.trainer import *
def test_early_stopping_simulation(base_test_env, early_stopping_dataset): def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
"""Simulate early stopping behavior""" """Simulate early stopping behavior"""
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters()) optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config) scheduler_fn = lambda optim: SchedulerFactory.create(
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
)
train_config = TrainConfig( train_config = TrainConfig(
strategy="seq", strategy="seq",

View File

@ -1,10 +1,9 @@
import torch
import numpy as np import numpy as np
import pytest import torch
from astrai.config import * from astrai.config import *
from astrai.trainer.schedule import *
from astrai.data.dataset import * from astrai.data.dataset import *
from astrai.trainer.schedule import *
def test_schedule_factory_random_configs(): def test_schedule_factory_random_configs():
@ -16,41 +15,57 @@ def test_schedule_factory_random_configs():
# Test multiple random configurations # Test multiple random configurations
for _ in range(5): # Test 5 random configurations for _ in range(5): # Test 5 random configurations
schedule_configs = [ # Test multiple random configurations
CosineScheduleConfig( cosine_params = {
warmup_steps=np.random.randint(50, 200), "schedule_type": "cosine",
total_steps=np.random.randint(1000, 5000), "warmup_steps": np.random.randint(50, 200),
min_rate=np.random.uniform(0.01, 0.1), "total_steps": np.random.randint(1000, 5000),
), "min_rate": np.random.uniform(0.01, 0.1),
SGDRScheduleConfig( }
warmup_steps=np.random.randint(50, 200), sgdr_params = {
cycle_length=np.random.randint(500, 2000), "schedule_type": "sgdr",
t_mult=np.random.randint(1, 3), "warmup_steps": np.random.randint(50, 200),
min_rate=np.random.uniform(0.01, 0.1), "cycle_length": np.random.randint(500, 2000),
), "t_mult": np.random.randint(1, 3),
] "min_rate": np.random.uniform(0.01, 0.1),
}
for config in schedule_configs: for params in [cosine_params, sgdr_params]:
# Validate configuration schedule_type = params["schedule_type"]
config.validate() # Convert parameters for scheduler constructor
if schedule_type == "cosine":
# Create scheduler using factory warmup_steps = params["warmup_steps"]
scheduler = SchedulerFactory.load(optimizer, config) total_steps = params["total_steps"]
min_rate = params["min_rate"]
# Verify scheduler type lr_decay_steps = total_steps - warmup_steps
if isinstance(config, CosineScheduleConfig): scheduler = SchedulerFactory.create(
assert isinstance(scheduler, CosineScheduler) optimizer,
assert scheduler.warmup_steps == config.warmup_steps schedule_type,
assert ( warmup_steps=warmup_steps,
scheduler.lr_decay_steps == config.total_steps - config.warmup_steps lr_decay_steps=lr_decay_steps,
min_rate=min_rate,
)
assert isinstance(scheduler, CosineScheduler)
assert scheduler.warmup_steps == warmup_steps
assert scheduler.lr_decay_steps == lr_decay_steps
assert scheduler.min_rate == min_rate
elif schedule_type == "sgdr":
warmup_steps = params["warmup_steps"]
cycle_length = params["cycle_length"]
t_mult = params["t_mult"]
min_rate = params["min_rate"]
scheduler = SchedulerFactory.create(
optimizer,
schedule_type,
warmup_steps=warmup_steps,
cycle_length=cycle_length,
t_mult=t_mult,
min_rate=min_rate,
) )
assert scheduler.min_rate == config.min_rate
elif isinstance(config, SGDRScheduleConfig):
assert isinstance(scheduler, SGDRScheduler) assert isinstance(scheduler, SGDRScheduler)
assert scheduler.warmup_steps == config.warmup_steps assert scheduler.warmup_steps == warmup_steps
assert scheduler.cycle_length == config.cycle_length assert scheduler.cycle_length == cycle_length
assert scheduler.t_mult == config.t_mult assert scheduler.t_mult == t_mult
assert scheduler.min_rate == config.min_rate assert scheduler.min_rate == min_rate
# Test scheduler state dict functionality # Test scheduler state dict functionality
state_dict = scheduler.state_dict() state_dict = scheduler.state_dict()
@ -76,16 +91,25 @@ def test_schedule_factory_edge_cases():
# Test edge cases for CosineScheduleConfig # Test edge cases for CosineScheduleConfig
edge_cases = [ edge_cases = [
# Minimal warmup and steps # Minimal warmup and steps
CosineScheduleConfig(warmup_steps=1, total_steps=10, min_rate=0.01), {"warmup_steps": 1, "total_steps": 10, "min_rate": 0.01},
# Large values # Large values
CosineScheduleConfig(warmup_steps=1000, total_steps=10000, min_rate=0.5), {"warmup_steps": 1000, "total_steps": 10000, "min_rate": 0.5},
# Zero min_rate (edge case) # Zero min_rate (edge case)
CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.0), {"warmup_steps": 100, "total_steps": 1000, "min_rate": 0.0},
] ]
for config in edge_cases: for params in edge_cases:
config.validate() warmup_steps = params["warmup_steps"]
scheduler = SchedulerFactory.load(optimizer, config) total_steps = params["total_steps"]
min_rate = params["min_rate"]
lr_decay_steps = total_steps - warmup_steps
scheduler = SchedulerFactory.create(
optimizer,
"cosine",
warmup_steps=warmup_steps,
lr_decay_steps=lr_decay_steps,
min_rate=min_rate,
)
assert scheduler is not None assert scheduler is not None
# Test multiple steps # Test multiple steps
@ -93,34 +117,24 @@ def test_schedule_factory_edge_cases():
scheduler.step() scheduler.step()
def test_schedule_factory_invalid_configs():
"""Test scheduler factory with invalid configurations"""
# Test invalid configurations that should raise errors
invalid_configs = [
# Negative warmup steps
{"warmup_steps": -10, "total_steps": 1000, "min_rate": 0.1},
# Total steps less than warmup steps
{"warmup_steps": 500, "total_steps": 400, "min_rate": 0.1},
# Invalid min_rate
{"warmup_steps": 100, "total_steps": 1000, "min_rate": -0.1},
{"warmup_steps": 100, "total_steps": 1000, "min_rate": 1.1},
]
for kwargs in invalid_configs:
with pytest.raises(ValueError):
config = CosineScheduleConfig(**kwargs)
config.validate()
def test_schedule_factory_state_persistence(): def test_schedule_factory_state_persistence():
"""Test scheduler state persistence (save/load)""" """Test scheduler state persistence (save/load)"""
model = torch.nn.Linear(10, 2) model = torch.nn.Linear(10, 2)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
config = CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.1) # Create scheduler directly with parameters
scheduler = SchedulerFactory.load(optimizer, config) warmup_steps = 100
total_steps = 1000
min_rate = 0.1
lr_decay_steps = total_steps - warmup_steps
scheduler = SchedulerFactory.create(
optimizer,
"cosine",
warmup_steps=warmup_steps,
lr_decay_steps=lr_decay_steps,
min_rate=min_rate,
)
# Take a few steps # Take a few steps
for _ in range(5): for _ in range(5):
@ -129,8 +143,14 @@ def test_schedule_factory_state_persistence():
# Save state # Save state
state_dict = scheduler.state_dict() state_dict = scheduler.state_dict()
# Create new scheduler and load state # Create new scheduler with same parameters
new_scheduler = SchedulerFactory.load(optimizer, config) new_scheduler = SchedulerFactory.create(
optimizer,
"cosine",
warmup_steps=warmup_steps,
lr_decay_steps=lr_decay_steps,
min_rate=min_rate,
)
new_scheduler.load_state_dict(state_dict) new_scheduler.load_state_dict(state_dict)
# Verify states match # Verify states match

View File

@ -1,5 +1,3 @@
import torch
from astrai.data.dataset import * from astrai.data.dataset import *
from astrai.trainer import Trainer from astrai.trainer import Trainer