"""Unified inference engine.""" import gc import logging import threading from typing import Any, Dict, Generator, List, Optional, Union import torch import torch.nn as nn from astrai.inference.scheduler import InferenceScheduler from astrai.tokenize import AutoTokenizer logger = logging.getLogger(__name__) class GenerationRequest: """Request parameters for text generation.""" def __init__( self, messages: List[Dict[str, str]], top_k: int = 50, top_p: float = 1.0, temperature: float = 1.0, max_len: int = 1024, stream: bool = False, ): self.messages = messages self.top_k = top_k self.top_p = top_p self.temperature = temperature self.max_len = max_len self.stream = stream self._validate() def _validate(self): """Validate request parameters.""" if not (isinstance(self.top_k, int) and self.top_k >= 0): raise ValueError("top_k must be a non-negative integer") if not (0.0 <= self.top_p <= 1.0): raise ValueError("top_p must be a float between 0.0 and 1.0") if not (isinstance(self.temperature, (int, float)) and self.temperature >= 0): raise ValueError("temperature must be a non-negative number") class _Result: """Unified result holder for streaming/non-streaming modes.""" def __init__(self, count: int = 1, stream: bool = False): self._stream = stream self._lock = threading.Lock() self._event = threading.Event() self.tokens: List[str] = [] self.results: List[str] = [""] * count if count > 1 else [""] self.done_flags: List[bool] = [False] * count self._completed_count = 0 def append(self, token: str, idx: int = 0): with self._lock: if self._stream: self.tokens.append(token) else: if token == "[DONE]": if not self.done_flags[idx]: self.done_flags[idx] = True self._completed_count += 1 if self._completed_count == len(self.results): self._event.set() else: self.results[idx] += token self._event.set() def pop_all(self) -> List[str]: with self._lock: tokens = self.tokens.copy() self.tokens.clear() if not tokens: self._event.clear() return tokens def wait(self, timeout: float = None) -> bool: return self._event.wait(timeout=timeout) def get_results(self) -> List[str]: with self._lock: return self.results.copy() class InferenceEngine: """Unified inference engine for continuous batching.""" def __init__( self, model: nn.Module, tokenizer: AutoTokenizer, max_batch_size: int = 1, max_seq_len: Optional[int] = None, max_prefix_len: int = 512, cache_capacity: int = 1000, ): """ Initialize inference engine with separate model and tokenizer. Args: model: The language model for inference (nn.Module, e.g., Transformer) tokenizer: The tokenizer for encoding/decoding text config: Model configuration max_batch_size: Maximum batch size for continuous batching max_seq_len: Maximum sequence length (defaults to config.max_len) max_prefix_len: Maximum prefix length for cache (default: 512) cache_capacity: Maximum number of cached prefixes (default: 1000) """ self.model = model self.tokenizer = tokenizer # Get device and dtype from model parameters try: first_param = next(model.parameters()) device = first_param.device dtype = first_param.dtype except StopIteration: # Model has no parameters, use default device/dtype device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 self.scheduler = InferenceScheduler( model=self.model, tokenizer=self.tokenizer, max_batch_size=max_batch_size, max_seq_len=max_seq_len, max_prefix_len=max_prefix_len, cache_capacity=cache_capacity, device=device, dtype=dtype, ) self.kv_cache = self.scheduler.kv_cache self.seq_mask = self.scheduler.seq_mask self.scheduler.start() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): """Handle exceptions on exit.""" self.shutdown() return False def generate( self, prompt: Union[str, List[str]], stream: bool = False, max_tokens: int = 1024, temperature: float = 1.0, top_p: float = 1.0, top_k: int = 50, abort_on_exception: bool = True, ) -> Union[Generator[str, None, None], str, List[str]]: """Unified generation interface. Args: abort_on_exception: If True, abort the generation when consumer stops iterating (GeneratorExit/StopIteration). Default: True. """ is_batch = isinstance(prompt, list) prompts = prompt if is_batch else [prompt] if stream: return self._generate_streaming( prompts, is_batch, max_tokens, temperature, top_p, top_k, abort_on_exception, ) else: return self._generate_non_streaming( prompts, is_batch, max_tokens, temperature, top_p, top_k ) def generate_with_request( self, request: GenerationRequest ) -> Union[Generator[str, None, None], str, List[str]]: """Generate with GenerationRequest object.""" # Use tokenizer's chat template with messages prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False) return self.generate( prompt=prompt, stream=request.stream, max_tokens=request.max_len, temperature=request.temperature, top_p=request.top_p, top_k=request.top_k, ) def _generate_streaming( self, prompts: List[str], is_batch: bool, max_tokens: int, temperature: float, top_p: float, top_k: int, abort_on_exception: bool = True, ) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]: """Generate with streaming output. Args: abort_on_exception: If True, abort the task when generator is stopped early by consumer (GeneratorExit/StopIteration). """ if is_batch: raise NotImplementedError("Batch streaming is not implemented yet") result = _Result(stream=True) task_id = self.scheduler.add_task( prompt=prompts[0], max_tokens=max_tokens, temperature=temperature, top_p=top_p, top_k=top_k, stream_callback=result.append, ) def gen(): try: while True: tokens = result.pop_all() for token in tokens: if token == "[DONE]": return yield token result.wait(timeout=0.05) except Exception: # Consumer stopped iterating - abort the task if abort_on_exception: self.scheduler.remove_task(task_id) raise gen.task_id = task_id return gen() def _generate_non_streaming( self, prompts: List[str], is_batch: bool, max_tokens: int, temperature: float, top_p: float, top_k: int, ) -> Union[str, List[str]]: """Generate without streaming.""" result = _Result(count=len(prompts)) for i, p in enumerate(prompts): # Create closure to capture current index value using factory function def make_callback(idx): def callback(token): result.append(idx, token) return callback self.scheduler.add_task( prompt=p, max_tokens=max_tokens, temperature=temperature, top_p=top_p, top_k=top_k, stream_callback=make_callback(i), ) result.wait() results = result.get_results() return results if is_batch else results[0] def get_stats(self) -> Dict[str, Any]: """Get engine statistics.""" return self.scheduler.get_stats() def shutdown(self) -> None: """Shutdown the engine and release all resources.""" self.scheduler.stop() if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect()