refactor: 精简推理引擎代码,优化参数传递规范

This commit is contained in:
ViperEkura 2026-04-09 14:17:48 +08:00
parent ab5e207f42
commit bbeaff4c60
4 changed files with 68 additions and 156 deletions

View File

@ -45,17 +45,31 @@ class GenerationRequest:
raise ValueError("temperature must be a non-negative number")
class _StreamingResult:
"""Streaming result holder with event-based notification."""
class _Result:
"""Unified result holder for streaming/non-streaming modes."""
def __init__(self):
self.tokens: List[str] = []
self._event = threading.Event()
def __init__(self, count: int = 1, stream: bool = False):
self._stream = stream
self._lock = threading.Lock()
self._event = threading.Event()
self.tokens: List[str] = []
self.results: List[str] = [""] * count if count > 1 else [""]
self.done_flags: List[bool] = [False] * count
self._completed_count = 0
def append(self, token: str):
def append(self, token: str, idx: int = 0):
with self._lock:
self.tokens.append(token)
if self._stream:
self.tokens.append(token)
else:
if token == "[DONE]":
if not self.done_flags[idx]:
self.done_flags[idx] = True
self._completed_count += 1
if self._completed_count == len(self.results):
self._event.set()
else:
self.results[idx] += token
self._event.set()
def pop_all(self) -> List[str]:
@ -69,35 +83,6 @@ class _StreamingResult:
def wait(self, timeout: float = None) -> bool:
return self._event.wait(timeout=timeout)
class _NonStreamingResult:
"""Non-streaming result holder with event-based completion notification."""
def __init__(self, count: int):
self.results: List[str] = [""] * count
self.done_flags: List[bool] = [False] * count
self._completed_count = 0
self._event = threading.Event()
self._lock = threading.Lock()
def append(self, idx: int, token: str):
with self._lock:
if token == "[DONE]":
if not self.done_flags[idx]:
self.done_flags[idx] = True
self._completed_count += 1
if self._completed_count == len(self.results):
self._event.set()
else:
self.results[idx] += token
def is_all_done(self) -> bool:
with self._lock:
return all(self.done_flags)
def wait(self, timeout: float = None) -> bool:
return self._event.wait(timeout=timeout)
def get_results(self) -> List[str]:
with self._lock:
return self.results.copy()
@ -233,7 +218,7 @@ class InferenceEngine:
if is_batch:
raise NotImplementedError("Batch streaming is not implemented yet")
result = _StreamingResult()
result = _Result(stream=True)
task_id = self.scheduler.add_task(
prompt=prompts[0],
@ -272,7 +257,7 @@ class InferenceEngine:
top_k: int,
) -> Union[str, List[str]]:
"""Generate without streaming."""
result = _NonStreamingResult(len(prompts))
result = _Result(count=len(prompts))
for i, p in enumerate(prompts):
# Create closure to capture current index value using factory function

View File

@ -168,9 +168,10 @@ class Task:
def is_finished(self, stop_ids: List[int]) -> bool:
"""Check if task is finished."""
if self.output_ids and self.output_ids[-1] in stop_ids:
return True
return self.output_tokens >= self.max_tokens
return (
bool(self.output_ids and self.output_ids[-1] in stop_ids)
or self.output_tokens >= self.max_tokens
)
def apply_sampling_strategies(
@ -360,58 +361,48 @@ class InferenceScheduler:
return
with self._lock:
to_add = []
for _ in range(min(available_slots, len(self.waiting_queue))):
if self.waiting_queue:
task = self.waiting_queue.pop(0)
task.status = TaskStatus.RUNNING
to_add.append(task)
to_add = [
self.waiting_queue.pop(0)
for _ in range(min(available_slots, len(self.waiting_queue)))
]
for task in to_add:
for i in range(self.max_batch_size):
if all(t.slot != i for t in self.active_tasks):
task.slot = i
break
task.slot = self._allocate_slot()
task.status = TaskStatus.RUNNING
self.active_tasks.append(task)
def _allocate_slot(self) -> int:
"""Allocate an available slot for a task."""
for i in range(self.max_batch_size):
if not any(t.slot == i for t in self.active_tasks):
return i
return -1
def _execute_prefill(self, tasks: List[Task]) -> None:
"""Execute Prefill phase with incremental prefill support."""
if not tasks:
return
# Group tasks by their prefix_len to handle different prefill scenarios
fully_cached_tasks = [] # prefix_len == total_len, skip prefill
partial_prefill_tasks = [] # prefix_len > 0, need incremental prefill
full_prefill_tasks = [] # prefix_len == 0, full prefill
# Group tasks by prefix cache status
fully_cached, partial, full = [], [], []
for task in tasks:
total_len = len(task.prompt_ids)
prefix_len = task.prefix_len
total_len, prefix_len = len(task.prompt_ids), task.prefix_len
if prefix_len == total_len:
# Scenario 1: complete match, skip prefill
task.input_tokens = total_len
task.output_tokens = 0
fully_cached_tasks.append(task)
fully_cached.append(task)
elif prefix_len > 0:
# Scenario 2: partial match, incremental prefill
partial_prefill_tasks.append(task)
partial.append(task)
else:
# Scenario 3: no match, full prefill
full_prefill_tasks.append(task)
full.append(task)
# Handle fully cached tasks - update seq_mask
for task in fully_cached_tasks:
if task.slot >= 0:
self.seq_mask[task.slot, : task.input_tokens] = True
# Handle fully cached tasks
for t in fully_cached:
t.input_tokens, t.output_tokens = len(t.prompt_ids), 0
if t.slot >= 0:
self.seq_mask[t.slot, : t.input_tokens] = True
# Execute full prefill for new prefixes
if full_prefill_tasks:
self._execute_full_prefill(full_prefill_tasks)
# Execute incremental prefill for partial matches
if partial_prefill_tasks:
self._execute_partial_prefill(partial_prefill_tasks)
if full:
self._execute_full_prefill(full)
if partial:
self._execute_partial_prefill(partial)
def _execute_full_prefill(self, tasks: List[Task]) -> None:
"""Execute full prefill for tasks without prefix cache."""

View File

@ -9,7 +9,7 @@ import json
import logging
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional
import torch
import uvicorn
@ -134,78 +134,6 @@ class CompletionResponse(BaseModel):
choices: List[Dict[str, Any]]
class StreamCompletionResponse(BaseModel):
id: str = "chatcmpl-default"
object: str = "chat.completion.chunk"
created: int = 0
model: str = "astrai"
choices: List[Dict[str, Any]]
def convert_messages_to_history(
messages: List[ChatMessage],
) -> tuple[Optional[str], Optional[List[Tuple[str, str]]]]:
"""Convert OpenAI-style messages to system_prompt and history."""
system_prompt = None
history: List[Tuple[str, str]] = []
user_buffer = []
assistant_buffer = []
for msg in messages:
if msg.role == "system":
system_prompt = msg.content
elif msg.role == "user":
if assistant_buffer:
# Flush previous pair
history.append(("".join(user_buffer), "".join(assistant_buffer)))
user_buffer = []
assistant_buffer = []
user_buffer.append(msg.content)
elif msg.role == "assistant":
assistant_buffer.append(msg.content)
else:
logger.warning(f"Unknown role {msg.role}")
return system_prompt, history if history else None
def convert_messages_to_prompt(
messages: List[ChatMessage], engine: InferenceEngine = None
) -> str:
"""Convert messages to prompt string.
Args:
messages: List of ChatMessage objects
engine: InferenceEngine instance for accessing tokenizer
Returns:
str: Formatted prompt string
"""
# Convert to dict format for chat template
msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
# Extract system prompt if present
system_prompt = None
filtered_messages = []
for msg in msg_dicts:
if msg["role"] == "system":
system_prompt = msg["content"]
else:
filtered_messages.append(msg)
# Use engine's tokenizer chat template if available
if engine is not None and engine.tokenizer is not None:
return engine.tokenizer.apply_chat_template(
filtered_messages, system_prompt=system_prompt, tokenize=False
)
# Fallback: simple concatenation (deprecated)
prompt_parts = []
for msg in filtered_messages:
prompt_parts.append(
f"<im▁start>{msg['role']}\n{msg['content']}<im▁end>"
)
return "\n".join(prompt_parts) + "\n<im▁start>assistant\n"
@app.get("/health")
async def health():
return {
@ -233,7 +161,12 @@ async def chat_completion(request: ChatCompletionRequest):
raise HTTPException(status_code=503, detail="Engine not initialized")
# Convert messages to prompt using engine's tokenizer
prompt = convert_messages_to_prompt(request.messages, engine=_engine)
# Extract system prompt if present, then apply chat template
# Apply chat template directly with messages
prompt = _engine.tokenizer.apply_chat_template(
[{"role": m.role, "content": m.content} for m in request.messages],
tokenize=False,
)
if request.stream:
# Streaming response (use synchronous generator)

View File

@ -209,9 +209,9 @@ class AutoTokenizer:
Args:
messages: List of message dicts with 'role' and 'content'.
system_prompt: Optional system prompt string.
system_prompt: Optional system prompt string (auto-converted to first message).
tokenize: Whether to return token IDs (True) or raw string (False).
add_generation_prompt: Whether to add the generation prompt (default: False).
add_generation_prompt: Whether to add the generation prompt (default: True).
**kwargs: Additional variables to pass to the template.
Returns:
@ -225,10 +225,13 @@ class AutoTokenizer:
"Chat template not set. Use set_chat_template() to set a template first."
)
# Auto-convert system_prompt to first message if provided
if system_prompt:
messages = [{"role": "system", "content": system_prompt}] + list(messages)
# Render the template
rendered = self._chat_template.render(
messages=messages,
system_prompt=system_prompt,
add_generation_prompt=add_generation_prompt,
**kwargs,
)