241 lines
8.3 KiB
Python
241 lines
8.3 KiB
Python
"""Chat service module"""
|
|
import json
|
|
from typing import List, Dict, Any, AsyncGenerator
|
|
|
|
from luxx.models import Conversation, Message
|
|
from luxx.tools.executor import ToolExecutor
|
|
from luxx.tools.core import registry
|
|
from luxx.services.llm_client import LLMClient
|
|
from luxx.config import config
|
|
|
|
|
|
# Maximum iterations to prevent infinite loops
|
|
MAX_ITERATIONS = 10
|
|
|
|
|
|
def get_llm_client(conversation: Conversation = None):
|
|
"""Get LLM client, optionally using conversation's provider"""
|
|
if conversation and conversation.provider_id:
|
|
from luxx.models import LLMProvider
|
|
from luxx.database import SessionLocal
|
|
db = SessionLocal()
|
|
try:
|
|
provider = db.query(LLMProvider).filter(LLMProvider.id == conversation.provider_id).first()
|
|
if provider:
|
|
client = LLMClient(
|
|
api_key=provider.api_key,
|
|
api_url=provider.base_url,
|
|
model=provider.default_model
|
|
)
|
|
return client
|
|
finally:
|
|
db.close()
|
|
|
|
# Fallback to global config
|
|
client = LLMClient()
|
|
return client
|
|
|
|
|
|
class ChatService:
|
|
"""Chat service"""
|
|
|
|
def __init__(self):
|
|
self.tool_executor = ToolExecutor()
|
|
|
|
def build_messages(
|
|
self,
|
|
conversation: Conversation,
|
|
include_system: bool = True
|
|
) -> List[Dict[str, str]]:
|
|
"""Build message list"""
|
|
from luxx.database import SessionLocal
|
|
from luxx.models import Message
|
|
|
|
messages = []
|
|
|
|
if include_system and conversation.system_prompt:
|
|
messages.append({
|
|
"role": "system",
|
|
"content": conversation.system_prompt
|
|
})
|
|
|
|
db = SessionLocal()
|
|
try:
|
|
db_messages = db.query(Message).filter(
|
|
Message.conversation_id == conversation.id
|
|
).order_by(Message.created_at).all()
|
|
|
|
for msg in db_messages:
|
|
messages.append({
|
|
"role": msg.role,
|
|
"content": msg.content
|
|
})
|
|
finally:
|
|
db.close()
|
|
|
|
return messages
|
|
|
|
async def stream_response(
|
|
self,
|
|
conversation: Conversation,
|
|
user_message: str,
|
|
tools_enabled: bool = True
|
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
|
"""
|
|
Streaming response generator
|
|
|
|
Event types:
|
|
- process_step: thinking/text/tool_call/tool_result step
|
|
- done: final response complete
|
|
- error: on error
|
|
"""
|
|
try:
|
|
messages = self.build_messages(conversation)
|
|
|
|
messages.append({
|
|
"role": "user",
|
|
"content": user_message
|
|
})
|
|
|
|
tools = registry.list_all() if tools_enabled else None
|
|
|
|
iteration = 0
|
|
|
|
llm = get_llm_client(conversation)
|
|
model = conversation.model or llm.default_model or "gpt-4"
|
|
|
|
while iteration < MAX_ITERATIONS:
|
|
iteration += 1
|
|
print(f"[CHAT] Starting iteration {iteration}, messages: {len(messages)}")
|
|
|
|
tool_calls_this_round = None
|
|
|
|
async for event in llm.stream_call(
|
|
model=model,
|
|
messages=messages,
|
|
tools=tools,
|
|
temperature=conversation.temperature,
|
|
max_tokens=conversation.max_tokens
|
|
):
|
|
event_type = event.get("type")
|
|
|
|
if event_type == "content_delta":
|
|
content = event.get("content", "")
|
|
if content:
|
|
yield {"type": "text", "content": content}
|
|
|
|
elif event_type == "tool_call_delta":
|
|
tool_call = event.get("tool_call", {})
|
|
yield {"type": "tool_call", "data": tool_call}
|
|
|
|
elif event_type == "done":
|
|
tool_calls_this_round = event.get("tool_calls")
|
|
print(f"[CHAT] Done event, tool_calls: {tool_calls_this_round}")
|
|
|
|
if tool_calls_this_round and tools_enabled:
|
|
print(f"[CHAT] Executing tools")
|
|
yield {"type": "tool_call", "data": tool_calls_this_round}
|
|
|
|
tool_results = self.tool_executor.process_tool_calls_parallel(
|
|
tool_calls_this_round,
|
|
{}
|
|
)
|
|
|
|
messages.append({
|
|
"role": "assistant",
|
|
"content": "",
|
|
"tool_calls": tool_calls_this_round
|
|
})
|
|
|
|
for tr in tool_results:
|
|
messages.append({
|
|
"role": "tool",
|
|
"tool_call_id": tr.get("tool_call_id"),
|
|
"content": str(tr.get("result", ""))
|
|
})
|
|
|
|
yield {"type": "tool_result", "data": tool_results}
|
|
else:
|
|
print(f"[CHAT] Breaking: tool_calls={tool_calls_this_round}, tools_enabled={tools_enabled}")
|
|
break
|
|
|
|
if not tool_calls_this_round or not tools_enabled:
|
|
print(f"[CHAT] Breaking at outer")
|
|
break
|
|
|
|
yield {"type": "done"}
|
|
|
|
except Exception as e:
|
|
print(f"[CHAT] Exception: {type(e).__name__}: {str(e)}")
|
|
yield {"type": "error", "error": str(e)}
|
|
|
|
def non_stream_response(
|
|
self,
|
|
conversation: Conversation,
|
|
user_message: str,
|
|
tools_enabled: bool = False
|
|
) -> Dict[str, Any]:
|
|
"""Non-streaming response"""
|
|
try:
|
|
messages = self.build_messages(conversation)
|
|
messages.append({
|
|
"role": "user",
|
|
"content": user_message
|
|
})
|
|
|
|
tools = registry.list_all() if tools_enabled else None
|
|
|
|
iteration = 0
|
|
|
|
llm_client = get_llm_client(conversation)
|
|
model = conversation.model or llm_client.default_model or "gpt-4"
|
|
|
|
while iteration < MAX_ITERATIONS:
|
|
iteration += 1
|
|
|
|
response = llm_client.sync_call(
|
|
model=model,
|
|
messages=messages,
|
|
tools=tools,
|
|
temperature=conversation.temperature,
|
|
max_tokens=conversation.max_tokens
|
|
)
|
|
|
|
tool_calls = response.tool_calls
|
|
|
|
if tool_calls and tools_enabled:
|
|
messages.append({
|
|
"role": "assistant",
|
|
"content": response.content,
|
|
"tool_calls": tool_calls
|
|
})
|
|
|
|
tool_results = self.tool_executor.process_tool_calls_parallel(tool_calls)
|
|
|
|
for tr in tool_results:
|
|
messages.append({
|
|
"role": "tool",
|
|
"tool_call_id": tr.get("tool_call_id"),
|
|
"content": str(tr.get("result", ""))
|
|
})
|
|
else:
|
|
return {
|
|
"success": True,
|
|
"content": response.content
|
|
}
|
|
|
|
return {
|
|
"success": True,
|
|
"content": "Max iterations reached"
|
|
}
|
|
|
|
except Exception as e:
|
|
return {
|
|
"success": False,
|
|
"error": str(e)
|
|
}
|
|
|
|
|
|
# Global chat service
|
|
chat_service = ChatService()
|