From 0852b852f828f2387667f692d4f55c344ddbd7ae Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Fri, 3 Apr 2026 22:06:32 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E4=BC=98=E5=8C=96=E5=8F=82?= =?UTF-8?q?=E6=95=B0=E4=BC=A0=E9=80=92=EF=BC=8C=E6=B8=85=E7=90=86=E5=AF=BC?= =?UTF-8?q?=E5=85=A5=E6=A0=B7=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- assets/docs/design.md | 22 ++-- astrai/__init__.py | 12 +-- astrai/config/__init__.py | 12 --- astrai/config/model_config.py | 1 - astrai/config/param_config.py | 10 +- astrai/config/schedule_config.py | 149 -------------------------- astrai/config/train_config.py | 10 +- astrai/data/__init__.py | 11 +- astrai/data/dataset.py | 7 +- astrai/data/sampler.py | 4 +- astrai/data/serialization.py | 15 +-- astrai/data/tokenizer.py | 5 +- astrai/inference/__init__.py | 9 +- astrai/inference/core.py | 7 +- astrai/inference/generator.py | 11 +- astrai/inference/server.py | 10 +- astrai/model/__init__.py | 6 +- astrai/model/module.py | 4 +- astrai/model/transformer.py | 7 +- astrai/parallel/__init__.py | 7 +- astrai/parallel/module.py | 6 +- astrai/parallel/setup.py | 8 +- astrai/trainer/__init__.py | 15 ++- astrai/trainer/metric_util.py | 3 +- astrai/trainer/schedule.py | 28 +---- astrai/trainer/strategy.py | 8 +- astrai/trainer/train_callback.py | 24 ++--- astrai/trainer/train_context.py | 12 +-- astrai/trainer/trainer.py | 21 ++-- scripts/demo/download.py | 1 + scripts/demo/generate_ar.py | 6 +- scripts/demo/generate_batch.py | 6 +- scripts/demo/stream_chat.py | 6 +- scripts/tools/benchmark.py | 6 +- scripts/tools/generate.py | 5 +- scripts/tools/perplexity.py | 5 +- scripts/tools/server.py | 1 + scripts/tools/train.py | 27 +++-- tests/conftest.py | 9 +- tests/data/test_checkpoint.py | 5 +- tests/data/test_dataset.py | 4 +- tests/data/test_sampler.py | 2 +- tests/inference/conftest.py | 4 +- tests/inference/test_server.py | 3 - tests/module/test_module.py | 6 +- tests/module/test_tie_weight.py | 10 +- tests/trainer/conftest.py | 13 +-- tests/trainer/test_callbacks.py | 6 +- tests/trainer/test_early_stopping.py | 12 ++- tests/trainer/test_train_strategy.py | 152 +++++++++++++++------------ tests/trainer/test_trainer.py | 2 - 51 files changed, 300 insertions(+), 435 deletions(-) delete mode 100644 astrai/config/schedule_config.py diff --git a/assets/docs/design.md b/assets/docs/design.md index 170731e..b9fa3d3 100644 --- a/assets/docs/design.md +++ b/assets/docs/design.md @@ -24,7 +24,6 @@ flowchart TB C1[model_config.py
Model Architecture] C2[train_config.py
Training Params] C3[param_config.py
Hyperparameters] - C4[schedule_config.py
Scheduler Config] end subgraph Data["Data Module (data/)"] @@ -95,14 +94,13 @@ flowchart TB ### 1. Configuration Module (config/) - **model_config.py**: Defines model structure parameters (layers, heads, dimensions, etc.), managed through `ModelConfig`. -- **train_config.py**: Sets training parameters (batch size, training stages PT/SFT/DPO, optimizers, etc.), loaded by `TrainConfig`. +- **train_config.py**: Sets training parameters (batch size, training stages SEQ/SFT/GRPO/DPO, optimizers, etc.), loaded by `TrainConfig`. - **param_config.py**: Manages hyperparameters for training and inference. -- **schedule_config.py**: Controls learning rate strategies (cosine annealing) and training progress. ### 2. Data Module (data/) - **dataset.py**: Dataset handling and loading. - **sampler.py**: Data sampling for different training stages. -- **serialization.py**: Data serialization and deserialization. +- **serialization.py**: Data serialization and deserialization, checkpoint management. - **tokenizer.py**: Text tokenization and encoding. ### 3. Model Module (model/) @@ -112,15 +110,15 @@ flowchart TB ### 4. Trainer Module (trainer/) - **trainer.py**: Main training entry point. - **train_context.py**: Training context management (model, optimizer, scheduler, metrics). -- **strategy.py**: Training strategies for PT/SFT/DPO stages. -- **schedule.py**: Learning rate scheduler. +- **strategy.py**: Training strategies for SEQ/SFT/GRPO/DPO stages via `StrategyFactory`. +- **schedule.py**: Learning rate scheduler implementation (cosine, SGDR, etc.). - **train_callback.py**: Training callbacks (checkpoint, early stopping, etc.). - **metric_util.py**: Metrics calculation and tracking. ### 5. Inference Module (inference/) - **generator.py**: Text generation with various methods (sync, batch, streaming). - **core.py**: Inference core with KV cache optimization. -- **server.py**: API service for inference. +- **server.py**: API service for inference (FastAPI + Uvicorn). ### 6. Parallel Module (parallel/) - **setup.py**: Distributed initialization for multi-GPU/multi-machine training. @@ -134,7 +132,7 @@ flowchart TB The common training process for large language models (LLM) typically includes three stages: **Pre-training (PT)**, **Supervised Fine-Tuning (SFT)**, and **Reinforcement Learning from Human Feedback (RLHF)**. This system is designed to support seamless end-to-end flow, achieving efficient switching and state management of different training stages through modular strategies, ensuring the model's capabilities gradually evolve from general language understanding to human-preference-aligned dialogue and instruction execution. -### **2.1 Pre-training Stage** +### **2.1 Pre-training Stage (SEQ/PT)** The pre-training stage aims to build the model's foundational language capabilities and general knowledge representation. This stage performs self-supervised learning on large-scale, unlabeled corpora (typically covering hundreds of GB to TB of text data). The model architecture is based on the standard Transformer Decoder, trained through masked language modeling objectives (such as causal language modeling), enabling the model to learn vocabulary, grammar, semantics, and world knowledge embedded in text. @@ -152,7 +150,7 @@ $$ - $\theta$: Model parameters - $P(x_t \mid x_{ Dict[str, Any]: - """Get configuration kwargs for scheduler creation.""" - raise NotImplementedError - - def validate(self) -> None: - """Validate configuration parameters.""" - if self.warmup_steps < 0: - raise ValueError( - f"warmup_steps must be non-negative, got {self.warmup_steps}" - ) - if not 0 <= self.min_rate <= 1: - raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}") - - -@dataclass -class CosineScheduleConfig(ScheduleConfig): - """Cosine annealing learning rate schedule configuration.""" - - total_steps: int = field( - default=None, metadata={"help": "Total training steps for cosine schedule."} - ) - - def __post_init__(self) -> None: - self.schedule_type = "cosine" - self.validate() - - def get_kwargs(self) -> Dict[str, Any]: - if self.total_steps is None: - raise ValueError("total_steps must be specified for cosine schedule") - - return { - "schedule_type": self.schedule_type, - "warmup_steps": self.warmup_steps, - "lr_decay_steps": self.total_steps - self.warmup_steps, - "min_rate": self.min_rate, - } - - def validate(self) -> None: - super().validate() - if self.total_steps is not None and self.total_steps <= self.warmup_steps: - raise ValueError( - f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})" - ) - - -@dataclass -class SGDRScheduleConfig(ScheduleConfig): - """Stochastic Gradient Descent with Warm Restarts schedule configuration.""" - - cycle_length: int = field( - default=1000, metadata={"help": "Length of the first cycle in steps."} - ) - t_mult: int = field( - default=2, metadata={"help": "Multiplier for cycle length growth."} - ) - - def __post_init__(self) -> None: - self.schedule_type = "sgdr" - self.validate() - - def get_kwargs(self) -> Dict[str, Any]: - return { - "schedule_type": self.schedule_type, - "warmup_steps": self.warmup_steps, - "cycle_length": self.cycle_length, - "min_rate": self.min_rate, - "t_mult": self.t_mult, - } - - def validate(self) -> None: - super().validate() - if self.cycle_length <= 0: - raise ValueError(f"cycle_length must be positive, got {self.cycle_length}") - if self.t_mult < 1: - raise ValueError(f"t_mult must be >= 1, got {self.t_mult}") - - -class ScheduleConfigFactory: - """Factory class for creating ScheduleConfig instances. - - Supports both direct instantiation and factory creation methods. - - Example usage: - # Direct creation - config = CosineScheduleConfig(total_steps=10000) - - # Factory method - config = ScheduleConfigFactory.create("cosine", total_steps=10000) - """ - - CONFIG_MAP: Dict[str, Type[ScheduleConfig]] = { - "cosine": CosineScheduleConfig, - "sgdr": SGDRScheduleConfig, - } - - @classmethod - def create(cls, schedule_type: str, **kwargs) -> ScheduleConfig: - """Create a schedule config instance. - - Args: - schedule_type: Type of schedule ("cosine", "sgdr") - **kwargs: Arguments passed to the config constructor - - Returns: - ScheduleConfig instance - - Raises: - ValueError: If schedule_type is not supported - """ - if schedule_type not in cls.CONFIG_MAP: - raise ValueError( - f"Unknown schedule type: '{schedule_type}'. " - f"Supported types: {sorted(cls.CONFIG_MAP.keys())}" - ) - - config_cls = cls.CONFIG_MAP[schedule_type] - return config_cls(**kwargs) - - @classmethod - def available_types(cls) -> list: - """Return list of available schedule type names.""" - return list(cls.CONFIG_MAP.keys()) diff --git a/astrai/config/train_config.py b/astrai/config/train_config.py index f4b3eb6..8cb9dc3 100644 --- a/astrai/config/train_config.py +++ b/astrai/config/train_config.py @@ -1,11 +1,11 @@ -import torch.nn as nn -from torch.utils.data import Dataset -from torch.optim import Optimizer -from torch.optim.lr_scheduler import LRScheduler - from dataclasses import dataclass, field from typing import Callable, List, Optional +import torch.nn as nn +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler +from torch.utils.data import Dataset + @dataclass class TrainConfig: diff --git a/astrai/data/__init__.py b/astrai/data/__init__.py index 6a526cb..02c33a5 100644 --- a/astrai/data/__init__.py +++ b/astrai/data/__init__.py @@ -1,16 +1,15 @@ from astrai.data.dataset import ( BaseDataset, - SEQDataset, + DatasetFactory, + DatasetLoader, DPODataset, - SFTDataset, GRPODataset, MultiSegmentFetcher, - DatasetLoader, - DatasetFactory, + SEQDataset, + SFTDataset, ) - -from astrai.data.tokenizer import BpeTokenizer from astrai.data.sampler import ResumableDistributedSampler +from astrai.data.tokenizer import BpeTokenizer __all__ = [ # Base classes diff --git a/astrai/data/dataset.py b/astrai/data/dataset.py index 60e70d9..258c19f 100644 --- a/astrai/data/dataset.py +++ b/astrai/data/dataset.py @@ -1,13 +1,14 @@ """Dataset implementations with factory pattern for training.""" -import torch import bisect - from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Union + +import torch from torch import Tensor from torch.utils.data import Dataset + from astrai.data.serialization import load_h5 -from typing import List, Dict, Optional, Union class BaseSegmentFetcher: diff --git a/astrai/data/sampler.py b/astrai/data/sampler.py index 82162c5..cd12512 100644 --- a/astrai/data/sampler.py +++ b/astrai/data/sampler.py @@ -1,8 +1,8 @@ +from typing import Optional + import torch import torch.distributed as dist - from torch.utils.data import Dataset, Sampler -from typing import Optional class ResumableDistributedSampler(Sampler[int]): diff --git a/astrai/data/serialization.py b/astrai/data/serialization.py index 4f52687..d5a99c6 100644 --- a/astrai/data/serialization.py +++ b/astrai/data/serialization.py @@ -1,13 +1,14 @@ -import os -import h5py -import torch import json -import safetensors.torch as st -import torch.distributed as dist - +import os from pathlib import Path -from torch import Tensor from typing import Any, Dict, List + +import h5py +import safetensors.torch as st +import torch +import torch.distributed as dist +from torch import Tensor + from astrai.parallel.setup import get_rank diff --git a/astrai/data/tokenizer.py b/astrai/data/tokenizer.py index 1b1ef61..96cab47 100644 --- a/astrai/data/tokenizer.py +++ b/astrai/data/tokenizer.py @@ -1,8 +1,9 @@ from abc import ABC, abstractmethod -from tokenizers import Tokenizer, decoders, processors, normalizers, pre_tokenizers +from typing import List, Union + +from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors from tokenizers.models import BPE from tokenizers.trainers import BpeTrainer as BpeTrainerImpl -from typing import List, Union class BaseTokenizer(ABC): diff --git a/astrai/inference/__init__.py b/astrai/inference/__init__.py index a63d77e..6675be6 100644 --- a/astrai/inference/__init__.py +++ b/astrai/inference/__init__.py @@ -1,16 +1,15 @@ from astrai.inference.core import ( - GeneratorCore, EmbeddingEncoderCore, + GeneratorCore, KVCacheManager, ) - from astrai.inference.generator import ( - GenerationRequest, - LoopGenerator, - StreamGenerator, BatchGenerator, EmbeddingEncoder, + GenerationRequest, GeneratorFactory, + LoopGenerator, + StreamGenerator, ) __all__ = [ diff --git a/astrai/inference/core.py b/astrai/inference/core.py index 185b935..b0952f7 100644 --- a/astrai/inference/core.py +++ b/astrai/inference/core.py @@ -1,8 +1,9 @@ -import torch +from typing import Any, Callable, List, Optional, Self, Tuple, Union +import torch from torch import Tensor -from typing import Any, Callable, List, Tuple, Union, Optional, Self -from astrai.config import ModelParameter, ModelConfig + +from astrai.config import ModelConfig, ModelParameter def apply_sampling_strategies( diff --git a/astrai/inference/generator.py b/astrai/inference/generator.py index 2974c59..1c8b043 100644 --- a/astrai/inference/generator.py +++ b/astrai/inference/generator.py @@ -1,10 +1,11 @@ -import torch from dataclasses import dataclass -from torch import Tensor -from typing import List, Tuple, Union, Optional, Generator -from astrai.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheManager -from astrai.config.param_config import ModelParameter +from typing import Generator, List, Optional, Tuple, Union +import torch +from torch import Tensor + +from astrai.config.param_config import ModelParameter +from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager HistoryType = List[Tuple[str, str]] diff --git a/astrai/inference/server.py b/astrai/inference/server.py index b846f8f..e273257 100644 --- a/astrai/inference/server.py +++ b/astrai/inference/server.py @@ -1,13 +1,15 @@ -import torch -import uvicorn import logging from pathlib import Path -from typing import List, Optional, Dict, Any, Tuple +from typing import Any, Dict, List, Optional, Tuple + +import torch +import uvicorn from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field + from astrai.config.param_config import ModelParameter -from astrai.inference.generator import GeneratorFactory, GenerationRequest +from astrai.inference.generator import GenerationRequest, GeneratorFactory logger = logging.getLogger(__name__) diff --git a/astrai/model/__init__.py b/astrai/model/__init__.py index dccfdc3..516ccd3 100644 --- a/astrai/model/__init__.py +++ b/astrai/model/__init__.py @@ -1,9 +1,9 @@ from astrai.model.module import ( + GQA, + MLP, + DecoderBlock, Linear, RMSNorm, - MLP, - GQA, - DecoderBlock, ) from astrai.model.transformer import Transformer diff --git a/astrai/model/module.py b/astrai/model/module.py index e29eb5f..0d22166 100644 --- a/astrai/model/module.py +++ b/astrai/model/module.py @@ -1,9 +1,9 @@ +from typing import Optional, Tuple + import torch import torch.nn as nn import torch.nn.functional as F - from torch import Tensor -from typing import Optional, Tuple def repeat_kv(x: Tensor, n_rep: int) -> Tensor: diff --git a/astrai/model/transformer.py b/astrai/model/transformer.py index 0608d23..79746bb 100644 --- a/astrai/model/transformer.py +++ b/astrai/model/transformer.py @@ -1,12 +1,13 @@ +from typing import Any, Mapping, Optional, Tuple + import torch import torch.nn as nn - from torch import Tensor -from typing import Any, Mapping, Optional, Tuple + from astrai.config.model_config import ModelConfig from astrai.model.module import ( - Embedding, DecoderBlock, + Embedding, Linear, RMSNorm, RotaryEmbedding, diff --git a/astrai/parallel/__init__.py b/astrai/parallel/__init__.py index 562eff4..03f13bf 100644 --- a/astrai/parallel/__init__.py +++ b/astrai/parallel/__init__.py @@ -1,14 +1,13 @@ +from astrai.parallel.module import ColumnParallelLinear, RowParallelLinear from astrai.parallel.setup import ( - get_world_size, - get_rank, get_current_device, + get_rank, + get_world_size, only_on_rank, setup_parallel, spawn_parallel_fn, ) -from astrai.parallel.module import RowParallelLinear, ColumnParallelLinear - __all__ = [ "get_world_size", "get_rank", diff --git a/astrai/parallel/module.py b/astrai/parallel/module.py index 3dd94e9..8e12493 100644 --- a/astrai/parallel/module.py +++ b/astrai/parallel/module.py @@ -1,10 +1,10 @@ +from typing import Dict + import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F -import torch.distributed as dist - from torch import Tensor -from typing import Dict class ParallelModel(nn.Module): diff --git a/astrai/parallel/setup.py b/astrai/parallel/setup.py index 111fde4..9a17f9c 100644 --- a/astrai/parallel/setup.py +++ b/astrai/parallel/setup.py @@ -1,12 +1,12 @@ import os +from contextlib import contextmanager +from functools import wraps +from typing import Callable, List, Optional + import torch import torch.distributed as dist import torch.multiprocessing as mp -from functools import wraps -from contextlib import contextmanager -from typing import Callable, List, Optional - def get_current_device(): return os.environ["LOCAL_DEVICE"] diff --git a/astrai/trainer/__init__.py b/astrai/trainer/__init__.py index 6165d57..05b7b0c 100644 --- a/astrai/trainer/__init__.py +++ b/astrai/trainer/__init__.py @@ -1,15 +1,14 @@ -from astrai.trainer.trainer import Trainer -from astrai.trainer.strategy import StrategyFactory, BaseStrategy -from astrai.trainer.schedule import SchedulerFactory, BaseScheduler - +from astrai.trainer.schedule import BaseScheduler, SchedulerFactory +from astrai.trainer.strategy import BaseStrategy, StrategyFactory from astrai.trainer.train_callback import ( - TrainCallback, - GradientClippingCallback, - SchedulerCallback, CheckpointCallback, - ProgressBarCallback, + GradientClippingCallback, MetricLoggerCallback, + ProgressBarCallback, + SchedulerCallback, + TrainCallback, ) +from astrai.trainer.trainer import Trainer __all__ = [ # Main trainer diff --git a/astrai/trainer/metric_util.py b/astrai/trainer/metric_util.py index 920b2a6..bea32b6 100644 --- a/astrai/trainer/metric_util.py +++ b/astrai/trainer/metric_util.py @@ -1,6 +1,7 @@ -import torch.nn as nn from typing import Dict +import torch.nn as nn + def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]: """Compute gradient norm for each parameter in the model.""" diff --git a/astrai/trainer/schedule.py b/astrai/trainer/schedule.py index f0eacbd..339a702 100644 --- a/astrai/trainer/schedule.py +++ b/astrai/trainer/schedule.py @@ -1,10 +1,10 @@ """Learning rate scheduler implementations with factory pattern.""" import math -from abc import abstractmethod, ABC +from abc import ABC, abstractmethod from typing import Any, Dict, List, Type + from torch.optim.lr_scheduler import LRScheduler -from astrai.config.schedule_config import ScheduleConfig class BaseScheduler(LRScheduler, ABC): @@ -37,10 +37,6 @@ class SchedulerFactory: ... scheduler = SchedulerFactory.create(optimizer, "custom", **kwargs) - - # Or from config - config = CosineScheduleConfig(total_steps=10000) - scheduler = SchedulerFactory.load(optimizer, config) """ SCHEDULER_MAP: Dict[str, Type[BaseScheduler]] = {} @@ -67,7 +63,7 @@ class SchedulerFactory: return decorator @classmethod - def create(cls, optimizer, schedule_type: str, **kwargs) -> BaseScheduler: + def create(cls, optimizer, schedule_type: str = "none", **kwargs) -> BaseScheduler: """Create a scheduler instance by type name. Args: @@ -90,29 +86,13 @@ class SchedulerFactory: scheduler_cls = cls.SCHEDULER_MAP[schedule_type] return scheduler_cls(optimizer, **kwargs) - @staticmethod - def load(optimizer, schedule_config: ScheduleConfig) -> BaseScheduler: - """Create a scheduler from a ScheduleConfig object. - - Args: - optimizer: PyTorch optimizer - schedule_config: ScheduleConfig instance - - Returns: - Scheduler instance - """ - kwargs = schedule_config.get_kwargs() - schedule_type = kwargs.pop("schedule_type") - return SchedulerFactory.create(optimizer, schedule_type, **kwargs) - @classmethod def available_types(cls) -> list: """Return list of registered scheduler type names.""" return list(cls.SCHEDULER_MAP.keys()) -# ============== Scheduler Classes ============== -# All scheduler classes are registered at class definition time using the decorator +# ----------- Scheduler implementations ----------- @SchedulerFactory.register("cosine") diff --git a/astrai/trainer/strategy.py b/astrai/trainer/strategy.py index 4d973cf..445cd86 100644 --- a/astrai/trainer/strategy.py +++ b/astrai/trainer/strategy.py @@ -1,14 +1,14 @@ """Training strategy implementations with factory pattern.""" import copy +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, Union + import torch import torch.nn as nn import torch.nn.functional as F -from torch.nn.parallel import DistributedDataParallel as DDP - from torch import Tensor -from typing import Any, Callable, Dict, Union -from abc import ABC, abstractmethod +from torch.nn.parallel import DistributedDataParallel as DDP def unwrap_model(model: nn.Module) -> nn.Module: diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index 997cd48..cfa6283 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -1,25 +1,25 @@ -import os import json +import os import time -import torch.nn as nn - from pathlib import Path -from tqdm import tqdm -from torch.nn.utils import clip_grad_norm_ from typing import Callable, List, Optional, Protocol +import torch.nn as nn +from torch.nn.utils import clip_grad_norm_ +from tqdm import tqdm + +from astrai.data.serialization import Checkpoint from astrai.parallel import only_on_rank from astrai.trainer.metric_util import ( + ctx_get_grad_max, + ctx_get_grad_mean, + ctx_get_grad_min, + ctx_get_grad_nan_num, + ctx_get_grad_norm, + ctx_get_grad_std, ctx_get_loss, ctx_get_lr, - ctx_get_grad_max, - ctx_get_grad_min, - ctx_get_grad_norm, - ctx_get_grad_mean, - ctx_get_grad_std, - ctx_get_grad_nan_num, ) -from astrai.data.serialization import Checkpoint from astrai.trainer.train_context import TrainContext diff --git a/astrai/trainer/train_context.py b/astrai/trainer/train_context.py index 7f65b02..864cf52 100644 --- a/astrai/trainer/train_context.py +++ b/astrai/trainer/train_context.py @@ -1,16 +1,16 @@ +from dataclasses import dataclass, field +from typing import Optional, Self + import torch.nn as nn from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler from torch.utils.data import DataLoader +from astrai.config.train_config import TrainConfig from astrai.data import ResumableDistributedSampler from astrai.data.serialization import Checkpoint -from astrai.trainer.strategy import StrategyFactory, BaseStrategy -from astrai.config.train_config import TrainConfig -from astrai.parallel.setup import get_current_device, get_world_size, get_rank - -from dataclasses import dataclass, field -from typing import Optional, Self +from astrai.parallel.setup import get_current_device, get_rank, get_world_size +from astrai.trainer.strategy import BaseStrategy, StrategyFactory @dataclass diff --git a/astrai/trainer/trainer.py b/astrai/trainer/trainer.py index 74c35a6..f2a2389 100644 --- a/astrai/trainer/trainer.py +++ b/astrai/trainer/trainer.py @@ -1,17 +1,18 @@ import logging -from typing import Optional, List +from typing import List, Optional + from astrai.config import TrainConfig -from astrai.trainer.train_callback import ( - TrainCallback, - ProgressBarCallback, - CheckpointCallback, - MetricLoggerCallback, - GradientClippingCallback, - SchedulerCallback, -) -from astrai.trainer.train_context import TrainContext, TrainContextBuilder from astrai.data.serialization import Checkpoint from astrai.parallel.setup import spawn_parallel_fn +from astrai.trainer.train_callback import ( + CheckpointCallback, + GradientClippingCallback, + MetricLoggerCallback, + ProgressBarCallback, + SchedulerCallback, + TrainCallback, +) +from astrai.trainer.train_context import TrainContext, TrainContextBuilder logger = logging.getLogger(__name__) diff --git a/scripts/demo/download.py b/scripts/demo/download.py index ba7aaf3..790c5a1 100644 --- a/scripts/demo/download.py +++ b/scripts/demo/download.py @@ -1,4 +1,5 @@ from pathlib import Path + from huggingface_hub import snapshot_download PROJECT_ROOT = Path(__file__).resolve().parents[2] diff --git a/scripts/demo/generate_ar.py b/scripts/demo/generate_ar.py index d1dfbaf..441a695 100644 --- a/scripts/demo/generate_ar.py +++ b/scripts/demo/generate_ar.py @@ -1,7 +1,9 @@ -import torch from pathlib import Path + +import torch + from astrai.config.param_config import ModelParameter -from astrai.inference.generator import GeneratorFactory, GenerationRequest +from astrai.inference.generator import GenerationRequest, GeneratorFactory PROJECT_ROOT = Path(__file__).resolve().parents[2] PARAMETER_ROOT = Path(PROJECT_ROOT, "params") diff --git a/scripts/demo/generate_batch.py b/scripts/demo/generate_batch.py index d813341..7b87a1f 100644 --- a/scripts/demo/generate_batch.py +++ b/scripts/demo/generate_batch.py @@ -1,7 +1,9 @@ -import torch from pathlib import Path + +import torch + from astrai.config.param_config import ModelParameter -from astrai.inference.generator import GeneratorFactory, GenerationRequest +from astrai.inference.generator import GenerationRequest, GeneratorFactory PROJECT_ROOT = Path(__file__).resolve().parents[2] PARAMETER_ROOT = Path(PROJECT_ROOT, "params") diff --git a/scripts/demo/stream_chat.py b/scripts/demo/stream_chat.py index 937823f..9a84708 100644 --- a/scripts/demo/stream_chat.py +++ b/scripts/demo/stream_chat.py @@ -1,7 +1,9 @@ -import torch from pathlib import Path + +import torch + from astrai.config.param_config import ModelParameter -from astrai.inference.generator import GeneratorFactory, GenerationRequest +from astrai.inference.generator import GenerationRequest, GeneratorFactory PROJECT_ROOT = Path(__file__).resolve().parents[2] PARAMETER_ROOT = Path(PROJECT_ROOT, "params") diff --git a/scripts/tools/benchmark.py b/scripts/tools/benchmark.py index 211cef5..3fa3076 100644 --- a/scripts/tools/benchmark.py +++ b/scripts/tools/benchmark.py @@ -1,6 +1,8 @@ -import torch -from typing import Dict, Any from dataclasses import dataclass +from typing import Any, Dict + +import torch + from astrai.model.transformer import ModelConfig, Transformer diff --git a/scripts/tools/generate.py b/scripts/tools/generate.py index 3d10489..26f3db3 100644 --- a/scripts/tools/generate.py +++ b/scripts/tools/generate.py @@ -1,6 +1,7 @@ -import torch -import json import argparse +import json + +import torch from astrai.config.param_config import ModelParameter from astrai.inference.generator import BatchGenerator, GenerationRequest diff --git a/scripts/tools/perplexity.py b/scripts/tools/perplexity.py index 3fe02f4..a67a231 100644 --- a/scripts/tools/perplexity.py +++ b/scripts/tools/perplexity.py @@ -1,11 +1,12 @@ +import argparse import json + import torch import torch.nn as nn import torch.nn.functional as F -import argparse import tqdm - from torch import Tensor + from astrai.config.param_config import ModelParameter diff --git a/scripts/tools/server.py b/scripts/tools/server.py index 31f3235..fc8151d 100644 --- a/scripts/tools/server.py +++ b/scripts/tools/server.py @@ -1,5 +1,6 @@ import argparse from pathlib import Path + from astrai.inference.server import run_server diff --git a/scripts/tools/train.py b/scripts/tools/train.py index 86260ed..03e238d 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -1,15 +1,16 @@ -import os import argparse +import os +from functools import partial + import torch import torch.nn as nn import torch.optim as optim from torch.nn.parallel import DistributedDataParallel as DDP -from functools import partial +from astrai.config import ModelParameter, TrainConfig from astrai.data import DatasetLoader -from astrai.config import ModelParameter, TrainConfig, CosineScheduleConfig -from astrai.trainer import Trainer, SchedulerFactory from astrai.parallel import get_rank +from astrai.trainer import SchedulerFactory, Trainer def parse_args() -> argparse.Namespace: @@ -158,7 +159,7 @@ def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer: def create_scheduler( optimizer: optim.Optimizer, **kwargs ) -> optim.lr_scheduler.LRScheduler: - return SchedulerFactory.load(optimizer, **kwargs) + return SchedulerFactory.create(optimizer, **kwargs) def prepare_checkpoint(model: nn.Module) -> dict: @@ -211,11 +212,6 @@ def train( stride=stride, ) - schedule_config = CosineScheduleConfig( - warmup_steps=warmup_steps, - total_steps=len(dataset) * n_epoch // (batch_size * nprocs), - ) - optimizer_fn = partial( create_optimizer, **{ @@ -224,7 +220,16 @@ def train( "weight_decay": adamw_weight_decay, }, ) - scheduler_fn = partial(create_scheduler, **{"schedule_config": schedule_config}) + + toltal_steps = len(dataset) * n_epoch // (batch_size * nprocs) + scheduler_fn = partial( + create_scheduler, + **{ + "scheduler": "cosine", + "warmup_steps": warmup_steps, + "lr_decay_steps": toltal_steps - warmup_steps, + }, + ) train_config = TrainConfig( model=model, diff --git a/tests/conftest.py b/tests/conftest.py index 5280650..ed44218 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,12 @@ -import os import json -import numpy as np -import tempfile +import os import shutil -import torch +import tempfile + +import numpy as np import pytest import safetensors.torch as st +import torch from tokenizers import pre_tokenizers from torch.utils.data import Dataset diff --git a/tests/data/test_checkpoint.py b/tests/data/test_checkpoint.py index 33ccb66..f261a13 100644 --- a/tests/data/test_checkpoint.py +++ b/tests/data/test_checkpoint.py @@ -1,9 +1,10 @@ -import torch import tempfile -import torch.distributed as dist +import torch +import torch.distributed as dist from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR + from astrai.data.serialization import Checkpoint from astrai.parallel.setup import get_rank, spawn_parallel_fn diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index 27a077b..b54013d 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -1,8 +1,8 @@ -import torch import numpy as np +import torch -from astrai.data.serialization import save_h5 from astrai.data.dataset import * +from astrai.data.serialization import save_h5 def test_dataset_loader_random_paths(base_test_env): diff --git a/tests/data/test_sampler.py b/tests/data/test_sampler.py index 2c461c1..8ac821c 100644 --- a/tests/data/test_sampler.py +++ b/tests/data/test_sampler.py @@ -1,5 +1,5 @@ -from astrai.trainer import * from astrai.data import * +from astrai.trainer import * def test_random_sampler_consistency(random_dataset): diff --git a/tests/inference/conftest.py b/tests/inference/conftest.py index 37b6b53..5b0144c 100644 --- a/tests/inference/conftest.py +++ b/tests/inference/conftest.py @@ -1,8 +1,10 @@ """Shared fixtures for inference tests.""" -import pytest from unittest.mock import MagicMock, patch + +import pytest from fastapi.testclient import TestClient + from astrai.inference.server import app diff --git a/tests/inference/test_server.py b/tests/inference/test_server.py index 9b07901..1e1bb92 100644 --- a/tests/inference/test_server.py +++ b/tests/inference/test_server.py @@ -1,9 +1,6 @@ """Unit tests for the inference HTTP server.""" import pytest -from unittest.mock import MagicMock, patch -from fastapi.testclient import TestClient -from astrai.inference.server import app def test_health_no_model(client, monkeypatch): diff --git a/tests/module/test_module.py b/tests/module/test_module.py index 6304989..19e63ca 100644 --- a/tests/module/test_module.py +++ b/tests/module/test_module.py @@ -1,10 +1,12 @@ import os + import torch -from astrai.trainer import * + from astrai.config import * -from astrai.model import * from astrai.data import * from astrai.inference.generator import EmbeddingEncoderCore, GeneratorCore +from astrai.model import * +from astrai.trainer import * def test_model_parameter(test_env): diff --git a/tests/module/test_tie_weight.py b/tests/module/test_tie_weight.py index 63f71f0..143fd68 100644 --- a/tests/module/test_tie_weight.py +++ b/tests/module/test_tie_weight.py @@ -1,11 +1,13 @@ -import os import json -import torch -import pytest +import os import tempfile + +import pytest import safetensors.torch as st -from astrai.model.transformer import Transformer +import torch + from astrai.config.model_config import ModelConfig +from astrai.model.transformer import Transformer @pytest.fixture diff --git a/tests/trainer/conftest.py b/tests/trainer/conftest.py index e4bd006..2a086a0 100644 --- a/tests/trainer/conftest.py +++ b/tests/trainer/conftest.py @@ -1,6 +1,9 @@ +import pytest import torch from torch.utils.data import Dataset -import pytest + +from astrai.config import TrainConfig +from astrai.trainer.schedule import SchedulerFactory class TrainerDataset(Dataset): @@ -54,13 +57,11 @@ def create_train_config( Returns: TrainConfig instance configured for testing """ - from astrai.config import TrainConfig - from astrai.config.schedule_config import CosineScheduleConfig - from astrai.trainer.schedule import SchedulerFactory - schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20) optimizer_fn = lambda m: torch.optim.AdamW(m.parameters(), lr=0.001) - scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config) + scheduler_fn = lambda optim: SchedulerFactory.create( + optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05 + ) return TrainConfig( strategy=strategy, diff --git a/tests/trainer/test_callbacks.py b/tests/trainer/test_callbacks.py index 58eda2e..afd1c57 100644 --- a/tests/trainer/test_callbacks.py +++ b/tests/trainer/test_callbacks.py @@ -6,10 +6,10 @@ from astrai.trainer import * def test_callback_integration(base_test_env, random_dataset): """Test that all callbacks are properly integrated""" - schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20) - optimizer_fn = lambda model: torch.optim.AdamW(model.parameters()) - scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config) + scheduler_fn = lambda optim: SchedulerFactory.create( + optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05 + ) train_config = TrainConfig( model=base_test_env["model"], diff --git a/tests/trainer/test_early_stopping.py b/tests/trainer/test_early_stopping.py index a530cdb..58d2bb5 100644 --- a/tests/trainer/test_early_stopping.py +++ b/tests/trainer/test_early_stopping.py @@ -1,18 +1,20 @@ import os -import torch + import numpy as np +import torch + from astrai.config import * -from astrai.trainer import * from astrai.data.serialization import Checkpoint +from astrai.trainer import * def test_early_stopping_simulation(base_test_env, early_stopping_dataset): """Simulate early stopping behavior""" - schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20) - optimizer_fn = lambda model: torch.optim.AdamW(model.parameters()) - scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config) + scheduler_fn = lambda optim: SchedulerFactory.create( + optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05 + ) train_config = TrainConfig( strategy="seq", diff --git a/tests/trainer/test_train_strategy.py b/tests/trainer/test_train_strategy.py index 339476e..7138312 100644 --- a/tests/trainer/test_train_strategy.py +++ b/tests/trainer/test_train_strategy.py @@ -1,10 +1,9 @@ -import torch import numpy as np -import pytest +import torch from astrai.config import * -from astrai.trainer.schedule import * from astrai.data.dataset import * +from astrai.trainer.schedule import * def test_schedule_factory_random_configs(): @@ -16,41 +15,57 @@ def test_schedule_factory_random_configs(): # Test multiple random configurations for _ in range(5): # Test 5 random configurations - schedule_configs = [ - CosineScheduleConfig( - warmup_steps=np.random.randint(50, 200), - total_steps=np.random.randint(1000, 5000), - min_rate=np.random.uniform(0.01, 0.1), - ), - SGDRScheduleConfig( - warmup_steps=np.random.randint(50, 200), - cycle_length=np.random.randint(500, 2000), - t_mult=np.random.randint(1, 3), - min_rate=np.random.uniform(0.01, 0.1), - ), - ] - - for config in schedule_configs: - # Validate configuration - config.validate() - - # Create scheduler using factory - scheduler = SchedulerFactory.load(optimizer, config) - - # Verify scheduler type - if isinstance(config, CosineScheduleConfig): - assert isinstance(scheduler, CosineScheduler) - assert scheduler.warmup_steps == config.warmup_steps - assert ( - scheduler.lr_decay_steps == config.total_steps - config.warmup_steps + # Test multiple random configurations + cosine_params = { + "schedule_type": "cosine", + "warmup_steps": np.random.randint(50, 200), + "total_steps": np.random.randint(1000, 5000), + "min_rate": np.random.uniform(0.01, 0.1), + } + sgdr_params = { + "schedule_type": "sgdr", + "warmup_steps": np.random.randint(50, 200), + "cycle_length": np.random.randint(500, 2000), + "t_mult": np.random.randint(1, 3), + "min_rate": np.random.uniform(0.01, 0.1), + } + for params in [cosine_params, sgdr_params]: + schedule_type = params["schedule_type"] + # Convert parameters for scheduler constructor + if schedule_type == "cosine": + warmup_steps = params["warmup_steps"] + total_steps = params["total_steps"] + min_rate = params["min_rate"] + lr_decay_steps = total_steps - warmup_steps + scheduler = SchedulerFactory.create( + optimizer, + schedule_type, + warmup_steps=warmup_steps, + lr_decay_steps=lr_decay_steps, + min_rate=min_rate, + ) + assert isinstance(scheduler, CosineScheduler) + assert scheduler.warmup_steps == warmup_steps + assert scheduler.lr_decay_steps == lr_decay_steps + assert scheduler.min_rate == min_rate + elif schedule_type == "sgdr": + warmup_steps = params["warmup_steps"] + cycle_length = params["cycle_length"] + t_mult = params["t_mult"] + min_rate = params["min_rate"] + scheduler = SchedulerFactory.create( + optimizer, + schedule_type, + warmup_steps=warmup_steps, + cycle_length=cycle_length, + t_mult=t_mult, + min_rate=min_rate, ) - assert scheduler.min_rate == config.min_rate - elif isinstance(config, SGDRScheduleConfig): assert isinstance(scheduler, SGDRScheduler) - assert scheduler.warmup_steps == config.warmup_steps - assert scheduler.cycle_length == config.cycle_length - assert scheduler.t_mult == config.t_mult - assert scheduler.min_rate == config.min_rate + assert scheduler.warmup_steps == warmup_steps + assert scheduler.cycle_length == cycle_length + assert scheduler.t_mult == t_mult + assert scheduler.min_rate == min_rate # Test scheduler state dict functionality state_dict = scheduler.state_dict() @@ -76,16 +91,25 @@ def test_schedule_factory_edge_cases(): # Test edge cases for CosineScheduleConfig edge_cases = [ # Minimal warmup and steps - CosineScheduleConfig(warmup_steps=1, total_steps=10, min_rate=0.01), + {"warmup_steps": 1, "total_steps": 10, "min_rate": 0.01}, # Large values - CosineScheduleConfig(warmup_steps=1000, total_steps=10000, min_rate=0.5), + {"warmup_steps": 1000, "total_steps": 10000, "min_rate": 0.5}, # Zero min_rate (edge case) - CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.0), + {"warmup_steps": 100, "total_steps": 1000, "min_rate": 0.0}, ] - for config in edge_cases: - config.validate() - scheduler = SchedulerFactory.load(optimizer, config) + for params in edge_cases: + warmup_steps = params["warmup_steps"] + total_steps = params["total_steps"] + min_rate = params["min_rate"] + lr_decay_steps = total_steps - warmup_steps + scheduler = SchedulerFactory.create( + optimizer, + "cosine", + warmup_steps=warmup_steps, + lr_decay_steps=lr_decay_steps, + min_rate=min_rate, + ) assert scheduler is not None # Test multiple steps @@ -93,34 +117,24 @@ def test_schedule_factory_edge_cases(): scheduler.step() -def test_schedule_factory_invalid_configs(): - """Test scheduler factory with invalid configurations""" - - # Test invalid configurations that should raise errors - invalid_configs = [ - # Negative warmup steps - {"warmup_steps": -10, "total_steps": 1000, "min_rate": 0.1}, - # Total steps less than warmup steps - {"warmup_steps": 500, "total_steps": 400, "min_rate": 0.1}, - # Invalid min_rate - {"warmup_steps": 100, "total_steps": 1000, "min_rate": -0.1}, - {"warmup_steps": 100, "total_steps": 1000, "min_rate": 1.1}, - ] - - for kwargs in invalid_configs: - with pytest.raises(ValueError): - config = CosineScheduleConfig(**kwargs) - config.validate() - - def test_schedule_factory_state_persistence(): """Test scheduler state persistence (save/load)""" model = torch.nn.Linear(10, 2) optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) - config = CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.1) - scheduler = SchedulerFactory.load(optimizer, config) + # Create scheduler directly with parameters + warmup_steps = 100 + total_steps = 1000 + min_rate = 0.1 + lr_decay_steps = total_steps - warmup_steps + scheduler = SchedulerFactory.create( + optimizer, + "cosine", + warmup_steps=warmup_steps, + lr_decay_steps=lr_decay_steps, + min_rate=min_rate, + ) # Take a few steps for _ in range(5): @@ -129,8 +143,14 @@ def test_schedule_factory_state_persistence(): # Save state state_dict = scheduler.state_dict() - # Create new scheduler and load state - new_scheduler = SchedulerFactory.load(optimizer, config) + # Create new scheduler with same parameters + new_scheduler = SchedulerFactory.create( + optimizer, + "cosine", + warmup_steps=warmup_steps, + lr_decay_steps=lr_decay_steps, + min_rate=min_rate, + ) new_scheduler.load_state_dict(state_dict) # Verify states match diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index b0fad1f..c17b5c7 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1,5 +1,3 @@ -import torch - from astrai.data.dataset import * from astrai.trainer import Trainer