From 9b22b1651ec16ea36b5cc1236b429dff40d59652 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 5 Apr 2026 21:56:22 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E4=BC=98=E5=8C=96=E5=B7=A5?= =?UTF-8?q?=E5=85=B7=E8=84=9A=E6=9C=AC=E6=8E=A5=E5=8F=A3=E5=B9=B6=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E6=89=B9=E5=A4=84=E7=90=86=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrai/inference/engine.py | 9 +++- astrai/inference/server.py | 2 +- scripts/demo/generate_ar.py | 4 +- scripts/demo/generate_batch.py | 15 +++---- scripts/demo/stream_chat.py | 2 +- scripts/tools/generate.py | 48 +++++++++++++++++---- scripts/tools/perplexity.py | 76 ++++++++++++++++++---------------- scripts/tools/server.py | 48 +++++++++++++++++---- 8 files changed, 140 insertions(+), 64 deletions(-) diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py index 7f83047..7b0aba3 100644 --- a/astrai/inference/engine.py +++ b/astrai/inference/engine.py @@ -269,13 +269,20 @@ class InferenceEngine: result = _NonStreamingResult(len(prompts)) for i, p in enumerate(prompts): + # Create closure to capture current index value using factory function + def make_callback(idx): + def callback(token): + result.append(idx, token) + + return callback + self.scheduler.add_task( prompt=p, max_tokens=max_tokens, temperature=temperature, top_p=top_p, top_k=top_k, - stream_callback=result.append, + stream_callback=make_callback(i), ) result.wait() diff --git a/astrai/inference/server.py b/astrai/inference/server.py index aa6b801..d7cec2d 100644 --- a/astrai/inference/server.py +++ b/astrai/inference/server.py @@ -97,7 +97,7 @@ def load_model( # Load tokenizer separately tokenizer = TextTokenizer.from_pretrained(param_path) - _model_param = AutoModel.from_pretrained(param_path, tokenizer=tokenizer) + _model_param = AutoModel.from_pretrained(param_path) _model_param.to(device=device, dtype=dtype) logger.info(f"Model loaded on {device} with dtype {dtype}") diff --git a/scripts/demo/generate_ar.py b/scripts/demo/generate_ar.py index 698ecec..80ec033 100644 --- a/scripts/demo/generate_ar.py +++ b/scripts/demo/generate_ar.py @@ -13,11 +13,9 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params") def generate_text(): # Load model from pretrained model = AutoModel.from_pretrained(PARAMETER_ROOT) + tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT) model.to(device="cuda", dtype=torch.bfloat16) - # Load tokenizer from pretrained - tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT / "tokenizer") - query = input(">> ") engine = InferenceEngine( diff --git a/scripts/demo/generate_batch.py b/scripts/demo/generate_batch.py index 0662d3a..a074976 100644 --- a/scripts/demo/generate_batch.py +++ b/scripts/demo/generate_batch.py @@ -2,8 +2,9 @@ from pathlib import Path import torch -from astrai.model import AutoModel from astrai.inference import InferenceEngine +from astrai.model import AutoModel +from astrai.tokenize import AutoTokenizer PROJECT_ROOT = Path(__file__).resolve().parents[2] PARAMETER_ROOT = Path(PROJECT_ROOT, "params") @@ -11,9 +12,9 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params") def batch_generate(): # Load model using AutoModel - model = AutoModel.from_pretrained( - PARAMETER_ROOT, device="cuda", dtype=torch.bfloat16 - ) + model = AutoModel.from_pretrained(PARAMETER_ROOT) + tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT) + model.to(device="cuda", dtype=torch.bfloat16) inputs = [ "你好", @@ -24,13 +25,13 @@ def batch_generate(): ] engine = InferenceEngine( - model=model.model, - tokenizer=model.tokenizer, + model=model, + tokenizer=tokenizer, ) responses = engine.generate( prompt=inputs, stream=False, - max_tokens=model.config.max_len, + max_tokens=2048, temperature=0.8, top_p=0.95, top_k=50, diff --git a/scripts/demo/stream_chat.py b/scripts/demo/stream_chat.py index 1a9af74..1a5ce95 100644 --- a/scripts/demo/stream_chat.py +++ b/scripts/demo/stream_chat.py @@ -33,7 +33,7 @@ def chat(): for token in engine.generate( prompt=prompt, stream=True, - max_tokens=model.config.max_len, + max_tokens=2048, temperature=0.8, top_p=0.95, top_k=50, diff --git a/scripts/tools/generate.py b/scripts/tools/generate.py index 641c4ad..3ab02c9 100644 --- a/scripts/tools/generate.py +++ b/scripts/tools/generate.py @@ -17,30 +17,56 @@ def processor( top_p: float, question_key: str, response_key: str, + max_tokens: int, ): - # Load model using AutoModel - model = AutoModel.from_pretrained(model_dir, device="cuda", dtype=torch.bfloat16) - engine = InferenceEngine(model=model.model, tokenizer=model.tokenizer) + # Load model and tokenizer + model = AutoModel.from_pretrained(model_dir) + tokenizer = AutoTokenizer.from_pretrained(model_dir) + model.to(device="cuda", dtype=torch.bfloat16) + + # Create inference engine + engine = InferenceEngine(model=model, tokenizer=tokenizer) with open(input_json_file, "r", encoding="utf-8") as f: input_data = [json.loads(line) for line in f] - queries = [item[question_key] for item in input_data] + # Check input format: chat messages or raw text + if input_data and "messages" in input_data[0]: + # Chat format: [{"messages": [...]}] + prompts = [ + tokenizer.apply_chat_template(item["messages"], tokenize=False) + for item in input_data + ] + else: + # Raw text format: [{"question": "..."}] + prompts = [item[question_key] for item in input_data] + # Use provided max_tokens or default to model config max_len + if max_tokens is None: + max_tokens = model.config.max_len + + # Generate responses (batch) responses = engine.generate( - prompt=queries, + prompt=prompts, stream=False, - max_tokens=model.config.max_len, + max_tokens=max_tokens, temperature=temperature, top_p=top_p, top_k=top_k, ) + # Write results with open(output_json_file, "w", encoding="utf-8") as f: - for query, response in zip(queries, responses): - output_item = {question_key: query, response_key: response} + for prompt, response in zip(prompts, responses): + if input_data and "messages" in input_data[0]: + output_item = {"response": response} + else: + output_item = {question_key: prompt, response_key: response} f.write(json.dumps(output_item, ensure_ascii=False) + "\n") + # Cleanup + engine.shutdown() + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.") @@ -90,6 +116,12 @@ if __name__ == "__main__": parser.add_argument( "--batch_size", type=int, default=1, help="Batch size for generating responses." ) + parser.add_argument( + "--max_tokens", + type=int, + default=2048, + help="Maximum tokens to generate (default: model config max_len).", + ) args = parser.parse_args() diff --git a/scripts/tools/perplexity.py b/scripts/tools/perplexity.py index ebc4060..a410320 100644 --- a/scripts/tools/perplexity.py +++ b/scripts/tools/perplexity.py @@ -2,62 +2,42 @@ import argparse import json import torch -import torch.nn as nn import torch.nn.functional as F import tqdm -from torch import Tensor from astrai.model import AutoModel - - -def compute_perplexity( - model: nn.Module, - input_ids: Tensor, - input_mask: Tensor, -) -> Tensor: - """ - Compute the perplexity of a batch of input sequences, - where PPL = exp(-(1/N) * sum(log P(w_i | w_