AstrAI/tools/generate.py

103 lines
2.7 KiB
Python

import torch
import json
import argparse
from astrai.config.param_config import ModelParameter
from astrai.inference.generator import BatchGenerator, GenerationRequest
from astrai.inference.core import disable_random_init
def processor(
model_dir: str,
input_json_file: str,
output_json_file: str,
batch_size: int,
temperature: float,
top_k: int,
top_p: float,
question_key: str,
response_key: str,
):
with disable_random_init():
param = ModelParameter.load(model_dir)
param.to(device="cuda", dtype=torch.bfloat16)
generator = BatchGenerator(param)
with open(input_json_file, "r", encoding="utf-8") as f:
input_data = [json.loads(line) for line in f]
queries = [item[question_key] for item in input_data]
request = GenerationRequest(
query=queries,
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_len=param.config.max_len,
history=None,
system_prompt=None,
)
responses = generator.generate(request)
with open(output_json_file, "w", encoding="utf-8") as f:
for query, response in zip(queries, responses):
output_item = {question_key: query, response_key: response}
f.write(json.dumps(output_item, ensure_ascii=False) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.")
parser.add_argument(
"--model_dir", type=str, required=True, help="Path to the model directory."
)
parser.add_argument(
"--input_json_file",
type=str,
required=True,
help="Path to the input JSONL file.",
)
parser.add_argument(
"--output_json_file",
type=str,
required=True,
help="Path to the output JSONL file.",
)
parser.add_argument(
"--question_key",
type=str,
default="question",
help="Key for the question in the input JSON.",
)
parser.add_argument(
"--response_key",
type=str,
default="response",
help="Key for the response in the output JSON.",
)
parser.add_argument(
"--temperature",
type=float,
default=0.60,
help="Temperature for generating responses.",
)
parser.add_argument(
"--top_k", type=int, default=30, help="Top-k value for generating responses."
)
parser.add_argument(
"--top_p",
type=float,
default=0.95,
help="Top-p value for generating responses.",
)
parser.add_argument(
"--batch_size", type=int, default=1, help="Batch size for generating responses."
)
args = parser.parse_args()
with torch.inference_mode():
processor(**vars(args))