refactor: 优化工具脚本接口并修复批处理问题
This commit is contained in:
parent
e58dbd7c57
commit
9b22b1651e
|
|
@ -269,13 +269,20 @@ class InferenceEngine:
|
|||
result = _NonStreamingResult(len(prompts))
|
||||
|
||||
for i, p in enumerate(prompts):
|
||||
# Create closure to capture current index value using factory function
|
||||
def make_callback(idx):
|
||||
def callback(token):
|
||||
result.append(idx, token)
|
||||
|
||||
return callback
|
||||
|
||||
self.scheduler.add_task(
|
||||
prompt=p,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
stream_callback=result.append,
|
||||
stream_callback=make_callback(i),
|
||||
)
|
||||
|
||||
result.wait()
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ def load_model(
|
|||
|
||||
# Load tokenizer separately
|
||||
tokenizer = TextTokenizer.from_pretrained(param_path)
|
||||
_model_param = AutoModel.from_pretrained(param_path, tokenizer=tokenizer)
|
||||
_model_param = AutoModel.from_pretrained(param_path)
|
||||
_model_param.to(device=device, dtype=dtype)
|
||||
logger.info(f"Model loaded on {device} with dtype {dtype}")
|
||||
|
||||
|
|
|
|||
|
|
@ -13,11 +13,9 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
|||
def generate_text():
|
||||
# Load model from pretrained
|
||||
model = AutoModel.from_pretrained(PARAMETER_ROOT)
|
||||
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
|
||||
model.to(device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Load tokenizer from pretrained
|
||||
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT / "tokenizer")
|
||||
|
||||
query = input(">> ")
|
||||
|
||||
engine = InferenceEngine(
|
||||
|
|
|
|||
|
|
@ -2,8 +2,9 @@ from pathlib import Path
|
|||
|
||||
import torch
|
||||
|
||||
from astrai.model import AutoModel
|
||||
from astrai.inference import InferenceEngine
|
||||
from astrai.model import AutoModel
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||
|
|
@ -11,9 +12,9 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
|||
|
||||
def batch_generate():
|
||||
# Load model using AutoModel
|
||||
model = AutoModel.from_pretrained(
|
||||
PARAMETER_ROOT, device="cuda", dtype=torch.bfloat16
|
||||
)
|
||||
model = AutoModel.from_pretrained(PARAMETER_ROOT)
|
||||
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
|
||||
model.to(device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
inputs = [
|
||||
"你好",
|
||||
|
|
@ -24,13 +25,13 @@ def batch_generate():
|
|||
]
|
||||
|
||||
engine = InferenceEngine(
|
||||
model=model.model,
|
||||
tokenizer=model.tokenizer,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
responses = engine.generate(
|
||||
prompt=inputs,
|
||||
stream=False,
|
||||
max_tokens=model.config.max_len,
|
||||
max_tokens=2048,
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
top_k=50,
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ def chat():
|
|||
for token in engine.generate(
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
max_tokens=model.config.max_len,
|
||||
max_tokens=2048,
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
top_k=50,
|
||||
|
|
|
|||
|
|
@ -17,30 +17,56 @@ def processor(
|
|||
top_p: float,
|
||||
question_key: str,
|
||||
response_key: str,
|
||||
max_tokens: int,
|
||||
):
|
||||
# Load model using AutoModel
|
||||
model = AutoModel.from_pretrained(model_dir, device="cuda", dtype=torch.bfloat16)
|
||||
engine = InferenceEngine(model=model.model, tokenizer=model.tokenizer)
|
||||
# Load model and tokenizer
|
||||
model = AutoModel.from_pretrained(model_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
model.to(device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Create inference engine
|
||||
engine = InferenceEngine(model=model, tokenizer=tokenizer)
|
||||
|
||||
with open(input_json_file, "r", encoding="utf-8") as f:
|
||||
input_data = [json.loads(line) for line in f]
|
||||
|
||||
queries = [item[question_key] for item in input_data]
|
||||
# Check input format: chat messages or raw text
|
||||
if input_data and "messages" in input_data[0]:
|
||||
# Chat format: [{"messages": [...]}]
|
||||
prompts = [
|
||||
tokenizer.apply_chat_template(item["messages"], tokenize=False)
|
||||
for item in input_data
|
||||
]
|
||||
else:
|
||||
# Raw text format: [{"question": "..."}]
|
||||
prompts = [item[question_key] for item in input_data]
|
||||
|
||||
# Use provided max_tokens or default to model config max_len
|
||||
if max_tokens is None:
|
||||
max_tokens = model.config.max_len
|
||||
|
||||
# Generate responses (batch)
|
||||
responses = engine.generate(
|
||||
prompt=queries,
|
||||
prompt=prompts,
|
||||
stream=False,
|
||||
max_tokens=model.config.max_len,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
# Write results
|
||||
with open(output_json_file, "w", encoding="utf-8") as f:
|
||||
for query, response in zip(queries, responses):
|
||||
output_item = {question_key: query, response_key: response}
|
||||
for prompt, response in zip(prompts, responses):
|
||||
if input_data and "messages" in input_data[0]:
|
||||
output_item = {"response": response}
|
||||
else:
|
||||
output_item = {question_key: prompt, response_key: response}
|
||||
f.write(json.dumps(output_item, ensure_ascii=False) + "\n")
|
||||
|
||||
# Cleanup
|
||||
engine.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.")
|
||||
|
|
@ -90,6 +116,12 @@ if __name__ == "__main__":
|
|||
parser.add_argument(
|
||||
"--batch_size", type=int, default=1, help="Batch size for generating responses."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_tokens",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Maximum tokens to generate (default: model config max_len).",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
|
|||
|
|
@ -2,62 +2,42 @@ import argparse
|
|||
import json
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import tqdm
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.model import AutoModel
|
||||
|
||||
|
||||
def compute_perplexity(
|
||||
model: nn.Module,
|
||||
input_ids: Tensor,
|
||||
input_mask: Tensor,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Compute the perplexity of a batch of input sequences,
|
||||
where PPL = exp(-(1/N) * sum(log P(w_i | w_<i))).
|
||||
"""
|
||||
|
||||
output = model(input_ids, input_mask=input_mask)
|
||||
logits = output["logits"]
|
||||
|
||||
shifted_logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size]
|
||||
shifted_input_ids = input_ids[:, 1:] # [batch_size, seq_len-1]
|
||||
shifted_mask = input_mask[:, 1:] # [batch_size, seq_len-1]
|
||||
|
||||
loss = F.cross_entropy(
|
||||
shifted_logits.flatten(0, 1), shifted_input_ids.flatten(0, 1), reduction="none"
|
||||
)
|
||||
|
||||
loss = loss.view(shifted_input_ids.shape) # [batch_size, seq_len-1]
|
||||
loss = loss * shifted_mask
|
||||
sentence_loss = (loss).sum(dim=1) / shifted_mask.sum(dim=1)
|
||||
perplexity = torch.exp(sentence_loss) # [batch_size]
|
||||
|
||||
return perplexity
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
|
||||
def process_file(
|
||||
model_dir: str, input_file: str, output_file: str, batch_size: int, text_key: str
|
||||
):
|
||||
# Load model using AutoModel
|
||||
model = AutoModel.from_pretrained(model_dir, device="cuda", dtype=torch.bfloat16)
|
||||
tokenizer = model.tokenizer
|
||||
# Load model and tokenizer
|
||||
model = AutoModel.from_pretrained(model_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
model.to(device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
with open(input_file, "r", encoding="utf-8") as f:
|
||||
input_data = [json.loads(line) for line in f]
|
||||
|
||||
texts = [item[text_key] for item in input_data]
|
||||
|
||||
# Encode all texts
|
||||
print(f"Encoding {len(texts)} texts...")
|
||||
encoded_texts = [tokenizer.encode(text) for text in texts]
|
||||
|
||||
output_data = []
|
||||
total_batches = (len(encoded_texts) + batch_size - 1) // batch_size
|
||||
|
||||
for i in tqdm.tqdm(
|
||||
range(0, len(encoded_texts), batch_size), desc="Computing perplexity"
|
||||
range(0, len(encoded_texts), batch_size),
|
||||
total=total_batches,
|
||||
desc="Computing perplexity",
|
||||
):
|
||||
batch_encoded = encoded_texts[i : i + batch_size]
|
||||
batch_texts = texts[i : i + batch_size]
|
||||
|
||||
# Find max length in batch and pad
|
||||
max_len = max(len(seq) for seq in batch_encoded)
|
||||
padded_ids = []
|
||||
masks = []
|
||||
|
|
@ -69,17 +49,41 @@ def process_file(
|
|||
padded_ids.append(padded_seq)
|
||||
masks.append(mask)
|
||||
|
||||
# Convert to tensors
|
||||
input_ids = torch.tensor(padded_ids, device="cuda", dtype=torch.long)
|
||||
input_mask = torch.tensor(masks, device="cuda", dtype=torch.bool)
|
||||
perplexity = compute_perplexity(model.model, input_ids, input_mask)
|
||||
|
||||
# Compute perplexity
|
||||
output = model(input_ids, input_mask=input_mask)
|
||||
logits = output["logits"]
|
||||
|
||||
# Shift for causal language modeling
|
||||
shifted_logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size]
|
||||
shifted_input_ids = input_ids[:, 1:] # [batch_size, seq_len-1]
|
||||
shifted_mask = input_mask[:, 1:] # [batch_size, seq_len-1]
|
||||
|
||||
# Compute cross entropy loss
|
||||
loss = F.cross_entropy(
|
||||
shifted_logits.flatten(0, 1),
|
||||
shifted_input_ids.flatten(0, 1),
|
||||
reduction="none",
|
||||
)
|
||||
|
||||
loss = loss.view(shifted_input_ids.shape) # [batch_size, seq_len-1]
|
||||
loss = loss * shifted_mask
|
||||
sentence_loss = loss.sum(dim=1) / shifted_mask.sum(dim=1).clamp(min=1)
|
||||
perplexity = torch.exp(sentence_loss) # [batch_size]
|
||||
|
||||
for text, ppl in zip(batch_texts, perplexity):
|
||||
output_data.append({text_key: text, "ppl": float(ppl.item())})
|
||||
|
||||
# Write results
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
for item in output_data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
||||
|
||||
print(f"Perplexity computation complete. Results saved to {output_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from astrai.inference.server import run_server
|
||||
|
||||
|
||||
|
|
@ -21,17 +23,49 @@ def main():
|
|||
default=None,
|
||||
help="Path to model parameters (default: project_root/params)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="Device to load model on (default: cuda)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
default="bfloat16",
|
||||
choices=["bfloat16", "float16", "float32"],
|
||||
help="Data type for model weights (default: bfloat16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_batch_size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Maximum batch size for continuous batching (default: 16)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# If param_path is provided, set environment variable or modify global?
|
||||
# Currently the server loads from default location on startup.
|
||||
# We could pass it via an environment variable, but for simplicity we assume
|
||||
# the default location is correct.
|
||||
project_root = Path(__file__).parent.parent
|
||||
# Convert dtype string to torch dtype
|
||||
dtype_map = {
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float16": torch.float16,
|
||||
"float32": torch.float32,
|
||||
}
|
||||
dtype = dtype_map[args.dtype]
|
||||
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
param_path = args.param_path or (project_root / "params")
|
||||
print(f"Starting AstrAI inference server on http://{args.host}:{args.port}")
|
||||
print(f"Model parameters expected at: {[param_path]}")
|
||||
run_server(host=args.host, port=args.port, reload=args.reload)
|
||||
print(f"Model parameters expected at: {param_path}")
|
||||
print(f"Device: {args.device}, Dtype: {args.dtype}")
|
||||
run_server(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
reload=args.reload,
|
||||
device=args.device,
|
||||
dtype=dtype,
|
||||
param_path=param_path,
|
||||
max_batch_size=args.max_batch_size,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in New Issue