diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py index 6177d75..f701c94 100644 --- a/astrai/inference/engine.py +++ b/astrai/inference/engine.py @@ -45,17 +45,31 @@ class GenerationRequest: raise ValueError("temperature must be a non-negative number") -class _StreamingResult: - """Streaming result holder with event-based notification.""" +class _Result: + """Unified result holder for streaming/non-streaming modes.""" - def __init__(self): - self.tokens: List[str] = [] - self._event = threading.Event() + def __init__(self, count: int = 1, stream: bool = False): + self._stream = stream self._lock = threading.Lock() + self._event = threading.Event() + self.tokens: List[str] = [] + self.results: List[str] = [""] * count if count > 1 else [""] + self.done_flags: List[bool] = [False] * count + self._completed_count = 0 - def append(self, token: str): + def append(self, token: str, idx: int = 0): with self._lock: - self.tokens.append(token) + if self._stream: + self.tokens.append(token) + else: + if token == "[DONE]": + if not self.done_flags[idx]: + self.done_flags[idx] = True + self._completed_count += 1 + if self._completed_count == len(self.results): + self._event.set() + else: + self.results[idx] += token self._event.set() def pop_all(self) -> List[str]: @@ -69,35 +83,6 @@ class _StreamingResult: def wait(self, timeout: float = None) -> bool: return self._event.wait(timeout=timeout) - -class _NonStreamingResult: - """Non-streaming result holder with event-based completion notification.""" - - def __init__(self, count: int): - self.results: List[str] = [""] * count - self.done_flags: List[bool] = [False] * count - self._completed_count = 0 - self._event = threading.Event() - self._lock = threading.Lock() - - def append(self, idx: int, token: str): - with self._lock: - if token == "[DONE]": - if not self.done_flags[idx]: - self.done_flags[idx] = True - self._completed_count += 1 - if self._completed_count == len(self.results): - self._event.set() - else: - self.results[idx] += token - - def is_all_done(self) -> bool: - with self._lock: - return all(self.done_flags) - - def wait(self, timeout: float = None) -> bool: - return self._event.wait(timeout=timeout) - def get_results(self) -> List[str]: with self._lock: return self.results.copy() @@ -233,7 +218,7 @@ class InferenceEngine: if is_batch: raise NotImplementedError("Batch streaming is not implemented yet") - result = _StreamingResult() + result = _Result(stream=True) task_id = self.scheduler.add_task( prompt=prompts[0], @@ -272,7 +257,7 @@ class InferenceEngine: top_k: int, ) -> Union[str, List[str]]: """Generate without streaming.""" - result = _NonStreamingResult(len(prompts)) + result = _Result(count=len(prompts)) for i, p in enumerate(prompts): # Create closure to capture current index value using factory function diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index 24f4804..510b256 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -168,9 +168,10 @@ class Task: def is_finished(self, stop_ids: List[int]) -> bool: """Check if task is finished.""" - if self.output_ids and self.output_ids[-1] in stop_ids: - return True - return self.output_tokens >= self.max_tokens + return ( + bool(self.output_ids and self.output_ids[-1] in stop_ids) + or self.output_tokens >= self.max_tokens + ) def apply_sampling_strategies( @@ -360,58 +361,48 @@ class InferenceScheduler: return with self._lock: - to_add = [] - for _ in range(min(available_slots, len(self.waiting_queue))): - if self.waiting_queue: - task = self.waiting_queue.pop(0) - task.status = TaskStatus.RUNNING - to_add.append(task) - + to_add = [ + self.waiting_queue.pop(0) + for _ in range(min(available_slots, len(self.waiting_queue))) + ] for task in to_add: - for i in range(self.max_batch_size): - if all(t.slot != i for t in self.active_tasks): - task.slot = i - break + task.slot = self._allocate_slot() + task.status = TaskStatus.RUNNING self.active_tasks.append(task) + def _allocate_slot(self) -> int: + """Allocate an available slot for a task.""" + for i in range(self.max_batch_size): + if not any(t.slot == i for t in self.active_tasks): + return i + return -1 + def _execute_prefill(self, tasks: List[Task]) -> None: """Execute Prefill phase with incremental prefill support.""" if not tasks: return - # Group tasks by their prefix_len to handle different prefill scenarios - fully_cached_tasks = [] # prefix_len == total_len, skip prefill - partial_prefill_tasks = [] # prefix_len > 0, need incremental prefill - full_prefill_tasks = [] # prefix_len == 0, full prefill - + # Group tasks by prefix cache status + fully_cached, partial, full = [], [], [] for task in tasks: - total_len = len(task.prompt_ids) - prefix_len = task.prefix_len - + total_len, prefix_len = len(task.prompt_ids), task.prefix_len if prefix_len == total_len: - # Scenario 1: complete match, skip prefill - task.input_tokens = total_len - task.output_tokens = 0 - fully_cached_tasks.append(task) + fully_cached.append(task) elif prefix_len > 0: - # Scenario 2: partial match, incremental prefill - partial_prefill_tasks.append(task) + partial.append(task) else: - # Scenario 3: no match, full prefill - full_prefill_tasks.append(task) + full.append(task) - # Handle fully cached tasks - update seq_mask - for task in fully_cached_tasks: - if task.slot >= 0: - self.seq_mask[task.slot, : task.input_tokens] = True + # Handle fully cached tasks + for t in fully_cached: + t.input_tokens, t.output_tokens = len(t.prompt_ids), 0 + if t.slot >= 0: + self.seq_mask[t.slot, : t.input_tokens] = True - # Execute full prefill for new prefixes - if full_prefill_tasks: - self._execute_full_prefill(full_prefill_tasks) - - # Execute incremental prefill for partial matches - if partial_prefill_tasks: - self._execute_partial_prefill(partial_prefill_tasks) + if full: + self._execute_full_prefill(full) + if partial: + self._execute_partial_prefill(partial) def _execute_full_prefill(self, tasks: List[Task]) -> None: """Execute full prefill for tasks without prefix cache.""" diff --git a/astrai/inference/server.py b/astrai/inference/server.py index 86b25c9..8a68ac0 100644 --- a/astrai/inference/server.py +++ b/astrai/inference/server.py @@ -9,7 +9,7 @@ import json import logging from contextlib import asynccontextmanager from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional import torch import uvicorn @@ -134,78 +134,6 @@ class CompletionResponse(BaseModel): choices: List[Dict[str, Any]] -class StreamCompletionResponse(BaseModel): - id: str = "chatcmpl-default" - object: str = "chat.completion.chunk" - created: int = 0 - model: str = "astrai" - choices: List[Dict[str, Any]] - - -def convert_messages_to_history( - messages: List[ChatMessage], -) -> tuple[Optional[str], Optional[List[Tuple[str, str]]]]: - """Convert OpenAI-style messages to system_prompt and history.""" - system_prompt = None - history: List[Tuple[str, str]] = [] - user_buffer = [] - assistant_buffer = [] - for msg in messages: - if msg.role == "system": - system_prompt = msg.content - elif msg.role == "user": - if assistant_buffer: - # Flush previous pair - history.append(("".join(user_buffer), "".join(assistant_buffer))) - user_buffer = [] - assistant_buffer = [] - user_buffer.append(msg.content) - elif msg.role == "assistant": - assistant_buffer.append(msg.content) - else: - logger.warning(f"Unknown role {msg.role}") - return system_prompt, history if history else None - - -def convert_messages_to_prompt( - messages: List[ChatMessage], engine: InferenceEngine = None -) -> str: - """Convert messages to prompt string. - - Args: - messages: List of ChatMessage objects - engine: InferenceEngine instance for accessing tokenizer - - Returns: - str: Formatted prompt string - """ - # Convert to dict format for chat template - msg_dicts = [{"role": m.role, "content": m.content} for m in messages] - - # Extract system prompt if present - system_prompt = None - filtered_messages = [] - for msg in msg_dicts: - if msg["role"] == "system": - system_prompt = msg["content"] - else: - filtered_messages.append(msg) - - # Use engine's tokenizer chat template if available - if engine is not None and engine.tokenizer is not None: - return engine.tokenizer.apply_chat_template( - filtered_messages, system_prompt=system_prompt, tokenize=False - ) - - # Fallback: simple concatenation (deprecated) - prompt_parts = [] - for msg in filtered_messages: - prompt_parts.append( - f"<|im▁start|>{msg['role']}\n{msg['content']}<|im▁end|>" - ) - return "\n".join(prompt_parts) + "\n<|im▁start|>assistant\n" - - @app.get("/health") async def health(): return { @@ -233,7 +161,12 @@ async def chat_completion(request: ChatCompletionRequest): raise HTTPException(status_code=503, detail="Engine not initialized") # Convert messages to prompt using engine's tokenizer - prompt = convert_messages_to_prompt(request.messages, engine=_engine) + # Extract system prompt if present, then apply chat template + # Apply chat template directly with messages + prompt = _engine.tokenizer.apply_chat_template( + [{"role": m.role, "content": m.content} for m in request.messages], + tokenize=False, + ) if request.stream: # Streaming response (use synchronous generator) diff --git a/astrai/tokenize/tokenizer.py b/astrai/tokenize/tokenizer.py index b40ca0f..14fcb18 100644 --- a/astrai/tokenize/tokenizer.py +++ b/astrai/tokenize/tokenizer.py @@ -209,9 +209,9 @@ class AutoTokenizer: Args: messages: List of message dicts with 'role' and 'content'. - system_prompt: Optional system prompt string. + system_prompt: Optional system prompt string (auto-converted to first message). tokenize: Whether to return token IDs (True) or raw string (False). - add_generation_prompt: Whether to add the generation prompt (default: False). + add_generation_prompt: Whether to add the generation prompt (default: True). **kwargs: Additional variables to pass to the template. Returns: @@ -225,10 +225,13 @@ class AutoTokenizer: "Chat template not set. Use set_chat_template() to set a template first." ) + # Auto-convert system_prompt to first message if provided + if system_prompt: + messages = [{"role": "system", "content": system_prompt}] + list(messages) + # Render the template rendered = self._chat_template.render( messages=messages, - system_prompt=system_prompt, add_generation_prompt=add_generation_prompt, **kwargs, )