feat(strategy): 支持模型输入可调用对象并优化损失计算

This commit is contained in:
ViperEkura 2025-10-06 17:08:56 +08:00
parent 8c9e973179
commit 4ffa7454f2
1 changed files with 7 additions and 13 deletions

View File

@ -5,7 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from typing import Any, Literal, Tuple, Callable, Dict from typing import Any, Literal, Tuple, Callable, Dict, Union
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -35,7 +35,7 @@ def move_to_device(batch:Dict[str, Tensor], device: str) -> Any:
class BaseStrategy(ABC): class BaseStrategy(ABC):
def __init__(self, model: nn.Module, device: str): def __init__(self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str):
self.model = model self.model = model
self.device = device self.device = device
@ -54,13 +54,13 @@ class SeqStrategy(BaseStrategy):
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device) batch = move_to_device(batch, self.device)
input_ids, target_ids = batch["input_ids"], batch["target_ids"] input_ids, target_ids = batch["input_ids"], batch["target_ids"]
B, L = input_ids.size() logits = self.model(input_ids=input_ids)["logits"]
logits: Tensor = self.model(input_ids=input_ids)["logits"]
loss = F.cross_entropy( loss = F.cross_entropy(
input=logits.view(B * L, -1), input=logits.flatten(0, 1),
target=target_ids.flatten() target=target_ids.flatten()
) )
return loss return loss
@ -74,17 +74,11 @@ class SftStrategy(BaseStrategy):
loss_mask, attn_mask = batch["loss_mask"], batch["attn_mask"] loss_mask, attn_mask = batch["loss_mask"], batch["attn_mask"]
ignore_index = -100 ignore_index = -100
B, L = input_ids.size() logits = self.model(input_ids=input_ids, input_mask=attn_mask)["logits"]
logits: Tensor = self.model(
input_ids=input_ids,
input_mask=attn_mask
)["logits"]
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index) target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
loss = F.cross_entropy( loss = F.cross_entropy(
input=logits.view(B * L, -1), input=logits.flatten(0, 1),
target=target_ids.flatten(), target=target_ids.flatten(),
ignore_index=ignore_index ignore_index=ignore_index
) )