feat(inference): 实现采样策略并优化生成器逻辑
This commit is contained in:
parent
98efca7b9d
commit
e72e244df6
|
|
@ -1,29 +1,86 @@
|
|||
import torch
|
||||
|
||||
from torch import Tensor
|
||||
from typing import List, Tuple, Union, Optional, Generator, Self
|
||||
from typing import List, Tuple, Union, Optional, Self
|
||||
from khaosz.config.param_config import ModelParameter
|
||||
|
||||
|
||||
def apply_sampling_strategies(
|
||||
logits: Tensor,
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
filter_value: float = -float("inf")
|
||||
) -> Tensor:
|
||||
"""
|
||||
Apply sampling strategies to the logits tensor.
|
||||
|
||||
Args:
|
||||
logits (Tensor): The logits tensor.
|
||||
temperature (float): The temperature parameter.
|
||||
top_k (int): The top-k parameter.
|
||||
top_p (float): The top-p parameter.
|
||||
filter_value (float, optional): The filter value. Defaults to -float("inf").
|
||||
|
||||
Returns:
|
||||
Tensor: The sampled logits tensor.
|
||||
|
||||
"""
|
||||
|
||||
if temperature != 1.0:
|
||||
logits = logits / temperature
|
||||
|
||||
if top_k > 0:
|
||||
top_k = min(top_k, logits.size(-1))
|
||||
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
if top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
|
||||
indices_to_remove.scatter_(
|
||||
dim=1,
|
||||
index=sorted_indices,
|
||||
src=sorted_indices_to_remove
|
||||
)
|
||||
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
class GeneratorCore:
|
||||
def __init__(self, parameter: ModelParameter):
|
||||
self.model = parameter.model
|
||||
self.tokenizer = parameter.tokenizer
|
||||
self.config = parameter.config
|
||||
|
||||
def compute_logits(
|
||||
def generate_iterator(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
|
||||
start_pos: int = 0
|
||||
)-> Tuple[Tensor, int]:
|
||||
|
||||
with torch.inference_mode():
|
||||
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
|
||||
logits = outputs["logits"][:, -1, :]
|
||||
cache_increase = input_ids.size(-1)
|
||||
|
||||
return logits, cache_increase
|
||||
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
next_token_id = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
return next_token_id, cache_increase
|
||||
|
||||
def to(self, *args, **kargs) -> Self:
|
||||
self.model.to(*args, **kargs)
|
||||
|
|
@ -95,3 +152,71 @@ class EmbeddingEncoderCore:
|
|||
def to(self, *args, **kargs) -> Self:
|
||||
self.model.to(*args, **kargs)
|
||||
return self
|
||||
|
||||
|
||||
class KVCacheManager:
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int,
|
||||
batch_size: int,
|
||||
max_len: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
device: torch.device = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16
|
||||
):
|
||||
self.num_layers = num_layers
|
||||
self.batch_size = batch_size
|
||||
self.max_len = max_len
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
self._kv_cache: List[Tuple[Tensor, Tensor]] = None
|
||||
self._seq_mask: Tensor = None
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self):
|
||||
self._kv_cache = []
|
||||
for _ in range(self.num_layers):
|
||||
k_cache = torch.zeros(
|
||||
(self.batch_size, self.max_len, self.num_heads, self.head_dim),
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
v_cache = torch.zeros(
|
||||
(self.batch_size, self.max_len, self.num_heads, self.head_dim),
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
self._kv_cache.append((k_cache, v_cache))
|
||||
|
||||
self._seq_mask = torch.ones(
|
||||
(self.batch_size, self.max_len),
|
||||
device=self.device, dtype=torch.bool
|
||||
)
|
||||
|
||||
def update(self, active_mask: Tensor):
|
||||
for i in range(self.num_layers):
|
||||
k_cache, v_cache = self._kv_cache[i]
|
||||
new_k_cache, new_v_cache = k_cache[active_mask], v_cache[active_mask]
|
||||
self._kv_cache[i] = (new_k_cache, new_v_cache)
|
||||
|
||||
self._seq_mask = self._seq_mask[active_mask]
|
||||
|
||||
def reset(self, full_reset=False):
|
||||
if full_reset:
|
||||
self._kv_cache = None
|
||||
self._seq_mask = None
|
||||
else:
|
||||
self._initialize()
|
||||
|
||||
def set_seq_mask(self, input_ids: Tensor, pad_id: int):
|
||||
batch_size, seq_len = input_ids.shape
|
||||
bool_mask = (input_ids != pad_id)
|
||||
self._seq_mask[: batch_size, : seq_len] = bool_mask
|
||||
|
||||
def get_kvcache(self) -> List[Tuple[Tensor, Tensor]]:
|
||||
return self._kv_cache
|
||||
|
||||
def get_seq_mask(self) -> Tensor:
|
||||
return self._seq_mask
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
from typing import List, Tuple, Union, Optional, Generator
|
||||
from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore
|
||||
from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheManager
|
||||
from khaosz.config.param_config import ModelParameter
|
||||
|
||||
|
||||
|
|
@ -51,125 +51,6 @@ def pad_sequence(ids_list: List[List[int]], max_ids_len: int, pad_id: int) -> Li
|
|||
|
||||
return new_ids_list
|
||||
|
||||
def apply_sampling_strategies(
|
||||
logits: Tensor,
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
filter_value: float = -float("inf")
|
||||
) -> Tensor:
|
||||
"""
|
||||
Apply sampling strategies to the logits tensor.
|
||||
|
||||
Args:
|
||||
logits (Tensor): The logits tensor.
|
||||
temperature (float): The temperature parameter.
|
||||
top_k (int): The top-k parameter.
|
||||
top_p (float): The top-p parameter.
|
||||
filter_value (float, optional): The filter value. Defaults to -float("inf").
|
||||
|
||||
Returns:
|
||||
Tensor: The sampled logits tensor.
|
||||
|
||||
"""
|
||||
|
||||
if temperature != 1.0:
|
||||
logits = logits / temperature
|
||||
|
||||
if top_k > 0:
|
||||
top_k = min(top_k, logits.size(-1))
|
||||
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
if top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
|
||||
indices_to_remove.scatter_(
|
||||
dim=1,
|
||||
index=sorted_indices,
|
||||
src=sorted_indices_to_remove
|
||||
)
|
||||
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
class KVCacheManager:
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int,
|
||||
batch_size: int,
|
||||
max_len: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
device: torch.device = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16
|
||||
):
|
||||
self.num_layers = num_layers
|
||||
self.batch_size = batch_size
|
||||
self.max_len = max_len
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
self._kv_cache: List[Tuple[Tensor, Tensor]] = None
|
||||
self._seq_mask: Tensor = None
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self):
|
||||
self._kv_cache = []
|
||||
for _ in range(self.num_layers):
|
||||
k_cache = torch.zeros(
|
||||
(self.batch_size, self.max_len, self.num_heads, self.head_dim),
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
v_cache = torch.zeros(
|
||||
(self.batch_size, self.max_len, self.num_heads, self.head_dim),
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
self._kv_cache.append((k_cache, v_cache))
|
||||
|
||||
self._seq_mask = torch.ones(
|
||||
(self.batch_size, self.max_len),
|
||||
device=self.device, dtype=torch.bool
|
||||
)
|
||||
|
||||
def update(self, active_mask: Tensor):
|
||||
for i in range(self.num_layers):
|
||||
k_cache, v_cache = self._kv_cache[i]
|
||||
new_k_cache, new_v_cache = k_cache[active_mask], v_cache[active_mask]
|
||||
self._kv_cache[i] = (new_k_cache, new_v_cache)
|
||||
|
||||
self._seq_mask = self._seq_mask[active_mask]
|
||||
|
||||
def reset(self, full_reset=False):
|
||||
if full_reset:
|
||||
self._kv_cache = None
|
||||
self._seq_mask = None
|
||||
else:
|
||||
self._initialize()
|
||||
|
||||
def set_seq_mask(self, input_ids: Tensor, pad_id: int):
|
||||
batch_size, seq_len = input_ids.shape
|
||||
bool_mask = (input_ids != pad_id)
|
||||
self._seq_mask[: batch_size, : seq_len] = bool_mask
|
||||
|
||||
def get_kvcache(self) -> List[Tuple[Tensor, Tensor]]:
|
||||
return self._kv_cache
|
||||
|
||||
def get_seq_mask(self) -> Tensor:
|
||||
return self._seq_mask
|
||||
|
||||
|
||||
|
||||
|
||||
class TextGenerator(GeneratorCore):
|
||||
def __init__(self, parameter: ModelParameter):
|
||||
|
|
@ -202,17 +83,11 @@ class TextGenerator(GeneratorCore):
|
|||
start_cache_pos = len(ids)
|
||||
cur_cache_pos = 0
|
||||
self.model.eval()
|
||||
|
||||
while len(ids) < self.config.m_len:
|
||||
kv_caches = cache_manager.get_kvcache()
|
||||
logits, cache_increase = self.compute_logits(
|
||||
input_ids,
|
||||
kv_caches=kv_caches,
|
||||
start_pos=cur_cache_pos
|
||||
)
|
||||
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
next_token_id = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
for _ in range(len(ids), self.config.m_len):
|
||||
next_token_id, cache_increase = self.generate_iterator(
|
||||
input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos)
|
||||
|
||||
input_ids = next_token_id
|
||||
ids.append(next_token_id.item())
|
||||
|
|
@ -226,7 +101,6 @@ class TextGenerator(GeneratorCore):
|
|||
return response
|
||||
|
||||
|
||||
|
||||
class ChatGenerator(GeneratorCore):
|
||||
def __init__(self, parameter: ModelParameter):
|
||||
super().__init__(parameter)
|
||||
|
|
@ -263,18 +137,11 @@ class ChatGenerator(GeneratorCore):
|
|||
start_cache_pos = len(ids)
|
||||
cur_cache_pos = 0
|
||||
self.model.eval()
|
||||
|
||||
|
||||
while len(ids) < self.config.m_len:
|
||||
kv_caches = cache_manager.get_kvcache()
|
||||
logits, cache_increase = self.compute_logits(
|
||||
input_ids,
|
||||
kv_caches=kv_caches,
|
||||
start_pos=cur_cache_pos
|
||||
)
|
||||
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
next_token_id = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
for _ in range(len(ids), self.config.m_len):
|
||||
next_token_id, cache_increase = self.generate_iterator(
|
||||
input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos)
|
||||
|
||||
input_ids = next_token_id
|
||||
ids.append(next_token_id.item())
|
||||
|
|
@ -325,18 +192,11 @@ class StreamGenerator(GeneratorCore):
|
|||
start_cache_pos = len(ids)
|
||||
cur_cache_pos = 0
|
||||
self.model.eval()
|
||||
|
||||
|
||||
while len(ids) < self.config.m_len:
|
||||
kv_caches = cache_manager.get_kvcache()
|
||||
logits, cache_increase = self.compute_logits(
|
||||
input_ids,
|
||||
kv_caches=kv_caches,
|
||||
start_pos=cur_cache_pos
|
||||
)
|
||||
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
next_token_id = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
for _ in range(len(ids), self.config.m_len):
|
||||
next_token_id, cache_increase = self.generate_iterator(
|
||||
input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos)
|
||||
|
||||
input_ids = next_token_id
|
||||
ids.append(next_token_id.item())
|
||||
|
|
@ -397,18 +257,10 @@ class BatchGenerator(GeneratorCore):
|
|||
kv_caches = cache_manager.get_kvcache()
|
||||
attn_mask =cache_manager.get_seq_mask()
|
||||
|
||||
logits, cache_increase = self.compute_logits(
|
||||
input_tensor,
|
||||
attn_mask=attn_mask,
|
||||
kv_caches=kv_caches,
|
||||
start_pos=cur_cache_pos
|
||||
)
|
||||
next_token_id, cache_increase = self.generate_iterator(
|
||||
input_tensor, temperature, top_k, top_p, attn_mask=attn_mask, kv_caches=kv_caches, start_pos=cur_cache_pos)
|
||||
|
||||
cur_cache_pos += cache_increase
|
||||
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
next_token_id = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
active_mask = []
|
||||
c_ids = 0
|
||||
|
||||
|
|
@ -437,7 +289,6 @@ class BatchGenerator(GeneratorCore):
|
|||
return responses
|
||||
|
||||
|
||||
|
||||
class RetrievalGenerator(GeneratorCore):
|
||||
def __init__(self, retriever_parameter: ModelParameter):
|
||||
super().__init__(retriever_parameter)
|
||||
|
|
|
|||
Loading…
Reference in New Issue