reafactor: 统一并增强项目中的工厂模式实现
This commit is contained in:
parent
60f4df95bd
commit
3e33c14376
|
|
@ -1,16 +1,26 @@
|
|||
from khaosz.config.model_config import ModelConfig
|
||||
from khaosz.config.param_config import BaseModelIO, ModelParameter
|
||||
from khaosz.config.schedule_config import ScheduleConfig, CosineScheduleConfig, SGDRScheduleConfig
|
||||
from khaosz.config.schedule_config import (
|
||||
ScheduleConfig,
|
||||
CosineScheduleConfig,
|
||||
SGDRScheduleConfig,
|
||||
ScheduleConfigFactory
|
||||
)
|
||||
from khaosz.config.train_config import TrainConfig
|
||||
|
||||
|
||||
__all__ = [
|
||||
# Base I/O
|
||||
"BaseModelIO",
|
||||
"ModelParameter",
|
||||
|
||||
# Model configuration
|
||||
"ModelConfig",
|
||||
"TrainConfig",
|
||||
|
||||
# Schedule configuration
|
||||
"ScheduleConfig",
|
||||
"CosineScheduleConfig",
|
||||
"SGDRScheduleConfig",
|
||||
"ScheduleConfigFactory",
|
||||
]
|
||||
|
|
@ -1,10 +1,15 @@
|
|||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Type
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScheduleConfig(ABC):
|
||||
"""Base configuration class for learning rate schedulers.
|
||||
|
||||
Provides common validation and interface for all schedule types.
|
||||
"""
|
||||
|
||||
schedule_type: str = field(
|
||||
default="cosine",
|
||||
metadata={
|
||||
|
|
@ -23,6 +28,7 @@ class ScheduleConfig(ABC):
|
|||
|
||||
@abstractmethod
|
||||
def get_kwargs(self) -> Dict[str, Any]:
|
||||
"""Get configuration kwargs for scheduler creation."""
|
||||
raise NotImplementedError
|
||||
|
||||
def validate(self) -> None:
|
||||
|
|
@ -35,6 +41,8 @@ class ScheduleConfig(ABC):
|
|||
|
||||
@dataclass
|
||||
class CosineScheduleConfig(ScheduleConfig):
|
||||
"""Cosine annealing learning rate schedule configuration."""
|
||||
|
||||
total_steps: int = field(
|
||||
default=None,
|
||||
metadata={"help": "Total training steps for cosine schedule."}
|
||||
|
|
@ -63,6 +71,8 @@ class CosineScheduleConfig(ScheduleConfig):
|
|||
|
||||
@dataclass
|
||||
class SGDRScheduleConfig(ScheduleConfig):
|
||||
"""Stochastic Gradient Descent with Warm Restarts schedule configuration."""
|
||||
|
||||
cycle_length: int = field(
|
||||
default=1000,
|
||||
metadata={"help": "Length of the first cycle in steps."}
|
||||
|
|
@ -91,3 +101,50 @@ class SGDRScheduleConfig(ScheduleConfig):
|
|||
raise ValueError(f"cycle_length must be positive, got {self.cycle_length}")
|
||||
if self.t_mult < 1:
|
||||
raise ValueError(f"t_mult must be >= 1, got {self.t_mult}")
|
||||
|
||||
|
||||
class ScheduleConfigFactory:
|
||||
"""Factory class for creating ScheduleConfig instances.
|
||||
|
||||
Supports both direct instantiation and factory creation methods.
|
||||
|
||||
Example usage:
|
||||
# Direct creation
|
||||
config = CosineScheduleConfig(total_steps=10000)
|
||||
|
||||
# Factory method
|
||||
config = ScheduleConfigFactory.create("cosine", total_steps=10000)
|
||||
"""
|
||||
|
||||
CONFIG_MAP: Dict[str, Type[ScheduleConfig]] = {
|
||||
"cosine": CosineScheduleConfig,
|
||||
"sgdr": SGDRScheduleConfig,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create(cls, schedule_type: str, **kwargs) -> ScheduleConfig:
|
||||
"""Create a schedule config instance.
|
||||
|
||||
Args:
|
||||
schedule_type: Type of schedule ("cosine", "sgdr")
|
||||
**kwargs: Arguments passed to the config constructor
|
||||
|
||||
Returns:
|
||||
ScheduleConfig instance
|
||||
|
||||
Raises:
|
||||
ValueError: If schedule_type is not supported
|
||||
"""
|
||||
if schedule_type not in cls.CONFIG_MAP:
|
||||
raise ValueError(
|
||||
f"Unknown schedule type: '{schedule_type}'. "
|
||||
f"Supported types: {sorted(cls.CONFIG_MAP.keys())}"
|
||||
)
|
||||
|
||||
config_cls = cls.CONFIG_MAP[schedule_type]
|
||||
return config_cls(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def available_types(cls) -> list:
|
||||
"""Return list of available schedule type names."""
|
||||
return list(cls.CONFIG_MAP.keys())
|
||||
|
|
@ -5,20 +5,31 @@ from khaosz.data.dataset import (
|
|||
SFTDataset,
|
||||
GRPODataset,
|
||||
MultiSegmentFetcher,
|
||||
DatasetLoader
|
||||
DatasetLoader,
|
||||
DatasetFactory
|
||||
)
|
||||
|
||||
from khaosz.data.tokenizer import BpeTokenizer
|
||||
from khaosz.data.sampler import ResumableDistributedSampler
|
||||
|
||||
__all__ = [
|
||||
# Base classes
|
||||
"BaseDataset",
|
||||
|
||||
# Dataset implementations
|
||||
"SEQDataset",
|
||||
"SFTDataset",
|
||||
"DPODataset",
|
||||
"GRPODataset",
|
||||
|
||||
# Fetchers
|
||||
"MultiSegmentFetcher",
|
||||
|
||||
# Factory (DatasetLoader is alias for backward compatibility)
|
||||
"DatasetLoader",
|
||||
"DatasetFactory",
|
||||
|
||||
# Tokenizer and sampler
|
||||
"BpeTokenizer",
|
||||
"ResumableDistributedSampler"
|
||||
]
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
"""Dataset implementations with factory pattern for training."""
|
||||
|
||||
import torch
|
||||
import bisect
|
||||
|
||||
|
|
@ -8,8 +10,13 @@ from khaosz.data.serialization import load_h5
|
|||
from typing import Callable, List, Dict, Literal, Optional, Union
|
||||
|
||||
|
||||
|
||||
class BaseSegmentFetcher:
|
||||
"""Fetches data segments across multiple tensor segments.
|
||||
|
||||
Maintains cumulative lengths for efficient range queries across
|
||||
multiple discontinuous segments.
|
||||
"""
|
||||
|
||||
def __init__(self, segments: List[Tensor]):
|
||||
self.segments = segments
|
||||
self.cum_lengths = []
|
||||
|
|
@ -25,12 +32,21 @@ class BaseSegmentFetcher:
|
|||
return self.total_length
|
||||
|
||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||
"""Fetch data in the range [begin_idx, end_idx).
|
||||
|
||||
Args:
|
||||
begin_idx: Starting index (inclusive)
|
||||
end_idx: Ending index (exclusive)
|
||||
|
||||
Returns:
|
||||
Concatenated tensor of data in the specified range
|
||||
"""
|
||||
if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length):
|
||||
raise ValueError("begin_idx or end_idx out of bounds")
|
||||
if begin_idx >= end_idx:
|
||||
return torch.tensor([], dtype=torch.long)
|
||||
|
||||
# fix the range index bug
|
||||
# Find segment boundaries for the range
|
||||
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
|
||||
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
|
||||
|
||||
|
|
@ -47,6 +63,11 @@ class BaseSegmentFetcher:
|
|||
|
||||
|
||||
class MultiSegmentFetcher:
|
||||
"""Manages multiple segment fetchers for different data keys.
|
||||
|
||||
Each key corresponds to a different type of data (e.g., "sequence", "mask").
|
||||
"""
|
||||
|
||||
def __init__(self, muti_segments: Dict):
|
||||
self.muti_keys = list(muti_segments.keys())
|
||||
self.muti_fetchers = {
|
||||
|
|
@ -55,10 +76,21 @@ class MultiSegmentFetcher:
|
|||
}
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the minimum length across all fetchers."""
|
||||
len_list = [len(seg) for seg in self.muti_fetchers.values()]
|
||||
return min(len_list)
|
||||
|
||||
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Dict:
|
||||
"""Fetch data for specific keys.
|
||||
|
||||
Args:
|
||||
begin_idx: Starting index
|
||||
end_idx: Ending index
|
||||
keys: Single key or list of keys to fetch
|
||||
|
||||
Returns:
|
||||
Dictionary of tensors if multiple keys, single tensor if one key
|
||||
"""
|
||||
fetch_dict = {}
|
||||
keys = [keys] if isinstance(keys, str) else keys
|
||||
|
||||
|
|
@ -70,23 +102,43 @@ class MultiSegmentFetcher:
|
|||
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
|
||||
|
||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
|
||||
"""Fetch all keys."""
|
||||
return self.key_fetch(begin_idx, end_idx, self.muti_keys)
|
||||
|
||||
|
||||
class BaseDataset(Dataset, ABC):
|
||||
"""Abstract base class for all dataset types.
|
||||
|
||||
Implements common functionality for window-based data fetching.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__()
|
||||
self.segments = {}
|
||||
self.window_size = window_size
|
||||
self.stride = stride
|
||||
self.total_samples = None
|
||||
self.fetcher: Optional[MultiSegmentFetcher] = None
|
||||
|
||||
def load(self, load_path: str):
|
||||
"""Load dataset from HDF5 file.
|
||||
|
||||
Args:
|
||||
load_path: Path to the HDF5 data file
|
||||
"""
|
||||
self.segments = load_h5(load_path)
|
||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||
self.total_samples = len(self.fetcher)
|
||||
|
||||
def get_index(self, index: int) -> int:
|
||||
def get_index(self, index: int) -> tuple:
|
||||
"""Calculate begin and end indices for a sample.
|
||||
|
||||
Args:
|
||||
index: Sample index
|
||||
|
||||
Returns:
|
||||
Tuple of (begin_idx, end_idx)
|
||||
"""
|
||||
assert self.total_samples > self.window_size
|
||||
|
||||
begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size)
|
||||
|
|
@ -96,6 +148,10 @@ class BaseDataset(Dataset, ABC):
|
|||
|
||||
@abstractmethod
|
||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||
"""Get a single sample by index.
|
||||
|
||||
Must be implemented by subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __len__(self) -> int:
|
||||
|
|
@ -105,16 +161,109 @@ class BaseDataset(Dataset, ABC):
|
|||
return (self.total_samples - 1 - self.window_size) // self.stride + 1
|
||||
|
||||
|
||||
class DatasetFactory:
|
||||
"""Factory class for creating dataset instances.
|
||||
|
||||
Supports decorator-based registration for extensible dataset types.
|
||||
All default dataset types (seq, sft, dpo, grpo) are registered automatically
|
||||
when their classes are defined with the decorator.
|
||||
|
||||
Example usage:
|
||||
@DatasetFactory.register("custom")
|
||||
class CustomDataset(BaseDataset):
|
||||
...
|
||||
|
||||
dataset = DatasetFactory.create("custom", window_size, stride)
|
||||
"""
|
||||
|
||||
SUPPORTED_TYPES = frozenset({"seq", "sft", "dpo", "grpo"})
|
||||
DATASET_MAP: Dict[str, type] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str):
|
||||
"""Decorator to register a new dataset class.
|
||||
|
||||
Args:
|
||||
name: Registration name for the dataset type
|
||||
|
||||
Returns:
|
||||
Decorator function that registers the dataset class
|
||||
"""
|
||||
def decorator(dataset_cls: type) -> type:
|
||||
if not issubclass(dataset_cls, BaseDataset):
|
||||
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
|
||||
cls.DATASET_MAP[name] = dataset_cls
|
||||
return dataset_cls
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def create(cls, train_type: str, window_size: int, stride: int) -> BaseDataset:
|
||||
"""Create a dataset instance.
|
||||
|
||||
Args:
|
||||
train_type: Type of training ("seq", "sft", "dpo", "grpo")
|
||||
window_size: Window size for data sampling
|
||||
stride: Stride between consecutive samples
|
||||
|
||||
Returns:
|
||||
Dataset instance
|
||||
"""
|
||||
if train_type not in cls.SUPPORTED_TYPES:
|
||||
raise ValueError(
|
||||
f"Unknown dataset type: '{train_type}'. "
|
||||
f"Supported types: {sorted(cls.SUPPORTED_TYPES)}"
|
||||
)
|
||||
|
||||
if train_type not in cls.DATASET_MAP:
|
||||
raise NotImplementedError(
|
||||
f"Dataset type '{train_type}' is supported but not yet implemented."
|
||||
)
|
||||
|
||||
dataset_cls = cls.DATASET_MAP[train_type]
|
||||
return dataset_cls(window_size, stride)
|
||||
|
||||
@classmethod
|
||||
def load(cls, train_type: str, load_path: str, window_size: int, stride: Optional[int] = None) -> BaseDataset:
|
||||
"""Create and load a dataset in one step.
|
||||
|
||||
Args:
|
||||
train_type: Type of training dataset
|
||||
load_path: Path to the data file
|
||||
window_size: Window size for data sampling
|
||||
stride: Stride between consecutive samples (default: same as window_size)
|
||||
|
||||
Returns:
|
||||
Loaded dataset instance
|
||||
"""
|
||||
if stride is None:
|
||||
stride = window_size
|
||||
|
||||
dataset = cls.create(train_type, window_size, stride)
|
||||
dataset.load(load_path)
|
||||
|
||||
return dataset
|
||||
|
||||
@classmethod
|
||||
def available_types(cls) -> list:
|
||||
"""Return list of registered dataset type names."""
|
||||
return list(cls.DATASET_MAP.keys())
|
||||
|
||||
|
||||
# ============== Dataset Classes ==============
|
||||
# All dataset classes are registered at class definition time using the decorator
|
||||
|
||||
|
||||
@DatasetFactory.register("seq")
|
||||
class SEQDataset(BaseDataset):
|
||||
"""Dataset for sequential next-token prediction training."""
|
||||
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__(window_size, stride)
|
||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
|
||||
|
||||
def __getitem__(self, index):
|
||||
# fix the range index bug
|
||||
begin_idx, end_idx = self.get_index(index)
|
||||
|
||||
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
|
||||
|
|
@ -123,10 +272,12 @@ class SEQDataset(BaseDataset):
|
|||
return {"input_ids": x, "target_ids": y}
|
||||
|
||||
|
||||
@DatasetFactory.register("sft")
|
||||
class SFTDataset(BaseDataset):
|
||||
"""Dataset for supervised fine-tuning with loss masking."""
|
||||
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__(window_size, stride)
|
||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||
|
|
@ -141,10 +292,12 @@ class SFTDataset(BaseDataset):
|
|||
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
|
||||
|
||||
|
||||
@DatasetFactory.register("dpo")
|
||||
class DPODataset(BaseDataset):
|
||||
"""Dataset for Direct Preference Optimization training."""
|
||||
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__(window_size, stride)
|
||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||
|
|
@ -160,10 +313,12 @@ class DPODataset(BaseDataset):
|
|||
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
|
||||
|
||||
|
||||
@DatasetFactory.register("grpo")
|
||||
class GRPODataset(BaseDataset):
|
||||
"""Dataset for Group Relative Policy Optimization training."""
|
||||
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__(window_size, stride)
|
||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||
|
|
@ -179,24 +334,5 @@ class GRPODataset(BaseDataset):
|
|||
return {"prompts": prompts, "responses": responses, "masks": masks, "rewards": rewards}
|
||||
|
||||
|
||||
class DatasetLoader:
|
||||
@staticmethod
|
||||
def load(
|
||||
train_type: Literal["seq", "sft", "dpo"],
|
||||
load_path: str,
|
||||
window_size: int,
|
||||
stride: Optional[int] = None,
|
||||
) -> BaseDataset:
|
||||
if stride is None:
|
||||
stride = window_size
|
||||
|
||||
dataset_router: Dict[str, Callable[[int], BaseDataset]] = {
|
||||
"seq": lambda window_size: SEQDataset(window_size, stride),
|
||||
"sft": lambda window_size: SFTDataset(window_size, stride),
|
||||
"dpo": lambda window_size: DPODataset(window_size, stride),
|
||||
"grpo": lambda window_size: GRPODataset(window_size, stride),
|
||||
}
|
||||
dataset = dataset_router[train_type](window_size)
|
||||
dataset.load(load_path)
|
||||
|
||||
return dataset
|
||||
# Backward compatibility alias
|
||||
DatasetLoader = DatasetFactory
|
||||
|
|
|
|||
|
|
@ -77,6 +77,7 @@ class GenerationRequest:
|
|||
query: Input query (string or list of strings for batch).
|
||||
history: Conversation history.
|
||||
system_prompt: System prompt for the conversation.
|
||||
stream: Whether to use streaming generation.
|
||||
"""
|
||||
top_k: int
|
||||
top_p: float
|
||||
|
|
@ -86,6 +87,7 @@ class GenerationRequest:
|
|||
query: Union[str, List[str]]
|
||||
history: Optional[Union[HistoryType, List[HistoryType]]] = None
|
||||
system_prompt: Optional[str] = None
|
||||
stream: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if not isinstance(self.top_k, int) or self.top_k < 0:
|
||||
|
|
@ -233,33 +235,62 @@ class EmbeddingEncoder(EmbeddingEncoderCore):
|
|||
|
||||
|
||||
class GeneratorFactory:
|
||||
"""Factory class for creating appropriate generator instances based on request features."""
|
||||
"""Factory class for creating generator instances.
|
||||
|
||||
Provides smart generator selection based on request characteristics:
|
||||
- Streaming: Use StreamGenerator for streaming output
|
||||
- Batch: Use BatchGenerator when query is a list
|
||||
- Single: Use LoopGenerator for single query non-streaming
|
||||
|
||||
Example usage:
|
||||
generator = GeneratorFactory.create_generator(parameter, request)
|
||||
result = generator.generate(request)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_generator(parameter: ModelParameter, request: GenerationRequest):
|
||||
def create_generator(parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore:
|
||||
"""Create a generator based on request characteristics.
|
||||
|
||||
Args:
|
||||
parameter: Model parameters containing model, tokenizer, config
|
||||
request: Generation request with query, options, etc.
|
||||
|
||||
Returns:
|
||||
Appropriate GeneratorCore subclass instance
|
||||
"""
|
||||
Create a generator based on the characteristics of GenerationRequest.
|
||||
# Streaming generation: check stream field first
|
||||
if request.stream:
|
||||
return StreamGenerator(parameter)
|
||||
|
||||
# Batch generation: query is a list of strings
|
||||
if isinstance(request.query, list):
|
||||
return BatchGenerator(parameter)
|
||||
|
||||
# Default: single query non-streaming
|
||||
return LoopGenerator(parameter)
|
||||
|
||||
@staticmethod
|
||||
def create_encoder(parameter: ModelParameter) -> EmbeddingEncoderCore:
|
||||
"""Create an embedding encoder instance.
|
||||
|
||||
Args:
|
||||
parameter: Model parameters
|
||||
|
||||
Returns:
|
||||
EmbeddingEncoderCore instance
|
||||
"""
|
||||
return EmbeddingEncoder(parameter)
|
||||
|
||||
@classmethod
|
||||
def create(cls, parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore:
|
||||
"""Convenience method that delegates to create_generator.
|
||||
|
||||
Args:
|
||||
parameter: Model parameters
|
||||
request: Generation request
|
||||
|
||||
Returns:
|
||||
Subclass instance of GeneratorCore
|
||||
Generator instance
|
||||
"""
|
||||
|
||||
# Streaming generation detection: check stream field
|
||||
if request.stream:
|
||||
return StreamGenerator(parameter)
|
||||
|
||||
# Batch generation detection: query is a list
|
||||
if isinstance(request.query, list):
|
||||
return BatchGenerator(parameter)
|
||||
|
||||
# Default return LoopGenerator
|
||||
return LoopGenerator(parameter)
|
||||
|
||||
@staticmethod
|
||||
def create_encoder(parameter: ModelParameter):
|
||||
"""Create an EmbeddingEncoder instance"""
|
||||
return EmbeddingEncoder(parameter)
|
||||
return cls.create_generator(parameter, request)
|
||||
|
||||
|
|
@ -134,8 +134,8 @@ def spawn_parallel_fn(
|
|||
|
||||
if world_size == 1:
|
||||
device_ids = device_ids or [0]
|
||||
deice_id = torch.device(device_type, device_ids[0])
|
||||
os.environ["LOCAL_DEVICE"] = str(deice_id)
|
||||
device_id = torch.device(device_type, device_ids[0])
|
||||
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||
|
||||
func(**kwargs)
|
||||
return
|
||||
|
|
|
|||
|
|
@ -1,29 +1,33 @@
|
|||
from khaosz.trainer.trainer import Trainer
|
||||
from khaosz.trainer.strategy import StrategyFactory
|
||||
from khaosz.trainer.schedule import SchedulerFactory
|
||||
from khaosz.trainer.strategy import StrategyFactory, BaseStrategy
|
||||
from khaosz.trainer.schedule import SchedulerFactory, BaseScheduler
|
||||
|
||||
from khaosz.trainer.train_callback import (
|
||||
TrainCallback,
|
||||
ProgressBarCallback,
|
||||
CheckpointCallback,
|
||||
TrainCallback,
|
||||
GradientClippingCallback,
|
||||
SchedulerCallback,
|
||||
MetricLoggerCallback
|
||||
CheckpointCallback,
|
||||
ProgressBarCallback,
|
||||
MetricLoggerCallback,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# trainer
|
||||
# Main trainer
|
||||
"Trainer",
|
||||
|
||||
# factory
|
||||
# Strategy factory
|
||||
"StrategyFactory",
|
||||
"SchedulerFactory",
|
||||
"BaseStrategy",
|
||||
|
||||
# callback
|
||||
"TrainCallback",
|
||||
"ProgressBarCallback",
|
||||
"CheckpointCallback",
|
||||
# Scheduler factory
|
||||
"SchedulerFactory",
|
||||
"BaseScheduler",
|
||||
|
||||
# Callbacks
|
||||
"TrainCallback",
|
||||
"GradientClippingCallback",
|
||||
"SchedulerCallback",
|
||||
"MetricLoggerCallback"
|
||||
"CheckpointCallback",
|
||||
"ProgressBarCallback",
|
||||
"MetricLoggerCallback",
|
||||
]
|
||||
|
|
@ -1,20 +1,21 @@
|
|||
"""Learning rate scheduler implementations with factory pattern."""
|
||||
|
||||
import math
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Type
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from khaosz.config.schedule_config import ScheduleConfig
|
||||
|
||||
|
||||
class BaseScheduler(LRScheduler, ABC):
|
||||
"""
|
||||
Base scheduler class for all other schedulers.
|
||||
"""
|
||||
"""Base scheduler class for all other schedulers."""
|
||||
|
||||
def __init__(self, optimizer, last_epoch: int = -1):
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
@abstractmethod
|
||||
def get_lr(self) -> List[float]:
|
||||
"""Calculate the current learning rate."""
|
||||
raise NotImplementedError
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
|
|
@ -24,10 +25,95 @@ class BaseScheduler(LRScheduler, ABC):
|
|||
super().load_state_dict(state_dict)
|
||||
|
||||
|
||||
class SchedulerFactory:
|
||||
"""Factory class for creating learning rate schedulers.
|
||||
|
||||
Supports decorator-based registration for extensible scheduler types.
|
||||
Also supports creation from ScheduleConfig objects.
|
||||
|
||||
Example usage:
|
||||
@SchedulerFactory.register("custom")
|
||||
class CustomScheduler(BaseScheduler):
|
||||
...
|
||||
|
||||
scheduler = SchedulerFactory.create(optimizer, "custom", **kwargs)
|
||||
|
||||
# Or from config
|
||||
config = CosineScheduleConfig(total_steps=10000)
|
||||
scheduler = SchedulerFactory.load(optimizer, config)
|
||||
"""
|
||||
|
||||
SCHEDULER_MAP: Dict[str, Type[BaseScheduler]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str):
|
||||
"""Decorator to register a new scheduler class.
|
||||
|
||||
Args:
|
||||
name: Registration name for the scheduler
|
||||
|
||||
Returns:
|
||||
Decorator function that registers the scheduler class
|
||||
"""
|
||||
def decorator(scheduler_cls: Type[BaseScheduler]) -> Type[BaseScheduler]:
|
||||
if not issubclass(scheduler_cls, BaseScheduler):
|
||||
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
|
||||
cls.SCHEDULER_MAP[name] = scheduler_cls
|
||||
return scheduler_cls
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def create(cls, optimizer, schedule_type: str, **kwargs) -> BaseScheduler:
|
||||
"""Create a scheduler instance by type name.
|
||||
|
||||
Args:
|
||||
optimizer: PyTorch optimizer
|
||||
schedule_type: Type of scheduler ("cosine", "sgdr")
|
||||
**kwargs: Arguments passed to the scheduler constructor
|
||||
|
||||
Returns:
|
||||
Scheduler instance
|
||||
|
||||
Raises:
|
||||
ValueError: If schedule_type is not supported
|
||||
"""
|
||||
if schedule_type not in cls.SCHEDULER_MAP:
|
||||
raise ValueError(
|
||||
f"Unknown schedule type: '{schedule_type}'. "
|
||||
f"Supported types: {sorted(cls.SCHEDULER_MAP.keys())}"
|
||||
)
|
||||
|
||||
scheduler_cls = cls.SCHEDULER_MAP[schedule_type]
|
||||
return scheduler_cls(optimizer, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def load(optimizer, schedule_config: ScheduleConfig) -> BaseScheduler:
|
||||
"""Create a scheduler from a ScheduleConfig object.
|
||||
|
||||
Args:
|
||||
optimizer: PyTorch optimizer
|
||||
schedule_config: ScheduleConfig instance
|
||||
|
||||
Returns:
|
||||
Scheduler instance
|
||||
"""
|
||||
kwargs = schedule_config.get_kwargs()
|
||||
schedule_type = kwargs.pop("schedule_type")
|
||||
return SchedulerFactory.create(optimizer, schedule_type, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def available_types(cls) -> list:
|
||||
"""Return list of registered scheduler type names."""
|
||||
return list(cls.SCHEDULER_MAP.keys())
|
||||
|
||||
|
||||
# ============== Scheduler Classes ==============
|
||||
# All scheduler classes are registered at class definition time using the decorator
|
||||
|
||||
|
||||
@SchedulerFactory.register("cosine")
|
||||
class CosineScheduler(BaseScheduler):
|
||||
"""
|
||||
Cosine decay scheduler with warmup, implemented as PyTorch LRScheduler.
|
||||
"""
|
||||
"""Cosine decay scheduler with warmup, implemented as PyTorch LRScheduler."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -75,10 +161,9 @@ class CosineScheduler(BaseScheduler):
|
|||
super().load_state_dict(state_dict)
|
||||
|
||||
|
||||
@SchedulerFactory.register("sgdr")
|
||||
class SGDRScheduler(BaseScheduler):
|
||||
"""
|
||||
SGDR (Stochastic Gradient Descent with Warm Restarts) scheduler,
|
||||
"""
|
||||
"""SGDR (Stochastic Gradient Descent with Warm Restarts) scheduler."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -142,23 +227,3 @@ class SGDRScheduler(BaseScheduler):
|
|||
self.min_rate = state_dict.pop('min_rate')
|
||||
self.t_mult = state_dict.pop('t_mult')
|
||||
super().load_state_dict(state_dict)
|
||||
|
||||
|
||||
|
||||
class SchedulerFactory:
|
||||
"""
|
||||
Factory class for creating learning rate schedulers.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(optimizer, schedule_config: ScheduleConfig) -> BaseScheduler:
|
||||
kwargs = schedule_config.get_kwargs()
|
||||
schedule_type = kwargs.pop("schedule_type")
|
||||
|
||||
if schedule_type == "cosine":
|
||||
return CosineScheduler(optimizer, **kwargs)
|
||||
elif schedule_type == "sgdr":
|
||||
return SGDRScheduler(optimizer, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported schedule type: {schedule_type}")
|
||||
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
"""Training strategy implementations with factory pattern."""
|
||||
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -17,9 +19,10 @@ def unwrap_model(model: nn.Module) -> nn.Module:
|
|||
|
||||
|
||||
def create_ref_model(model: nn.Module) -> nn.Module:
|
||||
"""
|
||||
Create a reference model for DPO/GRPO training.
|
||||
Handles DDP-wrapped models safely.
|
||||
"""Create a reference model for DPO/GRPO training.
|
||||
|
||||
Handles DDP-wrapped models safely by unwrapping first,
|
||||
then creating a deep copy with frozen gradients.
|
||||
"""
|
||||
original_model = unwrap_model(model)
|
||||
ref_model = copy.deepcopy(original_model)
|
||||
|
|
@ -28,17 +31,18 @@ def create_ref_model(model: nn.Module) -> nn.Module:
|
|||
return ref_model
|
||||
|
||||
|
||||
def move_to_device(batch:Dict[str, Tensor], device: str) -> Any:
|
||||
def move_to_device(batch: Dict[str, Tensor], device: str) -> Any:
|
||||
"""Move batch tensors to specified device with non-blocking transfer."""
|
||||
return {key: value.to(device, non_blocking=True) for key, value in batch.items()}
|
||||
|
||||
|
||||
def get_logprobs(
|
||||
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
|
||||
input_ids: Tensor,
|
||||
mask: Tensor,
|
||||
reduction: str,
|
||||
):
|
||||
"""
|
||||
Compute token-wise log probabilities from model outputs.
|
||||
"""Compute token-wise log probabilities from model outputs.
|
||||
|
||||
Args:
|
||||
model: The language model
|
||||
|
|
@ -49,7 +53,6 @@ def get_logprobs(
|
|||
Returns:
|
||||
Log probabilities with reduction applied over sequence dimension
|
||||
"""
|
||||
# reduction on seq_len dim
|
||||
allowed_reductions = ["mean", "sum", "none"]
|
||||
if reduction not in allowed_reductions:
|
||||
raise ValueError(f"reduction must be one of {allowed_reductions}, got '{reduction}'")
|
||||
|
|
@ -60,7 +63,6 @@ def get_logprobs(
|
|||
logits = model(input_ids[:, :-1], mask[:, :-1])["logits"]
|
||||
log_probs = torch.log_softmax(logits.float(), dim=-1)
|
||||
|
||||
# [batch_size, seq_len - 1]
|
||||
token_logprobs = torch.gather(
|
||||
log_probs,
|
||||
dim=-1,
|
||||
|
|
@ -76,20 +78,112 @@ def get_logprobs(
|
|||
|
||||
|
||||
class BaseStrategy(ABC):
|
||||
"""Abstract base class for training strategies."""
|
||||
|
||||
def __init__(self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str):
|
||||
self.model = model
|
||||
self.device = device
|
||||
|
||||
@abstractmethod
|
||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
"""Compute loss for the given batch.
|
||||
|
||||
Args:
|
||||
batch: Dictionary containing batch tensors
|
||||
|
||||
Returns:
|
||||
Computed loss tensor
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
"""Allow calling strategy directly as a callable."""
|
||||
return self.compute_loss(batch)
|
||||
|
||||
|
||||
class StrategyFactory:
|
||||
"""Factory class for creating training strategy instances.
|
||||
|
||||
Supports decorator-based registration for extensible strategy types.
|
||||
All default strategies (seq, sft, dpo, grpo) are automatically registered.
|
||||
|
||||
Example usage:
|
||||
@StrategyFactory.register("custom")
|
||||
class CustomStrategy(BaseStrategy):
|
||||
...
|
||||
|
||||
strategy = StrategyFactory.create(model, "custom", device)
|
||||
"""
|
||||
|
||||
SUPPORTED_STRATEGIES = frozenset({"seq", "sft", "dpo", "grpo"})
|
||||
STRATEGY_MAP: Dict[str, type] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str):
|
||||
"""Decorator to register a new strategy class.
|
||||
|
||||
Args:
|
||||
name: Registration name for the strategy
|
||||
|
||||
Returns:
|
||||
Decorator function that registers the strategy class
|
||||
"""
|
||||
def decorator(strategy_cls: type) -> type:
|
||||
if not issubclass(strategy_cls, BaseStrategy):
|
||||
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
|
||||
cls.STRATEGY_MAP[name] = strategy_cls
|
||||
return strategy_cls
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def create(cls, model, train_type: str, device: str, **kwargs) -> BaseStrategy:
|
||||
"""Create a strategy instance based on training type.
|
||||
|
||||
Args:
|
||||
model: Model instance for the strategy
|
||||
train_type: Type of training ("seq", "sft", "dpo", "grpo")
|
||||
device: Device to run the strategy on
|
||||
**kwargs: Additional arguments passed to strategy constructor
|
||||
|
||||
Returns:
|
||||
Strategy instance
|
||||
|
||||
Raises:
|
||||
ValueError: If train_type is not supported
|
||||
NotImplementedError: If train_type is in supported list but not implemented
|
||||
"""
|
||||
if train_type not in cls.SUPPORTED_STRATEGIES:
|
||||
raise ValueError(
|
||||
f"Unknown training strategy: '{train_type}'. "
|
||||
f"Supported strategies: {sorted(cls.SUPPORTED_STRATEGIES)}"
|
||||
)
|
||||
|
||||
if train_type not in cls.STRATEGY_MAP:
|
||||
raise NotImplementedError(
|
||||
f"Strategy '{train_type}' is supported but not yet implemented."
|
||||
)
|
||||
|
||||
strategy_cls = cls.STRATEGY_MAP[train_type]
|
||||
return strategy_cls(model, device, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def available_strategies(cls) -> list:
|
||||
"""Return list of registered strategy names."""
|
||||
return list(cls.STRATEGY_MAP.keys())
|
||||
|
||||
|
||||
# ============== Strategy Classes ==============
|
||||
# All strategies are registered at class definition time using the decorator
|
||||
|
||||
|
||||
@StrategyFactory.register("seq")
|
||||
class SEQStrategy(BaseStrategy):
|
||||
def __init__(self, model, device, label_smoothing):
|
||||
"""Standard next-token prediction training strategy.
|
||||
|
||||
Computes cross-entropy loss for next token prediction.
|
||||
"""
|
||||
|
||||
def __init__(self, model, device, label_smoothing: float = 0.0):
|
||||
super().__init__(model, device)
|
||||
self.label_smoothing = label_smoothing
|
||||
|
||||
|
|
@ -99,15 +193,22 @@ class SEQStrategy(BaseStrategy):
|
|||
logits = self.model(input_ids=input_ids)["logits"]
|
||||
|
||||
loss = F.cross_entropy(
|
||||
input=logits.flatten(0, 1).float(),
|
||||
target=target_ids.flatten()
|
||||
input=logits.flatten(0, 1).float(),
|
||||
target=target_ids.flatten(),
|
||||
label_smoothing=self.label_smoothing
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
@StrategyFactory.register("sft")
|
||||
class SFTStrategy(BaseStrategy):
|
||||
def __init__(self, model, device, label_smoothing):
|
||||
"""Supervised Fine-tuning strategy with loss masking.
|
||||
|
||||
Applies cross-entropy loss only to tokens where loss_mask is True.
|
||||
"""
|
||||
|
||||
def __init__(self, model, device, label_smoothing: float = 0.0):
|
||||
super().__init__(model, device)
|
||||
self.label_smoothing = label_smoothing
|
||||
|
||||
|
|
@ -122,19 +223,27 @@ class SFTStrategy(BaseStrategy):
|
|||
loss = F.cross_entropy(
|
||||
input=logits.flatten(0, 1).float(),
|
||||
target=target_ids.flatten(),
|
||||
ignore_index=ignore_index
|
||||
ignore_index=ignore_index,
|
||||
label_smoothing=self.label_smoothing
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
@StrategyFactory.register("dpo")
|
||||
class DPOStrategy(BaseStrategy):
|
||||
"""Direct Preference Optimization strategy.
|
||||
|
||||
Implements the DPO loss from the paper "Direct Preference Optimization".
|
||||
Uses a reference model to compute KL divergence penalty.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
device: str,
|
||||
beta: float,
|
||||
reduction: str,
|
||||
beta: float = 0.1,
|
||||
reduction: str = "mean",
|
||||
):
|
||||
super().__init__(model, device)
|
||||
self.ref_model = create_ref_model(model)
|
||||
|
|
@ -168,16 +277,21 @@ class DPOStrategy(BaseStrategy):
|
|||
return dpo_loss
|
||||
|
||||
|
||||
@StrategyFactory.register("grpo")
|
||||
class GRPOStrategy(BaseStrategy):
|
||||
"""Group Relative Policy Optimization strategy.
|
||||
|
||||
Implements GRPO with clipping and KL penalty.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
device: str,
|
||||
clip_eps: float,
|
||||
kl_coef: float,
|
||||
group_size: int,
|
||||
reduction: str,
|
||||
clip_eps: float = 0.2,
|
||||
kl_coef: float = 0.01,
|
||||
group_size: int = 4,
|
||||
reduction: str = "mean",
|
||||
):
|
||||
super().__init__(model, device)
|
||||
self.ref_model = create_ref_model(model)
|
||||
|
|
@ -209,16 +323,14 @@ class GRPOStrategy(BaseStrategy):
|
|||
log_probs_ref = get_logprobs(self.ref_model, full_sequences, full_masks, self.reduction)
|
||||
log_probs_ref = log_probs_ref.view(batch_size, group_size)
|
||||
|
||||
# Compute advantages from rewards
|
||||
# Compute advantages from rewards with normalization
|
||||
eps = torch.finfo(log_probs_policy.dtype).eps
|
||||
mean = rewards.mean(dim=-1, keepdim=True)
|
||||
std = rewards.std(dim=-1, keepdim=True)
|
||||
advantages = (rewards - mean) / (std + eps)
|
||||
|
||||
# log_ratio = log_probs_policy - log_probs_old
|
||||
# ratio = torch.exp(log_ratio)
|
||||
# off policy: policy_model = old_model, then ratio = 1
|
||||
ratio = torch.exp(0)
|
||||
# PPO-style clipped surrogate objective
|
||||
ratio = torch.exp(0) # Off-policy: policy_model = old_model
|
||||
surr1 = ratio * advantages
|
||||
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
||||
|
||||
|
|
@ -227,36 +339,3 @@ class GRPOStrategy(BaseStrategy):
|
|||
total_loss = policy_loss + kl_penalty
|
||||
|
||||
return total_loss
|
||||
|
||||
|
||||
class StrategyFactory:
|
||||
|
||||
def load(model, train_type, device, **kwargs):
|
||||
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
|
||||
"seq": lambda: SEQStrategy(
|
||||
model,
|
||||
device,
|
||||
kwargs.get("label_smoothing", 0.0)
|
||||
),
|
||||
"sft": lambda: SFTStrategy(
|
||||
model,
|
||||
device,
|
||||
kwargs.get("label_smoothing", 0.0)
|
||||
),
|
||||
"dpo": lambda: DPOStrategy(
|
||||
model,
|
||||
device,
|
||||
kwargs.get("dpo_beta"),
|
||||
kwargs.get("reduction", "mean")
|
||||
),
|
||||
"grpo": lambda: GRPOStrategy(
|
||||
model,
|
||||
device,
|
||||
kwargs.get("grpo_clip_eps"),
|
||||
kwargs.get("grpo_kl_coef"),
|
||||
kwargs.get("grpo_group_size"),
|
||||
kwargs.get("reduction", "mean")
|
||||
)
|
||||
}
|
||||
strategy = train_strategy[train_type]()
|
||||
return strategy
|
||||
Loading…
Reference in New Issue