From 0f518473afe30f1994337afc9643a34474dfd680 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 19 Mar 2026 22:23:51 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E5=BC=BA=E5=8C=96?= =?UTF-8?q?=E5=AD=A6=E4=B9=A0=E7=AE=97=E6=B3=95=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/strategy.py | 138 ++++++++++++++++--------------------- tools/train.py | 7 +- 2 files changed, 62 insertions(+), 83 deletions(-) diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index 2db5556..8df3fd8 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -15,35 +15,32 @@ def get_logprobs( model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], input_ids: Tensor, mask: Tensor, - pad_token_id: int, reduction: str, ): + # reduction on seq_len dim allowed_reductions = ["mean", "sum", "none"] if reduction not in allowed_reductions: raise ValueError(f"reduction must be one of {allowed_reductions}, got '{reduction}'") - - pad_mask = input_ids.ne(pad_token_id) - logits = model(input_ids, pad_mask)["logits"] - log_probs = torch.log_softmax(logits.float(), dim=-1) - - shifted_log_probs = log_probs[:, :-1, :] + shifted_input_ids = input_ids[:, 1:] shifted_mask = mask[:, 1:] - prompt_mask = pad_mask[:, 1:] - + + logits = model(input_ids[:, :-1, :], mask[:, :-1, :])["logits"] + log_probs = torch.log_softmax(logits.float(), dim=-1) + + # [batch_size, seq_len - 1] token_logprobs = torch.gather( - shifted_log_probs, + log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1) ).squeeze(-1) - valid_mask = (prompt_mask & shifted_mask) if reduction == "mean": - return (token_logprobs * valid_mask).mean(dim=-1) + return (token_logprobs * shifted_mask).sum(dim=-1) / shifted_mask.sum(dim=-1).clamp(min=1.0) elif reduction == "sum": - return (token_logprobs * valid_mask).sum(dim=-1) + return (token_logprobs * shifted_mask).sum(dim=-1) else: - return token_logprobs + return token_logprobs * shifted_mask class BaseStrategy(ABC): @@ -104,7 +101,6 @@ class DpoStrategy(BaseStrategy): self, model, device, - pad_token_id: int, beta: float, reduction: str, @@ -113,30 +109,35 @@ class DpoStrategy(BaseStrategy): ref_model = copy.deepcopy(self.model) ref_model.requires_grad_(False) ref_model.eval() - + self.ref_model = ref_model - self.pad_token_id = pad_token_id self.beta = beta self.reduction = reduction def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: batch = move_to_device(batch, self.device) - good_ids, bad_ids = batch["chosen"], batch["rejected"] - good_mask, bad_mask = batch["chosen_mask"], batch["rejected_mask"] - - log_pi_good = get_logprobs(self.model, good_ids, good_mask, self.pad_token_id, self.reduction) - log_pi_bad = get_logprobs(self.model, bad_ids, bad_mask, self.pad_token_id, self.reduction) + chosen_ids, rejected_ids = batch["chosen"], batch["rejected"] + chosen_mask, rejected_mask = batch["chosen_mask"], batch["rejected_mask"] + + contact_ids = torch.cat([chosen_ids, rejected_ids], dim=0) + contact_mask = torch.cat([chosen_mask, rejected_mask], dim=0) + log_pi = get_logprobs(self.model, contact_ids, contact_mask, self.reduction) + with torch.no_grad(): - log_ref_good = get_logprobs(self.ref_model, good_ids, good_mask, self.pad_token_id, self.reduction) - log_ref_bad = get_logprobs(self.ref_model, bad_ids, bad_mask, self.pad_token_id, self.reduction) + log_ref = get_logprobs(self.ref_model, contact_ids, contact_mask, self.reduction) - pi_log_ratio = log_pi_good - log_pi_bad - ref_log_ratio = log_ref_good - log_ref_bad + log_pi_chosen = log_pi[:chosen_ids.shape[0]] + log_pi_rejected = log_pi[chosen_ids.shape[0]:] + log_ref_chosen = log_ref[:chosen_ids.shape[0]] + log_ref_rejected = log_ref[chosen_ids.shape[0]:] + pi_log_ratio = log_pi_chosen - log_pi_rejected + ref_log_ratio = log_ref_chosen - log_ref_rejected + ratio_diff = pi_log_ratio - ref_log_ratio - dpo_loss = -F.logsigmoid(self.beta * ratio_diff).mean() + return dpo_loss @@ -146,7 +147,6 @@ class GrpoStrategy(BaseStrategy): self, model, device, - pad_token_id: int, clip_eps: float, kl_coef: float, group_size: int, @@ -159,60 +159,44 @@ class GrpoStrategy(BaseStrategy): ref_model.eval() self.ref_model = ref_model - self.pad_token_id = pad_token_id self.clip_eps = clip_eps self.kl_coef = kl_coef self.group_size = group_size self.reduction = reduction - def compute_advantages(self, rewards: Tensor, eps=1e-8) -> Tensor: + def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: + batch = move_to_device(batch, self.device) + prompts = batch["prompts"] + responses = batch["responses"] + masks = batch["masks"] + rewards = batch["rewards"] + + batch_size, group_size, response_len = responses.shape + responses_flat = responses.view(-1, response_len) + masks_flat = masks.view(-1, response_len) + prompt_expanded = prompts.unsqueeze(1).repeat(1, group_size, 1).flatten(0, 1) + + # Shape: (batch_size * group_size, seq_len + response_len) + full_sequences = torch.cat([prompt_expanded, responses_flat], dim=-1) + full_masks = torch.cat([torch.ones_like(prompt_expanded), masks_flat], dim=-1) + + log_probs_policy = get_logprobs(self.model, full_sequences, full_masks, self.reduction) + log_probs_policy = log_probs_policy.view(batch_size, group_size) + + with torch.no_grad(): + log_probs_ref = get_logprobs(self.ref_model, full_sequences, full_masks, self.reduction) + log_probs_ref = log_probs_ref.view(batch_size, group_size) + + # Compute advantages from rewards + eps = torch.finfo(log_probs_policy.dtype).eps mean = rewards.mean(dim=-1, keepdim=True) std = rewards.std(dim=-1, keepdim=True) advantages = (rewards - mean) / (std + eps) - return advantages - - def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: - batch = move_to_device(batch, self.device) - input_ids = batch["input_ids"] - responses = batch["responses"] - response_masks = batch["response_masks"] - rewards = batch["rewards"] - - batch_size, group_size, response_len = responses.shape - - # Shape: (batch_size * group_size, response_len) - responses_flat = responses.view(-1, response_len) - masks_flat = response_masks.view(-1, response_len) - - # Shape: (batch_size * group_size, seq_len) - input_ids_expanded = input_ids.unsqueeze(1).repeat(1, group_size, 1).flatten(0, 1) - - # Shape: (batch_size * group_size, seq_len + response_len) - full_sequences = torch.cat([input_ids_expanded, responses_flat], dim=-1) - full_masks = torch.cat([torch.ones_like(input_ids_expanded), masks_flat], dim=-1) - - # Get log probabilities from policy model - log_probs_policy = get_logprobs(self.model, full_sequences, - full_masks, self.pad_token_id, self.reduction) - # Reshape to (batch_size, group_size) - log_probs_policy = log_probs_policy.view(batch_size, group_size) - - # Get log probabilities from reference model (no grad) - with torch.no_grad(): - log_probs_ref = get_logprobs(self.ref_model, full_sequences, - full_masks, self.pad_token_id, self.reduction) - log_probs_ref = log_probs_ref.view(batch_size, group_size) - - # Compute advantages from rewards - advantages = self.compute_advantages(rewards) - - # Compute importance sampling ratio - # Since we're re-generating responses, we assume old policy = reference policy - log_ratio = log_probs_policy - log_probs_ref - ratio = torch.exp(log_ratio) - - # Advantages shape: (batch_size, group_size) + # log_ratio = log_probs_policy - log_probs_old + # ratio = torch.exp(log_ratio) + # off policy: policy_model = old_model, then ratio = 1 + ratio = torch.exp(0) surr1 = ratio * advantages surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages @@ -240,17 +224,15 @@ class StrategyFactory: "dpo": lambda: DpoStrategy( model, device, - kwargs.get("pad_token_id"), kwargs.get("dpo_beta"), kwargs.get("reduction", "mean") ), "grpo": lambda: GrpoStrategy( model, device, - kwargs.get("pad_token_id"), - kwargs.get("grpo_clip_eps", 0.2), - kwargs.get("grpo_kl_coef", 0.04), - kwargs.get("grpo_group_size", 4), + kwargs.get("grpo_clip_eps"), + kwargs.get("grpo_kl_coef"), + kwargs.get("grpo_group_size"), kwargs.get("reduction", "mean") ) } diff --git a/tools/train.py b/tools/train.py index ec22690..bf3287f 100644 --- a/tools/train.py +++ b/tools/train.py @@ -112,11 +112,8 @@ def train( model = parameter.model - kwargs = { + strategy_kwargs = { "dpo_beta": dpo_beta, - "bos_token_id": parameter.tokenizer.bos_id, - "eos_token_id": parameter.tokenizer.eos_id, - "pad_token_id": parameter.tokenizer.pad_id, "label_smoothing": label_smoothing } @@ -158,7 +155,7 @@ def train( parallel_wrapper=ddp_wrap, state_dict_fn=prepare_checkpoint, device_type=device_type, - extra_kwargs=kwargs, + extra_kwargs=strategy_kwargs, ) trainer = Trainer(train_config)