AstrAI/khaosz/inference/core.py

299 lines
9.2 KiB
Python

import torch
import torch.nn as nn
from torch import Tensor
from contextlib import contextmanager
from typing import Any, Callable, List, Tuple, Union, Optional, Self
from khaosz.config import ModelParameter, ModelConfig
def apply_sampling_strategies(
logits: Tensor,
temperature: float,
top_k: int,
top_p: float,
filter_value: float = -float("inf"),
) -> Tensor:
"""
Apply sampling strategies to the logits tensor.
Args:
logits (Tensor): The logits tensor.
temperature (float): The temperature parameter.
top_k (int): The top-k parameter.
top_p (float): The top-p parameter.
filter_value (float, optional): The filter value. Defaults to -float("inf").
Returns:
Tensor: The sampled logits tensor.
"""
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
top_k = min(top_k, logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(
dim=1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
@contextmanager
def disable_random_init():
init_functions = [
"xavier_normal_",
"xavier_uniform_",
"kaiming_normal_",
"kaiming_uniform_",
"zeros_",
"ones_",
"constant_",
"normal_",
"uniform_",
]
original_funcs = {}
for name in init_functions:
if hasattr(nn.init, name):
original_funcs[name] = getattr(nn.init, name)
setattr(nn.init, name, lambda *args, **kwargs: None)
try:
yield
finally:
for name, orig_func in original_funcs.items():
setattr(nn.init, name, orig_func)
class GeneratorCore:
def __init__(self, parameter: ModelParameter):
self.model = parameter.model
self.tokenizer = parameter.tokenizer
self.config = parameter.config
def generate_iterator(
self,
input_ids: Tensor,
temperature: float,
top_k: int,
top_p: float,
attn_mask: Optional[Tensor] = None,
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
start_pos: int = 0,
) -> Tuple[Tensor, int]:
with torch.inference_mode():
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
logits = outputs["logits"][:, -1, :]
cache_increase = input_ids.size(-1)
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
probs = torch.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
return next_token_id, cache_increase
def generate_loop(
self,
input_ids: Tensor,
ids: List[int],
temperature: float,
top_k: int,
top_p: float,
attn_mask: Optional[Tensor] = None,
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
start_pos: int = 0,
callback: Optional[Callable[..., Any]] = None,
) -> List[int]:
cur_cache_pos = start_pos
for _ in range(len(ids), self.config.max_len):
next_token_id, cache_increase = self.generate_iterator(
input_ids,
temperature,
top_k,
top_p,
attn_mask,
kv_caches,
cur_cache_pos,
)
input_ids = next_token_id
ids.append(next_token_id.item())
cur_cache_pos += cache_increase
if callback:
callback(next_token_id.item(), ids.copy())
if next_token_id.item() in self.tokenizer.stop_ids:
break
return ids
def to(self, *args, **kargs) -> Self:
self.model.to(*args, **kargs)
return self
class EmbeddingEncoderCore:
def __init__(self, parameter: ModelParameter):
self.model = parameter.model
self.tokenizer = parameter.tokenizer
self.config = parameter.config
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
with_batch = isinstance(sentence, list)
ids = self.tokenizer.encode(sentence)
batch_ids = ids if with_batch else [ids]
max_model_len = self.config.max_len
all_fragments = []
fragment_origin_idx = []
for i, seq in enumerate(batch_ids):
if len(seq) > max_model_len:
fragments = [
seq[j : j + max_model_len]
for j in range(0, len(seq), max_model_len)
]
all_fragments.extend(fragments)
fragment_origin_idx.extend([i] * len(fragments))
else:
all_fragments.append(seq)
fragment_origin_idx.append(i)
# if empty fragments
if not all_fragments or not ids:
return [] if with_batch else torch.tensor([])
device = next(self.model.parameters()).device
max_len = min(max(len(seq) for seq in all_fragments), max_model_len)
padded_ids = []
masks = []
for seq in all_fragments:
pad_len = max_len - len(seq)
padded_seq = seq + [self.tokenizer.pad_id] * pad_len
mask = [token_id != self.tokenizer.pad_id for token_id in padded_seq]
padded_ids.append(padded_seq)
masks.append(mask)
input_tensor = torch.tensor(padded_ids, device=device, dtype=torch.long)
seq_mask = torch.tensor(masks, device=device, dtype=torch.bool)
with torch.inference_mode():
outputs = self.model(input_tensor, seq_mask)["hidden_states"]
# [num_fragments, seq_len, hidden_size]
fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1))
sentence_embs: List[Tensor] = []
for i in range(len(batch_ids)):
indices = [
idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i
]
if indices:
sum_frags = torch.sum(
fragment_embs[indices, :, :], dim=1
) # [frags, hidden_size]
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(
1
) # [frags, 1]
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
sentence_embs.append(emb.flatten())
if with_batch:
return [emb.flatten() for emb in sentence_embs]
else:
return sentence_embs[0].flatten()
def to(self, *args, **kargs) -> Self:
self.model.to(*args, **kargs)
return self
class KVCacheManager:
def __init__(
self,
config: ModelConfig,
batch_size: int,
device: torch.device = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
self.batch_size = batch_size
self.device = device
self.dtype = dtype
self.num_layers = config.n_layers
self.max_len = config.max_len
self.num_heads = config.n_kv_heads
self.head_dim = config.dim // config.n_heads
self._kv_cache: Tuple[Tensor, Tensor] = None
self._seq_mask: Tensor = None
self._initialize()
def _initialize(self):
k_cache = torch.empty(
(
self.batch_size,
self.max_len,
self.num_layers,
self.num_heads,
self.head_dim,
),
device=self.device,
dtype=self.dtype,
)
v_cache = torch.empty(
(
self.batch_size,
self.max_len,
self.num_layers,
self.num_heads,
self.head_dim,
),
device=self.device,
dtype=self.dtype,
)
self._kv_cache = (k_cache, v_cache)
self._seq_mask = torch.ones(
(self.batch_size, self.max_len), device=self.device, dtype=torch.bool
)
def update(self, active_mask: Tensor):
k_cache, v_cache = self._kv_cache
self._kv_cache = (k_cache[active_mask], v_cache[active_mask])
self._seq_mask = self._seq_mask[active_mask]
def reset(self, full_reset=False):
if full_reset:
self._kv_cache = None
self._seq_mask = None
else:
self._initialize()
def set_seq_mask(self, input_ids: Tensor, pad_id: int):
batch_size, seq_len = input_ids.shape
bool_mask = input_ids != pad_id
self._seq_mask[:batch_size, :seq_len] = bool_mask
def get_kvcache(self) -> Tuple[Tensor, Tensor]:
return self._kv_cache
def get_seq_mask(self) -> Tensor:
return self._seq_mask