diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index 21e57ac..e0adbb7 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -8,10 +8,15 @@ from typing import Any, Callable, Dict, Union from abc import ABC, abstractmethod -def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id: int): +def get_logprobs( + model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], + input_ids: Tensor, + mask: Tensor, + pad_token_id: int +): input_mask = input_ids.ne(pad_token_id) logits = model(input_ids, input_mask)["logits"] - log_probs = torch.log_softmax(logits, dim=-1) + log_probs = torch.log_softmax(logits.float(), dim=-1) shifted_log_probs = log_probs[:, :-1, :] shifted_input_ids = input_ids[:, 1:] @@ -55,7 +60,7 @@ class SeqStrategy(BaseStrategy): logits = self.model(input_ids=input_ids)["logits"] loss = F.cross_entropy( - input=logits.flatten(0, 1), + input=logits.flatten(0, 1).float(), target=target_ids.flatten() ) @@ -75,7 +80,7 @@ class SftStrategy(BaseStrategy): target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index) loss = F.cross_entropy( - input=logits.flatten(0, 1), + input=logits.flatten(0, 1).float(), target=target_ids.flatten(), ignore_index=ignore_index )