feat(strategy): 支持模型输入可调用对象并优化损失计算

This commit is contained in:
ViperEkura 2025-10-06 17:08:56 +08:00
parent 8c9e973179
commit 4ffa7454f2
1 changed files with 7 additions and 13 deletions

View File

@ -5,7 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import Any, Literal, Tuple, Callable, Dict
from typing import Any, Literal, Tuple, Callable, Dict, Union
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
@ -35,7 +35,7 @@ def move_to_device(batch:Dict[str, Tensor], device: str) -> Any:
class BaseStrategy(ABC):
def __init__(self, model: nn.Module, device: str):
def __init__(self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str):
self.model = model
self.device = device
@ -54,13 +54,13 @@ class SeqStrategy(BaseStrategy):
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device)
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
B, L = input_ids.size()
logits: Tensor = self.model(input_ids=input_ids)["logits"]
logits = self.model(input_ids=input_ids)["logits"]
loss = F.cross_entropy(
input=logits.view(B * L, -1),
input=logits.flatten(0, 1),
target=target_ids.flatten()
)
return loss
@ -74,17 +74,11 @@ class SftStrategy(BaseStrategy):
loss_mask, attn_mask = batch["loss_mask"], batch["attn_mask"]
ignore_index = -100
B, L = input_ids.size()
logits: Tensor = self.model(
input_ids=input_ids,
input_mask=attn_mask
)["logits"]
logits = self.model(input_ids=input_ids, input_mask=attn_mask)["logits"]
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
loss = F.cross_entropy(
input=logits.view(B * L, -1),
input=logits.flatten(0, 1),
target=target_ids.flatten(),
ignore_index=ignore_index
)