refactor: 优化工具脚本接口并修复批处理问题

This commit is contained in:
ViperEkura 2026-04-05 21:56:22 +08:00
parent e58dbd7c57
commit 9b22b1651e
8 changed files with 140 additions and 64 deletions

View File

@ -269,13 +269,20 @@ class InferenceEngine:
result = _NonStreamingResult(len(prompts))
for i, p in enumerate(prompts):
# Create closure to capture current index value using factory function
def make_callback(idx):
def callback(token):
result.append(idx, token)
return callback
self.scheduler.add_task(
prompt=p,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=result.append,
stream_callback=make_callback(i),
)
result.wait()

View File

@ -97,7 +97,7 @@ def load_model(
# Load tokenizer separately
tokenizer = TextTokenizer.from_pretrained(param_path)
_model_param = AutoModel.from_pretrained(param_path, tokenizer=tokenizer)
_model_param = AutoModel.from_pretrained(param_path)
_model_param.to(device=device, dtype=dtype)
logger.info(f"Model loaded on {device} with dtype {dtype}")

View File

@ -13,11 +13,9 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
def generate_text():
# Load model from pretrained
model = AutoModel.from_pretrained(PARAMETER_ROOT)
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
model.to(device="cuda", dtype=torch.bfloat16)
# Load tokenizer from pretrained
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT / "tokenizer")
query = input(">> ")
engine = InferenceEngine(

View File

@ -2,8 +2,9 @@ from pathlib import Path
import torch
from astrai.model import AutoModel
from astrai.inference import InferenceEngine
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
PROJECT_ROOT = Path(__file__).resolve().parents[2]
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
@ -11,9 +12,9 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
def batch_generate():
# Load model using AutoModel
model = AutoModel.from_pretrained(
PARAMETER_ROOT, device="cuda", dtype=torch.bfloat16
)
model = AutoModel.from_pretrained(PARAMETER_ROOT)
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
model.to(device="cuda", dtype=torch.bfloat16)
inputs = [
"你好",
@ -24,13 +25,13 @@ def batch_generate():
]
engine = InferenceEngine(
model=model.model,
tokenizer=model.tokenizer,
model=model,
tokenizer=tokenizer,
)
responses = engine.generate(
prompt=inputs,
stream=False,
max_tokens=model.config.max_len,
max_tokens=2048,
temperature=0.8,
top_p=0.95,
top_k=50,

View File

@ -33,7 +33,7 @@ def chat():
for token in engine.generate(
prompt=prompt,
stream=True,
max_tokens=model.config.max_len,
max_tokens=2048,
temperature=0.8,
top_p=0.95,
top_k=50,

View File

@ -17,30 +17,56 @@ def processor(
top_p: float,
question_key: str,
response_key: str,
max_tokens: int,
):
# Load model using AutoModel
model = AutoModel.from_pretrained(model_dir, device="cuda", dtype=torch.bfloat16)
engine = InferenceEngine(model=model.model, tokenizer=model.tokenizer)
# Load model and tokenizer
model = AutoModel.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model.to(device="cuda", dtype=torch.bfloat16)
# Create inference engine
engine = InferenceEngine(model=model, tokenizer=tokenizer)
with open(input_json_file, "r", encoding="utf-8") as f:
input_data = [json.loads(line) for line in f]
queries = [item[question_key] for item in input_data]
# Check input format: chat messages or raw text
if input_data and "messages" in input_data[0]:
# Chat format: [{"messages": [...]}]
prompts = [
tokenizer.apply_chat_template(item["messages"], tokenize=False)
for item in input_data
]
else:
# Raw text format: [{"question": "..."}]
prompts = [item[question_key] for item in input_data]
# Use provided max_tokens or default to model config max_len
if max_tokens is None:
max_tokens = model.config.max_len
# Generate responses (batch)
responses = engine.generate(
prompt=queries,
prompt=prompts,
stream=False,
max_tokens=model.config.max_len,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
# Write results
with open(output_json_file, "w", encoding="utf-8") as f:
for query, response in zip(queries, responses):
output_item = {question_key: query, response_key: response}
for prompt, response in zip(prompts, responses):
if input_data and "messages" in input_data[0]:
output_item = {"response": response}
else:
output_item = {question_key: prompt, response_key: response}
f.write(json.dumps(output_item, ensure_ascii=False) + "\n")
# Cleanup
engine.shutdown()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.")
@ -90,6 +116,12 @@ if __name__ == "__main__":
parser.add_argument(
"--batch_size", type=int, default=1, help="Batch size for generating responses."
)
parser.add_argument(
"--max_tokens",
type=int,
default=2048,
help="Maximum tokens to generate (default: model config max_len).",
)
args = parser.parse_args()

View File

@ -2,62 +2,42 @@ import argparse
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
from torch import Tensor
from astrai.model import AutoModel
def compute_perplexity(
model: nn.Module,
input_ids: Tensor,
input_mask: Tensor,
) -> Tensor:
"""
Compute the perplexity of a batch of input sequences,
where PPL = exp(-(1/N) * sum(log P(w_i | w_<i))).
"""
output = model(input_ids, input_mask=input_mask)
logits = output["logits"]
shifted_logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size]
shifted_input_ids = input_ids[:, 1:] # [batch_size, seq_len-1]
shifted_mask = input_mask[:, 1:] # [batch_size, seq_len-1]
loss = F.cross_entropy(
shifted_logits.flatten(0, 1), shifted_input_ids.flatten(0, 1), reduction="none"
)
loss = loss.view(shifted_input_ids.shape) # [batch_size, seq_len-1]
loss = loss * shifted_mask
sentence_loss = (loss).sum(dim=1) / shifted_mask.sum(dim=1)
perplexity = torch.exp(sentence_loss) # [batch_size]
return perplexity
from astrai.tokenize import AutoTokenizer
def process_file(
model_dir: str, input_file: str, output_file: str, batch_size: int, text_key: str
):
# Load model using AutoModel
model = AutoModel.from_pretrained(model_dir, device="cuda", dtype=torch.bfloat16)
tokenizer = model.tokenizer
# Load model and tokenizer
model = AutoModel.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model.to(device="cuda", dtype=torch.bfloat16)
with open(input_file, "r", encoding="utf-8") as f:
input_data = [json.loads(line) for line in f]
texts = [item[text_key] for item in input_data]
# Encode all texts
print(f"Encoding {len(texts)} texts...")
encoded_texts = [tokenizer.encode(text) for text in texts]
output_data = []
total_batches = (len(encoded_texts) + batch_size - 1) // batch_size
for i in tqdm.tqdm(
range(0, len(encoded_texts), batch_size), desc="Computing perplexity"
range(0, len(encoded_texts), batch_size),
total=total_batches,
desc="Computing perplexity",
):
batch_encoded = encoded_texts[i : i + batch_size]
batch_texts = texts[i : i + batch_size]
# Find max length in batch and pad
max_len = max(len(seq) for seq in batch_encoded)
padded_ids = []
masks = []
@ -69,17 +49,41 @@ def process_file(
padded_ids.append(padded_seq)
masks.append(mask)
# Convert to tensors
input_ids = torch.tensor(padded_ids, device="cuda", dtype=torch.long)
input_mask = torch.tensor(masks, device="cuda", dtype=torch.bool)
perplexity = compute_perplexity(model.model, input_ids, input_mask)
# Compute perplexity
output = model(input_ids, input_mask=input_mask)
logits = output["logits"]
# Shift for causal language modeling
shifted_logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size]
shifted_input_ids = input_ids[:, 1:] # [batch_size, seq_len-1]
shifted_mask = input_mask[:, 1:] # [batch_size, seq_len-1]
# Compute cross entropy loss
loss = F.cross_entropy(
shifted_logits.flatten(0, 1),
shifted_input_ids.flatten(0, 1),
reduction="none",
)
loss = loss.view(shifted_input_ids.shape) # [batch_size, seq_len-1]
loss = loss * shifted_mask
sentence_loss = loss.sum(dim=1) / shifted_mask.sum(dim=1).clamp(min=1)
perplexity = torch.exp(sentence_loss) # [batch_size]
for text, ppl in zip(batch_texts, perplexity):
output_data.append({text_key: text, "ppl": float(ppl.item())})
# Write results
with open(output_file, "w", encoding="utf-8") as f:
for item in output_data:
f.write(json.dumps(item, ensure_ascii=False) + "\n")
print(f"Perplexity computation complete. Results saved to {output_file}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")

View File

@ -1,6 +1,8 @@
import argparse
from pathlib import Path
import torch
from astrai.inference.server import run_server
@ -21,17 +23,49 @@ def main():
default=None,
help="Path to model parameters (default: project_root/params)",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to load model on (default: cuda)",
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
choices=["bfloat16", "float16", "float32"],
help="Data type for model weights (default: bfloat16)",
)
parser.add_argument(
"--max_batch_size",
type=int,
default=16,
help="Maximum batch size for continuous batching (default: 16)",
)
args = parser.parse_args()
# If param_path is provided, set environment variable or modify global?
# Currently the server loads from default location on startup.
# We could pass it via an environment variable, but for simplicity we assume
# the default location is correct.
project_root = Path(__file__).parent.parent
# Convert dtype string to torch dtype
dtype_map = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
}
dtype = dtype_map[args.dtype]
project_root = Path(__file__).parent.parent.parent
param_path = args.param_path or (project_root / "params")
print(f"Starting AstrAI inference server on http://{args.host}:{args.port}")
print(f"Model parameters expected at: {[param_path]}")
run_server(host=args.host, port=args.port, reload=args.reload)
print(f"Model parameters expected at: {param_path}")
print(f"Device: {args.device}, Dtype: {args.dtype}")
run_server(
host=args.host,
port=args.port,
reload=args.reload,
device=args.device,
dtype=dtype,
param_path=param_path,
max_batch_size=args.max_batch_size,
)
if __name__ == "__main__":