From c86e573195cb954cfbffdebfdd0131ab2a506f7b Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 8 Dec 2025 14:10:08 +0800 Subject: [PATCH] =?UTF-8?q?feat(trainer):=20=E6=94=B9=E8=BF=9B=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E8=BE=93=E5=85=A5=E5=92=8C=E6=8D=9F=E5=A4=B1=E8=AE=A1?= =?UTF-8?q?=E7=AE=97=E4=B8=AD=E7=9A=84=E6=95=B0=E6=8D=AE=E7=B1=BB=E5=9E=8B?= =?UTF-8?q?=E7=B2=BE=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/strategy.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index 21e57ac..e0adbb7 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -8,10 +8,15 @@ from typing import Any, Callable, Dict, Union from abc import ABC, abstractmethod -def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id: int): +def get_logprobs( + model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], + input_ids: Tensor, + mask: Tensor, + pad_token_id: int +): input_mask = input_ids.ne(pad_token_id) logits = model(input_ids, input_mask)["logits"] - log_probs = torch.log_softmax(logits, dim=-1) + log_probs = torch.log_softmax(logits.float(), dim=-1) shifted_log_probs = log_probs[:, :-1, :] shifted_input_ids = input_ids[:, 1:] @@ -55,7 +60,7 @@ class SeqStrategy(BaseStrategy): logits = self.model(input_ids=input_ids)["logits"] loss = F.cross_entropy( - input=logits.flatten(0, 1), + input=logits.flatten(0, 1).float(), target=target_ids.flatten() ) @@ -75,7 +80,7 @@ class SftStrategy(BaseStrategy): target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index) loss = F.cross_entropy( - input=logits.flatten(0, 1), + input=logits.flatten(0, 1).float(), target=target_ids.flatten(), ignore_index=ignore_index )