diff --git a/khaosz/config/__init__.py b/khaosz/config/__init__.py index 1caac27..c153017 100644 --- a/khaosz/config/__init__.py +++ b/khaosz/config/__init__.py @@ -1,16 +1,26 @@ from khaosz.config.model_config import ModelConfig from khaosz.config.param_config import BaseModelIO, ModelParameter -from khaosz.config.schedule_config import ScheduleConfig, CosineScheduleConfig, SGDRScheduleConfig +from khaosz.config.schedule_config import ( + ScheduleConfig, + CosineScheduleConfig, + SGDRScheduleConfig, + ScheduleConfigFactory +) from khaosz.config.train_config import TrainConfig __all__ = [ + # Base I/O "BaseModelIO", "ModelParameter", + + # Model configuration "ModelConfig", "TrainConfig", + # Schedule configuration "ScheduleConfig", "CosineScheduleConfig", "SGDRScheduleConfig", + "ScheduleConfigFactory", ] \ No newline at end of file diff --git a/khaosz/config/schedule_config.py b/khaosz/config/schedule_config.py index 10c9d39..5a53a33 100644 --- a/khaosz/config/schedule_config.py +++ b/khaosz/config/schedule_config.py @@ -1,10 +1,15 @@ -from typing import Any, Dict +from typing import Any, Dict, Type from abc import ABC, abstractmethod from dataclasses import dataclass, field @dataclass class ScheduleConfig(ABC): + """Base configuration class for learning rate schedulers. + + Provides common validation and interface for all schedule types. + """ + schedule_type: str = field( default="cosine", metadata={ @@ -23,6 +28,7 @@ class ScheduleConfig(ABC): @abstractmethod def get_kwargs(self) -> Dict[str, Any]: + """Get configuration kwargs for scheduler creation.""" raise NotImplementedError def validate(self) -> None: @@ -35,6 +41,8 @@ class ScheduleConfig(ABC): @dataclass class CosineScheduleConfig(ScheduleConfig): + """Cosine annealing learning rate schedule configuration.""" + total_steps: int = field( default=None, metadata={"help": "Total training steps for cosine schedule."} @@ -63,6 +71,8 @@ class CosineScheduleConfig(ScheduleConfig): @dataclass class SGDRScheduleConfig(ScheduleConfig): + """Stochastic Gradient Descent with Warm Restarts schedule configuration.""" + cycle_length: int = field( default=1000, metadata={"help": "Length of the first cycle in steps."} @@ -90,4 +100,51 @@ class SGDRScheduleConfig(ScheduleConfig): if self.cycle_length <= 0: raise ValueError(f"cycle_length must be positive, got {self.cycle_length}") if self.t_mult < 1: - raise ValueError(f"t_mult must be >= 1, got {self.t_mult}") \ No newline at end of file + raise ValueError(f"t_mult must be >= 1, got {self.t_mult}") + + +class ScheduleConfigFactory: + """Factory class for creating ScheduleConfig instances. + + Supports both direct instantiation and factory creation methods. + + Example usage: + # Direct creation + config = CosineScheduleConfig(total_steps=10000) + + # Factory method + config = ScheduleConfigFactory.create("cosine", total_steps=10000) + """ + + CONFIG_MAP: Dict[str, Type[ScheduleConfig]] = { + "cosine": CosineScheduleConfig, + "sgdr": SGDRScheduleConfig, + } + + @classmethod + def create(cls, schedule_type: str, **kwargs) -> ScheduleConfig: + """Create a schedule config instance. + + Args: + schedule_type: Type of schedule ("cosine", "sgdr") + **kwargs: Arguments passed to the config constructor + + Returns: + ScheduleConfig instance + + Raises: + ValueError: If schedule_type is not supported + """ + if schedule_type not in cls.CONFIG_MAP: + raise ValueError( + f"Unknown schedule type: '{schedule_type}'. " + f"Supported types: {sorted(cls.CONFIG_MAP.keys())}" + ) + + config_cls = cls.CONFIG_MAP[schedule_type] + return config_cls(**kwargs) + + @classmethod + def available_types(cls) -> list: + """Return list of available schedule type names.""" + return list(cls.CONFIG_MAP.keys()) \ No newline at end of file diff --git a/khaosz/data/__init__.py b/khaosz/data/__init__.py index 44648a8..a10af3e 100644 --- a/khaosz/data/__init__.py +++ b/khaosz/data/__init__.py @@ -5,20 +5,31 @@ from khaosz.data.dataset import ( SFTDataset, GRPODataset, MultiSegmentFetcher, - DatasetLoader + DatasetLoader, + DatasetFactory ) from khaosz.data.tokenizer import BpeTokenizer from khaosz.data.sampler import ResumableDistributedSampler __all__ = [ + # Base classes "BaseDataset", + + # Dataset implementations "SEQDataset", "SFTDataset", "DPODataset", "GRPODataset", + + # Fetchers "MultiSegmentFetcher", + + # Factory (DatasetLoader is alias for backward compatibility) "DatasetLoader", + "DatasetFactory", + + # Tokenizer and sampler "BpeTokenizer", "ResumableDistributedSampler" ] \ No newline at end of file diff --git a/khaosz/data/dataset.py b/khaosz/data/dataset.py index bd40f5b..fc64ae3 100644 --- a/khaosz/data/dataset.py +++ b/khaosz/data/dataset.py @@ -1,3 +1,5 @@ +"""Dataset implementations with factory pattern for training.""" + import torch import bisect @@ -8,8 +10,13 @@ from khaosz.data.serialization import load_h5 from typing import Callable, List, Dict, Literal, Optional, Union - class BaseSegmentFetcher: + """Fetches data segments across multiple tensor segments. + + Maintains cumulative lengths for efficient range queries across + multiple discontinuous segments. + """ + def __init__(self, segments: List[Tensor]): self.segments = segments self.cum_lengths = [] @@ -25,12 +32,21 @@ class BaseSegmentFetcher: return self.total_length def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: + """Fetch data in the range [begin_idx, end_idx). + + Args: + begin_idx: Starting index (inclusive) + end_idx: Ending index (exclusive) + + Returns: + Concatenated tensor of data in the specified range + """ if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length): raise ValueError("begin_idx or end_idx out of bounds") if begin_idx >= end_idx: return torch.tensor([], dtype=torch.long) - # fix the range index bug + # Find segment boundaries for the range seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx) seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx) @@ -44,9 +60,14 @@ class BaseSegmentFetcher: result_segments.append(data) return torch.cat(result_segments, dim=0) - + class MultiSegmentFetcher: + """Manages multiple segment fetchers for different data keys. + + Each key corresponds to a different type of data (e.g., "sequence", "mask"). + """ + def __init__(self, muti_segments: Dict): self.muti_keys = list(muti_segments.keys()) self.muti_fetchers = { @@ -55,10 +76,21 @@ class MultiSegmentFetcher: } def __len__(self) -> int: + """Returns the minimum length across all fetchers.""" len_list = [len(seg) for seg in self.muti_fetchers.values()] return min(len_list) def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Dict: + """Fetch data for specific keys. + + Args: + begin_idx: Starting index + end_idx: Ending index + keys: Single key or list of keys to fetch + + Returns: + Dictionary of tensors if multiple keys, single tensor if one key + """ fetch_dict = {} keys = [keys] if isinstance(keys, str) else keys @@ -70,32 +102,56 @@ class MultiSegmentFetcher: return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]] def fetch_data(self, begin_idx: int, end_idx: int) -> Dict: + """Fetch all keys.""" return self.key_fetch(begin_idx, end_idx, self.muti_keys) class BaseDataset(Dataset, ABC): + """Abstract base class for all dataset types. + + Implements common functionality for window-based data fetching. + """ + def __init__(self, window_size: int, stride: int): super().__init__() self.segments = {} self.window_size = window_size self.stride = stride self.total_samples = None + self.fetcher: Optional[MultiSegmentFetcher] = None def load(self, load_path: str): + """Load dataset from HDF5 file. + + Args: + load_path: Path to the HDF5 data file + """ self.segments = load_h5(load_path) self.fetcher = MultiSegmentFetcher(self.segments) self.total_samples = len(self.fetcher) - def get_index(self, index: int) -> int: + def get_index(self, index: int) -> tuple: + """Calculate begin and end indices for a sample. + + Args: + index: Sample index + + Returns: + Tuple of (begin_idx, end_idx) + """ assert self.total_samples > self.window_size begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size) end_idx = min(begin_idx + self.window_size, self.total_samples - 1) return begin_idx, end_idx - + @abstractmethod def __getitem__(self, index: int) -> Dict[str, Tensor]: + """Get a single sample by index. + + Must be implemented by subclasses. + """ raise NotImplementedError def __len__(self) -> int: @@ -103,30 +159,125 @@ class BaseDataset(Dataset, ABC): if self.total_samples <= self.window_size: return 0 return (self.total_samples - 1 - self.window_size) // self.stride + 1 - + +class DatasetFactory: + """Factory class for creating dataset instances. + + Supports decorator-based registration for extensible dataset types. + All default dataset types (seq, sft, dpo, grpo) are registered automatically + when their classes are defined with the decorator. + + Example usage: + @DatasetFactory.register("custom") + class CustomDataset(BaseDataset): + ... + + dataset = DatasetFactory.create("custom", window_size, stride) + """ + + SUPPORTED_TYPES = frozenset({"seq", "sft", "dpo", "grpo"}) + DATASET_MAP: Dict[str, type] = {} + + @classmethod + def register(cls, name: str): + """Decorator to register a new dataset class. + + Args: + name: Registration name for the dataset type + + Returns: + Decorator function that registers the dataset class + """ + def decorator(dataset_cls: type) -> type: + if not issubclass(dataset_cls, BaseDataset): + raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset") + cls.DATASET_MAP[name] = dataset_cls + return dataset_cls + return decorator + + @classmethod + def create(cls, train_type: str, window_size: int, stride: int) -> BaseDataset: + """Create a dataset instance. + + Args: + train_type: Type of training ("seq", "sft", "dpo", "grpo") + window_size: Window size for data sampling + stride: Stride between consecutive samples + + Returns: + Dataset instance + """ + if train_type not in cls.SUPPORTED_TYPES: + raise ValueError( + f"Unknown dataset type: '{train_type}'. " + f"Supported types: {sorted(cls.SUPPORTED_TYPES)}" + ) + + if train_type not in cls.DATASET_MAP: + raise NotImplementedError( + f"Dataset type '{train_type}' is supported but not yet implemented." + ) + + dataset_cls = cls.DATASET_MAP[train_type] + return dataset_cls(window_size, stride) + + @classmethod + def load(cls, train_type: str, load_path: str, window_size: int, stride: Optional[int] = None) -> BaseDataset: + """Create and load a dataset in one step. + + Args: + train_type: Type of training dataset + load_path: Path to the data file + window_size: Window size for data sampling + stride: Stride between consecutive samples (default: same as window_size) + + Returns: + Loaded dataset instance + """ + if stride is None: + stride = window_size + + dataset = cls.create(train_type, window_size, stride) + dataset.load(load_path) + + return dataset + + @classmethod + def available_types(cls) -> list: + """Return list of registered dataset type names.""" + return list(cls.DATASET_MAP.keys()) + + +# ============== Dataset Classes ============== +# All dataset classes are registered at class definition time using the decorator + + +@DatasetFactory.register("seq") class SEQDataset(BaseDataset): + """Dataset for sequential next-token prediction training.""" + def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) - self.fetcher = MultiSegmentFetcher(self.segments) def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: return self.fetcher.key_fetch(begin_idx, end_idx, "sequence") def __getitem__(self, index): - # fix the range index bug begin_idx, end_idx = self.get_index(index) x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long) y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long) return {"input_ids": x, "target_ids": y} - - + + +@DatasetFactory.register("sft") class SFTDataset(BaseDataset): + """Dataset for supervised fine-tuning with loss masking.""" + def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) - self.fetcher = MultiSegmentFetcher(self.segments) def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: return self.fetcher.key_fetch(begin_idx, end_idx, key) @@ -141,10 +292,12 @@ class SFTDataset(BaseDataset): return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask} +@DatasetFactory.register("dpo") class DPODataset(BaseDataset): + """Dataset for Direct Preference Optimization training.""" + def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) - self.fetcher = MultiSegmentFetcher(self.segments) def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: return self.fetcher.key_fetch(begin_idx, end_idx, key) @@ -160,10 +313,12 @@ class DPODataset(BaseDataset): return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask} +@DatasetFactory.register("grpo") class GRPODataset(BaseDataset): + """Dataset for Group Relative Policy Optimization training.""" + def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) - self.fetcher = MultiSegmentFetcher(self.segments) def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: return self.fetcher.key_fetch(begin_idx, end_idx, key) @@ -179,24 +334,5 @@ class GRPODataset(BaseDataset): return {"prompts": prompts, "responses": responses, "masks": masks, "rewards": rewards} -class DatasetLoader: - @staticmethod - def load( - train_type: Literal["seq", "sft", "dpo"], - load_path: str, - window_size: int, - stride: Optional[int] = None, - ) -> BaseDataset: - if stride is None: - stride = window_size - - dataset_router: Dict[str, Callable[[int], BaseDataset]] = { - "seq": lambda window_size: SEQDataset(window_size, stride), - "sft": lambda window_size: SFTDataset(window_size, stride), - "dpo": lambda window_size: DPODataset(window_size, stride), - "grpo": lambda window_size: GRPODataset(window_size, stride), - } - dataset = dataset_router[train_type](window_size) - dataset.load(load_path) - - return dataset +# Backward compatibility alias +DatasetLoader = DatasetFactory diff --git a/khaosz/inference/generator.py b/khaosz/inference/generator.py index d42489a..e047899 100644 --- a/khaosz/inference/generator.py +++ b/khaosz/inference/generator.py @@ -77,6 +77,7 @@ class GenerationRequest: query: Input query (string or list of strings for batch). history: Conversation history. system_prompt: System prompt for the conversation. + stream: Whether to use streaming generation. """ top_k: int top_p: float @@ -86,6 +87,7 @@ class GenerationRequest: query: Union[str, List[str]] history: Optional[Union[HistoryType, List[HistoryType]]] = None system_prompt: Optional[str] = None + stream: bool = False def __post_init__(self): if not isinstance(self.top_k, int) or self.top_k < 0: @@ -233,33 +235,62 @@ class EmbeddingEncoder(EmbeddingEncoderCore): class GeneratorFactory: - """Factory class for creating appropriate generator instances based on request features.""" + """Factory class for creating generator instances. + + Provides smart generator selection based on request characteristics: + - Streaming: Use StreamGenerator for streaming output + - Batch: Use BatchGenerator when query is a list + - Single: Use LoopGenerator for single query non-streaming + + Example usage: + generator = GeneratorFactory.create_generator(parameter, request) + result = generator.generate(request) + """ @staticmethod - def create_generator(parameter: ModelParameter, request: GenerationRequest): - """ - Create a generator based on the characteristics of GenerationRequest. - Args: - parameter: Model parameters - request: Generation request + def create_generator(parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore: + """Create a generator based on request characteristics. + Args: + parameter: Model parameters containing model, tokenizer, config + request: Generation request with query, options, etc. + Returns: - Subclass instance of GeneratorCore + Appropriate GeneratorCore subclass instance """ - - # Streaming generation detection: check stream field + # Streaming generation: check stream field first if request.stream: return StreamGenerator(parameter) - # Batch generation detection: query is a list + # Batch generation: query is a list of strings if isinstance(request.query, list): return BatchGenerator(parameter) - # Default return LoopGenerator + # Default: single query non-streaming return LoopGenerator(parameter) @staticmethod - def create_encoder(parameter: ModelParameter): - """Create an EmbeddingEncoder instance""" + def create_encoder(parameter: ModelParameter) -> EmbeddingEncoderCore: + """Create an embedding encoder instance. + + Args: + parameter: Model parameters + + Returns: + EmbeddingEncoderCore instance + """ return EmbeddingEncoder(parameter) + + @classmethod + def create(cls, parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore: + """Convenience method that delegates to create_generator. + + Args: + parameter: Model parameters + request: Generation request + + Returns: + Generator instance + """ + return cls.create_generator(parameter, request) \ No newline at end of file diff --git a/khaosz/parallel/setup.py b/khaosz/parallel/setup.py index 932dc38..c1c6686 100644 --- a/khaosz/parallel/setup.py +++ b/khaosz/parallel/setup.py @@ -134,8 +134,8 @@ def spawn_parallel_fn( if world_size == 1: device_ids = device_ids or [0] - deice_id = torch.device(device_type, device_ids[0]) - os.environ["LOCAL_DEVICE"] = str(deice_id) + device_id = torch.device(device_type, device_ids[0]) + os.environ["LOCAL_DEVICE"] = str(device_id) func(**kwargs) return diff --git a/khaosz/trainer/__init__.py b/khaosz/trainer/__init__.py index bbe99ee..af1bd3b 100644 --- a/khaosz/trainer/__init__.py +++ b/khaosz/trainer/__init__.py @@ -1,29 +1,33 @@ from khaosz.trainer.trainer import Trainer -from khaosz.trainer.strategy import StrategyFactory -from khaosz.trainer.schedule import SchedulerFactory +from khaosz.trainer.strategy import StrategyFactory, BaseStrategy +from khaosz.trainer.schedule import SchedulerFactory, BaseScheduler from khaosz.trainer.train_callback import ( TrainCallback, - ProgressBarCallback, - CheckpointCallback, - TrainCallback, + GradientClippingCallback, SchedulerCallback, - MetricLoggerCallback + CheckpointCallback, + ProgressBarCallback, + MetricLoggerCallback, ) __all__ = [ - # trainer + # Main trainer "Trainer", - # factory + # Strategy factory "StrategyFactory", - "SchedulerFactory", + "BaseStrategy", - # callback - "TrainCallback", - "ProgressBarCallback", - "CheckpointCallback", + # Scheduler factory + "SchedulerFactory", + "BaseScheduler", + + # Callbacks "TrainCallback", + "GradientClippingCallback", "SchedulerCallback", - "MetricLoggerCallback" + "CheckpointCallback", + "ProgressBarCallback", + "MetricLoggerCallback", ] \ No newline at end of file diff --git a/khaosz/trainer/schedule.py b/khaosz/trainer/schedule.py index 84135f0..8853690 100644 --- a/khaosz/trainer/schedule.py +++ b/khaosz/trainer/schedule.py @@ -1,20 +1,21 @@ +"""Learning rate scheduler implementations with factory pattern.""" + import math from abc import abstractmethod, ABC -from typing import Any, Dict, List +from typing import Any, Dict, List, Type from torch.optim.lr_scheduler import LRScheduler from khaosz.config.schedule_config import ScheduleConfig class BaseScheduler(LRScheduler, ABC): - """ - Base scheduler class for all other schedulers. - """ + """Base scheduler class for all other schedulers.""" def __init__(self, optimizer, last_epoch: int = -1): super().__init__(optimizer, last_epoch) @abstractmethod def get_lr(self) -> List[float]: + """Calculate the current learning rate.""" raise NotImplementedError def state_dict(self) -> Dict[str, Any]: @@ -24,10 +25,95 @@ class BaseScheduler(LRScheduler, ABC): super().load_state_dict(state_dict) +class SchedulerFactory: + """Factory class for creating learning rate schedulers. + + Supports decorator-based registration for extensible scheduler types. + Also supports creation from ScheduleConfig objects. + + Example usage: + @SchedulerFactory.register("custom") + class CustomScheduler(BaseScheduler): + ... + + scheduler = SchedulerFactory.create(optimizer, "custom", **kwargs) + + # Or from config + config = CosineScheduleConfig(total_steps=10000) + scheduler = SchedulerFactory.load(optimizer, config) + """ + + SCHEDULER_MAP: Dict[str, Type[BaseScheduler]] = {} + + @classmethod + def register(cls, name: str): + """Decorator to register a new scheduler class. + + Args: + name: Registration name for the scheduler + + Returns: + Decorator function that registers the scheduler class + """ + def decorator(scheduler_cls: Type[BaseScheduler]) -> Type[BaseScheduler]: + if not issubclass(scheduler_cls, BaseScheduler): + raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler") + cls.SCHEDULER_MAP[name] = scheduler_cls + return scheduler_cls + return decorator + + @classmethod + def create(cls, optimizer, schedule_type: str, **kwargs) -> BaseScheduler: + """Create a scheduler instance by type name. + + Args: + optimizer: PyTorch optimizer + schedule_type: Type of scheduler ("cosine", "sgdr") + **kwargs: Arguments passed to the scheduler constructor + + Returns: + Scheduler instance + + Raises: + ValueError: If schedule_type is not supported + """ + if schedule_type not in cls.SCHEDULER_MAP: + raise ValueError( + f"Unknown schedule type: '{schedule_type}'. " + f"Supported types: {sorted(cls.SCHEDULER_MAP.keys())}" + ) + + scheduler_cls = cls.SCHEDULER_MAP[schedule_type] + return scheduler_cls(optimizer, **kwargs) + + @staticmethod + def load(optimizer, schedule_config: ScheduleConfig) -> BaseScheduler: + """Create a scheduler from a ScheduleConfig object. + + Args: + optimizer: PyTorch optimizer + schedule_config: ScheduleConfig instance + + Returns: + Scheduler instance + """ + kwargs = schedule_config.get_kwargs() + schedule_type = kwargs.pop("schedule_type") + return SchedulerFactory.create(optimizer, schedule_type, **kwargs) + + @classmethod + def available_types(cls) -> list: + """Return list of registered scheduler type names.""" + return list(cls.SCHEDULER_MAP.keys()) + + +# ============== Scheduler Classes ============== +# All scheduler classes are registered at class definition time using the decorator + + +@SchedulerFactory.register("cosine") class CosineScheduler(BaseScheduler): - """ - Cosine decay scheduler with warmup, implemented as PyTorch LRScheduler. - """ + """Cosine decay scheduler with warmup, implemented as PyTorch LRScheduler.""" def __init__( self, @@ -75,10 +161,9 @@ class CosineScheduler(BaseScheduler): super().load_state_dict(state_dict) +@SchedulerFactory.register("sgdr") class SGDRScheduler(BaseScheduler): - """ - SGDR (Stochastic Gradient Descent with Warm Restarts) scheduler, - """ + """SGDR (Stochastic Gradient Descent with Warm Restarts) scheduler.""" def __init__( self, @@ -141,24 +226,4 @@ class SGDRScheduler(BaseScheduler): self.cycle_length = state_dict.pop('cycle_length') self.min_rate = state_dict.pop('min_rate') self.t_mult = state_dict.pop('t_mult') - super().load_state_dict(state_dict) - - - -class SchedulerFactory: - """ - Factory class for creating learning rate schedulers. - """ - - @staticmethod - def load(optimizer, schedule_config: ScheduleConfig) -> BaseScheduler: - kwargs = schedule_config.get_kwargs() - schedule_type = kwargs.pop("schedule_type") - - if schedule_type == "cosine": - return CosineScheduler(optimizer, **kwargs) - elif schedule_type == "sgdr": - return SGDRScheduler(optimizer, **kwargs) - else: - raise ValueError(f"Unsupported schedule type: {schedule_type}") - \ No newline at end of file + super().load_state_dict(state_dict) \ No newline at end of file diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index 44f61e6..85679d2 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -1,3 +1,5 @@ +"""Training strategy implementations with factory pattern.""" + import copy import torch import torch.nn as nn @@ -17,9 +19,10 @@ def unwrap_model(model: nn.Module) -> nn.Module: def create_ref_model(model: nn.Module) -> nn.Module: - """ - Create a reference model for DPO/GRPO training. - Handles DDP-wrapped models safely. + """Create a reference model for DPO/GRPO training. + + Handles DDP-wrapped models safely by unwrapping first, + then creating a deep copy with frozen gradients. """ original_model = unwrap_model(model) ref_model = copy.deepcopy(original_model) @@ -28,17 +31,18 @@ def create_ref_model(model: nn.Module) -> nn.Module: return ref_model -def move_to_device(batch:Dict[str, Tensor], device: str) -> Any: +def move_to_device(batch: Dict[str, Tensor], device: str) -> Any: + """Move batch tensors to specified device with non-blocking transfer.""" return {key: value.to(device, non_blocking=True) for key, value in batch.items()} + def get_logprobs( model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], input_ids: Tensor, mask: Tensor, reduction: str, ): - """ - Compute token-wise log probabilities from model outputs. + """Compute token-wise log probabilities from model outputs. Args: model: The language model @@ -49,7 +53,6 @@ def get_logprobs( Returns: Log probabilities with reduction applied over sequence dimension """ - # reduction on seq_len dim allowed_reductions = ["mean", "sum", "none"] if reduction not in allowed_reductions: raise ValueError(f"reduction must be one of {allowed_reductions}, got '{reduction}'") @@ -60,7 +63,6 @@ def get_logprobs( logits = model(input_ids[:, :-1], mask[:, :-1])["logits"] log_probs = torch.log_softmax(logits.float(), dim=-1) - # [batch_size, seq_len - 1] token_logprobs = torch.gather( log_probs, dim=-1, @@ -76,20 +78,112 @@ def get_logprobs( class BaseStrategy(ABC): + """Abstract base class for training strategies.""" + def __init__(self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str): self.model = model self.device = device @abstractmethod def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: + """Compute loss for the given batch. + + Args: + batch: Dictionary containing batch tensors + + Returns: + Computed loss tensor + """ raise NotImplementedError def __call__(self, batch: Dict[str, Tensor]) -> Tensor: + """Allow calling strategy directly as a callable.""" return self.compute_loss(batch) +class StrategyFactory: + """Factory class for creating training strategy instances. + + Supports decorator-based registration for extensible strategy types. + All default strategies (seq, sft, dpo, grpo) are automatically registered. + + Example usage: + @StrategyFactory.register("custom") + class CustomStrategy(BaseStrategy): + ... + + strategy = StrategyFactory.create(model, "custom", device) + """ + + SUPPORTED_STRATEGIES = frozenset({"seq", "sft", "dpo", "grpo"}) + STRATEGY_MAP: Dict[str, type] = {} + + @classmethod + def register(cls, name: str): + """Decorator to register a new strategy class. + + Args: + name: Registration name for the strategy + + Returns: + Decorator function that registers the strategy class + """ + def decorator(strategy_cls: type) -> type: + if not issubclass(strategy_cls, BaseStrategy): + raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy") + cls.STRATEGY_MAP[name] = strategy_cls + return strategy_cls + return decorator + + @classmethod + def create(cls, model, train_type: str, device: str, **kwargs) -> BaseStrategy: + """Create a strategy instance based on training type. + + Args: + model: Model instance for the strategy + train_type: Type of training ("seq", "sft", "dpo", "grpo") + device: Device to run the strategy on + **kwargs: Additional arguments passed to strategy constructor + + Returns: + Strategy instance + + Raises: + ValueError: If train_type is not supported + NotImplementedError: If train_type is in supported list but not implemented + """ + if train_type not in cls.SUPPORTED_STRATEGIES: + raise ValueError( + f"Unknown training strategy: '{train_type}'. " + f"Supported strategies: {sorted(cls.SUPPORTED_STRATEGIES)}" + ) + + if train_type not in cls.STRATEGY_MAP: + raise NotImplementedError( + f"Strategy '{train_type}' is supported but not yet implemented." + ) + + strategy_cls = cls.STRATEGY_MAP[train_type] + return strategy_cls(model, device, **kwargs) + + @classmethod + def available_strategies(cls) -> list: + """Return list of registered strategy names.""" + return list(cls.STRATEGY_MAP.keys()) + + +# ============== Strategy Classes ============== +# All strategies are registered at class definition time using the decorator + + +@StrategyFactory.register("seq") class SEQStrategy(BaseStrategy): - def __init__(self, model, device, label_smoothing): + """Standard next-token prediction training strategy. + + Computes cross-entropy loss for next token prediction. + """ + + def __init__(self, model, device, label_smoothing: float = 0.0): super().__init__(model, device) self.label_smoothing = label_smoothing @@ -99,15 +193,22 @@ class SEQStrategy(BaseStrategy): logits = self.model(input_ids=input_ids)["logits"] loss = F.cross_entropy( - input=logits.flatten(0, 1).float(), - target=target_ids.flatten() + input=logits.flatten(0, 1).float(), + target=target_ids.flatten(), + label_smoothing=self.label_smoothing ) return loss - + +@StrategyFactory.register("sft") class SFTStrategy(BaseStrategy): - def __init__(self, model, device, label_smoothing): + """Supervised Fine-tuning strategy with loss masking. + + Applies cross-entropy loss only to tokens where loss_mask is True. + """ + + def __init__(self, model, device, label_smoothing: float = 0.0): super().__init__(model, device) self.label_smoothing = label_smoothing @@ -122,19 +223,27 @@ class SFTStrategy(BaseStrategy): loss = F.cross_entropy( input=logits.flatten(0, 1).float(), target=target_ids.flatten(), - ignore_index=ignore_index + ignore_index=ignore_index, + label_smoothing=self.label_smoothing ) return loss +@StrategyFactory.register("dpo") class DPOStrategy(BaseStrategy): + """Direct Preference Optimization strategy. + + Implements the DPO loss from the paper "Direct Preference Optimization". + Uses a reference model to compute KL divergence penalty. + """ + def __init__( self, model: nn.Module, device: str, - beta: float, - reduction: str, + beta: float = 0.1, + reduction: str = "mean", ): super().__init__(model, device) self.ref_model = create_ref_model(model) @@ -168,16 +277,21 @@ class DPOStrategy(BaseStrategy): return dpo_loss +@StrategyFactory.register("grpo") class GRPOStrategy(BaseStrategy): + """Group Relative Policy Optimization strategy. + + Implements GRPO with clipping and KL penalty. + """ def __init__( self, model: nn.Module, device: str, - clip_eps: float, - kl_coef: float, - group_size: int, - reduction: str, + clip_eps: float = 0.2, + kl_coef: float = 0.01, + group_size: int = 4, + reduction: str = "mean", ): super().__init__(model, device) self.ref_model = create_ref_model(model) @@ -209,16 +323,14 @@ class GRPOStrategy(BaseStrategy): log_probs_ref = get_logprobs(self.ref_model, full_sequences, full_masks, self.reduction) log_probs_ref = log_probs_ref.view(batch_size, group_size) - # Compute advantages from rewards + # Compute advantages from rewards with normalization eps = torch.finfo(log_probs_policy.dtype).eps mean = rewards.mean(dim=-1, keepdim=True) std = rewards.std(dim=-1, keepdim=True) advantages = (rewards - mean) / (std + eps) - # log_ratio = log_probs_policy - log_probs_old - # ratio = torch.exp(log_ratio) - # off policy: policy_model = old_model, then ratio = 1 - ratio = torch.exp(0) + # PPO-style clipped surrogate objective + ratio = torch.exp(0) # Off-policy: policy_model = old_model surr1 = ratio * advantages surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages @@ -227,36 +339,3 @@ class GRPOStrategy(BaseStrategy): total_loss = policy_loss + kl_penalty return total_loss - - -class StrategyFactory: - - def load(model, train_type, device, **kwargs): - train_strategy: Dict[str, Callable[[], BaseStrategy]] = { - "seq": lambda: SEQStrategy( - model, - device, - kwargs.get("label_smoothing", 0.0) - ), - "sft": lambda: SFTStrategy( - model, - device, - kwargs.get("label_smoothing", 0.0) - ), - "dpo": lambda: DPOStrategy( - model, - device, - kwargs.get("dpo_beta"), - kwargs.get("reduction", "mean") - ), - "grpo": lambda: GRPOStrategy( - model, - device, - kwargs.get("grpo_clip_eps"), - kwargs.get("grpo_kl_coef"), - kwargs.get("grpo_group_size"), - kwargs.get("reduction", "mean") - ) - } - strategy = train_strategy[train_type]() - return strategy \ No newline at end of file