reafactor: 统一并增强项目中的工厂模式实现

This commit is contained in:
ViperEkura 2026-03-30 01:33:14 +08:00
parent 60f4df95bd
commit 3e33c14376
9 changed files with 550 additions and 157 deletions

View File

@ -1,16 +1,26 @@
from khaosz.config.model_config import ModelConfig
from khaosz.config.param_config import BaseModelIO, ModelParameter
from khaosz.config.schedule_config import ScheduleConfig, CosineScheduleConfig, SGDRScheduleConfig
from khaosz.config.schedule_config import (
ScheduleConfig,
CosineScheduleConfig,
SGDRScheduleConfig,
ScheduleConfigFactory
)
from khaosz.config.train_config import TrainConfig
__all__ = [
# Base I/O
"BaseModelIO",
"ModelParameter",
# Model configuration
"ModelConfig",
"TrainConfig",
# Schedule configuration
"ScheduleConfig",
"CosineScheduleConfig",
"SGDRScheduleConfig",
"ScheduleConfigFactory",
]

View File

@ -1,10 +1,15 @@
from typing import Any, Dict
from typing import Any, Dict, Type
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
@dataclass
class ScheduleConfig(ABC):
"""Base configuration class for learning rate schedulers.
Provides common validation and interface for all schedule types.
"""
schedule_type: str = field(
default="cosine",
metadata={
@ -23,6 +28,7 @@ class ScheduleConfig(ABC):
@abstractmethod
def get_kwargs(self) -> Dict[str, Any]:
"""Get configuration kwargs for scheduler creation."""
raise NotImplementedError
def validate(self) -> None:
@ -35,6 +41,8 @@ class ScheduleConfig(ABC):
@dataclass
class CosineScheduleConfig(ScheduleConfig):
"""Cosine annealing learning rate schedule configuration."""
total_steps: int = field(
default=None,
metadata={"help": "Total training steps for cosine schedule."}
@ -63,6 +71,8 @@ class CosineScheduleConfig(ScheduleConfig):
@dataclass
class SGDRScheduleConfig(ScheduleConfig):
"""Stochastic Gradient Descent with Warm Restarts schedule configuration."""
cycle_length: int = field(
default=1000,
metadata={"help": "Length of the first cycle in steps."}
@ -91,3 +101,50 @@ class SGDRScheduleConfig(ScheduleConfig):
raise ValueError(f"cycle_length must be positive, got {self.cycle_length}")
if self.t_mult < 1:
raise ValueError(f"t_mult must be >= 1, got {self.t_mult}")
class ScheduleConfigFactory:
"""Factory class for creating ScheduleConfig instances.
Supports both direct instantiation and factory creation methods.
Example usage:
# Direct creation
config = CosineScheduleConfig(total_steps=10000)
# Factory method
config = ScheduleConfigFactory.create("cosine", total_steps=10000)
"""
CONFIG_MAP: Dict[str, Type[ScheduleConfig]] = {
"cosine": CosineScheduleConfig,
"sgdr": SGDRScheduleConfig,
}
@classmethod
def create(cls, schedule_type: str, **kwargs) -> ScheduleConfig:
"""Create a schedule config instance.
Args:
schedule_type: Type of schedule ("cosine", "sgdr")
**kwargs: Arguments passed to the config constructor
Returns:
ScheduleConfig instance
Raises:
ValueError: If schedule_type is not supported
"""
if schedule_type not in cls.CONFIG_MAP:
raise ValueError(
f"Unknown schedule type: '{schedule_type}'. "
f"Supported types: {sorted(cls.CONFIG_MAP.keys())}"
)
config_cls = cls.CONFIG_MAP[schedule_type]
return config_cls(**kwargs)
@classmethod
def available_types(cls) -> list:
"""Return list of available schedule type names."""
return list(cls.CONFIG_MAP.keys())

View File

@ -5,20 +5,31 @@ from khaosz.data.dataset import (
SFTDataset,
GRPODataset,
MultiSegmentFetcher,
DatasetLoader
DatasetLoader,
DatasetFactory
)
from khaosz.data.tokenizer import BpeTokenizer
from khaosz.data.sampler import ResumableDistributedSampler
__all__ = [
# Base classes
"BaseDataset",
# Dataset implementations
"SEQDataset",
"SFTDataset",
"DPODataset",
"GRPODataset",
# Fetchers
"MultiSegmentFetcher",
# Factory (DatasetLoader is alias for backward compatibility)
"DatasetLoader",
"DatasetFactory",
# Tokenizer and sampler
"BpeTokenizer",
"ResumableDistributedSampler"
]

View File

@ -1,3 +1,5 @@
"""Dataset implementations with factory pattern for training."""
import torch
import bisect
@ -8,8 +10,13 @@ from khaosz.data.serialization import load_h5
from typing import Callable, List, Dict, Literal, Optional, Union
class BaseSegmentFetcher:
"""Fetches data segments across multiple tensor segments.
Maintains cumulative lengths for efficient range queries across
multiple discontinuous segments.
"""
def __init__(self, segments: List[Tensor]):
self.segments = segments
self.cum_lengths = []
@ -25,12 +32,21 @@ class BaseSegmentFetcher:
return self.total_length
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
"""Fetch data in the range [begin_idx, end_idx).
Args:
begin_idx: Starting index (inclusive)
end_idx: Ending index (exclusive)
Returns:
Concatenated tensor of data in the specified range
"""
if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length):
raise ValueError("begin_idx or end_idx out of bounds")
if begin_idx >= end_idx:
return torch.tensor([], dtype=torch.long)
# fix the range index bug
# Find segment boundaries for the range
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
@ -47,6 +63,11 @@ class BaseSegmentFetcher:
class MultiSegmentFetcher:
"""Manages multiple segment fetchers for different data keys.
Each key corresponds to a different type of data (e.g., "sequence", "mask").
"""
def __init__(self, muti_segments: Dict):
self.muti_keys = list(muti_segments.keys())
self.muti_fetchers = {
@ -55,10 +76,21 @@ class MultiSegmentFetcher:
}
def __len__(self) -> int:
"""Returns the minimum length across all fetchers."""
len_list = [len(seg) for seg in self.muti_fetchers.values()]
return min(len_list)
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Dict:
"""Fetch data for specific keys.
Args:
begin_idx: Starting index
end_idx: Ending index
keys: Single key or list of keys to fetch
Returns:
Dictionary of tensors if multiple keys, single tensor if one key
"""
fetch_dict = {}
keys = [keys] if isinstance(keys, str) else keys
@ -70,23 +102,43 @@ class MultiSegmentFetcher:
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
"""Fetch all keys."""
return self.key_fetch(begin_idx, end_idx, self.muti_keys)
class BaseDataset(Dataset, ABC):
"""Abstract base class for all dataset types.
Implements common functionality for window-based data fetching.
"""
def __init__(self, window_size: int, stride: int):
super().__init__()
self.segments = {}
self.window_size = window_size
self.stride = stride
self.total_samples = None
self.fetcher: Optional[MultiSegmentFetcher] = None
def load(self, load_path: str):
"""Load dataset from HDF5 file.
Args:
load_path: Path to the HDF5 data file
"""
self.segments = load_h5(load_path)
self.fetcher = MultiSegmentFetcher(self.segments)
self.total_samples = len(self.fetcher)
def get_index(self, index: int) -> int:
def get_index(self, index: int) -> tuple:
"""Calculate begin and end indices for a sample.
Args:
index: Sample index
Returns:
Tuple of (begin_idx, end_idx)
"""
assert self.total_samples > self.window_size
begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size)
@ -96,6 +148,10 @@ class BaseDataset(Dataset, ABC):
@abstractmethod
def __getitem__(self, index: int) -> Dict[str, Tensor]:
"""Get a single sample by index.
Must be implemented by subclasses.
"""
raise NotImplementedError
def __len__(self) -> int:
@ -105,16 +161,109 @@ class BaseDataset(Dataset, ABC):
return (self.total_samples - 1 - self.window_size) // self.stride + 1
class DatasetFactory:
"""Factory class for creating dataset instances.
Supports decorator-based registration for extensible dataset types.
All default dataset types (seq, sft, dpo, grpo) are registered automatically
when their classes are defined with the decorator.
Example usage:
@DatasetFactory.register("custom")
class CustomDataset(BaseDataset):
...
dataset = DatasetFactory.create("custom", window_size, stride)
"""
SUPPORTED_TYPES = frozenset({"seq", "sft", "dpo", "grpo"})
DATASET_MAP: Dict[str, type] = {}
@classmethod
def register(cls, name: str):
"""Decorator to register a new dataset class.
Args:
name: Registration name for the dataset type
Returns:
Decorator function that registers the dataset class
"""
def decorator(dataset_cls: type) -> type:
if not issubclass(dataset_cls, BaseDataset):
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
cls.DATASET_MAP[name] = dataset_cls
return dataset_cls
return decorator
@classmethod
def create(cls, train_type: str, window_size: int, stride: int) -> BaseDataset:
"""Create a dataset instance.
Args:
train_type: Type of training ("seq", "sft", "dpo", "grpo")
window_size: Window size for data sampling
stride: Stride between consecutive samples
Returns:
Dataset instance
"""
if train_type not in cls.SUPPORTED_TYPES:
raise ValueError(
f"Unknown dataset type: '{train_type}'. "
f"Supported types: {sorted(cls.SUPPORTED_TYPES)}"
)
if train_type not in cls.DATASET_MAP:
raise NotImplementedError(
f"Dataset type '{train_type}' is supported but not yet implemented."
)
dataset_cls = cls.DATASET_MAP[train_type]
return dataset_cls(window_size, stride)
@classmethod
def load(cls, train_type: str, load_path: str, window_size: int, stride: Optional[int] = None) -> BaseDataset:
"""Create and load a dataset in one step.
Args:
train_type: Type of training dataset
load_path: Path to the data file
window_size: Window size for data sampling
stride: Stride between consecutive samples (default: same as window_size)
Returns:
Loaded dataset instance
"""
if stride is None:
stride = window_size
dataset = cls.create(train_type, window_size, stride)
dataset.load(load_path)
return dataset
@classmethod
def available_types(cls) -> list:
"""Return list of registered dataset type names."""
return list(cls.DATASET_MAP.keys())
# ============== Dataset Classes ==============
# All dataset classes are registered at class definition time using the decorator
@DatasetFactory.register("seq")
class SEQDataset(BaseDataset):
"""Dataset for sequential next-token prediction training."""
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
self.fetcher = MultiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
def __getitem__(self, index):
# fix the range index bug
begin_idx, end_idx = self.get_index(index)
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
@ -123,10 +272,12 @@ class SEQDataset(BaseDataset):
return {"input_ids": x, "target_ids": y}
@DatasetFactory.register("sft")
class SFTDataset(BaseDataset):
"""Dataset for supervised fine-tuning with loss masking."""
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
self.fetcher = MultiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key)
@ -141,10 +292,12 @@ class SFTDataset(BaseDataset):
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
@DatasetFactory.register("dpo")
class DPODataset(BaseDataset):
"""Dataset for Direct Preference Optimization training."""
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
self.fetcher = MultiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key)
@ -160,10 +313,12 @@ class DPODataset(BaseDataset):
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
@DatasetFactory.register("grpo")
class GRPODataset(BaseDataset):
"""Dataset for Group Relative Policy Optimization training."""
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
self.fetcher = MultiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key)
@ -179,24 +334,5 @@ class GRPODataset(BaseDataset):
return {"prompts": prompts, "responses": responses, "masks": masks, "rewards": rewards}
class DatasetLoader:
@staticmethod
def load(
train_type: Literal["seq", "sft", "dpo"],
load_path: str,
window_size: int,
stride: Optional[int] = None,
) -> BaseDataset:
if stride is None:
stride = window_size
dataset_router: Dict[str, Callable[[int], BaseDataset]] = {
"seq": lambda window_size: SEQDataset(window_size, stride),
"sft": lambda window_size: SFTDataset(window_size, stride),
"dpo": lambda window_size: DPODataset(window_size, stride),
"grpo": lambda window_size: GRPODataset(window_size, stride),
}
dataset = dataset_router[train_type](window_size)
dataset.load(load_path)
return dataset
# Backward compatibility alias
DatasetLoader = DatasetFactory

View File

@ -77,6 +77,7 @@ class GenerationRequest:
query: Input query (string or list of strings for batch).
history: Conversation history.
system_prompt: System prompt for the conversation.
stream: Whether to use streaming generation.
"""
top_k: int
top_p: float
@ -86,6 +87,7 @@ class GenerationRequest:
query: Union[str, List[str]]
history: Optional[Union[HistoryType, List[HistoryType]]] = None
system_prompt: Optional[str] = None
stream: bool = False
def __post_init__(self):
if not isinstance(self.top_k, int) or self.top_k < 0:
@ -233,33 +235,62 @@ class EmbeddingEncoder(EmbeddingEncoderCore):
class GeneratorFactory:
"""Factory class for creating appropriate generator instances based on request features."""
"""Factory class for creating generator instances.
Provides smart generator selection based on request characteristics:
- Streaming: Use StreamGenerator for streaming output
- Batch: Use BatchGenerator when query is a list
- Single: Use LoopGenerator for single query non-streaming
Example usage:
generator = GeneratorFactory.create_generator(parameter, request)
result = generator.generate(request)
"""
@staticmethod
def create_generator(parameter: ModelParameter, request: GenerationRequest):
def create_generator(parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore:
"""Create a generator based on request characteristics.
Args:
parameter: Model parameters containing model, tokenizer, config
request: Generation request with query, options, etc.
Returns:
Appropriate GeneratorCore subclass instance
"""
Create a generator based on the characteristics of GenerationRequest.
# Streaming generation: check stream field first
if request.stream:
return StreamGenerator(parameter)
# Batch generation: query is a list of strings
if isinstance(request.query, list):
return BatchGenerator(parameter)
# Default: single query non-streaming
return LoopGenerator(parameter)
@staticmethod
def create_encoder(parameter: ModelParameter) -> EmbeddingEncoderCore:
"""Create an embedding encoder instance.
Args:
parameter: Model parameters
Returns:
EmbeddingEncoderCore instance
"""
return EmbeddingEncoder(parameter)
@classmethod
def create(cls, parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore:
"""Convenience method that delegates to create_generator.
Args:
parameter: Model parameters
request: Generation request
Returns:
Subclass instance of GeneratorCore
Generator instance
"""
# Streaming generation detection: check stream field
if request.stream:
return StreamGenerator(parameter)
# Batch generation detection: query is a list
if isinstance(request.query, list):
return BatchGenerator(parameter)
# Default return LoopGenerator
return LoopGenerator(parameter)
@staticmethod
def create_encoder(parameter: ModelParameter):
"""Create an EmbeddingEncoder instance"""
return EmbeddingEncoder(parameter)
return cls.create_generator(parameter, request)

View File

@ -134,8 +134,8 @@ def spawn_parallel_fn(
if world_size == 1:
device_ids = device_ids or [0]
deice_id = torch.device(device_type, device_ids[0])
os.environ["LOCAL_DEVICE"] = str(deice_id)
device_id = torch.device(device_type, device_ids[0])
os.environ["LOCAL_DEVICE"] = str(device_id)
func(**kwargs)
return

View File

@ -1,29 +1,33 @@
from khaosz.trainer.trainer import Trainer
from khaosz.trainer.strategy import StrategyFactory
from khaosz.trainer.schedule import SchedulerFactory
from khaosz.trainer.strategy import StrategyFactory, BaseStrategy
from khaosz.trainer.schedule import SchedulerFactory, BaseScheduler
from khaosz.trainer.train_callback import (
TrainCallback,
ProgressBarCallback,
CheckpointCallback,
TrainCallback,
GradientClippingCallback,
SchedulerCallback,
MetricLoggerCallback
CheckpointCallback,
ProgressBarCallback,
MetricLoggerCallback,
)
__all__ = [
# trainer
# Main trainer
"Trainer",
# factory
# Strategy factory
"StrategyFactory",
"SchedulerFactory",
"BaseStrategy",
# callback
"TrainCallback",
"ProgressBarCallback",
"CheckpointCallback",
# Scheduler factory
"SchedulerFactory",
"BaseScheduler",
# Callbacks
"TrainCallback",
"GradientClippingCallback",
"SchedulerCallback",
"MetricLoggerCallback"
"CheckpointCallback",
"ProgressBarCallback",
"MetricLoggerCallback",
]

View File

@ -1,20 +1,21 @@
"""Learning rate scheduler implementations with factory pattern."""
import math
from abc import abstractmethod, ABC
from typing import Any, Dict, List
from typing import Any, Dict, List, Type
from torch.optim.lr_scheduler import LRScheduler
from khaosz.config.schedule_config import ScheduleConfig
class BaseScheduler(LRScheduler, ABC):
"""
Base scheduler class for all other schedulers.
"""
"""Base scheduler class for all other schedulers."""
def __init__(self, optimizer, last_epoch: int = -1):
super().__init__(optimizer, last_epoch)
@abstractmethod
def get_lr(self) -> List[float]:
"""Calculate the current learning rate."""
raise NotImplementedError
def state_dict(self) -> Dict[str, Any]:
@ -24,10 +25,95 @@ class BaseScheduler(LRScheduler, ABC):
super().load_state_dict(state_dict)
class SchedulerFactory:
"""Factory class for creating learning rate schedulers.
Supports decorator-based registration for extensible scheduler types.
Also supports creation from ScheduleConfig objects.
Example usage:
@SchedulerFactory.register("custom")
class CustomScheduler(BaseScheduler):
...
scheduler = SchedulerFactory.create(optimizer, "custom", **kwargs)
# Or from config
config = CosineScheduleConfig(total_steps=10000)
scheduler = SchedulerFactory.load(optimizer, config)
"""
SCHEDULER_MAP: Dict[str, Type[BaseScheduler]] = {}
@classmethod
def register(cls, name: str):
"""Decorator to register a new scheduler class.
Args:
name: Registration name for the scheduler
Returns:
Decorator function that registers the scheduler class
"""
def decorator(scheduler_cls: Type[BaseScheduler]) -> Type[BaseScheduler]:
if not issubclass(scheduler_cls, BaseScheduler):
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
cls.SCHEDULER_MAP[name] = scheduler_cls
return scheduler_cls
return decorator
@classmethod
def create(cls, optimizer, schedule_type: str, **kwargs) -> BaseScheduler:
"""Create a scheduler instance by type name.
Args:
optimizer: PyTorch optimizer
schedule_type: Type of scheduler ("cosine", "sgdr")
**kwargs: Arguments passed to the scheduler constructor
Returns:
Scheduler instance
Raises:
ValueError: If schedule_type is not supported
"""
if schedule_type not in cls.SCHEDULER_MAP:
raise ValueError(
f"Unknown schedule type: '{schedule_type}'. "
f"Supported types: {sorted(cls.SCHEDULER_MAP.keys())}"
)
scheduler_cls = cls.SCHEDULER_MAP[schedule_type]
return scheduler_cls(optimizer, **kwargs)
@staticmethod
def load(optimizer, schedule_config: ScheduleConfig) -> BaseScheduler:
"""Create a scheduler from a ScheduleConfig object.
Args:
optimizer: PyTorch optimizer
schedule_config: ScheduleConfig instance
Returns:
Scheduler instance
"""
kwargs = schedule_config.get_kwargs()
schedule_type = kwargs.pop("schedule_type")
return SchedulerFactory.create(optimizer, schedule_type, **kwargs)
@classmethod
def available_types(cls) -> list:
"""Return list of registered scheduler type names."""
return list(cls.SCHEDULER_MAP.keys())
# ============== Scheduler Classes ==============
# All scheduler classes are registered at class definition time using the decorator
@SchedulerFactory.register("cosine")
class CosineScheduler(BaseScheduler):
"""
Cosine decay scheduler with warmup, implemented as PyTorch LRScheduler.
"""
"""Cosine decay scheduler with warmup, implemented as PyTorch LRScheduler."""
def __init__(
self,
@ -75,10 +161,9 @@ class CosineScheduler(BaseScheduler):
super().load_state_dict(state_dict)
@SchedulerFactory.register("sgdr")
class SGDRScheduler(BaseScheduler):
"""
SGDR (Stochastic Gradient Descent with Warm Restarts) scheduler,
"""
"""SGDR (Stochastic Gradient Descent with Warm Restarts) scheduler."""
def __init__(
self,
@ -142,23 +227,3 @@ class SGDRScheduler(BaseScheduler):
self.min_rate = state_dict.pop('min_rate')
self.t_mult = state_dict.pop('t_mult')
super().load_state_dict(state_dict)
class SchedulerFactory:
"""
Factory class for creating learning rate schedulers.
"""
@staticmethod
def load(optimizer, schedule_config: ScheduleConfig) -> BaseScheduler:
kwargs = schedule_config.get_kwargs()
schedule_type = kwargs.pop("schedule_type")
if schedule_type == "cosine":
return CosineScheduler(optimizer, **kwargs)
elif schedule_type == "sgdr":
return SGDRScheduler(optimizer, **kwargs)
else:
raise ValueError(f"Unsupported schedule type: {schedule_type}")

View File

@ -1,3 +1,5 @@
"""Training strategy implementations with factory pattern."""
import copy
import torch
import torch.nn as nn
@ -17,9 +19,10 @@ def unwrap_model(model: nn.Module) -> nn.Module:
def create_ref_model(model: nn.Module) -> nn.Module:
"""
Create a reference model for DPO/GRPO training.
Handles DDP-wrapped models safely.
"""Create a reference model for DPO/GRPO training.
Handles DDP-wrapped models safely by unwrapping first,
then creating a deep copy with frozen gradients.
"""
original_model = unwrap_model(model)
ref_model = copy.deepcopy(original_model)
@ -28,17 +31,18 @@ def create_ref_model(model: nn.Module) -> nn.Module:
return ref_model
def move_to_device(batch:Dict[str, Tensor], device: str) -> Any:
def move_to_device(batch: Dict[str, Tensor], device: str) -> Any:
"""Move batch tensors to specified device with non-blocking transfer."""
return {key: value.to(device, non_blocking=True) for key, value in batch.items()}
def get_logprobs(
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
input_ids: Tensor,
mask: Tensor,
reduction: str,
):
"""
Compute token-wise log probabilities from model outputs.
"""Compute token-wise log probabilities from model outputs.
Args:
model: The language model
@ -49,7 +53,6 @@ def get_logprobs(
Returns:
Log probabilities with reduction applied over sequence dimension
"""
# reduction on seq_len dim
allowed_reductions = ["mean", "sum", "none"]
if reduction not in allowed_reductions:
raise ValueError(f"reduction must be one of {allowed_reductions}, got '{reduction}'")
@ -60,7 +63,6 @@ def get_logprobs(
logits = model(input_ids[:, :-1], mask[:, :-1])["logits"]
log_probs = torch.log_softmax(logits.float(), dim=-1)
# [batch_size, seq_len - 1]
token_logprobs = torch.gather(
log_probs,
dim=-1,
@ -76,20 +78,112 @@ def get_logprobs(
class BaseStrategy(ABC):
"""Abstract base class for training strategies."""
def __init__(self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str):
self.model = model
self.device = device
@abstractmethod
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
"""Compute loss for the given batch.
Args:
batch: Dictionary containing batch tensors
Returns:
Computed loss tensor
"""
raise NotImplementedError
def __call__(self, batch: Dict[str, Tensor]) -> Tensor:
"""Allow calling strategy directly as a callable."""
return self.compute_loss(batch)
class StrategyFactory:
"""Factory class for creating training strategy instances.
Supports decorator-based registration for extensible strategy types.
All default strategies (seq, sft, dpo, grpo) are automatically registered.
Example usage:
@StrategyFactory.register("custom")
class CustomStrategy(BaseStrategy):
...
strategy = StrategyFactory.create(model, "custom", device)
"""
SUPPORTED_STRATEGIES = frozenset({"seq", "sft", "dpo", "grpo"})
STRATEGY_MAP: Dict[str, type] = {}
@classmethod
def register(cls, name: str):
"""Decorator to register a new strategy class.
Args:
name: Registration name for the strategy
Returns:
Decorator function that registers the strategy class
"""
def decorator(strategy_cls: type) -> type:
if not issubclass(strategy_cls, BaseStrategy):
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
cls.STRATEGY_MAP[name] = strategy_cls
return strategy_cls
return decorator
@classmethod
def create(cls, model, train_type: str, device: str, **kwargs) -> BaseStrategy:
"""Create a strategy instance based on training type.
Args:
model: Model instance for the strategy
train_type: Type of training ("seq", "sft", "dpo", "grpo")
device: Device to run the strategy on
**kwargs: Additional arguments passed to strategy constructor
Returns:
Strategy instance
Raises:
ValueError: If train_type is not supported
NotImplementedError: If train_type is in supported list but not implemented
"""
if train_type not in cls.SUPPORTED_STRATEGIES:
raise ValueError(
f"Unknown training strategy: '{train_type}'. "
f"Supported strategies: {sorted(cls.SUPPORTED_STRATEGIES)}"
)
if train_type not in cls.STRATEGY_MAP:
raise NotImplementedError(
f"Strategy '{train_type}' is supported but not yet implemented."
)
strategy_cls = cls.STRATEGY_MAP[train_type]
return strategy_cls(model, device, **kwargs)
@classmethod
def available_strategies(cls) -> list:
"""Return list of registered strategy names."""
return list(cls.STRATEGY_MAP.keys())
# ============== Strategy Classes ==============
# All strategies are registered at class definition time using the decorator
@StrategyFactory.register("seq")
class SEQStrategy(BaseStrategy):
def __init__(self, model, device, label_smoothing):
"""Standard next-token prediction training strategy.
Computes cross-entropy loss for next token prediction.
"""
def __init__(self, model, device, label_smoothing: float = 0.0):
super().__init__(model, device)
self.label_smoothing = label_smoothing
@ -99,15 +193,22 @@ class SEQStrategy(BaseStrategy):
logits = self.model(input_ids=input_ids)["logits"]
loss = F.cross_entropy(
input=logits.flatten(0, 1).float(),
target=target_ids.flatten()
input=logits.flatten(0, 1).float(),
target=target_ids.flatten(),
label_smoothing=self.label_smoothing
)
return loss
@StrategyFactory.register("sft")
class SFTStrategy(BaseStrategy):
def __init__(self, model, device, label_smoothing):
"""Supervised Fine-tuning strategy with loss masking.
Applies cross-entropy loss only to tokens where loss_mask is True.
"""
def __init__(self, model, device, label_smoothing: float = 0.0):
super().__init__(model, device)
self.label_smoothing = label_smoothing
@ -122,19 +223,27 @@ class SFTStrategy(BaseStrategy):
loss = F.cross_entropy(
input=logits.flatten(0, 1).float(),
target=target_ids.flatten(),
ignore_index=ignore_index
ignore_index=ignore_index,
label_smoothing=self.label_smoothing
)
return loss
@StrategyFactory.register("dpo")
class DPOStrategy(BaseStrategy):
"""Direct Preference Optimization strategy.
Implements the DPO loss from the paper "Direct Preference Optimization".
Uses a reference model to compute KL divergence penalty.
"""
def __init__(
self,
model: nn.Module,
device: str,
beta: float,
reduction: str,
beta: float = 0.1,
reduction: str = "mean",
):
super().__init__(model, device)
self.ref_model = create_ref_model(model)
@ -168,16 +277,21 @@ class DPOStrategy(BaseStrategy):
return dpo_loss
@StrategyFactory.register("grpo")
class GRPOStrategy(BaseStrategy):
"""Group Relative Policy Optimization strategy.
Implements GRPO with clipping and KL penalty.
"""
def __init__(
self,
model: nn.Module,
device: str,
clip_eps: float,
kl_coef: float,
group_size: int,
reduction: str,
clip_eps: float = 0.2,
kl_coef: float = 0.01,
group_size: int = 4,
reduction: str = "mean",
):
super().__init__(model, device)
self.ref_model = create_ref_model(model)
@ -209,16 +323,14 @@ class GRPOStrategy(BaseStrategy):
log_probs_ref = get_logprobs(self.ref_model, full_sequences, full_masks, self.reduction)
log_probs_ref = log_probs_ref.view(batch_size, group_size)
# Compute advantages from rewards
# Compute advantages from rewards with normalization
eps = torch.finfo(log_probs_policy.dtype).eps
mean = rewards.mean(dim=-1, keepdim=True)
std = rewards.std(dim=-1, keepdim=True)
advantages = (rewards - mean) / (std + eps)
# log_ratio = log_probs_policy - log_probs_old
# ratio = torch.exp(log_ratio)
# off policy: policy_model = old_model, then ratio = 1
ratio = torch.exp(0)
# PPO-style clipped surrogate objective
ratio = torch.exp(0) # Off-policy: policy_model = old_model
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
@ -227,36 +339,3 @@ class GRPOStrategy(BaseStrategy):
total_loss = policy_loss + kl_penalty
return total_loss
class StrategyFactory:
def load(model, train_type, device, **kwargs):
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
"seq": lambda: SEQStrategy(
model,
device,
kwargs.get("label_smoothing", 0.0)
),
"sft": lambda: SFTStrategy(
model,
device,
kwargs.get("label_smoothing", 0.0)
),
"dpo": lambda: DPOStrategy(
model,
device,
kwargs.get("dpo_beta"),
kwargs.get("reduction", "mean")
),
"grpo": lambda: GRPOStrategy(
model,
device,
kwargs.get("grpo_clip_eps"),
kwargs.get("grpo_kl_coef"),
kwargs.get("grpo_group_size"),
kwargs.get("reduction", "mean")
)
}
strategy = train_strategy[train_type]()
return strategy