AstrAI/khaosz/inference/core.py

97 lines
3.7 KiB
Python

import torch
from torch import Tensor
from typing import List, Tuple, Union, Optional, Generator, Self
from khaosz.config.param_config import ModelParameter
class GeneratorCore:
def __init__(self, parameter: ModelParameter):
self.model = parameter.model
self.tokenizer = parameter.tokenizer
self.config = parameter.config
def compute_logits(
self,
input_ids: Tensor,
attn_mask: Optional[Tensor] = None,
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
start_pos: int = 0
) -> Tuple[Tensor, int]:
with torch.inference_mode():
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
logits = outputs["logits"][:, -1, :]
cache_increase = input_ids.size(-1)
return logits, cache_increase
def to(self, *args, **kargs) -> Self:
self.model.to(*args, **kargs)
return self
class EmbeddingEncoderCore:
def __init__(self, parameter: ModelParameter):
self.model = parameter.model
self.tokenizer = parameter.tokenizer
self.config = parameter.config
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
with_batch = isinstance(sentence, list)
ids = self.tokenizer.encode(sentence)
batch_ids = ids if with_batch else [ids]
max_model_len = self.config.m_len
all_fragments = []
fragment_origin_idx = []
for i, seq in enumerate(batch_ids):
if len(seq) > max_model_len:
fragments = [seq[j:j+max_model_len] for j in range(0, len(seq), max_model_len)]
all_fragments.extend(fragments)
fragment_origin_idx.extend([i] * len(fragments))
else:
all_fragments.append(seq)
fragment_origin_idx.append(i)
#if empty fragments
if not all_fragments or not ids:
return [] if with_batch else torch.tensor([])
device = next(self.model.parameters()).device
max_len = min(max(len(seq) for seq in all_fragments), max_model_len)
padded_ids = []
masks = []
for seq in all_fragments:
pad_len = max_len - len(seq)
padded_seq = seq + [self.tokenizer.pad_id] * pad_len
mask = [token_id != self.tokenizer.pad_id for token_id in padded_seq]
padded_ids.append(padded_seq)
masks.append(mask)
input_tensor = torch.tensor(padded_ids, device=device, dtype=torch.long)
seq_mask = torch.tensor(masks, device=device, dtype=torch.bool)
with torch.inference_mode():
outputs = self.model(input_tensor, seq_mask)["hidden_states"]
# [num_fragments, seq_len, hidden_size]
fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1))
sentence_embs: List[Tensor] = []
for i in range(len(batch_ids)):
indices = [idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i]
if indices is not None:
sum_frags = torch.sum(fragment_embs[indices, :, :], dim=1) # [frags, hidden_size]
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(1) # [frags, 1]
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
sentence_embs.append(emb.flatten())
if with_batch:
return [emb.flatten() for emb in sentence_embs]
else:
return sentence_embs[0].flatten()
def to(self, *args, **kargs) -> Self:
self.model.to(*args, **kargs)
return self