diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index 37e6510..81cb15d 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor -from typing import Any, Literal, Tuple, Callable, Dict +from typing import Any, Literal, Tuple, Callable, Dict, Union from abc import ABC, abstractmethod from dataclasses import dataclass, field @@ -35,7 +35,7 @@ def move_to_device(batch:Dict[str, Tensor], device: str) -> Any: class BaseStrategy(ABC): - def __init__(self, model: nn.Module, device: str): + def __init__(self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str): self.model = model self.device = device @@ -54,13 +54,13 @@ class SeqStrategy(BaseStrategy): def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: batch = move_to_device(batch, self.device) input_ids, target_ids = batch["input_ids"], batch["target_ids"] - B, L = input_ids.size() - logits: Tensor = self.model(input_ids=input_ids)["logits"] + logits = self.model(input_ids=input_ids)["logits"] loss = F.cross_entropy( - input=logits.view(B * L, -1), + input=logits.flatten(0, 1), target=target_ids.flatten() ) + return loss @@ -74,17 +74,11 @@ class SftStrategy(BaseStrategy): loss_mask, attn_mask = batch["loss_mask"], batch["attn_mask"] ignore_index = -100 - B, L = input_ids.size() - - logits: Tensor = self.model( - input_ids=input_ids, - input_mask=attn_mask - )["logits"] - + logits = self.model(input_ids=input_ids, input_mask=attn_mask)["logits"] target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index) loss = F.cross_entropy( - input=logits.view(B * L, -1), + input=logits.flatten(0, 1), target=target_ids.flatten(), ignore_index=ignore_index )