feat: 优化工厂模式的实现
This commit is contained in:
parent
aa5e03d7f6
commit
3346c75584
|
|
@ -5,7 +5,7 @@ from astrai.config import (
|
|||
ModelConfig,
|
||||
TrainConfig,
|
||||
)
|
||||
from astrai.core.factory import BaseFactory
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.data import BpeTokenizer, DatasetFactory
|
||||
from astrai.inference.generator import (
|
||||
BatchGenerator,
|
||||
|
|
|
|||
|
|
@ -1,105 +0,0 @@
|
|||
"""Base factory class for extensible component registration."""
|
||||
|
||||
from abc import ABC
|
||||
from typing import Callable, Dict, Generic, Type, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BaseFactory(ABC, Generic[T]):
|
||||
"""Generic factory class for component registration and creation.
|
||||
|
||||
This base class provides a decorator-based registration pattern
|
||||
for creating extensible component factories.
|
||||
|
||||
Example usage:
|
||||
class MyFactory(BaseFactory[MyBaseClass]):
|
||||
pass
|
||||
|
||||
@MyFactory.register("custom")
|
||||
class CustomComponent(MyBaseClass):
|
||||
...
|
||||
|
||||
component = MyFactory.create("custom", *args, **kwargs)
|
||||
"""
|
||||
|
||||
_registry: Dict[str, Type[T]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str) -> Callable[[Type[T]], Type[T]]:
|
||||
"""Decorator to register a component class.
|
||||
|
||||
Args:
|
||||
name: Registration name for the component
|
||||
|
||||
Returns:
|
||||
Decorator function that registers the component class
|
||||
|
||||
Raises:
|
||||
TypeError: If the decorated class doesn't inherit from the base type
|
||||
"""
|
||||
|
||||
def decorator(component_cls: Type[T]) -> Type[T]:
|
||||
cls._validate_component(component_cls)
|
||||
cls._registry[name] = component_cls
|
||||
return component_cls
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def create(cls, name: str, *args, **kwargs) -> T:
|
||||
"""Create a component instance by name.
|
||||
|
||||
Args:
|
||||
name: Registered name of the component
|
||||
*args: Positional arguments passed to component constructor
|
||||
**kwargs: Keyword arguments passed to component constructor
|
||||
|
||||
Returns:
|
||||
Component instance
|
||||
|
||||
Raises:
|
||||
ValueError: If the component name is not registered
|
||||
"""
|
||||
if name not in cls._registry:
|
||||
raise ValueError(
|
||||
f"Unknown component: '{name}'. "
|
||||
f"Supported types: {sorted(cls._registry.keys())}"
|
||||
)
|
||||
component_cls = cls._registry[name]
|
||||
return component_cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, component_cls: Type[T]) -> None:
|
||||
"""Validate that the component class is valid for this factory.
|
||||
|
||||
Override this method in subclasses to add custom validation.
|
||||
|
||||
Args:
|
||||
component_cls: Component class to validate
|
||||
|
||||
Raises:
|
||||
TypeError: If the component class is invalid
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def list_registered(cls) -> list:
|
||||
"""List all registered component names.
|
||||
|
||||
Returns:
|
||||
List of registered component names
|
||||
"""
|
||||
return sorted(cls._registry.keys())
|
||||
|
||||
@classmethod
|
||||
def is_registered(cls, name: str) -> bool:
|
||||
"""Check if a component name is registered.
|
||||
|
||||
Args:
|
||||
name: Component name to check
|
||||
|
||||
Returns:
|
||||
True if registered, False otherwise
|
||||
"""
|
||||
return name in cls._registry
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
from astrai.data.dataset import (
|
||||
BaseDataset,
|
||||
DatasetFactory,
|
||||
DatasetFactory,
|
||||
DPODataset,
|
||||
GRPODataset,
|
||||
MultiSegmentFetcher,
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import torch
|
|||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from astrai.core.factory import BaseFactory
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.data.serialization import load_h5
|
||||
|
||||
|
||||
|
|
@ -181,8 +181,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
|||
dataset = DatasetFactory.create("custom", window_size, stride)
|
||||
"""
|
||||
|
||||
_registry: Dict[str, type] = {}
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, dataset_cls: type) -> None:
|
||||
"""Validate that the dataset class inherits from BaseDataset."""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,187 @@
|
|||
"""Base factory class for extensible component registration."""
|
||||
|
||||
from abc import ABC
|
||||
from typing import Callable, Dict, Generic, Type, TypeVar, Optional, List, Tuple
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Registry:
|
||||
"""Flexible registry for component classes with category and priority support.
|
||||
|
||||
This registry stores component classes with optional metadata (category, priority).
|
||||
It provides methods for registration, retrieval, and listing with filtering.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._entries = {} # name -> (component_cls, category, priority)
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
component_cls: Type,
|
||||
category: Optional[str] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
"""Register a component class with optional category and priority."""
|
||||
if name in self._entries:
|
||||
raise ValueError(f"Component '{name}' is already registered")
|
||||
self._entries[name] = (component_cls, category, priority)
|
||||
|
||||
def get(self, name: str) -> Type:
|
||||
"""Get component class by name."""
|
||||
if name not in self._entries:
|
||||
raise KeyError(f"Component '{name}' not found in registry")
|
||||
return self._entries[name][0]
|
||||
|
||||
def get_with_metadata(self, name: str) -> Tuple[Type, Optional[str], int]:
|
||||
"""Get component class with its metadata."""
|
||||
entry = self._entries.get(name)
|
||||
if entry is None:
|
||||
raise KeyError(f"Component '{name}' not found in registry")
|
||||
return entry
|
||||
|
||||
def contains(self, name: str) -> bool:
|
||||
"""Check if a name is registered."""
|
||||
return name in self._entries
|
||||
|
||||
def list_names(self) -> List[str]:
|
||||
"""Return list of registered component names."""
|
||||
return sorted(self._entries.keys())
|
||||
|
||||
def list_by_category(self, category: str) -> List[str]:
|
||||
"""Return names of components belonging to a specific category."""
|
||||
return sorted(
|
||||
name for name, (_, cat, _) in self._entries.items() if cat == category
|
||||
)
|
||||
|
||||
def list_by_priority(self, reverse: bool = False) -> List[str]:
|
||||
"""Return names sorted by priority (default ascending)."""
|
||||
return sorted(
|
||||
self._entries.keys(),
|
||||
key=lambda name: self._entries[name][2],
|
||||
reverse=reverse,
|
||||
)
|
||||
|
||||
def entries(self) -> Dict[str, Tuple[Type, Optional[str], int]]:
|
||||
"""Return raw entries dictionary."""
|
||||
return self._entries.copy()
|
||||
|
||||
|
||||
class BaseFactory(ABC, Generic[T]):
|
||||
"""Generic factory class for component registration and creation.
|
||||
|
||||
This base class provides a decorator-based registration pattern
|
||||
for creating extensible component factories.
|
||||
|
||||
Example usage:
|
||||
class MyFactory(BaseFactory[MyBaseClass]):
|
||||
pass
|
||||
|
||||
@MyFactory.register("custom")
|
||||
class CustomComponent(MyBaseClass):
|
||||
...
|
||||
|
||||
component = MyFactory.create("custom", *args, **kwargs)
|
||||
"""
|
||||
|
||||
_registry: Registry
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
cls._registry = Registry()
|
||||
|
||||
@classmethod
|
||||
def register(
|
||||
cls, name: str, category: Optional[str] = None, priority: int = 0
|
||||
) -> Callable[[Type[T]], Type[T]]:
|
||||
"""Decorator to register a component class with optional category and priority.
|
||||
|
||||
Args:
|
||||
name: Registration name for the component
|
||||
category: Optional category for grouping components
|
||||
priority: Priority for ordering (default 0)
|
||||
|
||||
Returns:
|
||||
Decorator function that registers the component class
|
||||
|
||||
Raises:
|
||||
TypeError: If the decorated class doesn't inherit from the base type
|
||||
"""
|
||||
|
||||
def decorator(component_cls: Type[T]) -> Type[T]:
|
||||
cls._validate_component(component_cls)
|
||||
cls._registry.register(
|
||||
name, component_cls, category=category, priority=priority
|
||||
)
|
||||
return component_cls
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def create(cls, name: str, *args, **kwargs) -> T:
|
||||
"""Create a component instance by name.
|
||||
|
||||
Args:
|
||||
name: Registered name of the component
|
||||
*args: Positional arguments passed to component constructor
|
||||
**kwargs: Keyword arguments passed to component constructor
|
||||
|
||||
Returns:
|
||||
Component instance
|
||||
|
||||
Raises:
|
||||
ValueError: If the component name is not registered
|
||||
"""
|
||||
if not cls._registry.contains(name):
|
||||
raise ValueError(
|
||||
f"Unknown component: '{name}'. "
|
||||
f"Supported types: {sorted(cls._registry.list_names())}"
|
||||
)
|
||||
component_cls = cls._registry.get(name)
|
||||
return component_cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, component_cls: Type[T]) -> None:
|
||||
"""Validate that the component class is valid for this factory.
|
||||
|
||||
Override this method in subclasses to add custom validation.
|
||||
|
||||
Args:
|
||||
component_cls: Component class to validate
|
||||
|
||||
Raises:
|
||||
TypeError: If the component class is invalid
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def list_registered(cls) -> list:
|
||||
"""List all registered component names.
|
||||
|
||||
Returns:
|
||||
List of registered component names
|
||||
"""
|
||||
return cls._registry.list_names()
|
||||
|
||||
@classmethod
|
||||
def is_registered(cls, name: str) -> bool:
|
||||
"""Check if a component name is registered.
|
||||
|
||||
Args:
|
||||
name: Component name to check
|
||||
|
||||
Returns:
|
||||
True if registered, False otherwise
|
||||
"""
|
||||
return cls._registry.contains(name)
|
||||
|
||||
@classmethod
|
||||
def list_by_category(cls, category: str) -> List[str]:
|
||||
"""List registered component names in a category."""
|
||||
return cls._registry.list_by_category(category)
|
||||
|
||||
@classmethod
|
||||
def list_by_priority(cls, reverse: bool = False) -> List[str]:
|
||||
"""List registered component names sorted by priority."""
|
||||
return cls._registry.list_by_priority(reverse)
|
||||
|
|
@ -6,7 +6,7 @@ from jinja2 import Template
|
|||
from torch import Tensor
|
||||
|
||||
from astrai.config.param_config import ModelParameter
|
||||
from astrai.core.factory import BaseFactory
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager
|
||||
|
||||
HistoryType = List[Tuple[str, str]]
|
||||
|
|
@ -299,8 +299,6 @@ class GeneratorFactory(BaseFactory[GeneratorCore]):
|
|||
result = generator.generate(request)
|
||||
"""
|
||||
|
||||
_registry: Dict[str, type] = {}
|
||||
|
||||
@staticmethod
|
||||
def create(parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore:
|
||||
"""Create a generator based on request characteristics.
|
||||
|
|
|
|||
|
|
@ -1,12 +1,8 @@
|
|||
from astrai.trainer.schedule import BaseScheduler, SchedulerFactory
|
||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||
from astrai.trainer.train_callback import (
|
||||
CheckpointCallback,
|
||||
GradientClippingCallback,
|
||||
MetricLoggerCallback,
|
||||
ProgressBarCallback,
|
||||
SchedulerCallback,
|
||||
TrainCallback,
|
||||
CallbackFactory,
|
||||
)
|
||||
from astrai.trainer.trainer import Trainer
|
||||
|
||||
|
|
@ -19,11 +15,7 @@ __all__ = [
|
|||
# Scheduler factory
|
||||
"SchedulerFactory",
|
||||
"BaseScheduler",
|
||||
# Callbacks
|
||||
# Callback factory
|
||||
"TrainCallback",
|
||||
"GradientClippingCallback",
|
||||
"SchedulerCallback",
|
||||
"CheckpointCallback",
|
||||
"ProgressBarCallback",
|
||||
"MetricLoggerCallback",
|
||||
"CallbackFactory",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import Any, Dict, List, Type
|
|||
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from astrai.core.factory import BaseFactory
|
||||
from astrai.factory import BaseFactory
|
||||
|
||||
|
||||
class BaseScheduler(LRScheduler, ABC):
|
||||
|
|
@ -41,8 +41,6 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
|
|||
scheduler = SchedulerFactory.create("custom", optimizer, **kwargs)
|
||||
"""
|
||||
|
||||
_registry: Dict[str, Type[BaseScheduler]] = {}
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]) -> None:
|
||||
"""Validate that the scheduler class inherits from BaseScheduler."""
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import torch.nn.functional as F
|
|||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from astrai.core.factory import BaseFactory
|
||||
from astrai.factory import BaseFactory
|
||||
|
||||
|
||||
def unwrap_model(model: nn.Module) -> nn.Module:
|
||||
|
|
@ -122,8 +122,6 @@ class StrategyFactory(BaseFactory["BaseStrategy"]):
|
|||
strategy = StrategyFactory.create("custom", model, device)
|
||||
"""
|
||||
|
||||
_registry: Dict[str, type] = {}
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, strategy_cls: type) -> None:
|
||||
"""Validate that the strategy class inherits from BaseStrategy."""
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import json
|
|||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Protocol
|
||||
from typing import Callable, List, Optional, Protocol, runtime_checkable
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
|
|
@ -21,8 +21,10 @@ from astrai.trainer.metric_util import (
|
|||
ctx_get_lr,
|
||||
)
|
||||
from astrai.trainer.train_context import TrainContext
|
||||
from astrai.factory import BaseFactory
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class TrainCallback(Protocol):
|
||||
"""
|
||||
Callback interface for trainer.
|
||||
|
|
@ -56,6 +58,25 @@ class TrainCallback(Protocol):
|
|||
"""Called when an error occurs during training."""
|
||||
|
||||
|
||||
class CallbackFactory(BaseFactory[TrainCallback]):
|
||||
"""Factory for registering and creating training callbacks.
|
||||
|
||||
Example:
|
||||
@CallbackFactory.register("my_callback")
|
||||
class MyCallback(TrainCallback):
|
||||
...
|
||||
|
||||
callback = CallbackFactory.create("my_callback", **kwargs)
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, callback_cls: type) -> None:
|
||||
"""Validate that the callback class inherits from TrainCallback."""
|
||||
if not issubclass(callback_cls, TrainCallback):
|
||||
raise TypeError(f"{callback_cls.__name__} must inherit from TrainCallback")
|
||||
|
||||
|
||||
@CallbackFactory.register("gradient_clipping")
|
||||
class GradientClippingCallback(TrainCallback):
|
||||
"""
|
||||
Gradient clipping callback for trainer.
|
||||
|
|
@ -69,6 +90,7 @@ class GradientClippingCallback(TrainCallback):
|
|||
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
||||
|
||||
|
||||
@CallbackFactory.register("scheduler")
|
||||
class SchedulerCallback(TrainCallback):
|
||||
"""
|
||||
Scheduler callback for trainer.
|
||||
|
|
@ -87,6 +109,7 @@ class SchedulerCallback(TrainCallback):
|
|||
context.scheduler.step()
|
||||
|
||||
|
||||
@CallbackFactory.register("checkpoint")
|
||||
class CheckpointCallback(TrainCallback):
|
||||
"""
|
||||
Checkpoint callback for trainer.
|
||||
|
|
@ -135,6 +158,7 @@ class CheckpointCallback(TrainCallback):
|
|||
self._save_checkpoint(context)
|
||||
|
||||
|
||||
@CallbackFactory.register("progress_bar")
|
||||
class ProgressBarCallback(TrainCallback):
|
||||
"""
|
||||
Progress bar callback for trainer.
|
||||
|
|
@ -169,6 +193,7 @@ class ProgressBarCallback(TrainCallback):
|
|||
self.progress_bar.close()
|
||||
|
||||
|
||||
@CallbackFactory.register("metric_logger")
|
||||
class MetricLoggerCallback(TrainCallback):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -5,12 +5,8 @@ from astrai.config import TrainConfig
|
|||
from astrai.data.serialization import Checkpoint
|
||||
from astrai.parallel.setup import spawn_parallel_fn
|
||||
from astrai.trainer.train_callback import (
|
||||
CheckpointCallback,
|
||||
GradientClippingCallback,
|
||||
MetricLoggerCallback,
|
||||
ProgressBarCallback,
|
||||
SchedulerCallback,
|
||||
TrainCallback,
|
||||
CallbackFactory,
|
||||
)
|
||||
from astrai.trainer.train_context import TrainContext, TrainContextBuilder
|
||||
|
||||
|
|
@ -28,13 +24,13 @@ class Trainer:
|
|||
)
|
||||
|
||||
def _get_default_callbacks(self) -> List[TrainCallback]:
|
||||
train_config = self.train_config
|
||||
cfg = self.train_config
|
||||
return [
|
||||
ProgressBarCallback(train_config.n_epoch),
|
||||
CheckpointCallback(train_config.ckpt_dir, train_config.ckpt_interval),
|
||||
MetricLoggerCallback(train_config.ckpt_dir, train_config.ckpt_interval),
|
||||
GradientClippingCallback(train_config.max_grad_norm),
|
||||
SchedulerCallback(),
|
||||
CallbackFactory.create("progress_bar", cfg.n_epoch),
|
||||
CallbackFactory.create("checkpoint", cfg.ckpt_dir, cfg.ckpt_interval),
|
||||
CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval),
|
||||
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
||||
CallbackFactory.create("scheduler"),
|
||||
]
|
||||
|
||||
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
||||
|
|
|
|||
Loading…
Reference in New Issue