From 4ffa7454f2e1279b00e4457525efc6c299fa6e81 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 6 Oct 2025 17:08:56 +0800 Subject: [PATCH] =?UTF-8?q?feat(strategy):=20=E6=94=AF=E6=8C=81=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E8=BE=93=E5=85=A5=E5=8F=AF=E8=B0=83=E7=94=A8=E5=AF=B9?= =?UTF-8?q?=E8=B1=A1=E5=B9=B6=E4=BC=98=E5=8C=96=E6=8D=9F=E5=A4=B1=E8=AE=A1?= =?UTF-8?q?=E7=AE=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/strategy.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index 37e6510..81cb15d 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor -from typing import Any, Literal, Tuple, Callable, Dict +from typing import Any, Literal, Tuple, Callable, Dict, Union from abc import ABC, abstractmethod from dataclasses import dataclass, field @@ -35,7 +35,7 @@ def move_to_device(batch:Dict[str, Tensor], device: str) -> Any: class BaseStrategy(ABC): - def __init__(self, model: nn.Module, device: str): + def __init__(self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str): self.model = model self.device = device @@ -54,13 +54,13 @@ class SeqStrategy(BaseStrategy): def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: batch = move_to_device(batch, self.device) input_ids, target_ids = batch["input_ids"], batch["target_ids"] - B, L = input_ids.size() - logits: Tensor = self.model(input_ids=input_ids)["logits"] + logits = self.model(input_ids=input_ids)["logits"] loss = F.cross_entropy( - input=logits.view(B * L, -1), + input=logits.flatten(0, 1), target=target_ids.flatten() ) + return loss @@ -74,17 +74,11 @@ class SftStrategy(BaseStrategy): loss_mask, attn_mask = batch["loss_mask"], batch["attn_mask"] ignore_index = -100 - B, L = input_ids.size() - - logits: Tensor = self.model( - input_ids=input_ids, - input_mask=attn_mask - )["logits"] - + logits = self.model(input_ids=input_ids, input_mask=attn_mask)["logits"] target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index) loss = F.cross_entropy( - input=logits.view(B * L, -1), + input=logits.flatten(0, 1), target=target_ids.flatten(), ignore_index=ignore_index )