AstrAI/astrai/inference/engine.py

242 lines
7.1 KiB
Python

"""Unified inference engine."""
import threading
from typing import Any, Dict, Generator, List, Optional, Union
from astrai.config import ModelParameter
from astrai.tokenize.chat_template import build_prompt
from astrai.inference.scheduler import InferenceScheduler
class GenerationRequest:
"""Request parameters for text generation."""
def __init__(
self,
query: Union[str, List[str]],
top_k: int = 50,
top_p: float = 1.0,
temperature: float = 1.0,
max_len: int = 1024,
history: Optional[Any] = None,
system_prompt: Optional[str] = None,
stream: bool = False,
):
self.query = query
self.top_k = top_k
self.top_p = top_p
self.temperature = temperature
self.max_len = max_len
self.history = history
self.system_prompt = system_prompt
self.stream = stream
self._validate()
def _validate(self):
"""Validate request parameters."""
if not isinstance(self.top_k, int) or self.top_k < 0:
raise ValueError("top_k must be a non-negative integer")
if not isinstance(self.top_p, float) or self.top_p < 0.0 or self.top_p > 1.0:
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not isinstance(self.temperature, float) or self.temperature < 0.0:
raise ValueError("temperature must be a non-negative float")
class _StreamingResult:
"""Streaming result holder with event-based notification."""
def __init__(self):
self.tokens: List[str] = []
self._event = threading.Event()
self._lock = threading.Lock()
def append(self, token: str):
with self._lock:
self.tokens.append(token)
self._event.set()
def pop_all(self) -> List[str]:
with self._lock:
tokens = self.tokens.copy()
self.tokens.clear()
if not tokens:
self._event.clear()
return tokens
def wait(self, timeout: float = None) -> bool:
return self._event.wait(timeout=timeout)
class _NonStreamingResult:
"""Non-streaming result holder with event-based completion notification."""
def __init__(self, count: int):
self.results: List[str] = ["" for _ in range(count)]
self.done_flags: List[bool] = [False] * count
self._completed_count = 0
self._event = threading.Event()
self._lock = threading.Lock()
def append(self, idx: int, token: str):
with self._lock:
if token == "[DONE]":
if not self.done_flags[idx]:
self.done_flags[idx] = True
self._completed_count += 1
if self._completed_count == len(self.results):
self._event.set()
else:
self.results[idx] += token
def is_all_done(self) -> bool:
with self._lock:
return all(self.done_flags)
def wait(self, timeout: float = None) -> bool:
return self._event.wait(timeout=timeout)
def get_results(self) -> List[str]:
with self._lock:
return self.results.copy()
class InferenceEngine:
"""Unified inference engine for continuous batching."""
def __init__(
self,
parameter: ModelParameter,
max_batch_size: int = 16,
max_seq_len: Optional[int] = None,
):
self.model = parameter.model
self.tokenizer = parameter.tokenizer
self.config = parameter.config
model_params = next(self.model.parameters())
self.device = model_params.device
self.dtype = model_params.dtype
self.scheduler = InferenceScheduler(
model=self.model,
tokenizer=self.tokenizer,
config=self.config,
max_batch_size=max_batch_size,
max_seq_len=max_seq_len,
device=self.device,
dtype=self.dtype,
)
self.kv_cache = self.scheduler.kv_cache
self.seq_mask = self.scheduler.seq_mask
self.scheduler.start()
def generate(
self,
prompt: Union[str, List[str]],
stream: bool = False,
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
) -> Union[Generator[str, None, None], str, List[str]]:
"""Unified generation interface."""
is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt]
if stream:
return self._generate_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k
)
else:
return self._generate_non_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k
)
def generate_with_request(
self, request: GenerationRequest
) -> Union[Generator[str, None, None], str, List[str]]:
"""Generate with GenerationRequest object."""
prompt = build_prompt(request.query, request.history)
return self.generate(
prompt=prompt,
stream=request.stream,
max_tokens=request.max_len,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
)
def _generate_streaming(
self,
prompts: List[str],
is_batch: bool,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]:
"""Generate with streaming output."""
if is_batch:
raise NotImplementedError("Batch streaming is not implemented yet")
result = _StreamingResult()
self.scheduler.add_task(
prompt=prompts[0],
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=result.append,
)
def gen():
while True:
tokens = result.pop_all()
for token in tokens:
if token == "[DONE]":
return
yield token
result.wait(timeout=0.05)
return gen()
def _generate_non_streaming(
self,
prompts: List[str],
is_batch: bool,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
) -> Union[str, List[str]]:
"""Generate without streaming."""
result = _NonStreamingResult(len(prompts))
for i, p in enumerate(prompts):
self.scheduler.add_task(
prompt=p,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=result.append,
)
result.wait()
results = result.get_results()
return results if is_batch else results[0]
def get_stats(self) -> Dict[str, Any]:
"""Get engine statistics."""
return self.scheduler.get_stats()
def shutdown(self) -> None:
"""Shutdown the engine."""
self.scheduler.stop()