fix: 修复一些已知问题
This commit is contained in:
parent
c01791ff54
commit
60f4df95bd
|
|
@ -171,13 +171,13 @@ class GRPODataset(BaseDataset):
|
|||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||
begin_idx, end_idx = self.get_index(index)
|
||||
|
||||
prompts = self._fetch_data(begin_idx, end_idx, "prompts"),
|
||||
responses = self._fetch_data(begin_idx, end_idx, "responses"),
|
||||
masks = self._fetch_data(begin_idx, end_idx, "masks"),
|
||||
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
|
||||
|
||||
return {"prompts": prompts, "responses": responses, "masks": masks, "rewards": rewards}
|
||||
prompts = self._fetch_data(begin_idx, end_idx, "prompts")
|
||||
responses = self._fetch_data(begin_idx, end_idx, "responses")
|
||||
masks = self._fetch_data(begin_idx, end_idx, "masks")
|
||||
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
|
||||
|
||||
return {"prompts": prompts, "responses": responses, "masks": masks, "rewards": rewards}
|
||||
|
||||
|
||||
class DatasetLoader:
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -2,12 +2,32 @@ import copy
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from torch import Tensor
|
||||
from typing import Any, Callable, Dict, Union
|
||||
from typing import Any, Callable, Dict, Union, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
def unwrap_model(model: nn.Module) -> nn.Module:
|
||||
"""Unwrap DDP wrapper if present to get the original model."""
|
||||
if isinstance(model, DDP):
|
||||
return model.module
|
||||
return model
|
||||
|
||||
|
||||
def create_ref_model(model: nn.Module) -> nn.Module:
|
||||
"""
|
||||
Create a reference model for DPO/GRPO training.
|
||||
Handles DDP-wrapped models safely.
|
||||
"""
|
||||
original_model = unwrap_model(model)
|
||||
ref_model = copy.deepcopy(original_model)
|
||||
ref_model.requires_grad_(False)
|
||||
ref_model.eval()
|
||||
return ref_model
|
||||
|
||||
|
||||
def move_to_device(batch:Dict[str, Tensor], device: str) -> Any:
|
||||
return {key: value.to(device, non_blocking=True) for key, value in batch.items()}
|
||||
|
||||
|
|
@ -17,6 +37,18 @@ def get_logprobs(
|
|||
mask: Tensor,
|
||||
reduction: str,
|
||||
):
|
||||
"""
|
||||
Compute token-wise log probabilities from model outputs.
|
||||
|
||||
Args:
|
||||
model: The language model
|
||||
input_ids: Input token IDs of shape [batch_size, seq_len]
|
||||
mask: Attention mask of shape [batch_size, seq_len]
|
||||
reduction: How to reduce over sequence dimension ("mean", "sum", "none")
|
||||
|
||||
Returns:
|
||||
Log probabilities with reduction applied over sequence dimension
|
||||
"""
|
||||
# reduction on seq_len dim
|
||||
allowed_reductions = ["mean", "sum", "none"]
|
||||
if reduction not in allowed_reductions:
|
||||
|
|
@ -25,7 +57,7 @@ def get_logprobs(
|
|||
shifted_input_ids = input_ids[:, 1:]
|
||||
shifted_mask = mask[:, 1:]
|
||||
|
||||
logits = model(input_ids[:, :-1, :], mask[:, :-1, :])["logits"]
|
||||
logits = model(input_ids[:, :-1], mask[:, :-1])["logits"]
|
||||
log_probs = torch.log_softmax(logits.float(), dim=-1)
|
||||
|
||||
# [batch_size, seq_len - 1]
|
||||
|
|
@ -99,18 +131,13 @@ class SFTStrategy(BaseStrategy):
|
|||
class DPOStrategy(BaseStrategy):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
device,
|
||||
model: nn.Module,
|
||||
device: str,
|
||||
beta: float,
|
||||
reduction: str,
|
||||
|
||||
):
|
||||
super().__init__(model, device)
|
||||
ref_model = copy.deepcopy(self.model)
|
||||
ref_model.requires_grad_(False)
|
||||
ref_model.eval()
|
||||
|
||||
self.ref_model = ref_model
|
||||
self.ref_model = create_ref_model(model)
|
||||
self.beta = beta
|
||||
self.reduction = reduction
|
||||
|
||||
|
|
@ -145,20 +172,15 @@ class GRPOStrategy(BaseStrategy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
device,
|
||||
model: nn.Module,
|
||||
device: str,
|
||||
clip_eps: float,
|
||||
kl_coef: float,
|
||||
group_size: int,
|
||||
reduction: str,
|
||||
):
|
||||
|
||||
super().__init__(model, device)
|
||||
ref_model = copy.deepcopy(self.model)
|
||||
ref_model.requires_grad_(False)
|
||||
ref_model.eval()
|
||||
|
||||
self.ref_model = ref_model
|
||||
self.ref_model = create_ref_model(model)
|
||||
self.clip_eps = clip_eps
|
||||
self.kl_coef = kl_coef
|
||||
self.group_size = group_size
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ class TrainContextBuilder:
|
|||
|
||||
def with_strategy(self) -> Self:
|
||||
self._context.strategy = StrategyFactory.load(
|
||||
model=self.config.model,
|
||||
model=self._context.model,
|
||||
train_type=self.config.strategy,
|
||||
device=get_current_device(),
|
||||
**self.config.extra_kwargs
|
||||
|
|
|
|||
|
|
@ -72,13 +72,6 @@ class Trainer:
|
|||
self._call_callbacks('on_epoch_begin', context)
|
||||
|
||||
for batch in context.dataloader:
|
||||
if context.iteration % self.train_config.accumulation_steps == 0:
|
||||
# 2. step
|
||||
self._call_callbacks('on_step_begin', context)
|
||||
context.optimizer.step()
|
||||
context.optimizer.zero_grad()
|
||||
self._call_callbacks('on_step_end', context)
|
||||
|
||||
# 3. batch
|
||||
self._call_callbacks('on_batch_begin', context)
|
||||
loss = context.strategy(batch)
|
||||
|
|
@ -91,6 +84,13 @@ class Trainer:
|
|||
|
||||
self._call_callbacks('on_batch_end', context)
|
||||
|
||||
if context.iteration % self.train_config.accumulation_steps == 0:
|
||||
# 2. step
|
||||
self._call_callbacks('on_step_begin', context)
|
||||
context.optimizer.step()
|
||||
context.optimizer.zero_grad()
|
||||
self._call_callbacks('on_step_end', context)
|
||||
|
||||
self._call_callbacks('on_epoch_end', context)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
Loading…
Reference in New Issue