refactor: 优化工厂模式结构

This commit is contained in:
ViperEkura 2026-04-04 11:33:58 +08:00
parent 7861af12e4
commit e97536758f
9 changed files with 164 additions and 136 deletions

View File

@ -5,7 +5,8 @@ from astrai.config import (
ModelConfig, ModelConfig,
TrainConfig, TrainConfig,
) )
from astrai.data import BpeTokenizer, DatasetLoader from astrai.core.factory import BaseFactory
from astrai.data import BpeTokenizer, DatasetFactory
from astrai.inference.generator import ( from astrai.inference.generator import (
BatchGenerator, BatchGenerator,
EmbeddingEncoder, EmbeddingEncoder,
@ -21,7 +22,7 @@ __all__ = [
"Transformer", "Transformer",
"ModelConfig", "ModelConfig",
"TrainConfig", "TrainConfig",
"DatasetLoader", "DatasetFactory",
"BpeTokenizer", "BpeTokenizer",
"GenerationRequest", "GenerationRequest",
"LoopGenerator", "LoopGenerator",
@ -32,4 +33,5 @@ __all__ = [
"Trainer", "Trainer",
"StrategyFactory", "StrategyFactory",
"SchedulerFactory", "SchedulerFactory",
"BaseFactory",
] ]

105
astrai/core/factory.py Normal file
View File

@ -0,0 +1,105 @@
"""Base factory class for extensible component registration."""
from abc import ABC
from typing import Callable, Dict, Generic, Type, TypeVar
T = TypeVar("T")
class BaseFactory(ABC, Generic[T]):
"""Generic factory class for component registration and creation.
This base class provides a decorator-based registration pattern
for creating extensible component factories.
Example usage:
class MyFactory(BaseFactory[MyBaseClass]):
pass
@MyFactory.register("custom")
class CustomComponent(MyBaseClass):
...
component = MyFactory.create("custom", *args, **kwargs)
"""
_registry: Dict[str, Type[T]] = {}
@classmethod
def register(cls, name: str) -> Callable[[Type[T]], Type[T]]:
"""Decorator to register a component class.
Args:
name: Registration name for the component
Returns:
Decorator function that registers the component class
Raises:
TypeError: If the decorated class doesn't inherit from the base type
"""
def decorator(component_cls: Type[T]) -> Type[T]:
cls._validate_component(component_cls)
cls._registry[name] = component_cls
return component_cls
return decorator
@classmethod
def create(cls, name: str, *args, **kwargs) -> T:
"""Create a component instance by name.
Args:
name: Registered name of the component
*args: Positional arguments passed to component constructor
**kwargs: Keyword arguments passed to component constructor
Returns:
Component instance
Raises:
ValueError: If the component name is not registered
"""
if name not in cls._registry:
raise ValueError(
f"Unknown component: '{name}'. "
f"Supported types: {sorted(cls._registry.keys())}"
)
component_cls = cls._registry[name]
return component_cls(*args, **kwargs)
@classmethod
def _validate_component(cls, component_cls: Type[T]) -> None:
"""Validate that the component class is valid for this factory.
Override this method in subclasses to add custom validation.
Args:
component_cls: Component class to validate
Raises:
TypeError: If the component class is invalid
"""
pass
@classmethod
def list_registered(cls) -> list:
"""List all registered component names.
Returns:
List of registered component names
"""
return sorted(cls._registry.keys())
@classmethod
def is_registered(cls, name: str) -> bool:
"""Check if a component name is registered.
Args:
name: Component name to check
Returns:
True if registered, False otherwise
"""
return name in cls._registry

View File

@ -1,7 +1,7 @@
from astrai.data.dataset import ( from astrai.data.dataset import (
BaseDataset, BaseDataset,
DatasetFactory, DatasetFactory,
DatasetLoader, DatasetFactory,
DPODataset, DPODataset,
GRPODataset, GRPODataset,
MultiSegmentFetcher, MultiSegmentFetcher,
@ -21,8 +21,8 @@ __all__ = [
"GRPODataset", "GRPODataset",
# Fetchers # Fetchers
"MultiSegmentFetcher", "MultiSegmentFetcher",
# Factory (DatasetLoader is alias for backward compatibility) # Factory (DatasetFactory is alias for backward compatibility)
"DatasetLoader", "DatasetFactory",
"DatasetFactory", "DatasetFactory",
# Tokenizer and sampler # Tokenizer and sampler
"BpeTokenizer", "BpeTokenizer",

View File

@ -8,6 +8,7 @@ import torch
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset from torch.utils.data import Dataset
from astrai.core.factory import BaseFactory
from astrai.data.serialization import load_h5 from astrai.data.serialization import load_h5
@ -165,7 +166,7 @@ class BaseDataset(Dataset, ABC):
return (self.total_samples - 1 - self.window_size) // self.stride + 1 return (self.total_samples - 1 - self.window_size) // self.stride + 1
class DatasetFactory: class DatasetFactory(BaseFactory["BaseDataset"]):
"""Factory class for creating dataset instances. """Factory class for creating dataset instances.
Supports decorator-based registration for extensible dataset types. Supports decorator-based registration for extensible dataset types.
@ -180,30 +181,16 @@ class DatasetFactory:
dataset = DatasetFactory.create("custom", window_size, stride) dataset = DatasetFactory.create("custom", window_size, stride)
""" """
SUPPORTED_TYPES = frozenset({"seq", "sft", "dpo", "grpo"}) _registry: Dict[str, type] = {}
DATASET_MAP: Dict[str, type] = {}
@classmethod @classmethod
def register(cls, name: str): def _validate_component(cls, dataset_cls: type) -> None:
"""Decorator to register a new dataset class. """Validate that the dataset class inherits from BaseDataset."""
if not issubclass(dataset_cls, BaseDataset):
Args: raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
name: Registration name for the dataset type
Returns:
Decorator function that registers the dataset class
"""
def decorator(dataset_cls: type) -> type:
if not issubclass(dataset_cls, BaseDataset):
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
cls.DATASET_MAP[name] = dataset_cls
return dataset_cls
return decorator
@classmethod @classmethod
def create(cls, train_type: str, window_size: int, stride: int) -> BaseDataset: def create(cls, train_type: str, window_size: int, stride: int) -> "BaseDataset":
"""Create a dataset instance. """Create a dataset instance.
Args: Args:
@ -214,19 +201,7 @@ class DatasetFactory:
Returns: Returns:
Dataset instance Dataset instance
""" """
if train_type not in cls.SUPPORTED_TYPES: return super().create(train_type, window_size, stride)
raise ValueError(
f"Unknown dataset type: '{train_type}'. "
f"Supported types: {sorted(cls.SUPPORTED_TYPES)}"
)
if train_type not in cls.DATASET_MAP:
raise NotImplementedError(
f"Dataset type '{train_type}' is supported but not yet implemented."
)
dataset_cls = cls.DATASET_MAP[train_type]
return dataset_cls(window_size, stride)
@classmethod @classmethod
def load( def load(
@ -235,7 +210,7 @@ class DatasetFactory:
load_path: str, load_path: str,
window_size: int, window_size: int,
stride: Optional[int] = None, stride: Optional[int] = None,
) -> BaseDataset: ) -> "BaseDataset":
"""Create and load a dataset in one step. """Create and load a dataset in one step.
Args: Args:
@ -258,7 +233,7 @@ class DatasetFactory:
@classmethod @classmethod
def available_types(cls) -> list: def available_types(cls) -> list:
"""Return list of registered dataset type names.""" """Return list of registered dataset type names."""
return list(cls.DATASET_MAP.keys()) return cls.list_registered()
# ============== Dataset Classes ============== # ============== Dataset Classes ==============
@ -362,7 +337,3 @@ class GRPODataset(BaseDataset):
"masks": masks, "masks": masks,
"rewards": rewards, "rewards": rewards,
} }
# Backward compatibility alias
DatasetLoader = DatasetFactory

View File

@ -1,10 +1,11 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Generator, List, Optional, Tuple, Union from typing import Dict, Generator, List, Optional, Tuple, Union
import torch import torch
from torch import Tensor from torch import Tensor
from astrai.config.param_config import ModelParameter from astrai.config.param_config import ModelParameter
from astrai.core.factory import BaseFactory
from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager
HistoryType = List[Tuple[str, str]] HistoryType = List[Tuple[str, str]]
@ -254,7 +255,7 @@ class EmbeddingEncoder(EmbeddingEncoderCore):
return super().encode(sentence) return super().encode(sentence)
class GeneratorFactory: class GeneratorFactory(BaseFactory[GeneratorCore]):
"""Factory class for creating generator instances. """Factory class for creating generator instances.
Provides smart generator selection based on request characteristics: Provides smart generator selection based on request characteristics:
@ -263,14 +264,14 @@ class GeneratorFactory:
- Single: Use LoopGenerator for single query non-streaming - Single: Use LoopGenerator for single query non-streaming
Example usage: Example usage:
generator = GeneratorFactory.create_generator(parameter, request) generator = GeneratorFactory.create(parameter, request)
result = generator.generate(request) result = generator.generate(request)
""" """
_registry: Dict[str, type] = {}
@staticmethod @staticmethod
def create_generator( def create(parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore:
parameter: ModelParameter, request: GenerationRequest
) -> GeneratorCore:
"""Create a generator based on request characteristics. """Create a generator based on request characteristics.
Args: Args:

View File

@ -6,6 +6,8 @@ from typing import Any, Dict, List, Type
from torch.optim.lr_scheduler import LRScheduler from torch.optim.lr_scheduler import LRScheduler
from astrai.core.factory import BaseFactory
class BaseScheduler(LRScheduler, ABC): class BaseScheduler(LRScheduler, ABC):
"""Base scheduler class for all other schedulers.""" """Base scheduler class for all other schedulers."""
@ -25,7 +27,7 @@ class BaseScheduler(LRScheduler, ABC):
super().load_state_dict(state_dict) super().load_state_dict(state_dict)
class SchedulerFactory: class SchedulerFactory(BaseFactory["BaseScheduler"]):
"""Factory class for creating learning rate schedulers. """Factory class for creating learning rate schedulers.
Supports decorator-based registration for extensible scheduler types. Supports decorator-based registration for extensible scheduler types.
@ -36,34 +38,21 @@ class SchedulerFactory:
class CustomScheduler(BaseScheduler): class CustomScheduler(BaseScheduler):
... ...
scheduler = SchedulerFactory.create(optimizer, "custom", **kwargs) scheduler = SchedulerFactory.create("custom", optimizer, **kwargs)
""" """
SCHEDULER_MAP: Dict[str, Type[BaseScheduler]] = {} _registry: Dict[str, Type[BaseScheduler]] = {}
@classmethod @classmethod
def register(cls, name: str): def _validate_component(cls, scheduler_cls: Type[BaseScheduler]) -> None:
"""Decorator to register a new scheduler class. """Validate that the scheduler class inherits from BaseScheduler."""
if not issubclass(scheduler_cls, BaseScheduler):
Args: raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
name: Registration name for the scheduler
Returns:
Decorator function that registers the scheduler class
"""
def decorator(scheduler_cls: Type[BaseScheduler]) -> Type[BaseScheduler]:
if not issubclass(scheduler_cls, BaseScheduler):
raise TypeError(
f"{scheduler_cls.__name__} must inherit from BaseScheduler"
)
cls.SCHEDULER_MAP[name] = scheduler_cls
return scheduler_cls
return decorator
@classmethod @classmethod
def create(cls, optimizer, schedule_type: str = "none", **kwargs) -> BaseScheduler: def create(
cls, optimizer, schedule_type: str = "none", **kwargs
) -> "BaseScheduler":
"""Create a scheduler instance by type name. """Create a scheduler instance by type name.
Args: Args:
@ -73,23 +62,13 @@ class SchedulerFactory:
Returns: Returns:
Scheduler instance Scheduler instance
Raises:
ValueError: If schedule_type is not supported
""" """
if schedule_type not in cls.SCHEDULER_MAP: return super().create(schedule_type, optimizer, **kwargs)
raise ValueError(
f"Unknown schedule type: '{schedule_type}'. "
f"Supported types: {sorted(cls.SCHEDULER_MAP.keys())}"
)
scheduler_cls = cls.SCHEDULER_MAP[schedule_type]
return scheduler_cls(optimizer, **kwargs)
@classmethod @classmethod
def available_types(cls) -> list: def available_types(cls) -> list:
"""Return list of registered scheduler type names.""" """Return list of registered scheduler type names."""
return list(cls.SCHEDULER_MAP.keys()) return cls.list_registered()
# ----------- Scheduler implementations ----------- # ----------- Scheduler implementations -----------

View File

@ -10,6 +10,8 @@ import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from astrai.core.factory import BaseFactory
def unwrap_model(model: nn.Module) -> nn.Module: def unwrap_model(model: nn.Module) -> nn.Module:
"""Unwrap DDP wrapper if present to get the original model.""" """Unwrap DDP wrapper if present to get the original model."""
@ -106,7 +108,7 @@ class BaseStrategy(ABC):
return self.compute_loss(batch) return self.compute_loss(batch)
class StrategyFactory: class StrategyFactory(BaseFactory["BaseStrategy"]):
"""Factory class for creating training strategy instances. """Factory class for creating training strategy instances.
Supports decorator-based registration for extensible strategy types. Supports decorator-based registration for extensible strategy types.
@ -117,68 +119,36 @@ class StrategyFactory:
class CustomStrategy(BaseStrategy): class CustomStrategy(BaseStrategy):
... ...
strategy = StrategyFactory.create(model, "custom", device) strategy = StrategyFactory.create("custom", model, device)
""" """
SUPPORTED_STRATEGIES = frozenset({"seq", "sft", "dpo", "grpo"}) _registry: Dict[str, type] = {}
STRATEGY_MAP: Dict[str, type] = {}
@classmethod @classmethod
def register(cls, name: str): def _validate_component(cls, strategy_cls: type) -> None:
"""Decorator to register a new strategy class. """Validate that the strategy class inherits from BaseStrategy."""
if not issubclass(strategy_cls, BaseStrategy):
Args: raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
name: Registration name for the strategy
Returns:
Decorator function that registers the strategy class
"""
def decorator(strategy_cls: type) -> type:
if not issubclass(strategy_cls, BaseStrategy):
raise TypeError(
f"{strategy_cls.__name__} must inherit from BaseStrategy"
)
cls.STRATEGY_MAP[name] = strategy_cls
return strategy_cls
return decorator
@classmethod @classmethod
def create(cls, model, train_type: str, device: str, **kwargs) -> BaseStrategy: def create(cls, train_type: str, model, device: str, **kwargs) -> "BaseStrategy":
"""Create a strategy instance based on training type. """Create a strategy instance based on training type.
Args: Args:
model: Model instance for the strategy
train_type: Type of training ("seq", "sft", "dpo", "grpo") train_type: Type of training ("seq", "sft", "dpo", "grpo")
model: Model instance for the strategy
device: Device to run the strategy on device: Device to run the strategy on
**kwargs: Additional arguments passed to strategy constructor **kwargs: Additional arguments passed to strategy constructor
Returns: Returns:
Strategy instance Strategy instance
Raises:
ValueError: If train_type is not supported
NotImplementedError: If train_type is in supported list but not implemented
""" """
if train_type not in cls.SUPPORTED_STRATEGIES: return super().create(train_type, model, device, **kwargs)
raise ValueError(
f"Unknown training strategy: '{train_type}'. "
f"Supported strategies: {sorted(cls.SUPPORTED_STRATEGIES)}"
)
if train_type not in cls.STRATEGY_MAP:
raise NotImplementedError(
f"Strategy '{train_type}' is supported but not yet implemented."
)
strategy_cls = cls.STRATEGY_MAP[train_type]
return strategy_cls(model, device, **kwargs)
@classmethod @classmethod
def available_strategies(cls) -> list: def available_strategies(cls) -> list:
"""Return list of registered strategy names.""" """Return list of registered strategy names."""
return list(cls.STRATEGY_MAP.keys()) return cls.list_registered()
# ============== Strategy Classes ============== # ============== Strategy Classes ==============

View File

@ -8,7 +8,7 @@ import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from astrai.config import ModelParameter, TrainConfig from astrai.config import ModelParameter, TrainConfig
from astrai.data import DatasetLoader from astrai.data import DatasetFactory
from astrai.parallel import get_rank from astrai.parallel import get_rank
from astrai.trainer import SchedulerFactory, Trainer from astrai.trainer import SchedulerFactory, Trainer
@ -205,7 +205,7 @@ def train(
strategy_kwargs = {"dpo_beta": dpo_beta, "label_smoothing": label_smoothing} strategy_kwargs = {"dpo_beta": dpo_beta, "label_smoothing": label_smoothing}
dataset = DatasetLoader.load( dataset = DatasetFactory.load(
train_type=train_type, train_type=train_type,
load_path=data_root_path, load_path=data_root_path,
window_size=window_size, window_size=window_size,

View File

@ -23,7 +23,7 @@ def test_dataset_loader_random_paths(base_test_env):
save_h5(test_dir, f"data_{i}", dummy_data) save_h5(test_dir, f"data_{i}", dummy_data)
# Test loading with multiple paths # Test loading with multiple paths
loaded_dataset = DatasetLoader.load( loaded_dataset = DatasetFactory.load(
train_type="seq", train_type="seq",
load_path=test_dir, load_path=test_dir,
window_size=64, window_size=64,
@ -57,7 +57,7 @@ def test_dpo_strategy_with_random_data(base_test_env):
save_h5(test_dir, "dpo_data", dummy_data) save_h5(test_dir, "dpo_data", dummy_data)
# Load DPO dataset # Load DPO dataset
dpo_dataset = DatasetLoader.load( dpo_dataset = DatasetFactory.load(
train_type="dpo", train_type="dpo",
load_path=test_dir, load_path=test_dir,
window_size=64, window_size=64,
@ -93,7 +93,7 @@ def test_sft_dataset_with_random_data(base_test_env):
save_h5(test_dir, "sft_data", dummy_data) save_h5(test_dir, "sft_data", dummy_data)
# Load SFT dataset # Load SFT dataset
sft_dataset = DatasetLoader.load( sft_dataset = DatasetFactory.load(
train_type="sft", train_type="sft",
load_path=test_dir, load_path=test_dir,
window_size=64, window_size=64,
@ -127,7 +127,7 @@ def test_dataset_with_custom_stride(base_test_env):
# Test with custom stride # Test with custom stride
custom_stride = 32 custom_stride = 32
dataset = DatasetLoader.load( dataset = DatasetFactory.load(
train_type="seq", load_path=test_dir, window_size=64, stride=custom_stride train_type="seq", load_path=test_dir, window_size=64, stride=custom_stride
) )
@ -136,7 +136,7 @@ def test_dataset_with_custom_stride(base_test_env):
# With stride 32 and window 64 on 200 length data, we should get more samples # With stride 32 and window 64 on 200 length data, we should get more samples
# than with default stride (which equals window size) # than with default stride (which equals window size)
default_stride_dataset = DatasetLoader.load( default_stride_dataset = DatasetFactory.load(
train_type="seq", train_type="seq",
load_path=test_dir, load_path=test_dir,
window_size=64, window_size=64,