From fc98d9b7e656acec80594318befab04f2d24d517 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 4 Oct 2025 21:45:53 +0800 Subject: [PATCH] =?UTF-8?q?refactor(khaosz/trainer):=20=E7=A7=BB=E9=99=A4?= =?UTF-8?q?=E6=9C=AA=E4=BD=BF=E7=94=A8=E7=9A=84=E5=AF=BC=E5=85=A5=E6=A8=A1?= =?UTF-8?q?=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/strategy.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index 9dc4dd4..37e6510 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -5,11 +5,9 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor -from torch.optim import Optimizer -from torch.utils.data import Dataset -from typing import Any, Literal, Optional, Tuple, Callable, Dict +from typing import Any, Literal, Tuple, Callable, Dict from abc import ABC, abstractmethod -from dataclasses import asdict, dataclass, field +from dataclasses import dataclass, field def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id: int):