refactor: 优化工厂模式结构
This commit is contained in:
parent
7861af12e4
commit
e97536758f
|
|
@ -5,7 +5,8 @@ from astrai.config import (
|
|||
ModelConfig,
|
||||
TrainConfig,
|
||||
)
|
||||
from astrai.data import BpeTokenizer, DatasetLoader
|
||||
from astrai.core.factory import BaseFactory
|
||||
from astrai.data import BpeTokenizer, DatasetFactory
|
||||
from astrai.inference.generator import (
|
||||
BatchGenerator,
|
||||
EmbeddingEncoder,
|
||||
|
|
@ -21,7 +22,7 @@ __all__ = [
|
|||
"Transformer",
|
||||
"ModelConfig",
|
||||
"TrainConfig",
|
||||
"DatasetLoader",
|
||||
"DatasetFactory",
|
||||
"BpeTokenizer",
|
||||
"GenerationRequest",
|
||||
"LoopGenerator",
|
||||
|
|
@ -32,4 +33,5 @@ __all__ = [
|
|||
"Trainer",
|
||||
"StrategyFactory",
|
||||
"SchedulerFactory",
|
||||
"BaseFactory",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,105 @@
|
|||
"""Base factory class for extensible component registration."""
|
||||
|
||||
from abc import ABC
|
||||
from typing import Callable, Dict, Generic, Type, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BaseFactory(ABC, Generic[T]):
|
||||
"""Generic factory class for component registration and creation.
|
||||
|
||||
This base class provides a decorator-based registration pattern
|
||||
for creating extensible component factories.
|
||||
|
||||
Example usage:
|
||||
class MyFactory(BaseFactory[MyBaseClass]):
|
||||
pass
|
||||
|
||||
@MyFactory.register("custom")
|
||||
class CustomComponent(MyBaseClass):
|
||||
...
|
||||
|
||||
component = MyFactory.create("custom", *args, **kwargs)
|
||||
"""
|
||||
|
||||
_registry: Dict[str, Type[T]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str) -> Callable[[Type[T]], Type[T]]:
|
||||
"""Decorator to register a component class.
|
||||
|
||||
Args:
|
||||
name: Registration name for the component
|
||||
|
||||
Returns:
|
||||
Decorator function that registers the component class
|
||||
|
||||
Raises:
|
||||
TypeError: If the decorated class doesn't inherit from the base type
|
||||
"""
|
||||
|
||||
def decorator(component_cls: Type[T]) -> Type[T]:
|
||||
cls._validate_component(component_cls)
|
||||
cls._registry[name] = component_cls
|
||||
return component_cls
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def create(cls, name: str, *args, **kwargs) -> T:
|
||||
"""Create a component instance by name.
|
||||
|
||||
Args:
|
||||
name: Registered name of the component
|
||||
*args: Positional arguments passed to component constructor
|
||||
**kwargs: Keyword arguments passed to component constructor
|
||||
|
||||
Returns:
|
||||
Component instance
|
||||
|
||||
Raises:
|
||||
ValueError: If the component name is not registered
|
||||
"""
|
||||
if name not in cls._registry:
|
||||
raise ValueError(
|
||||
f"Unknown component: '{name}'. "
|
||||
f"Supported types: {sorted(cls._registry.keys())}"
|
||||
)
|
||||
component_cls = cls._registry[name]
|
||||
return component_cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, component_cls: Type[T]) -> None:
|
||||
"""Validate that the component class is valid for this factory.
|
||||
|
||||
Override this method in subclasses to add custom validation.
|
||||
|
||||
Args:
|
||||
component_cls: Component class to validate
|
||||
|
||||
Raises:
|
||||
TypeError: If the component class is invalid
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def list_registered(cls) -> list:
|
||||
"""List all registered component names.
|
||||
|
||||
Returns:
|
||||
List of registered component names
|
||||
"""
|
||||
return sorted(cls._registry.keys())
|
||||
|
||||
@classmethod
|
||||
def is_registered(cls, name: str) -> bool:
|
||||
"""Check if a component name is registered.
|
||||
|
||||
Args:
|
||||
name: Component name to check
|
||||
|
||||
Returns:
|
||||
True if registered, False otherwise
|
||||
"""
|
||||
return name in cls._registry
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
from astrai.data.dataset import (
|
||||
BaseDataset,
|
||||
DatasetFactory,
|
||||
DatasetLoader,
|
||||
DatasetFactory,
|
||||
DPODataset,
|
||||
GRPODataset,
|
||||
MultiSegmentFetcher,
|
||||
|
|
@ -21,8 +21,8 @@ __all__ = [
|
|||
"GRPODataset",
|
||||
# Fetchers
|
||||
"MultiSegmentFetcher",
|
||||
# Factory (DatasetLoader is alias for backward compatibility)
|
||||
"DatasetLoader",
|
||||
# Factory (DatasetFactory is alias for backward compatibility)
|
||||
"DatasetFactory",
|
||||
"DatasetFactory",
|
||||
# Tokenizer and sampler
|
||||
"BpeTokenizer",
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import torch
|
|||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from astrai.core.factory import BaseFactory
|
||||
from astrai.data.serialization import load_h5
|
||||
|
||||
|
||||
|
|
@ -165,7 +166,7 @@ class BaseDataset(Dataset, ABC):
|
|||
return (self.total_samples - 1 - self.window_size) // self.stride + 1
|
||||
|
||||
|
||||
class DatasetFactory:
|
||||
class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||
"""Factory class for creating dataset instances.
|
||||
|
||||
Supports decorator-based registration for extensible dataset types.
|
||||
|
|
@ -180,30 +181,16 @@ class DatasetFactory:
|
|||
dataset = DatasetFactory.create("custom", window_size, stride)
|
||||
"""
|
||||
|
||||
SUPPORTED_TYPES = frozenset({"seq", "sft", "dpo", "grpo"})
|
||||
DATASET_MAP: Dict[str, type] = {}
|
||||
_registry: Dict[str, type] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str):
|
||||
"""Decorator to register a new dataset class.
|
||||
|
||||
Args:
|
||||
name: Registration name for the dataset type
|
||||
|
||||
Returns:
|
||||
Decorator function that registers the dataset class
|
||||
"""
|
||||
|
||||
def decorator(dataset_cls: type) -> type:
|
||||
def _validate_component(cls, dataset_cls: type) -> None:
|
||||
"""Validate that the dataset class inherits from BaseDataset."""
|
||||
if not issubclass(dataset_cls, BaseDataset):
|
||||
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
|
||||
cls.DATASET_MAP[name] = dataset_cls
|
||||
return dataset_cls
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def create(cls, train_type: str, window_size: int, stride: int) -> BaseDataset:
|
||||
def create(cls, train_type: str, window_size: int, stride: int) -> "BaseDataset":
|
||||
"""Create a dataset instance.
|
||||
|
||||
Args:
|
||||
|
|
@ -214,19 +201,7 @@ class DatasetFactory:
|
|||
Returns:
|
||||
Dataset instance
|
||||
"""
|
||||
if train_type not in cls.SUPPORTED_TYPES:
|
||||
raise ValueError(
|
||||
f"Unknown dataset type: '{train_type}'. "
|
||||
f"Supported types: {sorted(cls.SUPPORTED_TYPES)}"
|
||||
)
|
||||
|
||||
if train_type not in cls.DATASET_MAP:
|
||||
raise NotImplementedError(
|
||||
f"Dataset type '{train_type}' is supported but not yet implemented."
|
||||
)
|
||||
|
||||
dataset_cls = cls.DATASET_MAP[train_type]
|
||||
return dataset_cls(window_size, stride)
|
||||
return super().create(train_type, window_size, stride)
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
|
|
@ -235,7 +210,7 @@ class DatasetFactory:
|
|||
load_path: str,
|
||||
window_size: int,
|
||||
stride: Optional[int] = None,
|
||||
) -> BaseDataset:
|
||||
) -> "BaseDataset":
|
||||
"""Create and load a dataset in one step.
|
||||
|
||||
Args:
|
||||
|
|
@ -258,7 +233,7 @@ class DatasetFactory:
|
|||
@classmethod
|
||||
def available_types(cls) -> list:
|
||||
"""Return list of registered dataset type names."""
|
||||
return list(cls.DATASET_MAP.keys())
|
||||
return cls.list_registered()
|
||||
|
||||
|
||||
# ============== Dataset Classes ==============
|
||||
|
|
@ -362,7 +337,3 @@ class GRPODataset(BaseDataset):
|
|||
"masks": masks,
|
||||
"rewards": rewards,
|
||||
}
|
||||
|
||||
|
||||
# Backward compatibility alias
|
||||
DatasetLoader = DatasetFactory
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Generator, List, Optional, Tuple, Union
|
||||
from typing import Dict, Generator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.config.param_config import ModelParameter
|
||||
from astrai.core.factory import BaseFactory
|
||||
from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager
|
||||
|
||||
HistoryType = List[Tuple[str, str]]
|
||||
|
|
@ -254,7 +255,7 @@ class EmbeddingEncoder(EmbeddingEncoderCore):
|
|||
return super().encode(sentence)
|
||||
|
||||
|
||||
class GeneratorFactory:
|
||||
class GeneratorFactory(BaseFactory[GeneratorCore]):
|
||||
"""Factory class for creating generator instances.
|
||||
|
||||
Provides smart generator selection based on request characteristics:
|
||||
|
|
@ -263,14 +264,14 @@ class GeneratorFactory:
|
|||
- Single: Use LoopGenerator for single query non-streaming
|
||||
|
||||
Example usage:
|
||||
generator = GeneratorFactory.create_generator(parameter, request)
|
||||
generator = GeneratorFactory.create(parameter, request)
|
||||
result = generator.generate(request)
|
||||
"""
|
||||
|
||||
_registry: Dict[str, type] = {}
|
||||
|
||||
@staticmethod
|
||||
def create_generator(
|
||||
parameter: ModelParameter, request: GenerationRequest
|
||||
) -> GeneratorCore:
|
||||
def create(parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore:
|
||||
"""Create a generator based on request characteristics.
|
||||
|
||||
Args:
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ from typing import Any, Dict, List, Type
|
|||
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from astrai.core.factory import BaseFactory
|
||||
|
||||
|
||||
class BaseScheduler(LRScheduler, ABC):
|
||||
"""Base scheduler class for all other schedulers."""
|
||||
|
|
@ -25,7 +27,7 @@ class BaseScheduler(LRScheduler, ABC):
|
|||
super().load_state_dict(state_dict)
|
||||
|
||||
|
||||
class SchedulerFactory:
|
||||
class SchedulerFactory(BaseFactory["BaseScheduler"]):
|
||||
"""Factory class for creating learning rate schedulers.
|
||||
|
||||
Supports decorator-based registration for extensible scheduler types.
|
||||
|
|
@ -36,34 +38,21 @@ class SchedulerFactory:
|
|||
class CustomScheduler(BaseScheduler):
|
||||
...
|
||||
|
||||
scheduler = SchedulerFactory.create(optimizer, "custom", **kwargs)
|
||||
scheduler = SchedulerFactory.create("custom", optimizer, **kwargs)
|
||||
"""
|
||||
|
||||
SCHEDULER_MAP: Dict[str, Type[BaseScheduler]] = {}
|
||||
_registry: Dict[str, Type[BaseScheduler]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str):
|
||||
"""Decorator to register a new scheduler class.
|
||||
|
||||
Args:
|
||||
name: Registration name for the scheduler
|
||||
|
||||
Returns:
|
||||
Decorator function that registers the scheduler class
|
||||
"""
|
||||
|
||||
def decorator(scheduler_cls: Type[BaseScheduler]) -> Type[BaseScheduler]:
|
||||
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]) -> None:
|
||||
"""Validate that the scheduler class inherits from BaseScheduler."""
|
||||
if not issubclass(scheduler_cls, BaseScheduler):
|
||||
raise TypeError(
|
||||
f"{scheduler_cls.__name__} must inherit from BaseScheduler"
|
||||
)
|
||||
cls.SCHEDULER_MAP[name] = scheduler_cls
|
||||
return scheduler_cls
|
||||
|
||||
return decorator
|
||||
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
|
||||
|
||||
@classmethod
|
||||
def create(cls, optimizer, schedule_type: str = "none", **kwargs) -> BaseScheduler:
|
||||
def create(
|
||||
cls, optimizer, schedule_type: str = "none", **kwargs
|
||||
) -> "BaseScheduler":
|
||||
"""Create a scheduler instance by type name.
|
||||
|
||||
Args:
|
||||
|
|
@ -73,23 +62,13 @@ class SchedulerFactory:
|
|||
|
||||
Returns:
|
||||
Scheduler instance
|
||||
|
||||
Raises:
|
||||
ValueError: If schedule_type is not supported
|
||||
"""
|
||||
if schedule_type not in cls.SCHEDULER_MAP:
|
||||
raise ValueError(
|
||||
f"Unknown schedule type: '{schedule_type}'. "
|
||||
f"Supported types: {sorted(cls.SCHEDULER_MAP.keys())}"
|
||||
)
|
||||
|
||||
scheduler_cls = cls.SCHEDULER_MAP[schedule_type]
|
||||
return scheduler_cls(optimizer, **kwargs)
|
||||
return super().create(schedule_type, optimizer, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def available_types(cls) -> list:
|
||||
"""Return list of registered scheduler type names."""
|
||||
return list(cls.SCHEDULER_MAP.keys())
|
||||
return cls.list_registered()
|
||||
|
||||
|
||||
# ----------- Scheduler implementations -----------
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ import torch.nn.functional as F
|
|||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from astrai.core.factory import BaseFactory
|
||||
|
||||
|
||||
def unwrap_model(model: nn.Module) -> nn.Module:
|
||||
"""Unwrap DDP wrapper if present to get the original model."""
|
||||
|
|
@ -106,7 +108,7 @@ class BaseStrategy(ABC):
|
|||
return self.compute_loss(batch)
|
||||
|
||||
|
||||
class StrategyFactory:
|
||||
class StrategyFactory(BaseFactory["BaseStrategy"]):
|
||||
"""Factory class for creating training strategy instances.
|
||||
|
||||
Supports decorator-based registration for extensible strategy types.
|
||||
|
|
@ -117,68 +119,36 @@ class StrategyFactory:
|
|||
class CustomStrategy(BaseStrategy):
|
||||
...
|
||||
|
||||
strategy = StrategyFactory.create(model, "custom", device)
|
||||
strategy = StrategyFactory.create("custom", model, device)
|
||||
"""
|
||||
|
||||
SUPPORTED_STRATEGIES = frozenset({"seq", "sft", "dpo", "grpo"})
|
||||
STRATEGY_MAP: Dict[str, type] = {}
|
||||
_registry: Dict[str, type] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str):
|
||||
"""Decorator to register a new strategy class.
|
||||
|
||||
Args:
|
||||
name: Registration name for the strategy
|
||||
|
||||
Returns:
|
||||
Decorator function that registers the strategy class
|
||||
"""
|
||||
|
||||
def decorator(strategy_cls: type) -> type:
|
||||
def _validate_component(cls, strategy_cls: type) -> None:
|
||||
"""Validate that the strategy class inherits from BaseStrategy."""
|
||||
if not issubclass(strategy_cls, BaseStrategy):
|
||||
raise TypeError(
|
||||
f"{strategy_cls.__name__} must inherit from BaseStrategy"
|
||||
)
|
||||
cls.STRATEGY_MAP[name] = strategy_cls
|
||||
return strategy_cls
|
||||
|
||||
return decorator
|
||||
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
|
||||
|
||||
@classmethod
|
||||
def create(cls, model, train_type: str, device: str, **kwargs) -> BaseStrategy:
|
||||
def create(cls, train_type: str, model, device: str, **kwargs) -> "BaseStrategy":
|
||||
"""Create a strategy instance based on training type.
|
||||
|
||||
Args:
|
||||
model: Model instance for the strategy
|
||||
train_type: Type of training ("seq", "sft", "dpo", "grpo")
|
||||
model: Model instance for the strategy
|
||||
device: Device to run the strategy on
|
||||
**kwargs: Additional arguments passed to strategy constructor
|
||||
|
||||
Returns:
|
||||
Strategy instance
|
||||
|
||||
Raises:
|
||||
ValueError: If train_type is not supported
|
||||
NotImplementedError: If train_type is in supported list but not implemented
|
||||
"""
|
||||
if train_type not in cls.SUPPORTED_STRATEGIES:
|
||||
raise ValueError(
|
||||
f"Unknown training strategy: '{train_type}'. "
|
||||
f"Supported strategies: {sorted(cls.SUPPORTED_STRATEGIES)}"
|
||||
)
|
||||
|
||||
if train_type not in cls.STRATEGY_MAP:
|
||||
raise NotImplementedError(
|
||||
f"Strategy '{train_type}' is supported but not yet implemented."
|
||||
)
|
||||
|
||||
strategy_cls = cls.STRATEGY_MAP[train_type]
|
||||
return strategy_cls(model, device, **kwargs)
|
||||
return super().create(train_type, model, device, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def available_strategies(cls) -> list:
|
||||
"""Return list of registered strategy names."""
|
||||
return list(cls.STRATEGY_MAP.keys())
|
||||
return cls.list_registered()
|
||||
|
||||
|
||||
# ============== Strategy Classes ==============
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import torch.optim as optim
|
|||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from astrai.config import ModelParameter, TrainConfig
|
||||
from astrai.data import DatasetLoader
|
||||
from astrai.data import DatasetFactory
|
||||
from astrai.parallel import get_rank
|
||||
from astrai.trainer import SchedulerFactory, Trainer
|
||||
|
||||
|
|
@ -205,7 +205,7 @@ def train(
|
|||
|
||||
strategy_kwargs = {"dpo_beta": dpo_beta, "label_smoothing": label_smoothing}
|
||||
|
||||
dataset = DatasetLoader.load(
|
||||
dataset = DatasetFactory.load(
|
||||
train_type=train_type,
|
||||
load_path=data_root_path,
|
||||
window_size=window_size,
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ def test_dataset_loader_random_paths(base_test_env):
|
|||
save_h5(test_dir, f"data_{i}", dummy_data)
|
||||
|
||||
# Test loading with multiple paths
|
||||
loaded_dataset = DatasetLoader.load(
|
||||
loaded_dataset = DatasetFactory.load(
|
||||
train_type="seq",
|
||||
load_path=test_dir,
|
||||
window_size=64,
|
||||
|
|
@ -57,7 +57,7 @@ def test_dpo_strategy_with_random_data(base_test_env):
|
|||
save_h5(test_dir, "dpo_data", dummy_data)
|
||||
|
||||
# Load DPO dataset
|
||||
dpo_dataset = DatasetLoader.load(
|
||||
dpo_dataset = DatasetFactory.load(
|
||||
train_type="dpo",
|
||||
load_path=test_dir,
|
||||
window_size=64,
|
||||
|
|
@ -93,7 +93,7 @@ def test_sft_dataset_with_random_data(base_test_env):
|
|||
save_h5(test_dir, "sft_data", dummy_data)
|
||||
|
||||
# Load SFT dataset
|
||||
sft_dataset = DatasetLoader.load(
|
||||
sft_dataset = DatasetFactory.load(
|
||||
train_type="sft",
|
||||
load_path=test_dir,
|
||||
window_size=64,
|
||||
|
|
@ -127,7 +127,7 @@ def test_dataset_with_custom_stride(base_test_env):
|
|||
|
||||
# Test with custom stride
|
||||
custom_stride = 32
|
||||
dataset = DatasetLoader.load(
|
||||
dataset = DatasetFactory.load(
|
||||
train_type="seq", load_path=test_dir, window_size=64, stride=custom_stride
|
||||
)
|
||||
|
||||
|
|
@ -136,7 +136,7 @@ def test_dataset_with_custom_stride(base_test_env):
|
|||
|
||||
# With stride 32 and window 64 on 200 length data, we should get more samples
|
||||
# than with default stride (which equals window size)
|
||||
default_stride_dataset = DatasetLoader.load(
|
||||
default_stride_dataset = DatasetFactory.load(
|
||||
train_type="seq",
|
||||
load_path=test_dir,
|
||||
window_size=64,
|
||||
|
|
|
|||
Loading…
Reference in New Issue