feat(strategy): 支持模型输入可调用对象并优化损失计算
This commit is contained in:
parent
8c9e973179
commit
4ffa7454f2
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
from torch import Tensor
|
||||
from typing import Any, Literal, Tuple, Callable, Dict
|
||||
from typing import Any, Literal, Tuple, Callable, Dict, Union
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
|
@ -35,7 +35,7 @@ def move_to_device(batch:Dict[str, Tensor], device: str) -> Any:
|
|||
|
||||
|
||||
class BaseStrategy(ABC):
|
||||
def __init__(self, model: nn.Module, device: str):
|
||||
def __init__(self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str):
|
||||
self.model = model
|
||||
self.device = device
|
||||
|
||||
|
|
@ -54,13 +54,13 @@ class SeqStrategy(BaseStrategy):
|
|||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
batch = move_to_device(batch, self.device)
|
||||
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
|
||||
B, L = input_ids.size()
|
||||
logits: Tensor = self.model(input_ids=input_ids)["logits"]
|
||||
logits = self.model(input_ids=input_ids)["logits"]
|
||||
|
||||
loss = F.cross_entropy(
|
||||
input=logits.view(B * L, -1),
|
||||
input=logits.flatten(0, 1),
|
||||
target=target_ids.flatten()
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
|
|
@ -74,17 +74,11 @@ class SftStrategy(BaseStrategy):
|
|||
loss_mask, attn_mask = batch["loss_mask"], batch["attn_mask"]
|
||||
|
||||
ignore_index = -100
|
||||
B, L = input_ids.size()
|
||||
|
||||
logits: Tensor = self.model(
|
||||
input_ids=input_ids,
|
||||
input_mask=attn_mask
|
||||
)["logits"]
|
||||
|
||||
logits = self.model(input_ids=input_ids, input_mask=attn_mask)["logits"]
|
||||
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
||||
|
||||
loss = F.cross_entropy(
|
||||
input=logits.view(B * L, -1),
|
||||
input=logits.flatten(0, 1),
|
||||
target=target_ids.flatten(),
|
||||
ignore_index=ignore_index
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue