refactor: 优化工厂模式结构

This commit is contained in:
ViperEkura 2026-04-04 11:33:58 +08:00
parent 7861af12e4
commit e97536758f
9 changed files with 164 additions and 136 deletions

View File

@ -5,7 +5,8 @@ from astrai.config import (
ModelConfig,
TrainConfig,
)
from astrai.data import BpeTokenizer, DatasetLoader
from astrai.core.factory import BaseFactory
from astrai.data import BpeTokenizer, DatasetFactory
from astrai.inference.generator import (
BatchGenerator,
EmbeddingEncoder,
@ -21,7 +22,7 @@ __all__ = [
"Transformer",
"ModelConfig",
"TrainConfig",
"DatasetLoader",
"DatasetFactory",
"BpeTokenizer",
"GenerationRequest",
"LoopGenerator",
@ -32,4 +33,5 @@ __all__ = [
"Trainer",
"StrategyFactory",
"SchedulerFactory",
"BaseFactory",
]

105
astrai/core/factory.py Normal file
View File

@ -0,0 +1,105 @@
"""Base factory class for extensible component registration."""
from abc import ABC
from typing import Callable, Dict, Generic, Type, TypeVar
T = TypeVar("T")
class BaseFactory(ABC, Generic[T]):
"""Generic factory class for component registration and creation.
This base class provides a decorator-based registration pattern
for creating extensible component factories.
Example usage:
class MyFactory(BaseFactory[MyBaseClass]):
pass
@MyFactory.register("custom")
class CustomComponent(MyBaseClass):
...
component = MyFactory.create("custom", *args, **kwargs)
"""
_registry: Dict[str, Type[T]] = {}
@classmethod
def register(cls, name: str) -> Callable[[Type[T]], Type[T]]:
"""Decorator to register a component class.
Args:
name: Registration name for the component
Returns:
Decorator function that registers the component class
Raises:
TypeError: If the decorated class doesn't inherit from the base type
"""
def decorator(component_cls: Type[T]) -> Type[T]:
cls._validate_component(component_cls)
cls._registry[name] = component_cls
return component_cls
return decorator
@classmethod
def create(cls, name: str, *args, **kwargs) -> T:
"""Create a component instance by name.
Args:
name: Registered name of the component
*args: Positional arguments passed to component constructor
**kwargs: Keyword arguments passed to component constructor
Returns:
Component instance
Raises:
ValueError: If the component name is not registered
"""
if name not in cls._registry:
raise ValueError(
f"Unknown component: '{name}'. "
f"Supported types: {sorted(cls._registry.keys())}"
)
component_cls = cls._registry[name]
return component_cls(*args, **kwargs)
@classmethod
def _validate_component(cls, component_cls: Type[T]) -> None:
"""Validate that the component class is valid for this factory.
Override this method in subclasses to add custom validation.
Args:
component_cls: Component class to validate
Raises:
TypeError: If the component class is invalid
"""
pass
@classmethod
def list_registered(cls) -> list:
"""List all registered component names.
Returns:
List of registered component names
"""
return sorted(cls._registry.keys())
@classmethod
def is_registered(cls, name: str) -> bool:
"""Check if a component name is registered.
Args:
name: Component name to check
Returns:
True if registered, False otherwise
"""
return name in cls._registry

View File

@ -1,7 +1,7 @@
from astrai.data.dataset import (
BaseDataset,
DatasetFactory,
DatasetLoader,
DatasetFactory,
DPODataset,
GRPODataset,
MultiSegmentFetcher,
@ -21,8 +21,8 @@ __all__ = [
"GRPODataset",
# Fetchers
"MultiSegmentFetcher",
# Factory (DatasetLoader is alias for backward compatibility)
"DatasetLoader",
# Factory (DatasetFactory is alias for backward compatibility)
"DatasetFactory",
"DatasetFactory",
# Tokenizer and sampler
"BpeTokenizer",

View File

@ -8,6 +8,7 @@ import torch
from torch import Tensor
from torch.utils.data import Dataset
from astrai.core.factory import BaseFactory
from astrai.data.serialization import load_h5
@ -165,7 +166,7 @@ class BaseDataset(Dataset, ABC):
return (self.total_samples - 1 - self.window_size) // self.stride + 1
class DatasetFactory:
class DatasetFactory(BaseFactory["BaseDataset"]):
"""Factory class for creating dataset instances.
Supports decorator-based registration for extensible dataset types.
@ -180,30 +181,16 @@ class DatasetFactory:
dataset = DatasetFactory.create("custom", window_size, stride)
"""
SUPPORTED_TYPES = frozenset({"seq", "sft", "dpo", "grpo"})
DATASET_MAP: Dict[str, type] = {}
_registry: Dict[str, type] = {}
@classmethod
def register(cls, name: str):
"""Decorator to register a new dataset class.
Args:
name: Registration name for the dataset type
Returns:
Decorator function that registers the dataset class
"""
def decorator(dataset_cls: type) -> type:
def _validate_component(cls, dataset_cls: type) -> None:
"""Validate that the dataset class inherits from BaseDataset."""
if not issubclass(dataset_cls, BaseDataset):
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
cls.DATASET_MAP[name] = dataset_cls
return dataset_cls
return decorator
@classmethod
def create(cls, train_type: str, window_size: int, stride: int) -> BaseDataset:
def create(cls, train_type: str, window_size: int, stride: int) -> "BaseDataset":
"""Create a dataset instance.
Args:
@ -214,19 +201,7 @@ class DatasetFactory:
Returns:
Dataset instance
"""
if train_type not in cls.SUPPORTED_TYPES:
raise ValueError(
f"Unknown dataset type: '{train_type}'. "
f"Supported types: {sorted(cls.SUPPORTED_TYPES)}"
)
if train_type not in cls.DATASET_MAP:
raise NotImplementedError(
f"Dataset type '{train_type}' is supported but not yet implemented."
)
dataset_cls = cls.DATASET_MAP[train_type]
return dataset_cls(window_size, stride)
return super().create(train_type, window_size, stride)
@classmethod
def load(
@ -235,7 +210,7 @@ class DatasetFactory:
load_path: str,
window_size: int,
stride: Optional[int] = None,
) -> BaseDataset:
) -> "BaseDataset":
"""Create and load a dataset in one step.
Args:
@ -258,7 +233,7 @@ class DatasetFactory:
@classmethod
def available_types(cls) -> list:
"""Return list of registered dataset type names."""
return list(cls.DATASET_MAP.keys())
return cls.list_registered()
# ============== Dataset Classes ==============
@ -362,7 +337,3 @@ class GRPODataset(BaseDataset):
"masks": masks,
"rewards": rewards,
}
# Backward compatibility alias
DatasetLoader = DatasetFactory

View File

@ -1,10 +1,11 @@
from dataclasses import dataclass
from typing import Generator, List, Optional, Tuple, Union
from typing import Dict, Generator, List, Optional, Tuple, Union
import torch
from torch import Tensor
from astrai.config.param_config import ModelParameter
from astrai.core.factory import BaseFactory
from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager
HistoryType = List[Tuple[str, str]]
@ -254,7 +255,7 @@ class EmbeddingEncoder(EmbeddingEncoderCore):
return super().encode(sentence)
class GeneratorFactory:
class GeneratorFactory(BaseFactory[GeneratorCore]):
"""Factory class for creating generator instances.
Provides smart generator selection based on request characteristics:
@ -263,14 +264,14 @@ class GeneratorFactory:
- Single: Use LoopGenerator for single query non-streaming
Example usage:
generator = GeneratorFactory.create_generator(parameter, request)
generator = GeneratorFactory.create(parameter, request)
result = generator.generate(request)
"""
_registry: Dict[str, type] = {}
@staticmethod
def create_generator(
parameter: ModelParameter, request: GenerationRequest
) -> GeneratorCore:
def create(parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore:
"""Create a generator based on request characteristics.
Args:

View File

@ -6,6 +6,8 @@ from typing import Any, Dict, List, Type
from torch.optim.lr_scheduler import LRScheduler
from astrai.core.factory import BaseFactory
class BaseScheduler(LRScheduler, ABC):
"""Base scheduler class for all other schedulers."""
@ -25,7 +27,7 @@ class BaseScheduler(LRScheduler, ABC):
super().load_state_dict(state_dict)
class SchedulerFactory:
class SchedulerFactory(BaseFactory["BaseScheduler"]):
"""Factory class for creating learning rate schedulers.
Supports decorator-based registration for extensible scheduler types.
@ -36,34 +38,21 @@ class SchedulerFactory:
class CustomScheduler(BaseScheduler):
...
scheduler = SchedulerFactory.create(optimizer, "custom", **kwargs)
scheduler = SchedulerFactory.create("custom", optimizer, **kwargs)
"""
SCHEDULER_MAP: Dict[str, Type[BaseScheduler]] = {}
_registry: Dict[str, Type[BaseScheduler]] = {}
@classmethod
def register(cls, name: str):
"""Decorator to register a new scheduler class.
Args:
name: Registration name for the scheduler
Returns:
Decorator function that registers the scheduler class
"""
def decorator(scheduler_cls: Type[BaseScheduler]) -> Type[BaseScheduler]:
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]) -> None:
"""Validate that the scheduler class inherits from BaseScheduler."""
if not issubclass(scheduler_cls, BaseScheduler):
raise TypeError(
f"{scheduler_cls.__name__} must inherit from BaseScheduler"
)
cls.SCHEDULER_MAP[name] = scheduler_cls
return scheduler_cls
return decorator
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
@classmethod
def create(cls, optimizer, schedule_type: str = "none", **kwargs) -> BaseScheduler:
def create(
cls, optimizer, schedule_type: str = "none", **kwargs
) -> "BaseScheduler":
"""Create a scheduler instance by type name.
Args:
@ -73,23 +62,13 @@ class SchedulerFactory:
Returns:
Scheduler instance
Raises:
ValueError: If schedule_type is not supported
"""
if schedule_type not in cls.SCHEDULER_MAP:
raise ValueError(
f"Unknown schedule type: '{schedule_type}'. "
f"Supported types: {sorted(cls.SCHEDULER_MAP.keys())}"
)
scheduler_cls = cls.SCHEDULER_MAP[schedule_type]
return scheduler_cls(optimizer, **kwargs)
return super().create(schedule_type, optimizer, **kwargs)
@classmethod
def available_types(cls) -> list:
"""Return list of registered scheduler type names."""
return list(cls.SCHEDULER_MAP.keys())
return cls.list_registered()
# ----------- Scheduler implementations -----------

View File

@ -10,6 +10,8 @@ import torch.nn.functional as F
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from astrai.core.factory import BaseFactory
def unwrap_model(model: nn.Module) -> nn.Module:
"""Unwrap DDP wrapper if present to get the original model."""
@ -106,7 +108,7 @@ class BaseStrategy(ABC):
return self.compute_loss(batch)
class StrategyFactory:
class StrategyFactory(BaseFactory["BaseStrategy"]):
"""Factory class for creating training strategy instances.
Supports decorator-based registration for extensible strategy types.
@ -117,68 +119,36 @@ class StrategyFactory:
class CustomStrategy(BaseStrategy):
...
strategy = StrategyFactory.create(model, "custom", device)
strategy = StrategyFactory.create("custom", model, device)
"""
SUPPORTED_STRATEGIES = frozenset({"seq", "sft", "dpo", "grpo"})
STRATEGY_MAP: Dict[str, type] = {}
_registry: Dict[str, type] = {}
@classmethod
def register(cls, name: str):
"""Decorator to register a new strategy class.
Args:
name: Registration name for the strategy
Returns:
Decorator function that registers the strategy class
"""
def decorator(strategy_cls: type) -> type:
def _validate_component(cls, strategy_cls: type) -> None:
"""Validate that the strategy class inherits from BaseStrategy."""
if not issubclass(strategy_cls, BaseStrategy):
raise TypeError(
f"{strategy_cls.__name__} must inherit from BaseStrategy"
)
cls.STRATEGY_MAP[name] = strategy_cls
return strategy_cls
return decorator
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
@classmethod
def create(cls, model, train_type: str, device: str, **kwargs) -> BaseStrategy:
def create(cls, train_type: str, model, device: str, **kwargs) -> "BaseStrategy":
"""Create a strategy instance based on training type.
Args:
model: Model instance for the strategy
train_type: Type of training ("seq", "sft", "dpo", "grpo")
model: Model instance for the strategy
device: Device to run the strategy on
**kwargs: Additional arguments passed to strategy constructor
Returns:
Strategy instance
Raises:
ValueError: If train_type is not supported
NotImplementedError: If train_type is in supported list but not implemented
"""
if train_type not in cls.SUPPORTED_STRATEGIES:
raise ValueError(
f"Unknown training strategy: '{train_type}'. "
f"Supported strategies: {sorted(cls.SUPPORTED_STRATEGIES)}"
)
if train_type not in cls.STRATEGY_MAP:
raise NotImplementedError(
f"Strategy '{train_type}' is supported but not yet implemented."
)
strategy_cls = cls.STRATEGY_MAP[train_type]
return strategy_cls(model, device, **kwargs)
return super().create(train_type, model, device, **kwargs)
@classmethod
def available_strategies(cls) -> list:
"""Return list of registered strategy names."""
return list(cls.STRATEGY_MAP.keys())
return cls.list_registered()
# ============== Strategy Classes ==============

View File

@ -8,7 +8,7 @@ import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from astrai.config import ModelParameter, TrainConfig
from astrai.data import DatasetLoader
from astrai.data import DatasetFactory
from astrai.parallel import get_rank
from astrai.trainer import SchedulerFactory, Trainer
@ -205,7 +205,7 @@ def train(
strategy_kwargs = {"dpo_beta": dpo_beta, "label_smoothing": label_smoothing}
dataset = DatasetLoader.load(
dataset = DatasetFactory.load(
train_type=train_type,
load_path=data_root_path,
window_size=window_size,

View File

@ -23,7 +23,7 @@ def test_dataset_loader_random_paths(base_test_env):
save_h5(test_dir, f"data_{i}", dummy_data)
# Test loading with multiple paths
loaded_dataset = DatasetLoader.load(
loaded_dataset = DatasetFactory.load(
train_type="seq",
load_path=test_dir,
window_size=64,
@ -57,7 +57,7 @@ def test_dpo_strategy_with_random_data(base_test_env):
save_h5(test_dir, "dpo_data", dummy_data)
# Load DPO dataset
dpo_dataset = DatasetLoader.load(
dpo_dataset = DatasetFactory.load(
train_type="dpo",
load_path=test_dir,
window_size=64,
@ -93,7 +93,7 @@ def test_sft_dataset_with_random_data(base_test_env):
save_h5(test_dir, "sft_data", dummy_data)
# Load SFT dataset
sft_dataset = DatasetLoader.load(
sft_dataset = DatasetFactory.load(
train_type="sft",
load_path=test_dir,
window_size=64,
@ -127,7 +127,7 @@ def test_dataset_with_custom_stride(base_test_env):
# Test with custom stride
custom_stride = 32
dataset = DatasetLoader.load(
dataset = DatasetFactory.load(
train_type="seq", load_path=test_dir, window_size=64, stride=custom_stride
)
@ -136,7 +136,7 @@ def test_dataset_with_custom_stride(base_test_env):
# With stride 32 and window 64 on 200 length data, we should get more samples
# than with default stride (which equals window size)
default_stride_dataset = DatasetLoader.load(
default_stride_dataset = DatasetFactory.load(
train_type="seq",
load_path=test_dir,
window_size=64,