fix: 修复强化学习算法问题
This commit is contained in:
parent
a5574f92e2
commit
0f518473af
|
|
@ -15,35 +15,32 @@ def get_logprobs(
|
||||||
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
|
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
|
||||||
input_ids: Tensor,
|
input_ids: Tensor,
|
||||||
mask: Tensor,
|
mask: Tensor,
|
||||||
pad_token_id: int,
|
|
||||||
reduction: str,
|
reduction: str,
|
||||||
):
|
):
|
||||||
|
# reduction on seq_len dim
|
||||||
allowed_reductions = ["mean", "sum", "none"]
|
allowed_reductions = ["mean", "sum", "none"]
|
||||||
if reduction not in allowed_reductions:
|
if reduction not in allowed_reductions:
|
||||||
raise ValueError(f"reduction must be one of {allowed_reductions}, got '{reduction}'")
|
raise ValueError(f"reduction must be one of {allowed_reductions}, got '{reduction}'")
|
||||||
|
|
||||||
pad_mask = input_ids.ne(pad_token_id)
|
|
||||||
logits = model(input_ids, pad_mask)["logits"]
|
|
||||||
log_probs = torch.log_softmax(logits.float(), dim=-1)
|
|
||||||
|
|
||||||
shifted_log_probs = log_probs[:, :-1, :]
|
|
||||||
shifted_input_ids = input_ids[:, 1:]
|
shifted_input_ids = input_ids[:, 1:]
|
||||||
shifted_mask = mask[:, 1:]
|
shifted_mask = mask[:, 1:]
|
||||||
prompt_mask = pad_mask[:, 1:]
|
|
||||||
|
logits = model(input_ids[:, :-1, :], mask[:, :-1, :])["logits"]
|
||||||
|
log_probs = torch.log_softmax(logits.float(), dim=-1)
|
||||||
|
|
||||||
|
# [batch_size, seq_len - 1]
|
||||||
token_logprobs = torch.gather(
|
token_logprobs = torch.gather(
|
||||||
shifted_log_probs,
|
log_probs,
|
||||||
dim=-1,
|
dim=-1,
|
||||||
index=shifted_input_ids.unsqueeze(-1)
|
index=shifted_input_ids.unsqueeze(-1)
|
||||||
).squeeze(-1)
|
).squeeze(-1)
|
||||||
valid_mask = (prompt_mask & shifted_mask)
|
|
||||||
|
|
||||||
if reduction == "mean":
|
if reduction == "mean":
|
||||||
return (token_logprobs * valid_mask).mean(dim=-1)
|
return (token_logprobs * shifted_mask).sum(dim=-1) / shifted_mask.sum(dim=-1).clamp(min=1.0)
|
||||||
elif reduction == "sum":
|
elif reduction == "sum":
|
||||||
return (token_logprobs * valid_mask).sum(dim=-1)
|
return (token_logprobs * shifted_mask).sum(dim=-1)
|
||||||
else:
|
else:
|
||||||
return token_logprobs
|
return token_logprobs * shifted_mask
|
||||||
|
|
||||||
|
|
||||||
class BaseStrategy(ABC):
|
class BaseStrategy(ABC):
|
||||||
|
|
@ -104,7 +101,6 @@ class DpoStrategy(BaseStrategy):
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
device,
|
device,
|
||||||
pad_token_id: int,
|
|
||||||
beta: float,
|
beta: float,
|
||||||
reduction: str,
|
reduction: str,
|
||||||
|
|
||||||
|
|
@ -113,30 +109,35 @@ class DpoStrategy(BaseStrategy):
|
||||||
ref_model = copy.deepcopy(self.model)
|
ref_model = copy.deepcopy(self.model)
|
||||||
ref_model.requires_grad_(False)
|
ref_model.requires_grad_(False)
|
||||||
ref_model.eval()
|
ref_model.eval()
|
||||||
|
|
||||||
self.ref_model = ref_model
|
self.ref_model = ref_model
|
||||||
self.pad_token_id = pad_token_id
|
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
|
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
batch = move_to_device(batch, self.device)
|
batch = move_to_device(batch, self.device)
|
||||||
good_ids, bad_ids = batch["chosen"], batch["rejected"]
|
chosen_ids, rejected_ids = batch["chosen"], batch["rejected"]
|
||||||
good_mask, bad_mask = batch["chosen_mask"], batch["rejected_mask"]
|
chosen_mask, rejected_mask = batch["chosen_mask"], batch["rejected_mask"]
|
||||||
|
|
||||||
log_pi_good = get_logprobs(self.model, good_ids, good_mask, self.pad_token_id, self.reduction)
|
contact_ids = torch.cat([chosen_ids, rejected_ids], dim=0)
|
||||||
log_pi_bad = get_logprobs(self.model, bad_ids, bad_mask, self.pad_token_id, self.reduction)
|
contact_mask = torch.cat([chosen_mask, rejected_mask], dim=0)
|
||||||
|
|
||||||
|
log_pi = get_logprobs(self.model, contact_ids, contact_mask, self.reduction)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
log_ref_good = get_logprobs(self.ref_model, good_ids, good_mask, self.pad_token_id, self.reduction)
|
log_ref = get_logprobs(self.ref_model, contact_ids, contact_mask, self.reduction)
|
||||||
log_ref_bad = get_logprobs(self.ref_model, bad_ids, bad_mask, self.pad_token_id, self.reduction)
|
|
||||||
|
|
||||||
pi_log_ratio = log_pi_good - log_pi_bad
|
log_pi_chosen = log_pi[:chosen_ids.shape[0]]
|
||||||
ref_log_ratio = log_ref_good - log_ref_bad
|
log_pi_rejected = log_pi[chosen_ids.shape[0]:]
|
||||||
|
log_ref_chosen = log_ref[:chosen_ids.shape[0]]
|
||||||
|
log_ref_rejected = log_ref[chosen_ids.shape[0]:]
|
||||||
|
|
||||||
|
pi_log_ratio = log_pi_chosen - log_pi_rejected
|
||||||
|
ref_log_ratio = log_ref_chosen - log_ref_rejected
|
||||||
|
|
||||||
ratio_diff = pi_log_ratio - ref_log_ratio
|
ratio_diff = pi_log_ratio - ref_log_ratio
|
||||||
|
|
||||||
dpo_loss = -F.logsigmoid(self.beta * ratio_diff).mean()
|
dpo_loss = -F.logsigmoid(self.beta * ratio_diff).mean()
|
||||||
|
|
||||||
return dpo_loss
|
return dpo_loss
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -146,7 +147,6 @@ class GrpoStrategy(BaseStrategy):
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
device,
|
device,
|
||||||
pad_token_id: int,
|
|
||||||
clip_eps: float,
|
clip_eps: float,
|
||||||
kl_coef: float,
|
kl_coef: float,
|
||||||
group_size: int,
|
group_size: int,
|
||||||
|
|
@ -159,60 +159,44 @@ class GrpoStrategy(BaseStrategy):
|
||||||
ref_model.eval()
|
ref_model.eval()
|
||||||
|
|
||||||
self.ref_model = ref_model
|
self.ref_model = ref_model
|
||||||
self.pad_token_id = pad_token_id
|
|
||||||
self.clip_eps = clip_eps
|
self.clip_eps = clip_eps
|
||||||
self.kl_coef = kl_coef
|
self.kl_coef = kl_coef
|
||||||
self.group_size = group_size
|
self.group_size = group_size
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
|
|
||||||
def compute_advantages(self, rewards: Tensor, eps=1e-8) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
|
batch = move_to_device(batch, self.device)
|
||||||
|
prompts = batch["prompts"]
|
||||||
|
responses = batch["responses"]
|
||||||
|
masks = batch["masks"]
|
||||||
|
rewards = batch["rewards"]
|
||||||
|
|
||||||
|
batch_size, group_size, response_len = responses.shape
|
||||||
|
responses_flat = responses.view(-1, response_len)
|
||||||
|
masks_flat = masks.view(-1, response_len)
|
||||||
|
prompt_expanded = prompts.unsqueeze(1).repeat(1, group_size, 1).flatten(0, 1)
|
||||||
|
|
||||||
|
# Shape: (batch_size * group_size, seq_len + response_len)
|
||||||
|
full_sequences = torch.cat([prompt_expanded, responses_flat], dim=-1)
|
||||||
|
full_masks = torch.cat([torch.ones_like(prompt_expanded), masks_flat], dim=-1)
|
||||||
|
|
||||||
|
log_probs_policy = get_logprobs(self.model, full_sequences, full_masks, self.reduction)
|
||||||
|
log_probs_policy = log_probs_policy.view(batch_size, group_size)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
log_probs_ref = get_logprobs(self.ref_model, full_sequences, full_masks, self.reduction)
|
||||||
|
log_probs_ref = log_probs_ref.view(batch_size, group_size)
|
||||||
|
|
||||||
|
# Compute advantages from rewards
|
||||||
|
eps = torch.finfo(log_probs_policy.dtype).eps
|
||||||
mean = rewards.mean(dim=-1, keepdim=True)
|
mean = rewards.mean(dim=-1, keepdim=True)
|
||||||
std = rewards.std(dim=-1, keepdim=True)
|
std = rewards.std(dim=-1, keepdim=True)
|
||||||
advantages = (rewards - mean) / (std + eps)
|
advantages = (rewards - mean) / (std + eps)
|
||||||
|
|
||||||
return advantages
|
# log_ratio = log_probs_policy - log_probs_old
|
||||||
|
# ratio = torch.exp(log_ratio)
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
# off policy: policy_model = old_model, then ratio = 1
|
||||||
batch = move_to_device(batch, self.device)
|
ratio = torch.exp(0)
|
||||||
input_ids = batch["input_ids"]
|
|
||||||
responses = batch["responses"]
|
|
||||||
response_masks = batch["response_masks"]
|
|
||||||
rewards = batch["rewards"]
|
|
||||||
|
|
||||||
batch_size, group_size, response_len = responses.shape
|
|
||||||
|
|
||||||
# Shape: (batch_size * group_size, response_len)
|
|
||||||
responses_flat = responses.view(-1, response_len)
|
|
||||||
masks_flat = response_masks.view(-1, response_len)
|
|
||||||
|
|
||||||
# Shape: (batch_size * group_size, seq_len)
|
|
||||||
input_ids_expanded = input_ids.unsqueeze(1).repeat(1, group_size, 1).flatten(0, 1)
|
|
||||||
|
|
||||||
# Shape: (batch_size * group_size, seq_len + response_len)
|
|
||||||
full_sequences = torch.cat([input_ids_expanded, responses_flat], dim=-1)
|
|
||||||
full_masks = torch.cat([torch.ones_like(input_ids_expanded), masks_flat], dim=-1)
|
|
||||||
|
|
||||||
# Get log probabilities from policy model
|
|
||||||
log_probs_policy = get_logprobs(self.model, full_sequences,
|
|
||||||
full_masks, self.pad_token_id, self.reduction)
|
|
||||||
# Reshape to (batch_size, group_size)
|
|
||||||
log_probs_policy = log_probs_policy.view(batch_size, group_size)
|
|
||||||
|
|
||||||
# Get log probabilities from reference model (no grad)
|
|
||||||
with torch.no_grad():
|
|
||||||
log_probs_ref = get_logprobs(self.ref_model, full_sequences,
|
|
||||||
full_masks, self.pad_token_id, self.reduction)
|
|
||||||
log_probs_ref = log_probs_ref.view(batch_size, group_size)
|
|
||||||
|
|
||||||
# Compute advantages from rewards
|
|
||||||
advantages = self.compute_advantages(rewards)
|
|
||||||
|
|
||||||
# Compute importance sampling ratio
|
|
||||||
# Since we're re-generating responses, we assume old policy = reference policy
|
|
||||||
log_ratio = log_probs_policy - log_probs_ref
|
|
||||||
ratio = torch.exp(log_ratio)
|
|
||||||
|
|
||||||
# Advantages shape: (batch_size, group_size)
|
|
||||||
surr1 = ratio * advantages
|
surr1 = ratio * advantages
|
||||||
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
||||||
|
|
||||||
|
|
@ -240,17 +224,15 @@ class StrategyFactory:
|
||||||
"dpo": lambda: DpoStrategy(
|
"dpo": lambda: DpoStrategy(
|
||||||
model,
|
model,
|
||||||
device,
|
device,
|
||||||
kwargs.get("pad_token_id"),
|
|
||||||
kwargs.get("dpo_beta"),
|
kwargs.get("dpo_beta"),
|
||||||
kwargs.get("reduction", "mean")
|
kwargs.get("reduction", "mean")
|
||||||
),
|
),
|
||||||
"grpo": lambda: GrpoStrategy(
|
"grpo": lambda: GrpoStrategy(
|
||||||
model,
|
model,
|
||||||
device,
|
device,
|
||||||
kwargs.get("pad_token_id"),
|
kwargs.get("grpo_clip_eps"),
|
||||||
kwargs.get("grpo_clip_eps", 0.2),
|
kwargs.get("grpo_kl_coef"),
|
||||||
kwargs.get("grpo_kl_coef", 0.04),
|
kwargs.get("grpo_group_size"),
|
||||||
kwargs.get("grpo_group_size", 4),
|
|
||||||
kwargs.get("reduction", "mean")
|
kwargs.get("reduction", "mean")
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -112,11 +112,8 @@ def train(
|
||||||
|
|
||||||
model = parameter.model
|
model = parameter.model
|
||||||
|
|
||||||
kwargs = {
|
strategy_kwargs = {
|
||||||
"dpo_beta": dpo_beta,
|
"dpo_beta": dpo_beta,
|
||||||
"bos_token_id": parameter.tokenizer.bos_id,
|
|
||||||
"eos_token_id": parameter.tokenizer.eos_id,
|
|
||||||
"pad_token_id": parameter.tokenizer.pad_id,
|
|
||||||
"label_smoothing": label_smoothing
|
"label_smoothing": label_smoothing
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -158,7 +155,7 @@ def train(
|
||||||
parallel_wrapper=ddp_wrap,
|
parallel_wrapper=ddp_wrap,
|
||||||
state_dict_fn=prepare_checkpoint,
|
state_dict_fn=prepare_checkpoint,
|
||||||
device_type=device_type,
|
device_type=device_type,
|
||||||
extra_kwargs=kwargs,
|
extra_kwargs=strategy_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(train_config)
|
trainer = Trainer(train_config)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue