240 lines
8.3 KiB
Python
240 lines
8.3 KiB
Python
import torch
|
|
from torch import Tensor
|
|
from typing import Any, Callable, List, Tuple, Union, Optional, Self
|
|
from khaosz.config import ModelParameter, TransformerConfig
|
|
|
|
|
|
def apply_sampling_strategies(
|
|
logits: Tensor,
|
|
temperature: float,
|
|
top_k: int,
|
|
top_p: float,
|
|
filter_value: float = -float("inf")
|
|
) -> Tensor:
|
|
"""
|
|
Apply sampling strategies to the logits tensor.
|
|
|
|
Args:
|
|
logits (Tensor): The logits tensor.
|
|
temperature (float): The temperature parameter.
|
|
top_k (int): The top-k parameter.
|
|
top_p (float): The top-p parameter.
|
|
filter_value (float, optional): The filter value. Defaults to -float("inf").
|
|
|
|
Returns:
|
|
Tensor: The sampled logits tensor.
|
|
|
|
"""
|
|
|
|
if temperature != 1.0:
|
|
logits = logits / temperature
|
|
|
|
if top_k > 0:
|
|
top_k = min(top_k, logits.size(-1))
|
|
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
|
|
logits[indices_to_remove] = filter_value
|
|
|
|
if top_p < 1.0:
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
|
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
sorted_indices_to_remove[..., 0] = 0
|
|
|
|
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
|
|
indices_to_remove.scatter_(
|
|
dim=1,
|
|
index=sorted_indices,
|
|
src=sorted_indices_to_remove
|
|
)
|
|
|
|
logits[indices_to_remove] = filter_value
|
|
|
|
return logits
|
|
|
|
|
|
class GeneratorCore:
|
|
def __init__(self, parameter: ModelParameter):
|
|
self.model = parameter.model
|
|
self.tokenizer = parameter.tokenizer
|
|
self.config = parameter.config
|
|
|
|
def generate_iterator(
|
|
self,
|
|
input_ids: Tensor,
|
|
temperature: float,
|
|
top_k: int,
|
|
top_p: float,
|
|
attn_mask: Optional[Tensor] = None,
|
|
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
|
|
start_pos: int = 0
|
|
)-> Tuple[Tensor, int]:
|
|
|
|
with torch.inference_mode():
|
|
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
|
|
logits = outputs["logits"][:, -1, :]
|
|
cache_increase = input_ids.size(-1)
|
|
|
|
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
|
|
probs = torch.softmax(logits, dim=-1)
|
|
next_token_id = torch.multinomial(probs, num_samples=1)
|
|
|
|
return next_token_id, cache_increase
|
|
|
|
def to(self, *args, **kargs) -> Self:
|
|
self.model.to(*args, **kargs)
|
|
return self
|
|
|
|
def generate_loop(
|
|
self,
|
|
input_ids: Tensor,
|
|
ids: List[int],
|
|
temperature: float,
|
|
top_k: int,
|
|
top_p: float,
|
|
attn_mask: Optional[Tensor] = None,
|
|
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
|
|
start_pos: int = 0,
|
|
callback: Optional[Callable[..., Any]] = None
|
|
) -> List[int]:
|
|
cur_cache_pos = start_pos
|
|
|
|
for _ in range(len(ids), self.config.m_len):
|
|
next_token_id, cache_increase = self.generate_iterator(
|
|
input_ids, temperature, top_k, top_p, attn_mask, kv_caches, cur_cache_pos)
|
|
|
|
input_ids = next_token_id
|
|
ids.append(next_token_id.item())
|
|
cur_cache_pos += cache_increase
|
|
|
|
if callback:
|
|
callback(next_token_id.item(), ids.copy())
|
|
|
|
if next_token_id.item() in self.tokenizer.stop_ids:
|
|
break
|
|
|
|
return ids
|
|
|
|
|
|
class EmbeddingEncoderCore:
|
|
def __init__(self, parameter: ModelParameter):
|
|
self.model = parameter.model
|
|
self.tokenizer = parameter.tokenizer
|
|
self.config = parameter.config
|
|
|
|
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
|
|
with_batch = isinstance(sentence, list)
|
|
ids = self.tokenizer.encode(sentence)
|
|
batch_ids = ids if with_batch else [ids]
|
|
max_model_len = self.config.m_len
|
|
|
|
all_fragments = []
|
|
fragment_origin_idx = []
|
|
|
|
for i, seq in enumerate(batch_ids):
|
|
if len(seq) > max_model_len:
|
|
fragments = [seq[j:j+max_model_len] for j in range(0, len(seq), max_model_len)]
|
|
all_fragments.extend(fragments)
|
|
fragment_origin_idx.extend([i] * len(fragments))
|
|
else:
|
|
all_fragments.append(seq)
|
|
fragment_origin_idx.append(i)
|
|
|
|
#if empty fragments
|
|
if not all_fragments or not ids:
|
|
return [] if with_batch else torch.tensor([])
|
|
|
|
device = next(self.model.parameters()).device
|
|
max_len = min(max(len(seq) for seq in all_fragments), max_model_len)
|
|
|
|
padded_ids = []
|
|
masks = []
|
|
for seq in all_fragments:
|
|
pad_len = max_len - len(seq)
|
|
padded_seq = seq + [self.tokenizer.pad_id] * pad_len
|
|
mask = [token_id != self.tokenizer.pad_id for token_id in padded_seq]
|
|
padded_ids.append(padded_seq)
|
|
masks.append(mask)
|
|
|
|
input_tensor = torch.tensor(padded_ids, device=device, dtype=torch.long)
|
|
seq_mask = torch.tensor(masks, device=device, dtype=torch.bool)
|
|
|
|
with torch.inference_mode():
|
|
outputs = self.model(input_tensor, seq_mask)["hidden_states"]
|
|
# [num_fragments, seq_len, hidden_size]
|
|
fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1))
|
|
|
|
sentence_embs: List[Tensor] = []
|
|
for i in range(len(batch_ids)):
|
|
indices = [idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i]
|
|
if indices is not None:
|
|
sum_frags = torch.sum(fragment_embs[indices, :, :], dim=1) # [frags, hidden_size]
|
|
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(1) # [frags, 1]
|
|
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
|
|
sentence_embs.append(emb.flatten())
|
|
|
|
if with_batch:
|
|
return [emb.flatten() for emb in sentence_embs]
|
|
else:
|
|
return sentence_embs[0].flatten()
|
|
|
|
def to(self, *args, **kargs) -> Self:
|
|
self.model.to(*args, **kargs)
|
|
return self
|
|
|
|
|
|
class KVCacheManager:
|
|
def __init__(
|
|
self,
|
|
config: TransformerConfig,
|
|
batch_size: int,
|
|
device: torch.device = "cuda",
|
|
dtype: torch.dtype = torch.bfloat16
|
|
):
|
|
self.batch_size = batch_size
|
|
self.device = device
|
|
self.dtype = dtype
|
|
self.num_layers = config.n_layer
|
|
self.max_len = config.m_len
|
|
self.num_heads = config.n_kvhead
|
|
self.head_dim = config.n_dim //config.n_head
|
|
|
|
self._kv_cache: Tuple[Tensor, Tensor] = None
|
|
self._seq_mask: Tensor = None
|
|
self._initialize()
|
|
|
|
def _initialize(self):
|
|
k_cache = torch.zeros(
|
|
(self.batch_size, self.num_layers, self.max_len, self.num_heads, self.head_dim),
|
|
device=self.device, dtype=self.dtype
|
|
)
|
|
v_cache = torch.zeros(
|
|
(self.batch_size, self.num_layers, self.max_len, self.num_heads, self.head_dim),
|
|
device=self.device, dtype=self.dtype
|
|
)
|
|
self._kv_cache = (k_cache, v_cache)
|
|
self._seq_mask = torch.ones((self.batch_size, self.max_len), device=self.device, dtype=torch.bool)
|
|
|
|
def update(self, active_mask: Tensor):
|
|
k_cache, v_cache = self._kv_cache
|
|
self._kv_cache = (k_cache[active_mask], v_cache[active_mask])
|
|
self._seq_mask = self._seq_mask[active_mask]
|
|
|
|
def reset(self, full_reset=False):
|
|
if full_reset:
|
|
self._kv_cache = None
|
|
self._seq_mask = None
|
|
else:
|
|
self._initialize()
|
|
|
|
def set_seq_mask(self, input_ids: Tensor, pad_id: int):
|
|
batch_size, seq_len = input_ids.shape
|
|
bool_mask = (input_ids != pad_id)
|
|
self._seq_mask[: batch_size, : seq_len] = bool_mask
|
|
|
|
def get_kvcache(self) -> Tuple[Tensor, Tensor]:
|
|
return self._kv_cache
|
|
|
|
def get_seq_mask(self) -> Tensor:
|
|
return self._seq_mask |