chore: 增加ppl计算工具并优化代码格式
This commit is contained in:
parent
62fba9a298
commit
abc3a06266
|
|
@ -1,4 +1,3 @@
|
||||||
import os
|
|
||||||
import torch
|
import torch
|
||||||
import json
|
import json
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -9,8 +8,6 @@ from typing import List
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
|
|
||||||
def batch_generate(
|
def batch_generate(
|
||||||
model: Khaosz,
|
model: Khaosz,
|
||||||
query: List[str],
|
query: List[str],
|
||||||
|
|
@ -38,9 +35,6 @@ def batch_generate(
|
||||||
top_p=top_p
|
top_p=top_p
|
||||||
)
|
)
|
||||||
|
|
||||||
for batch_query, batch_response in zip(batch_query, batch_responses):
|
|
||||||
print(f"Q: {batch_query[:50]} \nR: {batch_response[:50]})")
|
|
||||||
|
|
||||||
for query, response in zip(batch_query, batch_responses):
|
for query, response in zip(batch_query, batch_responses):
|
||||||
original_idx = original_indices[query]
|
original_idx = original_indices[query]
|
||||||
responses[original_idx] = response
|
responses[original_idx] = response
|
||||||
|
|
@ -49,20 +43,23 @@ def batch_generate(
|
||||||
|
|
||||||
|
|
||||||
def processor(
|
def processor(
|
||||||
model: Khaosz,
|
model_dir: str,
|
||||||
input_json_file: str,
|
input_json_file: str,
|
||||||
output_json_file: str,
|
output_json_file: str,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
temperature: float,
|
temperature: float,
|
||||||
top_p: float,
|
|
||||||
top_k: int,
|
top_k: int,
|
||||||
question_key: str="question",
|
top_p: float,
|
||||||
|
question_key: str,
|
||||||
|
response_key: str,
|
||||||
):
|
):
|
||||||
with open(input_json_file, "r", encoding='utf-8') as f:
|
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||||
input_dict = [json.loads(line) for line in f]
|
|
||||||
query = [item[question_key] for item in input_dict]
|
|
||||||
|
|
||||||
output_dict = batch_generate(
|
with open(input_json_file, "r", encoding='utf-8') as f:
|
||||||
|
input_data = [json.loads(line) for line in f]
|
||||||
|
query = [item[question_key] for item in input_data]
|
||||||
|
|
||||||
|
responses = batch_generate(
|
||||||
model=model,
|
model=model,
|
||||||
query=query,
|
query=query,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
|
|
@ -71,8 +68,12 @@ def processor(
|
||||||
batch_size=batch_size
|
batch_size=batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Write output in JSONL format
|
||||||
with open(output_json_file, "w", encoding='utf-8') as f:
|
with open(output_json_file, "w", encoding='utf-8') as f:
|
||||||
json.dump(output_dict, f, indent=4, ensure_ascii=False)
|
for query, response in zip(query, responses):
|
||||||
|
output_item = {question_key: query, response_key: response}
|
||||||
|
f.write(json.dumps(output_item, ensure_ascii=False) + '\n')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.")
|
parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.")
|
||||||
|
|
@ -81,21 +82,11 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--input_json_file", type=str, required=True, help="Path to the input JSONL file.")
|
parser.add_argument("--input_json_file", type=str, required=True, help="Path to the input JSONL file.")
|
||||||
parser.add_argument("--output_json_file", type=str, required=True, help="Path to the output JSONL file.")
|
parser.add_argument("--output_json_file", type=str, required=True, help="Path to the output JSONL file.")
|
||||||
parser.add_argument("--question_key", type=str, default="question", help="Key for the question in the input JSON.")
|
parser.add_argument("--question_key", type=str, default="question", help="Key for the question in the input JSON.")
|
||||||
|
parser.add_argument("--response_key", type=str, default="response", help="Key for the response in the output JSON.")
|
||||||
parser.add_argument("--temperature", type=float, default=0.60, help="Temperature for generating responses.")
|
parser.add_argument("--temperature", type=float, default=0.60, help="Temperature for generating responses.")
|
||||||
parser.add_argument("--top_p", type=float, default=0.95, help="Top-p value for generating responses.")
|
|
||||||
parser.add_argument("--top_k", type=int, default=30, help="Top-k value for generating responses.")
|
parser.add_argument("--top_k", type=int, default=30, help="Top-k value for generating responses.")
|
||||||
|
parser.add_argument("--top_p", type=float, default=0.95, help="Top-p value for generating responses.")
|
||||||
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for generating responses.")
|
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for generating responses.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
model = Khaosz(args.model_dir).to(device='cuda', dtype=torch.bfloat16)
|
processor(**vars(args))
|
||||||
|
|
||||||
processor(
|
|
||||||
model,
|
|
||||||
input_json_file=args.input_json_file,
|
|
||||||
output_json_file=args.output_json_file,
|
|
||||||
question_key=args.question_key,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
temperature=args.temperature,
|
|
||||||
top_k=args.top_k,
|
|
||||||
top_p=args.top_p
|
|
||||||
)
|
|
||||||
|
|
@ -0,0 +1,102 @@
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import argparse
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
from khaosz import Khaosz
|
||||||
|
|
||||||
|
def compute_perplexity(
|
||||||
|
model: nn.Module,
|
||||||
|
input_ids: Tensor,
|
||||||
|
input_mask: Tensor,
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
Compute the perplexity of a batch of input sequences,
|
||||||
|
where PPL = exp(-(1/N) * sum(log P(w_i | w_<i))).
|
||||||
|
"""
|
||||||
|
|
||||||
|
output = model(input_ids, input_mask)
|
||||||
|
logits = output["logits"]
|
||||||
|
|
||||||
|
shifted_logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size]
|
||||||
|
shifted_input_ids = input_ids[:, 1:] # [batch_size, seq_len-1]
|
||||||
|
shifted_mask = input_mask[:, 1:] # [batch_size, seq_len-1]
|
||||||
|
|
||||||
|
loss = F.cross_entropy(
|
||||||
|
shifted_logits.flatten(0, 1),
|
||||||
|
shifted_input_ids.flatten(0, 1),
|
||||||
|
reduction='none'
|
||||||
|
)
|
||||||
|
|
||||||
|
loss = loss.view(shifted_input_ids.shape) # [batch_size, seq_len-1]
|
||||||
|
loss = loss * shifted_mask
|
||||||
|
sentence_loss = (loss).sum(dim=1) / shifted_mask.sum(dim=1)
|
||||||
|
perplexity = torch.exp(sentence_loss) # [batch_size]
|
||||||
|
|
||||||
|
return perplexity
|
||||||
|
|
||||||
|
def process_file(
|
||||||
|
model_dir: str,
|
||||||
|
input_file: str,
|
||||||
|
output_file: str,
|
||||||
|
batch_size: int,
|
||||||
|
text_key: str
|
||||||
|
):
|
||||||
|
model = Khaosz(model_dir).to(device="cuda", dtype=torch.bfloat16)
|
||||||
|
tokenizer = model.parameter.tokenizer
|
||||||
|
|
||||||
|
with open(input_file, "r", encoding='utf-8') as f:
|
||||||
|
input_data = [json.loads(line) for line in f]
|
||||||
|
|
||||||
|
texts = [item[text_key] for item in input_data]
|
||||||
|
encoded_texts = [tokenizer.encode(text) for text in texts]
|
||||||
|
|
||||||
|
output_data = []
|
||||||
|
|
||||||
|
for i in tqdm(range(0, len(encoded_texts), batch_size), desc="Computing perplexity"):
|
||||||
|
batch_encoded = encoded_texts[i:i + batch_size]
|
||||||
|
batch_texts = texts[i:i + batch_size]
|
||||||
|
|
||||||
|
# Pad sequences to the same length (left padding)
|
||||||
|
max_len = max(len(seq) for seq in batch_encoded)
|
||||||
|
padded_ids = []
|
||||||
|
masks = []
|
||||||
|
|
||||||
|
for seq in batch_encoded:
|
||||||
|
pad_len = max_len - len(seq)
|
||||||
|
padded_seq = [tokenizer.pad_id] * pad_len + seq
|
||||||
|
mask = [False] * pad_len + [True] * len(seq)
|
||||||
|
padded_ids.append(padded_seq)
|
||||||
|
masks.append(mask)
|
||||||
|
|
||||||
|
input_ids = torch.tensor(padded_ids, device="cuda", dtype=torch.long)
|
||||||
|
input_mask = torch.tensor(masks, device="cuda", dtype=torch.bool)
|
||||||
|
|
||||||
|
# Compute perplexity
|
||||||
|
with torch.inference_mode():
|
||||||
|
perplexity = compute_perplexity(model.parameter.model, input_ids, input_mask)
|
||||||
|
|
||||||
|
for text, ppl in zip(batch_texts, perplexity):
|
||||||
|
output_data.append({text_key: text, "ppl": float(ppl.item())})
|
||||||
|
|
||||||
|
with open(output_file, "w", encoding='utf-8') as f:
|
||||||
|
for item in output_data:
|
||||||
|
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
|
||||||
|
parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory.")
|
||||||
|
parser.add_argument("--input_file", type=str, required=True, help="Path to the input file.")
|
||||||
|
parser.add_argument("--output_file", type=str, required=True, help="Path to the output file.")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=4, help="Batch size for evaluation.")
|
||||||
|
parser.add_argument("--text_key", type=str, default="text", help="Key for the text field in the input data.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
process_file(**vars(args))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
Reference in New Issue