diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index 9dc4dd4..37e6510 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -5,11 +5,9 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor -from torch.optim import Optimizer -from torch.utils.data import Dataset -from typing import Any, Literal, Optional, Tuple, Callable, Dict +from typing import Any, Literal, Tuple, Callable, Dict from abc import ABC, abstractmethod -from dataclasses import asdict, dataclass, field +from dataclasses import dataclass, field def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id: int):