feat(trainer): 改进模型输入和损失计算中的数据类型精度
This commit is contained in:
parent
0093ba7bb8
commit
c86e573195
|
|
@ -8,10 +8,15 @@ from typing import Any, Callable, Dict, Union
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id: int):
|
||||
def get_logprobs(
|
||||
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
|
||||
input_ids: Tensor,
|
||||
mask: Tensor,
|
||||
pad_token_id: int
|
||||
):
|
||||
input_mask = input_ids.ne(pad_token_id)
|
||||
logits = model(input_ids, input_mask)["logits"]
|
||||
log_probs = torch.log_softmax(logits, dim=-1)
|
||||
log_probs = torch.log_softmax(logits.float(), dim=-1)
|
||||
|
||||
shifted_log_probs = log_probs[:, :-1, :]
|
||||
shifted_input_ids = input_ids[:, 1:]
|
||||
|
|
@ -55,7 +60,7 @@ class SeqStrategy(BaseStrategy):
|
|||
logits = self.model(input_ids=input_ids)["logits"]
|
||||
|
||||
loss = F.cross_entropy(
|
||||
input=logits.flatten(0, 1),
|
||||
input=logits.flatten(0, 1).float(),
|
||||
target=target_ids.flatten()
|
||||
)
|
||||
|
||||
|
|
@ -75,7 +80,7 @@ class SftStrategy(BaseStrategy):
|
|||
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
||||
|
||||
loss = F.cross_entropy(
|
||||
input=logits.flatten(0, 1),
|
||||
input=logits.flatten(0, 1).float(),
|
||||
target=target_ids.flatten(),
|
||||
ignore_index=ignore_index
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue