feat(trainer): 改进模型输入和损失计算中的数据类型精度
This commit is contained in:
parent
0093ba7bb8
commit
c86e573195
|
|
@ -8,10 +8,15 @@ from typing import Any, Callable, Dict, Union
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id: int):
|
def get_logprobs(
|
||||||
|
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
|
||||||
|
input_ids: Tensor,
|
||||||
|
mask: Tensor,
|
||||||
|
pad_token_id: int
|
||||||
|
):
|
||||||
input_mask = input_ids.ne(pad_token_id)
|
input_mask = input_ids.ne(pad_token_id)
|
||||||
logits = model(input_ids, input_mask)["logits"]
|
logits = model(input_ids, input_mask)["logits"]
|
||||||
log_probs = torch.log_softmax(logits, dim=-1)
|
log_probs = torch.log_softmax(logits.float(), dim=-1)
|
||||||
|
|
||||||
shifted_log_probs = log_probs[:, :-1, :]
|
shifted_log_probs = log_probs[:, :-1, :]
|
||||||
shifted_input_ids = input_ids[:, 1:]
|
shifted_input_ids = input_ids[:, 1:]
|
||||||
|
|
@ -55,7 +60,7 @@ class SeqStrategy(BaseStrategy):
|
||||||
logits = self.model(input_ids=input_ids)["logits"]
|
logits = self.model(input_ids=input_ids)["logits"]
|
||||||
|
|
||||||
loss = F.cross_entropy(
|
loss = F.cross_entropy(
|
||||||
input=logits.flatten(0, 1),
|
input=logits.flatten(0, 1).float(),
|
||||||
target=target_ids.flatten()
|
target=target_ids.flatten()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -75,7 +80,7 @@ class SftStrategy(BaseStrategy):
|
||||||
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
||||||
|
|
||||||
loss = F.cross_entropy(
|
loss = F.cross_entropy(
|
||||||
input=logits.flatten(0, 1),
|
input=logits.flatten(0, 1).float(),
|
||||||
target=target_ids.flatten(),
|
target=target_ids.flatten(),
|
||||||
ignore_index=ignore_index
|
ignore_index=ignore_index
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue