From e72e244df6061119009a89b6092c2c5213e55b98 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 20 Oct 2025 13:00:41 +0800 Subject: [PATCH] =?UTF-8?q?feat(inference):=20=E5=AE=9E=E7=8E=B0=E9=87=87?= =?UTF-8?q?=E6=A0=B7=E7=AD=96=E7=95=A5=E5=B9=B6=E4=BC=98=E5=8C=96=E7=94=9F?= =?UTF-8?q?=E6=88=90=E5=99=A8=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/inference/core.py | 139 +++++++++++++++++++++++-- khaosz/inference/generator.py | 187 ++++------------------------------ 2 files changed, 151 insertions(+), 175 deletions(-) diff --git a/khaosz/inference/core.py b/khaosz/inference/core.py index 90732e6..c6b4b8d 100644 --- a/khaosz/inference/core.py +++ b/khaosz/inference/core.py @@ -1,29 +1,86 @@ import torch - from torch import Tensor -from typing import List, Tuple, Union, Optional, Generator, Self +from typing import List, Tuple, Union, Optional, Self from khaosz.config.param_config import ModelParameter +def apply_sampling_strategies( + logits: Tensor, + temperature: float, + top_k: int, + top_p: float, + filter_value: float = -float("inf") +) -> Tensor: + """ + Apply sampling strategies to the logits tensor. + + Args: + logits (Tensor): The logits tensor. + temperature (float): The temperature parameter. + top_k (int): The top-k parameter. + top_p (float): The top-p parameter. + filter_value (float, optional): The filter value. Defaults to -float("inf"). + + Returns: + Tensor: The sampled logits tensor. + + """ + + if temperature != 1.0: + logits = logits / temperature + + if top_k > 0: + top_k = min(top_k, logits.size(-1)) + indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) + + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + indices_to_remove = torch.zeros_like(logits, dtype=torch.bool) + indices_to_remove.scatter_( + dim=1, + index=sorted_indices, + src=sorted_indices_to_remove + ) + + logits[indices_to_remove] = filter_value + + return logits + + class GeneratorCore: def __init__(self, parameter: ModelParameter): self.model = parameter.model self.tokenizer = parameter.tokenizer self.config = parameter.config - - def compute_logits( + + def generate_iterator( self, input_ids: Tensor, + temperature: float, + top_k: int, + top_p: float, attn_mask: Optional[Tensor] = None, kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None, start_pos: int = 0 - ) -> Tuple[Tensor, int]: + )-> Tuple[Tensor, int]: + with torch.inference_mode(): outputs = self.model(input_ids, attn_mask, kv_caches, start_pos) logits = outputs["logits"][:, -1, :] cache_increase = input_ids.size(-1) - return logits, cache_increase + logits = apply_sampling_strategies(logits, temperature, top_k, top_p) + probs = torch.softmax(logits, dim=-1) + next_token_id = torch.multinomial(probs, num_samples=1) + + return next_token_id, cache_increase def to(self, *args, **kargs) -> Self: self.model.to(*args, **kargs) @@ -94,4 +151,72 @@ class EmbeddingEncoderCore: def to(self, *args, **kargs) -> Self: self.model.to(*args, **kargs) - return self \ No newline at end of file + return self + + +class KVCacheManager: + def __init__( + self, + num_layers: int, + batch_size: int, + max_len: int, + num_heads: int, + head_dim: int, + device: torch.device = "cuda", + dtype: torch.dtype = torch.bfloat16 + ): + self.num_layers = num_layers + self.batch_size = batch_size + self.max_len = max_len + self.num_heads = num_heads + self.head_dim = head_dim + self.device = device + self.dtype = dtype + + self._kv_cache: List[Tuple[Tensor, Tensor]] = None + self._seq_mask: Tensor = None + self._initialize() + + def _initialize(self): + self._kv_cache = [] + for _ in range(self.num_layers): + k_cache = torch.zeros( + (self.batch_size, self.max_len, self.num_heads, self.head_dim), + device=self.device, dtype=self.dtype + ) + v_cache = torch.zeros( + (self.batch_size, self.max_len, self.num_heads, self.head_dim), + device=self.device, dtype=self.dtype + ) + self._kv_cache.append((k_cache, v_cache)) + + self._seq_mask = torch.ones( + (self.batch_size, self.max_len), + device=self.device, dtype=torch.bool + ) + + def update(self, active_mask: Tensor): + for i in range(self.num_layers): + k_cache, v_cache = self._kv_cache[i] + new_k_cache, new_v_cache = k_cache[active_mask], v_cache[active_mask] + self._kv_cache[i] = (new_k_cache, new_v_cache) + + self._seq_mask = self._seq_mask[active_mask] + + def reset(self, full_reset=False): + if full_reset: + self._kv_cache = None + self._seq_mask = None + else: + self._initialize() + + def set_seq_mask(self, input_ids: Tensor, pad_id: int): + batch_size, seq_len = input_ids.shape + bool_mask = (input_ids != pad_id) + self._seq_mask[: batch_size, : seq_len] = bool_mask + + def get_kvcache(self) -> List[Tuple[Tensor, Tensor]]: + return self._kv_cache + + def get_seq_mask(self) -> Tensor: + return self._seq_mask \ No newline at end of file diff --git a/khaosz/inference/generator.py b/khaosz/inference/generator.py index 719fbd2..297c9d5 100644 --- a/khaosz/inference/generator.py +++ b/khaosz/inference/generator.py @@ -1,7 +1,7 @@ import torch from torch import Tensor from typing import List, Tuple, Union, Optional, Generator -from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore +from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheManager from khaosz.config.param_config import ModelParameter @@ -51,125 +51,6 @@ def pad_sequence(ids_list: List[List[int]], max_ids_len: int, pad_id: int) -> Li return new_ids_list -def apply_sampling_strategies( - logits: Tensor, - temperature: float, - top_k: int, - top_p: float, - filter_value: float = -float("inf") -) -> Tensor: - """ - Apply sampling strategies to the logits tensor. - - Args: - logits (Tensor): The logits tensor. - temperature (float): The temperature parameter. - top_k (int): The top-k parameter. - top_p (float): The top-p parameter. - filter_value (float, optional): The filter value. Defaults to -float("inf"). - - Returns: - Tensor: The sampled logits tensor. - - """ - - if temperature != 1.0: - logits = logits / temperature - - if top_k > 0: - top_k = min(top_k, logits.size(-1)) - indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None] - logits[indices_to_remove] = filter_value - - if top_p < 1.0: - sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) - cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) - - sorted_indices_to_remove = cumulative_probs > top_p - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - - indices_to_remove = torch.zeros_like(logits, dtype=torch.bool) - indices_to_remove.scatter_( - dim=1, - index=sorted_indices, - src=sorted_indices_to_remove - ) - - logits[indices_to_remove] = filter_value - - return logits - - -class KVCacheManager: - def __init__( - self, - num_layers: int, - batch_size: int, - max_len: int, - num_heads: int, - head_dim: int, - device: torch.device = "cuda", - dtype: torch.dtype = torch.bfloat16 - ): - self.num_layers = num_layers - self.batch_size = batch_size - self.max_len = max_len - self.num_heads = num_heads - self.head_dim = head_dim - self.device = device - self.dtype = dtype - - self._kv_cache: List[Tuple[Tensor, Tensor]] = None - self._seq_mask: Tensor = None - self._initialize() - - def _initialize(self): - self._kv_cache = [] - for _ in range(self.num_layers): - k_cache = torch.zeros( - (self.batch_size, self.max_len, self.num_heads, self.head_dim), - device=self.device, dtype=self.dtype - ) - v_cache = torch.zeros( - (self.batch_size, self.max_len, self.num_heads, self.head_dim), - device=self.device, dtype=self.dtype - ) - self._kv_cache.append((k_cache, v_cache)) - - self._seq_mask = torch.ones( - (self.batch_size, self.max_len), - device=self.device, dtype=torch.bool - ) - - def update(self, active_mask: Tensor): - for i in range(self.num_layers): - k_cache, v_cache = self._kv_cache[i] - new_k_cache, new_v_cache = k_cache[active_mask], v_cache[active_mask] - self._kv_cache[i] = (new_k_cache, new_v_cache) - - self._seq_mask = self._seq_mask[active_mask] - - def reset(self, full_reset=False): - if full_reset: - self._kv_cache = None - self._seq_mask = None - else: - self._initialize() - - def set_seq_mask(self, input_ids: Tensor, pad_id: int): - batch_size, seq_len = input_ids.shape - bool_mask = (input_ids != pad_id) - self._seq_mask[: batch_size, : seq_len] = bool_mask - - def get_kvcache(self) -> List[Tuple[Tensor, Tensor]]: - return self._kv_cache - - def get_seq_mask(self) -> Tensor: - return self._seq_mask - - - class TextGenerator(GeneratorCore): def __init__(self, parameter: ModelParameter): @@ -202,22 +83,16 @@ class TextGenerator(GeneratorCore): start_cache_pos = len(ids) cur_cache_pos = 0 self.model.eval() + kv_caches = cache_manager.get_kvcache() - while len(ids) < self.config.m_len: - kv_caches = cache_manager.get_kvcache() - logits, cache_increase = self.compute_logits( - input_ids, - kv_caches=kv_caches, - start_pos=cur_cache_pos - ) - logits = apply_sampling_strategies(logits, temperature, top_k, top_p) - probs = torch.softmax(logits, dim=-1) - next_token_id = torch.multinomial(probs, num_samples=1) + for _ in range(len(ids), self.config.m_len): + next_token_id, cache_increase = self.generate_iterator( + input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos) input_ids = next_token_id ids.append(next_token_id.item()) cur_cache_pos += cache_increase - + if next_token_id.item() in self.tokenizer.stop_ids: break @@ -226,7 +101,6 @@ class TextGenerator(GeneratorCore): return response - class ChatGenerator(GeneratorCore): def __init__(self, parameter: ModelParameter): super().__init__(parameter) @@ -263,26 +137,19 @@ class ChatGenerator(GeneratorCore): start_cache_pos = len(ids) cur_cache_pos = 0 self.model.eval() + kv_caches = cache_manager.get_kvcache() - - while len(ids) < self.config.m_len: - kv_caches = cache_manager.get_kvcache() - logits, cache_increase = self.compute_logits( - input_ids, - kv_caches=kv_caches, - start_pos=cur_cache_pos - ) - logits = apply_sampling_strategies(logits, temperature, top_k, top_p) - probs = torch.softmax(logits, dim=-1) - next_token_id = torch.multinomial(probs, num_samples=1) + for _ in range(len(ids), self.config.m_len): + next_token_id, cache_increase = self.generate_iterator( + input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos) input_ids = next_token_id ids.append(next_token_id.item()) cur_cache_pos += cache_increase - + if next_token_id.item() in self.tokenizer.stop_ids: break - + response = self.tokenizer.decode(ids[start_cache_pos:]) cpy_history.append((query, response)) @@ -325,18 +192,11 @@ class StreamGenerator(GeneratorCore): start_cache_pos = len(ids) cur_cache_pos = 0 self.model.eval() + kv_caches = cache_manager.get_kvcache() - - while len(ids) < self.config.m_len: - kv_caches = cache_manager.get_kvcache() - logits, cache_increase = self.compute_logits( - input_ids, - kv_caches=kv_caches, - start_pos=cur_cache_pos - ) - logits = apply_sampling_strategies(logits, temperature, top_k, top_p) - probs = torch.softmax(logits, dim=-1) - next_token_id = torch.multinomial(probs, num_samples=1) + for _ in range(len(ids), self.config.m_len): + next_token_id, cache_increase = self.generate_iterator( + input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos) input_ids = next_token_id ids.append(next_token_id.item()) @@ -370,7 +230,7 @@ class BatchGenerator(GeneratorCore): batch_size = len(queries) if histories is None: histories = [[] for _ in range(batch_size)] - + prompts = [build_prompt(query, history) for query, history in zip(queries, histories)] ids_list = [self.tokenizer.encode(prompt) for prompt in prompts] max_ids_len = max(len(ids) for ids in ids_list) @@ -397,18 +257,10 @@ class BatchGenerator(GeneratorCore): kv_caches = cache_manager.get_kvcache() attn_mask =cache_manager.get_seq_mask() - logits, cache_increase = self.compute_logits( - input_tensor, - attn_mask=attn_mask, - kv_caches=kv_caches, - start_pos=cur_cache_pos - ) + next_token_id, cache_increase = self.generate_iterator( + input_tensor, temperature, top_k, top_p, attn_mask=attn_mask, kv_caches=kv_caches, start_pos=cur_cache_pos) cur_cache_pos += cache_increase - logits = apply_sampling_strategies(logits, temperature, top_k, top_p) - probs = torch.softmax(logits, dim=-1) - next_token_id = torch.multinomial(probs, num_samples=1) - active_mask = [] c_ids = 0 @@ -435,7 +287,6 @@ class BatchGenerator(GeneratorCore): histories[i].append((queries[i], responses[i])) return responses - class RetrievalGenerator(GeneratorCore):