AstrAI/khaosz/inference/generator.py

220 lines
7.9 KiB
Python

import torch
from dataclasses import dataclass
from torch import Tensor
from typing import List, Tuple, Union, Optional, Generator
from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheManager
from khaosz.config.param_config import ModelParameter
HistoryType = List[Tuple[str, str]]
def build_prompt(
query: str,
init_prompt: Optional[str] = None,
history: Optional[List[Tuple[str, str]]] = None
) -> str:
"""
Build prompt in ChatML format for query and history
Args:
query(str): query string
history(Optional[List[Tuple[str, str]]]): history list of query and response
Returns:
str: prompt string in ChatML format
"""
prompt = f"<|im_start|>system\n{init_prompt}<|im_end|>\n" if init_prompt else ""
# (convert tuple format to ChatML)
if history:
for user_msg, assistant_msg in history:
prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
prompt += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n"
prompt += f"<|im_start|>user\n{query}<|im_end|>\n"
prompt += "<|im_start|>assistant\n"
return prompt
def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]:
"""
Pad a list of sequences to a fixed length.
Args:
ids_list (List[List[int]]): A list of sequences.
max_ids_len (int): The maximum length of sequences.
pad_id (int): The id to pad sequences.
Returns:
List[List[int]]: A list of padded sequences.
"""
max_ids_len = max(len(ids) for ids in ids_list)
new_ids_list = []
for ids in ids_list:
pad_len = max_ids_len - len(ids)
padded_seq = [pad_id] * pad_len + ids
new_ids_list.append(padded_seq)
return new_ids_list, max_ids_len
@dataclass
class GenerationRequest:
top_k: int
top_p: float
temperature: float
max_len: int
query: Union[str, List[str]]
history: Optional[Union[HistoryType, List[HistoryType]]] = None
system_prompt: Optional[str] = None
build_prompt: bool = True
def __post_init__(self):
if not isinstance(self.top_k, int) or self.top_k < 0:
raise ValueError("top_k must be a non-negative integer")
if not isinstance(self.top_p, float) or self.top_p < 0.0 or self.top_p > 1.0:
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not isinstance(self.temperature, float) or self.temperature < 0.0:
raise ValueError("temperature must be a non-negative float")
class LoopGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def generate(self, request: GenerationRequest) -> str:
device = next(self.model.parameters()).device
cache_manager = KVCacheManager(self.config, 1, device=device)
input_args = build_prompt(request.query, request.history) if request.build_prompt else request.query
ids = self.tokenizer.encode(input_args)
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
start_cache_pos = len(ids)
cur_cache_pos = 0
self.model.eval()
kv_caches = cache_manager.get_kvcache()
ids = self.generate_loop(
input_ids, ids, request.temperature, request.top_k, request.top_p,
kv_caches=kv_caches,
start_pos=cur_cache_pos
)
response = self.tokenizer.decode(ids[start_cache_pos:])
return response
class StreamGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def generate(self, request: GenerationRequest) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
if request.history is None:
request.history = []
device = next(self.model.parameters()).device
cache_manager = KVCacheManager(self.config, 1, device=device)
input_args = build_prompt(request.query, request.history) if request.build_prompt else request.query
ids = self.tokenizer.encode(input_args)
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
start_cache_pos = len(ids)
cur_cache_pos = 0
self.model.eval()
kv_caches = cache_manager.get_kvcache()
for _ in range(len(ids), self.config.max_len):
next_token_id, cache_increase = self.generate_iterator(
input_ids, request.temperature, request.top_k, request.top_p,
kv_caches=kv_caches,
start_pos=cur_cache_pos
)
input_ids = next_token_id
ids.append(next_token_id.item())
cur_cache_pos += cache_increase
response = self.tokenizer.decode(ids[start_cache_pos:])
yield response, request.history + [(request.query, response)]
if next_token_id.item() in self.tokenizer.stop_ids:
yield response + "\n", request.history + [(request.query, response)]
break
class BatchGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def generate(self, request: GenerationRequest) -> List[str]:
batch_size = len(request.query)
if request.history is None:
request.history = [[] for _ in range(batch_size)]
prompts = [build_prompt(query, history) for query, history in zip(request.query, request.history)]
ids_list = [self.tokenizer.encode(prompt) for prompt in prompts]
ids_list, max_ids_len = pad_sequence(ids_list, self.tokenizer.pad_id)
device = next(self.model.parameters()).device
cache_manager = KVCacheManager(self.config, batch_size, device=device)
input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long)
cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id)
activate_task_mask = [True] * batch_size
start_cache_pos = max_ids_len
cur_cache_pos = 0
while max_ids_len < self.config.max_len and sum(activate_task_mask) != 0:
kv_caches = cache_manager.get_kvcache()
attn_mask =cache_manager.get_seq_mask()
next_token_id, cache_increase = self.generate_iterator(
input_tensor, request.temperature, request.top_k, request.top_p,
attn_mask=attn_mask,
kv_caches=kv_caches,
start_pos=cur_cache_pos
)
cur_cache_pos += cache_increase
active_mask = []
c_ids = 0
for i in range(batch_size):
if activate_task_mask[i]:
token = next_token_id[c_ids, :].item()
ids_list[i].append(token)
c_ids += 1
is_active = not token in self.tokenizer.stop_ids
activate_task_mask[i] = is_active
active_mask.append(is_active)
active_mask = torch.tensor(active_mask, device=device, dtype=torch.bool)
cache_manager.update(active_mask)
input_tensor = next_token_id[active_mask, :]
max_ids_len += 1
responses = [str()] * batch_size
for i in range(batch_size):
responses[i] = self.tokenizer.decode(ids_list[i][start_cache_pos:])
request.history[i].append((request.query[i], responses[i]))
return responses
class EmbeddingEncoder(EmbeddingEncoderCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
return super().encode(sentence)