diff --git a/khaosz/data/dataset.py b/khaosz/data/dataset.py index b4e7fd6..bd40f5b 100644 --- a/khaosz/data/dataset.py +++ b/khaosz/data/dataset.py @@ -171,13 +171,13 @@ class GRPODataset(BaseDataset): def __getitem__(self, index: int) -> Dict[str, Tensor]: begin_idx, end_idx = self.get_index(index) - prompts = self._fetch_data(begin_idx, end_idx, "prompts"), - responses = self._fetch_data(begin_idx, end_idx, "responses"), - masks = self._fetch_data(begin_idx, end_idx, "masks"), - rewards = self._fetch_data(begin_idx, end_idx, "rewards") - - return {"prompts": prompts, "responses": responses, "masks": masks, "rewards": rewards} + prompts = self._fetch_data(begin_idx, end_idx, "prompts") + responses = self._fetch_data(begin_idx, end_idx, "responses") + masks = self._fetch_data(begin_idx, end_idx, "masks") + rewards = self._fetch_data(begin_idx, end_idx, "rewards") + return {"prompts": prompts, "responses": responses, "masks": masks, "rewards": rewards} + class DatasetLoader: @staticmethod diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index 49c1744..44f61e6 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -2,12 +2,32 @@ import copy import torch import torch.nn as nn import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP from torch import Tensor -from typing import Any, Callable, Dict, Union +from typing import Any, Callable, Dict, Union, Optional from abc import ABC, abstractmethod +def unwrap_model(model: nn.Module) -> nn.Module: + """Unwrap DDP wrapper if present to get the original model.""" + if isinstance(model, DDP): + return model.module + return model + + +def create_ref_model(model: nn.Module) -> nn.Module: + """ + Create a reference model for DPO/GRPO training. + Handles DDP-wrapped models safely. + """ + original_model = unwrap_model(model) + ref_model = copy.deepcopy(original_model) + ref_model.requires_grad_(False) + ref_model.eval() + return ref_model + + def move_to_device(batch:Dict[str, Tensor], device: str) -> Any: return {key: value.to(device, non_blocking=True) for key, value in batch.items()} @@ -17,6 +37,18 @@ def get_logprobs( mask: Tensor, reduction: str, ): + """ + Compute token-wise log probabilities from model outputs. + + Args: + model: The language model + input_ids: Input token IDs of shape [batch_size, seq_len] + mask: Attention mask of shape [batch_size, seq_len] + reduction: How to reduce over sequence dimension ("mean", "sum", "none") + + Returns: + Log probabilities with reduction applied over sequence dimension + """ # reduction on seq_len dim allowed_reductions = ["mean", "sum", "none"] if reduction not in allowed_reductions: @@ -25,7 +57,7 @@ def get_logprobs( shifted_input_ids = input_ids[:, 1:] shifted_mask = mask[:, 1:] - logits = model(input_ids[:, :-1, :], mask[:, :-1, :])["logits"] + logits = model(input_ids[:, :-1], mask[:, :-1])["logits"] log_probs = torch.log_softmax(logits.float(), dim=-1) # [batch_size, seq_len - 1] @@ -99,18 +131,13 @@ class SFTStrategy(BaseStrategy): class DPOStrategy(BaseStrategy): def __init__( self, - model, - device, + model: nn.Module, + device: str, beta: float, reduction: str, - ): super().__init__(model, device) - ref_model = copy.deepcopy(self.model) - ref_model.requires_grad_(False) - ref_model.eval() - - self.ref_model = ref_model + self.ref_model = create_ref_model(model) self.beta = beta self.reduction = reduction @@ -145,20 +172,15 @@ class GRPOStrategy(BaseStrategy): def __init__( self, - model, - device, + model: nn.Module, + device: str, clip_eps: float, kl_coef: float, group_size: int, reduction: str, ): - super().__init__(model, device) - ref_model = copy.deepcopy(self.model) - ref_model.requires_grad_(False) - ref_model.eval() - - self.ref_model = ref_model + self.ref_model = create_ref_model(model) self.clip_eps = clip_eps self.kl_coef = kl_coef self.group_size = group_size diff --git a/khaosz/trainer/train_context.py b/khaosz/trainer/train_context.py index ad4bdf4..f2172ad 100644 --- a/khaosz/trainer/train_context.py +++ b/khaosz/trainer/train_context.py @@ -88,7 +88,7 @@ class TrainContextBuilder: def with_strategy(self) -> Self: self._context.strategy = StrategyFactory.load( - model=self.config.model, + model=self._context.model, train_type=self.config.strategy, device=get_current_device(), **self.config.extra_kwargs diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index e1155d9..a75ef36 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -72,13 +72,6 @@ class Trainer: self._call_callbacks('on_epoch_begin', context) for batch in context.dataloader: - if context.iteration % self.train_config.accumulation_steps == 0: - # 2. step - self._call_callbacks('on_step_begin', context) - context.optimizer.step() - context.optimizer.zero_grad() - self._call_callbacks('on_step_end', context) - # 3. batch self._call_callbacks('on_batch_begin', context) loss = context.strategy(batch) @@ -91,6 +84,13 @@ class Trainer: self._call_callbacks('on_batch_end', context) + if context.iteration % self.train_config.accumulation_steps == 0: + # 2. step + self._call_callbacks('on_step_begin', context) + context.optimizer.step() + context.optimizer.zero_grad() + self._call_callbacks('on_step_end', context) + self._call_callbacks('on_epoch_end', context) except Exception as e: