refactor: 优化参数传递,清理导入样式
This commit is contained in:
parent
3a7d98a950
commit
0852b852f8
|
|
@ -24,7 +24,6 @@ flowchart TB
|
|||
C1[model_config.py<br/>Model Architecture]
|
||||
C2[train_config.py<br/>Training Params]
|
||||
C3[param_config.py<br/>Hyperparameters]
|
||||
C4[schedule_config.py<br/>Scheduler Config]
|
||||
end
|
||||
|
||||
subgraph Data["Data Module (data/)"]
|
||||
|
|
@ -95,14 +94,13 @@ flowchart TB
|
|||
|
||||
### 1. Configuration Module (config/)
|
||||
- **model_config.py**: Defines model structure parameters (layers, heads, dimensions, etc.), managed through `ModelConfig`.
|
||||
- **train_config.py**: Sets training parameters (batch size, training stages PT/SFT/DPO, optimizers, etc.), loaded by `TrainConfig`.
|
||||
- **train_config.py**: Sets training parameters (batch size, training stages SEQ/SFT/GRPO/DPO, optimizers, etc.), loaded by `TrainConfig`.
|
||||
- **param_config.py**: Manages hyperparameters for training and inference.
|
||||
- **schedule_config.py**: Controls learning rate strategies (cosine annealing) and training progress.
|
||||
|
||||
### 2. Data Module (data/)
|
||||
- **dataset.py**: Dataset handling and loading.
|
||||
- **sampler.py**: Data sampling for different training stages.
|
||||
- **serialization.py**: Data serialization and deserialization.
|
||||
- **serialization.py**: Data serialization and deserialization, checkpoint management.
|
||||
- **tokenizer.py**: Text tokenization and encoding.
|
||||
|
||||
### 3. Model Module (model/)
|
||||
|
|
@ -112,15 +110,15 @@ flowchart TB
|
|||
### 4. Trainer Module (trainer/)
|
||||
- **trainer.py**: Main training entry point.
|
||||
- **train_context.py**: Training context management (model, optimizer, scheduler, metrics).
|
||||
- **strategy.py**: Training strategies for PT/SFT/DPO stages.
|
||||
- **schedule.py**: Learning rate scheduler.
|
||||
- **strategy.py**: Training strategies for SEQ/SFT/GRPO/DPO stages via `StrategyFactory`.
|
||||
- **schedule.py**: Learning rate scheduler implementation (cosine, SGDR, etc.).
|
||||
- **train_callback.py**: Training callbacks (checkpoint, early stopping, etc.).
|
||||
- **metric_util.py**: Metrics calculation and tracking.
|
||||
|
||||
### 5. Inference Module (inference/)
|
||||
- **generator.py**: Text generation with various methods (sync, batch, streaming).
|
||||
- **core.py**: Inference core with KV cache optimization.
|
||||
- **server.py**: API service for inference.
|
||||
- **server.py**: API service for inference (FastAPI + Uvicorn).
|
||||
|
||||
### 6. Parallel Module (parallel/)
|
||||
- **setup.py**: Distributed initialization for multi-GPU/multi-machine training.
|
||||
|
|
@ -134,7 +132,7 @@ flowchart TB
|
|||
|
||||
The common training process for large language models (LLM) typically includes three stages: **Pre-training (PT)**, **Supervised Fine-Tuning (SFT)**, and **Reinforcement Learning from Human Feedback (RLHF)**. This system is designed to support seamless end-to-end flow, achieving efficient switching and state management of different training stages through modular strategies, ensuring the model's capabilities gradually evolve from general language understanding to human-preference-aligned dialogue and instruction execution.
|
||||
|
||||
### **2.1 Pre-training Stage**
|
||||
### **2.1 Pre-training Stage (SEQ/PT)**
|
||||
|
||||
The pre-training stage aims to build the model's foundational language capabilities and general knowledge representation. This stage performs self-supervised learning on large-scale, unlabeled corpora (typically covering hundreds of GB to TB of text data). The model architecture is based on the standard Transformer Decoder, trained through masked language modeling objectives (such as causal language modeling), enabling the model to learn vocabulary, grammar, semantics, and world knowledge embedded in text.
|
||||
|
||||
|
|
@ -152,7 +150,7 @@ $$
|
|||
- $\theta$: Model parameters
|
||||
- $P(x_t \mid x_{<t}; \theta)$: The probability of the model predicting the next token given the preceding context
|
||||
|
||||
The core of this stage lies in utilizing distributed parallel computing resources to achieve stable optimization of model parameters. The `PTStrategy` in the trainer module is specifically responsible for managing pre-training-specific data sampling, long sequence segmentation, and gradient accumulation logic. At the same time, the hardware adaptation module automatically selects the optimal parallel communication backend (such as NCCL) based on the runtime environment (such as NVIDIA GPU cluster) and performs computation graph optimization to maximize hardware utilization and training throughput.
|
||||
The core of this stage lies in utilizing distributed parallel computing resources to achieve stable optimization of model parameters. The `SEQStrategy` (Pre-training) in the trainer module is specifically responsible for managing pre-training-specific data sampling, long sequence segmentation, and gradient accumulation logic. At the same time, the hardware adaptation module automatically selects the optimal parallel communication backend (such as NCCL) based on the runtime environment (such as NVIDIA GPU cluster) and performs computation graph optimization to maximize hardware utilization and training throughput.
|
||||
|
||||
Additionally, the system achieves zero-copy reading of massive data through the efficient memory-mapped loader (`MmapFileHandler`) in the data module, overcoming traditional IO bottlenecks.
|
||||
|
||||
|
|
@ -205,9 +203,9 @@ $$
|
|||
- $\beta$: Temperature parameter (typically set to 0.1-0.5)
|
||||
- Note: Implicitly learning reward function $r(x, y) = \beta \log \frac{\pi_\theta(y \mid x)}{\pi_{\text{ref}}(y \mid x)}$
|
||||
|
||||
In this stage, the trainer module enables the `RLHFStrategy` (or similar `DPOStrategy` direct preference optimization strategy). This strategy manages a complex training loop containing the policy model (LLM to be optimized), reference model (usually an SFT model snapshot), and reward model. The system flow is as follows:
|
||||
In this stage, the trainer module enables the `DPOStrategy` (Direct Preference Optimization) or `GRPOStrategy` (Group Relative Policy Optimization). This strategy manages a complex training loop containing the policy model (LLM to be optimized), reference model (usually an SFT model snapshot), and reward model. The system flow is as follows:
|
||||
|
||||
1. **Preference Data Collection and Reward Modeling**: First, by collecting human annotators' ranking preferences for multiple model-generated results for the same prompt, a separate reward model (RM) is trained. This model learns to output a scalar reward score for generated text to quantify the degree of alignment with human preferences.
|
||||
2. **Policy Optimization**: Then, using the reward model as the optimization signal, the SFT model (as the policy) is fine-tuned through reinforcement learning algorithms. The goal of policy optimization is to maximize the expected cumulative reward obtained from the reward model, while constraining the output distribution of the policy model and reference model from deviating too much through a KL divergence penalty term, preventing mode collapse and maintaining generation diversity. The training context manager maintains the states of the policy model, reference model, and reward model (or value function model) simultaneously at this stage, and coordinates complex multi-stage gradient computations.
|
||||
1. **Preference Data Collection and Reward Modeling**: First, by collecting human annotators' ranking preferences for multiple model-generated results for the same prompt, a separate reward model (RM) can be trained. This model learns to output a scalar reward score for generated text to quantify the degree of alignment with human preferences.
|
||||
2. **Policy Optimization**: Then, using the reward model as the optimization signal, the SFT model (as the policy) is fine-tuned through reinforcement learning algorithms (DPO/GRPO). The goal of policy optimization is to maximize the expected cumulative reward obtained from the reward model, while constraining the output distribution of the policy model and reference model from deviating too much through a KL divergence penalty term, preventing mode collapse and maintaining generation diversity. The training context manager maintains the states of the policy model, reference model, and reward model (or value function model) simultaneously at this stage, and coordinates complex multi-stage gradient computations.
|
||||
|
||||
Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence. The system, through unified `Trainer` interface and strategy pattern design, makes each stage of training highly reusable at the code level, clearly decoupled at the process level, providing an efficient, flexible, and scalable engineering foundation for large-scale language model research and iteration.
|
||||
|
|
@ -5,17 +5,17 @@ from astrai.config import (
|
|||
ModelConfig,
|
||||
TrainConfig,
|
||||
)
|
||||
from astrai.model.transformer import Transformer
|
||||
from astrai.data import DatasetLoader, BpeTokenizer
|
||||
from astrai.data import BpeTokenizer, DatasetLoader
|
||||
from astrai.inference.generator import (
|
||||
GenerationRequest,
|
||||
LoopGenerator,
|
||||
StreamGenerator,
|
||||
BatchGenerator,
|
||||
EmbeddingEncoder,
|
||||
GenerationRequest,
|
||||
GeneratorFactory,
|
||||
LoopGenerator,
|
||||
StreamGenerator,
|
||||
)
|
||||
from astrai.trainer import Trainer, StrategyFactory, SchedulerFactory
|
||||
from astrai.model.transformer import Transformer
|
||||
from astrai.trainer import SchedulerFactory, StrategyFactory, Trainer
|
||||
|
||||
__all__ = [
|
||||
"Transformer",
|
||||
|
|
|
|||
|
|
@ -1,14 +1,7 @@
|
|||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.config.param_config import BaseModelIO, ModelParameter
|
||||
from astrai.config.schedule_config import (
|
||||
ScheduleConfig,
|
||||
CosineScheduleConfig,
|
||||
SGDRScheduleConfig,
|
||||
ScheduleConfigFactory,
|
||||
)
|
||||
from astrai.config.train_config import TrainConfig
|
||||
|
||||
|
||||
__all__ = [
|
||||
# Base I/O
|
||||
"BaseModelIO",
|
||||
|
|
@ -16,9 +9,4 @@ __all__ = [
|
|||
# Model configuration
|
||||
"ModelConfig",
|
||||
"TrainConfig",
|
||||
# Schedule configuration
|
||||
"ScheduleConfig",
|
||||
"CosineScheduleConfig",
|
||||
"SGDRScheduleConfig",
|
||||
"ScheduleConfigFactory",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Optional, Self
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
import torch.nn as nn
|
||||
import safetensors.torch as st
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Self, Union
|
||||
from pathlib import Path
|
||||
from typing import Self, Union
|
||||
|
||||
import safetensors.torch as st
|
||||
import torch.nn as nn
|
||||
|
||||
from astrai.data.tokenizer import BpeTokenizer
|
||||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.data.tokenizer import BpeTokenizer
|
||||
from astrai.model.transformer import Transformer
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,149 +0,0 @@
|
|||
from typing import Any, Dict, Type
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScheduleConfig(ABC):
|
||||
"""Base configuration class for learning rate schedulers.
|
||||
|
||||
Provides common validation and interface for all schedule types.
|
||||
"""
|
||||
|
||||
schedule_type: str = field(
|
||||
default="cosine",
|
||||
metadata={
|
||||
"help": "Type of learning rate schedule.",
|
||||
"choices": ["cosine", "sgdr"],
|
||||
},
|
||||
)
|
||||
warmup_steps: int = field(
|
||||
default=1000, metadata={"help": "Number of warmup steps."}
|
||||
)
|
||||
min_rate: float = field(
|
||||
default=0.05, metadata={"help": "Minimum learning rate multiplier."}
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_kwargs(self) -> Dict[str, Any]:
|
||||
"""Get configuration kwargs for scheduler creation."""
|
||||
raise NotImplementedError
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters."""
|
||||
if self.warmup_steps < 0:
|
||||
raise ValueError(
|
||||
f"warmup_steps must be non-negative, got {self.warmup_steps}"
|
||||
)
|
||||
if not 0 <= self.min_rate <= 1:
|
||||
raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CosineScheduleConfig(ScheduleConfig):
|
||||
"""Cosine annealing learning rate schedule configuration."""
|
||||
|
||||
total_steps: int = field(
|
||||
default=None, metadata={"help": "Total training steps for cosine schedule."}
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.schedule_type = "cosine"
|
||||
self.validate()
|
||||
|
||||
def get_kwargs(self) -> Dict[str, Any]:
|
||||
if self.total_steps is None:
|
||||
raise ValueError("total_steps must be specified for cosine schedule")
|
||||
|
||||
return {
|
||||
"schedule_type": self.schedule_type,
|
||||
"warmup_steps": self.warmup_steps,
|
||||
"lr_decay_steps": self.total_steps - self.warmup_steps,
|
||||
"min_rate": self.min_rate,
|
||||
}
|
||||
|
||||
def validate(self) -> None:
|
||||
super().validate()
|
||||
if self.total_steps is not None and self.total_steps <= self.warmup_steps:
|
||||
raise ValueError(
|
||||
f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SGDRScheduleConfig(ScheduleConfig):
|
||||
"""Stochastic Gradient Descent with Warm Restarts schedule configuration."""
|
||||
|
||||
cycle_length: int = field(
|
||||
default=1000, metadata={"help": "Length of the first cycle in steps."}
|
||||
)
|
||||
t_mult: int = field(
|
||||
default=2, metadata={"help": "Multiplier for cycle length growth."}
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.schedule_type = "sgdr"
|
||||
self.validate()
|
||||
|
||||
def get_kwargs(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"schedule_type": self.schedule_type,
|
||||
"warmup_steps": self.warmup_steps,
|
||||
"cycle_length": self.cycle_length,
|
||||
"min_rate": self.min_rate,
|
||||
"t_mult": self.t_mult,
|
||||
}
|
||||
|
||||
def validate(self) -> None:
|
||||
super().validate()
|
||||
if self.cycle_length <= 0:
|
||||
raise ValueError(f"cycle_length must be positive, got {self.cycle_length}")
|
||||
if self.t_mult < 1:
|
||||
raise ValueError(f"t_mult must be >= 1, got {self.t_mult}")
|
||||
|
||||
|
||||
class ScheduleConfigFactory:
|
||||
"""Factory class for creating ScheduleConfig instances.
|
||||
|
||||
Supports both direct instantiation and factory creation methods.
|
||||
|
||||
Example usage:
|
||||
# Direct creation
|
||||
config = CosineScheduleConfig(total_steps=10000)
|
||||
|
||||
# Factory method
|
||||
config = ScheduleConfigFactory.create("cosine", total_steps=10000)
|
||||
"""
|
||||
|
||||
CONFIG_MAP: Dict[str, Type[ScheduleConfig]] = {
|
||||
"cosine": CosineScheduleConfig,
|
||||
"sgdr": SGDRScheduleConfig,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create(cls, schedule_type: str, **kwargs) -> ScheduleConfig:
|
||||
"""Create a schedule config instance.
|
||||
|
||||
Args:
|
||||
schedule_type: Type of schedule ("cosine", "sgdr")
|
||||
**kwargs: Arguments passed to the config constructor
|
||||
|
||||
Returns:
|
||||
ScheduleConfig instance
|
||||
|
||||
Raises:
|
||||
ValueError: If schedule_type is not supported
|
||||
"""
|
||||
if schedule_type not in cls.CONFIG_MAP:
|
||||
raise ValueError(
|
||||
f"Unknown schedule type: '{schedule_type}'. "
|
||||
f"Supported types: {sorted(cls.CONFIG_MAP.keys())}"
|
||||
)
|
||||
|
||||
config_cls = cls.CONFIG_MAP[schedule_type]
|
||||
return config_cls(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def available_types(cls) -> list:
|
||||
"""Return list of available schedule type names."""
|
||||
return list(cls.CONFIG_MAP.keys())
|
||||
|
|
@ -1,11 +1,11 @@
|
|||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainConfig:
|
||||
|
|
|
|||
|
|
@ -1,16 +1,15 @@
|
|||
from astrai.data.dataset import (
|
||||
BaseDataset,
|
||||
SEQDataset,
|
||||
DatasetFactory,
|
||||
DatasetLoader,
|
||||
DPODataset,
|
||||
SFTDataset,
|
||||
GRPODataset,
|
||||
MultiSegmentFetcher,
|
||||
DatasetLoader,
|
||||
DatasetFactory,
|
||||
SEQDataset,
|
||||
SFTDataset,
|
||||
)
|
||||
|
||||
from astrai.data.tokenizer import BpeTokenizer
|
||||
from astrai.data.sampler import ResumableDistributedSampler
|
||||
from astrai.data.tokenizer import BpeTokenizer
|
||||
|
||||
__all__ = [
|
||||
# Base classes
|
||||
|
|
|
|||
|
|
@ -1,13 +1,14 @@
|
|||
"""Dataset implementations with factory pattern for training."""
|
||||
|
||||
import torch
|
||||
import bisect
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from astrai.data.serialization import load_h5
|
||||
from typing import List, Dict, Optional, Union
|
||||
|
||||
|
||||
class BaseSegmentFetcher:
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from torch.utils.data import Dataset, Sampler
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ResumableDistributedSampler(Sampler[int]):
|
||||
|
|
|
|||
|
|
@ -1,13 +1,14 @@
|
|||
import os
|
||||
import h5py
|
||||
import torch
|
||||
import json
|
||||
import safetensors.torch as st
|
||||
import torch.distributed as dist
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from torch import Tensor
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import h5py
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.parallel.setup import get_rank
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from tokenizers import Tokenizer, decoders, processors, normalizers, pre_tokenizers
|
||||
from typing import List, Union
|
||||
|
||||
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer as BpeTrainerImpl
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
class BaseTokenizer(ABC):
|
||||
|
|
|
|||
|
|
@ -1,16 +1,15 @@
|
|||
from astrai.inference.core import (
|
||||
GeneratorCore,
|
||||
EmbeddingEncoderCore,
|
||||
GeneratorCore,
|
||||
KVCacheManager,
|
||||
)
|
||||
|
||||
from astrai.inference.generator import (
|
||||
GenerationRequest,
|
||||
LoopGenerator,
|
||||
StreamGenerator,
|
||||
BatchGenerator,
|
||||
EmbeddingEncoder,
|
||||
GenerationRequest,
|
||||
GeneratorFactory,
|
||||
LoopGenerator,
|
||||
StreamGenerator,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import torch
|
||||
from typing import Any, Callable, List, Optional, Self, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from typing import Any, Callable, List, Tuple, Union, Optional, Self
|
||||
from astrai.config import ModelParameter, ModelConfig
|
||||
|
||||
from astrai.config import ModelConfig, ModelParameter
|
||||
|
||||
|
||||
def apply_sampling_strategies(
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
import torch
|
||||
from dataclasses import dataclass
|
||||
from torch import Tensor
|
||||
from typing import List, Tuple, Union, Optional, Generator
|
||||
from astrai.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheManager
|
||||
from astrai.config.param_config import ModelParameter
|
||||
from typing import Generator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.config.param_config import ModelParameter
|
||||
from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager
|
||||
|
||||
HistoryType = List[Tuple[str, str]]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,15 @@
|
|||
import torch
|
||||
import uvicorn
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from astrai.config.param_config import ModelParameter
|
||||
from astrai.inference.generator import GeneratorFactory, GenerationRequest
|
||||
from astrai.inference.generator import GenerationRequest, GeneratorFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from astrai.model.module import (
|
||||
GQA,
|
||||
MLP,
|
||||
DecoderBlock,
|
||||
Linear,
|
||||
RMSNorm,
|
||||
MLP,
|
||||
GQA,
|
||||
DecoderBlock,
|
||||
)
|
||||
from astrai.model.transformer import Transformer
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch import Tensor
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
from typing import Any, Mapping, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torch import Tensor
|
||||
from typing import Any, Mapping, Optional, Tuple
|
||||
|
||||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.model.module import (
|
||||
Embedding,
|
||||
DecoderBlock,
|
||||
Embedding,
|
||||
Linear,
|
||||
RMSNorm,
|
||||
RotaryEmbedding,
|
||||
|
|
|
|||
|
|
@ -1,14 +1,13 @@
|
|||
from astrai.parallel.module import ColumnParallelLinear, RowParallelLinear
|
||||
from astrai.parallel.setup import (
|
||||
get_world_size,
|
||||
get_rank,
|
||||
get_current_device,
|
||||
get_rank,
|
||||
get_world_size,
|
||||
only_on_rank,
|
||||
setup_parallel,
|
||||
spawn_parallel_fn,
|
||||
)
|
||||
|
||||
from astrai.parallel.module import RowParallelLinear, ColumnParallelLinear
|
||||
|
||||
__all__ = [
|
||||
"get_world_size",
|
||||
"get_rank",
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
|
||||
from torch import Tensor
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class ParallelModel(nn.Module):
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
import os
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from functools import wraps
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
|
||||
def get_current_device():
|
||||
return os.environ["LOCAL_DEVICE"]
|
||||
|
|
|
|||
|
|
@ -1,15 +1,14 @@
|
|||
from astrai.trainer.trainer import Trainer
|
||||
from astrai.trainer.strategy import StrategyFactory, BaseStrategy
|
||||
from astrai.trainer.schedule import SchedulerFactory, BaseScheduler
|
||||
|
||||
from astrai.trainer.schedule import BaseScheduler, SchedulerFactory
|
||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||
from astrai.trainer.train_callback import (
|
||||
TrainCallback,
|
||||
GradientClippingCallback,
|
||||
SchedulerCallback,
|
||||
CheckpointCallback,
|
||||
ProgressBarCallback,
|
||||
GradientClippingCallback,
|
||||
MetricLoggerCallback,
|
||||
ProgressBarCallback,
|
||||
SchedulerCallback,
|
||||
TrainCallback,
|
||||
)
|
||||
from astrai.trainer.trainer import Trainer
|
||||
|
||||
__all__ = [
|
||||
# Main trainer
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import torch.nn as nn
|
||||
from typing import Dict
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]:
|
||||
"""Compute gradient norm for each parameter in the model."""
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
"""Learning rate scheduler implementations with factory pattern."""
|
||||
|
||||
import math
|
||||
from abc import abstractmethod, ABC
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Type
|
||||
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from astrai.config.schedule_config import ScheduleConfig
|
||||
|
||||
|
||||
class BaseScheduler(LRScheduler, ABC):
|
||||
|
|
@ -37,10 +37,6 @@ class SchedulerFactory:
|
|||
...
|
||||
|
||||
scheduler = SchedulerFactory.create(optimizer, "custom", **kwargs)
|
||||
|
||||
# Or from config
|
||||
config = CosineScheduleConfig(total_steps=10000)
|
||||
scheduler = SchedulerFactory.load(optimizer, config)
|
||||
"""
|
||||
|
||||
SCHEDULER_MAP: Dict[str, Type[BaseScheduler]] = {}
|
||||
|
|
@ -67,7 +63,7 @@ class SchedulerFactory:
|
|||
return decorator
|
||||
|
||||
@classmethod
|
||||
def create(cls, optimizer, schedule_type: str, **kwargs) -> BaseScheduler:
|
||||
def create(cls, optimizer, schedule_type: str = "none", **kwargs) -> BaseScheduler:
|
||||
"""Create a scheduler instance by type name.
|
||||
|
||||
Args:
|
||||
|
|
@ -90,29 +86,13 @@ class SchedulerFactory:
|
|||
scheduler_cls = cls.SCHEDULER_MAP[schedule_type]
|
||||
return scheduler_cls(optimizer, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def load(optimizer, schedule_config: ScheduleConfig) -> BaseScheduler:
|
||||
"""Create a scheduler from a ScheduleConfig object.
|
||||
|
||||
Args:
|
||||
optimizer: PyTorch optimizer
|
||||
schedule_config: ScheduleConfig instance
|
||||
|
||||
Returns:
|
||||
Scheduler instance
|
||||
"""
|
||||
kwargs = schedule_config.get_kwargs()
|
||||
schedule_type = kwargs.pop("schedule_type")
|
||||
return SchedulerFactory.create(optimizer, schedule_type, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def available_types(cls) -> list:
|
||||
"""Return list of registered scheduler type names."""
|
||||
return list(cls.SCHEDULER_MAP.keys())
|
||||
|
||||
|
||||
# ============== Scheduler Classes ==============
|
||||
# All scheduler classes are registered at class definition time using the decorator
|
||||
# ----------- Scheduler implementations -----------
|
||||
|
||||
|
||||
@SchedulerFactory.register("cosine")
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
"""Training strategy implementations with factory pattern."""
|
||||
|
||||
import copy
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from torch import Tensor
|
||||
from typing import Any, Callable, Dict, Union
|
||||
from abc import ABC, abstractmethod
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
|
||||
def unwrap_model(model: nn.Module) -> nn.Module:
|
||||
|
|
|
|||
|
|
@ -1,25 +1,25 @@
|
|||
import os
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import torch.nn as nn
|
||||
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from typing import Callable, List, Optional, Protocol
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from tqdm import tqdm
|
||||
|
||||
from astrai.data.serialization import Checkpoint
|
||||
from astrai.parallel import only_on_rank
|
||||
from astrai.trainer.metric_util import (
|
||||
ctx_get_grad_max,
|
||||
ctx_get_grad_mean,
|
||||
ctx_get_grad_min,
|
||||
ctx_get_grad_nan_num,
|
||||
ctx_get_grad_norm,
|
||||
ctx_get_grad_std,
|
||||
ctx_get_loss,
|
||||
ctx_get_lr,
|
||||
ctx_get_grad_max,
|
||||
ctx_get_grad_min,
|
||||
ctx_get_grad_norm,
|
||||
ctx_get_grad_mean,
|
||||
ctx_get_grad_std,
|
||||
ctx_get_grad_nan_num,
|
||||
)
|
||||
from astrai.data.serialization import Checkpoint
|
||||
from astrai.trainer.train_context import TrainContext
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,16 +1,16 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Self
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from astrai.config.train_config import TrainConfig
|
||||
from astrai.data import ResumableDistributedSampler
|
||||
from astrai.data.serialization import Checkpoint
|
||||
from astrai.trainer.strategy import StrategyFactory, BaseStrategy
|
||||
from astrai.config.train_config import TrainConfig
|
||||
from astrai.parallel.setup import get_current_device, get_world_size, get_rank
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Self
|
||||
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -1,17 +1,18 @@
|
|||
import logging
|
||||
from typing import Optional, List
|
||||
from typing import List, Optional
|
||||
|
||||
from astrai.config import TrainConfig
|
||||
from astrai.trainer.train_callback import (
|
||||
TrainCallback,
|
||||
ProgressBarCallback,
|
||||
CheckpointCallback,
|
||||
MetricLoggerCallback,
|
||||
GradientClippingCallback,
|
||||
SchedulerCallback,
|
||||
)
|
||||
from astrai.trainer.train_context import TrainContext, TrainContextBuilder
|
||||
from astrai.data.serialization import Checkpoint
|
||||
from astrai.parallel.setup import spawn_parallel_fn
|
||||
from astrai.trainer.train_callback import (
|
||||
CheckpointCallback,
|
||||
GradientClippingCallback,
|
||||
MetricLoggerCallback,
|
||||
ProgressBarCallback,
|
||||
SchedulerCallback,
|
||||
TrainCallback,
|
||||
)
|
||||
from astrai.trainer.train_context import TrainContext, TrainContextBuilder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from astrai.config.param_config import ModelParameter
|
||||
from astrai.inference.generator import GeneratorFactory, GenerationRequest
|
||||
from astrai.inference.generator import GenerationRequest, GeneratorFactory
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from astrai.config.param_config import ModelParameter
|
||||
from astrai.inference.generator import GeneratorFactory, GenerationRequest
|
||||
from astrai.inference.generator import GenerationRequest, GeneratorFactory
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from astrai.config.param_config import ModelParameter
|
||||
from astrai.inference.generator import GeneratorFactory, GenerationRequest
|
||||
from astrai.inference.generator import GenerationRequest, GeneratorFactory
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import torch
|
||||
from typing import Dict, Any
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from astrai.model.transformer import ModelConfig, Transformer
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
import json
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import torch
|
||||
|
||||
from astrai.config.param_config import ModelParameter
|
||||
from astrai.inference.generator import BatchGenerator, GenerationRequest
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
import argparse
|
||||
import json
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import argparse
|
||||
import tqdm
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.config.param_config import ModelParameter
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from astrai.inference.server import run_server
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,15 +1,16 @@
|
|||
import os
|
||||
import argparse
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from functools import partial
|
||||
from astrai.config import ModelParameter, TrainConfig
|
||||
from astrai.data import DatasetLoader
|
||||
from astrai.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
||||
from astrai.trainer import Trainer, SchedulerFactory
|
||||
from astrai.parallel import get_rank
|
||||
from astrai.trainer import SchedulerFactory, Trainer
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
|
|
@ -158,7 +159,7 @@ def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
|
|||
def create_scheduler(
|
||||
optimizer: optim.Optimizer, **kwargs
|
||||
) -> optim.lr_scheduler.LRScheduler:
|
||||
return SchedulerFactory.load(optimizer, **kwargs)
|
||||
return SchedulerFactory.create(optimizer, **kwargs)
|
||||
|
||||
|
||||
def prepare_checkpoint(model: nn.Module) -> dict:
|
||||
|
|
@ -211,11 +212,6 @@ def train(
|
|||
stride=stride,
|
||||
)
|
||||
|
||||
schedule_config = CosineScheduleConfig(
|
||||
warmup_steps=warmup_steps,
|
||||
total_steps=len(dataset) * n_epoch // (batch_size * nprocs),
|
||||
)
|
||||
|
||||
optimizer_fn = partial(
|
||||
create_optimizer,
|
||||
**{
|
||||
|
|
@ -224,7 +220,16 @@ def train(
|
|||
"weight_decay": adamw_weight_decay,
|
||||
},
|
||||
)
|
||||
scheduler_fn = partial(create_scheduler, **{"schedule_config": schedule_config})
|
||||
|
||||
toltal_steps = len(dataset) * n_epoch // (batch_size * nprocs)
|
||||
scheduler_fn = partial(
|
||||
create_scheduler,
|
||||
**{
|
||||
"scheduler": "cosine",
|
||||
"warmup_steps": warmup_steps,
|
||||
"lr_decay_steps": toltal_steps - warmup_steps,
|
||||
},
|
||||
)
|
||||
|
||||
train_config = TrainConfig(
|
||||
model=model,
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import os
|
||||
import shutil
|
||||
import torch
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
from tokenizers import pre_tokenizers
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
import torch
|
||||
import tempfile
|
||||
import torch.distributed as dist
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
from astrai.data.serialization import Checkpoint
|
||||
from astrai.parallel.setup import get_rank, spawn_parallel_fn
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from astrai.data.serialization import save_h5
|
||||
from astrai.data.dataset import *
|
||||
from astrai.data.serialization import save_h5
|
||||
|
||||
|
||||
def test_dataset_loader_random_paths(base_test_env):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from astrai.trainer import *
|
||||
from astrai.data import *
|
||||
from astrai.trainer import *
|
||||
|
||||
|
||||
def test_random_sampler_consistency(random_dataset):
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
"""Shared fixtures for inference tests."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from astrai.inference.server import app
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,6 @@
|
|||
"""Unit tests for the inference HTTP server."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from fastapi.testclient import TestClient
|
||||
from astrai.inference.server import app
|
||||
|
||||
|
||||
def test_health_no_model(client, monkeypatch):
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
from astrai.trainer import *
|
||||
|
||||
from astrai.config import *
|
||||
from astrai.model import *
|
||||
from astrai.data import *
|
||||
from astrai.inference.generator import EmbeddingEncoderCore, GeneratorCore
|
||||
from astrai.model import *
|
||||
from astrai.trainer import *
|
||||
|
||||
|
||||
def test_model_parameter(test_env):
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
import os
|
||||
import json
|
||||
import torch
|
||||
import pytest
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import safetensors.torch as st
|
||||
from astrai.model.transformer import Transformer
|
||||
import torch
|
||||
|
||||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.model.transformer import Transformer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import pytest
|
||||
|
||||
from astrai.config import TrainConfig
|
||||
from astrai.trainer.schedule import SchedulerFactory
|
||||
|
||||
|
||||
class TrainerDataset(Dataset):
|
||||
|
|
@ -54,13 +57,11 @@ def create_train_config(
|
|||
Returns:
|
||||
TrainConfig instance configured for testing
|
||||
"""
|
||||
from astrai.config import TrainConfig
|
||||
from astrai.config.schedule_config import CosineScheduleConfig
|
||||
from astrai.trainer.schedule import SchedulerFactory
|
||||
|
||||
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
|
||||
optimizer_fn = lambda m: torch.optim.AdamW(m.parameters(), lr=0.001)
|
||||
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
||||
scheduler_fn = lambda optim: SchedulerFactory.create(
|
||||
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
||||
)
|
||||
|
||||
return TrainConfig(
|
||||
strategy=strategy,
|
||||
|
|
|
|||
|
|
@ -6,10 +6,10 @@ from astrai.trainer import *
|
|||
|
||||
def test_callback_integration(base_test_env, random_dataset):
|
||||
"""Test that all callbacks are properly integrated"""
|
||||
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
|
||||
|
||||
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
||||
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
||||
scheduler_fn = lambda optim: SchedulerFactory.create(
|
||||
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
||||
)
|
||||
|
||||
train_config = TrainConfig(
|
||||
model=base_test_env["model"],
|
||||
|
|
|
|||
|
|
@ -1,18 +1,20 @@
|
|||
import os
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from astrai.config import *
|
||||
from astrai.trainer import *
|
||||
from astrai.data.serialization import Checkpoint
|
||||
from astrai.trainer import *
|
||||
|
||||
|
||||
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||
"""Simulate early stopping behavior"""
|
||||
|
||||
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
|
||||
|
||||
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
||||
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
||||
scheduler_fn = lambda optim: SchedulerFactory.create(
|
||||
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
||||
)
|
||||
|
||||
train_config = TrainConfig(
|
||||
strategy="seq",
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from astrai.config import *
|
||||
from astrai.trainer.schedule import *
|
||||
from astrai.data.dataset import *
|
||||
from astrai.trainer.schedule import *
|
||||
|
||||
|
||||
def test_schedule_factory_random_configs():
|
||||
|
|
@ -16,41 +15,57 @@ def test_schedule_factory_random_configs():
|
|||
|
||||
# Test multiple random configurations
|
||||
for _ in range(5): # Test 5 random configurations
|
||||
schedule_configs = [
|
||||
CosineScheduleConfig(
|
||||
warmup_steps=np.random.randint(50, 200),
|
||||
total_steps=np.random.randint(1000, 5000),
|
||||
min_rate=np.random.uniform(0.01, 0.1),
|
||||
),
|
||||
SGDRScheduleConfig(
|
||||
warmup_steps=np.random.randint(50, 200),
|
||||
cycle_length=np.random.randint(500, 2000),
|
||||
t_mult=np.random.randint(1, 3),
|
||||
min_rate=np.random.uniform(0.01, 0.1),
|
||||
),
|
||||
]
|
||||
|
||||
for config in schedule_configs:
|
||||
# Validate configuration
|
||||
config.validate()
|
||||
|
||||
# Create scheduler using factory
|
||||
scheduler = SchedulerFactory.load(optimizer, config)
|
||||
|
||||
# Verify scheduler type
|
||||
if isinstance(config, CosineScheduleConfig):
|
||||
assert isinstance(scheduler, CosineScheduler)
|
||||
assert scheduler.warmup_steps == config.warmup_steps
|
||||
assert (
|
||||
scheduler.lr_decay_steps == config.total_steps - config.warmup_steps
|
||||
# Test multiple random configurations
|
||||
cosine_params = {
|
||||
"schedule_type": "cosine",
|
||||
"warmup_steps": np.random.randint(50, 200),
|
||||
"total_steps": np.random.randint(1000, 5000),
|
||||
"min_rate": np.random.uniform(0.01, 0.1),
|
||||
}
|
||||
sgdr_params = {
|
||||
"schedule_type": "sgdr",
|
||||
"warmup_steps": np.random.randint(50, 200),
|
||||
"cycle_length": np.random.randint(500, 2000),
|
||||
"t_mult": np.random.randint(1, 3),
|
||||
"min_rate": np.random.uniform(0.01, 0.1),
|
||||
}
|
||||
for params in [cosine_params, sgdr_params]:
|
||||
schedule_type = params["schedule_type"]
|
||||
# Convert parameters for scheduler constructor
|
||||
if schedule_type == "cosine":
|
||||
warmup_steps = params["warmup_steps"]
|
||||
total_steps = params["total_steps"]
|
||||
min_rate = params["min_rate"]
|
||||
lr_decay_steps = total_steps - warmup_steps
|
||||
scheduler = SchedulerFactory.create(
|
||||
optimizer,
|
||||
schedule_type,
|
||||
warmup_steps=warmup_steps,
|
||||
lr_decay_steps=lr_decay_steps,
|
||||
min_rate=min_rate,
|
||||
)
|
||||
assert isinstance(scheduler, CosineScheduler)
|
||||
assert scheduler.warmup_steps == warmup_steps
|
||||
assert scheduler.lr_decay_steps == lr_decay_steps
|
||||
assert scheduler.min_rate == min_rate
|
||||
elif schedule_type == "sgdr":
|
||||
warmup_steps = params["warmup_steps"]
|
||||
cycle_length = params["cycle_length"]
|
||||
t_mult = params["t_mult"]
|
||||
min_rate = params["min_rate"]
|
||||
scheduler = SchedulerFactory.create(
|
||||
optimizer,
|
||||
schedule_type,
|
||||
warmup_steps=warmup_steps,
|
||||
cycle_length=cycle_length,
|
||||
t_mult=t_mult,
|
||||
min_rate=min_rate,
|
||||
)
|
||||
assert scheduler.min_rate == config.min_rate
|
||||
elif isinstance(config, SGDRScheduleConfig):
|
||||
assert isinstance(scheduler, SGDRScheduler)
|
||||
assert scheduler.warmup_steps == config.warmup_steps
|
||||
assert scheduler.cycle_length == config.cycle_length
|
||||
assert scheduler.t_mult == config.t_mult
|
||||
assert scheduler.min_rate == config.min_rate
|
||||
assert scheduler.warmup_steps == warmup_steps
|
||||
assert scheduler.cycle_length == cycle_length
|
||||
assert scheduler.t_mult == t_mult
|
||||
assert scheduler.min_rate == min_rate
|
||||
|
||||
# Test scheduler state dict functionality
|
||||
state_dict = scheduler.state_dict()
|
||||
|
|
@ -76,16 +91,25 @@ def test_schedule_factory_edge_cases():
|
|||
# Test edge cases for CosineScheduleConfig
|
||||
edge_cases = [
|
||||
# Minimal warmup and steps
|
||||
CosineScheduleConfig(warmup_steps=1, total_steps=10, min_rate=0.01),
|
||||
{"warmup_steps": 1, "total_steps": 10, "min_rate": 0.01},
|
||||
# Large values
|
||||
CosineScheduleConfig(warmup_steps=1000, total_steps=10000, min_rate=0.5),
|
||||
{"warmup_steps": 1000, "total_steps": 10000, "min_rate": 0.5},
|
||||
# Zero min_rate (edge case)
|
||||
CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.0),
|
||||
{"warmup_steps": 100, "total_steps": 1000, "min_rate": 0.0},
|
||||
]
|
||||
|
||||
for config in edge_cases:
|
||||
config.validate()
|
||||
scheduler = SchedulerFactory.load(optimizer, config)
|
||||
for params in edge_cases:
|
||||
warmup_steps = params["warmup_steps"]
|
||||
total_steps = params["total_steps"]
|
||||
min_rate = params["min_rate"]
|
||||
lr_decay_steps = total_steps - warmup_steps
|
||||
scheduler = SchedulerFactory.create(
|
||||
optimizer,
|
||||
"cosine",
|
||||
warmup_steps=warmup_steps,
|
||||
lr_decay_steps=lr_decay_steps,
|
||||
min_rate=min_rate,
|
||||
)
|
||||
assert scheduler is not None
|
||||
|
||||
# Test multiple steps
|
||||
|
|
@ -93,34 +117,24 @@ def test_schedule_factory_edge_cases():
|
|||
scheduler.step()
|
||||
|
||||
|
||||
def test_schedule_factory_invalid_configs():
|
||||
"""Test scheduler factory with invalid configurations"""
|
||||
|
||||
# Test invalid configurations that should raise errors
|
||||
invalid_configs = [
|
||||
# Negative warmup steps
|
||||
{"warmup_steps": -10, "total_steps": 1000, "min_rate": 0.1},
|
||||
# Total steps less than warmup steps
|
||||
{"warmup_steps": 500, "total_steps": 400, "min_rate": 0.1},
|
||||
# Invalid min_rate
|
||||
{"warmup_steps": 100, "total_steps": 1000, "min_rate": -0.1},
|
||||
{"warmup_steps": 100, "total_steps": 1000, "min_rate": 1.1},
|
||||
]
|
||||
|
||||
for kwargs in invalid_configs:
|
||||
with pytest.raises(ValueError):
|
||||
config = CosineScheduleConfig(**kwargs)
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_schedule_factory_state_persistence():
|
||||
"""Test scheduler state persistence (save/load)"""
|
||||
|
||||
model = torch.nn.Linear(10, 2)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
|
||||
|
||||
config = CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.1)
|
||||
scheduler = SchedulerFactory.load(optimizer, config)
|
||||
# Create scheduler directly with parameters
|
||||
warmup_steps = 100
|
||||
total_steps = 1000
|
||||
min_rate = 0.1
|
||||
lr_decay_steps = total_steps - warmup_steps
|
||||
scheduler = SchedulerFactory.create(
|
||||
optimizer,
|
||||
"cosine",
|
||||
warmup_steps=warmup_steps,
|
||||
lr_decay_steps=lr_decay_steps,
|
||||
min_rate=min_rate,
|
||||
)
|
||||
|
||||
# Take a few steps
|
||||
for _ in range(5):
|
||||
|
|
@ -129,8 +143,14 @@ def test_schedule_factory_state_persistence():
|
|||
# Save state
|
||||
state_dict = scheduler.state_dict()
|
||||
|
||||
# Create new scheduler and load state
|
||||
new_scheduler = SchedulerFactory.load(optimizer, config)
|
||||
# Create new scheduler with same parameters
|
||||
new_scheduler = SchedulerFactory.create(
|
||||
optimizer,
|
||||
"cosine",
|
||||
warmup_steps=warmup_steps,
|
||||
lr_decay_steps=lr_decay_steps,
|
||||
min_rate=min_rate,
|
||||
)
|
||||
new_scheduler.load_state_dict(state_dict)
|
||||
|
||||
# Verify states match
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
import torch
|
||||
|
||||
from astrai.data.dataset import *
|
||||
from astrai.trainer import Trainer
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue