feat(inference): 实现采样策略并优化生成器逻辑
This commit is contained in:
parent
98efca7b9d
commit
e72e244df6
|
|
@ -1,29 +1,86 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import List, Tuple, Union, Optional, Generator, Self
|
from typing import List, Tuple, Union, Optional, Self
|
||||||
from khaosz.config.param_config import ModelParameter
|
from khaosz.config.param_config import ModelParameter
|
||||||
|
|
||||||
|
|
||||||
|
def apply_sampling_strategies(
|
||||||
|
logits: Tensor,
|
||||||
|
temperature: float,
|
||||||
|
top_k: int,
|
||||||
|
top_p: float,
|
||||||
|
filter_value: float = -float("inf")
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
Apply sampling strategies to the logits tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits (Tensor): The logits tensor.
|
||||||
|
temperature (float): The temperature parameter.
|
||||||
|
top_k (int): The top-k parameter.
|
||||||
|
top_p (float): The top-p parameter.
|
||||||
|
filter_value (float, optional): The filter value. Defaults to -float("inf").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: The sampled logits tensor.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
if temperature != 1.0:
|
||||||
|
logits = logits / temperature
|
||||||
|
|
||||||
|
if top_k > 0:
|
||||||
|
top_k = min(top_k, logits.size(-1))
|
||||||
|
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
|
||||||
|
logits[indices_to_remove] = filter_value
|
||||||
|
|
||||||
|
if top_p < 1.0:
|
||||||
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||||
|
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||||
|
|
||||||
|
sorted_indices_to_remove = cumulative_probs > top_p
|
||||||
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||||
|
sorted_indices_to_remove[..., 0] = 0
|
||||||
|
|
||||||
|
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
|
||||||
|
indices_to_remove.scatter_(
|
||||||
|
dim=1,
|
||||||
|
index=sorted_indices,
|
||||||
|
src=sorted_indices_to_remove
|
||||||
|
)
|
||||||
|
|
||||||
|
logits[indices_to_remove] = filter_value
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
class GeneratorCore:
|
class GeneratorCore:
|
||||||
def __init__(self, parameter: ModelParameter):
|
def __init__(self, parameter: ModelParameter):
|
||||||
self.model = parameter.model
|
self.model = parameter.model
|
||||||
self.tokenizer = parameter.tokenizer
|
self.tokenizer = parameter.tokenizer
|
||||||
self.config = parameter.config
|
self.config = parameter.config
|
||||||
|
|
||||||
def compute_logits(
|
def generate_iterator(
|
||||||
self,
|
self,
|
||||||
input_ids: Tensor,
|
input_ids: Tensor,
|
||||||
|
temperature: float,
|
||||||
|
top_k: int,
|
||||||
|
top_p: float,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
|
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
|
||||||
start_pos: int = 0
|
start_pos: int = 0
|
||||||
) -> Tuple[Tensor, int]:
|
)-> Tuple[Tensor, int]:
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
|
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
|
||||||
logits = outputs["logits"][:, -1, :]
|
logits = outputs["logits"][:, -1, :]
|
||||||
cache_increase = input_ids.size(-1)
|
cache_increase = input_ids.size(-1)
|
||||||
|
|
||||||
return logits, cache_increase
|
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
|
||||||
|
probs = torch.softmax(logits, dim=-1)
|
||||||
|
next_token_id = torch.multinomial(probs, num_samples=1)
|
||||||
|
|
||||||
|
return next_token_id, cache_increase
|
||||||
|
|
||||||
def to(self, *args, **kargs) -> Self:
|
def to(self, *args, **kargs) -> Self:
|
||||||
self.model.to(*args, **kargs)
|
self.model.to(*args, **kargs)
|
||||||
|
|
@ -95,3 +152,71 @@ class EmbeddingEncoderCore:
|
||||||
def to(self, *args, **kargs) -> Self:
|
def to(self, *args, **kargs) -> Self:
|
||||||
self.model.to(*args, **kargs)
|
self.model.to(*args, **kargs)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class KVCacheManager:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_layers: int,
|
||||||
|
batch_size: int,
|
||||||
|
max_len: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
device: torch.device = "cuda",
|
||||||
|
dtype: torch.dtype = torch.bfloat16
|
||||||
|
):
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.max_len = max_len
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.device = device
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
self._kv_cache: List[Tuple[Tensor, Tensor]] = None
|
||||||
|
self._seq_mask: Tensor = None
|
||||||
|
self._initialize()
|
||||||
|
|
||||||
|
def _initialize(self):
|
||||||
|
self._kv_cache = []
|
||||||
|
for _ in range(self.num_layers):
|
||||||
|
k_cache = torch.zeros(
|
||||||
|
(self.batch_size, self.max_len, self.num_heads, self.head_dim),
|
||||||
|
device=self.device, dtype=self.dtype
|
||||||
|
)
|
||||||
|
v_cache = torch.zeros(
|
||||||
|
(self.batch_size, self.max_len, self.num_heads, self.head_dim),
|
||||||
|
device=self.device, dtype=self.dtype
|
||||||
|
)
|
||||||
|
self._kv_cache.append((k_cache, v_cache))
|
||||||
|
|
||||||
|
self._seq_mask = torch.ones(
|
||||||
|
(self.batch_size, self.max_len),
|
||||||
|
device=self.device, dtype=torch.bool
|
||||||
|
)
|
||||||
|
|
||||||
|
def update(self, active_mask: Tensor):
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
k_cache, v_cache = self._kv_cache[i]
|
||||||
|
new_k_cache, new_v_cache = k_cache[active_mask], v_cache[active_mask]
|
||||||
|
self._kv_cache[i] = (new_k_cache, new_v_cache)
|
||||||
|
|
||||||
|
self._seq_mask = self._seq_mask[active_mask]
|
||||||
|
|
||||||
|
def reset(self, full_reset=False):
|
||||||
|
if full_reset:
|
||||||
|
self._kv_cache = None
|
||||||
|
self._seq_mask = None
|
||||||
|
else:
|
||||||
|
self._initialize()
|
||||||
|
|
||||||
|
def set_seq_mask(self, input_ids: Tensor, pad_id: int):
|
||||||
|
batch_size, seq_len = input_ids.shape
|
||||||
|
bool_mask = (input_ids != pad_id)
|
||||||
|
self._seq_mask[: batch_size, : seq_len] = bool_mask
|
||||||
|
|
||||||
|
def get_kvcache(self) -> List[Tuple[Tensor, Tensor]]:
|
||||||
|
return self._kv_cache
|
||||||
|
|
||||||
|
def get_seq_mask(self) -> Tensor:
|
||||||
|
return self._seq_mask
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import List, Tuple, Union, Optional, Generator
|
from typing import List, Tuple, Union, Optional, Generator
|
||||||
from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore
|
from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheManager
|
||||||
from khaosz.config.param_config import ModelParameter
|
from khaosz.config.param_config import ModelParameter
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -51,125 +51,6 @@ def pad_sequence(ids_list: List[List[int]], max_ids_len: int, pad_id: int) -> Li
|
||||||
|
|
||||||
return new_ids_list
|
return new_ids_list
|
||||||
|
|
||||||
def apply_sampling_strategies(
|
|
||||||
logits: Tensor,
|
|
||||||
temperature: float,
|
|
||||||
top_k: int,
|
|
||||||
top_p: float,
|
|
||||||
filter_value: float = -float("inf")
|
|
||||||
) -> Tensor:
|
|
||||||
"""
|
|
||||||
Apply sampling strategies to the logits tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
logits (Tensor): The logits tensor.
|
|
||||||
temperature (float): The temperature parameter.
|
|
||||||
top_k (int): The top-k parameter.
|
|
||||||
top_p (float): The top-p parameter.
|
|
||||||
filter_value (float, optional): The filter value. Defaults to -float("inf").
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: The sampled logits tensor.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
if temperature != 1.0:
|
|
||||||
logits = logits / temperature
|
|
||||||
|
|
||||||
if top_k > 0:
|
|
||||||
top_k = min(top_k, logits.size(-1))
|
|
||||||
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
|
|
||||||
logits[indices_to_remove] = filter_value
|
|
||||||
|
|
||||||
if top_p < 1.0:
|
|
||||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
|
||||||
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
|
||||||
|
|
||||||
sorted_indices_to_remove = cumulative_probs > top_p
|
|
||||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
||||||
sorted_indices_to_remove[..., 0] = 0
|
|
||||||
|
|
||||||
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
|
|
||||||
indices_to_remove.scatter_(
|
|
||||||
dim=1,
|
|
||||||
index=sorted_indices,
|
|
||||||
src=sorted_indices_to_remove
|
|
||||||
)
|
|
||||||
|
|
||||||
logits[indices_to_remove] = filter_value
|
|
||||||
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
class KVCacheManager:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_layers: int,
|
|
||||||
batch_size: int,
|
|
||||||
max_len: int,
|
|
||||||
num_heads: int,
|
|
||||||
head_dim: int,
|
|
||||||
device: torch.device = "cuda",
|
|
||||||
dtype: torch.dtype = torch.bfloat16
|
|
||||||
):
|
|
||||||
self.num_layers = num_layers
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.max_len = max_len
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_dim = head_dim
|
|
||||||
self.device = device
|
|
||||||
self.dtype = dtype
|
|
||||||
|
|
||||||
self._kv_cache: List[Tuple[Tensor, Tensor]] = None
|
|
||||||
self._seq_mask: Tensor = None
|
|
||||||
self._initialize()
|
|
||||||
|
|
||||||
def _initialize(self):
|
|
||||||
self._kv_cache = []
|
|
||||||
for _ in range(self.num_layers):
|
|
||||||
k_cache = torch.zeros(
|
|
||||||
(self.batch_size, self.max_len, self.num_heads, self.head_dim),
|
|
||||||
device=self.device, dtype=self.dtype
|
|
||||||
)
|
|
||||||
v_cache = torch.zeros(
|
|
||||||
(self.batch_size, self.max_len, self.num_heads, self.head_dim),
|
|
||||||
device=self.device, dtype=self.dtype
|
|
||||||
)
|
|
||||||
self._kv_cache.append((k_cache, v_cache))
|
|
||||||
|
|
||||||
self._seq_mask = torch.ones(
|
|
||||||
(self.batch_size, self.max_len),
|
|
||||||
device=self.device, dtype=torch.bool
|
|
||||||
)
|
|
||||||
|
|
||||||
def update(self, active_mask: Tensor):
|
|
||||||
for i in range(self.num_layers):
|
|
||||||
k_cache, v_cache = self._kv_cache[i]
|
|
||||||
new_k_cache, new_v_cache = k_cache[active_mask], v_cache[active_mask]
|
|
||||||
self._kv_cache[i] = (new_k_cache, new_v_cache)
|
|
||||||
|
|
||||||
self._seq_mask = self._seq_mask[active_mask]
|
|
||||||
|
|
||||||
def reset(self, full_reset=False):
|
|
||||||
if full_reset:
|
|
||||||
self._kv_cache = None
|
|
||||||
self._seq_mask = None
|
|
||||||
else:
|
|
||||||
self._initialize()
|
|
||||||
|
|
||||||
def set_seq_mask(self, input_ids: Tensor, pad_id: int):
|
|
||||||
batch_size, seq_len = input_ids.shape
|
|
||||||
bool_mask = (input_ids != pad_id)
|
|
||||||
self._seq_mask[: batch_size, : seq_len] = bool_mask
|
|
||||||
|
|
||||||
def get_kvcache(self) -> List[Tuple[Tensor, Tensor]]:
|
|
||||||
return self._kv_cache
|
|
||||||
|
|
||||||
def get_seq_mask(self) -> Tensor:
|
|
||||||
return self._seq_mask
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TextGenerator(GeneratorCore):
|
class TextGenerator(GeneratorCore):
|
||||||
def __init__(self, parameter: ModelParameter):
|
def __init__(self, parameter: ModelParameter):
|
||||||
|
|
@ -202,17 +83,11 @@ class TextGenerator(GeneratorCore):
|
||||||
start_cache_pos = len(ids)
|
start_cache_pos = len(ids)
|
||||||
cur_cache_pos = 0
|
cur_cache_pos = 0
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
kv_caches = cache_manager.get_kvcache()
|
||||||
|
|
||||||
while len(ids) < self.config.m_len:
|
for _ in range(len(ids), self.config.m_len):
|
||||||
kv_caches = cache_manager.get_kvcache()
|
next_token_id, cache_increase = self.generate_iterator(
|
||||||
logits, cache_increase = self.compute_logits(
|
input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos)
|
||||||
input_ids,
|
|
||||||
kv_caches=kv_caches,
|
|
||||||
start_pos=cur_cache_pos
|
|
||||||
)
|
|
||||||
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
|
|
||||||
probs = torch.softmax(logits, dim=-1)
|
|
||||||
next_token_id = torch.multinomial(probs, num_samples=1)
|
|
||||||
|
|
||||||
input_ids = next_token_id
|
input_ids = next_token_id
|
||||||
ids.append(next_token_id.item())
|
ids.append(next_token_id.item())
|
||||||
|
|
@ -226,7 +101,6 @@ class TextGenerator(GeneratorCore):
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ChatGenerator(GeneratorCore):
|
class ChatGenerator(GeneratorCore):
|
||||||
def __init__(self, parameter: ModelParameter):
|
def __init__(self, parameter: ModelParameter):
|
||||||
super().__init__(parameter)
|
super().__init__(parameter)
|
||||||
|
|
@ -263,18 +137,11 @@ class ChatGenerator(GeneratorCore):
|
||||||
start_cache_pos = len(ids)
|
start_cache_pos = len(ids)
|
||||||
cur_cache_pos = 0
|
cur_cache_pos = 0
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
kv_caches = cache_manager.get_kvcache()
|
||||||
|
|
||||||
|
for _ in range(len(ids), self.config.m_len):
|
||||||
while len(ids) < self.config.m_len:
|
next_token_id, cache_increase = self.generate_iterator(
|
||||||
kv_caches = cache_manager.get_kvcache()
|
input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos)
|
||||||
logits, cache_increase = self.compute_logits(
|
|
||||||
input_ids,
|
|
||||||
kv_caches=kv_caches,
|
|
||||||
start_pos=cur_cache_pos
|
|
||||||
)
|
|
||||||
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
|
|
||||||
probs = torch.softmax(logits, dim=-1)
|
|
||||||
next_token_id = torch.multinomial(probs, num_samples=1)
|
|
||||||
|
|
||||||
input_ids = next_token_id
|
input_ids = next_token_id
|
||||||
ids.append(next_token_id.item())
|
ids.append(next_token_id.item())
|
||||||
|
|
@ -325,18 +192,11 @@ class StreamGenerator(GeneratorCore):
|
||||||
start_cache_pos = len(ids)
|
start_cache_pos = len(ids)
|
||||||
cur_cache_pos = 0
|
cur_cache_pos = 0
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
kv_caches = cache_manager.get_kvcache()
|
||||||
|
|
||||||
|
for _ in range(len(ids), self.config.m_len):
|
||||||
while len(ids) < self.config.m_len:
|
next_token_id, cache_increase = self.generate_iterator(
|
||||||
kv_caches = cache_manager.get_kvcache()
|
input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos)
|
||||||
logits, cache_increase = self.compute_logits(
|
|
||||||
input_ids,
|
|
||||||
kv_caches=kv_caches,
|
|
||||||
start_pos=cur_cache_pos
|
|
||||||
)
|
|
||||||
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
|
|
||||||
probs = torch.softmax(logits, dim=-1)
|
|
||||||
next_token_id = torch.multinomial(probs, num_samples=1)
|
|
||||||
|
|
||||||
input_ids = next_token_id
|
input_ids = next_token_id
|
||||||
ids.append(next_token_id.item())
|
ids.append(next_token_id.item())
|
||||||
|
|
@ -397,18 +257,10 @@ class BatchGenerator(GeneratorCore):
|
||||||
kv_caches = cache_manager.get_kvcache()
|
kv_caches = cache_manager.get_kvcache()
|
||||||
attn_mask =cache_manager.get_seq_mask()
|
attn_mask =cache_manager.get_seq_mask()
|
||||||
|
|
||||||
logits, cache_increase = self.compute_logits(
|
next_token_id, cache_increase = self.generate_iterator(
|
||||||
input_tensor,
|
input_tensor, temperature, top_k, top_p, attn_mask=attn_mask, kv_caches=kv_caches, start_pos=cur_cache_pos)
|
||||||
attn_mask=attn_mask,
|
|
||||||
kv_caches=kv_caches,
|
|
||||||
start_pos=cur_cache_pos
|
|
||||||
)
|
|
||||||
|
|
||||||
cur_cache_pos += cache_increase
|
cur_cache_pos += cache_increase
|
||||||
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
|
|
||||||
probs = torch.softmax(logits, dim=-1)
|
|
||||||
next_token_id = torch.multinomial(probs, num_samples=1)
|
|
||||||
|
|
||||||
active_mask = []
|
active_mask = []
|
||||||
c_ids = 0
|
c_ids = 0
|
||||||
|
|
||||||
|
|
@ -437,7 +289,6 @@ class BatchGenerator(GeneratorCore):
|
||||||
return responses
|
return responses
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RetrievalGenerator(GeneratorCore):
|
class RetrievalGenerator(GeneratorCore):
|
||||||
def __init__(self, retriever_parameter: ModelParameter):
|
def __init__(self, retriever_parameter: ModelParameter):
|
||||||
super().__init__(retriever_parameter)
|
super().__init__(retriever_parameter)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue