From abc3a06266b46fd06a83b1a9159666fa56470da7 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Wed, 18 Mar 2026 16:16:02 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E5=A2=9E=E5=8A=A0ppl=E8=AE=A1?= =?UTF-8?q?=E7=AE=97=E5=B7=A5=E5=85=B7=E5=B9=B6=E4=BC=98=E5=8C=96=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E6=A0=BC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/generate.py | 45 ++++++++----------- tools/perplexity.py | 102 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 120 insertions(+), 27 deletions(-) create mode 100644 tools/perplexity.py diff --git a/tools/generate.py b/tools/generate.py index cc00851..cedefdc 100644 --- a/tools/generate.py +++ b/tools/generate.py @@ -1,4 +1,3 @@ -import os import torch import json import torch @@ -9,8 +8,6 @@ from typing import List from tqdm import tqdm -PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) - def batch_generate( model: Khaosz, query: List[str], @@ -38,9 +35,6 @@ def batch_generate( top_p=top_p ) - for batch_query, batch_response in zip(batch_query, batch_responses): - print(f"Q: {batch_query[:50]} \nR: {batch_response[:50]})") - for query, response in zip(batch_query, batch_responses): original_idx = original_indices[query] responses[original_idx] = response @@ -49,20 +43,23 @@ def batch_generate( def processor( - model: Khaosz, + model_dir: str, input_json_file: str, output_json_file: str, batch_size: int, temperature: float, - top_p: float, top_k: int, - question_key: str="question", + top_p: float, + question_key: str, + response_key: str, ): - with open(input_json_file, "r", encoding='utf-8') as f: - input_dict = [json.loads(line) for line in f] - query = [item[question_key] for item in input_dict] + model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16) - output_dict = batch_generate( + with open(input_json_file, "r", encoding='utf-8') as f: + input_data = [json.loads(line) for line in f] + query = [item[question_key] for item in input_data] + + responses = batch_generate( model=model, query=query, temperature=temperature, @@ -71,8 +68,12 @@ def processor( batch_size=batch_size ) + # Write output in JSONL format with open(output_json_file, "w", encoding='utf-8') as f: - json.dump(output_dict, f, indent=4, ensure_ascii=False) + for query, response in zip(query, responses): + output_item = {question_key: query, response_key: response} + f.write(json.dumps(output_item, ensure_ascii=False) + '\n') + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.") @@ -81,21 +82,11 @@ if __name__ == "__main__": parser.add_argument("--input_json_file", type=str, required=True, help="Path to the input JSONL file.") parser.add_argument("--output_json_file", type=str, required=True, help="Path to the output JSONL file.") parser.add_argument("--question_key", type=str, default="question", help="Key for the question in the input JSON.") + parser.add_argument("--response_key", type=str, default="response", help="Key for the response in the output JSON.") parser.add_argument("--temperature", type=float, default=0.60, help="Temperature for generating responses.") - parser.add_argument("--top_p", type=float, default=0.95, help="Top-p value for generating responses.") parser.add_argument("--top_k", type=int, default=30, help="Top-k value for generating responses.") + parser.add_argument("--top_p", type=float, default=0.95, help="Top-p value for generating responses.") parser.add_argument("--batch_size", type=int, default=1, help="Batch size for generating responses.") args = parser.parse_args() - model = Khaosz(args.model_dir).to(device='cuda', dtype=torch.bfloat16) - - processor( - model, - input_json_file=args.input_json_file, - output_json_file=args.output_json_file, - question_key=args.question_key, - batch_size=args.batch_size, - temperature=args.temperature, - top_k=args.top_k, - top_p=args.top_p - ) \ No newline at end of file + processor(**vars(args)) \ No newline at end of file diff --git a/tools/perplexity.py b/tools/perplexity.py new file mode 100644 index 0000000..fe3bf56 --- /dev/null +++ b/tools/perplexity.py @@ -0,0 +1,102 @@ +import json +import torch +import torch.nn as nn +import torch.nn.functional as F +import argparse +import tqdm + +from torch import Tensor +from khaosz import Khaosz + +def compute_perplexity( + model: nn.Module, + input_ids: Tensor, + input_mask: Tensor, + ) -> Tensor: + """ + Compute the perplexity of a batch of input sequences, + where PPL = exp(-(1/N) * sum(log P(w_i | w_