refactor: 优化接口设置, 去除冗余代码

This commit is contained in:
ViperEkura 2026-03-18 15:07:35 +08:00
parent e23a5ca426
commit 62fba9a298
8 changed files with 172 additions and 298 deletions

View File

@ -12,7 +12,7 @@ def batch_generate():
inputs = ["你好", "请问什么是人工智能", "今天天气如何", "我感到焦虑, 请问我应该怎么办", "请问什么是显卡"]
responses = model.batch_generate(
queries=inputs,
query=inputs,
temperature=0.8,
top_p=0.95,
top_k=50

View File

@ -26,9 +26,10 @@ if __name__ == "__main__":
query = "作者设计了一个怎样的模型"
emb_query = model.encode(query)
retrieved = retriever.retrieve(emb_query, retrive_top_k)
retrieved_content = "\n".join([f"{idx + 1}. " + text for idx, (text, _) in enumerate(retrieved)])
retrive_response = model.retrieve_generate(
retrieved=retrieved,
retrieved=retrieved_content,
query=query,
temperature=0.8,
top_p=0.95,
@ -36,7 +37,7 @@ if __name__ == "__main__":
)
print("retrieve content:")
print("\n".join([f"{idx + 1}. " + text for idx, (text, _) in enumerate(retrieved)]))
print(retrieved_content)
print("\n\nretrive generate:")
print(retrive_response)

View File

@ -17,12 +17,11 @@ from khaosz.data import (
BpeTokenizer
)
from khaosz.inference.generator import (
TextGenerator,
ChatGenerator,
GenerationRequest,
LoopGenerator,
StreamGenerator,
BatchGenerator,
RetrievalGenerator,
EmbeddingEncoder
EmbeddingEncoder,
)
from khaosz.trainer import (
@ -46,11 +45,10 @@ __all__ = [
"DatasetLoader",
"BpeTokenizer",
"TextGenerator",
"ChatGenerator",
"GenerationRequest",
"LoopGenerator",
"StreamGenerator",
"BatchGenerator",
"RetrievalGenerator",
"EmbeddingEncoder",
"Trainer",

View File

@ -1,21 +1,42 @@
from torch import nn
from torch import Tensor
from contextlib import contextmanager
from typing import List, Tuple, Generator, Union
from khaosz.inference.generator import (
TextGenerator,
ChatGenerator,
GenerationRequest,
LoopGenerator,
StreamGenerator,
BatchGenerator,
RetrievalGenerator,
EmbeddingEncoder
)
from khaosz.config.param_config import ModelParameter
@contextmanager
def disable_random_init():
init_functions = [
'xavier_normal_', 'xavier_uniform_',
'kaiming_normal_', 'kaiming_uniform_',
'zeros_', 'ones_', 'constant_',
'normal_', 'uniform_'
]
original_funcs = {}
for name in init_functions:
if hasattr(nn.init, name):
original_funcs[name] = getattr(nn.init, name)
setattr(nn.init, name, lambda *args, **kwargs: None)
try:
yield
finally:
for name, orig_func in original_funcs.items():
setattr(nn.init, name, orig_func)
class Khaosz:
def __init__(self, model_dir: str):
self.parameter = ModelParameter()
self.parameter.load(model_dir)
with disable_random_init():
self.parameter = ModelParameter()
self.parameter.load(model_dir)
def to(self, *args, **kwargs):
self.parameter.to(*args, **kwargs)
@ -29,32 +50,33 @@ class Khaosz:
top_k: int=50,
top_p: float=0.95,
) -> str:
generator = ChatGenerator(self.parameter)
generator = LoopGenerator(self.parameter)
return generator.generate(
query,
GenerationRequest(
top_k, top_p, temperature,
self.parameter.config.max_len,
query=query,
history=history,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
build_prompt=True
))
def batch_generate(
self,
queries: List[str],
histories: List[Tuple[str, str]]=None,
query: List[str],
history: List[Tuple[str, str]]=None,
temperature: float=0.8,
top_k: int=50,
top_p: float=0.95,
) -> List[str]:
generator = BatchGenerator(self.parameter)
return generator.generate(
queries,
histories=histories,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
GenerationRequest(
top_k, top_p, temperature,
self.parameter.config.max_len,
query=query,
history=history,
build_prompt=True
))
def stream_generate(
self,
@ -66,12 +88,13 @@ class Khaosz:
) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
stream_generator = StreamGenerator(self.parameter)
return stream_generator.generate(
query,
history=history,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
GenerationRequest(
top_k, top_p, temperature,
self.parameter.config.max_len,
query=query,
history=history,
build_prompt=True
))
def retrieve_generate(
self,
@ -82,15 +105,16 @@ class Khaosz:
top_k: int=50,
top_p: float=0.95,
) -> str:
generator = RetrievalGenerator(self.parameter)
generator = LoopGenerator(self.parameter)
return generator.generate(
retrieved,
query,
GenerationRequest(
top_k, top_p, temperature,
self.parameter.config.max_len,
query=query,
history=history,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
system_prompt=retrieved,
build_prompt=True
))
def text_generate(
self,
@ -99,14 +123,15 @@ class Khaosz:
top_k: int=50,
top_p: float=0.95,
) -> str:
generator = TextGenerator(self.parameter)
generator = LoopGenerator(self.parameter)
return generator.generate(
GenerationRequest(
top_k, top_p, temperature,
self.parameter.config.max_len,
query=query,
build_prompt=False
))
return generator.generate(
query,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
encoder = EmbeddingEncoder(self.parameter)

View File

@ -1 +1,25 @@
# init file
from khaosz.inference.core import (
GeneratorCore,
EmbeddingEncoderCore,
KVCacheManager,
)
from khaosz.inference.generator import (
GenerationRequest,
LoopGenerator,
StreamGenerator,
BatchGenerator,
EmbeddingEncoder,
)
__all__ = [
"GeneratorCore",
"EmbeddingEncoderCore",
"KVCacheManager",
"GenerationRequest",
"LoopGenerator",
"StreamGenerator",
"BatchGenerator",
"EmbeddingEncoder",
]

View File

@ -1,98 +0,0 @@
import torch
from torch import Tensor
from functools import wraps
from inspect import signature
class CudaGraphWrapper:
def __init__(self, function, device="cuda", cast=False):
self.function = function
self.cast = cast
self.device = device
self.static_input = None
self.static_output = None
self.graph = None
self.signature = signature(function)
def _update_inplace(self, lhs, rhs):
if isinstance(lhs, Tensor) and isinstance(rhs, Tensor):
if lhs.shape != rhs.shape:
raise ValueError(
f"Tensor shape mismatch! "
f"Expected: {lhs.shape}, Got: {rhs.shape}. "
f"Function: {self.function}"
)
if self.cast:
if lhs.device != rhs.device:
rhs = rhs.to(device=lhs.device)
if lhs.dtype != rhs.dtype:
rhs = rhs.to(dtype=lhs.dtype)
else:
if lhs.device != rhs.device:
raise ValueError(
f"Tensor device mismatch! "
f"Expected: {lhs.device}, Got: {rhs.device}. "
f"Function: {self.function}"
)
if lhs.dtype != rhs.dtype:
raise ValueError(
f"Tensor dtype mismatch! "
f"Expected: {lhs.dtype}, Got: {rhs.dtype}. "
f"Function: {self.function}"
)
lhs.copy_(rhs)
elif isinstance(lhs, dict):
for k in lhs:
if k in rhs:
self._update_inplace(lhs[k], rhs[k])
elif isinstance(lhs, (list, tuple)):
for i in range(len(lhs)):
if i < len(rhs):
self._update_inplace(lhs[i], rhs[i])
elif isinstance(lhs, (int, float, bool, str, type(None))):
if lhs != rhs:
raise ValueError("Does not support changing control parameters.")
def _update_args(self, input_args, input_kwargs):
bound_args = self.signature.bind(*input_args, **input_kwargs)
bound_args.apply_defaults()
args_dict = bound_args.arguments
if self.static_input is None:
self.static_input = args_dict
else:
self._update_inplace(self.static_input, args_dict)
def run(self, *args, **kwargs):
self._update_args(args, kwargs)
if self.graph is None:
# warmup
_ = torch.matmul(
torch.randn(100, 100, device=self.device),
torch.randn(100, 100, device=self.device)
)
torch.cuda.synchronize()
# capture graph
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
self.static_output = self.function(**self.static_input)
self.graph.replay()
return self.static_output
def cuda_graph(device="cuda", cast=False):
def decorator(func):
wrapper = CudaGraphWrapper(func, device, cast)
@wraps(func)
def wrapped(*args, **kwargs):
return wrapper.run(*args, **kwargs)
return wrapped
return decorator

View File

@ -1,10 +1,13 @@
import torch
from dataclasses import dataclass
from torch import Tensor
from typing import List, Tuple, Union, Optional, Generator
from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheManager
from khaosz.config.param_config import ModelParameter
HistoryType = List[Tuple[str, str]]
def build_prompt(
query: str,
init_prompt: Optional[str] = None,
@ -34,7 +37,7 @@ def build_prompt(
return prompt
def pad_sequence(ids_list: List[List[int]], max_ids_len: int, pad_id: int) -> List[List[int]]:
def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]:
"""
Pad a list of sequences to a fixed length.
@ -47,34 +50,47 @@ def pad_sequence(ids_list: List[List[int]], max_ids_len: int, pad_id: int) -> Li
List[List[int]]: A list of padded sequences.
"""
max_ids_len = max(len(ids) for ids in ids_list)
new_ids_list = []
for ids in ids_list:
pad_len = max_ids_len - len(ids)
padded_seq = [pad_id] * pad_len + ids
new_ids_list.append(padded_seq)
return new_ids_list
return new_ids_list, max_ids_len
@dataclass
class GenerationRequest:
top_k: int
top_p: float
temperature: float
max_len: int
query: Union[str, List[str]]
history: Optional[Union[HistoryType, List[HistoryType]]] = None
system_prompt: Optional[str] = None
build_prompt: bool = True
def __post_init__(self):
if not isinstance(self.top_k, int) or self.top_k < 0:
raise ValueError("top_k must be a non-negative integer")
if not isinstance(self.top_p, float) or self.top_p < 0.0 or self.top_p > 1.0:
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not isinstance(self.temperature, float) or self.temperature < 0.0:
raise ValueError("temperature must be a non-negative float")
class TextGenerator(GeneratorCore):
class LoopGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def generate(
self,
query: str,
temperature: float,
top_k: int,
top_p: float,
) -> str:
assert temperature >= 0.0
assert top_k >= 0
assert top_p >= 0.0 and top_p <= 1.0
def generate(self, request: GenerationRequest) -> str:
device = next(self.model.parameters()).device
cache_manager = KVCacheManager(self.config, 1, device=device)
ids = self.tokenizer.encode(query)
input_args = build_prompt(request.query, request.history) if request.build_prompt else request.query
ids = self.tokenizer.encode(input_args)
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
start_cache_pos = len(ids)
@ -83,53 +99,10 @@ class TextGenerator(GeneratorCore):
kv_caches = cache_manager.get_kvcache()
ids = self.generate_loop(
input_ids, ids, temperature, top_k, top_p,
input_ids, ids, request.temperature, request.top_k, request.top_p,
kv_caches=kv_caches,
start_pos=cur_cache_pos
)
response = self.tokenizer.decode(ids[start_cache_pos:])
return response
class ChatGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def generate(
self,
query: str,
history: List[Tuple[str, str]],
temperature: float,
top_k: int,
top_p: float,
) -> str:
assert temperature >= 0.0
assert top_k >= 0
assert top_p >= 0.0 and top_p <= 1.0
if history is None:
history = []
device = next(self.model.parameters()).device
cache_manager = KVCacheManager(self.config, 1, device=device)
ids = self.tokenizer.encode(build_prompt(query, history))
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
start_cache_pos = len(ids)
cur_cache_pos = 0
self.model.eval()
kv_caches = cache_manager.get_kvcache()
ids = self.generate_loop(
input_ids, ids, temperature, top_k, top_p,
kv_caches=kv_caches,
start_pos=cur_cache_pos
)
response = self.tokenizer.decode(ids[start_cache_pos:])
return response
@ -139,28 +112,17 @@ class StreamGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def generate(
self,
query: str,
history: List[Tuple[str, str]],
temperature: float,
top_k: int,
top_p: float,
) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
def generate(self, request: GenerationRequest) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
assert temperature >= 0.0
assert top_k >= 0
assert top_p >= 0.0 and top_p <= 1.0
if history is None:
history = []
if request.history is None:
request.history = []
device = next(self.model.parameters()).device
cache_manager = KVCacheManager(self.config, 1, device=device)
ids = self.tokenizer.encode(build_prompt(query, history))
input_args = build_prompt(request.query, request.history) if request.build_prompt else request.query
ids = self.tokenizer.encode(input_args)
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
cpy_history = history.copy()
start_cache_pos = len(ids)
cur_cache_pos = 0
@ -169,17 +131,20 @@ class StreamGenerator(GeneratorCore):
for _ in range(len(ids), self.config.max_len):
next_token_id, cache_increase = self.generate_iterator(
input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos)
input_ids, request.temperature, request.top_k, request.top_p,
kv_caches=kv_caches,
start_pos=cur_cache_pos
)
input_ids = next_token_id
ids.append(next_token_id.item())
cur_cache_pos += cache_increase
response = self.tokenizer.decode(ids[start_cache_pos:])
yield response, cpy_history + [(query, response)]
yield response, request.history + [(request.query, response)]
if next_token_id.item() in self.tokenizer.stop_ids:
yield response + "\n", cpy_history + [(query, response)]
yield response + "\n", request.history + [(request.query, response)]
break
@ -187,27 +152,15 @@ class BatchGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def generate(
self,
queries: List[str],
histories: List[List[Tuple[str, str]]],
temperature: float,
top_k: int,
top_p: float
) -> List[str]:
def generate(self, request: GenerationRequest) -> List[str]:
batch_size = len(request.query)
if request.history is None:
request.history = [[] for _ in range(batch_size)]
assert temperature >= 0.0
assert top_k >= 0
assert top_p >= 0.0 and top_p <= 1.0
prompts = [build_prompt(query, history) for query, history in zip(request.query, request.history)]
batch_size = len(queries)
if histories is None:
histories = [[] for _ in range(batch_size)]
prompts = [build_prompt(query, history) for query, history in zip(queries, histories)]
ids_list = [self.tokenizer.encode(prompt) for prompt in prompts]
max_ids_len = max(len(ids) for ids in ids_list)
ids_list = pad_sequence(ids_list, max_ids_len, self.tokenizer.pad_id)
ids_list, max_ids_len = pad_sequence(ids_list, self.tokenizer.pad_id)
device = next(self.model.parameters()).device
cache_manager = KVCacheManager(self.config, batch_size, device=device)
@ -224,7 +177,11 @@ class BatchGenerator(GeneratorCore):
attn_mask =cache_manager.get_seq_mask()
next_token_id, cache_increase = self.generate_iterator(
input_tensor, temperature, top_k, top_p, attn_mask=attn_mask, kv_caches=kv_caches, start_pos=cur_cache_pos)
input_tensor, request.temperature, request.top_k, request.top_p,
attn_mask=attn_mask,
kv_caches=kv_caches,
start_pos=cur_cache_pos
)
cur_cache_pos += cache_increase
active_mask = []
@ -246,47 +203,14 @@ class BatchGenerator(GeneratorCore):
max_ids_len += 1
responses = [str()] * batch_size
for i in range(batch_size):
responses[i] = self.tokenizer.decode(ids_list[i][start_cache_pos:])
histories[i].append((queries[i], responses[i]))
request.history[i].append((request.query[i], responses[i]))
return responses
class RetrievalGenerator(GeneratorCore):
def __init__(self, retriever_parameter: ModelParameter):
super().__init__(retriever_parameter)
def generate(
self,
retrieved: List[str],
query: str,
history: List[Tuple[str, str]],
temperature: float,
top_k: int,
top_p: float,
) -> str:
assert temperature >= 0.0
assert top_k >= 0
assert top_p >= 0.0 and top_p <= 1.0
if history is None:
history = []
retrieved = "\n".join([f"{idx + 1}. {key}" for idx, key in enumerate(retrieved)]) if retrieved else ""
retrieved_query = f"{retrieved}\n\n{query}" if retrieved else query
parameter = ModelParameter(self.model, self.tokenizer, self.config)
return ChatGenerator(parameter).generate(
retrieved_query,
history,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
class EmbeddingEncoder(EmbeddingEncoderCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)

View File

@ -13,35 +13,35 @@ PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
def batch_generate(
model: Khaosz,
queries: List[str],
query: List[str],
temperature: float,
top_k: int,
top_p: float,
batch_size: int,
) -> List:
assert batch_size > 0
sorted_queries = sorted(queries, key=lambda x: len(x), reverse=True)
original_indices = {query: idx for idx, query in enumerate(queries)}
sorted_query = sorted(query, key=lambda x: len(x), reverse=True)
original_indices = {query: idx for idx, query in enumerate(query)}
responses = [None] * len(queries)
total_batches = (len(sorted_queries) + batch_size - 1) // batch_size
responses = [None] * len(query)
total_batches = (len(sorted_query) + batch_size - 1) // batch_size
for i in tqdm(range(0, total_batches * batch_size, batch_size), desc="Generating responses"):
batch_queries = sorted_queries[i: min(i + batch_size, len(queries))]
if not isinstance(batch_queries, list):
batch_queries = [batch_queries]
batch_query = sorted_query[i: min(i + batch_size, len(query))]
if not isinstance(batch_query, list):
batch_query = [batch_query]
batch_responses = model.batch_generate(
queries=batch_queries,
query=batch_query,
temperature=temperature,
top_k=top_k,
top_p=top_p
)
for batch_query, batch_response in zip(batch_queries, batch_responses):
for batch_query, batch_response in zip(batch_query, batch_responses):
print(f"Q: {batch_query[:50]} \nR: {batch_response[:50]})")
for query, response in zip(batch_queries, batch_responses):
for query, response in zip(batch_query, batch_responses):
original_idx = original_indices[query]
responses[original_idx] = response
@ -60,11 +60,11 @@ def processor(
):
with open(input_json_file, "r", encoding='utf-8') as f:
input_dict = [json.loads(line) for line in f]
queries = [item[question_key] for item in input_dict]
query = [item[question_key] for item in input_dict]
output_dict = batch_generate(
model=model,
queries=queries,
query=query,
temperature=temperature,
top_k=top_k,
top_p=top_p,