refactor: 优化工具脚本接口并修复批处理问题
This commit is contained in:
parent
e58dbd7c57
commit
9b22b1651e
|
|
@ -269,13 +269,20 @@ class InferenceEngine:
|
||||||
result = _NonStreamingResult(len(prompts))
|
result = _NonStreamingResult(len(prompts))
|
||||||
|
|
||||||
for i, p in enumerate(prompts):
|
for i, p in enumerate(prompts):
|
||||||
|
# Create closure to capture current index value using factory function
|
||||||
|
def make_callback(idx):
|
||||||
|
def callback(token):
|
||||||
|
result.append(idx, token)
|
||||||
|
|
||||||
|
return callback
|
||||||
|
|
||||||
self.scheduler.add_task(
|
self.scheduler.add_task(
|
||||||
prompt=p,
|
prompt=p,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
stream_callback=result.append,
|
stream_callback=make_callback(i),
|
||||||
)
|
)
|
||||||
|
|
||||||
result.wait()
|
result.wait()
|
||||||
|
|
|
||||||
|
|
@ -97,7 +97,7 @@ def load_model(
|
||||||
|
|
||||||
# Load tokenizer separately
|
# Load tokenizer separately
|
||||||
tokenizer = TextTokenizer.from_pretrained(param_path)
|
tokenizer = TextTokenizer.from_pretrained(param_path)
|
||||||
_model_param = AutoModel.from_pretrained(param_path, tokenizer=tokenizer)
|
_model_param = AutoModel.from_pretrained(param_path)
|
||||||
_model_param.to(device=device, dtype=dtype)
|
_model_param.to(device=device, dtype=dtype)
|
||||||
logger.info(f"Model loaded on {device} with dtype {dtype}")
|
logger.info(f"Model loaded on {device} with dtype {dtype}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,11 +13,9 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||||
def generate_text():
|
def generate_text():
|
||||||
# Load model from pretrained
|
# Load model from pretrained
|
||||||
model = AutoModel.from_pretrained(PARAMETER_ROOT)
|
model = AutoModel.from_pretrained(PARAMETER_ROOT)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
|
||||||
model.to(device="cuda", dtype=torch.bfloat16)
|
model.to(device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
# Load tokenizer from pretrained
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT / "tokenizer")
|
|
||||||
|
|
||||||
query = input(">> ")
|
query = input(">> ")
|
||||||
|
|
||||||
engine = InferenceEngine(
|
engine = InferenceEngine(
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,9 @@ from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.model import AutoModel
|
|
||||||
from astrai.inference import InferenceEngine
|
from astrai.inference import InferenceEngine
|
||||||
|
from astrai.model import AutoModel
|
||||||
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||||
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||||
|
|
@ -11,9 +12,9 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||||
|
|
||||||
def batch_generate():
|
def batch_generate():
|
||||||
# Load model using AutoModel
|
# Load model using AutoModel
|
||||||
model = AutoModel.from_pretrained(
|
model = AutoModel.from_pretrained(PARAMETER_ROOT)
|
||||||
PARAMETER_ROOT, device="cuda", dtype=torch.bfloat16
|
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
|
||||||
)
|
model.to(device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
inputs = [
|
inputs = [
|
||||||
"你好",
|
"你好",
|
||||||
|
|
@ -24,13 +25,13 @@ def batch_generate():
|
||||||
]
|
]
|
||||||
|
|
||||||
engine = InferenceEngine(
|
engine = InferenceEngine(
|
||||||
model=model.model,
|
model=model,
|
||||||
tokenizer=model.tokenizer,
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
responses = engine.generate(
|
responses = engine.generate(
|
||||||
prompt=inputs,
|
prompt=inputs,
|
||||||
stream=False,
|
stream=False,
|
||||||
max_tokens=model.config.max_len,
|
max_tokens=2048,
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ def chat():
|
||||||
for token in engine.generate(
|
for token in engine.generate(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=True,
|
stream=True,
|
||||||
max_tokens=model.config.max_len,
|
max_tokens=2048,
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
|
|
|
||||||
|
|
@ -17,30 +17,56 @@ def processor(
|
||||||
top_p: float,
|
top_p: float,
|
||||||
question_key: str,
|
question_key: str,
|
||||||
response_key: str,
|
response_key: str,
|
||||||
|
max_tokens: int,
|
||||||
):
|
):
|
||||||
# Load model using AutoModel
|
# Load model and tokenizer
|
||||||
model = AutoModel.from_pretrained(model_dir, device="cuda", dtype=torch.bfloat16)
|
model = AutoModel.from_pretrained(model_dir)
|
||||||
engine = InferenceEngine(model=model.model, tokenizer=model.tokenizer)
|
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||||
|
model.to(device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# Create inference engine
|
||||||
|
engine = InferenceEngine(model=model, tokenizer=tokenizer)
|
||||||
|
|
||||||
with open(input_json_file, "r", encoding="utf-8") as f:
|
with open(input_json_file, "r", encoding="utf-8") as f:
|
||||||
input_data = [json.loads(line) for line in f]
|
input_data = [json.loads(line) for line in f]
|
||||||
|
|
||||||
queries = [item[question_key] for item in input_data]
|
# Check input format: chat messages or raw text
|
||||||
|
if input_data and "messages" in input_data[0]:
|
||||||
|
# Chat format: [{"messages": [...]}]
|
||||||
|
prompts = [
|
||||||
|
tokenizer.apply_chat_template(item["messages"], tokenize=False)
|
||||||
|
for item in input_data
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# Raw text format: [{"question": "..."}]
|
||||||
|
prompts = [item[question_key] for item in input_data]
|
||||||
|
|
||||||
|
# Use provided max_tokens or default to model config max_len
|
||||||
|
if max_tokens is None:
|
||||||
|
max_tokens = model.config.max_len
|
||||||
|
|
||||||
|
# Generate responses (batch)
|
||||||
responses = engine.generate(
|
responses = engine.generate(
|
||||||
prompt=queries,
|
prompt=prompts,
|
||||||
stream=False,
|
stream=False,
|
||||||
max_tokens=model.config.max_len,
|
max_tokens=max_tokens,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Write results
|
||||||
with open(output_json_file, "w", encoding="utf-8") as f:
|
with open(output_json_file, "w", encoding="utf-8") as f:
|
||||||
for query, response in zip(queries, responses):
|
for prompt, response in zip(prompts, responses):
|
||||||
output_item = {question_key: query, response_key: response}
|
if input_data and "messages" in input_data[0]:
|
||||||
|
output_item = {"response": response}
|
||||||
|
else:
|
||||||
|
output_item = {question_key: prompt, response_key: response}
|
||||||
f.write(json.dumps(output_item, ensure_ascii=False) + "\n")
|
f.write(json.dumps(output_item, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
engine.shutdown()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.")
|
parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.")
|
||||||
|
|
@ -90,6 +116,12 @@ if __name__ == "__main__":
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--batch_size", type=int, default=1, help="Batch size for generating responses."
|
"--batch_size", type=int, default=1, help="Batch size for generating responses."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_tokens",
|
||||||
|
type=int,
|
||||||
|
default=2048,
|
||||||
|
help="Maximum tokens to generate (default: model config max_len).",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,62 +2,42 @@ import argparse
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import tqdm
|
import tqdm
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from astrai.model import AutoModel
|
from astrai.model import AutoModel
|
||||||
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
def compute_perplexity(
|
|
||||||
model: nn.Module,
|
|
||||||
input_ids: Tensor,
|
|
||||||
input_mask: Tensor,
|
|
||||||
) -> Tensor:
|
|
||||||
"""
|
|
||||||
Compute the perplexity of a batch of input sequences,
|
|
||||||
where PPL = exp(-(1/N) * sum(log P(w_i | w_<i))).
|
|
||||||
"""
|
|
||||||
|
|
||||||
output = model(input_ids, input_mask=input_mask)
|
|
||||||
logits = output["logits"]
|
|
||||||
|
|
||||||
shifted_logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size]
|
|
||||||
shifted_input_ids = input_ids[:, 1:] # [batch_size, seq_len-1]
|
|
||||||
shifted_mask = input_mask[:, 1:] # [batch_size, seq_len-1]
|
|
||||||
|
|
||||||
loss = F.cross_entropy(
|
|
||||||
shifted_logits.flatten(0, 1), shifted_input_ids.flatten(0, 1), reduction="none"
|
|
||||||
)
|
|
||||||
|
|
||||||
loss = loss.view(shifted_input_ids.shape) # [batch_size, seq_len-1]
|
|
||||||
loss = loss * shifted_mask
|
|
||||||
sentence_loss = (loss).sum(dim=1) / shifted_mask.sum(dim=1)
|
|
||||||
perplexity = torch.exp(sentence_loss) # [batch_size]
|
|
||||||
|
|
||||||
return perplexity
|
|
||||||
|
|
||||||
|
|
||||||
def process_file(
|
def process_file(
|
||||||
model_dir: str, input_file: str, output_file: str, batch_size: int, text_key: str
|
model_dir: str, input_file: str, output_file: str, batch_size: int, text_key: str
|
||||||
):
|
):
|
||||||
# Load model using AutoModel
|
# Load model and tokenizer
|
||||||
model = AutoModel.from_pretrained(model_dir, device="cuda", dtype=torch.bfloat16)
|
model = AutoModel.from_pretrained(model_dir)
|
||||||
tokenizer = model.tokenizer
|
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||||
|
model.to(device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
with open(input_file, "r", encoding="utf-8") as f:
|
with open(input_file, "r", encoding="utf-8") as f:
|
||||||
input_data = [json.loads(line) for line in f]
|
input_data = [json.loads(line) for line in f]
|
||||||
|
|
||||||
texts = [item[text_key] for item in input_data]
|
texts = [item[text_key] for item in input_data]
|
||||||
|
|
||||||
|
# Encode all texts
|
||||||
|
print(f"Encoding {len(texts)} texts...")
|
||||||
encoded_texts = [tokenizer.encode(text) for text in texts]
|
encoded_texts = [tokenizer.encode(text) for text in texts]
|
||||||
|
|
||||||
output_data = []
|
output_data = []
|
||||||
|
total_batches = (len(encoded_texts) + batch_size - 1) // batch_size
|
||||||
|
|
||||||
for i in tqdm.tqdm(
|
for i in tqdm.tqdm(
|
||||||
range(0, len(encoded_texts), batch_size), desc="Computing perplexity"
|
range(0, len(encoded_texts), batch_size),
|
||||||
|
total=total_batches,
|
||||||
|
desc="Computing perplexity",
|
||||||
):
|
):
|
||||||
batch_encoded = encoded_texts[i : i + batch_size]
|
batch_encoded = encoded_texts[i : i + batch_size]
|
||||||
batch_texts = texts[i : i + batch_size]
|
batch_texts = texts[i : i + batch_size]
|
||||||
|
|
||||||
|
# Find max length in batch and pad
|
||||||
max_len = max(len(seq) for seq in batch_encoded)
|
max_len = max(len(seq) for seq in batch_encoded)
|
||||||
padded_ids = []
|
padded_ids = []
|
||||||
masks = []
|
masks = []
|
||||||
|
|
@ -69,17 +49,41 @@ def process_file(
|
||||||
padded_ids.append(padded_seq)
|
padded_ids.append(padded_seq)
|
||||||
masks.append(mask)
|
masks.append(mask)
|
||||||
|
|
||||||
|
# Convert to tensors
|
||||||
input_ids = torch.tensor(padded_ids, device="cuda", dtype=torch.long)
|
input_ids = torch.tensor(padded_ids, device="cuda", dtype=torch.long)
|
||||||
input_mask = torch.tensor(masks, device="cuda", dtype=torch.bool)
|
input_mask = torch.tensor(masks, device="cuda", dtype=torch.bool)
|
||||||
perplexity = compute_perplexity(model.model, input_ids, input_mask)
|
|
||||||
|
# Compute perplexity
|
||||||
|
output = model(input_ids, input_mask=input_mask)
|
||||||
|
logits = output["logits"]
|
||||||
|
|
||||||
|
# Shift for causal language modeling
|
||||||
|
shifted_logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size]
|
||||||
|
shifted_input_ids = input_ids[:, 1:] # [batch_size, seq_len-1]
|
||||||
|
shifted_mask = input_mask[:, 1:] # [batch_size, seq_len-1]
|
||||||
|
|
||||||
|
# Compute cross entropy loss
|
||||||
|
loss = F.cross_entropy(
|
||||||
|
shifted_logits.flatten(0, 1),
|
||||||
|
shifted_input_ids.flatten(0, 1),
|
||||||
|
reduction="none",
|
||||||
|
)
|
||||||
|
|
||||||
|
loss = loss.view(shifted_input_ids.shape) # [batch_size, seq_len-1]
|
||||||
|
loss = loss * shifted_mask
|
||||||
|
sentence_loss = loss.sum(dim=1) / shifted_mask.sum(dim=1).clamp(min=1)
|
||||||
|
perplexity = torch.exp(sentence_loss) # [batch_size]
|
||||||
|
|
||||||
for text, ppl in zip(batch_texts, perplexity):
|
for text, ppl in zip(batch_texts, perplexity):
|
||||||
output_data.append({text_key: text, "ppl": float(ppl.item())})
|
output_data.append({text_key: text, "ppl": float(ppl.item())})
|
||||||
|
|
||||||
|
# Write results
|
||||||
with open(output_file, "w", encoding="utf-8") as f:
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
for item in output_data:
|
for item in output_data:
|
||||||
f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
|
print(f"Perplexity computation complete. Results saved to {output_file}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
|
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
import argparse
|
import argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from astrai.inference.server import run_server
|
from astrai.inference.server import run_server
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -21,17 +23,49 @@ def main():
|
||||||
default=None,
|
default=None,
|
||||||
help="Path to model parameters (default: project_root/params)",
|
help="Path to model parameters (default: project_root/params)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="cuda",
|
||||||
|
help="Device to load model on (default: cuda)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
type=str,
|
||||||
|
default="bfloat16",
|
||||||
|
choices=["bfloat16", "float16", "float32"],
|
||||||
|
help="Data type for model weights (default: bfloat16)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_batch_size",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help="Maximum batch size for continuous batching (default: 16)",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# If param_path is provided, set environment variable or modify global?
|
# Convert dtype string to torch dtype
|
||||||
# Currently the server loads from default location on startup.
|
dtype_map = {
|
||||||
# We could pass it via an environment variable, but for simplicity we assume
|
"bfloat16": torch.bfloat16,
|
||||||
# the default location is correct.
|
"float16": torch.float16,
|
||||||
project_root = Path(__file__).parent.parent
|
"float32": torch.float32,
|
||||||
|
}
|
||||||
|
dtype = dtype_map[args.dtype]
|
||||||
|
|
||||||
|
project_root = Path(__file__).parent.parent.parent
|
||||||
param_path = args.param_path or (project_root / "params")
|
param_path = args.param_path or (project_root / "params")
|
||||||
print(f"Starting AstrAI inference server on http://{args.host}:{args.port}")
|
print(f"Starting AstrAI inference server on http://{args.host}:{args.port}")
|
||||||
print(f"Model parameters expected at: {[param_path]}")
|
print(f"Model parameters expected at: {param_path}")
|
||||||
run_server(host=args.host, port=args.port, reload=args.reload)
|
print(f"Device: {args.device}, Dtype: {args.dtype}")
|
||||||
|
run_server(
|
||||||
|
host=args.host,
|
||||||
|
port=args.port,
|
||||||
|
reload=args.reload,
|
||||||
|
device=args.device,
|
||||||
|
dtype=dtype,
|
||||||
|
param_path=param_path,
|
||||||
|
max_batch_size=args.max_batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue