feat: 优化工厂模式的实现

This commit is contained in:
ViperEkura 2026-04-04 15:49:46 +08:00
parent aa5e03d7f6
commit 3346c75584
11 changed files with 228 additions and 142 deletions

View File

@ -5,7 +5,7 @@ from astrai.config import (
ModelConfig,
TrainConfig,
)
from astrai.core.factory import BaseFactory
from astrai.factory import BaseFactory
from astrai.data import BpeTokenizer, DatasetFactory
from astrai.inference.generator import (
BatchGenerator,

View File

@ -1,105 +0,0 @@
"""Base factory class for extensible component registration."""
from abc import ABC
from typing import Callable, Dict, Generic, Type, TypeVar
T = TypeVar("T")
class BaseFactory(ABC, Generic[T]):
"""Generic factory class for component registration and creation.
This base class provides a decorator-based registration pattern
for creating extensible component factories.
Example usage:
class MyFactory(BaseFactory[MyBaseClass]):
pass
@MyFactory.register("custom")
class CustomComponent(MyBaseClass):
...
component = MyFactory.create("custom", *args, **kwargs)
"""
_registry: Dict[str, Type[T]] = {}
@classmethod
def register(cls, name: str) -> Callable[[Type[T]], Type[T]]:
"""Decorator to register a component class.
Args:
name: Registration name for the component
Returns:
Decorator function that registers the component class
Raises:
TypeError: If the decorated class doesn't inherit from the base type
"""
def decorator(component_cls: Type[T]) -> Type[T]:
cls._validate_component(component_cls)
cls._registry[name] = component_cls
return component_cls
return decorator
@classmethod
def create(cls, name: str, *args, **kwargs) -> T:
"""Create a component instance by name.
Args:
name: Registered name of the component
*args: Positional arguments passed to component constructor
**kwargs: Keyword arguments passed to component constructor
Returns:
Component instance
Raises:
ValueError: If the component name is not registered
"""
if name not in cls._registry:
raise ValueError(
f"Unknown component: '{name}'. "
f"Supported types: {sorted(cls._registry.keys())}"
)
component_cls = cls._registry[name]
return component_cls(*args, **kwargs)
@classmethod
def _validate_component(cls, component_cls: Type[T]) -> None:
"""Validate that the component class is valid for this factory.
Override this method in subclasses to add custom validation.
Args:
component_cls: Component class to validate
Raises:
TypeError: If the component class is invalid
"""
pass
@classmethod
def list_registered(cls) -> list:
"""List all registered component names.
Returns:
List of registered component names
"""
return sorted(cls._registry.keys())
@classmethod
def is_registered(cls, name: str) -> bool:
"""Check if a component name is registered.
Args:
name: Component name to check
Returns:
True if registered, False otherwise
"""
return name in cls._registry

View File

@ -1,7 +1,6 @@
from astrai.data.dataset import (
BaseDataset,
DatasetFactory,
DatasetFactory,
DPODataset,
GRPODataset,
MultiSegmentFetcher,

View File

@ -8,7 +8,7 @@ import torch
from torch import Tensor
from torch.utils.data import Dataset
from astrai.core.factory import BaseFactory
from astrai.factory import BaseFactory
from astrai.data.serialization import load_h5
@ -181,8 +181,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
dataset = DatasetFactory.create("custom", window_size, stride)
"""
_registry: Dict[str, type] = {}
@classmethod
def _validate_component(cls, dataset_cls: type) -> None:
"""Validate that the dataset class inherits from BaseDataset."""

187
astrai/factory.py Normal file
View File

@ -0,0 +1,187 @@
"""Base factory class for extensible component registration."""
from abc import ABC
from typing import Callable, Dict, Generic, Type, TypeVar, Optional, List, Tuple
T = TypeVar("T")
class Registry:
"""Flexible registry for component classes with category and priority support.
This registry stores component classes with optional metadata (category, priority).
It provides methods for registration, retrieval, and listing with filtering.
"""
def __init__(self):
self._entries = {} # name -> (component_cls, category, priority)
def register(
self,
name: str,
component_cls: Type,
category: Optional[str] = None,
priority: int = 0,
) -> None:
"""Register a component class with optional category and priority."""
if name in self._entries:
raise ValueError(f"Component '{name}' is already registered")
self._entries[name] = (component_cls, category, priority)
def get(self, name: str) -> Type:
"""Get component class by name."""
if name not in self._entries:
raise KeyError(f"Component '{name}' not found in registry")
return self._entries[name][0]
def get_with_metadata(self, name: str) -> Tuple[Type, Optional[str], int]:
"""Get component class with its metadata."""
entry = self._entries.get(name)
if entry is None:
raise KeyError(f"Component '{name}' not found in registry")
return entry
def contains(self, name: str) -> bool:
"""Check if a name is registered."""
return name in self._entries
def list_names(self) -> List[str]:
"""Return list of registered component names."""
return sorted(self._entries.keys())
def list_by_category(self, category: str) -> List[str]:
"""Return names of components belonging to a specific category."""
return sorted(
name for name, (_, cat, _) in self._entries.items() if cat == category
)
def list_by_priority(self, reverse: bool = False) -> List[str]:
"""Return names sorted by priority (default ascending)."""
return sorted(
self._entries.keys(),
key=lambda name: self._entries[name][2],
reverse=reverse,
)
def entries(self) -> Dict[str, Tuple[Type, Optional[str], int]]:
"""Return raw entries dictionary."""
return self._entries.copy()
class BaseFactory(ABC, Generic[T]):
"""Generic factory class for component registration and creation.
This base class provides a decorator-based registration pattern
for creating extensible component factories.
Example usage:
class MyFactory(BaseFactory[MyBaseClass]):
pass
@MyFactory.register("custom")
class CustomComponent(MyBaseClass):
...
component = MyFactory.create("custom", *args, **kwargs)
"""
_registry: Registry
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._registry = Registry()
@classmethod
def register(
cls, name: str, category: Optional[str] = None, priority: int = 0
) -> Callable[[Type[T]], Type[T]]:
"""Decorator to register a component class with optional category and priority.
Args:
name: Registration name for the component
category: Optional category for grouping components
priority: Priority for ordering (default 0)
Returns:
Decorator function that registers the component class
Raises:
TypeError: If the decorated class doesn't inherit from the base type
"""
def decorator(component_cls: Type[T]) -> Type[T]:
cls._validate_component(component_cls)
cls._registry.register(
name, component_cls, category=category, priority=priority
)
return component_cls
return decorator
@classmethod
def create(cls, name: str, *args, **kwargs) -> T:
"""Create a component instance by name.
Args:
name: Registered name of the component
*args: Positional arguments passed to component constructor
**kwargs: Keyword arguments passed to component constructor
Returns:
Component instance
Raises:
ValueError: If the component name is not registered
"""
if not cls._registry.contains(name):
raise ValueError(
f"Unknown component: '{name}'. "
f"Supported types: {sorted(cls._registry.list_names())}"
)
component_cls = cls._registry.get(name)
return component_cls(*args, **kwargs)
@classmethod
def _validate_component(cls, component_cls: Type[T]) -> None:
"""Validate that the component class is valid for this factory.
Override this method in subclasses to add custom validation.
Args:
component_cls: Component class to validate
Raises:
TypeError: If the component class is invalid
"""
pass
@classmethod
def list_registered(cls) -> list:
"""List all registered component names.
Returns:
List of registered component names
"""
return cls._registry.list_names()
@classmethod
def is_registered(cls, name: str) -> bool:
"""Check if a component name is registered.
Args:
name: Component name to check
Returns:
True if registered, False otherwise
"""
return cls._registry.contains(name)
@classmethod
def list_by_category(cls, category: str) -> List[str]:
"""List registered component names in a category."""
return cls._registry.list_by_category(category)
@classmethod
def list_by_priority(cls, reverse: bool = False) -> List[str]:
"""List registered component names sorted by priority."""
return cls._registry.list_by_priority(reverse)

View File

@ -6,7 +6,7 @@ from jinja2 import Template
from torch import Tensor
from astrai.config.param_config import ModelParameter
from astrai.core.factory import BaseFactory
from astrai.factory import BaseFactory
from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager
HistoryType = List[Tuple[str, str]]
@ -299,8 +299,6 @@ class GeneratorFactory(BaseFactory[GeneratorCore]):
result = generator.generate(request)
"""
_registry: Dict[str, type] = {}
@staticmethod
def create(parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore:
"""Create a generator based on request characteristics.

View File

@ -1,12 +1,8 @@
from astrai.trainer.schedule import BaseScheduler, SchedulerFactory
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
from astrai.trainer.train_callback import (
CheckpointCallback,
GradientClippingCallback,
MetricLoggerCallback,
ProgressBarCallback,
SchedulerCallback,
TrainCallback,
CallbackFactory,
)
from astrai.trainer.trainer import Trainer
@ -19,11 +15,7 @@ __all__ = [
# Scheduler factory
"SchedulerFactory",
"BaseScheduler",
# Callbacks
# Callback factory
"TrainCallback",
"GradientClippingCallback",
"SchedulerCallback",
"CheckpointCallback",
"ProgressBarCallback",
"MetricLoggerCallback",
"CallbackFactory",
]

View File

@ -6,7 +6,7 @@ from typing import Any, Dict, List, Type
from torch.optim.lr_scheduler import LRScheduler
from astrai.core.factory import BaseFactory
from astrai.factory import BaseFactory
class BaseScheduler(LRScheduler, ABC):
@ -41,8 +41,6 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
scheduler = SchedulerFactory.create("custom", optimizer, **kwargs)
"""
_registry: Dict[str, Type[BaseScheduler]] = {}
@classmethod
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]) -> None:
"""Validate that the scheduler class inherits from BaseScheduler."""

View File

@ -10,7 +10,7 @@ import torch.nn.functional as F
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from astrai.core.factory import BaseFactory
from astrai.factory import BaseFactory
def unwrap_model(model: nn.Module) -> nn.Module:
@ -122,8 +122,6 @@ class StrategyFactory(BaseFactory["BaseStrategy"]):
strategy = StrategyFactory.create("custom", model, device)
"""
_registry: Dict[str, type] = {}
@classmethod
def _validate_component(cls, strategy_cls: type) -> None:
"""Validate that the strategy class inherits from BaseStrategy."""

View File

@ -2,7 +2,7 @@ import json
import os
import time
from pathlib import Path
from typing import Callable, List, Optional, Protocol
from typing import Callable, List, Optional, Protocol, runtime_checkable
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
@ -21,8 +21,10 @@ from astrai.trainer.metric_util import (
ctx_get_lr,
)
from astrai.trainer.train_context import TrainContext
from astrai.factory import BaseFactory
@runtime_checkable
class TrainCallback(Protocol):
"""
Callback interface for trainer.
@ -56,6 +58,25 @@ class TrainCallback(Protocol):
"""Called when an error occurs during training."""
class CallbackFactory(BaseFactory[TrainCallback]):
"""Factory for registering and creating training callbacks.
Example:
@CallbackFactory.register("my_callback")
class MyCallback(TrainCallback):
...
callback = CallbackFactory.create("my_callback", **kwargs)
"""
@classmethod
def _validate_component(cls, callback_cls: type) -> None:
"""Validate that the callback class inherits from TrainCallback."""
if not issubclass(callback_cls, TrainCallback):
raise TypeError(f"{callback_cls.__name__} must inherit from TrainCallback")
@CallbackFactory.register("gradient_clipping")
class GradientClippingCallback(TrainCallback):
"""
Gradient clipping callback for trainer.
@ -69,6 +90,7 @@ class GradientClippingCallback(TrainCallback):
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
@CallbackFactory.register("scheduler")
class SchedulerCallback(TrainCallback):
"""
Scheduler callback for trainer.
@ -87,6 +109,7 @@ class SchedulerCallback(TrainCallback):
context.scheduler.step()
@CallbackFactory.register("checkpoint")
class CheckpointCallback(TrainCallback):
"""
Checkpoint callback for trainer.
@ -135,6 +158,7 @@ class CheckpointCallback(TrainCallback):
self._save_checkpoint(context)
@CallbackFactory.register("progress_bar")
class ProgressBarCallback(TrainCallback):
"""
Progress bar callback for trainer.
@ -169,6 +193,7 @@ class ProgressBarCallback(TrainCallback):
self.progress_bar.close()
@CallbackFactory.register("metric_logger")
class MetricLoggerCallback(TrainCallback):
def __init__(
self,

View File

@ -5,12 +5,8 @@ from astrai.config import TrainConfig
from astrai.data.serialization import Checkpoint
from astrai.parallel.setup import spawn_parallel_fn
from astrai.trainer.train_callback import (
CheckpointCallback,
GradientClippingCallback,
MetricLoggerCallback,
ProgressBarCallback,
SchedulerCallback,
TrainCallback,
CallbackFactory,
)
from astrai.trainer.train_context import TrainContext, TrainContextBuilder
@ -28,13 +24,13 @@ class Trainer:
)
def _get_default_callbacks(self) -> List[TrainCallback]:
train_config = self.train_config
cfg = self.train_config
return [
ProgressBarCallback(train_config.n_epoch),
CheckpointCallback(train_config.ckpt_dir, train_config.ckpt_interval),
MetricLoggerCallback(train_config.ckpt_dir, train_config.ckpt_interval),
GradientClippingCallback(train_config.max_grad_norm),
SchedulerCallback(),
CallbackFactory.create("progress_bar", cfg.n_epoch),
CallbackFactory.create("checkpoint", cfg.ckpt_dir, cfg.ckpt_interval),
CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval),
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
CallbackFactory.create("scheduler"),
]
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext: