AstrAI/astrai/inference/engine.py

353 lines
10 KiB
Python

"""Unified inference engine."""
import gc
import logging
import signal
import threading
from typing import Any, Dict, Generator, List, Optional, Union
import torch
import torch.nn as nn
from astrai.tokenize.tokenizer import TextTokenizer
from astrai.inference.scheduler import InferenceScheduler
logger = logging.getLogger(__name__)
class GenerationRequest:
"""Request parameters for text generation."""
def __init__(
self,
messages: List[Dict[str, str]],
top_k: int = 50,
top_p: float = 1.0,
temperature: float = 1.0,
max_len: int = 1024,
stream: bool = False,
):
self.messages = messages
self.top_k = top_k
self.top_p = top_p
self.temperature = temperature
self.max_len = max_len
self.stream = stream
self._validate()
def _validate(self):
"""Validate request parameters."""
if not isinstance(self.top_k, int) or self.top_k < 0:
raise ValueError("top_k must be a non-negative integer")
if not isinstance(self.top_p, float) or self.top_p < 0.0 or self.top_p > 1.0:
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not isinstance(self.temperature, float) or self.temperature < 0.0:
raise ValueError("temperature must be a non-negative float")
class _StreamingResult:
"""Streaming result holder with event-based notification."""
def __init__(self):
self.tokens: List[str] = []
self._event = threading.Event()
self._lock = threading.Lock()
def append(self, token: str):
with self._lock:
self.tokens.append(token)
self._event.set()
def pop_all(self) -> List[str]:
with self._lock:
tokens = self.tokens.copy()
self.tokens.clear()
if not tokens:
self._event.clear()
return tokens
def wait(self, timeout: float = None) -> bool:
return self._event.wait(timeout=timeout)
class _NonStreamingResult:
"""Non-streaming result holder with event-based completion notification."""
def __init__(self, count: int):
self.results: List[str] = ["" for _ in range(count)]
self.done_flags: List[bool] = [False] * count
self._completed_count = 0
self._event = threading.Event()
self._lock = threading.Lock()
def append(self, idx: int, token: str):
with self._lock:
if token == "[DONE]":
if not self.done_flags[idx]:
self.done_flags[idx] = True
self._completed_count += 1
if self._completed_count == len(self.results):
self._event.set()
else:
self.results[idx] += token
def is_all_done(self) -> bool:
with self._lock:
return all(self.done_flags)
def wait(self, timeout: float = None) -> bool:
return self._event.wait(timeout=timeout)
def get_results(self) -> List[str]:
with self._lock:
return self.results.copy()
class InferenceEngine:
"""Unified inference engine for continuous batching."""
def __init__(
self,
model: nn.Module,
tokenizer: TextTokenizer,
max_batch_size: int = 1,
max_seq_len: Optional[int] = None,
):
"""
Initialize inference engine with separate model and tokenizer.
Args:
model: The language model for inference (nn.Module, e.g., Transformer)
tokenizer: The tokenizer for encoding/decoding text
config: Model configuration
max_batch_size: Maximum batch size for continuous batching
max_seq_len: Maximum sequence length (defaults to config.max_len)
"""
self.model = model
self.tokenizer = tokenizer
# Get device and dtype from model parameters
try:
first_param = next(model.parameters())
device = first_param.device
dtype = first_param.dtype
except StopIteration:
# Model has no parameters, use default device/dtype
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
self.scheduler = InferenceScheduler(
model=self.model,
tokenizer=self.tokenizer,
max_batch_size=max_batch_size,
max_seq_len=max_seq_len,
device=device,
dtype=dtype,
)
self.kv_cache = self.scheduler.kv_cache
self.seq_mask = self.scheduler.seq_mask
self.scheduler.start()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Handle exceptions on exit."""
if exc_type is not None:
# An exception occurred - try to save state
logger.warning(f"Exception {exc_type.__name__}: {exc_val}, saving state...")
try:
self.save_state("./inference_state")
except Exception:
pass
self.shutdown()
return False
def generate(
self,
prompt: Union[str, List[str]],
stream: bool = False,
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
abort_on_exception: bool = True,
) -> Union[Generator[str, None, None], str, List[str]]:
"""Unified generation interface.
Args:
abort_on_exception: If True, abort the generation when consumer
stops iterating (GeneratorExit/StopIteration). Default: True.
"""
is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt]
if stream:
return self._generate_streaming(
prompts,
is_batch,
max_tokens,
temperature,
top_p,
top_k,
abort_on_exception,
)
else:
return self._generate_non_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k
)
def generate_with_request(
self, request: GenerationRequest
) -> Union[Generator[str, None, None], str, List[str]]:
"""Generate with GenerationRequest object."""
# Use tokenizer's chat template with messages
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
return self.generate(
prompt=prompt,
stream=request.stream,
max_tokens=request.max_len,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
)
def _generate_streaming(
self,
prompts: List[str],
is_batch: bool,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
abort_on_exception: bool = True,
) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]:
"""Generate with streaming output.
Args:
abort_on_exception: If True, abort the task when generator is
stopped early by consumer (GeneratorExit/StopIteration).
"""
if is_batch:
raise NotImplementedError("Batch streaming is not implemented yet")
result = _StreamingResult()
task_id = self.scheduler.add_task(
prompt=prompts[0],
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=result.append,
)
def gen():
try:
while True:
tokens = result.pop_all()
for token in tokens:
if token == "[DONE]":
return
yield token
result.wait(timeout=0.05)
except Exception:
# Consumer stopped iterating - abort the task
if abort_on_exception:
self.scheduler.remove_task(task_id)
raise
gen.task_id = task_id
return gen()
def _generate_non_streaming(
self,
prompts: List[str],
is_batch: bool,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
) -> Union[str, List[str]]:
"""Generate without streaming."""
result = _NonStreamingResult(len(prompts))
for i, p in enumerate(prompts):
self.scheduler.add_task(
prompt=p,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=result.append,
)
result.wait()
results = result.get_results()
return results if is_batch else results[0]
def get_stats(self) -> Dict[str, Any]:
"""Get engine statistics."""
return self.scheduler.get_stats()
def shutdown(self) -> None:
"""Shutdown the engine and release all resources."""
# Stop scheduler first
self.scheduler.stop()
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def force_stop(self) -> None:
"""
Force stop the engine immediately without saving state.
Use this for emergency shutdown when graceful shutdown is not possible.
"""
# Stop watching threads if any
if hasattr(self, "stop_watching"):
self.stop_watching()
# Unregister signal handlers
if hasattr(self, "_original_sigint"):
signal.signal(signal.SIGINT, self._original_sigint)
if hasattr(self, "_original_sigterm"):
signal.signal(signal.SIGTERM, self._original_sigterm)
# Force stop scheduler
self.scheduler._running = False
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
@classmethod
def create_and_run(cls, model, tokenizer, **kwargs):
"""
Create engine, run generation, and shutdown automatically.
This is a convenience method for simple scripts.
Args:
model: The model to use
tokenizer: The tokenizer to use
**kwargs: Arguments passed to generate()
Returns:
Generated text result
"""
with cls(model, tokenizer) as engine:
result = engine.generate(**kwargs)
return result