fix: 修复一些已知问题

This commit is contained in:
ViperEkura 2026-03-30 01:08:19 +08:00
parent c01791ff54
commit 60f4df95bd
4 changed files with 54 additions and 32 deletions

View File

@ -171,13 +171,13 @@ class GRPODataset(BaseDataset):
def __getitem__(self, index: int) -> Dict[str, Tensor]: def __getitem__(self, index: int) -> Dict[str, Tensor]:
begin_idx, end_idx = self.get_index(index) begin_idx, end_idx = self.get_index(index)
prompts = self._fetch_data(begin_idx, end_idx, "prompts"), prompts = self._fetch_data(begin_idx, end_idx, "prompts")
responses = self._fetch_data(begin_idx, end_idx, "responses"), responses = self._fetch_data(begin_idx, end_idx, "responses")
masks = self._fetch_data(begin_idx, end_idx, "masks"), masks = self._fetch_data(begin_idx, end_idx, "masks")
rewards = self._fetch_data(begin_idx, end_idx, "rewards") rewards = self._fetch_data(begin_idx, end_idx, "rewards")
return {"prompts": prompts, "responses": responses, "masks": masks, "rewards": rewards}
return {"prompts": prompts, "responses": responses, "masks": masks, "rewards": rewards}
class DatasetLoader: class DatasetLoader:
@staticmethod @staticmethod

View File

@ -2,12 +2,32 @@ import copy
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch import Tensor from torch import Tensor
from typing import Any, Callable, Dict, Union from typing import Any, Callable, Dict, Union, Optional
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
def unwrap_model(model: nn.Module) -> nn.Module:
"""Unwrap DDP wrapper if present to get the original model."""
if isinstance(model, DDP):
return model.module
return model
def create_ref_model(model: nn.Module) -> nn.Module:
"""
Create a reference model for DPO/GRPO training.
Handles DDP-wrapped models safely.
"""
original_model = unwrap_model(model)
ref_model = copy.deepcopy(original_model)
ref_model.requires_grad_(False)
ref_model.eval()
return ref_model
def move_to_device(batch:Dict[str, Tensor], device: str) -> Any: def move_to_device(batch:Dict[str, Tensor], device: str) -> Any:
return {key: value.to(device, non_blocking=True) for key, value in batch.items()} return {key: value.to(device, non_blocking=True) for key, value in batch.items()}
@ -17,6 +37,18 @@ def get_logprobs(
mask: Tensor, mask: Tensor,
reduction: str, reduction: str,
): ):
"""
Compute token-wise log probabilities from model outputs.
Args:
model: The language model
input_ids: Input token IDs of shape [batch_size, seq_len]
mask: Attention mask of shape [batch_size, seq_len]
reduction: How to reduce over sequence dimension ("mean", "sum", "none")
Returns:
Log probabilities with reduction applied over sequence dimension
"""
# reduction on seq_len dim # reduction on seq_len dim
allowed_reductions = ["mean", "sum", "none"] allowed_reductions = ["mean", "sum", "none"]
if reduction not in allowed_reductions: if reduction not in allowed_reductions:
@ -25,7 +57,7 @@ def get_logprobs(
shifted_input_ids = input_ids[:, 1:] shifted_input_ids = input_ids[:, 1:]
shifted_mask = mask[:, 1:] shifted_mask = mask[:, 1:]
logits = model(input_ids[:, :-1, :], mask[:, :-1, :])["logits"] logits = model(input_ids[:, :-1], mask[:, :-1])["logits"]
log_probs = torch.log_softmax(logits.float(), dim=-1) log_probs = torch.log_softmax(logits.float(), dim=-1)
# [batch_size, seq_len - 1] # [batch_size, seq_len - 1]
@ -99,18 +131,13 @@ class SFTStrategy(BaseStrategy):
class DPOStrategy(BaseStrategy): class DPOStrategy(BaseStrategy):
def __init__( def __init__(
self, self,
model, model: nn.Module,
device, device: str,
beta: float, beta: float,
reduction: str, reduction: str,
): ):
super().__init__(model, device) super().__init__(model, device)
ref_model = copy.deepcopy(self.model) self.ref_model = create_ref_model(model)
ref_model.requires_grad_(False)
ref_model.eval()
self.ref_model = ref_model
self.beta = beta self.beta = beta
self.reduction = reduction self.reduction = reduction
@ -145,20 +172,15 @@ class GRPOStrategy(BaseStrategy):
def __init__( def __init__(
self, self,
model, model: nn.Module,
device, device: str,
clip_eps: float, clip_eps: float,
kl_coef: float, kl_coef: float,
group_size: int, group_size: int,
reduction: str, reduction: str,
): ):
super().__init__(model, device) super().__init__(model, device)
ref_model = copy.deepcopy(self.model) self.ref_model = create_ref_model(model)
ref_model.requires_grad_(False)
ref_model.eval()
self.ref_model = ref_model
self.clip_eps = clip_eps self.clip_eps = clip_eps
self.kl_coef = kl_coef self.kl_coef = kl_coef
self.group_size = group_size self.group_size = group_size

View File

@ -88,7 +88,7 @@ class TrainContextBuilder:
def with_strategy(self) -> Self: def with_strategy(self) -> Self:
self._context.strategy = StrategyFactory.load( self._context.strategy = StrategyFactory.load(
model=self.config.model, model=self._context.model,
train_type=self.config.strategy, train_type=self.config.strategy,
device=get_current_device(), device=get_current_device(),
**self.config.extra_kwargs **self.config.extra_kwargs

View File

@ -72,13 +72,6 @@ class Trainer:
self._call_callbacks('on_epoch_begin', context) self._call_callbacks('on_epoch_begin', context)
for batch in context.dataloader: for batch in context.dataloader:
if context.iteration % self.train_config.accumulation_steps == 0:
# 2. step
self._call_callbacks('on_step_begin', context)
context.optimizer.step()
context.optimizer.zero_grad()
self._call_callbacks('on_step_end', context)
# 3. batch # 3. batch
self._call_callbacks('on_batch_begin', context) self._call_callbacks('on_batch_begin', context)
loss = context.strategy(batch) loss = context.strategy(batch)
@ -91,6 +84,13 @@ class Trainer:
self._call_callbacks('on_batch_end', context) self._call_callbacks('on_batch_end', context)
if context.iteration % self.train_config.accumulation_steps == 0:
# 2. step
self._call_callbacks('on_step_begin', context)
context.optimizer.step()
context.optimizer.zero_grad()
self._call_callbacks('on_step_end', context)
self._call_callbacks('on_epoch_end', context) self._call_callbacks('on_epoch_end', context)
except Exception as e: except Exception as e: