From e97536758f4bd537beae87f9f532d7eca9d3fe65 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 4 Apr 2026 11:33:58 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E4=BC=98=E5=8C=96=E5=B7=A5?= =?UTF-8?q?=E5=8E=82=E6=A8=A1=E5=BC=8F=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrai/__init__.py | 6 +- astrai/core/factory.py | 105 ++++++++++++++++++++++++++++++++++ astrai/data/__init__.py | 6 +- astrai/data/dataset.py | 51 ++++------------- astrai/inference/generator.py | 13 +++-- astrai/trainer/schedule.py | 49 +++++----------- astrai/trainer/strategy.py | 56 +++++------------- scripts/tools/train.py | 4 +- tests/data/test_dataset.py | 10 ++-- 9 files changed, 164 insertions(+), 136 deletions(-) create mode 100644 astrai/core/factory.py diff --git a/astrai/__init__.py b/astrai/__init__.py index 167defd..4eeda4f 100644 --- a/astrai/__init__.py +++ b/astrai/__init__.py @@ -5,7 +5,8 @@ from astrai.config import ( ModelConfig, TrainConfig, ) -from astrai.data import BpeTokenizer, DatasetLoader +from astrai.core.factory import BaseFactory +from astrai.data import BpeTokenizer, DatasetFactory from astrai.inference.generator import ( BatchGenerator, EmbeddingEncoder, @@ -21,7 +22,7 @@ __all__ = [ "Transformer", "ModelConfig", "TrainConfig", - "DatasetLoader", + "DatasetFactory", "BpeTokenizer", "GenerationRequest", "LoopGenerator", @@ -32,4 +33,5 @@ __all__ = [ "Trainer", "StrategyFactory", "SchedulerFactory", + "BaseFactory", ] diff --git a/astrai/core/factory.py b/astrai/core/factory.py new file mode 100644 index 0000000..2c61ac4 --- /dev/null +++ b/astrai/core/factory.py @@ -0,0 +1,105 @@ +"""Base factory class for extensible component registration.""" + +from abc import ABC +from typing import Callable, Dict, Generic, Type, TypeVar + +T = TypeVar("T") + + +class BaseFactory(ABC, Generic[T]): + """Generic factory class for component registration and creation. + + This base class provides a decorator-based registration pattern + for creating extensible component factories. + + Example usage: + class MyFactory(BaseFactory[MyBaseClass]): + pass + + @MyFactory.register("custom") + class CustomComponent(MyBaseClass): + ... + + component = MyFactory.create("custom", *args, **kwargs) + """ + + _registry: Dict[str, Type[T]] = {} + + @classmethod + def register(cls, name: str) -> Callable[[Type[T]], Type[T]]: + """Decorator to register a component class. + + Args: + name: Registration name for the component + + Returns: + Decorator function that registers the component class + + Raises: + TypeError: If the decorated class doesn't inherit from the base type + """ + + def decorator(component_cls: Type[T]) -> Type[T]: + cls._validate_component(component_cls) + cls._registry[name] = component_cls + return component_cls + + return decorator + + @classmethod + def create(cls, name: str, *args, **kwargs) -> T: + """Create a component instance by name. + + Args: + name: Registered name of the component + *args: Positional arguments passed to component constructor + **kwargs: Keyword arguments passed to component constructor + + Returns: + Component instance + + Raises: + ValueError: If the component name is not registered + """ + if name not in cls._registry: + raise ValueError( + f"Unknown component: '{name}'. " + f"Supported types: {sorted(cls._registry.keys())}" + ) + component_cls = cls._registry[name] + return component_cls(*args, **kwargs) + + @classmethod + def _validate_component(cls, component_cls: Type[T]) -> None: + """Validate that the component class is valid for this factory. + + Override this method in subclasses to add custom validation. + + Args: + component_cls: Component class to validate + + Raises: + TypeError: If the component class is invalid + """ + pass + + @classmethod + def list_registered(cls) -> list: + """List all registered component names. + + Returns: + List of registered component names + """ + return sorted(cls._registry.keys()) + + @classmethod + def is_registered(cls, name: str) -> bool: + """Check if a component name is registered. + + Args: + name: Component name to check + + Returns: + True if registered, False otherwise + """ + return name in cls._registry diff --git a/astrai/data/__init__.py b/astrai/data/__init__.py index 02c33a5..7cc418d 100644 --- a/astrai/data/__init__.py +++ b/astrai/data/__init__.py @@ -1,7 +1,7 @@ from astrai.data.dataset import ( BaseDataset, DatasetFactory, - DatasetLoader, + DatasetFactory, DPODataset, GRPODataset, MultiSegmentFetcher, @@ -21,8 +21,8 @@ __all__ = [ "GRPODataset", # Fetchers "MultiSegmentFetcher", - # Factory (DatasetLoader is alias for backward compatibility) - "DatasetLoader", + # Factory (DatasetFactory is alias for backward compatibility) + "DatasetFactory", "DatasetFactory", # Tokenizer and sampler "BpeTokenizer", diff --git a/astrai/data/dataset.py b/astrai/data/dataset.py index 258c19f..395ec2a 100644 --- a/astrai/data/dataset.py +++ b/astrai/data/dataset.py @@ -8,6 +8,7 @@ import torch from torch import Tensor from torch.utils.data import Dataset +from astrai.core.factory import BaseFactory from astrai.data.serialization import load_h5 @@ -165,7 +166,7 @@ class BaseDataset(Dataset, ABC): return (self.total_samples - 1 - self.window_size) // self.stride + 1 -class DatasetFactory: +class DatasetFactory(BaseFactory["BaseDataset"]): """Factory class for creating dataset instances. Supports decorator-based registration for extensible dataset types. @@ -180,30 +181,16 @@ class DatasetFactory: dataset = DatasetFactory.create("custom", window_size, stride) """ - SUPPORTED_TYPES = frozenset({"seq", "sft", "dpo", "grpo"}) - DATASET_MAP: Dict[str, type] = {} + _registry: Dict[str, type] = {} @classmethod - def register(cls, name: str): - """Decorator to register a new dataset class. - - Args: - name: Registration name for the dataset type - - Returns: - Decorator function that registers the dataset class - """ - - def decorator(dataset_cls: type) -> type: - if not issubclass(dataset_cls, BaseDataset): - raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset") - cls.DATASET_MAP[name] = dataset_cls - return dataset_cls - - return decorator + def _validate_component(cls, dataset_cls: type) -> None: + """Validate that the dataset class inherits from BaseDataset.""" + if not issubclass(dataset_cls, BaseDataset): + raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset") @classmethod - def create(cls, train_type: str, window_size: int, stride: int) -> BaseDataset: + def create(cls, train_type: str, window_size: int, stride: int) -> "BaseDataset": """Create a dataset instance. Args: @@ -214,19 +201,7 @@ class DatasetFactory: Returns: Dataset instance """ - if train_type not in cls.SUPPORTED_TYPES: - raise ValueError( - f"Unknown dataset type: '{train_type}'. " - f"Supported types: {sorted(cls.SUPPORTED_TYPES)}" - ) - - if train_type not in cls.DATASET_MAP: - raise NotImplementedError( - f"Dataset type '{train_type}' is supported but not yet implemented." - ) - - dataset_cls = cls.DATASET_MAP[train_type] - return dataset_cls(window_size, stride) + return super().create(train_type, window_size, stride) @classmethod def load( @@ -235,7 +210,7 @@ class DatasetFactory: load_path: str, window_size: int, stride: Optional[int] = None, - ) -> BaseDataset: + ) -> "BaseDataset": """Create and load a dataset in one step. Args: @@ -258,7 +233,7 @@ class DatasetFactory: @classmethod def available_types(cls) -> list: """Return list of registered dataset type names.""" - return list(cls.DATASET_MAP.keys()) + return cls.list_registered() # ============== Dataset Classes ============== @@ -362,7 +337,3 @@ class GRPODataset(BaseDataset): "masks": masks, "rewards": rewards, } - - -# Backward compatibility alias -DatasetLoader = DatasetFactory diff --git a/astrai/inference/generator.py b/astrai/inference/generator.py index a3600e3..a99f7e3 100644 --- a/astrai/inference/generator.py +++ b/astrai/inference/generator.py @@ -1,10 +1,11 @@ from dataclasses import dataclass -from typing import Generator, List, Optional, Tuple, Union +from typing import Dict, Generator, List, Optional, Tuple, Union import torch from torch import Tensor from astrai.config.param_config import ModelParameter +from astrai.core.factory import BaseFactory from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager HistoryType = List[Tuple[str, str]] @@ -254,7 +255,7 @@ class EmbeddingEncoder(EmbeddingEncoderCore): return super().encode(sentence) -class GeneratorFactory: +class GeneratorFactory(BaseFactory[GeneratorCore]): """Factory class for creating generator instances. Provides smart generator selection based on request characteristics: @@ -263,14 +264,14 @@ class GeneratorFactory: - Single: Use LoopGenerator for single query non-streaming Example usage: - generator = GeneratorFactory.create_generator(parameter, request) + generator = GeneratorFactory.create(parameter, request) result = generator.generate(request) """ + _registry: Dict[str, type] = {} + @staticmethod - def create_generator( - parameter: ModelParameter, request: GenerationRequest - ) -> GeneratorCore: + def create(parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore: """Create a generator based on request characteristics. Args: diff --git a/astrai/trainer/schedule.py b/astrai/trainer/schedule.py index 339a702..0ca166a 100644 --- a/astrai/trainer/schedule.py +++ b/astrai/trainer/schedule.py @@ -6,6 +6,8 @@ from typing import Any, Dict, List, Type from torch.optim.lr_scheduler import LRScheduler +from astrai.core.factory import BaseFactory + class BaseScheduler(LRScheduler, ABC): """Base scheduler class for all other schedulers.""" @@ -25,7 +27,7 @@ class BaseScheduler(LRScheduler, ABC): super().load_state_dict(state_dict) -class SchedulerFactory: +class SchedulerFactory(BaseFactory["BaseScheduler"]): """Factory class for creating learning rate schedulers. Supports decorator-based registration for extensible scheduler types. @@ -36,34 +38,21 @@ class SchedulerFactory: class CustomScheduler(BaseScheduler): ... - scheduler = SchedulerFactory.create(optimizer, "custom", **kwargs) + scheduler = SchedulerFactory.create("custom", optimizer, **kwargs) """ - SCHEDULER_MAP: Dict[str, Type[BaseScheduler]] = {} + _registry: Dict[str, Type[BaseScheduler]] = {} @classmethod - def register(cls, name: str): - """Decorator to register a new scheduler class. - - Args: - name: Registration name for the scheduler - - Returns: - Decorator function that registers the scheduler class - """ - - def decorator(scheduler_cls: Type[BaseScheduler]) -> Type[BaseScheduler]: - if not issubclass(scheduler_cls, BaseScheduler): - raise TypeError( - f"{scheduler_cls.__name__} must inherit from BaseScheduler" - ) - cls.SCHEDULER_MAP[name] = scheduler_cls - return scheduler_cls - - return decorator + def _validate_component(cls, scheduler_cls: Type[BaseScheduler]) -> None: + """Validate that the scheduler class inherits from BaseScheduler.""" + if not issubclass(scheduler_cls, BaseScheduler): + raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler") @classmethod - def create(cls, optimizer, schedule_type: str = "none", **kwargs) -> BaseScheduler: + def create( + cls, optimizer, schedule_type: str = "none", **kwargs + ) -> "BaseScheduler": """Create a scheduler instance by type name. Args: @@ -73,23 +62,13 @@ class SchedulerFactory: Returns: Scheduler instance - - Raises: - ValueError: If schedule_type is not supported """ - if schedule_type not in cls.SCHEDULER_MAP: - raise ValueError( - f"Unknown schedule type: '{schedule_type}'. " - f"Supported types: {sorted(cls.SCHEDULER_MAP.keys())}" - ) - - scheduler_cls = cls.SCHEDULER_MAP[schedule_type] - return scheduler_cls(optimizer, **kwargs) + return super().create(schedule_type, optimizer, **kwargs) @classmethod def available_types(cls) -> list: """Return list of registered scheduler type names.""" - return list(cls.SCHEDULER_MAP.keys()) + return cls.list_registered() # ----------- Scheduler implementations ----------- diff --git a/astrai/trainer/strategy.py b/astrai/trainer/strategy.py index 445cd86..d6a27e8 100644 --- a/astrai/trainer/strategy.py +++ b/astrai/trainer/strategy.py @@ -10,6 +10,8 @@ import torch.nn.functional as F from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP +from astrai.core.factory import BaseFactory + def unwrap_model(model: nn.Module) -> nn.Module: """Unwrap DDP wrapper if present to get the original model.""" @@ -106,7 +108,7 @@ class BaseStrategy(ABC): return self.compute_loss(batch) -class StrategyFactory: +class StrategyFactory(BaseFactory["BaseStrategy"]): """Factory class for creating training strategy instances. Supports decorator-based registration for extensible strategy types. @@ -117,68 +119,36 @@ class StrategyFactory: class CustomStrategy(BaseStrategy): ... - strategy = StrategyFactory.create(model, "custom", device) + strategy = StrategyFactory.create("custom", model, device) """ - SUPPORTED_STRATEGIES = frozenset({"seq", "sft", "dpo", "grpo"}) - STRATEGY_MAP: Dict[str, type] = {} + _registry: Dict[str, type] = {} @classmethod - def register(cls, name: str): - """Decorator to register a new strategy class. - - Args: - name: Registration name for the strategy - - Returns: - Decorator function that registers the strategy class - """ - - def decorator(strategy_cls: type) -> type: - if not issubclass(strategy_cls, BaseStrategy): - raise TypeError( - f"{strategy_cls.__name__} must inherit from BaseStrategy" - ) - cls.STRATEGY_MAP[name] = strategy_cls - return strategy_cls - - return decorator + def _validate_component(cls, strategy_cls: type) -> None: + """Validate that the strategy class inherits from BaseStrategy.""" + if not issubclass(strategy_cls, BaseStrategy): + raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy") @classmethod - def create(cls, model, train_type: str, device: str, **kwargs) -> BaseStrategy: + def create(cls, train_type: str, model, device: str, **kwargs) -> "BaseStrategy": """Create a strategy instance based on training type. Args: - model: Model instance for the strategy train_type: Type of training ("seq", "sft", "dpo", "grpo") + model: Model instance for the strategy device: Device to run the strategy on **kwargs: Additional arguments passed to strategy constructor Returns: Strategy instance - - Raises: - ValueError: If train_type is not supported - NotImplementedError: If train_type is in supported list but not implemented """ - if train_type not in cls.SUPPORTED_STRATEGIES: - raise ValueError( - f"Unknown training strategy: '{train_type}'. " - f"Supported strategies: {sorted(cls.SUPPORTED_STRATEGIES)}" - ) - - if train_type not in cls.STRATEGY_MAP: - raise NotImplementedError( - f"Strategy '{train_type}' is supported but not yet implemented." - ) - - strategy_cls = cls.STRATEGY_MAP[train_type] - return strategy_cls(model, device, **kwargs) + return super().create(train_type, model, device, **kwargs) @classmethod def available_strategies(cls) -> list: """Return list of registered strategy names.""" - return list(cls.STRATEGY_MAP.keys()) + return cls.list_registered() # ============== Strategy Classes ============== diff --git a/scripts/tools/train.py b/scripts/tools/train.py index 903873e..ec14cff 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -8,7 +8,7 @@ import torch.optim as optim from torch.nn.parallel import DistributedDataParallel as DDP from astrai.config import ModelParameter, TrainConfig -from astrai.data import DatasetLoader +from astrai.data import DatasetFactory from astrai.parallel import get_rank from astrai.trainer import SchedulerFactory, Trainer @@ -205,7 +205,7 @@ def train( strategy_kwargs = {"dpo_beta": dpo_beta, "label_smoothing": label_smoothing} - dataset = DatasetLoader.load( + dataset = DatasetFactory.load( train_type=train_type, load_path=data_root_path, window_size=window_size, diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index b54013d..6de18af 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -23,7 +23,7 @@ def test_dataset_loader_random_paths(base_test_env): save_h5(test_dir, f"data_{i}", dummy_data) # Test loading with multiple paths - loaded_dataset = DatasetLoader.load( + loaded_dataset = DatasetFactory.load( train_type="seq", load_path=test_dir, window_size=64, @@ -57,7 +57,7 @@ def test_dpo_strategy_with_random_data(base_test_env): save_h5(test_dir, "dpo_data", dummy_data) # Load DPO dataset - dpo_dataset = DatasetLoader.load( + dpo_dataset = DatasetFactory.load( train_type="dpo", load_path=test_dir, window_size=64, @@ -93,7 +93,7 @@ def test_sft_dataset_with_random_data(base_test_env): save_h5(test_dir, "sft_data", dummy_data) # Load SFT dataset - sft_dataset = DatasetLoader.load( + sft_dataset = DatasetFactory.load( train_type="sft", load_path=test_dir, window_size=64, @@ -127,7 +127,7 @@ def test_dataset_with_custom_stride(base_test_env): # Test with custom stride custom_stride = 32 - dataset = DatasetLoader.load( + dataset = DatasetFactory.load( train_type="seq", load_path=test_dir, window_size=64, stride=custom_stride ) @@ -136,7 +136,7 @@ def test_dataset_with_custom_stride(base_test_env): # With stride 32 and window 64 on 200 length data, we should get more samples # than with default stride (which equals window size) - default_stride_dataset = DatasetLoader.load( + default_stride_dataset = DatasetFactory.load( train_type="seq", load_path=test_dir, window_size=64,