AstrAI/astrai/inference/engine.py

695 lines
23 KiB
Python

"""
Continuous Batching Inference Engine
This module provides the main continuous batching components:
- Task: Individual generation task with state management
- TaskStatus: Task state enumeration
- InferenceScheduler: Handles request scheduling and KV cache management
- InferenceEngine: Unified inference engine
Author: AstrAI Team
"""
import threading
import time
import uuid
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
import torch
from torch import Tensor
from astrai.config import ModelConfig, ModelParameter
from astrai.tokenize.chat_template import HistoryType, build_prompt
# Use print for debugging instead of logging
def _debug(*args, **kwargs):
pass
@dataclass
class GenerationRequest:
"""Request parameters for text generation."""
top_k: int
top_p: float
temperature: float
max_len: int
query: Union[str, List[str]]
history: Optional[Union[HistoryType, List[HistoryType]]] = None
system_prompt: Optional[str] = None
stream: bool = False
def __post_init__(self):
if not isinstance(self.top_k, int) or self.top_k < 0:
raise ValueError("top_k must be a non-negative integer")
if not isinstance(self.top_p, float) or self.top_p < 0.0 or self.top_p > 1.0:
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not isinstance(self.temperature, float) or self.temperature < 0.0:
raise ValueError("temperature must be a non-negative float")
class TaskStatus(Enum):
"""Task state enumeration for continuous batching.
States:
PENDING: Task is waiting to be scheduled
RUNNING: Task is currently being processed
FINISHED: Task completed successfully
ABORTED: Task was cancelled or failed
"""
PENDING = auto()
RUNNING = auto()
FINISHED = auto()
ABORTED = auto()
@dataclass
class Task:
"""Individual task for continuous batching.
Attributes:
task_id: Unique task identifier
prompt_ids: Input token IDs
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Top-p sampling parameter
top_k: Top-k sampling parameter
status: Current task status
output_ids: Generated token IDs
input_tokens: Number of input tokens
output_tokens: Number of output tokens generated
slot: Batch slot position (-1 if not assigned)
arrival_time: Task arrival timestamp
finish_time: Task completion timestamp
stream_callback: Callback for streaming output
"""
task_id: str
prompt_ids: List[int]
max_tokens: int = 1024
temperature: float = 1.0
top_p: float = 1.0
top_k: int = 50
status: TaskStatus = TaskStatus.PENDING
output_ids: List[int] = field(default_factory=list)
input_tokens: int = 0
output_tokens: int = 0
slot: int = -1
arrival_time: float = field(default_factory=time.time)
finish_time: Optional[float] = None
stream_callback: Optional[Callable[[str], None]] = None
def is_finished(self, stop_ids: List[int]) -> bool:
"""Check if task is finished."""
if self.output_ids and self.output_ids[-1] in stop_ids:
return True
if self.output_tokens >= self.max_tokens:
return True
return False
def apply_sampling_strategies(
logits: Tensor,
temperature: float,
top_k: int,
top_p: float,
filter_value: float = -float("inf"),
) -> Tensor:
"""Apply sampling strategies to the logits tensor."""
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
top_k = min(top_k, logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(
dim=1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
class InferenceScheduler:
"""Inference scheduler with continuous batching support.
Manages request scheduling, KV cache allocation, and generation loop.
Supports dynamic batch composition where new requests can join at any time
and completed requests are immediately released.
"""
def __init__(
self,
model,
tokenizer,
config: ModelConfig,
max_batch_size: int = 16,
max_seq_len: Optional[int] = None,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
self.model = model
self.tokenizer = tokenizer
self.config = config
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len or config.max_len
self.device = device
self.dtype = dtype
num_heads = config.n_kv_heads
head_dim = config.dim // config.n_heads
n_layers = config.n_layers
k_cache = torch.empty(
(
max_batch_size,
self.max_seq_len,
n_layers,
num_heads,
head_dim,
),
device=device,
dtype=dtype,
)
v_cache = torch.empty(
(
max_batch_size,
self.max_seq_len,
n_layers,
num_heads,
head_dim,
),
device=device,
dtype=dtype,
)
self.kv_cache = (k_cache, v_cache)
self.seq_mask = torch.ones(
(max_batch_size, self.max_seq_len), device=device, dtype=torch.bool
)
self.waiting_queue: List[Task] = []
self.active_tasks: List[Task] = []
self._running = False
self._task_event = threading.Event()
self._lock = threading.Lock()
self._total_tasks = 0
self._total_tokens = 0
def add_task(
self,
prompt: str,
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None,
) -> str:
"""Add a new task to the waiting queue."""
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
prompt_ids = self.tokenizer.encode(prompt)
_debug(
f"add_task: task_id={task_id}, prompt_len={len(prompt_ids)}, has_callback={stream_callback is not None}"
)
task = Task(
task_id=task_id,
prompt_ids=prompt_ids,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=stream_callback,
)
with self._lock:
self.waiting_queue.append(task)
self._total_tasks += 1
self._task_event.set()
return task_id
def remove_task(self, task_id: str) -> None:
"""Remove a task from the scheduler."""
with self._lock:
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
def _remove_finished_tasks(self) -> None:
"""Remove finished tasks from active batch and update caches."""
finished = []
for task in self.active_tasks:
if task.is_finished(self.tokenizer.stop_ids):
task.status = TaskStatus.FINISHED
task.finish_time = time.time()
finished.append(task)
self._total_tokens += task.output_tokens
for task in finished:
slot = task.slot
if slot >= 0 and slot < len(self.active_tasks):
self.seq_mask[slot, :] = False
task.slot = -1
self.active_tasks = [
t for t in self.active_tasks if t.status != TaskStatus.FINISHED
]
def _refill_active_batch(self) -> None:
"""Refill active batch with waiting tasks."""
available_slots = self.max_batch_size - len(self.active_tasks)
if available_slots <= 0:
return
with self._lock:
to_add = []
for _ in range(min(available_slots, len(self.waiting_queue))):
if self.waiting_queue:
task = self.waiting_queue.pop(0)
task.status = TaskStatus.RUNNING
to_add.append(task)
for task in to_add:
for i in range(self.max_batch_size):
if all(t.slot != i for t in self.active_tasks):
task.slot = i
break
self.active_tasks.append(task)
def _execute_prefill(self, tasks: List[Task]) -> None:
"""Execute Prefill phase: process entire prompt at once."""
if not tasks:
return
_debug(f"_execute_prefill: processing {len(tasks)} tasks")
# Sort tasks by slot to ensure correct batch indexing with KV cache
tasks = sorted(tasks, key=lambda t: t.slot)
prompt_lens = [len(task.prompt_ids) for task in tasks]
max_len = max(prompt_lens)
input_ids = torch.zeros(
len(tasks), max_len, dtype=torch.long, device=self.device
)
for i, task in enumerate(tasks):
if len(task.prompt_ids) > 0:
input_ids[i, : len(task.prompt_ids)] = torch.tensor(
task.prompt_ids, device=self.device
)
# Create boolean mask for attention
if self.tokenizer.pad_id is not None:
input_mask = torch.ne(input_ids, self.tokenizer.pad_id)
else:
input_mask = torch.ones(
input_ids.shape, dtype=torch.bool, device=self.device
)
_debug(
f"_execute_prefill: input_ids shape={input_ids.shape}, max_len={max_len}"
)
try:
with torch.inference_mode():
outputs = self.model(
input_ids,
input_mask=input_mask,
start_pos=0,
persistent_key_values=self.kv_cache,
)
_debug(
f"_execute_prefill: model forward done, output keys={outputs.keys() if hasattr(outputs, 'keys') else 'no keys'}"
)
except Exception as e:
_debug(f"_execute_prefill: ERROR: {e}")
raise
for i, task in enumerate(tasks):
task.input_tokens = prompt_lens[i]
task.output_tokens = 0
_debug(
f" task {task.task_id}: input_tokens={task.input_tokens}, output_tokens={task.output_tokens}"
)
for task in tasks:
if task.slot >= 0:
self.seq_mask[task.slot, : task.input_tokens] = True
_debug(f"_execute_prefill: done, {len(tasks)} tasks marked as prefill complete")
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
"""Execute Decode phase: generate one token at a time."""
if not tasks:
return
_debug(f"_execute_decode: processing {len(tasks)} tasks, start_pos={start_pos}")
# Sort tasks by slot to ensure batch index aligns with slot (KV cache position)
# Task at slot 0 → batch index 0 → KV stored at cache[0]
# Task at slot 1 → batch index 1 → KV stored at cache[1]
tasks = sorted(tasks, key=lambda t: t.slot)
input_ids = torch.zeros(len(tasks), dtype=torch.long, device=self.device)
for i, task in enumerate(tasks):
if task.output_ids:
input_ids[i] = task.output_ids[-1]
else:
input_ids[i] = task.prompt_ids[-1]
input_tensor = input_ids.unsqueeze(1) # shape: (batch, 1)
# Create 2D attention mask: (batch, seq_len)
active_mask = torch.ones((len(tasks), 1), dtype=torch.bool, device=self.device)
_debug(
f"_execute_decode: input_tensor shape={input_tensor.shape}, active_mask shape={active_mask.shape}"
)
try:
with torch.inference_mode():
outputs = self.model(
input_tensor,
input_mask=active_mask,
persistent_key_values=self.kv_cache,
start_pos=start_pos,
)
_debug(
f"_execute_decode: model forward done, logits shape={outputs['logits'].shape}"
)
logits = outputs["logits"][:, -1, :]
except Exception as e:
_debug(f"_execute_decode: ERROR: {e}")
raise
next_token_ids = []
for i, task in enumerate(tasks):
logit = logits[i : i + 1]
logit = apply_sampling_strategies(
logit,
task.temperature,
task.top_k,
task.top_p,
)
probs = torch.softmax(logit, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
next_token_ids.append(next_token.item())
_debug(f"_execute_decode: next_tokens={next_token_ids}")
for task, next_token in zip(tasks, next_token_ids):
task.output_ids.append(next_token)
task.output_tokens += 1
pos = task.input_tokens + task.output_tokens
if task.slot >= 0 and pos < self.max_seq_len:
self.seq_mask[task.slot, pos] = True
if task.stream_callback:
token_str = self.tokenizer.decode([next_token])
task.stream_callback(token_str)
# Check if any task reached max_tokens or stop token
for task in tasks:
if task.output_tokens >= task.max_tokens or (
task.output_ids and task.output_ids[-1] in self.tokenizer.stop_ids
):
_debug(
f"decode: task {task.task_id} finished, output_tokens={task.output_tokens}, max_tokens={task.max_tokens}"
)
if task.stream_callback:
task.stream_callback("[DONE]")
def _run_generation_loop(self) -> None:
"""Main generation loop with continuous batching."""
_debug("generation_loop: started")
while self._running:
self._remove_finished_tasks()
self._refill_active_batch()
if not self.active_tasks:
self._task_event.wait(timeout=0.01)
self._task_event.clear()
continue
_debug(
f"generation_loop: active={len(self.active_tasks)}, waiting={len(self.waiting_queue)}"
)
new_tasks = [t for t in self.active_tasks if t.output_tokens == 0]
decode_tasks = [t for t in self.active_tasks if t.output_tokens > 0]
_debug(
f"generation_loop: new_tasks={len(new_tasks)}, decode_tasks={len(decode_tasks)}"
)
for t in self.active_tasks:
_debug(
f" active task {t.task_id}: output_tokens={t.output_tokens}, input_tokens={t.input_tokens}"
)
if decode_tasks:
start_pos = max(t.input_tokens + t.output_tokens for t in decode_tasks)
else:
start_pos = 0
# First run prefill for new tasks
if new_tasks:
_debug(f"generation_loop: running prefill for {len(new_tasks)} tasks")
self._execute_prefill(new_tasks)
_debug(f"generation_loop: prefill done")
# After prefill, convert these tasks to decode tasks in the same iteration
decode_tasks = new_tasks
start_pos = max(t.input_tokens for t in decode_tasks)
_debug(
f"generation_loop: after prefill, decode_tasks={len(decode_tasks)}, start_pos={start_pos}"
)
if decode_tasks:
_debug(
f"generation_loop: running decode for {len(decode_tasks)} tasks, start_pos={start_pos}"
)
self._execute_decode(decode_tasks, start_pos)
_debug(f"generation_loop: decode done")
if not self.active_tasks and not self.waiting_queue:
time.sleep(0.001)
def start(self) -> None:
"""Start the generation loop in a background thread."""
if not self._running:
_debug("InferenceScheduler.start: starting loop thread")
self._running = True
self._loop_thread = threading.Thread(target=self._run_generation_loop)
self._loop_thread.daemon = True
self._loop_thread.start()
_debug("InferenceScheduler.start: loop thread started")
def stop(self) -> None:
"""Stop the generation loop."""
self._running = False
if hasattr(self, "_loop_thread"):
self._loop_thread.join(timeout=1.0)
def get_stats(self) -> Dict[str, Any]:
"""Get scheduler statistics."""
return {
"total_tasks": self._total_tasks,
"total_tokens": self._total_tokens,
"active_tasks": len(self.active_tasks),
"waiting_queue": len(self.waiting_queue),
}
class InferenceEngine:
"""Unified inference engine for continuous batching.
Provides a single interface for:
- Single request generation (streaming or non-streaming)
- Batch request generation (streaming or non-streaming)
"""
def __init__(
self,
parameter: ModelParameter,
max_batch_size: int = 16,
max_seq_len: Optional[int] = None,
):
self.model = parameter.model
self.tokenizer = parameter.tokenizer
self.config = parameter.config
model_params = next(self.model.parameters())
self.device = model_params.device
self.dtype = model_params.dtype
self.scheduler = InferenceScheduler(
model=self.model,
tokenizer=self.tokenizer,
config=self.config,
max_batch_size=max_batch_size,
max_seq_len=max_seq_len,
device=self.device,
dtype=self.dtype,
)
self.kv_cache = self.scheduler.kv_cache
self.seq_mask = self.scheduler.seq_mask
_debug("InferenceEngine: starting scheduler")
self.scheduler.start()
_debug("InferenceEngine: scheduler started")
def generate(
self,
prompt: Union[str, List[str]],
stream: bool = False,
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
) -> Union[Generator[str, None, None], str, List[str]]:
"""Unified generation interface."""
is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt]
if stream:
return self._generate_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k
)
else:
return self._generate_non_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k
)
def generate_with_request(
self, request: GenerationRequest
) -> Union[Generator[str, None, None], str, List[str]]:
"""Generate with GenerationRequest object."""
prompt = build_prompt(request.query, request.history)
return self.generate(
prompt=prompt,
stream=request.stream,
max_tokens=request.max_len,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
)
def _generate_streaming(
self,
prompts: List[str],
is_batch: bool,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]:
"""Generate with streaming output (synchronous)."""
results = []
_debug(f"_generate_streaming: prompts={len(prompts)}")
if is_batch:
raise NotImplementedError("Batch streaming is not implemented yet")
def make_callback(idx: int):
def cb(token: str):
_debug(f"callback[{idx}]: token={token!r}")
results.append(token)
return cb
for i, p in enumerate(prompts):
_debug(f"_generate_streaming: adding task {i}: {p[:30]}...")
self.scheduler.add_task(
prompt=p,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=make_callback(i),
)
def gen():
_debug("generator: start yielding")
while True:
# Yield accumulated tokens
while results:
token = results.pop(0)
if token == "[DONE]":
_debug("generator: got [DONE]")
return
_debug(f"generator: yielding {token!r}")
yield token
time.sleep(0.01)
return gen()
def _generate_non_streaming(
self,
prompts: List[str],
is_batch: bool,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
) -> Union[str, List[str]]:
"""Generate without streaming."""
results = ["" for _ in range(len(prompts))]
done_flags = [False] * len(prompts)
lock = threading.Lock()
def make_callback(idx: int):
def cb(token: str):
if token == "[DONE]":
done_flags[idx] = True
else:
with lock:
results[idx] += token
return cb
for i, p in enumerate(prompts):
self.scheduler.add_task(
prompt=p,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=make_callback(i),
)
while not all(done_flags):
time.sleep(0.001)
return results if is_batch else results[0]
def get_stats(self) -> Dict[str, Any]:
"""Get engine statistics."""
return self.scheduler.get_stats()
def shutdown(self) -> None:
"""Shutdown the engine."""
self.scheduler.stop()