refactor: 精简推理引擎代码,优化参数传递规范
This commit is contained in:
parent
ab5e207f42
commit
bbeaff4c60
|
|
@ -45,17 +45,31 @@ class GenerationRequest:
|
|||
raise ValueError("temperature must be a non-negative number")
|
||||
|
||||
|
||||
class _StreamingResult:
|
||||
"""Streaming result holder with event-based notification."""
|
||||
class _Result:
|
||||
"""Unified result holder for streaming/non-streaming modes."""
|
||||
|
||||
def __init__(self):
|
||||
self.tokens: List[str] = []
|
||||
self._event = threading.Event()
|
||||
def __init__(self, count: int = 1, stream: bool = False):
|
||||
self._stream = stream
|
||||
self._lock = threading.Lock()
|
||||
self._event = threading.Event()
|
||||
self.tokens: List[str] = []
|
||||
self.results: List[str] = [""] * count if count > 1 else [""]
|
||||
self.done_flags: List[bool] = [False] * count
|
||||
self._completed_count = 0
|
||||
|
||||
def append(self, token: str):
|
||||
def append(self, token: str, idx: int = 0):
|
||||
with self._lock:
|
||||
self.tokens.append(token)
|
||||
if self._stream:
|
||||
self.tokens.append(token)
|
||||
else:
|
||||
if token == "[DONE]":
|
||||
if not self.done_flags[idx]:
|
||||
self.done_flags[idx] = True
|
||||
self._completed_count += 1
|
||||
if self._completed_count == len(self.results):
|
||||
self._event.set()
|
||||
else:
|
||||
self.results[idx] += token
|
||||
self._event.set()
|
||||
|
||||
def pop_all(self) -> List[str]:
|
||||
|
|
@ -69,35 +83,6 @@ class _StreamingResult:
|
|||
def wait(self, timeout: float = None) -> bool:
|
||||
return self._event.wait(timeout=timeout)
|
||||
|
||||
|
||||
class _NonStreamingResult:
|
||||
"""Non-streaming result holder with event-based completion notification."""
|
||||
|
||||
def __init__(self, count: int):
|
||||
self.results: List[str] = [""] * count
|
||||
self.done_flags: List[bool] = [False] * count
|
||||
self._completed_count = 0
|
||||
self._event = threading.Event()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def append(self, idx: int, token: str):
|
||||
with self._lock:
|
||||
if token == "[DONE]":
|
||||
if not self.done_flags[idx]:
|
||||
self.done_flags[idx] = True
|
||||
self._completed_count += 1
|
||||
if self._completed_count == len(self.results):
|
||||
self._event.set()
|
||||
else:
|
||||
self.results[idx] += token
|
||||
|
||||
def is_all_done(self) -> bool:
|
||||
with self._lock:
|
||||
return all(self.done_flags)
|
||||
|
||||
def wait(self, timeout: float = None) -> bool:
|
||||
return self._event.wait(timeout=timeout)
|
||||
|
||||
def get_results(self) -> List[str]:
|
||||
with self._lock:
|
||||
return self.results.copy()
|
||||
|
|
@ -233,7 +218,7 @@ class InferenceEngine:
|
|||
if is_batch:
|
||||
raise NotImplementedError("Batch streaming is not implemented yet")
|
||||
|
||||
result = _StreamingResult()
|
||||
result = _Result(stream=True)
|
||||
|
||||
task_id = self.scheduler.add_task(
|
||||
prompt=prompts[0],
|
||||
|
|
@ -272,7 +257,7 @@ class InferenceEngine:
|
|||
top_k: int,
|
||||
) -> Union[str, List[str]]:
|
||||
"""Generate without streaming."""
|
||||
result = _NonStreamingResult(len(prompts))
|
||||
result = _Result(count=len(prompts))
|
||||
|
||||
for i, p in enumerate(prompts):
|
||||
# Create closure to capture current index value using factory function
|
||||
|
|
|
|||
|
|
@ -168,9 +168,10 @@ class Task:
|
|||
|
||||
def is_finished(self, stop_ids: List[int]) -> bool:
|
||||
"""Check if task is finished."""
|
||||
if self.output_ids and self.output_ids[-1] in stop_ids:
|
||||
return True
|
||||
return self.output_tokens >= self.max_tokens
|
||||
return (
|
||||
bool(self.output_ids and self.output_ids[-1] in stop_ids)
|
||||
or self.output_tokens >= self.max_tokens
|
||||
)
|
||||
|
||||
|
||||
def apply_sampling_strategies(
|
||||
|
|
@ -360,58 +361,48 @@ class InferenceScheduler:
|
|||
return
|
||||
|
||||
with self._lock:
|
||||
to_add = []
|
||||
for _ in range(min(available_slots, len(self.waiting_queue))):
|
||||
if self.waiting_queue:
|
||||
task = self.waiting_queue.pop(0)
|
||||
task.status = TaskStatus.RUNNING
|
||||
to_add.append(task)
|
||||
|
||||
to_add = [
|
||||
self.waiting_queue.pop(0)
|
||||
for _ in range(min(available_slots, len(self.waiting_queue)))
|
||||
]
|
||||
for task in to_add:
|
||||
for i in range(self.max_batch_size):
|
||||
if all(t.slot != i for t in self.active_tasks):
|
||||
task.slot = i
|
||||
break
|
||||
task.slot = self._allocate_slot()
|
||||
task.status = TaskStatus.RUNNING
|
||||
self.active_tasks.append(task)
|
||||
|
||||
def _allocate_slot(self) -> int:
|
||||
"""Allocate an available slot for a task."""
|
||||
for i in range(self.max_batch_size):
|
||||
if not any(t.slot == i for t in self.active_tasks):
|
||||
return i
|
||||
return -1
|
||||
|
||||
def _execute_prefill(self, tasks: List[Task]) -> None:
|
||||
"""Execute Prefill phase with incremental prefill support."""
|
||||
if not tasks:
|
||||
return
|
||||
|
||||
# Group tasks by their prefix_len to handle different prefill scenarios
|
||||
fully_cached_tasks = [] # prefix_len == total_len, skip prefill
|
||||
partial_prefill_tasks = [] # prefix_len > 0, need incremental prefill
|
||||
full_prefill_tasks = [] # prefix_len == 0, full prefill
|
||||
|
||||
# Group tasks by prefix cache status
|
||||
fully_cached, partial, full = [], [], []
|
||||
for task in tasks:
|
||||
total_len = len(task.prompt_ids)
|
||||
prefix_len = task.prefix_len
|
||||
|
||||
total_len, prefix_len = len(task.prompt_ids), task.prefix_len
|
||||
if prefix_len == total_len:
|
||||
# Scenario 1: complete match, skip prefill
|
||||
task.input_tokens = total_len
|
||||
task.output_tokens = 0
|
||||
fully_cached_tasks.append(task)
|
||||
fully_cached.append(task)
|
||||
elif prefix_len > 0:
|
||||
# Scenario 2: partial match, incremental prefill
|
||||
partial_prefill_tasks.append(task)
|
||||
partial.append(task)
|
||||
else:
|
||||
# Scenario 3: no match, full prefill
|
||||
full_prefill_tasks.append(task)
|
||||
full.append(task)
|
||||
|
||||
# Handle fully cached tasks - update seq_mask
|
||||
for task in fully_cached_tasks:
|
||||
if task.slot >= 0:
|
||||
self.seq_mask[task.slot, : task.input_tokens] = True
|
||||
# Handle fully cached tasks
|
||||
for t in fully_cached:
|
||||
t.input_tokens, t.output_tokens = len(t.prompt_ids), 0
|
||||
if t.slot >= 0:
|
||||
self.seq_mask[t.slot, : t.input_tokens] = True
|
||||
|
||||
# Execute full prefill for new prefixes
|
||||
if full_prefill_tasks:
|
||||
self._execute_full_prefill(full_prefill_tasks)
|
||||
|
||||
# Execute incremental prefill for partial matches
|
||||
if partial_prefill_tasks:
|
||||
self._execute_partial_prefill(partial_prefill_tasks)
|
||||
if full:
|
||||
self._execute_full_prefill(full)
|
||||
if partial:
|
||||
self._execute_partial_prefill(partial)
|
||||
|
||||
def _execute_full_prefill(self, tasks: List[Task]) -> None:
|
||||
"""Execute full prefill for tasks without prefix cache."""
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import json
|
|||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import uvicorn
|
||||
|
|
@ -134,78 +134,6 @@ class CompletionResponse(BaseModel):
|
|||
choices: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class StreamCompletionResponse(BaseModel):
|
||||
id: str = "chatcmpl-default"
|
||||
object: str = "chat.completion.chunk"
|
||||
created: int = 0
|
||||
model: str = "astrai"
|
||||
choices: List[Dict[str, Any]]
|
||||
|
||||
|
||||
def convert_messages_to_history(
|
||||
messages: List[ChatMessage],
|
||||
) -> tuple[Optional[str], Optional[List[Tuple[str, str]]]]:
|
||||
"""Convert OpenAI-style messages to system_prompt and history."""
|
||||
system_prompt = None
|
||||
history: List[Tuple[str, str]] = []
|
||||
user_buffer = []
|
||||
assistant_buffer = []
|
||||
for msg in messages:
|
||||
if msg.role == "system":
|
||||
system_prompt = msg.content
|
||||
elif msg.role == "user":
|
||||
if assistant_buffer:
|
||||
# Flush previous pair
|
||||
history.append(("".join(user_buffer), "".join(assistant_buffer)))
|
||||
user_buffer = []
|
||||
assistant_buffer = []
|
||||
user_buffer.append(msg.content)
|
||||
elif msg.role == "assistant":
|
||||
assistant_buffer.append(msg.content)
|
||||
else:
|
||||
logger.warning(f"Unknown role {msg.role}")
|
||||
return system_prompt, history if history else None
|
||||
|
||||
|
||||
def convert_messages_to_prompt(
|
||||
messages: List[ChatMessage], engine: InferenceEngine = None
|
||||
) -> str:
|
||||
"""Convert messages to prompt string.
|
||||
|
||||
Args:
|
||||
messages: List of ChatMessage objects
|
||||
engine: InferenceEngine instance for accessing tokenizer
|
||||
|
||||
Returns:
|
||||
str: Formatted prompt string
|
||||
"""
|
||||
# Convert to dict format for chat template
|
||||
msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
|
||||
|
||||
# Extract system prompt if present
|
||||
system_prompt = None
|
||||
filtered_messages = []
|
||||
for msg in msg_dicts:
|
||||
if msg["role"] == "system":
|
||||
system_prompt = msg["content"]
|
||||
else:
|
||||
filtered_messages.append(msg)
|
||||
|
||||
# Use engine's tokenizer chat template if available
|
||||
if engine is not None and engine.tokenizer is not None:
|
||||
return engine.tokenizer.apply_chat_template(
|
||||
filtered_messages, system_prompt=system_prompt, tokenize=False
|
||||
)
|
||||
|
||||
# Fallback: simple concatenation (deprecated)
|
||||
prompt_parts = []
|
||||
for msg in filtered_messages:
|
||||
prompt_parts.append(
|
||||
f"<|im▁start|>{msg['role']}\n{msg['content']}<|im▁end|>"
|
||||
)
|
||||
return "\n".join(prompt_parts) + "\n<|im▁start|>assistant\n"
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {
|
||||
|
|
@ -233,7 +161,12 @@ async def chat_completion(request: ChatCompletionRequest):
|
|||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
|
||||
# Convert messages to prompt using engine's tokenizer
|
||||
prompt = convert_messages_to_prompt(request.messages, engine=_engine)
|
||||
# Extract system prompt if present, then apply chat template
|
||||
# Apply chat template directly with messages
|
||||
prompt = _engine.tokenizer.apply_chat_template(
|
||||
[{"role": m.role, "content": m.content} for m in request.messages],
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
# Streaming response (use synchronous generator)
|
||||
|
|
|
|||
|
|
@ -209,9 +209,9 @@ class AutoTokenizer:
|
|||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
system_prompt: Optional system prompt string.
|
||||
system_prompt: Optional system prompt string (auto-converted to first message).
|
||||
tokenize: Whether to return token IDs (True) or raw string (False).
|
||||
add_generation_prompt: Whether to add the generation prompt (default: False).
|
||||
add_generation_prompt: Whether to add the generation prompt (default: True).
|
||||
**kwargs: Additional variables to pass to the template.
|
||||
|
||||
Returns:
|
||||
|
|
@ -225,10 +225,13 @@ class AutoTokenizer:
|
|||
"Chat template not set. Use set_chat_template() to set a template first."
|
||||
)
|
||||
|
||||
# Auto-convert system_prompt to first message if provided
|
||||
if system_prompt:
|
||||
messages = [{"role": "system", "content": system_prompt}] + list(messages)
|
||||
|
||||
# Render the template
|
||||
rendered = self._chat_template.render(
|
||||
messages=messages,
|
||||
system_prompt=system_prompt,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue