AstrAI/astrai/inference/generator.py

270 lines
8.9 KiB
Python

from dataclasses import dataclass
from typing import Generator, List, Optional, Tuple, Union
import torch
from torch import Tensor
from astrai.config.param_config import ModelParameter
from astrai.factory import BaseFactory
from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager
from astrai.tokenizer.chat_template import HistoryType, build_prompt
def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]:
"""
Pad a list of sequences to a fixed length.
Args:
ids_list (List[List[int]]): A list of sequences.
max_ids_len (int): The maximum length of sequences.
pad_id (int): The id to pad sequences.
Returns:
List[List[int]]: A list of padded sequences.
"""
max_ids_len = max(len(ids) for ids in ids_list)
new_ids_list = []
for ids in ids_list:
pad_len = max_ids_len - len(ids)
padded_seq = [pad_id] * pad_len + ids
new_ids_list.append(padded_seq)
return new_ids_list, max_ids_len
@dataclass
class GenerationRequest:
"""
Request parameters for text generation.
Attributes:
top_k: Top-k sampling parameter.
top_p: Top-p (nucleus) sampling parameter.
temperature: Sampling temperature.
max_len: Maximum generation length.
query: Input query (string or list of strings for batch).
history: Conversation history.
system_prompt: System prompt for the conversation.
stream: Whether to use streaming generation.
"""
top_k: int
top_p: float
temperature: float
max_len: int
query: Union[str, List[str]]
history: Optional[Union[HistoryType, List[HistoryType]]] = None
system_prompt: Optional[str] = None
stream: bool = False
def __post_init__(self):
if not isinstance(self.top_k, int) or self.top_k < 0:
raise ValueError("top_k must be a non-negative integer")
if not isinstance(self.top_p, float) or self.top_p < 0.0 or self.top_p > 1.0:
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not isinstance(self.temperature, float) or self.temperature < 0.0:
raise ValueError("temperature must be a non-negative float")
class LoopGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def generate(self, request: GenerationRequest) -> str:
model_params = next(self.model.parameters())
device = model_params.device
dtype = model_params.dtype
cache_manager = KVCacheManager(self.config, 1, device=device, dtype=dtype)
prompt = build_prompt(request.query, request.history)
ids = self.tokenizer.encode(prompt)
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
start_cache_pos = len(ids)
self.model.eval()
kv_caches = cache_manager.get_kvcache()
ids = self.generate_loop(
input_ids,
ids,
request.temperature,
request.top_k,
request.top_p,
kv_caches=kv_caches,
)
response = self.tokenizer.decode(ids[start_cache_pos:])
return response
class StreamGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def generate(self, request: GenerationRequest) -> Generator[str, None, None]:
model_params = next(self.model.parameters())
device = model_params.device
dtype = model_params.dtype
cache_manager = KVCacheManager(self.config, 1, device=device, dtype=dtype)
prompt = build_prompt(request.query, request.history)
ids = self.tokenizer.encode(prompt)
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
start_cache_pos = len(ids)
cur_cache_pos = 0
self.model.eval()
kv_caches = cache_manager.get_kvcache()
for _ in range(len(ids), self.config.max_len):
next_token_id, cache_increase = self.generate_iterator(
input_ids,
request.temperature,
request.top_k,
request.top_p,
kv_caches=kv_caches,
start_pos=cur_cache_pos,
)
input_ids = next_token_id
ids.append(next_token_id.item())
cur_cache_pos += cache_increase
response = self.tokenizer.decode(ids[start_cache_pos:])
yield response
if next_token_id.item() in self.tokenizer.stop_ids:
yield response + "\n"
break
class BatchGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def generate(self, request: GenerationRequest) -> List[str]:
batch_size = len(request.query)
if request.history is None:
request.history = [[] for _ in range(batch_size)]
prompts = [
build_prompt(query, history)
for query, history in zip(request.query, request.history)
]
ids_list = [self.tokenizer.encode(prompt) for prompt in prompts]
ids_list, max_ids_len = pad_sequence(ids_list, self.tokenizer.pad_id)
model_params = next(self.model.parameters())
device = model_params.device
dtype = model_params.dtype
cache_manager = KVCacheManager(
self.config, batch_size, device=device, dtype=dtype
)
input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long)
cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id)
activate_task_mask = [True] * batch_size
start_cache_pos = max_ids_len
cur_cache_pos = 0
while max_ids_len < self.config.max_len and sum(activate_task_mask) != 0:
kv_caches = cache_manager.get_kvcache()
attn_mask = cache_manager.get_seq_mask()
next_token_id, cache_increase = self.generate_iterator(
input_tensor,
request.temperature,
request.top_k,
request.top_p,
attn_mask=attn_mask,
kv_caches=kv_caches,
start_pos=cur_cache_pos,
)
cur_cache_pos += cache_increase
active_mask = []
c_ids = 0
for i in range(batch_size):
if activate_task_mask[i]:
token = next_token_id[c_ids, :].item()
ids_list[i].append(token)
c_ids += 1
is_active = token not in self.tokenizer.stop_ids
activate_task_mask[i] = is_active
active_mask.append(is_active)
active_mask = torch.tensor(active_mask, device=device, dtype=torch.bool)
cache_manager.update(active_mask)
input_tensor = next_token_id[active_mask, :]
max_ids_len += 1
responses = [str()] * batch_size
for i in range(batch_size):
responses[i] = self.tokenizer.decode(ids_list[i][start_cache_pos:])
request.history[i].append((request.query[i], responses[i]))
return responses
class EmbeddingEncoder(EmbeddingEncoderCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
return super().encode(sentence)
class GeneratorFactory(BaseFactory[GeneratorCore]):
"""Factory class for creating generator instances.
Provides smart generator selection based on request characteristics:
- Streaming: Use StreamGenerator for streaming output
- Batch: Use BatchGenerator when query is a list
- Single: Use LoopGenerator for single query non-streaming
Example usage:
generator = GeneratorFactory.create(parameter, request)
result = generator.generate(request)
"""
@staticmethod
def create(parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore:
"""Create a generator based on request characteristics.
Args:
parameter: Model parameters containing model, tokenizer, config
request: Generation request with query, options, etc.
Returns:
Appropriate GeneratorCore subclass instance
"""
# Streaming generation: check stream field first
if request.stream:
return StreamGenerator(parameter)
# Batch generation: query is a list of strings
if isinstance(request.query, list):
return BatchGenerator(parameter)
# Default: single query non-streaming
return LoopGenerator(parameter)
@staticmethod
def create_encoder(parameter: ModelParameter) -> EmbeddingEncoderCore:
"""Create an embedding encoder instance.
Args:
parameter: Model parameters
Returns:
EmbeddingEncoderCore instance
"""
return EmbeddingEncoder(parameter)