241 lines
9.5 KiB
Python
241 lines
9.5 KiB
Python
"""LLM API client"""
|
|
import json
|
|
import httpx
|
|
from typing import Dict, Any, Optional, List, AsyncGenerator
|
|
|
|
from luxx.config import config
|
|
|
|
|
|
class LLMResponse:
|
|
"""LLM response"""
|
|
content: str
|
|
tool_calls: Optional[List[Dict]] = None
|
|
usage: Optional[Dict] = None
|
|
|
|
def __init__(
|
|
self,
|
|
content: str = "",
|
|
tool_calls: Optional[List[Dict]] = None,
|
|
usage: Optional[Dict] = None
|
|
):
|
|
self.content = content
|
|
self.tool_calls = tool_calls
|
|
self.usage = usage
|
|
|
|
|
|
class LLMClient:
|
|
"""LLM API client with multi-provider support"""
|
|
|
|
def __init__(self, api_key: str = None, api_url: str = None, model: str = None):
|
|
self.api_key = api_key or config.llm_api_key
|
|
self.api_url = api_url or config.llm_api_url
|
|
self.default_model = model
|
|
self.provider = self._detect_provider()
|
|
self._client: Optional[httpx.AsyncClient] = None
|
|
|
|
def _detect_provider(self) -> str:
|
|
"""Detect provider from URL"""
|
|
url = self.api_url.lower()
|
|
if "deepseek" in url:
|
|
return "deepseek"
|
|
elif "glm" in url or "zhipu" in url:
|
|
return "glm"
|
|
elif "openai" in url:
|
|
return "openai"
|
|
return "openai"
|
|
|
|
async def close(self):
|
|
"""Close client"""
|
|
if self._client:
|
|
await self._client.aclose()
|
|
self._client = None
|
|
|
|
def _build_headers(self) -> Dict[str, str]:
|
|
"""Build request headers"""
|
|
return {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {self.api_key}"
|
|
}
|
|
|
|
def _build_body(
|
|
self,
|
|
model: str,
|
|
messages: List[Dict],
|
|
tools: Optional[List[Dict]] = None,
|
|
stream: bool = False,
|
|
**kwargs
|
|
) -> Dict[str, Any]:
|
|
"""Build request body"""
|
|
body = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"stream": stream
|
|
}
|
|
|
|
if "temperature" in kwargs:
|
|
body["temperature"] = kwargs["temperature"]
|
|
|
|
if "max_tokens" in kwargs:
|
|
body["max_tokens"] = kwargs["max_tokens"]
|
|
|
|
if tools:
|
|
body["tools"] = tools
|
|
|
|
return body
|
|
|
|
def _parse_response(self, data: Dict) -> LLMResponse:
|
|
"""Parse response"""
|
|
content = ""
|
|
tool_calls = None
|
|
usage = None
|
|
|
|
if "choices" in data:
|
|
choice = data["choices"][0]
|
|
content = choice.get("message", {}).get("content", "")
|
|
tool_calls = choice.get("message", {}).get("tool_calls")
|
|
|
|
if "usage" in data:
|
|
usage = data["usage"]
|
|
|
|
return LLMResponse(
|
|
content=content,
|
|
tool_calls=tool_calls,
|
|
usage=usage
|
|
)
|
|
|
|
async def client(self) -> httpx.AsyncClient:
|
|
"""Get HTTP client"""
|
|
if self._client is None:
|
|
self._client = httpx.AsyncClient(timeout=120.0)
|
|
return self._client
|
|
|
|
async def sync_call(
|
|
self,
|
|
model: str,
|
|
messages: List[Dict],
|
|
tools: Optional[List[Dict]] = None,
|
|
**kwargs
|
|
) -> LLMResponse:
|
|
"""Call LLM API (non-streaming)"""
|
|
body = self._build_body(model, messages, tools, stream=False, **kwargs)
|
|
|
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
response = await client.post(
|
|
self.api_url,
|
|
headers=self._build_headers(),
|
|
json=body
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
return self._parse_response(data)
|
|
|
|
async def stream_call(
|
|
self,
|
|
model: str,
|
|
messages: List[Dict],
|
|
tools: Optional[List[Dict]] = None,
|
|
**kwargs
|
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
|
"""Stream call LLM API"""
|
|
body = self._build_body(model, messages, tools, stream=True, **kwargs)
|
|
|
|
# Accumulators for tool calls (need to collect from delta chunks)
|
|
accumulated_tool_calls = {}
|
|
|
|
print(f"[LLM] Starting stream_call for model: {model}")
|
|
print(f"[LLM] Messages count: {len(messages)}")
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
print(f"[LLM] Sending request to {self.api_url}")
|
|
async with client.stream(
|
|
"POST",
|
|
self.api_url,
|
|
headers=self._build_headers(),
|
|
json=body
|
|
) as response:
|
|
print(f"[LLM] Response status: {response.status_code}")
|
|
response.raise_for_status()
|
|
|
|
chunk_count = 0
|
|
async for line in response.aiter_lines():
|
|
if not line.strip():
|
|
continue
|
|
|
|
if line.startswith("data: "):
|
|
data_str = line[6:]
|
|
|
|
if data_str == "[DONE]":
|
|
print(f"[LLM] Received [DONE], chunk_count: {chunk_count}")
|
|
# Don't yield done event for [DONE], the finish_reason will trigger it
|
|
continue
|
|
|
|
try:
|
|
chunk = json.loads(data_str)
|
|
chunk_count += 1
|
|
except json.JSONDecodeError:
|
|
print(f"[LLM] JSON decode error for: {data_str[:100]}")
|
|
continue
|
|
|
|
if "choices" not in chunk:
|
|
print(f"[LLM] No 'choices' in chunk")
|
|
continue
|
|
|
|
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
|
|
|
# DeepSeek reasoner: prefer 'content' over 'reasoning_content'
|
|
content = delta.get("content")
|
|
reasoning = delta.get("reasoning_content", "")
|
|
|
|
# Print first few chunks for debugging
|
|
if chunk_count <= 3:
|
|
print(f"[LLM] delta: content={repr(content)[:30]}, reasoning={repr(reasoning)[:30]}")
|
|
|
|
if content and isinstance(content, str) and content.strip():
|
|
print(f"[LLM] Yielding content: {content[:50]}...")
|
|
yield {"type": "content_delta", "content": content}
|
|
elif reasoning:
|
|
print(f"[LLM] Yielding reasoning: {reasoning[:50]}...")
|
|
yield {"type": "content_delta", "content": reasoning}
|
|
|
|
# Accumulate tool calls from delta chunks (DeepSeek sends them incrementally)
|
|
tool_calls_delta = delta.get("tool_calls", [])
|
|
for tc in tool_calls_delta:
|
|
idx = tc.get("index", 0)
|
|
if idx not in accumulated_tool_calls:
|
|
accumulated_tool_calls[idx] = {"index": idx}
|
|
if "function" in tc:
|
|
if "function" not in accumulated_tool_calls[idx]:
|
|
accumulated_tool_calls[idx]["function"] = {"name": "", "arguments": ""}
|
|
if "name" in tc["function"]:
|
|
accumulated_tool_calls[idx]["function"]["name"] += tc["function"]["name"]
|
|
if "arguments" in tc["function"]:
|
|
accumulated_tool_calls[idx]["function"]["arguments"] += tc["function"]["arguments"]
|
|
|
|
if tool_calls_delta:
|
|
print(f"[LLM] Found tool_calls in delta: {tool_calls_delta}")
|
|
yield {"type": "tool_call_delta", "tool_call": tool_calls_delta}
|
|
|
|
# Check for finish_reason to signal end of stream
|
|
choice = chunk.get("choices", [{}])[0]
|
|
finish_reason = choice.get("finish_reason")
|
|
if finish_reason:
|
|
print(f"[LLM] finish_reason: {finish_reason}")
|
|
final_tool_calls = list(accumulated_tool_calls.values()) if accumulated_tool_calls else None
|
|
yield {"type": "done", "tool_calls": final_tool_calls}
|
|
except httpx.HTTPStatusError as e:
|
|
status_code = e.response.status_code if e.response else "?"
|
|
print(f"[LLM] HTTP error: {status_code}")
|
|
yield {"type": "error", "error": f"HTTP {status_code}: Request failed"}
|
|
except httpx.ResponseNotRead:
|
|
print(f"[LLM] ResponseNotRead error")
|
|
yield {"type": "error", "error": "Streaming response error"}
|
|
except Exception as e:
|
|
print(f"[LLM] Exception: {type(e).__name__}: {str(e)}")
|
|
yield {"type": "error", "error": str(e)}
|
|
|
|
|
|
# Global LLM client
|
|
llm_client = LLMClient()
|