From 3346c755846c4e536b165b5c9631f9c069128ac4 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 4 Apr 2026 15:49:46 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=E5=B7=A5=E5=8E=82?= =?UTF-8?q?=E6=A8=A1=E5=BC=8F=E7=9A=84=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrai/__init__.py | 2 +- astrai/core/factory.py | 105 ----------------- astrai/data/__init__.py | 1 - astrai/data/dataset.py | 4 +- astrai/factory.py | 187 +++++++++++++++++++++++++++++++ astrai/inference/generator.py | 4 +- astrai/trainer/__init__.py | 14 +-- astrai/trainer/schedule.py | 4 +- astrai/trainer/strategy.py | 4 +- astrai/trainer/train_callback.py | 27 ++++- astrai/trainer/trainer.py | 18 ++- 11 files changed, 228 insertions(+), 142 deletions(-) delete mode 100644 astrai/core/factory.py create mode 100644 astrai/factory.py diff --git a/astrai/__init__.py b/astrai/__init__.py index 4eeda4f..a8191f4 100644 --- a/astrai/__init__.py +++ b/astrai/__init__.py @@ -5,7 +5,7 @@ from astrai.config import ( ModelConfig, TrainConfig, ) -from astrai.core.factory import BaseFactory +from astrai.factory import BaseFactory from astrai.data import BpeTokenizer, DatasetFactory from astrai.inference.generator import ( BatchGenerator, diff --git a/astrai/core/factory.py b/astrai/core/factory.py deleted file mode 100644 index 2c61ac4..0000000 --- a/astrai/core/factory.py +++ /dev/null @@ -1,105 +0,0 @@ -"""Base factory class for extensible component registration.""" - -from abc import ABC -from typing import Callable, Dict, Generic, Type, TypeVar - -T = TypeVar("T") - - -class BaseFactory(ABC, Generic[T]): - """Generic factory class for component registration and creation. - - This base class provides a decorator-based registration pattern - for creating extensible component factories. - - Example usage: - class MyFactory(BaseFactory[MyBaseClass]): - pass - - @MyFactory.register("custom") - class CustomComponent(MyBaseClass): - ... - - component = MyFactory.create("custom", *args, **kwargs) - """ - - _registry: Dict[str, Type[T]] = {} - - @classmethod - def register(cls, name: str) -> Callable[[Type[T]], Type[T]]: - """Decorator to register a component class. - - Args: - name: Registration name for the component - - Returns: - Decorator function that registers the component class - - Raises: - TypeError: If the decorated class doesn't inherit from the base type - """ - - def decorator(component_cls: Type[T]) -> Type[T]: - cls._validate_component(component_cls) - cls._registry[name] = component_cls - return component_cls - - return decorator - - @classmethod - def create(cls, name: str, *args, **kwargs) -> T: - """Create a component instance by name. - - Args: - name: Registered name of the component - *args: Positional arguments passed to component constructor - **kwargs: Keyword arguments passed to component constructor - - Returns: - Component instance - - Raises: - ValueError: If the component name is not registered - """ - if name not in cls._registry: - raise ValueError( - f"Unknown component: '{name}'. " - f"Supported types: {sorted(cls._registry.keys())}" - ) - component_cls = cls._registry[name] - return component_cls(*args, **kwargs) - - @classmethod - def _validate_component(cls, component_cls: Type[T]) -> None: - """Validate that the component class is valid for this factory. - - Override this method in subclasses to add custom validation. - - Args: - component_cls: Component class to validate - - Raises: - TypeError: If the component class is invalid - """ - pass - - @classmethod - def list_registered(cls) -> list: - """List all registered component names. - - Returns: - List of registered component names - """ - return sorted(cls._registry.keys()) - - @classmethod - def is_registered(cls, name: str) -> bool: - """Check if a component name is registered. - - Args: - name: Component name to check - - Returns: - True if registered, False otherwise - """ - return name in cls._registry diff --git a/astrai/data/__init__.py b/astrai/data/__init__.py index 7cc418d..337788c 100644 --- a/astrai/data/__init__.py +++ b/astrai/data/__init__.py @@ -1,7 +1,6 @@ from astrai.data.dataset import ( BaseDataset, DatasetFactory, - DatasetFactory, DPODataset, GRPODataset, MultiSegmentFetcher, diff --git a/astrai/data/dataset.py b/astrai/data/dataset.py index 395ec2a..25a289b 100644 --- a/astrai/data/dataset.py +++ b/astrai/data/dataset.py @@ -8,7 +8,7 @@ import torch from torch import Tensor from torch.utils.data import Dataset -from astrai.core.factory import BaseFactory +from astrai.factory import BaseFactory from astrai.data.serialization import load_h5 @@ -181,8 +181,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]): dataset = DatasetFactory.create("custom", window_size, stride) """ - _registry: Dict[str, type] = {} - @classmethod def _validate_component(cls, dataset_cls: type) -> None: """Validate that the dataset class inherits from BaseDataset.""" diff --git a/astrai/factory.py b/astrai/factory.py new file mode 100644 index 0000000..6c88c15 --- /dev/null +++ b/astrai/factory.py @@ -0,0 +1,187 @@ +"""Base factory class for extensible component registration.""" + +from abc import ABC +from typing import Callable, Dict, Generic, Type, TypeVar, Optional, List, Tuple + +T = TypeVar("T") + + +class Registry: + """Flexible registry for component classes with category and priority support. + + This registry stores component classes with optional metadata (category, priority). + It provides methods for registration, retrieval, and listing with filtering. + """ + + def __init__(self): + self._entries = {} # name -> (component_cls, category, priority) + + def register( + self, + name: str, + component_cls: Type, + category: Optional[str] = None, + priority: int = 0, + ) -> None: + """Register a component class with optional category and priority.""" + if name in self._entries: + raise ValueError(f"Component '{name}' is already registered") + self._entries[name] = (component_cls, category, priority) + + def get(self, name: str) -> Type: + """Get component class by name.""" + if name not in self._entries: + raise KeyError(f"Component '{name}' not found in registry") + return self._entries[name][0] + + def get_with_metadata(self, name: str) -> Tuple[Type, Optional[str], int]: + """Get component class with its metadata.""" + entry = self._entries.get(name) + if entry is None: + raise KeyError(f"Component '{name}' not found in registry") + return entry + + def contains(self, name: str) -> bool: + """Check if a name is registered.""" + return name in self._entries + + def list_names(self) -> List[str]: + """Return list of registered component names.""" + return sorted(self._entries.keys()) + + def list_by_category(self, category: str) -> List[str]: + """Return names of components belonging to a specific category.""" + return sorted( + name for name, (_, cat, _) in self._entries.items() if cat == category + ) + + def list_by_priority(self, reverse: bool = False) -> List[str]: + """Return names sorted by priority (default ascending).""" + return sorted( + self._entries.keys(), + key=lambda name: self._entries[name][2], + reverse=reverse, + ) + + def entries(self) -> Dict[str, Tuple[Type, Optional[str], int]]: + """Return raw entries dictionary.""" + return self._entries.copy() + + +class BaseFactory(ABC, Generic[T]): + """Generic factory class for component registration and creation. + + This base class provides a decorator-based registration pattern + for creating extensible component factories. + + Example usage: + class MyFactory(BaseFactory[MyBaseClass]): + pass + + @MyFactory.register("custom") + class CustomComponent(MyBaseClass): + ... + + component = MyFactory.create("custom", *args, **kwargs) + """ + + _registry: Registry + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + cls._registry = Registry() + + @classmethod + def register( + cls, name: str, category: Optional[str] = None, priority: int = 0 + ) -> Callable[[Type[T]], Type[T]]: + """Decorator to register a component class with optional category and priority. + + Args: + name: Registration name for the component + category: Optional category for grouping components + priority: Priority for ordering (default 0) + + Returns: + Decorator function that registers the component class + + Raises: + TypeError: If the decorated class doesn't inherit from the base type + """ + + def decorator(component_cls: Type[T]) -> Type[T]: + cls._validate_component(component_cls) + cls._registry.register( + name, component_cls, category=category, priority=priority + ) + return component_cls + + return decorator + + @classmethod + def create(cls, name: str, *args, **kwargs) -> T: + """Create a component instance by name. + + Args: + name: Registered name of the component + *args: Positional arguments passed to component constructor + **kwargs: Keyword arguments passed to component constructor + + Returns: + Component instance + + Raises: + ValueError: If the component name is not registered + """ + if not cls._registry.contains(name): + raise ValueError( + f"Unknown component: '{name}'. " + f"Supported types: {sorted(cls._registry.list_names())}" + ) + component_cls = cls._registry.get(name) + return component_cls(*args, **kwargs) + + @classmethod + def _validate_component(cls, component_cls: Type[T]) -> None: + """Validate that the component class is valid for this factory. + + Override this method in subclasses to add custom validation. + + Args: + component_cls: Component class to validate + + Raises: + TypeError: If the component class is invalid + """ + pass + + @classmethod + def list_registered(cls) -> list: + """List all registered component names. + + Returns: + List of registered component names + """ + return cls._registry.list_names() + + @classmethod + def is_registered(cls, name: str) -> bool: + """Check if a component name is registered. + + Args: + name: Component name to check + + Returns: + True if registered, False otherwise + """ + return cls._registry.contains(name) + + @classmethod + def list_by_category(cls, category: str) -> List[str]: + """List registered component names in a category.""" + return cls._registry.list_by_category(category) + + @classmethod + def list_by_priority(cls, reverse: bool = False) -> List[str]: + """List registered component names sorted by priority.""" + return cls._registry.list_by_priority(reverse) diff --git a/astrai/inference/generator.py b/astrai/inference/generator.py index 67d76cf..397e5e8 100644 --- a/astrai/inference/generator.py +++ b/astrai/inference/generator.py @@ -6,7 +6,7 @@ from jinja2 import Template from torch import Tensor from astrai.config.param_config import ModelParameter -from astrai.core.factory import BaseFactory +from astrai.factory import BaseFactory from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager HistoryType = List[Tuple[str, str]] @@ -299,8 +299,6 @@ class GeneratorFactory(BaseFactory[GeneratorCore]): result = generator.generate(request) """ - _registry: Dict[str, type] = {} - @staticmethod def create(parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore: """Create a generator based on request characteristics. diff --git a/astrai/trainer/__init__.py b/astrai/trainer/__init__.py index 05b7b0c..0e4485f 100644 --- a/astrai/trainer/__init__.py +++ b/astrai/trainer/__init__.py @@ -1,12 +1,8 @@ from astrai.trainer.schedule import BaseScheduler, SchedulerFactory from astrai.trainer.strategy import BaseStrategy, StrategyFactory from astrai.trainer.train_callback import ( - CheckpointCallback, - GradientClippingCallback, - MetricLoggerCallback, - ProgressBarCallback, - SchedulerCallback, TrainCallback, + CallbackFactory, ) from astrai.trainer.trainer import Trainer @@ -19,11 +15,7 @@ __all__ = [ # Scheduler factory "SchedulerFactory", "BaseScheduler", - # Callbacks + # Callback factory "TrainCallback", - "GradientClippingCallback", - "SchedulerCallback", - "CheckpointCallback", - "ProgressBarCallback", - "MetricLoggerCallback", + "CallbackFactory", ] diff --git a/astrai/trainer/schedule.py b/astrai/trainer/schedule.py index 0ca166a..9727bf0 100644 --- a/astrai/trainer/schedule.py +++ b/astrai/trainer/schedule.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Type from torch.optim.lr_scheduler import LRScheduler -from astrai.core.factory import BaseFactory +from astrai.factory import BaseFactory class BaseScheduler(LRScheduler, ABC): @@ -41,8 +41,6 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]): scheduler = SchedulerFactory.create("custom", optimizer, **kwargs) """ - _registry: Dict[str, Type[BaseScheduler]] = {} - @classmethod def _validate_component(cls, scheduler_cls: Type[BaseScheduler]) -> None: """Validate that the scheduler class inherits from BaseScheduler.""" diff --git a/astrai/trainer/strategy.py b/astrai/trainer/strategy.py index d6a27e8..fd3d1d5 100644 --- a/astrai/trainer/strategy.py +++ b/astrai/trainer/strategy.py @@ -10,7 +10,7 @@ import torch.nn.functional as F from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP -from astrai.core.factory import BaseFactory +from astrai.factory import BaseFactory def unwrap_model(model: nn.Module) -> nn.Module: @@ -122,8 +122,6 @@ class StrategyFactory(BaseFactory["BaseStrategy"]): strategy = StrategyFactory.create("custom", model, device) """ - _registry: Dict[str, type] = {} - @classmethod def _validate_component(cls, strategy_cls: type) -> None: """Validate that the strategy class inherits from BaseStrategy.""" diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index cfa6283..6ca09ae 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -2,7 +2,7 @@ import json import os import time from pathlib import Path -from typing import Callable, List, Optional, Protocol +from typing import Callable, List, Optional, Protocol, runtime_checkable import torch.nn as nn from torch.nn.utils import clip_grad_norm_ @@ -21,8 +21,10 @@ from astrai.trainer.metric_util import ( ctx_get_lr, ) from astrai.trainer.train_context import TrainContext +from astrai.factory import BaseFactory +@runtime_checkable class TrainCallback(Protocol): """ Callback interface for trainer. @@ -56,6 +58,25 @@ class TrainCallback(Protocol): """Called when an error occurs during training.""" +class CallbackFactory(BaseFactory[TrainCallback]): + """Factory for registering and creating training callbacks. + + Example: + @CallbackFactory.register("my_callback") + class MyCallback(TrainCallback): + ... + + callback = CallbackFactory.create("my_callback", **kwargs) + """ + + @classmethod + def _validate_component(cls, callback_cls: type) -> None: + """Validate that the callback class inherits from TrainCallback.""" + if not issubclass(callback_cls, TrainCallback): + raise TypeError(f"{callback_cls.__name__} must inherit from TrainCallback") + + +@CallbackFactory.register("gradient_clipping") class GradientClippingCallback(TrainCallback): """ Gradient clipping callback for trainer. @@ -69,6 +90,7 @@ class GradientClippingCallback(TrainCallback): clip_grad_norm_(context.model.parameters(), self.max_grad_norm) +@CallbackFactory.register("scheduler") class SchedulerCallback(TrainCallback): """ Scheduler callback for trainer. @@ -87,6 +109,7 @@ class SchedulerCallback(TrainCallback): context.scheduler.step() +@CallbackFactory.register("checkpoint") class CheckpointCallback(TrainCallback): """ Checkpoint callback for trainer. @@ -135,6 +158,7 @@ class CheckpointCallback(TrainCallback): self._save_checkpoint(context) +@CallbackFactory.register("progress_bar") class ProgressBarCallback(TrainCallback): """ Progress bar callback for trainer. @@ -169,6 +193,7 @@ class ProgressBarCallback(TrainCallback): self.progress_bar.close() +@CallbackFactory.register("metric_logger") class MetricLoggerCallback(TrainCallback): def __init__( self, diff --git a/astrai/trainer/trainer.py b/astrai/trainer/trainer.py index f2a2389..189e1dd 100644 --- a/astrai/trainer/trainer.py +++ b/astrai/trainer/trainer.py @@ -5,12 +5,8 @@ from astrai.config import TrainConfig from astrai.data.serialization import Checkpoint from astrai.parallel.setup import spawn_parallel_fn from astrai.trainer.train_callback import ( - CheckpointCallback, - GradientClippingCallback, - MetricLoggerCallback, - ProgressBarCallback, - SchedulerCallback, TrainCallback, + CallbackFactory, ) from astrai.trainer.train_context import TrainContext, TrainContextBuilder @@ -28,13 +24,13 @@ class Trainer: ) def _get_default_callbacks(self) -> List[TrainCallback]: - train_config = self.train_config + cfg = self.train_config return [ - ProgressBarCallback(train_config.n_epoch), - CheckpointCallback(train_config.ckpt_dir, train_config.ckpt_interval), - MetricLoggerCallback(train_config.ckpt_dir, train_config.ckpt_interval), - GradientClippingCallback(train_config.max_grad_norm), - SchedulerCallback(), + CallbackFactory.create("progress_bar", cfg.n_epoch), + CallbackFactory.create("checkpoint", cfg.ckpt_dir, cfg.ckpt_interval), + CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval), + CallbackFactory.create("gradient_clipping", cfg.max_grad_norm), + CallbackFactory.create("scheduler"), ] def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext: