feat(inference): 实现采样策略并优化生成器逻辑

This commit is contained in:
ViperEkura 2025-10-20 13:00:41 +08:00
parent 98efca7b9d
commit e72e244df6
2 changed files with 151 additions and 175 deletions

View File

@ -1,29 +1,86 @@
import torch import torch
from torch import Tensor from torch import Tensor
from typing import List, Tuple, Union, Optional, Generator, Self from typing import List, Tuple, Union, Optional, Self
from khaosz.config.param_config import ModelParameter from khaosz.config.param_config import ModelParameter
def apply_sampling_strategies(
logits: Tensor,
temperature: float,
top_k: int,
top_p: float,
filter_value: float = -float("inf")
) -> Tensor:
"""
Apply sampling strategies to the logits tensor.
Args:
logits (Tensor): The logits tensor.
temperature (float): The temperature parameter.
top_k (int): The top-k parameter.
top_p (float): The top-p parameter.
filter_value (float, optional): The filter value. Defaults to -float("inf").
Returns:
Tensor: The sampled logits tensor.
"""
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
top_k = min(top_k, logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(
dim=1,
index=sorted_indices,
src=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
class GeneratorCore: class GeneratorCore:
def __init__(self, parameter: ModelParameter): def __init__(self, parameter: ModelParameter):
self.model = parameter.model self.model = parameter.model
self.tokenizer = parameter.tokenizer self.tokenizer = parameter.tokenizer
self.config = parameter.config self.config = parameter.config
def compute_logits( def generate_iterator(
self, self,
input_ids: Tensor, input_ids: Tensor,
temperature: float,
top_k: int,
top_p: float,
attn_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None, kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
start_pos: int = 0 start_pos: int = 0
) -> Tuple[Tensor, int]: )-> Tuple[Tensor, int]:
with torch.inference_mode(): with torch.inference_mode():
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos) outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
logits = outputs["logits"][:, -1, :] logits = outputs["logits"][:, -1, :]
cache_increase = input_ids.size(-1) cache_increase = input_ids.size(-1)
return logits, cache_increase logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
probs = torch.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
return next_token_id, cache_increase
def to(self, *args, **kargs) -> Self: def to(self, *args, **kargs) -> Self:
self.model.to(*args, **kargs) self.model.to(*args, **kargs)
@ -95,3 +152,71 @@ class EmbeddingEncoderCore:
def to(self, *args, **kargs) -> Self: def to(self, *args, **kargs) -> Self:
self.model.to(*args, **kargs) self.model.to(*args, **kargs)
return self return self
class KVCacheManager:
def __init__(
self,
num_layers: int,
batch_size: int,
max_len: int,
num_heads: int,
head_dim: int,
device: torch.device = "cuda",
dtype: torch.dtype = torch.bfloat16
):
self.num_layers = num_layers
self.batch_size = batch_size
self.max_len = max_len
self.num_heads = num_heads
self.head_dim = head_dim
self.device = device
self.dtype = dtype
self._kv_cache: List[Tuple[Tensor, Tensor]] = None
self._seq_mask: Tensor = None
self._initialize()
def _initialize(self):
self._kv_cache = []
for _ in range(self.num_layers):
k_cache = torch.zeros(
(self.batch_size, self.max_len, self.num_heads, self.head_dim),
device=self.device, dtype=self.dtype
)
v_cache = torch.zeros(
(self.batch_size, self.max_len, self.num_heads, self.head_dim),
device=self.device, dtype=self.dtype
)
self._kv_cache.append((k_cache, v_cache))
self._seq_mask = torch.ones(
(self.batch_size, self.max_len),
device=self.device, dtype=torch.bool
)
def update(self, active_mask: Tensor):
for i in range(self.num_layers):
k_cache, v_cache = self._kv_cache[i]
new_k_cache, new_v_cache = k_cache[active_mask], v_cache[active_mask]
self._kv_cache[i] = (new_k_cache, new_v_cache)
self._seq_mask = self._seq_mask[active_mask]
def reset(self, full_reset=False):
if full_reset:
self._kv_cache = None
self._seq_mask = None
else:
self._initialize()
def set_seq_mask(self, input_ids: Tensor, pad_id: int):
batch_size, seq_len = input_ids.shape
bool_mask = (input_ids != pad_id)
self._seq_mask[: batch_size, : seq_len] = bool_mask
def get_kvcache(self) -> List[Tuple[Tensor, Tensor]]:
return self._kv_cache
def get_seq_mask(self) -> Tensor:
return self._seq_mask

View File

@ -1,7 +1,7 @@
import torch import torch
from torch import Tensor from torch import Tensor
from typing import List, Tuple, Union, Optional, Generator from typing import List, Tuple, Union, Optional, Generator
from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheManager
from khaosz.config.param_config import ModelParameter from khaosz.config.param_config import ModelParameter
@ -51,125 +51,6 @@ def pad_sequence(ids_list: List[List[int]], max_ids_len: int, pad_id: int) -> Li
return new_ids_list return new_ids_list
def apply_sampling_strategies(
logits: Tensor,
temperature: float,
top_k: int,
top_p: float,
filter_value: float = -float("inf")
) -> Tensor:
"""
Apply sampling strategies to the logits tensor.
Args:
logits (Tensor): The logits tensor.
temperature (float): The temperature parameter.
top_k (int): The top-k parameter.
top_p (float): The top-p parameter.
filter_value (float, optional): The filter value. Defaults to -float("inf").
Returns:
Tensor: The sampled logits tensor.
"""
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
top_k = min(top_k, logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(
dim=1,
index=sorted_indices,
src=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
class KVCacheManager:
def __init__(
self,
num_layers: int,
batch_size: int,
max_len: int,
num_heads: int,
head_dim: int,
device: torch.device = "cuda",
dtype: torch.dtype = torch.bfloat16
):
self.num_layers = num_layers
self.batch_size = batch_size
self.max_len = max_len
self.num_heads = num_heads
self.head_dim = head_dim
self.device = device
self.dtype = dtype
self._kv_cache: List[Tuple[Tensor, Tensor]] = None
self._seq_mask: Tensor = None
self._initialize()
def _initialize(self):
self._kv_cache = []
for _ in range(self.num_layers):
k_cache = torch.zeros(
(self.batch_size, self.max_len, self.num_heads, self.head_dim),
device=self.device, dtype=self.dtype
)
v_cache = torch.zeros(
(self.batch_size, self.max_len, self.num_heads, self.head_dim),
device=self.device, dtype=self.dtype
)
self._kv_cache.append((k_cache, v_cache))
self._seq_mask = torch.ones(
(self.batch_size, self.max_len),
device=self.device, dtype=torch.bool
)
def update(self, active_mask: Tensor):
for i in range(self.num_layers):
k_cache, v_cache = self._kv_cache[i]
new_k_cache, new_v_cache = k_cache[active_mask], v_cache[active_mask]
self._kv_cache[i] = (new_k_cache, new_v_cache)
self._seq_mask = self._seq_mask[active_mask]
def reset(self, full_reset=False):
if full_reset:
self._kv_cache = None
self._seq_mask = None
else:
self._initialize()
def set_seq_mask(self, input_ids: Tensor, pad_id: int):
batch_size, seq_len = input_ids.shape
bool_mask = (input_ids != pad_id)
self._seq_mask[: batch_size, : seq_len] = bool_mask
def get_kvcache(self) -> List[Tuple[Tensor, Tensor]]:
return self._kv_cache
def get_seq_mask(self) -> Tensor:
return self._seq_mask
class TextGenerator(GeneratorCore): class TextGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter): def __init__(self, parameter: ModelParameter):
@ -202,17 +83,11 @@ class TextGenerator(GeneratorCore):
start_cache_pos = len(ids) start_cache_pos = len(ids)
cur_cache_pos = 0 cur_cache_pos = 0
self.model.eval() self.model.eval()
while len(ids) < self.config.m_len:
kv_caches = cache_manager.get_kvcache() kv_caches = cache_manager.get_kvcache()
logits, cache_increase = self.compute_logits(
input_ids, for _ in range(len(ids), self.config.m_len):
kv_caches=kv_caches, next_token_id, cache_increase = self.generate_iterator(
start_pos=cur_cache_pos input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos)
)
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
probs = torch.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
input_ids = next_token_id input_ids = next_token_id
ids.append(next_token_id.item()) ids.append(next_token_id.item())
@ -226,7 +101,6 @@ class TextGenerator(GeneratorCore):
return response return response
class ChatGenerator(GeneratorCore): class ChatGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter): def __init__(self, parameter: ModelParameter):
super().__init__(parameter) super().__init__(parameter)
@ -263,18 +137,11 @@ class ChatGenerator(GeneratorCore):
start_cache_pos = len(ids) start_cache_pos = len(ids)
cur_cache_pos = 0 cur_cache_pos = 0
self.model.eval() self.model.eval()
while len(ids) < self.config.m_len:
kv_caches = cache_manager.get_kvcache() kv_caches = cache_manager.get_kvcache()
logits, cache_increase = self.compute_logits(
input_ids, for _ in range(len(ids), self.config.m_len):
kv_caches=kv_caches, next_token_id, cache_increase = self.generate_iterator(
start_pos=cur_cache_pos input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos)
)
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
probs = torch.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
input_ids = next_token_id input_ids = next_token_id
ids.append(next_token_id.item()) ids.append(next_token_id.item())
@ -325,18 +192,11 @@ class StreamGenerator(GeneratorCore):
start_cache_pos = len(ids) start_cache_pos = len(ids)
cur_cache_pos = 0 cur_cache_pos = 0
self.model.eval() self.model.eval()
while len(ids) < self.config.m_len:
kv_caches = cache_manager.get_kvcache() kv_caches = cache_manager.get_kvcache()
logits, cache_increase = self.compute_logits(
input_ids, for _ in range(len(ids), self.config.m_len):
kv_caches=kv_caches, next_token_id, cache_increase = self.generate_iterator(
start_pos=cur_cache_pos input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos)
)
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
probs = torch.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
input_ids = next_token_id input_ids = next_token_id
ids.append(next_token_id.item()) ids.append(next_token_id.item())
@ -397,18 +257,10 @@ class BatchGenerator(GeneratorCore):
kv_caches = cache_manager.get_kvcache() kv_caches = cache_manager.get_kvcache()
attn_mask =cache_manager.get_seq_mask() attn_mask =cache_manager.get_seq_mask()
logits, cache_increase = self.compute_logits( next_token_id, cache_increase = self.generate_iterator(
input_tensor, input_tensor, temperature, top_k, top_p, attn_mask=attn_mask, kv_caches=kv_caches, start_pos=cur_cache_pos)
attn_mask=attn_mask,
kv_caches=kv_caches,
start_pos=cur_cache_pos
)
cur_cache_pos += cache_increase cur_cache_pos += cache_increase
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
probs = torch.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
active_mask = [] active_mask = []
c_ids = 0 c_ids = 0
@ -437,7 +289,6 @@ class BatchGenerator(GeneratorCore):
return responses return responses
class RetrievalGenerator(GeneratorCore): class RetrievalGenerator(GeneratorCore):
def __init__(self, retriever_parameter: ModelParameter): def __init__(self, retriever_parameter: ModelParameter):
super().__init__(retriever_parameter) super().__init__(retriever_parameter)