import argparse import json import torch import torch.nn as nn import torch.nn.functional as F import tqdm from torch import Tensor from astrai.config.param_config import ModelParameter def compute_perplexity( model: nn.Module, input_ids: Tensor, input_mask: Tensor, ) -> Tensor: """ Compute the perplexity of a batch of input sequences, where PPL = exp(-(1/N) * sum(log P(w_i | w_