refactor: 优化工厂模式结构
This commit is contained in:
parent
7861af12e4
commit
e97536758f
|
|
@ -5,7 +5,8 @@ from astrai.config import (
|
||||||
ModelConfig,
|
ModelConfig,
|
||||||
TrainConfig,
|
TrainConfig,
|
||||||
)
|
)
|
||||||
from astrai.data import BpeTokenizer, DatasetLoader
|
from astrai.core.factory import BaseFactory
|
||||||
|
from astrai.data import BpeTokenizer, DatasetFactory
|
||||||
from astrai.inference.generator import (
|
from astrai.inference.generator import (
|
||||||
BatchGenerator,
|
BatchGenerator,
|
||||||
EmbeddingEncoder,
|
EmbeddingEncoder,
|
||||||
|
|
@ -21,7 +22,7 @@ __all__ = [
|
||||||
"Transformer",
|
"Transformer",
|
||||||
"ModelConfig",
|
"ModelConfig",
|
||||||
"TrainConfig",
|
"TrainConfig",
|
||||||
"DatasetLoader",
|
"DatasetFactory",
|
||||||
"BpeTokenizer",
|
"BpeTokenizer",
|
||||||
"GenerationRequest",
|
"GenerationRequest",
|
||||||
"LoopGenerator",
|
"LoopGenerator",
|
||||||
|
|
@ -32,4 +33,5 @@ __all__ = [
|
||||||
"Trainer",
|
"Trainer",
|
||||||
"StrategyFactory",
|
"StrategyFactory",
|
||||||
"SchedulerFactory",
|
"SchedulerFactory",
|
||||||
|
"BaseFactory",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,105 @@
|
||||||
|
"""Base factory class for extensible component registration."""
|
||||||
|
|
||||||
|
from abc import ABC
|
||||||
|
from typing import Callable, Dict, Generic, Type, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseFactory(ABC, Generic[T]):
|
||||||
|
"""Generic factory class for component registration and creation.
|
||||||
|
|
||||||
|
This base class provides a decorator-based registration pattern
|
||||||
|
for creating extensible component factories.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
class MyFactory(BaseFactory[MyBaseClass]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@MyFactory.register("custom")
|
||||||
|
class CustomComponent(MyBaseClass):
|
||||||
|
...
|
||||||
|
|
||||||
|
component = MyFactory.create("custom", *args, **kwargs)
|
||||||
|
"""
|
||||||
|
|
||||||
|
_registry: Dict[str, Type[T]] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, name: str) -> Callable[[Type[T]], Type[T]]:
|
||||||
|
"""Decorator to register a component class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Registration name for the component
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decorator function that registers the component class
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If the decorated class doesn't inherit from the base type
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(component_cls: Type[T]) -> Type[T]:
|
||||||
|
cls._validate_component(component_cls)
|
||||||
|
cls._registry[name] = component_cls
|
||||||
|
return component_cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, name: str, *args, **kwargs) -> T:
|
||||||
|
"""Create a component instance by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Registered name of the component
|
||||||
|
*args: Positional arguments passed to component constructor
|
||||||
|
**kwargs: Keyword arguments passed to component constructor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Component instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the component name is not registered
|
||||||
|
"""
|
||||||
|
if name not in cls._registry:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown component: '{name}'. "
|
||||||
|
f"Supported types: {sorted(cls._registry.keys())}"
|
||||||
|
)
|
||||||
|
component_cls = cls._registry[name]
|
||||||
|
return component_cls(*args, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_component(cls, component_cls: Type[T]) -> None:
|
||||||
|
"""Validate that the component class is valid for this factory.
|
||||||
|
|
||||||
|
Override this method in subclasses to add custom validation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
component_cls: Component class to validate
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If the component class is invalid
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_registered(cls) -> list:
|
||||||
|
"""List all registered component names.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of registered component names
|
||||||
|
"""
|
||||||
|
return sorted(cls._registry.keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_registered(cls, name: str) -> bool:
|
||||||
|
"""Check if a component name is registered.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Component name to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if registered, False otherwise
|
||||||
|
"""
|
||||||
|
return name in cls._registry
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from astrai.data.dataset import (
|
from astrai.data.dataset import (
|
||||||
BaseDataset,
|
BaseDataset,
|
||||||
DatasetFactory,
|
DatasetFactory,
|
||||||
DatasetLoader,
|
DatasetFactory,
|
||||||
DPODataset,
|
DPODataset,
|
||||||
GRPODataset,
|
GRPODataset,
|
||||||
MultiSegmentFetcher,
|
MultiSegmentFetcher,
|
||||||
|
|
@ -21,8 +21,8 @@ __all__ = [
|
||||||
"GRPODataset",
|
"GRPODataset",
|
||||||
# Fetchers
|
# Fetchers
|
||||||
"MultiSegmentFetcher",
|
"MultiSegmentFetcher",
|
||||||
# Factory (DatasetLoader is alias for backward compatibility)
|
# Factory (DatasetFactory is alias for backward compatibility)
|
||||||
"DatasetLoader",
|
"DatasetFactory",
|
||||||
"DatasetFactory",
|
"DatasetFactory",
|
||||||
# Tokenizer and sampler
|
# Tokenizer and sampler
|
||||||
"BpeTokenizer",
|
"BpeTokenizer",
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from astrai.core.factory import BaseFactory
|
||||||
from astrai.data.serialization import load_h5
|
from astrai.data.serialization import load_h5
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -165,7 +166,7 @@ class BaseDataset(Dataset, ABC):
|
||||||
return (self.total_samples - 1 - self.window_size) // self.stride + 1
|
return (self.total_samples - 1 - self.window_size) // self.stride + 1
|
||||||
|
|
||||||
|
|
||||||
class DatasetFactory:
|
class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||||
"""Factory class for creating dataset instances.
|
"""Factory class for creating dataset instances.
|
||||||
|
|
||||||
Supports decorator-based registration for extensible dataset types.
|
Supports decorator-based registration for extensible dataset types.
|
||||||
|
|
@ -180,30 +181,16 @@ class DatasetFactory:
|
||||||
dataset = DatasetFactory.create("custom", window_size, stride)
|
dataset = DatasetFactory.create("custom", window_size, stride)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
SUPPORTED_TYPES = frozenset({"seq", "sft", "dpo", "grpo"})
|
_registry: Dict[str, type] = {}
|
||||||
DATASET_MAP: Dict[str, type] = {}
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register(cls, name: str):
|
def _validate_component(cls, dataset_cls: type) -> None:
|
||||||
"""Decorator to register a new dataset class.
|
"""Validate that the dataset class inherits from BaseDataset."""
|
||||||
|
if not issubclass(dataset_cls, BaseDataset):
|
||||||
Args:
|
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
|
||||||
name: Registration name for the dataset type
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Decorator function that registers the dataset class
|
|
||||||
"""
|
|
||||||
|
|
||||||
def decorator(dataset_cls: type) -> type:
|
|
||||||
if not issubclass(dataset_cls, BaseDataset):
|
|
||||||
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
|
|
||||||
cls.DATASET_MAP[name] = dataset_cls
|
|
||||||
return dataset_cls
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, train_type: str, window_size: int, stride: int) -> BaseDataset:
|
def create(cls, train_type: str, window_size: int, stride: int) -> "BaseDataset":
|
||||||
"""Create a dataset instance.
|
"""Create a dataset instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -214,19 +201,7 @@ class DatasetFactory:
|
||||||
Returns:
|
Returns:
|
||||||
Dataset instance
|
Dataset instance
|
||||||
"""
|
"""
|
||||||
if train_type not in cls.SUPPORTED_TYPES:
|
return super().create(train_type, window_size, stride)
|
||||||
raise ValueError(
|
|
||||||
f"Unknown dataset type: '{train_type}'. "
|
|
||||||
f"Supported types: {sorted(cls.SUPPORTED_TYPES)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if train_type not in cls.DATASET_MAP:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"Dataset type '{train_type}' is supported but not yet implemented."
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_cls = cls.DATASET_MAP[train_type]
|
|
||||||
return dataset_cls(window_size, stride)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
|
|
@ -235,7 +210,7 @@ class DatasetFactory:
|
||||||
load_path: str,
|
load_path: str,
|
||||||
window_size: int,
|
window_size: int,
|
||||||
stride: Optional[int] = None,
|
stride: Optional[int] = None,
|
||||||
) -> BaseDataset:
|
) -> "BaseDataset":
|
||||||
"""Create and load a dataset in one step.
|
"""Create and load a dataset in one step.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -258,7 +233,7 @@ class DatasetFactory:
|
||||||
@classmethod
|
@classmethod
|
||||||
def available_types(cls) -> list:
|
def available_types(cls) -> list:
|
||||||
"""Return list of registered dataset type names."""
|
"""Return list of registered dataset type names."""
|
||||||
return list(cls.DATASET_MAP.keys())
|
return cls.list_registered()
|
||||||
|
|
||||||
|
|
||||||
# ============== Dataset Classes ==============
|
# ============== Dataset Classes ==============
|
||||||
|
|
@ -362,7 +337,3 @@ class GRPODataset(BaseDataset):
|
||||||
"masks": masks,
|
"masks": masks,
|
||||||
"rewards": rewards,
|
"rewards": rewards,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# Backward compatibility alias
|
|
||||||
DatasetLoader = DatasetFactory
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,11 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Generator, List, Optional, Tuple, Union
|
from typing import Dict, Generator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from astrai.config.param_config import ModelParameter
|
from astrai.config.param_config import ModelParameter
|
||||||
|
from astrai.core.factory import BaseFactory
|
||||||
from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager
|
from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager
|
||||||
|
|
||||||
HistoryType = List[Tuple[str, str]]
|
HistoryType = List[Tuple[str, str]]
|
||||||
|
|
@ -254,7 +255,7 @@ class EmbeddingEncoder(EmbeddingEncoderCore):
|
||||||
return super().encode(sentence)
|
return super().encode(sentence)
|
||||||
|
|
||||||
|
|
||||||
class GeneratorFactory:
|
class GeneratorFactory(BaseFactory[GeneratorCore]):
|
||||||
"""Factory class for creating generator instances.
|
"""Factory class for creating generator instances.
|
||||||
|
|
||||||
Provides smart generator selection based on request characteristics:
|
Provides smart generator selection based on request characteristics:
|
||||||
|
|
@ -263,14 +264,14 @@ class GeneratorFactory:
|
||||||
- Single: Use LoopGenerator for single query non-streaming
|
- Single: Use LoopGenerator for single query non-streaming
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
generator = GeneratorFactory.create_generator(parameter, request)
|
generator = GeneratorFactory.create(parameter, request)
|
||||||
result = generator.generate(request)
|
result = generator.generate(request)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_registry: Dict[str, type] = {}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_generator(
|
def create(parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore:
|
||||||
parameter: ModelParameter, request: GenerationRequest
|
|
||||||
) -> GeneratorCore:
|
|
||||||
"""Create a generator based on request characteristics.
|
"""Create a generator based on request characteristics.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,8 @@ from typing import Any, Dict, List, Type
|
||||||
|
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
|
|
||||||
|
from astrai.core.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
class BaseScheduler(LRScheduler, ABC):
|
class BaseScheduler(LRScheduler, ABC):
|
||||||
"""Base scheduler class for all other schedulers."""
|
"""Base scheduler class for all other schedulers."""
|
||||||
|
|
@ -25,7 +27,7 @@ class BaseScheduler(LRScheduler, ABC):
|
||||||
super().load_state_dict(state_dict)
|
super().load_state_dict(state_dict)
|
||||||
|
|
||||||
|
|
||||||
class SchedulerFactory:
|
class SchedulerFactory(BaseFactory["BaseScheduler"]):
|
||||||
"""Factory class for creating learning rate schedulers.
|
"""Factory class for creating learning rate schedulers.
|
||||||
|
|
||||||
Supports decorator-based registration for extensible scheduler types.
|
Supports decorator-based registration for extensible scheduler types.
|
||||||
|
|
@ -36,34 +38,21 @@ class SchedulerFactory:
|
||||||
class CustomScheduler(BaseScheduler):
|
class CustomScheduler(BaseScheduler):
|
||||||
...
|
...
|
||||||
|
|
||||||
scheduler = SchedulerFactory.create(optimizer, "custom", **kwargs)
|
scheduler = SchedulerFactory.create("custom", optimizer, **kwargs)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
SCHEDULER_MAP: Dict[str, Type[BaseScheduler]] = {}
|
_registry: Dict[str, Type[BaseScheduler]] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register(cls, name: str):
|
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]) -> None:
|
||||||
"""Decorator to register a new scheduler class.
|
"""Validate that the scheduler class inherits from BaseScheduler."""
|
||||||
|
if not issubclass(scheduler_cls, BaseScheduler):
|
||||||
Args:
|
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
|
||||||
name: Registration name for the scheduler
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Decorator function that registers the scheduler class
|
|
||||||
"""
|
|
||||||
|
|
||||||
def decorator(scheduler_cls: Type[BaseScheduler]) -> Type[BaseScheduler]:
|
|
||||||
if not issubclass(scheduler_cls, BaseScheduler):
|
|
||||||
raise TypeError(
|
|
||||||
f"{scheduler_cls.__name__} must inherit from BaseScheduler"
|
|
||||||
)
|
|
||||||
cls.SCHEDULER_MAP[name] = scheduler_cls
|
|
||||||
return scheduler_cls
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, optimizer, schedule_type: str = "none", **kwargs) -> BaseScheduler:
|
def create(
|
||||||
|
cls, optimizer, schedule_type: str = "none", **kwargs
|
||||||
|
) -> "BaseScheduler":
|
||||||
"""Create a scheduler instance by type name.
|
"""Create a scheduler instance by type name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -73,23 +62,13 @@ class SchedulerFactory:
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Scheduler instance
|
Scheduler instance
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If schedule_type is not supported
|
|
||||||
"""
|
"""
|
||||||
if schedule_type not in cls.SCHEDULER_MAP:
|
return super().create(schedule_type, optimizer, **kwargs)
|
||||||
raise ValueError(
|
|
||||||
f"Unknown schedule type: '{schedule_type}'. "
|
|
||||||
f"Supported types: {sorted(cls.SCHEDULER_MAP.keys())}"
|
|
||||||
)
|
|
||||||
|
|
||||||
scheduler_cls = cls.SCHEDULER_MAP[schedule_type]
|
|
||||||
return scheduler_cls(optimizer, **kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def available_types(cls) -> list:
|
def available_types(cls) -> list:
|
||||||
"""Return list of registered scheduler type names."""
|
"""Return list of registered scheduler type names."""
|
||||||
return list(cls.SCHEDULER_MAP.keys())
|
return cls.list_registered()
|
||||||
|
|
||||||
|
|
||||||
# ----------- Scheduler implementations -----------
|
# ----------- Scheduler implementations -----------
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,8 @@ import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
|
from astrai.core.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
def unwrap_model(model: nn.Module) -> nn.Module:
|
def unwrap_model(model: nn.Module) -> nn.Module:
|
||||||
"""Unwrap DDP wrapper if present to get the original model."""
|
"""Unwrap DDP wrapper if present to get the original model."""
|
||||||
|
|
@ -106,7 +108,7 @@ class BaseStrategy(ABC):
|
||||||
return self.compute_loss(batch)
|
return self.compute_loss(batch)
|
||||||
|
|
||||||
|
|
||||||
class StrategyFactory:
|
class StrategyFactory(BaseFactory["BaseStrategy"]):
|
||||||
"""Factory class for creating training strategy instances.
|
"""Factory class for creating training strategy instances.
|
||||||
|
|
||||||
Supports decorator-based registration for extensible strategy types.
|
Supports decorator-based registration for extensible strategy types.
|
||||||
|
|
@ -117,68 +119,36 @@ class StrategyFactory:
|
||||||
class CustomStrategy(BaseStrategy):
|
class CustomStrategy(BaseStrategy):
|
||||||
...
|
...
|
||||||
|
|
||||||
strategy = StrategyFactory.create(model, "custom", device)
|
strategy = StrategyFactory.create("custom", model, device)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
SUPPORTED_STRATEGIES = frozenset({"seq", "sft", "dpo", "grpo"})
|
_registry: Dict[str, type] = {}
|
||||||
STRATEGY_MAP: Dict[str, type] = {}
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register(cls, name: str):
|
def _validate_component(cls, strategy_cls: type) -> None:
|
||||||
"""Decorator to register a new strategy class.
|
"""Validate that the strategy class inherits from BaseStrategy."""
|
||||||
|
if not issubclass(strategy_cls, BaseStrategy):
|
||||||
Args:
|
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
|
||||||
name: Registration name for the strategy
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Decorator function that registers the strategy class
|
|
||||||
"""
|
|
||||||
|
|
||||||
def decorator(strategy_cls: type) -> type:
|
|
||||||
if not issubclass(strategy_cls, BaseStrategy):
|
|
||||||
raise TypeError(
|
|
||||||
f"{strategy_cls.__name__} must inherit from BaseStrategy"
|
|
||||||
)
|
|
||||||
cls.STRATEGY_MAP[name] = strategy_cls
|
|
||||||
return strategy_cls
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, model, train_type: str, device: str, **kwargs) -> BaseStrategy:
|
def create(cls, train_type: str, model, device: str, **kwargs) -> "BaseStrategy":
|
||||||
"""Create a strategy instance based on training type.
|
"""Create a strategy instance based on training type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: Model instance for the strategy
|
|
||||||
train_type: Type of training ("seq", "sft", "dpo", "grpo")
|
train_type: Type of training ("seq", "sft", "dpo", "grpo")
|
||||||
|
model: Model instance for the strategy
|
||||||
device: Device to run the strategy on
|
device: Device to run the strategy on
|
||||||
**kwargs: Additional arguments passed to strategy constructor
|
**kwargs: Additional arguments passed to strategy constructor
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Strategy instance
|
Strategy instance
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If train_type is not supported
|
|
||||||
NotImplementedError: If train_type is in supported list but not implemented
|
|
||||||
"""
|
"""
|
||||||
if train_type not in cls.SUPPORTED_STRATEGIES:
|
return super().create(train_type, model, device, **kwargs)
|
||||||
raise ValueError(
|
|
||||||
f"Unknown training strategy: '{train_type}'. "
|
|
||||||
f"Supported strategies: {sorted(cls.SUPPORTED_STRATEGIES)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if train_type not in cls.STRATEGY_MAP:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"Strategy '{train_type}' is supported but not yet implemented."
|
|
||||||
)
|
|
||||||
|
|
||||||
strategy_cls = cls.STRATEGY_MAP[train_type]
|
|
||||||
return strategy_cls(model, device, **kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def available_strategies(cls) -> list:
|
def available_strategies(cls) -> list:
|
||||||
"""Return list of registered strategy names."""
|
"""Return list of registered strategy names."""
|
||||||
return list(cls.STRATEGY_MAP.keys())
|
return cls.list_registered()
|
||||||
|
|
||||||
|
|
||||||
# ============== Strategy Classes ==============
|
# ============== Strategy Classes ==============
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ import torch.optim as optim
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
from astrai.config import ModelParameter, TrainConfig
|
from astrai.config import ModelParameter, TrainConfig
|
||||||
from astrai.data import DatasetLoader
|
from astrai.data import DatasetFactory
|
||||||
from astrai.parallel import get_rank
|
from astrai.parallel import get_rank
|
||||||
from astrai.trainer import SchedulerFactory, Trainer
|
from astrai.trainer import SchedulerFactory, Trainer
|
||||||
|
|
||||||
|
|
@ -205,7 +205,7 @@ def train(
|
||||||
|
|
||||||
strategy_kwargs = {"dpo_beta": dpo_beta, "label_smoothing": label_smoothing}
|
strategy_kwargs = {"dpo_beta": dpo_beta, "label_smoothing": label_smoothing}
|
||||||
|
|
||||||
dataset = DatasetLoader.load(
|
dataset = DatasetFactory.load(
|
||||||
train_type=train_type,
|
train_type=train_type,
|
||||||
load_path=data_root_path,
|
load_path=data_root_path,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ def test_dataset_loader_random_paths(base_test_env):
|
||||||
save_h5(test_dir, f"data_{i}", dummy_data)
|
save_h5(test_dir, f"data_{i}", dummy_data)
|
||||||
|
|
||||||
# Test loading with multiple paths
|
# Test loading with multiple paths
|
||||||
loaded_dataset = DatasetLoader.load(
|
loaded_dataset = DatasetFactory.load(
|
||||||
train_type="seq",
|
train_type="seq",
|
||||||
load_path=test_dir,
|
load_path=test_dir,
|
||||||
window_size=64,
|
window_size=64,
|
||||||
|
|
@ -57,7 +57,7 @@ def test_dpo_strategy_with_random_data(base_test_env):
|
||||||
save_h5(test_dir, "dpo_data", dummy_data)
|
save_h5(test_dir, "dpo_data", dummy_data)
|
||||||
|
|
||||||
# Load DPO dataset
|
# Load DPO dataset
|
||||||
dpo_dataset = DatasetLoader.load(
|
dpo_dataset = DatasetFactory.load(
|
||||||
train_type="dpo",
|
train_type="dpo",
|
||||||
load_path=test_dir,
|
load_path=test_dir,
|
||||||
window_size=64,
|
window_size=64,
|
||||||
|
|
@ -93,7 +93,7 @@ def test_sft_dataset_with_random_data(base_test_env):
|
||||||
save_h5(test_dir, "sft_data", dummy_data)
|
save_h5(test_dir, "sft_data", dummy_data)
|
||||||
|
|
||||||
# Load SFT dataset
|
# Load SFT dataset
|
||||||
sft_dataset = DatasetLoader.load(
|
sft_dataset = DatasetFactory.load(
|
||||||
train_type="sft",
|
train_type="sft",
|
||||||
load_path=test_dir,
|
load_path=test_dir,
|
||||||
window_size=64,
|
window_size=64,
|
||||||
|
|
@ -127,7 +127,7 @@ def test_dataset_with_custom_stride(base_test_env):
|
||||||
|
|
||||||
# Test with custom stride
|
# Test with custom stride
|
||||||
custom_stride = 32
|
custom_stride = 32
|
||||||
dataset = DatasetLoader.load(
|
dataset = DatasetFactory.load(
|
||||||
train_type="seq", load_path=test_dir, window_size=64, stride=custom_stride
|
train_type="seq", load_path=test_dir, window_size=64, stride=custom_stride
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -136,7 +136,7 @@ def test_dataset_with_custom_stride(base_test_env):
|
||||||
|
|
||||||
# With stride 32 and window 64 on 200 length data, we should get more samples
|
# With stride 32 and window 64 on 200 length data, we should get more samples
|
||||||
# than with default stride (which equals window size)
|
# than with default stride (which equals window size)
|
||||||
default_stride_dataset = DatasetLoader.load(
|
default_stride_dataset = DatasetFactory.load(
|
||||||
train_type="seq",
|
train_type="seq",
|
||||||
load_path=test_dir,
|
load_path=test_dir,
|
||||||
window_size=64,
|
window_size=64,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue