feat(trainer): 改进模型输入和损失计算中的数据类型精度

This commit is contained in:
ViperEkura 2025-12-08 14:10:08 +08:00
parent 0093ba7bb8
commit c86e573195
1 changed files with 9 additions and 4 deletions

View File

@ -8,10 +8,15 @@ from typing import Any, Callable, Dict, Union
from abc import ABC, abstractmethod
def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id: int):
def get_logprobs(
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
input_ids: Tensor,
mask: Tensor,
pad_token_id: int
):
input_mask = input_ids.ne(pad_token_id)
logits = model(input_ids, input_mask)["logits"]
log_probs = torch.log_softmax(logits, dim=-1)
log_probs = torch.log_softmax(logits.float(), dim=-1)
shifted_log_probs = log_probs[:, :-1, :]
shifted_input_ids = input_ids[:, 1:]
@ -55,7 +60,7 @@ class SeqStrategy(BaseStrategy):
logits = self.model(input_ids=input_ids)["logits"]
loss = F.cross_entropy(
input=logits.flatten(0, 1),
input=logits.flatten(0, 1).float(),
target=target_ids.flatten()
)
@ -75,7 +80,7 @@ class SftStrategy(BaseStrategy):
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
loss = F.cross_entropy(
input=logits.flatten(0, 1),
input=logits.flatten(0, 1).float(),
target=target_ids.flatten(),
ignore_index=ignore_index
)