fix: 修复一些已知问题
This commit is contained in:
parent
c01791ff54
commit
60f4df95bd
|
|
@ -171,13 +171,13 @@ class GRPODataset(BaseDataset):
|
||||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||||
begin_idx, end_idx = self.get_index(index)
|
begin_idx, end_idx = self.get_index(index)
|
||||||
|
|
||||||
prompts = self._fetch_data(begin_idx, end_idx, "prompts"),
|
prompts = self._fetch_data(begin_idx, end_idx, "prompts")
|
||||||
responses = self._fetch_data(begin_idx, end_idx, "responses"),
|
responses = self._fetch_data(begin_idx, end_idx, "responses")
|
||||||
masks = self._fetch_data(begin_idx, end_idx, "masks"),
|
masks = self._fetch_data(begin_idx, end_idx, "masks")
|
||||||
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
|
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
|
||||||
|
|
||||||
return {"prompts": prompts, "responses": responses, "masks": masks, "rewards": rewards}
|
|
||||||
|
|
||||||
|
return {"prompts": prompts, "responses": responses, "masks": masks, "rewards": rewards}
|
||||||
|
|
||||||
|
|
||||||
class DatasetLoader:
|
class DatasetLoader:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,32 @@ import copy
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import Any, Callable, Dict, Union
|
from typing import Any, Callable, Dict, Union, Optional
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
def unwrap_model(model: nn.Module) -> nn.Module:
|
||||||
|
"""Unwrap DDP wrapper if present to get the original model."""
|
||||||
|
if isinstance(model, DDP):
|
||||||
|
return model.module
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def create_ref_model(model: nn.Module) -> nn.Module:
|
||||||
|
"""
|
||||||
|
Create a reference model for DPO/GRPO training.
|
||||||
|
Handles DDP-wrapped models safely.
|
||||||
|
"""
|
||||||
|
original_model = unwrap_model(model)
|
||||||
|
ref_model = copy.deepcopy(original_model)
|
||||||
|
ref_model.requires_grad_(False)
|
||||||
|
ref_model.eval()
|
||||||
|
return ref_model
|
||||||
|
|
||||||
|
|
||||||
def move_to_device(batch:Dict[str, Tensor], device: str) -> Any:
|
def move_to_device(batch:Dict[str, Tensor], device: str) -> Any:
|
||||||
return {key: value.to(device, non_blocking=True) for key, value in batch.items()}
|
return {key: value.to(device, non_blocking=True) for key, value in batch.items()}
|
||||||
|
|
||||||
|
|
@ -17,6 +37,18 @@ def get_logprobs(
|
||||||
mask: Tensor,
|
mask: Tensor,
|
||||||
reduction: str,
|
reduction: str,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Compute token-wise log probabilities from model outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The language model
|
||||||
|
input_ids: Input token IDs of shape [batch_size, seq_len]
|
||||||
|
mask: Attention mask of shape [batch_size, seq_len]
|
||||||
|
reduction: How to reduce over sequence dimension ("mean", "sum", "none")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Log probabilities with reduction applied over sequence dimension
|
||||||
|
"""
|
||||||
# reduction on seq_len dim
|
# reduction on seq_len dim
|
||||||
allowed_reductions = ["mean", "sum", "none"]
|
allowed_reductions = ["mean", "sum", "none"]
|
||||||
if reduction not in allowed_reductions:
|
if reduction not in allowed_reductions:
|
||||||
|
|
@ -25,7 +57,7 @@ def get_logprobs(
|
||||||
shifted_input_ids = input_ids[:, 1:]
|
shifted_input_ids = input_ids[:, 1:]
|
||||||
shifted_mask = mask[:, 1:]
|
shifted_mask = mask[:, 1:]
|
||||||
|
|
||||||
logits = model(input_ids[:, :-1, :], mask[:, :-1, :])["logits"]
|
logits = model(input_ids[:, :-1], mask[:, :-1])["logits"]
|
||||||
log_probs = torch.log_softmax(logits.float(), dim=-1)
|
log_probs = torch.log_softmax(logits.float(), dim=-1)
|
||||||
|
|
||||||
# [batch_size, seq_len - 1]
|
# [batch_size, seq_len - 1]
|
||||||
|
|
@ -99,18 +131,13 @@ class SFTStrategy(BaseStrategy):
|
||||||
class DPOStrategy(BaseStrategy):
|
class DPOStrategy(BaseStrategy):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model: nn.Module,
|
||||||
device,
|
device: str,
|
||||||
beta: float,
|
beta: float,
|
||||||
reduction: str,
|
reduction: str,
|
||||||
|
|
||||||
):
|
):
|
||||||
super().__init__(model, device)
|
super().__init__(model, device)
|
||||||
ref_model = copy.deepcopy(self.model)
|
self.ref_model = create_ref_model(model)
|
||||||
ref_model.requires_grad_(False)
|
|
||||||
ref_model.eval()
|
|
||||||
|
|
||||||
self.ref_model = ref_model
|
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
|
|
||||||
|
|
@ -145,20 +172,15 @@ class GRPOStrategy(BaseStrategy):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model: nn.Module,
|
||||||
device,
|
device: str,
|
||||||
clip_eps: float,
|
clip_eps: float,
|
||||||
kl_coef: float,
|
kl_coef: float,
|
||||||
group_size: int,
|
group_size: int,
|
||||||
reduction: str,
|
reduction: str,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(model, device)
|
super().__init__(model, device)
|
||||||
ref_model = copy.deepcopy(self.model)
|
self.ref_model = create_ref_model(model)
|
||||||
ref_model.requires_grad_(False)
|
|
||||||
ref_model.eval()
|
|
||||||
|
|
||||||
self.ref_model = ref_model
|
|
||||||
self.clip_eps = clip_eps
|
self.clip_eps = clip_eps
|
||||||
self.kl_coef = kl_coef
|
self.kl_coef = kl_coef
|
||||||
self.group_size = group_size
|
self.group_size = group_size
|
||||||
|
|
|
||||||
|
|
@ -88,7 +88,7 @@ class TrainContextBuilder:
|
||||||
|
|
||||||
def with_strategy(self) -> Self:
|
def with_strategy(self) -> Self:
|
||||||
self._context.strategy = StrategyFactory.load(
|
self._context.strategy = StrategyFactory.load(
|
||||||
model=self.config.model,
|
model=self._context.model,
|
||||||
train_type=self.config.strategy,
|
train_type=self.config.strategy,
|
||||||
device=get_current_device(),
|
device=get_current_device(),
|
||||||
**self.config.extra_kwargs
|
**self.config.extra_kwargs
|
||||||
|
|
|
||||||
|
|
@ -72,13 +72,6 @@ class Trainer:
|
||||||
self._call_callbacks('on_epoch_begin', context)
|
self._call_callbacks('on_epoch_begin', context)
|
||||||
|
|
||||||
for batch in context.dataloader:
|
for batch in context.dataloader:
|
||||||
if context.iteration % self.train_config.accumulation_steps == 0:
|
|
||||||
# 2. step
|
|
||||||
self._call_callbacks('on_step_begin', context)
|
|
||||||
context.optimizer.step()
|
|
||||||
context.optimizer.zero_grad()
|
|
||||||
self._call_callbacks('on_step_end', context)
|
|
||||||
|
|
||||||
# 3. batch
|
# 3. batch
|
||||||
self._call_callbacks('on_batch_begin', context)
|
self._call_callbacks('on_batch_begin', context)
|
||||||
loss = context.strategy(batch)
|
loss = context.strategy(batch)
|
||||||
|
|
@ -91,6 +84,13 @@ class Trainer:
|
||||||
|
|
||||||
self._call_callbacks('on_batch_end', context)
|
self._call_callbacks('on_batch_end', context)
|
||||||
|
|
||||||
|
if context.iteration % self.train_config.accumulation_steps == 0:
|
||||||
|
# 2. step
|
||||||
|
self._call_callbacks('on_step_begin', context)
|
||||||
|
context.optimizer.step()
|
||||||
|
context.optimizer.zero_grad()
|
||||||
|
self._call_callbacks('on_step_end', context)
|
||||||
|
|
||||||
self._call_callbacks('on_epoch_end', context)
|
self._call_callbacks('on_epoch_end', context)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue