reafactor: 统一并增强项目中的工厂模式实现

This commit is contained in:
ViperEkura 2026-03-30 01:33:14 +08:00
parent 60f4df95bd
commit 3e33c14376
9 changed files with 550 additions and 157 deletions

View File

@ -1,16 +1,26 @@
from khaosz.config.model_config import ModelConfig from khaosz.config.model_config import ModelConfig
from khaosz.config.param_config import BaseModelIO, ModelParameter from khaosz.config.param_config import BaseModelIO, ModelParameter
from khaosz.config.schedule_config import ScheduleConfig, CosineScheduleConfig, SGDRScheduleConfig from khaosz.config.schedule_config import (
ScheduleConfig,
CosineScheduleConfig,
SGDRScheduleConfig,
ScheduleConfigFactory
)
from khaosz.config.train_config import TrainConfig from khaosz.config.train_config import TrainConfig
__all__ = [ __all__ = [
# Base I/O
"BaseModelIO", "BaseModelIO",
"ModelParameter", "ModelParameter",
# Model configuration
"ModelConfig", "ModelConfig",
"TrainConfig", "TrainConfig",
# Schedule configuration
"ScheduleConfig", "ScheduleConfig",
"CosineScheduleConfig", "CosineScheduleConfig",
"SGDRScheduleConfig", "SGDRScheduleConfig",
"ScheduleConfigFactory",
] ]

View File

@ -1,10 +1,15 @@
from typing import Any, Dict from typing import Any, Dict, Type
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
@dataclass @dataclass
class ScheduleConfig(ABC): class ScheduleConfig(ABC):
"""Base configuration class for learning rate schedulers.
Provides common validation and interface for all schedule types.
"""
schedule_type: str = field( schedule_type: str = field(
default="cosine", default="cosine",
metadata={ metadata={
@ -23,6 +28,7 @@ class ScheduleConfig(ABC):
@abstractmethod @abstractmethod
def get_kwargs(self) -> Dict[str, Any]: def get_kwargs(self) -> Dict[str, Any]:
"""Get configuration kwargs for scheduler creation."""
raise NotImplementedError raise NotImplementedError
def validate(self) -> None: def validate(self) -> None:
@ -35,6 +41,8 @@ class ScheduleConfig(ABC):
@dataclass @dataclass
class CosineScheduleConfig(ScheduleConfig): class CosineScheduleConfig(ScheduleConfig):
"""Cosine annealing learning rate schedule configuration."""
total_steps: int = field( total_steps: int = field(
default=None, default=None,
metadata={"help": "Total training steps for cosine schedule."} metadata={"help": "Total training steps for cosine schedule."}
@ -63,6 +71,8 @@ class CosineScheduleConfig(ScheduleConfig):
@dataclass @dataclass
class SGDRScheduleConfig(ScheduleConfig): class SGDRScheduleConfig(ScheduleConfig):
"""Stochastic Gradient Descent with Warm Restarts schedule configuration."""
cycle_length: int = field( cycle_length: int = field(
default=1000, default=1000,
metadata={"help": "Length of the first cycle in steps."} metadata={"help": "Length of the first cycle in steps."}
@ -91,3 +101,50 @@ class SGDRScheduleConfig(ScheduleConfig):
raise ValueError(f"cycle_length must be positive, got {self.cycle_length}") raise ValueError(f"cycle_length must be positive, got {self.cycle_length}")
if self.t_mult < 1: if self.t_mult < 1:
raise ValueError(f"t_mult must be >= 1, got {self.t_mult}") raise ValueError(f"t_mult must be >= 1, got {self.t_mult}")
class ScheduleConfigFactory:
"""Factory class for creating ScheduleConfig instances.
Supports both direct instantiation and factory creation methods.
Example usage:
# Direct creation
config = CosineScheduleConfig(total_steps=10000)
# Factory method
config = ScheduleConfigFactory.create("cosine", total_steps=10000)
"""
CONFIG_MAP: Dict[str, Type[ScheduleConfig]] = {
"cosine": CosineScheduleConfig,
"sgdr": SGDRScheduleConfig,
}
@classmethod
def create(cls, schedule_type: str, **kwargs) -> ScheduleConfig:
"""Create a schedule config instance.
Args:
schedule_type: Type of schedule ("cosine", "sgdr")
**kwargs: Arguments passed to the config constructor
Returns:
ScheduleConfig instance
Raises:
ValueError: If schedule_type is not supported
"""
if schedule_type not in cls.CONFIG_MAP:
raise ValueError(
f"Unknown schedule type: '{schedule_type}'. "
f"Supported types: {sorted(cls.CONFIG_MAP.keys())}"
)
config_cls = cls.CONFIG_MAP[schedule_type]
return config_cls(**kwargs)
@classmethod
def available_types(cls) -> list:
"""Return list of available schedule type names."""
return list(cls.CONFIG_MAP.keys())

View File

@ -5,20 +5,31 @@ from khaosz.data.dataset import (
SFTDataset, SFTDataset,
GRPODataset, GRPODataset,
MultiSegmentFetcher, MultiSegmentFetcher,
DatasetLoader DatasetLoader,
DatasetFactory
) )
from khaosz.data.tokenizer import BpeTokenizer from khaosz.data.tokenizer import BpeTokenizer
from khaosz.data.sampler import ResumableDistributedSampler from khaosz.data.sampler import ResumableDistributedSampler
__all__ = [ __all__ = [
# Base classes
"BaseDataset", "BaseDataset",
# Dataset implementations
"SEQDataset", "SEQDataset",
"SFTDataset", "SFTDataset",
"DPODataset", "DPODataset",
"GRPODataset", "GRPODataset",
# Fetchers
"MultiSegmentFetcher", "MultiSegmentFetcher",
# Factory (DatasetLoader is alias for backward compatibility)
"DatasetLoader", "DatasetLoader",
"DatasetFactory",
# Tokenizer and sampler
"BpeTokenizer", "BpeTokenizer",
"ResumableDistributedSampler" "ResumableDistributedSampler"
] ]

View File

@ -1,3 +1,5 @@
"""Dataset implementations with factory pattern for training."""
import torch import torch
import bisect import bisect
@ -8,8 +10,13 @@ from khaosz.data.serialization import load_h5
from typing import Callable, List, Dict, Literal, Optional, Union from typing import Callable, List, Dict, Literal, Optional, Union
class BaseSegmentFetcher: class BaseSegmentFetcher:
"""Fetches data segments across multiple tensor segments.
Maintains cumulative lengths for efficient range queries across
multiple discontinuous segments.
"""
def __init__(self, segments: List[Tensor]): def __init__(self, segments: List[Tensor]):
self.segments = segments self.segments = segments
self.cum_lengths = [] self.cum_lengths = []
@ -25,12 +32,21 @@ class BaseSegmentFetcher:
return self.total_length return self.total_length
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
"""Fetch data in the range [begin_idx, end_idx).
Args:
begin_idx: Starting index (inclusive)
end_idx: Ending index (exclusive)
Returns:
Concatenated tensor of data in the specified range
"""
if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length): if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length):
raise ValueError("begin_idx or end_idx out of bounds") raise ValueError("begin_idx or end_idx out of bounds")
if begin_idx >= end_idx: if begin_idx >= end_idx:
return torch.tensor([], dtype=torch.long) return torch.tensor([], dtype=torch.long)
# fix the range index bug # Find segment boundaries for the range
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx) seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx) seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
@ -47,6 +63,11 @@ class BaseSegmentFetcher:
class MultiSegmentFetcher: class MultiSegmentFetcher:
"""Manages multiple segment fetchers for different data keys.
Each key corresponds to a different type of data (e.g., "sequence", "mask").
"""
def __init__(self, muti_segments: Dict): def __init__(self, muti_segments: Dict):
self.muti_keys = list(muti_segments.keys()) self.muti_keys = list(muti_segments.keys())
self.muti_fetchers = { self.muti_fetchers = {
@ -55,10 +76,21 @@ class MultiSegmentFetcher:
} }
def __len__(self) -> int: def __len__(self) -> int:
"""Returns the minimum length across all fetchers."""
len_list = [len(seg) for seg in self.muti_fetchers.values()] len_list = [len(seg) for seg in self.muti_fetchers.values()]
return min(len_list) return min(len_list)
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Dict: def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Dict:
"""Fetch data for specific keys.
Args:
begin_idx: Starting index
end_idx: Ending index
keys: Single key or list of keys to fetch
Returns:
Dictionary of tensors if multiple keys, single tensor if one key
"""
fetch_dict = {} fetch_dict = {}
keys = [keys] if isinstance(keys, str) else keys keys = [keys] if isinstance(keys, str) else keys
@ -70,23 +102,43 @@ class MultiSegmentFetcher:
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]] return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict: def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
"""Fetch all keys."""
return self.key_fetch(begin_idx, end_idx, self.muti_keys) return self.key_fetch(begin_idx, end_idx, self.muti_keys)
class BaseDataset(Dataset, ABC): class BaseDataset(Dataset, ABC):
"""Abstract base class for all dataset types.
Implements common functionality for window-based data fetching.
"""
def __init__(self, window_size: int, stride: int): def __init__(self, window_size: int, stride: int):
super().__init__() super().__init__()
self.segments = {} self.segments = {}
self.window_size = window_size self.window_size = window_size
self.stride = stride self.stride = stride
self.total_samples = None self.total_samples = None
self.fetcher: Optional[MultiSegmentFetcher] = None
def load(self, load_path: str): def load(self, load_path: str):
"""Load dataset from HDF5 file.
Args:
load_path: Path to the HDF5 data file
"""
self.segments = load_h5(load_path) self.segments = load_h5(load_path)
self.fetcher = MultiSegmentFetcher(self.segments) self.fetcher = MultiSegmentFetcher(self.segments)
self.total_samples = len(self.fetcher) self.total_samples = len(self.fetcher)
def get_index(self, index: int) -> int: def get_index(self, index: int) -> tuple:
"""Calculate begin and end indices for a sample.
Args:
index: Sample index
Returns:
Tuple of (begin_idx, end_idx)
"""
assert self.total_samples > self.window_size assert self.total_samples > self.window_size
begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size) begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size)
@ -96,6 +148,10 @@ class BaseDataset(Dataset, ABC):
@abstractmethod @abstractmethod
def __getitem__(self, index: int) -> Dict[str, Tensor]: def __getitem__(self, index: int) -> Dict[str, Tensor]:
"""Get a single sample by index.
Must be implemented by subclasses.
"""
raise NotImplementedError raise NotImplementedError
def __len__(self) -> int: def __len__(self) -> int:
@ -105,16 +161,109 @@ class BaseDataset(Dataset, ABC):
return (self.total_samples - 1 - self.window_size) // self.stride + 1 return (self.total_samples - 1 - self.window_size) // self.stride + 1
class DatasetFactory:
"""Factory class for creating dataset instances.
Supports decorator-based registration for extensible dataset types.
All default dataset types (seq, sft, dpo, grpo) are registered automatically
when their classes are defined with the decorator.
Example usage:
@DatasetFactory.register("custom")
class CustomDataset(BaseDataset):
...
dataset = DatasetFactory.create("custom", window_size, stride)
"""
SUPPORTED_TYPES = frozenset({"seq", "sft", "dpo", "grpo"})
DATASET_MAP: Dict[str, type] = {}
@classmethod
def register(cls, name: str):
"""Decorator to register a new dataset class.
Args:
name: Registration name for the dataset type
Returns:
Decorator function that registers the dataset class
"""
def decorator(dataset_cls: type) -> type:
if not issubclass(dataset_cls, BaseDataset):
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
cls.DATASET_MAP[name] = dataset_cls
return dataset_cls
return decorator
@classmethod
def create(cls, train_type: str, window_size: int, stride: int) -> BaseDataset:
"""Create a dataset instance.
Args:
train_type: Type of training ("seq", "sft", "dpo", "grpo")
window_size: Window size for data sampling
stride: Stride between consecutive samples
Returns:
Dataset instance
"""
if train_type not in cls.SUPPORTED_TYPES:
raise ValueError(
f"Unknown dataset type: '{train_type}'. "
f"Supported types: {sorted(cls.SUPPORTED_TYPES)}"
)
if train_type not in cls.DATASET_MAP:
raise NotImplementedError(
f"Dataset type '{train_type}' is supported but not yet implemented."
)
dataset_cls = cls.DATASET_MAP[train_type]
return dataset_cls(window_size, stride)
@classmethod
def load(cls, train_type: str, load_path: str, window_size: int, stride: Optional[int] = None) -> BaseDataset:
"""Create and load a dataset in one step.
Args:
train_type: Type of training dataset
load_path: Path to the data file
window_size: Window size for data sampling
stride: Stride between consecutive samples (default: same as window_size)
Returns:
Loaded dataset instance
"""
if stride is None:
stride = window_size
dataset = cls.create(train_type, window_size, stride)
dataset.load(load_path)
return dataset
@classmethod
def available_types(cls) -> list:
"""Return list of registered dataset type names."""
return list(cls.DATASET_MAP.keys())
# ============== Dataset Classes ==============
# All dataset classes are registered at class definition time using the decorator
@DatasetFactory.register("seq")
class SEQDataset(BaseDataset): class SEQDataset(BaseDataset):
"""Dataset for sequential next-token prediction training."""
def __init__(self, window_size: int, stride: int): def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride) super().__init__(window_size, stride)
self.fetcher = MultiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence") return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
def __getitem__(self, index): def __getitem__(self, index):
# fix the range index bug
begin_idx, end_idx = self.get_index(index) begin_idx, end_idx = self.get_index(index)
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long) x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
@ -123,10 +272,12 @@ class SEQDataset(BaseDataset):
return {"input_ids": x, "target_ids": y} return {"input_ids": x, "target_ids": y}
@DatasetFactory.register("sft")
class SFTDataset(BaseDataset): class SFTDataset(BaseDataset):
"""Dataset for supervised fine-tuning with loss masking."""
def __init__(self, window_size: int, stride: int): def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride) super().__init__(window_size, stride)
self.fetcher = MultiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key) return self.fetcher.key_fetch(begin_idx, end_idx, key)
@ -141,10 +292,12 @@ class SFTDataset(BaseDataset):
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask} return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
@DatasetFactory.register("dpo")
class DPODataset(BaseDataset): class DPODataset(BaseDataset):
"""Dataset for Direct Preference Optimization training."""
def __init__(self, window_size: int, stride: int): def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride) super().__init__(window_size, stride)
self.fetcher = MultiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key) return self.fetcher.key_fetch(begin_idx, end_idx, key)
@ -160,10 +313,12 @@ class DPODataset(BaseDataset):
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask} return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
@DatasetFactory.register("grpo")
class GRPODataset(BaseDataset): class GRPODataset(BaseDataset):
"""Dataset for Group Relative Policy Optimization training."""
def __init__(self, window_size: int, stride: int): def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride) super().__init__(window_size, stride)
self.fetcher = MultiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key) return self.fetcher.key_fetch(begin_idx, end_idx, key)
@ -179,24 +334,5 @@ class GRPODataset(BaseDataset):
return {"prompts": prompts, "responses": responses, "masks": masks, "rewards": rewards} return {"prompts": prompts, "responses": responses, "masks": masks, "rewards": rewards}
class DatasetLoader: # Backward compatibility alias
@staticmethod DatasetLoader = DatasetFactory
def load(
train_type: Literal["seq", "sft", "dpo"],
load_path: str,
window_size: int,
stride: Optional[int] = None,
) -> BaseDataset:
if stride is None:
stride = window_size
dataset_router: Dict[str, Callable[[int], BaseDataset]] = {
"seq": lambda window_size: SEQDataset(window_size, stride),
"sft": lambda window_size: SFTDataset(window_size, stride),
"dpo": lambda window_size: DPODataset(window_size, stride),
"grpo": lambda window_size: GRPODataset(window_size, stride),
}
dataset = dataset_router[train_type](window_size)
dataset.load(load_path)
return dataset

View File

@ -77,6 +77,7 @@ class GenerationRequest:
query: Input query (string or list of strings for batch). query: Input query (string or list of strings for batch).
history: Conversation history. history: Conversation history.
system_prompt: System prompt for the conversation. system_prompt: System prompt for the conversation.
stream: Whether to use streaming generation.
""" """
top_k: int top_k: int
top_p: float top_p: float
@ -86,6 +87,7 @@ class GenerationRequest:
query: Union[str, List[str]] query: Union[str, List[str]]
history: Optional[Union[HistoryType, List[HistoryType]]] = None history: Optional[Union[HistoryType, List[HistoryType]]] = None
system_prompt: Optional[str] = None system_prompt: Optional[str] = None
stream: bool = False
def __post_init__(self): def __post_init__(self):
if not isinstance(self.top_k, int) or self.top_k < 0: if not isinstance(self.top_k, int) or self.top_k < 0:
@ -233,33 +235,62 @@ class EmbeddingEncoder(EmbeddingEncoderCore):
class GeneratorFactory: class GeneratorFactory:
"""Factory class for creating appropriate generator instances based on request features.""" """Factory class for creating generator instances.
Provides smart generator selection based on request characteristics:
- Streaming: Use StreamGenerator for streaming output
- Batch: Use BatchGenerator when query is a list
- Single: Use LoopGenerator for single query non-streaming
Example usage:
generator = GeneratorFactory.create_generator(parameter, request)
result = generator.generate(request)
"""
@staticmethod @staticmethod
def create_generator(parameter: ModelParameter, request: GenerationRequest): def create_generator(parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore:
"""Create a generator based on request characteristics.
Args:
parameter: Model parameters containing model, tokenizer, config
request: Generation request with query, options, etc.
Returns:
Appropriate GeneratorCore subclass instance
""" """
Create a generator based on the characteristics of GenerationRequest. # Streaming generation: check stream field first
if request.stream:
return StreamGenerator(parameter)
# Batch generation: query is a list of strings
if isinstance(request.query, list):
return BatchGenerator(parameter)
# Default: single query non-streaming
return LoopGenerator(parameter)
@staticmethod
def create_encoder(parameter: ModelParameter) -> EmbeddingEncoderCore:
"""Create an embedding encoder instance.
Args:
parameter: Model parameters
Returns:
EmbeddingEncoderCore instance
"""
return EmbeddingEncoder(parameter)
@classmethod
def create(cls, parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore:
"""Convenience method that delegates to create_generator.
Args: Args:
parameter: Model parameters parameter: Model parameters
request: Generation request request: Generation request
Returns: Returns:
Subclass instance of GeneratorCore Generator instance
""" """
return cls.create_generator(parameter, request)
# Streaming generation detection: check stream field
if request.stream:
return StreamGenerator(parameter)
# Batch generation detection: query is a list
if isinstance(request.query, list):
return BatchGenerator(parameter)
# Default return LoopGenerator
return LoopGenerator(parameter)
@staticmethod
def create_encoder(parameter: ModelParameter):
"""Create an EmbeddingEncoder instance"""
return EmbeddingEncoder(parameter)

View File

@ -134,8 +134,8 @@ def spawn_parallel_fn(
if world_size == 1: if world_size == 1:
device_ids = device_ids or [0] device_ids = device_ids or [0]
deice_id = torch.device(device_type, device_ids[0]) device_id = torch.device(device_type, device_ids[0])
os.environ["LOCAL_DEVICE"] = str(deice_id) os.environ["LOCAL_DEVICE"] = str(device_id)
func(**kwargs) func(**kwargs)
return return

View File

@ -1,29 +1,33 @@
from khaosz.trainer.trainer import Trainer from khaosz.trainer.trainer import Trainer
from khaosz.trainer.strategy import StrategyFactory from khaosz.trainer.strategy import StrategyFactory, BaseStrategy
from khaosz.trainer.schedule import SchedulerFactory from khaosz.trainer.schedule import SchedulerFactory, BaseScheduler
from khaosz.trainer.train_callback import ( from khaosz.trainer.train_callback import (
TrainCallback, TrainCallback,
ProgressBarCallback, GradientClippingCallback,
CheckpointCallback,
TrainCallback,
SchedulerCallback, SchedulerCallback,
MetricLoggerCallback CheckpointCallback,
ProgressBarCallback,
MetricLoggerCallback,
) )
__all__ = [ __all__ = [
# trainer # Main trainer
"Trainer", "Trainer",
# factory # Strategy factory
"StrategyFactory", "StrategyFactory",
"SchedulerFactory", "BaseStrategy",
# callback # Scheduler factory
"TrainCallback", "SchedulerFactory",
"ProgressBarCallback", "BaseScheduler",
"CheckpointCallback",
# Callbacks
"TrainCallback", "TrainCallback",
"GradientClippingCallback",
"SchedulerCallback", "SchedulerCallback",
"MetricLoggerCallback" "CheckpointCallback",
"ProgressBarCallback",
"MetricLoggerCallback",
] ]

View File

@ -1,20 +1,21 @@
"""Learning rate scheduler implementations with factory pattern."""
import math import math
from abc import abstractmethod, ABC from abc import abstractmethod, ABC
from typing import Any, Dict, List from typing import Any, Dict, List, Type
from torch.optim.lr_scheduler import LRScheduler from torch.optim.lr_scheduler import LRScheduler
from khaosz.config.schedule_config import ScheduleConfig from khaosz.config.schedule_config import ScheduleConfig
class BaseScheduler(LRScheduler, ABC): class BaseScheduler(LRScheduler, ABC):
""" """Base scheduler class for all other schedulers."""
Base scheduler class for all other schedulers.
"""
def __init__(self, optimizer, last_epoch: int = -1): def __init__(self, optimizer, last_epoch: int = -1):
super().__init__(optimizer, last_epoch) super().__init__(optimizer, last_epoch)
@abstractmethod @abstractmethod
def get_lr(self) -> List[float]: def get_lr(self) -> List[float]:
"""Calculate the current learning rate."""
raise NotImplementedError raise NotImplementedError
def state_dict(self) -> Dict[str, Any]: def state_dict(self) -> Dict[str, Any]:
@ -24,10 +25,95 @@ class BaseScheduler(LRScheduler, ABC):
super().load_state_dict(state_dict) super().load_state_dict(state_dict)
class SchedulerFactory:
"""Factory class for creating learning rate schedulers.
Supports decorator-based registration for extensible scheduler types.
Also supports creation from ScheduleConfig objects.
Example usage:
@SchedulerFactory.register("custom")
class CustomScheduler(BaseScheduler):
...
scheduler = SchedulerFactory.create(optimizer, "custom", **kwargs)
# Or from config
config = CosineScheduleConfig(total_steps=10000)
scheduler = SchedulerFactory.load(optimizer, config)
"""
SCHEDULER_MAP: Dict[str, Type[BaseScheduler]] = {}
@classmethod
def register(cls, name: str):
"""Decorator to register a new scheduler class.
Args:
name: Registration name for the scheduler
Returns:
Decorator function that registers the scheduler class
"""
def decorator(scheduler_cls: Type[BaseScheduler]) -> Type[BaseScheduler]:
if not issubclass(scheduler_cls, BaseScheduler):
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
cls.SCHEDULER_MAP[name] = scheduler_cls
return scheduler_cls
return decorator
@classmethod
def create(cls, optimizer, schedule_type: str, **kwargs) -> BaseScheduler:
"""Create a scheduler instance by type name.
Args:
optimizer: PyTorch optimizer
schedule_type: Type of scheduler ("cosine", "sgdr")
**kwargs: Arguments passed to the scheduler constructor
Returns:
Scheduler instance
Raises:
ValueError: If schedule_type is not supported
"""
if schedule_type not in cls.SCHEDULER_MAP:
raise ValueError(
f"Unknown schedule type: '{schedule_type}'. "
f"Supported types: {sorted(cls.SCHEDULER_MAP.keys())}"
)
scheduler_cls = cls.SCHEDULER_MAP[schedule_type]
return scheduler_cls(optimizer, **kwargs)
@staticmethod
def load(optimizer, schedule_config: ScheduleConfig) -> BaseScheduler:
"""Create a scheduler from a ScheduleConfig object.
Args:
optimizer: PyTorch optimizer
schedule_config: ScheduleConfig instance
Returns:
Scheduler instance
"""
kwargs = schedule_config.get_kwargs()
schedule_type = kwargs.pop("schedule_type")
return SchedulerFactory.create(optimizer, schedule_type, **kwargs)
@classmethod
def available_types(cls) -> list:
"""Return list of registered scheduler type names."""
return list(cls.SCHEDULER_MAP.keys())
# ============== Scheduler Classes ==============
# All scheduler classes are registered at class definition time using the decorator
@SchedulerFactory.register("cosine")
class CosineScheduler(BaseScheduler): class CosineScheduler(BaseScheduler):
""" """Cosine decay scheduler with warmup, implemented as PyTorch LRScheduler."""
Cosine decay scheduler with warmup, implemented as PyTorch LRScheduler.
"""
def __init__( def __init__(
self, self,
@ -75,10 +161,9 @@ class CosineScheduler(BaseScheduler):
super().load_state_dict(state_dict) super().load_state_dict(state_dict)
@SchedulerFactory.register("sgdr")
class SGDRScheduler(BaseScheduler): class SGDRScheduler(BaseScheduler):
""" """SGDR (Stochastic Gradient Descent with Warm Restarts) scheduler."""
SGDR (Stochastic Gradient Descent with Warm Restarts) scheduler,
"""
def __init__( def __init__(
self, self,
@ -142,23 +227,3 @@ class SGDRScheduler(BaseScheduler):
self.min_rate = state_dict.pop('min_rate') self.min_rate = state_dict.pop('min_rate')
self.t_mult = state_dict.pop('t_mult') self.t_mult = state_dict.pop('t_mult')
super().load_state_dict(state_dict) super().load_state_dict(state_dict)
class SchedulerFactory:
"""
Factory class for creating learning rate schedulers.
"""
@staticmethod
def load(optimizer, schedule_config: ScheduleConfig) -> BaseScheduler:
kwargs = schedule_config.get_kwargs()
schedule_type = kwargs.pop("schedule_type")
if schedule_type == "cosine":
return CosineScheduler(optimizer, **kwargs)
elif schedule_type == "sgdr":
return SGDRScheduler(optimizer, **kwargs)
else:
raise ValueError(f"Unsupported schedule type: {schedule_type}")

View File

@ -1,3 +1,5 @@
"""Training strategy implementations with factory pattern."""
import copy import copy
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -17,9 +19,10 @@ def unwrap_model(model: nn.Module) -> nn.Module:
def create_ref_model(model: nn.Module) -> nn.Module: def create_ref_model(model: nn.Module) -> nn.Module:
""" """Create a reference model for DPO/GRPO training.
Create a reference model for DPO/GRPO training.
Handles DDP-wrapped models safely. Handles DDP-wrapped models safely by unwrapping first,
then creating a deep copy with frozen gradients.
""" """
original_model = unwrap_model(model) original_model = unwrap_model(model)
ref_model = copy.deepcopy(original_model) ref_model = copy.deepcopy(original_model)
@ -29,16 +32,17 @@ def create_ref_model(model: nn.Module) -> nn.Module:
def move_to_device(batch: Dict[str, Tensor], device: str) -> Any: def move_to_device(batch: Dict[str, Tensor], device: str) -> Any:
"""Move batch tensors to specified device with non-blocking transfer."""
return {key: value.to(device, non_blocking=True) for key, value in batch.items()} return {key: value.to(device, non_blocking=True) for key, value in batch.items()}
def get_logprobs( def get_logprobs(
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
input_ids: Tensor, input_ids: Tensor,
mask: Tensor, mask: Tensor,
reduction: str, reduction: str,
): ):
""" """Compute token-wise log probabilities from model outputs.
Compute token-wise log probabilities from model outputs.
Args: Args:
model: The language model model: The language model
@ -49,7 +53,6 @@ def get_logprobs(
Returns: Returns:
Log probabilities with reduction applied over sequence dimension Log probabilities with reduction applied over sequence dimension
""" """
# reduction on seq_len dim
allowed_reductions = ["mean", "sum", "none"] allowed_reductions = ["mean", "sum", "none"]
if reduction not in allowed_reductions: if reduction not in allowed_reductions:
raise ValueError(f"reduction must be one of {allowed_reductions}, got '{reduction}'") raise ValueError(f"reduction must be one of {allowed_reductions}, got '{reduction}'")
@ -60,7 +63,6 @@ def get_logprobs(
logits = model(input_ids[:, :-1], mask[:, :-1])["logits"] logits = model(input_ids[:, :-1], mask[:, :-1])["logits"]
log_probs = torch.log_softmax(logits.float(), dim=-1) log_probs = torch.log_softmax(logits.float(), dim=-1)
# [batch_size, seq_len - 1]
token_logprobs = torch.gather( token_logprobs = torch.gather(
log_probs, log_probs,
dim=-1, dim=-1,
@ -76,20 +78,112 @@ def get_logprobs(
class BaseStrategy(ABC): class BaseStrategy(ABC):
"""Abstract base class for training strategies."""
def __init__(self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str): def __init__(self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str):
self.model = model self.model = model
self.device = device self.device = device
@abstractmethod @abstractmethod
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
"""Compute loss for the given batch.
Args:
batch: Dictionary containing batch tensors
Returns:
Computed loss tensor
"""
raise NotImplementedError raise NotImplementedError
def __call__(self, batch: Dict[str, Tensor]) -> Tensor: def __call__(self, batch: Dict[str, Tensor]) -> Tensor:
"""Allow calling strategy directly as a callable."""
return self.compute_loss(batch) return self.compute_loss(batch)
class StrategyFactory:
"""Factory class for creating training strategy instances.
Supports decorator-based registration for extensible strategy types.
All default strategies (seq, sft, dpo, grpo) are automatically registered.
Example usage:
@StrategyFactory.register("custom")
class CustomStrategy(BaseStrategy):
...
strategy = StrategyFactory.create(model, "custom", device)
"""
SUPPORTED_STRATEGIES = frozenset({"seq", "sft", "dpo", "grpo"})
STRATEGY_MAP: Dict[str, type] = {}
@classmethod
def register(cls, name: str):
"""Decorator to register a new strategy class.
Args:
name: Registration name for the strategy
Returns:
Decorator function that registers the strategy class
"""
def decorator(strategy_cls: type) -> type:
if not issubclass(strategy_cls, BaseStrategy):
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
cls.STRATEGY_MAP[name] = strategy_cls
return strategy_cls
return decorator
@classmethod
def create(cls, model, train_type: str, device: str, **kwargs) -> BaseStrategy:
"""Create a strategy instance based on training type.
Args:
model: Model instance for the strategy
train_type: Type of training ("seq", "sft", "dpo", "grpo")
device: Device to run the strategy on
**kwargs: Additional arguments passed to strategy constructor
Returns:
Strategy instance
Raises:
ValueError: If train_type is not supported
NotImplementedError: If train_type is in supported list but not implemented
"""
if train_type not in cls.SUPPORTED_STRATEGIES:
raise ValueError(
f"Unknown training strategy: '{train_type}'. "
f"Supported strategies: {sorted(cls.SUPPORTED_STRATEGIES)}"
)
if train_type not in cls.STRATEGY_MAP:
raise NotImplementedError(
f"Strategy '{train_type}' is supported but not yet implemented."
)
strategy_cls = cls.STRATEGY_MAP[train_type]
return strategy_cls(model, device, **kwargs)
@classmethod
def available_strategies(cls) -> list:
"""Return list of registered strategy names."""
return list(cls.STRATEGY_MAP.keys())
# ============== Strategy Classes ==============
# All strategies are registered at class definition time using the decorator
@StrategyFactory.register("seq")
class SEQStrategy(BaseStrategy): class SEQStrategy(BaseStrategy):
def __init__(self, model, device, label_smoothing): """Standard next-token prediction training strategy.
Computes cross-entropy loss for next token prediction.
"""
def __init__(self, model, device, label_smoothing: float = 0.0):
super().__init__(model, device) super().__init__(model, device)
self.label_smoothing = label_smoothing self.label_smoothing = label_smoothing
@ -100,14 +194,21 @@ class SEQStrategy(BaseStrategy):
loss = F.cross_entropy( loss = F.cross_entropy(
input=logits.flatten(0, 1).float(), input=logits.flatten(0, 1).float(),
target=target_ids.flatten() target=target_ids.flatten(),
label_smoothing=self.label_smoothing
) )
return loss return loss
@StrategyFactory.register("sft")
class SFTStrategy(BaseStrategy): class SFTStrategy(BaseStrategy):
def __init__(self, model, device, label_smoothing): """Supervised Fine-tuning strategy with loss masking.
Applies cross-entropy loss only to tokens where loss_mask is True.
"""
def __init__(self, model, device, label_smoothing: float = 0.0):
super().__init__(model, device) super().__init__(model, device)
self.label_smoothing = label_smoothing self.label_smoothing = label_smoothing
@ -122,19 +223,27 @@ class SFTStrategy(BaseStrategy):
loss = F.cross_entropy( loss = F.cross_entropy(
input=logits.flatten(0, 1).float(), input=logits.flatten(0, 1).float(),
target=target_ids.flatten(), target=target_ids.flatten(),
ignore_index=ignore_index ignore_index=ignore_index,
label_smoothing=self.label_smoothing
) )
return loss return loss
@StrategyFactory.register("dpo")
class DPOStrategy(BaseStrategy): class DPOStrategy(BaseStrategy):
"""Direct Preference Optimization strategy.
Implements the DPO loss from the paper "Direct Preference Optimization".
Uses a reference model to compute KL divergence penalty.
"""
def __init__( def __init__(
self, self,
model: nn.Module, model: nn.Module,
device: str, device: str,
beta: float, beta: float = 0.1,
reduction: str, reduction: str = "mean",
): ):
super().__init__(model, device) super().__init__(model, device)
self.ref_model = create_ref_model(model) self.ref_model = create_ref_model(model)
@ -168,16 +277,21 @@ class DPOStrategy(BaseStrategy):
return dpo_loss return dpo_loss
@StrategyFactory.register("grpo")
class GRPOStrategy(BaseStrategy): class GRPOStrategy(BaseStrategy):
"""Group Relative Policy Optimization strategy.
Implements GRPO with clipping and KL penalty.
"""
def __init__( def __init__(
self, self,
model: nn.Module, model: nn.Module,
device: str, device: str,
clip_eps: float, clip_eps: float = 0.2,
kl_coef: float, kl_coef: float = 0.01,
group_size: int, group_size: int = 4,
reduction: str, reduction: str = "mean",
): ):
super().__init__(model, device) super().__init__(model, device)
self.ref_model = create_ref_model(model) self.ref_model = create_ref_model(model)
@ -209,16 +323,14 @@ class GRPOStrategy(BaseStrategy):
log_probs_ref = get_logprobs(self.ref_model, full_sequences, full_masks, self.reduction) log_probs_ref = get_logprobs(self.ref_model, full_sequences, full_masks, self.reduction)
log_probs_ref = log_probs_ref.view(batch_size, group_size) log_probs_ref = log_probs_ref.view(batch_size, group_size)
# Compute advantages from rewards # Compute advantages from rewards with normalization
eps = torch.finfo(log_probs_policy.dtype).eps eps = torch.finfo(log_probs_policy.dtype).eps
mean = rewards.mean(dim=-1, keepdim=True) mean = rewards.mean(dim=-1, keepdim=True)
std = rewards.std(dim=-1, keepdim=True) std = rewards.std(dim=-1, keepdim=True)
advantages = (rewards - mean) / (std + eps) advantages = (rewards - mean) / (std + eps)
# log_ratio = log_probs_policy - log_probs_old # PPO-style clipped surrogate objective
# ratio = torch.exp(log_ratio) ratio = torch.exp(0) # Off-policy: policy_model = old_model
# off policy: policy_model = old_model, then ratio = 1
ratio = torch.exp(0)
surr1 = ratio * advantages surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
@ -227,36 +339,3 @@ class GRPOStrategy(BaseStrategy):
total_loss = policy_loss + kl_penalty total_loss = policy_loss + kl_penalty
return total_loss return total_loss
class StrategyFactory:
def load(model, train_type, device, **kwargs):
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
"seq": lambda: SEQStrategy(
model,
device,
kwargs.get("label_smoothing", 0.0)
),
"sft": lambda: SFTStrategy(
model,
device,
kwargs.get("label_smoothing", 0.0)
),
"dpo": lambda: DPOStrategy(
model,
device,
kwargs.get("dpo_beta"),
kwargs.get("reduction", "mean")
),
"grpo": lambda: GRPOStrategy(
model,
device,
kwargs.get("grpo_clip_eps"),
kwargs.get("grpo_kl_coef"),
kwargs.get("grpo_group_size"),
kwargs.get("reduction", "mean")
)
}
strategy = train_strategy[train_type]()
return strategy