refactor: 精简推理引擎代码,优化参数传递规范

This commit is contained in:
ViperEkura 2026-04-09 14:17:48 +08:00
parent ab5e207f42
commit bbeaff4c60
4 changed files with 68 additions and 156 deletions

View File

@ -45,17 +45,31 @@ class GenerationRequest:
raise ValueError("temperature must be a non-negative number") raise ValueError("temperature must be a non-negative number")
class _StreamingResult: class _Result:
"""Streaming result holder with event-based notification.""" """Unified result holder for streaming/non-streaming modes."""
def __init__(self): def __init__(self, count: int = 1, stream: bool = False):
self.tokens: List[str] = [] self._stream = stream
self._event = threading.Event()
self._lock = threading.Lock() self._lock = threading.Lock()
self._event = threading.Event()
self.tokens: List[str] = []
self.results: List[str] = [""] * count if count > 1 else [""]
self.done_flags: List[bool] = [False] * count
self._completed_count = 0
def append(self, token: str): def append(self, token: str, idx: int = 0):
with self._lock: with self._lock:
self.tokens.append(token) if self._stream:
self.tokens.append(token)
else:
if token == "[DONE]":
if not self.done_flags[idx]:
self.done_flags[idx] = True
self._completed_count += 1
if self._completed_count == len(self.results):
self._event.set()
else:
self.results[idx] += token
self._event.set() self._event.set()
def pop_all(self) -> List[str]: def pop_all(self) -> List[str]:
@ -69,35 +83,6 @@ class _StreamingResult:
def wait(self, timeout: float = None) -> bool: def wait(self, timeout: float = None) -> bool:
return self._event.wait(timeout=timeout) return self._event.wait(timeout=timeout)
class _NonStreamingResult:
"""Non-streaming result holder with event-based completion notification."""
def __init__(self, count: int):
self.results: List[str] = [""] * count
self.done_flags: List[bool] = [False] * count
self._completed_count = 0
self._event = threading.Event()
self._lock = threading.Lock()
def append(self, idx: int, token: str):
with self._lock:
if token == "[DONE]":
if not self.done_flags[idx]:
self.done_flags[idx] = True
self._completed_count += 1
if self._completed_count == len(self.results):
self._event.set()
else:
self.results[idx] += token
def is_all_done(self) -> bool:
with self._lock:
return all(self.done_flags)
def wait(self, timeout: float = None) -> bool:
return self._event.wait(timeout=timeout)
def get_results(self) -> List[str]: def get_results(self) -> List[str]:
with self._lock: with self._lock:
return self.results.copy() return self.results.copy()
@ -233,7 +218,7 @@ class InferenceEngine:
if is_batch: if is_batch:
raise NotImplementedError("Batch streaming is not implemented yet") raise NotImplementedError("Batch streaming is not implemented yet")
result = _StreamingResult() result = _Result(stream=True)
task_id = self.scheduler.add_task( task_id = self.scheduler.add_task(
prompt=prompts[0], prompt=prompts[0],
@ -272,7 +257,7 @@ class InferenceEngine:
top_k: int, top_k: int,
) -> Union[str, List[str]]: ) -> Union[str, List[str]]:
"""Generate without streaming.""" """Generate without streaming."""
result = _NonStreamingResult(len(prompts)) result = _Result(count=len(prompts))
for i, p in enumerate(prompts): for i, p in enumerate(prompts):
# Create closure to capture current index value using factory function # Create closure to capture current index value using factory function

View File

@ -168,9 +168,10 @@ class Task:
def is_finished(self, stop_ids: List[int]) -> bool: def is_finished(self, stop_ids: List[int]) -> bool:
"""Check if task is finished.""" """Check if task is finished."""
if self.output_ids and self.output_ids[-1] in stop_ids: return (
return True bool(self.output_ids and self.output_ids[-1] in stop_ids)
return self.output_tokens >= self.max_tokens or self.output_tokens >= self.max_tokens
)
def apply_sampling_strategies( def apply_sampling_strategies(
@ -360,58 +361,48 @@ class InferenceScheduler:
return return
with self._lock: with self._lock:
to_add = [] to_add = [
for _ in range(min(available_slots, len(self.waiting_queue))): self.waiting_queue.pop(0)
if self.waiting_queue: for _ in range(min(available_slots, len(self.waiting_queue)))
task = self.waiting_queue.pop(0) ]
task.status = TaskStatus.RUNNING
to_add.append(task)
for task in to_add: for task in to_add:
for i in range(self.max_batch_size): task.slot = self._allocate_slot()
if all(t.slot != i for t in self.active_tasks): task.status = TaskStatus.RUNNING
task.slot = i
break
self.active_tasks.append(task) self.active_tasks.append(task)
def _allocate_slot(self) -> int:
"""Allocate an available slot for a task."""
for i in range(self.max_batch_size):
if not any(t.slot == i for t in self.active_tasks):
return i
return -1
def _execute_prefill(self, tasks: List[Task]) -> None: def _execute_prefill(self, tasks: List[Task]) -> None:
"""Execute Prefill phase with incremental prefill support.""" """Execute Prefill phase with incremental prefill support."""
if not tasks: if not tasks:
return return
# Group tasks by their prefix_len to handle different prefill scenarios # Group tasks by prefix cache status
fully_cached_tasks = [] # prefix_len == total_len, skip prefill fully_cached, partial, full = [], [], []
partial_prefill_tasks = [] # prefix_len > 0, need incremental prefill
full_prefill_tasks = [] # prefix_len == 0, full prefill
for task in tasks: for task in tasks:
total_len = len(task.prompt_ids) total_len, prefix_len = len(task.prompt_ids), task.prefix_len
prefix_len = task.prefix_len
if prefix_len == total_len: if prefix_len == total_len:
# Scenario 1: complete match, skip prefill fully_cached.append(task)
task.input_tokens = total_len
task.output_tokens = 0
fully_cached_tasks.append(task)
elif prefix_len > 0: elif prefix_len > 0:
# Scenario 2: partial match, incremental prefill partial.append(task)
partial_prefill_tasks.append(task)
else: else:
# Scenario 3: no match, full prefill full.append(task)
full_prefill_tasks.append(task)
# Handle fully cached tasks - update seq_mask # Handle fully cached tasks
for task in fully_cached_tasks: for t in fully_cached:
if task.slot >= 0: t.input_tokens, t.output_tokens = len(t.prompt_ids), 0
self.seq_mask[task.slot, : task.input_tokens] = True if t.slot >= 0:
self.seq_mask[t.slot, : t.input_tokens] = True
# Execute full prefill for new prefixes if full:
if full_prefill_tasks: self._execute_full_prefill(full)
self._execute_full_prefill(full_prefill_tasks) if partial:
self._execute_partial_prefill(partial)
# Execute incremental prefill for partial matches
if partial_prefill_tasks:
self._execute_partial_prefill(partial_prefill_tasks)
def _execute_full_prefill(self, tasks: List[Task]) -> None: def _execute_full_prefill(self, tasks: List[Task]) -> None:
"""Execute full prefill for tasks without prefix cache.""" """Execute full prefill for tasks without prefix cache."""

View File

@ -9,7 +9,7 @@ import json
import logging import logging
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional
import torch import torch
import uvicorn import uvicorn
@ -134,78 +134,6 @@ class CompletionResponse(BaseModel):
choices: List[Dict[str, Any]] choices: List[Dict[str, Any]]
class StreamCompletionResponse(BaseModel):
id: str = "chatcmpl-default"
object: str = "chat.completion.chunk"
created: int = 0
model: str = "astrai"
choices: List[Dict[str, Any]]
def convert_messages_to_history(
messages: List[ChatMessage],
) -> tuple[Optional[str], Optional[List[Tuple[str, str]]]]:
"""Convert OpenAI-style messages to system_prompt and history."""
system_prompt = None
history: List[Tuple[str, str]] = []
user_buffer = []
assistant_buffer = []
for msg in messages:
if msg.role == "system":
system_prompt = msg.content
elif msg.role == "user":
if assistant_buffer:
# Flush previous pair
history.append(("".join(user_buffer), "".join(assistant_buffer)))
user_buffer = []
assistant_buffer = []
user_buffer.append(msg.content)
elif msg.role == "assistant":
assistant_buffer.append(msg.content)
else:
logger.warning(f"Unknown role {msg.role}")
return system_prompt, history if history else None
def convert_messages_to_prompt(
messages: List[ChatMessage], engine: InferenceEngine = None
) -> str:
"""Convert messages to prompt string.
Args:
messages: List of ChatMessage objects
engine: InferenceEngine instance for accessing tokenizer
Returns:
str: Formatted prompt string
"""
# Convert to dict format for chat template
msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
# Extract system prompt if present
system_prompt = None
filtered_messages = []
for msg in msg_dicts:
if msg["role"] == "system":
system_prompt = msg["content"]
else:
filtered_messages.append(msg)
# Use engine's tokenizer chat template if available
if engine is not None and engine.tokenizer is not None:
return engine.tokenizer.apply_chat_template(
filtered_messages, system_prompt=system_prompt, tokenize=False
)
# Fallback: simple concatenation (deprecated)
prompt_parts = []
for msg in filtered_messages:
prompt_parts.append(
f"<im▁start>{msg['role']}\n{msg['content']}<im▁end>"
)
return "\n".join(prompt_parts) + "\n<im▁start>assistant\n"
@app.get("/health") @app.get("/health")
async def health(): async def health():
return { return {
@ -233,7 +161,12 @@ async def chat_completion(request: ChatCompletionRequest):
raise HTTPException(status_code=503, detail="Engine not initialized") raise HTTPException(status_code=503, detail="Engine not initialized")
# Convert messages to prompt using engine's tokenizer # Convert messages to prompt using engine's tokenizer
prompt = convert_messages_to_prompt(request.messages, engine=_engine) # Extract system prompt if present, then apply chat template
# Apply chat template directly with messages
prompt = _engine.tokenizer.apply_chat_template(
[{"role": m.role, "content": m.content} for m in request.messages],
tokenize=False,
)
if request.stream: if request.stream:
# Streaming response (use synchronous generator) # Streaming response (use synchronous generator)

View File

@ -209,9 +209,9 @@ class AutoTokenizer:
Args: Args:
messages: List of message dicts with 'role' and 'content'. messages: List of message dicts with 'role' and 'content'.
system_prompt: Optional system prompt string. system_prompt: Optional system prompt string (auto-converted to first message).
tokenize: Whether to return token IDs (True) or raw string (False). tokenize: Whether to return token IDs (True) or raw string (False).
add_generation_prompt: Whether to add the generation prompt (default: False). add_generation_prompt: Whether to add the generation prompt (default: True).
**kwargs: Additional variables to pass to the template. **kwargs: Additional variables to pass to the template.
Returns: Returns:
@ -225,10 +225,13 @@ class AutoTokenizer:
"Chat template not set. Use set_chat_template() to set a template first." "Chat template not set. Use set_chat_template() to set a template first."
) )
# Auto-convert system_prompt to first message if provided
if system_prompt:
messages = [{"role": "system", "content": system_prompt}] + list(messages)
# Render the template # Render the template
rendered = self._chat_template.render( rendered = self._chat_template.render(
messages=messages, messages=messages,
system_prompt=system_prompt,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
**kwargs, **kwargs,
) )