feat(inference): 实现采样策略并优化生成器逻辑

This commit is contained in:
ViperEkura 2025-10-20 13:00:41 +08:00
parent 98efca7b9d
commit e72e244df6
2 changed files with 151 additions and 175 deletions

View File

@ -1,29 +1,86 @@
import torch
from torch import Tensor
from typing import List, Tuple, Union, Optional, Generator, Self
from typing import List, Tuple, Union, Optional, Self
from khaosz.config.param_config import ModelParameter
def apply_sampling_strategies(
logits: Tensor,
temperature: float,
top_k: int,
top_p: float,
filter_value: float = -float("inf")
) -> Tensor:
"""
Apply sampling strategies to the logits tensor.
Args:
logits (Tensor): The logits tensor.
temperature (float): The temperature parameter.
top_k (int): The top-k parameter.
top_p (float): The top-p parameter.
filter_value (float, optional): The filter value. Defaults to -float("inf").
Returns:
Tensor: The sampled logits tensor.
"""
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
top_k = min(top_k, logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(
dim=1,
index=sorted_indices,
src=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
class GeneratorCore:
def __init__(self, parameter: ModelParameter):
self.model = parameter.model
self.tokenizer = parameter.tokenizer
self.config = parameter.config
def compute_logits(
def generate_iterator(
self,
input_ids: Tensor,
temperature: float,
top_k: int,
top_p: float,
attn_mask: Optional[Tensor] = None,
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
start_pos: int = 0
)-> Tuple[Tensor, int]:
with torch.inference_mode():
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
logits = outputs["logits"][:, -1, :]
cache_increase = input_ids.size(-1)
return logits, cache_increase
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
probs = torch.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
return next_token_id, cache_increase
def to(self, *args, **kargs) -> Self:
self.model.to(*args, **kargs)
@ -95,3 +152,71 @@ class EmbeddingEncoderCore:
def to(self, *args, **kargs) -> Self:
self.model.to(*args, **kargs)
return self
class KVCacheManager:
def __init__(
self,
num_layers: int,
batch_size: int,
max_len: int,
num_heads: int,
head_dim: int,
device: torch.device = "cuda",
dtype: torch.dtype = torch.bfloat16
):
self.num_layers = num_layers
self.batch_size = batch_size
self.max_len = max_len
self.num_heads = num_heads
self.head_dim = head_dim
self.device = device
self.dtype = dtype
self._kv_cache: List[Tuple[Tensor, Tensor]] = None
self._seq_mask: Tensor = None
self._initialize()
def _initialize(self):
self._kv_cache = []
for _ in range(self.num_layers):
k_cache = torch.zeros(
(self.batch_size, self.max_len, self.num_heads, self.head_dim),
device=self.device, dtype=self.dtype
)
v_cache = torch.zeros(
(self.batch_size, self.max_len, self.num_heads, self.head_dim),
device=self.device, dtype=self.dtype
)
self._kv_cache.append((k_cache, v_cache))
self._seq_mask = torch.ones(
(self.batch_size, self.max_len),
device=self.device, dtype=torch.bool
)
def update(self, active_mask: Tensor):
for i in range(self.num_layers):
k_cache, v_cache = self._kv_cache[i]
new_k_cache, new_v_cache = k_cache[active_mask], v_cache[active_mask]
self._kv_cache[i] = (new_k_cache, new_v_cache)
self._seq_mask = self._seq_mask[active_mask]
def reset(self, full_reset=False):
if full_reset:
self._kv_cache = None
self._seq_mask = None
else:
self._initialize()
def set_seq_mask(self, input_ids: Tensor, pad_id: int):
batch_size, seq_len = input_ids.shape
bool_mask = (input_ids != pad_id)
self._seq_mask[: batch_size, : seq_len] = bool_mask
def get_kvcache(self) -> List[Tuple[Tensor, Tensor]]:
return self._kv_cache
def get_seq_mask(self) -> Tensor:
return self._seq_mask

View File

@ -1,7 +1,7 @@
import torch
from torch import Tensor
from typing import List, Tuple, Union, Optional, Generator
from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore
from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheManager
from khaosz.config.param_config import ModelParameter
@ -51,125 +51,6 @@ def pad_sequence(ids_list: List[List[int]], max_ids_len: int, pad_id: int) -> Li
return new_ids_list
def apply_sampling_strategies(
logits: Tensor,
temperature: float,
top_k: int,
top_p: float,
filter_value: float = -float("inf")
) -> Tensor:
"""
Apply sampling strategies to the logits tensor.
Args:
logits (Tensor): The logits tensor.
temperature (float): The temperature parameter.
top_k (int): The top-k parameter.
top_p (float): The top-p parameter.
filter_value (float, optional): The filter value. Defaults to -float("inf").
Returns:
Tensor: The sampled logits tensor.
"""
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
top_k = min(top_k, logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(
dim=1,
index=sorted_indices,
src=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
class KVCacheManager:
def __init__(
self,
num_layers: int,
batch_size: int,
max_len: int,
num_heads: int,
head_dim: int,
device: torch.device = "cuda",
dtype: torch.dtype = torch.bfloat16
):
self.num_layers = num_layers
self.batch_size = batch_size
self.max_len = max_len
self.num_heads = num_heads
self.head_dim = head_dim
self.device = device
self.dtype = dtype
self._kv_cache: List[Tuple[Tensor, Tensor]] = None
self._seq_mask: Tensor = None
self._initialize()
def _initialize(self):
self._kv_cache = []
for _ in range(self.num_layers):
k_cache = torch.zeros(
(self.batch_size, self.max_len, self.num_heads, self.head_dim),
device=self.device, dtype=self.dtype
)
v_cache = torch.zeros(
(self.batch_size, self.max_len, self.num_heads, self.head_dim),
device=self.device, dtype=self.dtype
)
self._kv_cache.append((k_cache, v_cache))
self._seq_mask = torch.ones(
(self.batch_size, self.max_len),
device=self.device, dtype=torch.bool
)
def update(self, active_mask: Tensor):
for i in range(self.num_layers):
k_cache, v_cache = self._kv_cache[i]
new_k_cache, new_v_cache = k_cache[active_mask], v_cache[active_mask]
self._kv_cache[i] = (new_k_cache, new_v_cache)
self._seq_mask = self._seq_mask[active_mask]
def reset(self, full_reset=False):
if full_reset:
self._kv_cache = None
self._seq_mask = None
else:
self._initialize()
def set_seq_mask(self, input_ids: Tensor, pad_id: int):
batch_size, seq_len = input_ids.shape
bool_mask = (input_ids != pad_id)
self._seq_mask[: batch_size, : seq_len] = bool_mask
def get_kvcache(self) -> List[Tuple[Tensor, Tensor]]:
return self._kv_cache
def get_seq_mask(self) -> Tensor:
return self._seq_mask
class TextGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter):
@ -202,17 +83,11 @@ class TextGenerator(GeneratorCore):
start_cache_pos = len(ids)
cur_cache_pos = 0
self.model.eval()
while len(ids) < self.config.m_len:
kv_caches = cache_manager.get_kvcache()
logits, cache_increase = self.compute_logits(
input_ids,
kv_caches=kv_caches,
start_pos=cur_cache_pos
)
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
probs = torch.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
for _ in range(len(ids), self.config.m_len):
next_token_id, cache_increase = self.generate_iterator(
input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos)
input_ids = next_token_id
ids.append(next_token_id.item())
@ -226,7 +101,6 @@ class TextGenerator(GeneratorCore):
return response
class ChatGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
@ -263,18 +137,11 @@ class ChatGenerator(GeneratorCore):
start_cache_pos = len(ids)
cur_cache_pos = 0
self.model.eval()
while len(ids) < self.config.m_len:
kv_caches = cache_manager.get_kvcache()
logits, cache_increase = self.compute_logits(
input_ids,
kv_caches=kv_caches,
start_pos=cur_cache_pos
)
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
probs = torch.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
for _ in range(len(ids), self.config.m_len):
next_token_id, cache_increase = self.generate_iterator(
input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos)
input_ids = next_token_id
ids.append(next_token_id.item())
@ -325,18 +192,11 @@ class StreamGenerator(GeneratorCore):
start_cache_pos = len(ids)
cur_cache_pos = 0
self.model.eval()
while len(ids) < self.config.m_len:
kv_caches = cache_manager.get_kvcache()
logits, cache_increase = self.compute_logits(
input_ids,
kv_caches=kv_caches,
start_pos=cur_cache_pos
)
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
probs = torch.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
for _ in range(len(ids), self.config.m_len):
next_token_id, cache_increase = self.generate_iterator(
input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos)
input_ids = next_token_id
ids.append(next_token_id.item())
@ -397,18 +257,10 @@ class BatchGenerator(GeneratorCore):
kv_caches = cache_manager.get_kvcache()
attn_mask =cache_manager.get_seq_mask()
logits, cache_increase = self.compute_logits(
input_tensor,
attn_mask=attn_mask,
kv_caches=kv_caches,
start_pos=cur_cache_pos
)
next_token_id, cache_increase = self.generate_iterator(
input_tensor, temperature, top_k, top_p, attn_mask=attn_mask, kv_caches=kv_caches, start_pos=cur_cache_pos)
cur_cache_pos += cache_increase
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
probs = torch.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
active_mask = []
c_ids = 0
@ -437,7 +289,6 @@ class BatchGenerator(GeneratorCore):
return responses
class RetrievalGenerator(GeneratorCore):
def __init__(self, retriever_parameter: ModelParameter):
super().__init__(retriever_parameter)