Compare commits
1 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
6961039db0 |
|
|
@ -1,9 +1,15 @@
|
||||||
"""Chat service module"""
|
"""Chat service module - Refactored with step-by-step flow"""
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
import logging
|
import logging
|
||||||
|
from typing import List, Dict, Any, AsyncGenerator, Tuple, Optional
|
||||||
|
|
||||||
|
# For Python < 3.9 compatibility
|
||||||
|
try:
|
||||||
|
from typing import List
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
from typing import List, Dict,AsyncGenerator
|
|
||||||
from luxx.models import Conversation, Message, LLMProvider
|
from luxx.models import Conversation, Message, LLMProvider
|
||||||
from luxx.tools.executor import ToolExecutor
|
from luxx.tools.executor import ToolExecutor
|
||||||
from luxx.tools.core import registry
|
from luxx.tools.core import registry
|
||||||
|
|
@ -11,85 +17,274 @@ from luxx.services.llm_client import LLMClient
|
||||||
from luxx.database import SessionLocal
|
from luxx.database import SessionLocal
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
# Maximum iterations to prevent infinite loops
|
|
||||||
MAX_ITERATIONS = 20
|
MAX_ITERATIONS = 20
|
||||||
|
|
||||||
|
|
||||||
def _sse_event(event: str, data: dict) -> str:
|
def _sse_event(event: str, data: dict) -> str:
|
||||||
"""Format a Server-Sent Event string."""
|
"""Format SSE event string."""
|
||||||
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
def get_llm_client(conversation: Conversation = None):
|
def get_llm_client(conversation: Conversation = None) -> Tuple[LLMClient, Optional[int]]:
|
||||||
"""Get LLM client, optionally using conversation's provider. Returns (client, max_tokens)"""
|
"""Get LLM client from conversation's provider."""
|
||||||
max_tokens = None
|
|
||||||
if conversation and conversation.provider_id:
|
if conversation and conversation.provider_id:
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
provider = db.query(LLMProvider).filter(LLMProvider.id == conversation.provider_id).first()
|
provider = db.query(LLMProvider).filter(LLMProvider.id == conversation.provider_id).first()
|
||||||
if provider:
|
if provider:
|
||||||
max_tokens = provider.max_tokens
|
return LLMClient(
|
||||||
client = LLMClient(
|
|
||||||
api_key=provider.api_key,
|
api_key=provider.api_key,
|
||||||
api_url=provider.base_url,
|
api_url=provider.base_url,
|
||||||
model=provider.default_model
|
model=provider.default_model
|
||||||
)
|
), provider.max_tokens
|
||||||
return client, max_tokens
|
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
return LLMClient(), None
|
||||||
|
|
||||||
# Fallback to global config
|
|
||||||
client = LLMClient()
|
class StreamState:
|
||||||
return client, max_tokens
|
"""Holds streaming state across iterations."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.messages: List[Dict] = []
|
||||||
|
self.all_steps: List[Dict] = []
|
||||||
|
self.all_tool_calls: List[Dict] = []
|
||||||
|
self.all_tool_results: List[Dict] = []
|
||||||
|
self.step_index: int = 0
|
||||||
|
self.total_usage: Dict[str, int] = {
|
||||||
|
"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0
|
||||||
|
}
|
||||||
|
# Current iteration state
|
||||||
|
self.full_content: str = ""
|
||||||
|
self.full_thinking: str = ""
|
||||||
|
self.tool_calls_list: List[Dict] = []
|
||||||
|
self.thinking_step_id: Optional[str] = None
|
||||||
|
self.thinking_step_idx: Optional[int] = None
|
||||||
|
self.text_step_id: Optional[str] = None
|
||||||
|
self.text_step_idx: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatService:
|
class ChatService:
|
||||||
"""Chat service with tool support"""
|
"""Chat service with step-by-step flow architecture."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.tool_executor = ToolExecutor()
|
self.tool_executor = ToolExecutor()
|
||||||
|
|
||||||
def build_messages(
|
# ==================== Step 1: Initialize ====================
|
||||||
self,
|
|
||||||
conversation: Conversation,
|
|
||||||
include_system: bool = True
|
|
||||||
) -> List[Dict[str, str]]:
|
|
||||||
"""Build message list"""
|
|
||||||
|
|
||||||
|
def build_messages(self, conversation: Conversation, user_message: str) -> List[Dict[str, str]]:
|
||||||
|
"""Build message list including user message."""
|
||||||
messages = []
|
messages = []
|
||||||
|
if conversation.system_prompt:
|
||||||
if include_system and conversation.system_prompt:
|
messages.append({"role": "system", "content": conversation.system_prompt})
|
||||||
messages.append({
|
|
||||||
"role": "system",
|
|
||||||
"content": conversation.system_prompt
|
|
||||||
})
|
|
||||||
|
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
db_messages = db.query(Message).filter(
|
for msg in db.query(Message).filter(
|
||||||
Message.conversation_id == conversation.id
|
Message.conversation_id == conversation.id
|
||||||
).order_by(Message.created_at).all()
|
).order_by(Message.created_at).all():
|
||||||
|
|
||||||
for msg in db_messages:
|
|
||||||
# Parse JSON content if possible
|
|
||||||
try:
|
try:
|
||||||
content_obj = json.loads(msg.content) if msg.content else {}
|
obj = json.loads(msg.content) if msg.content else {}
|
||||||
if isinstance(content_obj, dict):
|
content = obj.get("text", msg.content) if isinstance(obj, dict) else msg.content
|
||||||
content = content_obj.get("text", msg.content)
|
|
||||||
else:
|
|
||||||
content = msg.content
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
except (json.JSONDecodeError, TypeError):
|
||||||
content = msg.content
|
content = msg.content
|
||||||
|
messages.append({"role": msg.role, "content": content})
|
||||||
messages.append({
|
|
||||||
"role": msg.role,
|
|
||||||
"content": content
|
|
||||||
})
|
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
messages.append({"role": "user", "content": json.dumps({"text": user_message, "attachments": []})})
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
def init_stream_state(self, conversation: Conversation, user_message: str, enabled_tools: list) -> Tuple[StreamState, LLMClient, Dict, str, Optional[int]]:
|
||||||
|
"""Initialize streaming state. Returns: (state, llm, tools, model, max_tokens)"""
|
||||||
|
state = StreamState()
|
||||||
|
state.messages = self.build_messages(conversation, user_message)
|
||||||
|
|
||||||
|
tools = [t for t in registry.list_all() if t.get("function", {}).get("name") in enabled_tools] if enabled_tools else []
|
||||||
|
|
||||||
|
llm, max_tokens = get_llm_client(conversation)
|
||||||
|
model = conversation.model or llm.default_model or "gpt-4"
|
||||||
|
|
||||||
|
tool_context = {"workspace": None, "user_id": None, "username": None, "user_permission_level": 1}
|
||||||
|
|
||||||
|
return state, llm, tools, model, max_tokens, tool_context
|
||||||
|
|
||||||
|
# ==================== Step 2: Stream LLM ====================
|
||||||
|
|
||||||
|
def parse_sse_line(self, sse_line: str) -> Tuple[Optional[str], Optional[str]]:
|
||||||
|
"""Parse SSE line into (event_type, data_str)."""
|
||||||
|
event_type = data_str = None
|
||||||
|
for line in sse_line.strip().split('\n'):
|
||||||
|
if line.startswith('event: '):
|
||||||
|
event_type = line[7:].strip()
|
||||||
|
elif line.startswith('data: '):
|
||||||
|
data_str = line[6:].strip()
|
||||||
|
return event_type, data_str
|
||||||
|
|
||||||
|
def stream_llm_response(self, llm, model: str, messages: List[Dict], tools: list,
|
||||||
|
temperature: float, max_tokens: int, thinking_enabled: bool):
|
||||||
|
"""
|
||||||
|
Stream LLM response and yield (sse_line, parsed_chunk) pairs.
|
||||||
|
"""
|
||||||
|
for sse_line in llm.stream_call(
|
||||||
|
model=model, messages=messages, tools=tools,
|
||||||
|
temperature=temperature, max_tokens=max_tokens or 8192,
|
||||||
|
thinking_enabled=thinking_enabled
|
||||||
|
):
|
||||||
|
_, data_str = self.parse_sse_line(sse_line)
|
||||||
|
chunk = None
|
||||||
|
if data_str:
|
||||||
|
try:
|
||||||
|
chunk = json.loads(data_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
yield sse_line, chunk
|
||||||
|
|
||||||
|
def process_delta(self, state: StreamState, delta: dict) -> List[str]:
|
||||||
|
"""
|
||||||
|
Process a single delta, return list of SSE event strings.
|
||||||
|
"""
|
||||||
|
events = []
|
||||||
|
|
||||||
|
# Handle thinking/reasoning
|
||||||
|
reasoning = delta.get("reasoning_content", "")
|
||||||
|
if reasoning:
|
||||||
|
if not state.full_thinking:
|
||||||
|
state.thinking_step_idx = state.step_index
|
||||||
|
state.thinking_step_id = f"step-{state.step_index}"
|
||||||
|
state.step_index += 1
|
||||||
|
state.full_thinking += reasoning
|
||||||
|
events.append(_sse_event("process_step", {
|
||||||
|
"step": {"id": state.thinking_step_id, "index": state.thinking_step_idx, "type": "thinking", "content": state.full_thinking}
|
||||||
|
}))
|
||||||
|
|
||||||
|
# Handle content
|
||||||
|
content = delta.get("content", "")
|
||||||
|
if content:
|
||||||
|
if not state.full_content:
|
||||||
|
state.text_step_idx = state.step_index
|
||||||
|
state.text_step_id = f"step-{state.step_index}"
|
||||||
|
state.step_index += 1
|
||||||
|
state.full_content += content
|
||||||
|
events.append(_sse_event("process_step", {
|
||||||
|
"step": {"id": state.text_step_id, "index": state.text_step_idx, "type": "text", "content": state.full_content}
|
||||||
|
}))
|
||||||
|
|
||||||
|
# Handle tool calls
|
||||||
|
for tc in delta.get("tool_calls", []):
|
||||||
|
idx = tc.get("index", 0)
|
||||||
|
if idx >= len(state.tool_calls_list):
|
||||||
|
state.tool_calls_list.append({"id": tc.get("id", ""), "type": "function", "function": {"name": "", "arguments": ""}})
|
||||||
|
func = tc.get("function", {})
|
||||||
|
if func.get("name"):
|
||||||
|
state.tool_calls_list[idx]["function"]["name"] += func["name"]
|
||||||
|
if func.get("arguments"):
|
||||||
|
state.tool_calls_list[idx]["function"]["arguments"] += func["arguments"]
|
||||||
|
|
||||||
|
return events
|
||||||
|
|
||||||
|
def save_steps(self, state: StreamState):
|
||||||
|
"""Save current iteration steps to all_steps."""
|
||||||
|
if state.thinking_step_id:
|
||||||
|
state.all_steps.append({"id": state.thinking_step_id, "index": state.thinking_step_idx, "type": "thinking", "content": state.full_thinking})
|
||||||
|
if state.text_step_id:
|
||||||
|
state.all_steps.append({"id": state.text_step_id, "index": state.text_step_idx, "type": "text", "content": state.full_content})
|
||||||
|
|
||||||
|
# ==================== Step 3: Execute Tools ====================
|
||||||
|
|
||||||
|
def execute_tools(self, state: StreamState, tool_context: Dict) -> Tuple[List[Dict], List[str]]:
|
||||||
|
"""
|
||||||
|
Execute tools and return (results, events).
|
||||||
|
"""
|
||||||
|
if not state.tool_calls_list:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
state.all_tool_calls.extend(state.tool_calls_list)
|
||||||
|
tool_call_ids = []
|
||||||
|
events = []
|
||||||
|
|
||||||
|
# Yield tool_call steps
|
||||||
|
for tc in state.tool_calls_list:
|
||||||
|
step_id = f"step-{state.step_index}"
|
||||||
|
tool_call_ids.append(step_id)
|
||||||
|
state.step_index += 1
|
||||||
|
step = {
|
||||||
|
"id": step_id, "index": len(state.all_steps), "type": "tool_call",
|
||||||
|
"id_ref": tc.get("id", ""), "name": tc["function"]["name"], "arguments": tc["function"]["arguments"]
|
||||||
|
}
|
||||||
|
state.all_steps.append(step)
|
||||||
|
events.append(_sse_event("process_step", {"step": step}))
|
||||||
|
|
||||||
|
# Execute tools
|
||||||
|
results = self.tool_executor.process_tool_calls_parallel(state.tool_calls_list, tool_context)
|
||||||
|
|
||||||
|
# Yield tool_result steps
|
||||||
|
for i, tr in enumerate(results):
|
||||||
|
ref_id = tool_call_ids[i] if i < len(tool_call_ids) else f"tool-{i}"
|
||||||
|
step_id = f"step-{state.step_index}"
|
||||||
|
state.step_index += 1
|
||||||
|
|
||||||
|
content = tr.get("content", "")
|
||||||
|
success = True
|
||||||
|
try:
|
||||||
|
obj = json.loads(content)
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
success = obj.get("success", True)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
step = {
|
||||||
|
"id": step_id, "index": len(state.all_steps), "type": "tool_result",
|
||||||
|
"id_ref": ref_id, "name": tr.get("name", ""), "content": content, "success": success
|
||||||
|
}
|
||||||
|
state.all_steps.append(step)
|
||||||
|
events.append(_sse_event("process_step", {"step": step}))
|
||||||
|
|
||||||
|
state.all_tool_results.append({"role": "tool", "tool_call_id": tr.get("tool_call_id", ""), "content": content})
|
||||||
|
|
||||||
|
return results, events
|
||||||
|
|
||||||
|
def update_messages_for_next_iteration(self, state: StreamState, results: List[Dict]):
|
||||||
|
"""Update messages list with assistant response and tool results for next iteration."""
|
||||||
|
state.messages.append({"role": "assistant", "content": state.full_content or "", "tool_calls": state.tool_calls_list})
|
||||||
|
if results:
|
||||||
|
state.messages.extend(state.all_tool_results[-len(results):])
|
||||||
|
state.all_tool_results = []
|
||||||
|
|
||||||
|
def reset_iteration_state(self, state: StreamState):
|
||||||
|
"""Reset state for next iteration."""
|
||||||
|
state.full_content = state.full_thinking = ""
|
||||||
|
state.tool_calls_list = []
|
||||||
|
state.thinking_step_id = state.thinking_step_idx = state.text_step_id = state.text_step_idx = None
|
||||||
|
|
||||||
|
# ==================== Step 4: Finalize ====================
|
||||||
|
|
||||||
|
def save_message(self, conversation_id: str, state: StreamState, token_count: int):
|
||||||
|
"""Save assistant message to database."""
|
||||||
|
content_json = {"text": state.full_content, "steps": state.all_steps}
|
||||||
|
if state.all_tool_calls:
|
||||||
|
content_json["tool_calls"] = state.all_tool_calls
|
||||||
|
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
db.add(Message(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="assistant",
|
||||||
|
content=json.dumps(content_json, ensure_ascii=False),
|
||||||
|
token_count=token_count,
|
||||||
|
usage=json.dumps(state.total_usage) if state.total_usage else None
|
||||||
|
))
|
||||||
|
db.commit()
|
||||||
|
except Exception:
|
||||||
|
db.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
# ==================== Main Orchestrator ====================
|
||||||
|
|
||||||
async def stream_response(
|
async def stream_response(
|
||||||
self,
|
self,
|
||||||
conversation: Conversation,
|
conversation: Conversation,
|
||||||
|
|
@ -100,360 +295,100 @@ class ChatService:
|
||||||
username: str = None,
|
username: str = None,
|
||||||
workspace: str = None,
|
workspace: str = None,
|
||||||
user_permission_level: int = 1
|
user_permission_level: int = 1
|
||||||
) -> AsyncGenerator[Dict[str, str], None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""
|
"""Main streaming orchestrator - step-by-step flow."""
|
||||||
Streaming response generator
|
|
||||||
|
|
||||||
Yields raw SSE event strings for direct forwarding.
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
messages = self.build_messages(conversation)
|
# Step 1: Initialize
|
||||||
|
state, llm, tools, model, max_tokens, tool_context = self.init_stream_state(
|
||||||
messages.append({
|
conversation, user_message, enabled_tools or []
|
||||||
"role": "user",
|
)
|
||||||
"content": json.dumps({"text": user_message, "attachments": []})
|
tool_context.update
|
||||||
|
({
|
||||||
|
"user_id": user_id,
|
||||||
|
"username": username,
|
||||||
|
"workspace": workspace,
|
||||||
|
"user_permission_level": user_permission_level
|
||||||
})
|
})
|
||||||
|
|
||||||
# Get tools based on enabled_tools filter
|
# ReAct loop
|
||||||
if enabled_tools:
|
|
||||||
tools = [t for t in registry.list_all() if t.get("function", {}).get("name") in enabled_tools]
|
|
||||||
else:
|
|
||||||
tools = []
|
|
||||||
|
|
||||||
llm, provider_max_tokens = get_llm_client(conversation)
|
|
||||||
model = conversation.model or llm.default_model or "gpt-4"
|
|
||||||
# 直接使用 provider 的 max_tokens
|
|
||||||
max_tokens = provider_max_tokens
|
|
||||||
|
|
||||||
# State tracking
|
|
||||||
all_steps = []
|
|
||||||
all_tool_calls = []
|
|
||||||
all_tool_results = []
|
|
||||||
step_index = 0
|
|
||||||
|
|
||||||
# Token usage tracking
|
|
||||||
total_usage = {
|
|
||||||
"prompt_tokens": 0,
|
|
||||||
"completion_tokens": 0,
|
|
||||||
"total_tokens": 0
|
|
||||||
}
|
|
||||||
|
|
||||||
# Global step IDs for thinking and text (persist across iterations)
|
|
||||||
thinking_step_id = None
|
|
||||||
thinking_step_idx = None
|
|
||||||
text_step_id = None
|
|
||||||
text_step_idx = None
|
|
||||||
|
|
||||||
for _ in range(MAX_ITERATIONS):
|
for _ in range(MAX_ITERATIONS):
|
||||||
# Stream from LLM
|
self.reset_iteration_state(state)
|
||||||
full_content = ""
|
|
||||||
full_thinking = ""
|
|
||||||
tool_calls_list = []
|
|
||||||
|
|
||||||
# Step tracking - use unified step-{index} format
|
# Step 2: Stream LLM
|
||||||
thinking_step_id = None
|
async for sse_line, chunk in self.stream_llm_response(
|
||||||
thinking_step_idx = None
|
llm, model, state.messages, tools,
|
||||||
text_step_id = None
|
conversation.temperature, max_tokens,
|
||||||
text_step_idx = None
|
thinking_enabled or conversation.thinking_enabled
|
||||||
|
|
||||||
async for sse_line in llm.stream_call(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
tools=tools,
|
|
||||||
temperature=conversation.temperature,
|
|
||||||
max_tokens=max_tokens or 8192,
|
|
||||||
thinking_enabled=thinking_enabled or conversation.thinking_enabled
|
|
||||||
):
|
):
|
||||||
# Parse SSE line
|
# Handle error events
|
||||||
# Format: "event: xxx\ndata: {...}\n\n"
|
event_type, data_str = self.parse_sse_line(sse_line)
|
||||||
event_type = None
|
if event_type == 'error':
|
||||||
data_str = None
|
error_data = json.loads(data_str) if data_str else {}
|
||||||
|
yield _sse_event("error", {"content": error_data.get("content", "Unknown error")})
|
||||||
|
return
|
||||||
|
|
||||||
for line in sse_line.strip().split('\n'):
|
if not chunk:
|
||||||
if line.startswith('event: '):
|
|
||||||
event_type = line[7:].strip()
|
|
||||||
elif line.startswith('data: '):
|
|
||||||
data_str = line[6:].strip()
|
|
||||||
|
|
||||||
if data_str is None:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Handle error events from LLM
|
# Extract usage
|
||||||
if event_type == 'error':
|
|
||||||
try:
|
|
||||||
error_data = json.loads(data_str)
|
|
||||||
yield _sse_event("error", {"content": error_data.get("content", "Unknown error")})
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
yield _sse_event("error", {"content": data_str})
|
|
||||||
return
|
|
||||||
|
|
||||||
# Parse the data
|
|
||||||
try:
|
|
||||||
chunk = json.loads(data_str)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
yield _sse_event("error", {"content": f"Failed to parse response: {data_str}"})
|
|
||||||
return
|
|
||||||
|
|
||||||
# 提取 API 返回的 usage 信息
|
|
||||||
if "usage" in chunk:
|
if "usage" in chunk:
|
||||||
usage = chunk["usage"]
|
u = chunk["usage"]
|
||||||
total_usage["prompt_tokens"] = usage.get("prompt_tokens", 0)
|
state.total_usage = {
|
||||||
total_usage["completion_tokens"] = usage.get("completion_tokens", 0)
|
"prompt_tokens": u.get("prompt_tokens", 0),
|
||||||
total_usage["total_tokens"] = usage.get("total_tokens", 0)
|
"completion_tokens": u.get("completion_tokens", 0),
|
||||||
|
"total_tokens": u.get("total_tokens", 0)
|
||||||
|
}
|
||||||
|
|
||||||
# Check for error in response
|
# Check for API errors
|
||||||
if "error" in chunk:
|
if "error" in chunk:
|
||||||
error_msg = chunk["error"].get("message", str(chunk["error"]))
|
yield _sse_event("error", {"content": f"API Error: {chunk['error'].get('message', str(chunk['error']))}"})
|
||||||
yield _sse_event("error", {"content": f"API Error: {error_msg}"})
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Get delta
|
# Get delta
|
||||||
choices = chunk.get("choices", [])
|
|
||||||
delta = None
|
delta = None
|
||||||
|
choices = chunk.get("choices", [])
|
||||||
if choices:
|
if choices:
|
||||||
delta = choices[0].get("delta", {})
|
delta = choices[0].get("delta", {})
|
||||||
# If no delta but has message (non-streaming response)
|
|
||||||
if not delta:
|
if not delta:
|
||||||
message = choices[0].get("message", {})
|
content = choices[0].get("message", {}).get("content")
|
||||||
if message.get("content"):
|
if content:
|
||||||
delta = {"content": message["content"]}
|
delta = {"content": content}
|
||||||
|
|
||||||
if not delta:
|
if not delta:
|
||||||
# Check if there's any content in the response (for non-standard LLM responses)
|
|
||||||
content = chunk.get("content") or chunk.get("message", {}).get("content", "")
|
content = chunk.get("content") or chunk.get("message", {}).get("content", "")
|
||||||
if content:
|
if content:
|
||||||
delta = {"content": content}
|
delta = {"content": content}
|
||||||
|
|
||||||
if delta:
|
if delta:
|
||||||
# Handle reasoning (thinking)
|
# Step 2b: Process delta
|
||||||
reasoning = delta.get("reasoning_content", "")
|
for event in self.process_delta(state, delta):
|
||||||
if reasoning:
|
yield event
|
||||||
prev_thinking_len = len(full_thinking)
|
|
||||||
full_thinking += reasoning
|
|
||||||
if prev_thinking_len == 0: # New thinking stream started
|
|
||||||
thinking_step_idx = step_index
|
|
||||||
thinking_step_id = f"step-{step_index}"
|
|
||||||
step_index += 1
|
|
||||||
yield _sse_event("process_step", {
|
|
||||||
"step": {
|
|
||||||
"id": thinking_step_id,
|
|
||||||
"index": thinking_step_idx,
|
|
||||||
"type": "thinking",
|
|
||||||
"content": full_thinking
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
# Handle content
|
# Save steps after streaming
|
||||||
content = delta.get("content", "")
|
self.save_steps(state)
|
||||||
if content:
|
|
||||||
prev_content_len = len(full_content)
|
|
||||||
full_content += content
|
|
||||||
if prev_content_len == 0: # New text stream started
|
|
||||||
text_step_idx = step_index
|
|
||||||
text_step_id = f"step-{step_index}"
|
|
||||||
step_index += 1
|
|
||||||
yield _sse_event("process_step", {
|
|
||||||
"step": {
|
|
||||||
"id": text_step_id,
|
|
||||||
"index": text_step_idx,
|
|
||||||
"type": "text",
|
|
||||||
"content": full_content
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
# Accumulate tool calls
|
# Step 3: Execute tools if present
|
||||||
tool_calls_delta = delta.get("tool_calls", [])
|
if state.tool_calls_list:
|
||||||
for tc in tool_calls_delta:
|
results, events = self.execute_tools(state, tool_context)
|
||||||
idx = tc.get("index", 0)
|
for event in events:
|
||||||
if idx >= len(tool_calls_list):
|
yield event
|
||||||
tool_calls_list.append({
|
self.update_messages_for_next_iteration(state, results)
|
||||||
"id": tc.get("id", ""),
|
|
||||||
"type": "function",
|
|
||||||
"function": {"name": "", "arguments": ""}
|
|
||||||
})
|
|
||||||
func = tc.get("function", {})
|
|
||||||
if func.get("name"):
|
|
||||||
tool_calls_list[idx]["function"]["name"] += func["name"]
|
|
||||||
if func.get("arguments"):
|
|
||||||
tool_calls_list[idx]["function"]["arguments"] += func["arguments"]
|
|
||||||
|
|
||||||
# Save thinking step
|
|
||||||
if thinking_step_id is not None:
|
|
||||||
all_steps.append({
|
|
||||||
"id": thinking_step_id,
|
|
||||||
"index": thinking_step_idx,
|
|
||||||
"type": "thinking",
|
|
||||||
"content": full_thinking
|
|
||||||
})
|
|
||||||
|
|
||||||
# Save text step
|
|
||||||
if text_step_id is not None:
|
|
||||||
all_steps.append({
|
|
||||||
"id": text_step_id,
|
|
||||||
"index": text_step_idx,
|
|
||||||
"type": "text",
|
|
||||||
"content": full_content
|
|
||||||
})
|
|
||||||
|
|
||||||
# Handle tool calls
|
|
||||||
if tool_calls_list:
|
|
||||||
all_tool_calls.extend(tool_calls_list)
|
|
||||||
|
|
||||||
# Yield tool_call steps - use unified step-{index} format
|
|
||||||
tool_call_step_ids = [] # Track step IDs for tool calls
|
|
||||||
for tc in tool_calls_list:
|
|
||||||
call_step_idx = step_index
|
|
||||||
call_step_id = f"step-{step_index}"
|
|
||||||
tool_call_step_ids.append(call_step_id)
|
|
||||||
step_index += 1
|
|
||||||
call_step = {
|
|
||||||
"id": call_step_id,
|
|
||||||
"index": call_step_idx,
|
|
||||||
"type": "tool_call",
|
|
||||||
"id_ref": tc.get("id", ""),
|
|
||||||
"name": tc["function"]["name"],
|
|
||||||
"arguments": tc["function"]["arguments"]
|
|
||||||
}
|
|
||||||
all_steps.append(call_step)
|
|
||||||
yield _sse_event("process_step", {"step": call_step})
|
|
||||||
|
|
||||||
# Execute tools
|
|
||||||
tool_context = {
|
|
||||||
"workspace": workspace,
|
|
||||||
"user_id": user_id,
|
|
||||||
"username": username,
|
|
||||||
"user_permission_level": user_permission_level
|
|
||||||
}
|
|
||||||
tool_results = self.tool_executor.process_tool_calls_parallel(
|
|
||||||
tool_calls_list, tool_context
|
|
||||||
)
|
|
||||||
|
|
||||||
# Yield tool_result steps - use unified step-{index} format
|
|
||||||
for i, tr in enumerate(tool_results):
|
|
||||||
tool_call_step_id = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"step-{i}"
|
|
||||||
result_step_idx = step_index
|
|
||||||
result_step_id = f"step-{step_index}"
|
|
||||||
step_index += 1
|
|
||||||
|
|
||||||
content = tr.get("content", "")
|
|
||||||
success = True
|
|
||||||
try:
|
|
||||||
content_obj = json.loads(content)
|
|
||||||
if isinstance(content_obj, dict):
|
|
||||||
success = content_obj.get("success", True)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
result_step = {
|
|
||||||
"id": result_step_id,
|
|
||||||
"index": result_step_idx,
|
|
||||||
"type": "tool_result",
|
|
||||||
"id_ref": tool_call_step_id, # Reference to the tool_call step
|
|
||||||
"name": tr.get("name", ""),
|
|
||||||
"content": content,
|
|
||||||
"success": success
|
|
||||||
}
|
|
||||||
all_steps.append(result_step)
|
|
||||||
yield _sse_event("process_step", {"step": result_step})
|
|
||||||
|
|
||||||
all_tool_results.append({
|
|
||||||
"role": "tool",
|
|
||||||
"tool_call_id": tr.get("tool_call_id", ""),
|
|
||||||
"content": tr.get("content", "")
|
|
||||||
})
|
|
||||||
|
|
||||||
# Add assistant message with tool calls for next iteration
|
|
||||||
messages.append({
|
|
||||||
"role": "assistant",
|
|
||||||
"content": full_content or "",
|
|
||||||
"tool_calls": tool_calls_list
|
|
||||||
})
|
|
||||||
messages.extend(all_tool_results[-len(tool_results):])
|
|
||||||
all_tool_results = []
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# No tool calls - final iteration, save message
|
# Step 4: Finalize (no tool calls)
|
||||||
msg_id = str(uuid.uuid4())
|
token_count = state.total_usage.get("completion_tokens", 0)
|
||||||
|
self.save_message(conversation.id, state, token_count)
|
||||||
actual_token_count = total_usage.get("completion_tokens", 0)
|
yield _sse_event("done", {"message_id": str(uuid.uuid4()), "token_count": token_count, "usage": state.total_usage})
|
||||||
logger.info(f"total_usage: {total_usage}")
|
|
||||||
|
|
||||||
self._save_message(
|
|
||||||
conversation.id,
|
|
||||||
msg_id,
|
|
||||||
full_content,
|
|
||||||
all_tool_calls,
|
|
||||||
all_tool_results,
|
|
||||||
all_steps,
|
|
||||||
actual_token_count,
|
|
||||||
total_usage
|
|
||||||
)
|
|
||||||
|
|
||||||
yield _sse_event("done", {
|
|
||||||
"message_id": msg_id,
|
|
||||||
"token_count": actual_token_count,
|
|
||||||
"usage": total_usage
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Max iterations exceeded - save message before error
|
# Max iterations exceeded
|
||||||
if full_content or all_tool_calls:
|
if state.full_content or state.all_tool_calls:
|
||||||
msg_id = str(uuid.uuid4())
|
self.save_message(conversation.id, state, state.total_usage.get("completion_tokens", 0))
|
||||||
self._save_message(
|
|
||||||
conversation.id,
|
|
||||||
msg_id,
|
|
||||||
full_content,
|
|
||||||
all_tool_calls,
|
|
||||||
all_tool_results,
|
|
||||||
all_steps,
|
|
||||||
actual_token_count,
|
|
||||||
total_usage
|
|
||||||
)
|
|
||||||
yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"})
|
yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield _sse_event("error", {"content": str(e)})
|
yield _sse_event("error", {"content": str(e)})
|
||||||
|
|
||||||
def _save_message(
|
|
||||||
self,
|
|
||||||
conversation_id: str,
|
|
||||||
msg_id: str,
|
|
||||||
full_content: str,
|
|
||||||
all_tool_calls: list,
|
|
||||||
all_tool_results: list,
|
|
||||||
all_steps: list,
|
|
||||||
token_count: int = 0,
|
|
||||||
usage: dict = None
|
|
||||||
):
|
|
||||||
"""Save the assistant message to database."""
|
|
||||||
|
|
||||||
|
# Global service
|
||||||
content_json = {
|
|
||||||
"text": full_content,
|
|
||||||
"steps": all_steps
|
|
||||||
}
|
|
||||||
if all_tool_calls:
|
|
||||||
content_json["tool_calls"] = all_tool_calls
|
|
||||||
|
|
||||||
db = SessionLocal()
|
|
||||||
try:
|
|
||||||
msg = Message(
|
|
||||||
id=msg_id,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
role="assistant",
|
|
||||||
content=json.dumps(content_json, ensure_ascii=False),
|
|
||||||
token_count=token_count,
|
|
||||||
usage=json.dumps(usage) if usage else None
|
|
||||||
)
|
|
||||||
db.add(msg)
|
|
||||||
db.commit()
|
|
||||||
except Exception as e:
|
|
||||||
db.rollback()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
|
|
||||||
# Global chat service
|
|
||||||
chat_service = ChatService()
|
chat_service = ChatService()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue