130 lines
3.6 KiB
Python
130 lines
3.6 KiB
Python
import argparse
|
|
import json
|
|
|
|
import torch
|
|
|
|
from astrai.inference import InferenceEngine
|
|
from astrai.model import AutoModel
|
|
from astrai.tokenize import AutoTokenizer
|
|
|
|
|
|
def processor(
|
|
model_dir: str,
|
|
input_json_file: str,
|
|
output_json_file: str,
|
|
temperature: float,
|
|
top_k: int,
|
|
top_p: float,
|
|
question_key: str,
|
|
response_key: str,
|
|
max_tokens: int,
|
|
):
|
|
# Load model and tokenizer
|
|
model = AutoModel.from_pretrained(model_dir)
|
|
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
|
model.to(device="cuda", dtype=torch.bfloat16)
|
|
|
|
# Create inference engine
|
|
engine = InferenceEngine(model=model, tokenizer=tokenizer)
|
|
|
|
with open(input_json_file, "r", encoding="utf-8") as f:
|
|
input_data = [json.loads(line) for line in f]
|
|
|
|
# Check input format: chat messages or raw text
|
|
if input_data and "messages" in input_data[0]:
|
|
# Chat format: [{"messages": [...]}]
|
|
prompts = [
|
|
tokenizer.apply_chat_template(item["messages"], tokenize=False)
|
|
for item in input_data
|
|
]
|
|
else:
|
|
# Raw text format: [{"question": "..."}]
|
|
prompts = [item[question_key] for item in input_data]
|
|
|
|
# Use provided max_tokens or default to model config max_len
|
|
if max_tokens is None:
|
|
max_tokens = model.config.max_len
|
|
|
|
# Generate responses (batch)
|
|
responses = engine.generate(
|
|
prompt=prompts,
|
|
stream=False,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
top_k=top_k,
|
|
)
|
|
|
|
# Write results
|
|
with open(output_json_file, "w", encoding="utf-8") as f:
|
|
for prompt, response in zip(prompts, responses):
|
|
if input_data and "messages" in input_data[0]:
|
|
output_item = {"response": response}
|
|
else:
|
|
output_item = {question_key: prompt, response_key: response}
|
|
f.write(json.dumps(output_item, ensure_ascii=False) + "\n")
|
|
|
|
# Cleanup
|
|
engine.shutdown()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.")
|
|
|
|
parser.add_argument(
|
|
"--model_dir", type=str, required=True, help="Path to the model directory."
|
|
)
|
|
parser.add_argument(
|
|
"--input_json_file",
|
|
type=str,
|
|
required=True,
|
|
help="Path to the input JSONL file.",
|
|
)
|
|
parser.add_argument(
|
|
"--output_json_file",
|
|
type=str,
|
|
required=True,
|
|
help="Path to the output JSONL file.",
|
|
)
|
|
parser.add_argument(
|
|
"--question_key",
|
|
type=str,
|
|
default="question",
|
|
help="Key for the question in the input JSON.",
|
|
)
|
|
parser.add_argument(
|
|
"--response_key",
|
|
type=str,
|
|
default="response",
|
|
help="Key for the response in the output JSON.",
|
|
)
|
|
parser.add_argument(
|
|
"--temperature",
|
|
type=float,
|
|
default=0.60,
|
|
help="Temperature for generating responses.",
|
|
)
|
|
parser.add_argument(
|
|
"--top_k", type=int, default=30, help="Top-k value for generating responses."
|
|
)
|
|
parser.add_argument(
|
|
"--top_p",
|
|
type=float,
|
|
default=0.95,
|
|
help="Top-p value for generating responses.",
|
|
)
|
|
parser.add_argument(
|
|
"--batch_size", type=int, default=1, help="Batch size for generating responses."
|
|
)
|
|
parser.add_argument(
|
|
"--max_tokens",
|
|
type=int,
|
|
default=2048,
|
|
help="Maximum tokens to generate (default: model config max_len).",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
with torch.inference_mode():
|
|
processor(**vars(args))
|