From 62fba9a298f0c00842c688b84528a956524fa372 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Wed, 18 Mar 2026 15:07:35 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E4=BC=98=E5=8C=96=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=E8=AE=BE=E7=BD=AE=EF=BC=8C=20=E5=8E=BB=E9=99=A4?= =?UTF-8?q?=E5=86=97=E4=BD=99=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- demo/generate_batch.py | 2 +- demo/generate_retrieve.py | 5 +- khaosz/__init__.py | 16 ++- khaosz/api.py | 107 ++++++++++++------- khaosz/inference/__init__.py | 26 ++++- khaosz/inference/cuda_graph.py | 98 ----------------- khaosz/inference/generator.py | 190 ++++++++++----------------------- tools/generate.py | 26 ++--- 8 files changed, 172 insertions(+), 298 deletions(-) delete mode 100644 khaosz/inference/cuda_graph.py diff --git a/demo/generate_batch.py b/demo/generate_batch.py index 43ece47..85a563f 100644 --- a/demo/generate_batch.py +++ b/demo/generate_batch.py @@ -12,7 +12,7 @@ def batch_generate(): inputs = ["你好", "请问什么是人工智能", "今天天气如何", "我感到焦虑, 请问我应该怎么办", "请问什么是显卡"] responses = model.batch_generate( - queries=inputs, + query=inputs, temperature=0.8, top_p=0.95, top_k=50 diff --git a/demo/generate_retrieve.py b/demo/generate_retrieve.py index 0431fbd..c49ac54 100644 --- a/demo/generate_retrieve.py +++ b/demo/generate_retrieve.py @@ -26,9 +26,10 @@ if __name__ == "__main__": query = "作者设计了一个怎样的模型" emb_query = model.encode(query) retrieved = retriever.retrieve(emb_query, retrive_top_k) + retrieved_content = "\n".join([f"{idx + 1}. " + text for idx, (text, _) in enumerate(retrieved)]) retrive_response = model.retrieve_generate( - retrieved=retrieved, + retrieved=retrieved_content, query=query, temperature=0.8, top_p=0.95, @@ -36,7 +37,7 @@ if __name__ == "__main__": ) print("retrieve content:") - print("\n".join([f"{idx + 1}. " + text for idx, (text, _) in enumerate(retrieved)])) + print(retrieved_content) print("\n\nretrive generate:") print(retrive_response) \ No newline at end of file diff --git a/khaosz/__init__.py b/khaosz/__init__.py index be0414a..f9a14f1 100644 --- a/khaosz/__init__.py +++ b/khaosz/__init__.py @@ -17,12 +17,11 @@ from khaosz.data import ( BpeTokenizer ) from khaosz.inference.generator import ( - TextGenerator, - ChatGenerator, - StreamGenerator, - BatchGenerator, - RetrievalGenerator, - EmbeddingEncoder + GenerationRequest, + LoopGenerator, + StreamGenerator, + BatchGenerator, + EmbeddingEncoder, ) from khaosz.trainer import ( @@ -46,11 +45,10 @@ __all__ = [ "DatasetLoader", "BpeTokenizer", - "TextGenerator", - "ChatGenerator", + "GenerationRequest", + "LoopGenerator", "StreamGenerator", "BatchGenerator", - "RetrievalGenerator", "EmbeddingEncoder", "Trainer", diff --git a/khaosz/api.py b/khaosz/api.py index 5ce2349..ef9f321 100644 --- a/khaosz/api.py +++ b/khaosz/api.py @@ -1,21 +1,42 @@ +from torch import nn from torch import Tensor +from contextlib import contextmanager from typing import List, Tuple, Generator, Union from khaosz.inference.generator import ( - TextGenerator, - ChatGenerator, + GenerationRequest, + LoopGenerator, StreamGenerator, BatchGenerator, - RetrievalGenerator, EmbeddingEncoder ) from khaosz.config.param_config import ModelParameter +@contextmanager +def disable_random_init(): + init_functions = [ + 'xavier_normal_', 'xavier_uniform_', + 'kaiming_normal_', 'kaiming_uniform_', + 'zeros_', 'ones_', 'constant_', + 'normal_', 'uniform_' + ] + original_funcs = {} + for name in init_functions: + if hasattr(nn.init, name): + original_funcs[name] = getattr(nn.init, name) + setattr(nn.init, name, lambda *args, **kwargs: None) + try: + yield + finally: + for name, orig_func in original_funcs.items(): + setattr(nn.init, name, orig_func) + class Khaosz: def __init__(self, model_dir: str): - self.parameter = ModelParameter() - self.parameter.load(model_dir) + with disable_random_init(): + self.parameter = ModelParameter() + self.parameter.load(model_dir) def to(self, *args, **kwargs): self.parameter.to(*args, **kwargs) @@ -29,32 +50,33 @@ class Khaosz: top_k: int=50, top_p: float=0.95, ) -> str: - generator = ChatGenerator(self.parameter) + generator = LoopGenerator(self.parameter) return generator.generate( - query, + GenerationRequest( + top_k, top_p, temperature, + self.parameter.config.max_len, + query=query, history=history, - temperature=temperature, - top_k=top_k, - top_p=top_p, - ) + build_prompt=True + )) def batch_generate( self, - queries: List[str], - histories: List[Tuple[str, str]]=None, + query: List[str], + history: List[Tuple[str, str]]=None, temperature: float=0.8, top_k: int=50, top_p: float=0.95, ) -> List[str]: generator = BatchGenerator(self.parameter) return generator.generate( - queries, - histories=histories, - temperature=temperature, - top_k=top_k, - top_p=top_p, - ) - + GenerationRequest( + top_k, top_p, temperature, + self.parameter.config.max_len, + query=query, + history=history, + build_prompt=True + )) def stream_generate( self, @@ -66,12 +88,13 @@ class Khaosz: ) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]: stream_generator = StreamGenerator(self.parameter) return stream_generator.generate( - query, - history=history, - temperature=temperature, - top_k=top_k, - top_p=top_p, - ) + GenerationRequest( + top_k, top_p, temperature, + self.parameter.config.max_len, + query=query, + history=history, + build_prompt=True + )) def retrieve_generate( self, @@ -82,15 +105,16 @@ class Khaosz: top_k: int=50, top_p: float=0.95, ) -> str: - generator = RetrievalGenerator(self.parameter) + generator = LoopGenerator(self.parameter) return generator.generate( - retrieved, - query, + GenerationRequest( + top_k, top_p, temperature, + self.parameter.config.max_len, + query=query, history=history, - temperature=temperature, - top_k=top_k, - top_p=top_p, - ) + system_prompt=retrieved, + build_prompt=True + )) def text_generate( self, @@ -99,14 +123,15 @@ class Khaosz: top_k: int=50, top_p: float=0.95, ) -> str: - generator = TextGenerator(self.parameter) - - return generator.generate( - query, - temperature=temperature, - top_k=top_k, - top_p=top_p, - ) + generator = LoopGenerator(self.parameter) + return generator.generate( + GenerationRequest( + top_k, top_p, temperature, + self.parameter.config.max_len, + query=query, + build_prompt=False + )) + def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]: encoder = EmbeddingEncoder(self.parameter) diff --git a/khaosz/inference/__init__.py b/khaosz/inference/__init__.py index 5bd6efa..f94aec6 100644 --- a/khaosz/inference/__init__.py +++ b/khaosz/inference/__init__.py @@ -1 +1,25 @@ -# init file \ No newline at end of file +from khaosz.inference.core import ( + GeneratorCore, + EmbeddingEncoderCore, + KVCacheManager, +) + +from khaosz.inference.generator import ( + GenerationRequest, + LoopGenerator, + StreamGenerator, + BatchGenerator, + EmbeddingEncoder, +) + +__all__ = [ + "GeneratorCore", + "EmbeddingEncoderCore", + "KVCacheManager", + + "GenerationRequest", + "LoopGenerator", + "StreamGenerator", + "BatchGenerator", + "EmbeddingEncoder", +] \ No newline at end of file diff --git a/khaosz/inference/cuda_graph.py b/khaosz/inference/cuda_graph.py deleted file mode 100644 index 09c05a7..0000000 --- a/khaosz/inference/cuda_graph.py +++ /dev/null @@ -1,98 +0,0 @@ -import torch -from torch import Tensor -from functools import wraps -from inspect import signature - - -class CudaGraphWrapper: - def __init__(self, function, device="cuda", cast=False): - self.function = function - self.cast = cast - self.device = device - self.static_input = None - self.static_output = None - self.graph = None - self.signature = signature(function) - - def _update_inplace(self, lhs, rhs): - if isinstance(lhs, Tensor) and isinstance(rhs, Tensor): - if lhs.shape != rhs.shape: - raise ValueError( - f"Tensor shape mismatch! " - f"Expected: {lhs.shape}, Got: {rhs.shape}. " - f"Function: {self.function}" - ) - if self.cast: - if lhs.device != rhs.device: - rhs = rhs.to(device=lhs.device) - - if lhs.dtype != rhs.dtype: - rhs = rhs.to(dtype=lhs.dtype) - else: - if lhs.device != rhs.device: - raise ValueError( - f"Tensor device mismatch! " - f"Expected: {lhs.device}, Got: {rhs.device}. " - f"Function: {self.function}" - ) - if lhs.dtype != rhs.dtype: - raise ValueError( - f"Tensor dtype mismatch! " - f"Expected: {lhs.dtype}, Got: {rhs.dtype}. " - f"Function: {self.function}" - ) - lhs.copy_(rhs) - elif isinstance(lhs, dict): - for k in lhs: - if k in rhs: - self._update_inplace(lhs[k], rhs[k]) - elif isinstance(lhs, (list, tuple)): - for i in range(len(lhs)): - if i < len(rhs): - self._update_inplace(lhs[i], rhs[i]) - elif isinstance(lhs, (int, float, bool, str, type(None))): - if lhs != rhs: - raise ValueError("Does not support changing control parameters.") - - def _update_args(self, input_args, input_kwargs): - bound_args = self.signature.bind(*input_args, **input_kwargs) - bound_args.apply_defaults() - args_dict = bound_args.arguments - - if self.static_input is None: - self.static_input = args_dict - else: - self._update_inplace(self.static_input, args_dict) - - def run(self, *args, **kwargs): - self._update_args(args, kwargs) - - if self.graph is None: - # warmup - _ = torch.matmul( - torch.randn(100, 100, device=self.device), - torch.randn(100, 100, device=self.device) - ) - torch.cuda.synchronize() - - # capture graph - self.graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(self.graph): - self.static_output = self.function(**self.static_input) - - self.graph.replay() - - return self.static_output - - -def cuda_graph(device="cuda", cast=False): - def decorator(func): - wrapper = CudaGraphWrapper(func, device, cast) - - @wraps(func) - def wrapped(*args, **kwargs): - return wrapper.run(*args, **kwargs) - - return wrapped - - return decorator \ No newline at end of file diff --git a/khaosz/inference/generator.py b/khaosz/inference/generator.py index 034a1b5..877807b 100644 --- a/khaosz/inference/generator.py +++ b/khaosz/inference/generator.py @@ -1,10 +1,13 @@ import torch +from dataclasses import dataclass from torch import Tensor from typing import List, Tuple, Union, Optional, Generator from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheManager from khaosz.config.param_config import ModelParameter +HistoryType = List[Tuple[str, str]] + def build_prompt( query: str, init_prompt: Optional[str] = None, @@ -34,7 +37,7 @@ def build_prompt( return prompt -def pad_sequence(ids_list: List[List[int]], max_ids_len: int, pad_id: int) -> List[List[int]]: +def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]: """ Pad a list of sequences to a fixed length. @@ -47,34 +50,47 @@ def pad_sequence(ids_list: List[List[int]], max_ids_len: int, pad_id: int) -> Li List[List[int]]: A list of padded sequences. """ + max_ids_len = max(len(ids) for ids in ids_list) new_ids_list = [] for ids in ids_list: pad_len = max_ids_len - len(ids) padded_seq = [pad_id] * pad_len + ids new_ids_list.append(padded_seq) - return new_ids_list + return new_ids_list, max_ids_len + +@dataclass +class GenerationRequest: + top_k: int + top_p: float + temperature: float + max_len: int + + query: Union[str, List[str]] + history: Optional[Union[HistoryType, List[HistoryType]]] = None + system_prompt: Optional[str] = None + + build_prompt: bool = True + + def __post_init__(self): + if not isinstance(self.top_k, int) or self.top_k < 0: + raise ValueError("top_k must be a non-negative integer") + if not isinstance(self.top_p, float) or self.top_p < 0.0 or self.top_p > 1.0: + raise ValueError("top_p must be a float between 0.0 and 1.0") + if not isinstance(self.temperature, float) or self.temperature < 0.0: + raise ValueError("temperature must be a non-negative float") -class TextGenerator(GeneratorCore): +class LoopGenerator(GeneratorCore): def __init__(self, parameter: ModelParameter): super().__init__(parameter) - def generate( - self, - query: str, - temperature: float, - top_k: int, - top_p: float, - ) -> str: - assert temperature >= 0.0 - assert top_k >= 0 - assert top_p >= 0.0 and top_p <= 1.0 - + def generate(self, request: GenerationRequest) -> str: device = next(self.model.parameters()).device cache_manager = KVCacheManager(self.config, 1, device=device) - ids = self.tokenizer.encode(query) + input_args = build_prompt(request.query, request.history) if request.build_prompt else request.query + ids = self.tokenizer.encode(input_args) input_ids = torch.tensor([ids], device=device, dtype=torch.long) start_cache_pos = len(ids) @@ -83,84 +99,30 @@ class TextGenerator(GeneratorCore): kv_caches = cache_manager.get_kvcache() ids = self.generate_loop( - input_ids, ids, temperature, top_k, top_p, + input_ids, ids, request.temperature, request.top_k, request.top_p, kv_caches=kv_caches, start_pos=cur_cache_pos ) - response = self.tokenizer.decode(ids[start_cache_pos:]) return response -class ChatGenerator(GeneratorCore): - def __init__(self, parameter: ModelParameter): - super().__init__(parameter) - - def generate( - self, - query: str, - history: List[Tuple[str, str]], - temperature: float, - top_k: int, - top_p: float, - ) -> str: - - assert temperature >= 0.0 - assert top_k >= 0 - assert top_p >= 0.0 and top_p <= 1.0 - - if history is None: - history = [] - - device = next(self.model.parameters()).device - cache_manager = KVCacheManager(self.config, 1, device=device) - - ids = self.tokenizer.encode(build_prompt(query, history)) - input_ids = torch.tensor([ids], device=device, dtype=torch.long) - - start_cache_pos = len(ids) - cur_cache_pos = 0 - self.model.eval() - kv_caches = cache_manager.get_kvcache() - - ids = self.generate_loop( - input_ids, ids, temperature, top_k, top_p, - kv_caches=kv_caches, - start_pos=cur_cache_pos - ) - - response = self.tokenizer.decode(ids[start_cache_pos:]) - - return response - - class StreamGenerator(GeneratorCore): def __init__(self, parameter: ModelParameter): super().__init__(parameter) - def generate( - self, - query: str, - history: List[Tuple[str, str]], - temperature: float, - top_k: int, - top_p: float, - ) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]: - - assert temperature >= 0.0 - assert top_k >= 0 - assert top_p >= 0.0 and top_p <= 1.0 + def generate(self, request: GenerationRequest) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]: - if history is None: - history = [] + if request.history is None: + request.history = [] device = next(self.model.parameters()).device cache_manager = KVCacheManager(self.config, 1, device=device) - - ids = self.tokenizer.encode(build_prompt(query, history)) + + input_args = build_prompt(request.query, request.history) if request.build_prompt else request.query + ids = self.tokenizer.encode(input_args) input_ids = torch.tensor([ids], device=device, dtype=torch.long) - cpy_history = history.copy() start_cache_pos = len(ids) cur_cache_pos = 0 @@ -169,45 +131,36 @@ class StreamGenerator(GeneratorCore): for _ in range(len(ids), self.config.max_len): next_token_id, cache_increase = self.generate_iterator( - input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos) + input_ids, request.temperature, request.top_k, request.top_p, + kv_caches=kv_caches, + start_pos=cur_cache_pos + ) input_ids = next_token_id ids.append(next_token_id.item()) cur_cache_pos += cache_increase response = self.tokenizer.decode(ids[start_cache_pos:]) - yield response, cpy_history + [(query, response)] + yield response, request.history + [(request.query, response)] if next_token_id.item() in self.tokenizer.stop_ids: - yield response + "\n", cpy_history + [(query, response)] + yield response + "\n", request.history + [(request.query, response)] break - + class BatchGenerator(GeneratorCore): def __init__(self, parameter: ModelParameter): super().__init__(parameter) - def generate( - self, - queries: List[str], - histories: List[List[Tuple[str, str]]], - temperature: float, - top_k: int, - top_p: float - ) -> List[str]: + def generate(self, request: GenerationRequest) -> List[str]: + batch_size = len(request.query) + if request.history is None: + request.history = [[] for _ in range(batch_size)] - assert temperature >= 0.0 - assert top_k >= 0 - assert top_p >= 0.0 and top_p <= 1.0 + prompts = [build_prompt(query, history) for query, history in zip(request.query, request.history)] - batch_size = len(queries) - if histories is None: - histories = [[] for _ in range(batch_size)] - - prompts = [build_prompt(query, history) for query, history in zip(queries, histories)] ids_list = [self.tokenizer.encode(prompt) for prompt in prompts] - max_ids_len = max(len(ids) for ids in ids_list) - ids_list = pad_sequence(ids_list, max_ids_len, self.tokenizer.pad_id) + ids_list, max_ids_len = pad_sequence(ids_list, self.tokenizer.pad_id) device = next(self.model.parameters()).device cache_manager = KVCacheManager(self.config, batch_size, device=device) @@ -224,7 +177,11 @@ class BatchGenerator(GeneratorCore): attn_mask =cache_manager.get_seq_mask() next_token_id, cache_increase = self.generate_iterator( - input_tensor, temperature, top_k, top_p, attn_mask=attn_mask, kv_caches=kv_caches, start_pos=cur_cache_pos) + input_tensor, request.temperature, request.top_k, request.top_p, + attn_mask=attn_mask, + kv_caches=kv_caches, + start_pos=cur_cache_pos + ) cur_cache_pos += cache_increase active_mask = [] @@ -246,47 +203,14 @@ class BatchGenerator(GeneratorCore): max_ids_len += 1 - responses = [str()] * batch_size for i in range(batch_size): responses[i] = self.tokenizer.decode(ids_list[i][start_cache_pos:]) - histories[i].append((queries[i], responses[i])) + request.history[i].append((request.query[i], responses[i])) return responses -class RetrievalGenerator(GeneratorCore): - def __init__(self, retriever_parameter: ModelParameter): - super().__init__(retriever_parameter) - - def generate( - self, - retrieved: List[str], - query: str, - history: List[Tuple[str, str]], - temperature: float, - top_k: int, - top_p: float, - ) -> str: - assert temperature >= 0.0 - assert top_k >= 0 - assert top_p >= 0.0 and top_p <= 1.0 - - if history is None: - history = [] - - retrieved = "\n".join([f"{idx + 1}. {key}" for idx, key in enumerate(retrieved)]) if retrieved else "" - retrieved_query = f"{retrieved}\n\n{query}" if retrieved else query - parameter = ModelParameter(self.model, self.tokenizer, self.config) - - return ChatGenerator(parameter).generate( - retrieved_query, - history, - temperature=temperature, - top_k=top_k, - top_p=top_p, - ) - class EmbeddingEncoder(EmbeddingEncoderCore): def __init__(self, parameter: ModelParameter): super().__init__(parameter) diff --git a/tools/generate.py b/tools/generate.py index a18c433..cc00851 100644 --- a/tools/generate.py +++ b/tools/generate.py @@ -13,35 +13,35 @@ PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) def batch_generate( model: Khaosz, - queries: List[str], + query: List[str], temperature: float, top_k: int, top_p: float, batch_size: int, ) -> List: assert batch_size > 0 - sorted_queries = sorted(queries, key=lambda x: len(x), reverse=True) - original_indices = {query: idx for idx, query in enumerate(queries)} + sorted_query = sorted(query, key=lambda x: len(x), reverse=True) + original_indices = {query: idx for idx, query in enumerate(query)} - responses = [None] * len(queries) - total_batches = (len(sorted_queries) + batch_size - 1) // batch_size + responses = [None] * len(query) + total_batches = (len(sorted_query) + batch_size - 1) // batch_size for i in tqdm(range(0, total_batches * batch_size, batch_size), desc="Generating responses"): - batch_queries = sorted_queries[i: min(i + batch_size, len(queries))] - if not isinstance(batch_queries, list): - batch_queries = [batch_queries] + batch_query = sorted_query[i: min(i + batch_size, len(query))] + if not isinstance(batch_query, list): + batch_query = [batch_query] batch_responses = model.batch_generate( - queries=batch_queries, + query=batch_query, temperature=temperature, top_k=top_k, top_p=top_p ) - for batch_query, batch_response in zip(batch_queries, batch_responses): + for batch_query, batch_response in zip(batch_query, batch_responses): print(f"Q: {batch_query[:50]} \nR: {batch_response[:50]})") - for query, response in zip(batch_queries, batch_responses): + for query, response in zip(batch_query, batch_responses): original_idx = original_indices[query] responses[original_idx] = response @@ -60,11 +60,11 @@ def processor( ): with open(input_json_file, "r", encoding='utf-8') as f: input_dict = [json.loads(line) for line in f] - queries = [item[question_key] for item in input_dict] + query = [item[question_key] for item in input_dict] output_dict = batch_generate( model=model, - queries=queries, + query=query, temperature=temperature, top_k=top_k, top_p=top_p,