reafactor: 统一并增强项目中的工厂模式实现
This commit is contained in:
parent
60f4df95bd
commit
3e33c14376
|
|
@ -1,16 +1,26 @@
|
||||||
from khaosz.config.model_config import ModelConfig
|
from khaosz.config.model_config import ModelConfig
|
||||||
from khaosz.config.param_config import BaseModelIO, ModelParameter
|
from khaosz.config.param_config import BaseModelIO, ModelParameter
|
||||||
from khaosz.config.schedule_config import ScheduleConfig, CosineScheduleConfig, SGDRScheduleConfig
|
from khaosz.config.schedule_config import (
|
||||||
|
ScheduleConfig,
|
||||||
|
CosineScheduleConfig,
|
||||||
|
SGDRScheduleConfig,
|
||||||
|
ScheduleConfigFactory
|
||||||
|
)
|
||||||
from khaosz.config.train_config import TrainConfig
|
from khaosz.config.train_config import TrainConfig
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
# Base I/O
|
||||||
"BaseModelIO",
|
"BaseModelIO",
|
||||||
"ModelParameter",
|
"ModelParameter",
|
||||||
|
|
||||||
|
# Model configuration
|
||||||
"ModelConfig",
|
"ModelConfig",
|
||||||
"TrainConfig",
|
"TrainConfig",
|
||||||
|
|
||||||
|
# Schedule configuration
|
||||||
"ScheduleConfig",
|
"ScheduleConfig",
|
||||||
"CosineScheduleConfig",
|
"CosineScheduleConfig",
|
||||||
"SGDRScheduleConfig",
|
"SGDRScheduleConfig",
|
||||||
|
"ScheduleConfigFactory",
|
||||||
]
|
]
|
||||||
|
|
@ -1,10 +1,15 @@
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, Type
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ScheduleConfig(ABC):
|
class ScheduleConfig(ABC):
|
||||||
|
"""Base configuration class for learning rate schedulers.
|
||||||
|
|
||||||
|
Provides common validation and interface for all schedule types.
|
||||||
|
"""
|
||||||
|
|
||||||
schedule_type: str = field(
|
schedule_type: str = field(
|
||||||
default="cosine",
|
default="cosine",
|
||||||
metadata={
|
metadata={
|
||||||
|
|
@ -23,6 +28,7 @@ class ScheduleConfig(ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_kwargs(self) -> Dict[str, Any]:
|
def get_kwargs(self) -> Dict[str, Any]:
|
||||||
|
"""Get configuration kwargs for scheduler creation."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
|
|
@ -35,6 +41,8 @@ class ScheduleConfig(ABC):
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CosineScheduleConfig(ScheduleConfig):
|
class CosineScheduleConfig(ScheduleConfig):
|
||||||
|
"""Cosine annealing learning rate schedule configuration."""
|
||||||
|
|
||||||
total_steps: int = field(
|
total_steps: int = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Total training steps for cosine schedule."}
|
metadata={"help": "Total training steps for cosine schedule."}
|
||||||
|
|
@ -63,6 +71,8 @@ class CosineScheduleConfig(ScheduleConfig):
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SGDRScheduleConfig(ScheduleConfig):
|
class SGDRScheduleConfig(ScheduleConfig):
|
||||||
|
"""Stochastic Gradient Descent with Warm Restarts schedule configuration."""
|
||||||
|
|
||||||
cycle_length: int = field(
|
cycle_length: int = field(
|
||||||
default=1000,
|
default=1000,
|
||||||
metadata={"help": "Length of the first cycle in steps."}
|
metadata={"help": "Length of the first cycle in steps."}
|
||||||
|
|
@ -91,3 +101,50 @@ class SGDRScheduleConfig(ScheduleConfig):
|
||||||
raise ValueError(f"cycle_length must be positive, got {self.cycle_length}")
|
raise ValueError(f"cycle_length must be positive, got {self.cycle_length}")
|
||||||
if self.t_mult < 1:
|
if self.t_mult < 1:
|
||||||
raise ValueError(f"t_mult must be >= 1, got {self.t_mult}")
|
raise ValueError(f"t_mult must be >= 1, got {self.t_mult}")
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduleConfigFactory:
|
||||||
|
"""Factory class for creating ScheduleConfig instances.
|
||||||
|
|
||||||
|
Supports both direct instantiation and factory creation methods.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
# Direct creation
|
||||||
|
config = CosineScheduleConfig(total_steps=10000)
|
||||||
|
|
||||||
|
# Factory method
|
||||||
|
config = ScheduleConfigFactory.create("cosine", total_steps=10000)
|
||||||
|
"""
|
||||||
|
|
||||||
|
CONFIG_MAP: Dict[str, Type[ScheduleConfig]] = {
|
||||||
|
"cosine": CosineScheduleConfig,
|
||||||
|
"sgdr": SGDRScheduleConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, schedule_type: str, **kwargs) -> ScheduleConfig:
|
||||||
|
"""Create a schedule config instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
schedule_type: Type of schedule ("cosine", "sgdr")
|
||||||
|
**kwargs: Arguments passed to the config constructor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ScheduleConfig instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If schedule_type is not supported
|
||||||
|
"""
|
||||||
|
if schedule_type not in cls.CONFIG_MAP:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown schedule type: '{schedule_type}'. "
|
||||||
|
f"Supported types: {sorted(cls.CONFIG_MAP.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
config_cls = cls.CONFIG_MAP[schedule_type]
|
||||||
|
return config_cls(**kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def available_types(cls) -> list:
|
||||||
|
"""Return list of available schedule type names."""
|
||||||
|
return list(cls.CONFIG_MAP.keys())
|
||||||
|
|
@ -5,20 +5,31 @@ from khaosz.data.dataset import (
|
||||||
SFTDataset,
|
SFTDataset,
|
||||||
GRPODataset,
|
GRPODataset,
|
||||||
MultiSegmentFetcher,
|
MultiSegmentFetcher,
|
||||||
DatasetLoader
|
DatasetLoader,
|
||||||
|
DatasetFactory
|
||||||
)
|
)
|
||||||
|
|
||||||
from khaosz.data.tokenizer import BpeTokenizer
|
from khaosz.data.tokenizer import BpeTokenizer
|
||||||
from khaosz.data.sampler import ResumableDistributedSampler
|
from khaosz.data.sampler import ResumableDistributedSampler
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
# Base classes
|
||||||
"BaseDataset",
|
"BaseDataset",
|
||||||
|
|
||||||
|
# Dataset implementations
|
||||||
"SEQDataset",
|
"SEQDataset",
|
||||||
"SFTDataset",
|
"SFTDataset",
|
||||||
"DPODataset",
|
"DPODataset",
|
||||||
"GRPODataset",
|
"GRPODataset",
|
||||||
|
|
||||||
|
# Fetchers
|
||||||
"MultiSegmentFetcher",
|
"MultiSegmentFetcher",
|
||||||
|
|
||||||
|
# Factory (DatasetLoader is alias for backward compatibility)
|
||||||
"DatasetLoader",
|
"DatasetLoader",
|
||||||
|
"DatasetFactory",
|
||||||
|
|
||||||
|
# Tokenizer and sampler
|
||||||
"BpeTokenizer",
|
"BpeTokenizer",
|
||||||
"ResumableDistributedSampler"
|
"ResumableDistributedSampler"
|
||||||
]
|
]
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
"""Dataset implementations with factory pattern for training."""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import bisect
|
import bisect
|
||||||
|
|
||||||
|
|
@ -8,8 +10,13 @@ from khaosz.data.serialization import load_h5
|
||||||
from typing import Callable, List, Dict, Literal, Optional, Union
|
from typing import Callable, List, Dict, Literal, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BaseSegmentFetcher:
|
class BaseSegmentFetcher:
|
||||||
|
"""Fetches data segments across multiple tensor segments.
|
||||||
|
|
||||||
|
Maintains cumulative lengths for efficient range queries across
|
||||||
|
multiple discontinuous segments.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, segments: List[Tensor]):
|
def __init__(self, segments: List[Tensor]):
|
||||||
self.segments = segments
|
self.segments = segments
|
||||||
self.cum_lengths = []
|
self.cum_lengths = []
|
||||||
|
|
@ -25,12 +32,21 @@ class BaseSegmentFetcher:
|
||||||
return self.total_length
|
return self.total_length
|
||||||
|
|
||||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||||
|
"""Fetch data in the range [begin_idx, end_idx).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
begin_idx: Starting index (inclusive)
|
||||||
|
end_idx: Ending index (exclusive)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Concatenated tensor of data in the specified range
|
||||||
|
"""
|
||||||
if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length):
|
if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length):
|
||||||
raise ValueError("begin_idx or end_idx out of bounds")
|
raise ValueError("begin_idx or end_idx out of bounds")
|
||||||
if begin_idx >= end_idx:
|
if begin_idx >= end_idx:
|
||||||
return torch.tensor([], dtype=torch.long)
|
return torch.tensor([], dtype=torch.long)
|
||||||
|
|
||||||
# fix the range index bug
|
# Find segment boundaries for the range
|
||||||
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
|
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
|
||||||
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
|
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
|
||||||
|
|
||||||
|
|
@ -47,6 +63,11 @@ class BaseSegmentFetcher:
|
||||||
|
|
||||||
|
|
||||||
class MultiSegmentFetcher:
|
class MultiSegmentFetcher:
|
||||||
|
"""Manages multiple segment fetchers for different data keys.
|
||||||
|
|
||||||
|
Each key corresponds to a different type of data (e.g., "sequence", "mask").
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, muti_segments: Dict):
|
def __init__(self, muti_segments: Dict):
|
||||||
self.muti_keys = list(muti_segments.keys())
|
self.muti_keys = list(muti_segments.keys())
|
||||||
self.muti_fetchers = {
|
self.muti_fetchers = {
|
||||||
|
|
@ -55,10 +76,21 @@ class MultiSegmentFetcher:
|
||||||
}
|
}
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
|
"""Returns the minimum length across all fetchers."""
|
||||||
len_list = [len(seg) for seg in self.muti_fetchers.values()]
|
len_list = [len(seg) for seg in self.muti_fetchers.values()]
|
||||||
return min(len_list)
|
return min(len_list)
|
||||||
|
|
||||||
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Dict:
|
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Dict:
|
||||||
|
"""Fetch data for specific keys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
begin_idx: Starting index
|
||||||
|
end_idx: Ending index
|
||||||
|
keys: Single key or list of keys to fetch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of tensors if multiple keys, single tensor if one key
|
||||||
|
"""
|
||||||
fetch_dict = {}
|
fetch_dict = {}
|
||||||
keys = [keys] if isinstance(keys, str) else keys
|
keys = [keys] if isinstance(keys, str) else keys
|
||||||
|
|
||||||
|
|
@ -70,23 +102,43 @@ class MultiSegmentFetcher:
|
||||||
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
|
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
|
||||||
|
|
||||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
|
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
|
||||||
|
"""Fetch all keys."""
|
||||||
return self.key_fetch(begin_idx, end_idx, self.muti_keys)
|
return self.key_fetch(begin_idx, end_idx, self.muti_keys)
|
||||||
|
|
||||||
|
|
||||||
class BaseDataset(Dataset, ABC):
|
class BaseDataset(Dataset, ABC):
|
||||||
|
"""Abstract base class for all dataset types.
|
||||||
|
|
||||||
|
Implements common functionality for window-based data fetching.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.segments = {}
|
self.segments = {}
|
||||||
self.window_size = window_size
|
self.window_size = window_size
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.total_samples = None
|
self.total_samples = None
|
||||||
|
self.fetcher: Optional[MultiSegmentFetcher] = None
|
||||||
|
|
||||||
def load(self, load_path: str):
|
def load(self, load_path: str):
|
||||||
|
"""Load dataset from HDF5 file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
load_path: Path to the HDF5 data file
|
||||||
|
"""
|
||||||
self.segments = load_h5(load_path)
|
self.segments = load_h5(load_path)
|
||||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||||
self.total_samples = len(self.fetcher)
|
self.total_samples = len(self.fetcher)
|
||||||
|
|
||||||
def get_index(self, index: int) -> int:
|
def get_index(self, index: int) -> tuple:
|
||||||
|
"""Calculate begin and end indices for a sample.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index: Sample index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (begin_idx, end_idx)
|
||||||
|
"""
|
||||||
assert self.total_samples > self.window_size
|
assert self.total_samples > self.window_size
|
||||||
|
|
||||||
begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size)
|
begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size)
|
||||||
|
|
@ -96,6 +148,10 @@ class BaseDataset(Dataset, ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||||
|
"""Get a single sample by index.
|
||||||
|
|
||||||
|
Must be implemented by subclasses.
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
|
|
@ -105,16 +161,109 @@ class BaseDataset(Dataset, ABC):
|
||||||
return (self.total_samples - 1 - self.window_size) // self.stride + 1
|
return (self.total_samples - 1 - self.window_size) // self.stride + 1
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetFactory:
|
||||||
|
"""Factory class for creating dataset instances.
|
||||||
|
|
||||||
|
Supports decorator-based registration for extensible dataset types.
|
||||||
|
All default dataset types (seq, sft, dpo, grpo) are registered automatically
|
||||||
|
when their classes are defined with the decorator.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
@DatasetFactory.register("custom")
|
||||||
|
class CustomDataset(BaseDataset):
|
||||||
|
...
|
||||||
|
|
||||||
|
dataset = DatasetFactory.create("custom", window_size, stride)
|
||||||
|
"""
|
||||||
|
|
||||||
|
SUPPORTED_TYPES = frozenset({"seq", "sft", "dpo", "grpo"})
|
||||||
|
DATASET_MAP: Dict[str, type] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, name: str):
|
||||||
|
"""Decorator to register a new dataset class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Registration name for the dataset type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decorator function that registers the dataset class
|
||||||
|
"""
|
||||||
|
def decorator(dataset_cls: type) -> type:
|
||||||
|
if not issubclass(dataset_cls, BaseDataset):
|
||||||
|
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
|
||||||
|
cls.DATASET_MAP[name] = dataset_cls
|
||||||
|
return dataset_cls
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, train_type: str, window_size: int, stride: int) -> BaseDataset:
|
||||||
|
"""Create a dataset instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_type: Type of training ("seq", "sft", "dpo", "grpo")
|
||||||
|
window_size: Window size for data sampling
|
||||||
|
stride: Stride between consecutive samples
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset instance
|
||||||
|
"""
|
||||||
|
if train_type not in cls.SUPPORTED_TYPES:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown dataset type: '{train_type}'. "
|
||||||
|
f"Supported types: {sorted(cls.SUPPORTED_TYPES)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if train_type not in cls.DATASET_MAP:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Dataset type '{train_type}' is supported but not yet implemented."
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_cls = cls.DATASET_MAP[train_type]
|
||||||
|
return dataset_cls(window_size, stride)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, train_type: str, load_path: str, window_size: int, stride: Optional[int] = None) -> BaseDataset:
|
||||||
|
"""Create and load a dataset in one step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_type: Type of training dataset
|
||||||
|
load_path: Path to the data file
|
||||||
|
window_size: Window size for data sampling
|
||||||
|
stride: Stride between consecutive samples (default: same as window_size)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loaded dataset instance
|
||||||
|
"""
|
||||||
|
if stride is None:
|
||||||
|
stride = window_size
|
||||||
|
|
||||||
|
dataset = cls.create(train_type, window_size, stride)
|
||||||
|
dataset.load(load_path)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def available_types(cls) -> list:
|
||||||
|
"""Return list of registered dataset type names."""
|
||||||
|
return list(cls.DATASET_MAP.keys())
|
||||||
|
|
||||||
|
|
||||||
|
# ============== Dataset Classes ==============
|
||||||
|
# All dataset classes are registered at class definition time using the decorator
|
||||||
|
|
||||||
|
|
||||||
|
@DatasetFactory.register("seq")
|
||||||
class SEQDataset(BaseDataset):
|
class SEQDataset(BaseDataset):
|
||||||
|
"""Dataset for sequential next-token prediction training."""
|
||||||
|
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||||
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
|
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
# fix the range index bug
|
|
||||||
begin_idx, end_idx = self.get_index(index)
|
begin_idx, end_idx = self.get_index(index)
|
||||||
|
|
||||||
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
|
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
|
||||||
|
|
@ -123,10 +272,12 @@ class SEQDataset(BaseDataset):
|
||||||
return {"input_ids": x, "target_ids": y}
|
return {"input_ids": x, "target_ids": y}
|
||||||
|
|
||||||
|
|
||||||
|
@DatasetFactory.register("sft")
|
||||||
class SFTDataset(BaseDataset):
|
class SFTDataset(BaseDataset):
|
||||||
|
"""Dataset for supervised fine-tuning with loss masking."""
|
||||||
|
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||||
|
|
@ -141,10 +292,12 @@ class SFTDataset(BaseDataset):
|
||||||
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
|
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
|
||||||
|
|
||||||
|
|
||||||
|
@DatasetFactory.register("dpo")
|
||||||
class DPODataset(BaseDataset):
|
class DPODataset(BaseDataset):
|
||||||
|
"""Dataset for Direct Preference Optimization training."""
|
||||||
|
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||||
|
|
@ -160,10 +313,12 @@ class DPODataset(BaseDataset):
|
||||||
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
|
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
|
||||||
|
|
||||||
|
|
||||||
|
@DatasetFactory.register("grpo")
|
||||||
class GRPODataset(BaseDataset):
|
class GRPODataset(BaseDataset):
|
||||||
|
"""Dataset for Group Relative Policy Optimization training."""
|
||||||
|
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||||
|
|
@ -179,24 +334,5 @@ class GRPODataset(BaseDataset):
|
||||||
return {"prompts": prompts, "responses": responses, "masks": masks, "rewards": rewards}
|
return {"prompts": prompts, "responses": responses, "masks": masks, "rewards": rewards}
|
||||||
|
|
||||||
|
|
||||||
class DatasetLoader:
|
# Backward compatibility alias
|
||||||
@staticmethod
|
DatasetLoader = DatasetFactory
|
||||||
def load(
|
|
||||||
train_type: Literal["seq", "sft", "dpo"],
|
|
||||||
load_path: str,
|
|
||||||
window_size: int,
|
|
||||||
stride: Optional[int] = None,
|
|
||||||
) -> BaseDataset:
|
|
||||||
if stride is None:
|
|
||||||
stride = window_size
|
|
||||||
|
|
||||||
dataset_router: Dict[str, Callable[[int], BaseDataset]] = {
|
|
||||||
"seq": lambda window_size: SEQDataset(window_size, stride),
|
|
||||||
"sft": lambda window_size: SFTDataset(window_size, stride),
|
|
||||||
"dpo": lambda window_size: DPODataset(window_size, stride),
|
|
||||||
"grpo": lambda window_size: GRPODataset(window_size, stride),
|
|
||||||
}
|
|
||||||
dataset = dataset_router[train_type](window_size)
|
|
||||||
dataset.load(load_path)
|
|
||||||
|
|
||||||
return dataset
|
|
||||||
|
|
|
||||||
|
|
@ -77,6 +77,7 @@ class GenerationRequest:
|
||||||
query: Input query (string or list of strings for batch).
|
query: Input query (string or list of strings for batch).
|
||||||
history: Conversation history.
|
history: Conversation history.
|
||||||
system_prompt: System prompt for the conversation.
|
system_prompt: System prompt for the conversation.
|
||||||
|
stream: Whether to use streaming generation.
|
||||||
"""
|
"""
|
||||||
top_k: int
|
top_k: int
|
||||||
top_p: float
|
top_p: float
|
||||||
|
|
@ -86,6 +87,7 @@ class GenerationRequest:
|
||||||
query: Union[str, List[str]]
|
query: Union[str, List[str]]
|
||||||
history: Optional[Union[HistoryType, List[HistoryType]]] = None
|
history: Optional[Union[HistoryType, List[HistoryType]]] = None
|
||||||
system_prompt: Optional[str] = None
|
system_prompt: Optional[str] = None
|
||||||
|
stream: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if not isinstance(self.top_k, int) or self.top_k < 0:
|
if not isinstance(self.top_k, int) or self.top_k < 0:
|
||||||
|
|
@ -233,33 +235,62 @@ class EmbeddingEncoder(EmbeddingEncoderCore):
|
||||||
|
|
||||||
|
|
||||||
class GeneratorFactory:
|
class GeneratorFactory:
|
||||||
"""Factory class for creating appropriate generator instances based on request features."""
|
"""Factory class for creating generator instances.
|
||||||
|
|
||||||
|
Provides smart generator selection based on request characteristics:
|
||||||
|
- Streaming: Use StreamGenerator for streaming output
|
||||||
|
- Batch: Use BatchGenerator when query is a list
|
||||||
|
- Single: Use LoopGenerator for single query non-streaming
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
generator = GeneratorFactory.create_generator(parameter, request)
|
||||||
|
result = generator.generate(request)
|
||||||
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_generator(parameter: ModelParameter, request: GenerationRequest):
|
def create_generator(parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore:
|
||||||
|
"""Create a generator based on request characteristics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parameter: Model parameters containing model, tokenizer, config
|
||||||
|
request: Generation request with query, options, etc.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Appropriate GeneratorCore subclass instance
|
||||||
"""
|
"""
|
||||||
Create a generator based on the characteristics of GenerationRequest.
|
# Streaming generation: check stream field first
|
||||||
|
if request.stream:
|
||||||
|
return StreamGenerator(parameter)
|
||||||
|
|
||||||
|
# Batch generation: query is a list of strings
|
||||||
|
if isinstance(request.query, list):
|
||||||
|
return BatchGenerator(parameter)
|
||||||
|
|
||||||
|
# Default: single query non-streaming
|
||||||
|
return LoopGenerator(parameter)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_encoder(parameter: ModelParameter) -> EmbeddingEncoderCore:
|
||||||
|
"""Create an embedding encoder instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parameter: Model parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmbeddingEncoderCore instance
|
||||||
|
"""
|
||||||
|
return EmbeddingEncoder(parameter)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore:
|
||||||
|
"""Convenience method that delegates to create_generator.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
parameter: Model parameters
|
parameter: Model parameters
|
||||||
request: Generation request
|
request: Generation request
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Subclass instance of GeneratorCore
|
Generator instance
|
||||||
"""
|
"""
|
||||||
|
return cls.create_generator(parameter, request)
|
||||||
# Streaming generation detection: check stream field
|
|
||||||
if request.stream:
|
|
||||||
return StreamGenerator(parameter)
|
|
||||||
|
|
||||||
# Batch generation detection: query is a list
|
|
||||||
if isinstance(request.query, list):
|
|
||||||
return BatchGenerator(parameter)
|
|
||||||
|
|
||||||
# Default return LoopGenerator
|
|
||||||
return LoopGenerator(parameter)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_encoder(parameter: ModelParameter):
|
|
||||||
"""Create an EmbeddingEncoder instance"""
|
|
||||||
return EmbeddingEncoder(parameter)
|
|
||||||
|
|
||||||
|
|
@ -134,8 +134,8 @@ def spawn_parallel_fn(
|
||||||
|
|
||||||
if world_size == 1:
|
if world_size == 1:
|
||||||
device_ids = device_ids or [0]
|
device_ids = device_ids or [0]
|
||||||
deice_id = torch.device(device_type, device_ids[0])
|
device_id = torch.device(device_type, device_ids[0])
|
||||||
os.environ["LOCAL_DEVICE"] = str(deice_id)
|
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||||
|
|
||||||
func(**kwargs)
|
func(**kwargs)
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -1,29 +1,33 @@
|
||||||
from khaosz.trainer.trainer import Trainer
|
from khaosz.trainer.trainer import Trainer
|
||||||
from khaosz.trainer.strategy import StrategyFactory
|
from khaosz.trainer.strategy import StrategyFactory, BaseStrategy
|
||||||
from khaosz.trainer.schedule import SchedulerFactory
|
from khaosz.trainer.schedule import SchedulerFactory, BaseScheduler
|
||||||
|
|
||||||
from khaosz.trainer.train_callback import (
|
from khaosz.trainer.train_callback import (
|
||||||
TrainCallback,
|
TrainCallback,
|
||||||
ProgressBarCallback,
|
GradientClippingCallback,
|
||||||
CheckpointCallback,
|
|
||||||
TrainCallback,
|
|
||||||
SchedulerCallback,
|
SchedulerCallback,
|
||||||
MetricLoggerCallback
|
CheckpointCallback,
|
||||||
|
ProgressBarCallback,
|
||||||
|
MetricLoggerCallback,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# trainer
|
# Main trainer
|
||||||
"Trainer",
|
"Trainer",
|
||||||
|
|
||||||
# factory
|
# Strategy factory
|
||||||
"StrategyFactory",
|
"StrategyFactory",
|
||||||
"SchedulerFactory",
|
"BaseStrategy",
|
||||||
|
|
||||||
# callback
|
# Scheduler factory
|
||||||
"TrainCallback",
|
"SchedulerFactory",
|
||||||
"ProgressBarCallback",
|
"BaseScheduler",
|
||||||
"CheckpointCallback",
|
|
||||||
|
# Callbacks
|
||||||
"TrainCallback",
|
"TrainCallback",
|
||||||
|
"GradientClippingCallback",
|
||||||
"SchedulerCallback",
|
"SchedulerCallback",
|
||||||
"MetricLoggerCallback"
|
"CheckpointCallback",
|
||||||
|
"ProgressBarCallback",
|
||||||
|
"MetricLoggerCallback",
|
||||||
]
|
]
|
||||||
|
|
@ -1,20 +1,21 @@
|
||||||
|
"""Learning rate scheduler implementations with factory pattern."""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from abc import abstractmethod, ABC
|
from abc import abstractmethod, ABC
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Type
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
from khaosz.config.schedule_config import ScheduleConfig
|
from khaosz.config.schedule_config import ScheduleConfig
|
||||||
|
|
||||||
|
|
||||||
class BaseScheduler(LRScheduler, ABC):
|
class BaseScheduler(LRScheduler, ABC):
|
||||||
"""
|
"""Base scheduler class for all other schedulers."""
|
||||||
Base scheduler class for all other schedulers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, optimizer, last_epoch: int = -1):
|
def __init__(self, optimizer, last_epoch: int = -1):
|
||||||
super().__init__(optimizer, last_epoch)
|
super().__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_lr(self) -> List[float]:
|
def get_lr(self) -> List[float]:
|
||||||
|
"""Calculate the current learning rate."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def state_dict(self) -> Dict[str, Any]:
|
def state_dict(self) -> Dict[str, Any]:
|
||||||
|
|
@ -24,10 +25,95 @@ class BaseScheduler(LRScheduler, ABC):
|
||||||
super().load_state_dict(state_dict)
|
super().load_state_dict(state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerFactory:
|
||||||
|
"""Factory class for creating learning rate schedulers.
|
||||||
|
|
||||||
|
Supports decorator-based registration for extensible scheduler types.
|
||||||
|
Also supports creation from ScheduleConfig objects.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
@SchedulerFactory.register("custom")
|
||||||
|
class CustomScheduler(BaseScheduler):
|
||||||
|
...
|
||||||
|
|
||||||
|
scheduler = SchedulerFactory.create(optimizer, "custom", **kwargs)
|
||||||
|
|
||||||
|
# Or from config
|
||||||
|
config = CosineScheduleConfig(total_steps=10000)
|
||||||
|
scheduler = SchedulerFactory.load(optimizer, config)
|
||||||
|
"""
|
||||||
|
|
||||||
|
SCHEDULER_MAP: Dict[str, Type[BaseScheduler]] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, name: str):
|
||||||
|
"""Decorator to register a new scheduler class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Registration name for the scheduler
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decorator function that registers the scheduler class
|
||||||
|
"""
|
||||||
|
def decorator(scheduler_cls: Type[BaseScheduler]) -> Type[BaseScheduler]:
|
||||||
|
if not issubclass(scheduler_cls, BaseScheduler):
|
||||||
|
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
|
||||||
|
cls.SCHEDULER_MAP[name] = scheduler_cls
|
||||||
|
return scheduler_cls
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, optimizer, schedule_type: str, **kwargs) -> BaseScheduler:
|
||||||
|
"""Create a scheduler instance by type name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer: PyTorch optimizer
|
||||||
|
schedule_type: Type of scheduler ("cosine", "sgdr")
|
||||||
|
**kwargs: Arguments passed to the scheduler constructor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Scheduler instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If schedule_type is not supported
|
||||||
|
"""
|
||||||
|
if schedule_type not in cls.SCHEDULER_MAP:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown schedule type: '{schedule_type}'. "
|
||||||
|
f"Supported types: {sorted(cls.SCHEDULER_MAP.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler_cls = cls.SCHEDULER_MAP[schedule_type]
|
||||||
|
return scheduler_cls(optimizer, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(optimizer, schedule_config: ScheduleConfig) -> BaseScheduler:
|
||||||
|
"""Create a scheduler from a ScheduleConfig object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer: PyTorch optimizer
|
||||||
|
schedule_config: ScheduleConfig instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Scheduler instance
|
||||||
|
"""
|
||||||
|
kwargs = schedule_config.get_kwargs()
|
||||||
|
schedule_type = kwargs.pop("schedule_type")
|
||||||
|
return SchedulerFactory.create(optimizer, schedule_type, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def available_types(cls) -> list:
|
||||||
|
"""Return list of registered scheduler type names."""
|
||||||
|
return list(cls.SCHEDULER_MAP.keys())
|
||||||
|
|
||||||
|
|
||||||
|
# ============== Scheduler Classes ==============
|
||||||
|
# All scheduler classes are registered at class definition time using the decorator
|
||||||
|
|
||||||
|
|
||||||
|
@SchedulerFactory.register("cosine")
|
||||||
class CosineScheduler(BaseScheduler):
|
class CosineScheduler(BaseScheduler):
|
||||||
"""
|
"""Cosine decay scheduler with warmup, implemented as PyTorch LRScheduler."""
|
||||||
Cosine decay scheduler with warmup, implemented as PyTorch LRScheduler.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -75,10 +161,9 @@ class CosineScheduler(BaseScheduler):
|
||||||
super().load_state_dict(state_dict)
|
super().load_state_dict(state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
@SchedulerFactory.register("sgdr")
|
||||||
class SGDRScheduler(BaseScheduler):
|
class SGDRScheduler(BaseScheduler):
|
||||||
"""
|
"""SGDR (Stochastic Gradient Descent with Warm Restarts) scheduler."""
|
||||||
SGDR (Stochastic Gradient Descent with Warm Restarts) scheduler,
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -142,23 +227,3 @@ class SGDRScheduler(BaseScheduler):
|
||||||
self.min_rate = state_dict.pop('min_rate')
|
self.min_rate = state_dict.pop('min_rate')
|
||||||
self.t_mult = state_dict.pop('t_mult')
|
self.t_mult = state_dict.pop('t_mult')
|
||||||
super().load_state_dict(state_dict)
|
super().load_state_dict(state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SchedulerFactory:
|
|
||||||
"""
|
|
||||||
Factory class for creating learning rate schedulers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load(optimizer, schedule_config: ScheduleConfig) -> BaseScheduler:
|
|
||||||
kwargs = schedule_config.get_kwargs()
|
|
||||||
schedule_type = kwargs.pop("schedule_type")
|
|
||||||
|
|
||||||
if schedule_type == "cosine":
|
|
||||||
return CosineScheduler(optimizer, **kwargs)
|
|
||||||
elif schedule_type == "sgdr":
|
|
||||||
return SGDRScheduler(optimizer, **kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported schedule type: {schedule_type}")
|
|
||||||
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
"""Training strategy implementations with factory pattern."""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
@ -17,9 +19,10 @@ def unwrap_model(model: nn.Module) -> nn.Module:
|
||||||
|
|
||||||
|
|
||||||
def create_ref_model(model: nn.Module) -> nn.Module:
|
def create_ref_model(model: nn.Module) -> nn.Module:
|
||||||
"""
|
"""Create a reference model for DPO/GRPO training.
|
||||||
Create a reference model for DPO/GRPO training.
|
|
||||||
Handles DDP-wrapped models safely.
|
Handles DDP-wrapped models safely by unwrapping first,
|
||||||
|
then creating a deep copy with frozen gradients.
|
||||||
"""
|
"""
|
||||||
original_model = unwrap_model(model)
|
original_model = unwrap_model(model)
|
||||||
ref_model = copy.deepcopy(original_model)
|
ref_model = copy.deepcopy(original_model)
|
||||||
|
|
@ -28,17 +31,18 @@ def create_ref_model(model: nn.Module) -> nn.Module:
|
||||||
return ref_model
|
return ref_model
|
||||||
|
|
||||||
|
|
||||||
def move_to_device(batch:Dict[str, Tensor], device: str) -> Any:
|
def move_to_device(batch: Dict[str, Tensor], device: str) -> Any:
|
||||||
|
"""Move batch tensors to specified device with non-blocking transfer."""
|
||||||
return {key: value.to(device, non_blocking=True) for key, value in batch.items()}
|
return {key: value.to(device, non_blocking=True) for key, value in batch.items()}
|
||||||
|
|
||||||
|
|
||||||
def get_logprobs(
|
def get_logprobs(
|
||||||
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
|
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
|
||||||
input_ids: Tensor,
|
input_ids: Tensor,
|
||||||
mask: Tensor,
|
mask: Tensor,
|
||||||
reduction: str,
|
reduction: str,
|
||||||
):
|
):
|
||||||
"""
|
"""Compute token-wise log probabilities from model outputs.
|
||||||
Compute token-wise log probabilities from model outputs.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: The language model
|
model: The language model
|
||||||
|
|
@ -49,7 +53,6 @@ def get_logprobs(
|
||||||
Returns:
|
Returns:
|
||||||
Log probabilities with reduction applied over sequence dimension
|
Log probabilities with reduction applied over sequence dimension
|
||||||
"""
|
"""
|
||||||
# reduction on seq_len dim
|
|
||||||
allowed_reductions = ["mean", "sum", "none"]
|
allowed_reductions = ["mean", "sum", "none"]
|
||||||
if reduction not in allowed_reductions:
|
if reduction not in allowed_reductions:
|
||||||
raise ValueError(f"reduction must be one of {allowed_reductions}, got '{reduction}'")
|
raise ValueError(f"reduction must be one of {allowed_reductions}, got '{reduction}'")
|
||||||
|
|
@ -60,7 +63,6 @@ def get_logprobs(
|
||||||
logits = model(input_ids[:, :-1], mask[:, :-1])["logits"]
|
logits = model(input_ids[:, :-1], mask[:, :-1])["logits"]
|
||||||
log_probs = torch.log_softmax(logits.float(), dim=-1)
|
log_probs = torch.log_softmax(logits.float(), dim=-1)
|
||||||
|
|
||||||
# [batch_size, seq_len - 1]
|
|
||||||
token_logprobs = torch.gather(
|
token_logprobs = torch.gather(
|
||||||
log_probs,
|
log_probs,
|
||||||
dim=-1,
|
dim=-1,
|
||||||
|
|
@ -76,20 +78,112 @@ def get_logprobs(
|
||||||
|
|
||||||
|
|
||||||
class BaseStrategy(ABC):
|
class BaseStrategy(ABC):
|
||||||
|
"""Abstract base class for training strategies."""
|
||||||
|
|
||||||
def __init__(self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str):
|
def __init__(self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
|
"""Compute loss for the given batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: Dictionary containing batch tensors
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Computed loss tensor
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __call__(self, batch: Dict[str, Tensor]) -> Tensor:
|
def __call__(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
|
"""Allow calling strategy directly as a callable."""
|
||||||
return self.compute_loss(batch)
|
return self.compute_loss(batch)
|
||||||
|
|
||||||
|
|
||||||
|
class StrategyFactory:
|
||||||
|
"""Factory class for creating training strategy instances.
|
||||||
|
|
||||||
|
Supports decorator-based registration for extensible strategy types.
|
||||||
|
All default strategies (seq, sft, dpo, grpo) are automatically registered.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
@StrategyFactory.register("custom")
|
||||||
|
class CustomStrategy(BaseStrategy):
|
||||||
|
...
|
||||||
|
|
||||||
|
strategy = StrategyFactory.create(model, "custom", device)
|
||||||
|
"""
|
||||||
|
|
||||||
|
SUPPORTED_STRATEGIES = frozenset({"seq", "sft", "dpo", "grpo"})
|
||||||
|
STRATEGY_MAP: Dict[str, type] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, name: str):
|
||||||
|
"""Decorator to register a new strategy class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Registration name for the strategy
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decorator function that registers the strategy class
|
||||||
|
"""
|
||||||
|
def decorator(strategy_cls: type) -> type:
|
||||||
|
if not issubclass(strategy_cls, BaseStrategy):
|
||||||
|
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
|
||||||
|
cls.STRATEGY_MAP[name] = strategy_cls
|
||||||
|
return strategy_cls
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, model, train_type: str, device: str, **kwargs) -> BaseStrategy:
|
||||||
|
"""Create a strategy instance based on training type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model instance for the strategy
|
||||||
|
train_type: Type of training ("seq", "sft", "dpo", "grpo")
|
||||||
|
device: Device to run the strategy on
|
||||||
|
**kwargs: Additional arguments passed to strategy constructor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Strategy instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If train_type is not supported
|
||||||
|
NotImplementedError: If train_type is in supported list but not implemented
|
||||||
|
"""
|
||||||
|
if train_type not in cls.SUPPORTED_STRATEGIES:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown training strategy: '{train_type}'. "
|
||||||
|
f"Supported strategies: {sorted(cls.SUPPORTED_STRATEGIES)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if train_type not in cls.STRATEGY_MAP:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Strategy '{train_type}' is supported but not yet implemented."
|
||||||
|
)
|
||||||
|
|
||||||
|
strategy_cls = cls.STRATEGY_MAP[train_type]
|
||||||
|
return strategy_cls(model, device, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def available_strategies(cls) -> list:
|
||||||
|
"""Return list of registered strategy names."""
|
||||||
|
return list(cls.STRATEGY_MAP.keys())
|
||||||
|
|
||||||
|
|
||||||
|
# ============== Strategy Classes ==============
|
||||||
|
# All strategies are registered at class definition time using the decorator
|
||||||
|
|
||||||
|
|
||||||
|
@StrategyFactory.register("seq")
|
||||||
class SEQStrategy(BaseStrategy):
|
class SEQStrategy(BaseStrategy):
|
||||||
def __init__(self, model, device, label_smoothing):
|
"""Standard next-token prediction training strategy.
|
||||||
|
|
||||||
|
Computes cross-entropy loss for next token prediction.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model, device, label_smoothing: float = 0.0):
|
||||||
super().__init__(model, device)
|
super().__init__(model, device)
|
||||||
self.label_smoothing = label_smoothing
|
self.label_smoothing = label_smoothing
|
||||||
|
|
||||||
|
|
@ -99,15 +193,22 @@ class SEQStrategy(BaseStrategy):
|
||||||
logits = self.model(input_ids=input_ids)["logits"]
|
logits = self.model(input_ids=input_ids)["logits"]
|
||||||
|
|
||||||
loss = F.cross_entropy(
|
loss = F.cross_entropy(
|
||||||
input=logits.flatten(0, 1).float(),
|
input=logits.flatten(0, 1).float(),
|
||||||
target=target_ids.flatten()
|
target=target_ids.flatten(),
|
||||||
|
label_smoothing=self.label_smoothing
|
||||||
)
|
)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
@StrategyFactory.register("sft")
|
||||||
class SFTStrategy(BaseStrategy):
|
class SFTStrategy(BaseStrategy):
|
||||||
def __init__(self, model, device, label_smoothing):
|
"""Supervised Fine-tuning strategy with loss masking.
|
||||||
|
|
||||||
|
Applies cross-entropy loss only to tokens where loss_mask is True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model, device, label_smoothing: float = 0.0):
|
||||||
super().__init__(model, device)
|
super().__init__(model, device)
|
||||||
self.label_smoothing = label_smoothing
|
self.label_smoothing = label_smoothing
|
||||||
|
|
||||||
|
|
@ -122,19 +223,27 @@ class SFTStrategy(BaseStrategy):
|
||||||
loss = F.cross_entropy(
|
loss = F.cross_entropy(
|
||||||
input=logits.flatten(0, 1).float(),
|
input=logits.flatten(0, 1).float(),
|
||||||
target=target_ids.flatten(),
|
target=target_ids.flatten(),
|
||||||
ignore_index=ignore_index
|
ignore_index=ignore_index,
|
||||||
|
label_smoothing=self.label_smoothing
|
||||||
)
|
)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
@StrategyFactory.register("dpo")
|
||||||
class DPOStrategy(BaseStrategy):
|
class DPOStrategy(BaseStrategy):
|
||||||
|
"""Direct Preference Optimization strategy.
|
||||||
|
|
||||||
|
Implements the DPO loss from the paper "Direct Preference Optimization".
|
||||||
|
Uses a reference model to compute KL divergence penalty.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
device: str,
|
device: str,
|
||||||
beta: float,
|
beta: float = 0.1,
|
||||||
reduction: str,
|
reduction: str = "mean",
|
||||||
):
|
):
|
||||||
super().__init__(model, device)
|
super().__init__(model, device)
|
||||||
self.ref_model = create_ref_model(model)
|
self.ref_model = create_ref_model(model)
|
||||||
|
|
@ -168,16 +277,21 @@ class DPOStrategy(BaseStrategy):
|
||||||
return dpo_loss
|
return dpo_loss
|
||||||
|
|
||||||
|
|
||||||
|
@StrategyFactory.register("grpo")
|
||||||
class GRPOStrategy(BaseStrategy):
|
class GRPOStrategy(BaseStrategy):
|
||||||
|
"""Group Relative Policy Optimization strategy.
|
||||||
|
|
||||||
|
Implements GRPO with clipping and KL penalty.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
device: str,
|
device: str,
|
||||||
clip_eps: float,
|
clip_eps: float = 0.2,
|
||||||
kl_coef: float,
|
kl_coef: float = 0.01,
|
||||||
group_size: int,
|
group_size: int = 4,
|
||||||
reduction: str,
|
reduction: str = "mean",
|
||||||
):
|
):
|
||||||
super().__init__(model, device)
|
super().__init__(model, device)
|
||||||
self.ref_model = create_ref_model(model)
|
self.ref_model = create_ref_model(model)
|
||||||
|
|
@ -209,16 +323,14 @@ class GRPOStrategy(BaseStrategy):
|
||||||
log_probs_ref = get_logprobs(self.ref_model, full_sequences, full_masks, self.reduction)
|
log_probs_ref = get_logprobs(self.ref_model, full_sequences, full_masks, self.reduction)
|
||||||
log_probs_ref = log_probs_ref.view(batch_size, group_size)
|
log_probs_ref = log_probs_ref.view(batch_size, group_size)
|
||||||
|
|
||||||
# Compute advantages from rewards
|
# Compute advantages from rewards with normalization
|
||||||
eps = torch.finfo(log_probs_policy.dtype).eps
|
eps = torch.finfo(log_probs_policy.dtype).eps
|
||||||
mean = rewards.mean(dim=-1, keepdim=True)
|
mean = rewards.mean(dim=-1, keepdim=True)
|
||||||
std = rewards.std(dim=-1, keepdim=True)
|
std = rewards.std(dim=-1, keepdim=True)
|
||||||
advantages = (rewards - mean) / (std + eps)
|
advantages = (rewards - mean) / (std + eps)
|
||||||
|
|
||||||
# log_ratio = log_probs_policy - log_probs_old
|
# PPO-style clipped surrogate objective
|
||||||
# ratio = torch.exp(log_ratio)
|
ratio = torch.exp(0) # Off-policy: policy_model = old_model
|
||||||
# off policy: policy_model = old_model, then ratio = 1
|
|
||||||
ratio = torch.exp(0)
|
|
||||||
surr1 = ratio * advantages
|
surr1 = ratio * advantages
|
||||||
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
||||||
|
|
||||||
|
|
@ -227,36 +339,3 @@ class GRPOStrategy(BaseStrategy):
|
||||||
total_loss = policy_loss + kl_penalty
|
total_loss = policy_loss + kl_penalty
|
||||||
|
|
||||||
return total_loss
|
return total_loss
|
||||||
|
|
||||||
|
|
||||||
class StrategyFactory:
|
|
||||||
|
|
||||||
def load(model, train_type, device, **kwargs):
|
|
||||||
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
|
|
||||||
"seq": lambda: SEQStrategy(
|
|
||||||
model,
|
|
||||||
device,
|
|
||||||
kwargs.get("label_smoothing", 0.0)
|
|
||||||
),
|
|
||||||
"sft": lambda: SFTStrategy(
|
|
||||||
model,
|
|
||||||
device,
|
|
||||||
kwargs.get("label_smoothing", 0.0)
|
|
||||||
),
|
|
||||||
"dpo": lambda: DPOStrategy(
|
|
||||||
model,
|
|
||||||
device,
|
|
||||||
kwargs.get("dpo_beta"),
|
|
||||||
kwargs.get("reduction", "mean")
|
|
||||||
),
|
|
||||||
"grpo": lambda: GRPOStrategy(
|
|
||||||
model,
|
|
||||||
device,
|
|
||||||
kwargs.get("grpo_clip_eps"),
|
|
||||||
kwargs.get("grpo_kl_coef"),
|
|
||||||
kwargs.get("grpo_group_size"),
|
|
||||||
kwargs.get("reduction", "mean")
|
|
||||||
)
|
|
||||||
}
|
|
||||||
strategy = train_strategy[train_type]()
|
|
||||||
return strategy
|
|
||||||
Loading…
Reference in New Issue