Luxx/alcor/services/llm_client.py

257 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""LLM API客户端"""
import json
from typing import Dict, List, Optional, Generator, Any, Callable, AsyncGenerator
from dataclasses import dataclass
import httpx
from alcor.config import config
@dataclass
class LLMResponse:
"""LLM响应"""
content: str
tool_calls: Optional[List[Dict[str, Any]]] = None
usage: Optional[Dict[str, int]] = None
finish_reason: Optional[str] = None
raw: Optional[Dict] = None
class LLMClient:
"""LLM API客户端支持多种提供商"""
def __init__(
self,
api_key: Optional[str] = None,
api_url: Optional[str] = None,
provider: Optional[str] = None
):
self.api_key = api_key or config.llm_api_key
self.api_url = api_url or config.llm_api_url
self.provider = provider or config.llm_provider or self._detect_provider()
self._client: Optional[httpx.AsyncClient] = None
def _detect_provider(self) -> str:
"""检测提供商"""
url = self.api_url.lower()
if "deepseek" in url:
return "deepseek"
elif "bigmodel" in url or "glm" in url:
return "glm"
elif "zhipu" in url:
return "glm"
elif "qwen" in url or "dashscope" in url:
return "qwen"
elif "moonshot" in url or "moonshot" in url:
return "moonshot"
return "openai"
@property
def client(self) -> httpx.AsyncClient:
"""获取HTTP客户端"""
if self._client is None:
self._client = httpx.AsyncClient(
timeout=httpx.Timeout(120.0, connect=30.0),
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
)
return self._client
async def close(self):
"""关闭客户端"""
if self._client:
await self._client.aclose()
self._client = None
def _build_headers(self) -> Dict[str, str]:
"""构建请求头"""
return {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
def _build_body(
self,
model: str,
messages: List[Dict[str, str]],
tools: Optional[List[Dict]] = None,
stream: bool = True,
**kwargs
) -> Dict[str, Any]:
"""构建请求体"""
body = {
"model": model,
"messages": messages,
"stream": stream
}
# 添加可选参数
if "temperature" in kwargs:
body["temperature"] = kwargs["temperature"]
if "max_tokens" in kwargs:
body["max_tokens"] = kwargs["max_tokens"]
if "top_p" in kwargs:
body["top_p"] = kwargs["top_p"]
if "thinking_enabled" in kwargs:
body["thinking_enabled"] = kwargs["thinking_enabled"]
# 添加工具
if tools:
body["tools"] = tools
return body
def _parse_response(self, data: Dict) -> LLMResponse:
"""解析响应"""
# 通用字段
content = ""
tool_calls = None
usage = None
finish_reason = None
# OpenAI格式
if "choices" in data:
choice = data["choices"][0]
message = choice.get("message", {})
content = message.get("content", "")
tool_calls = message.get("tool_calls")
finish_reason = choice.get("finish_reason")
# 使用量统计
if "usage" in data:
usage = {
"prompt_tokens": data["usage"].get("prompt_tokens", 0),
"completion_tokens": data["usage"].get("completion_tokens", 0),
"total_tokens": data["usage"].get("total_tokens", 0)
}
return LLMResponse(
content=content,
tool_calls=tool_calls,
usage=usage,
finish_reason=finish_reason,
raw=data
)
async def call(
self,
model: str,
messages: List[Dict[str, str]],
tools: Optional[List[Dict]] = None,
**kwargs
) -> LLMResponse:
"""调用LLM API非流式"""
body = self._build_body(model, messages, tools, stream=False, **kwargs)
try:
response = await self.client.post(
self.api_url,
json=body,
headers=self._build_headers()
)
response.raise_for_status()
data = response.json()
return self._parse_response(data)
except httpx.HTTPStatusError as e:
raise Exception(f"HTTP error: {e.response.status_code} - {e.response.text}")
except Exception as e:
raise Exception(f"LLM API error: {str(e)}")
async def stream(
self,
model: str,
messages: List[Dict[str, str]],
tools: Optional[List[Dict]] = None,
**kwargs
) -> AsyncGenerator[Dict[str, Any], None]:
"""流式调用LLM API"""
body = self._build_body(model, messages, tools, stream=True, **kwargs)
try:
async with self.client.stream(
"POST",
self.api_url,
json=body,
headers=self._build_headers()
) as response:
response.raise_for_status()
accumulated_content = ""
accumulated_tool_calls: Dict[int, Dict] = {}
async for line in response.aiter_lines():
if not line.strip():
continue
# 跳过SSE前缀
if line.startswith("data: "):
line = line[6:]
if line == "[DONE]":
break
try:
chunk = json.loads(line)
except json.JSONDecodeError:
continue
# 解析SSE数据
delta = chunk.get("choices", [{}])[0].get("delta", {})
# 内容增量
content_delta = delta.get("content", "")
if content_delta:
accumulated_content += content_delta
yield {
"type": "content_delta",
"content": content_delta,
"full_content": accumulated_content
}
# 工具调用增量
tool_calls = delta.get("tool_calls", [])
for tc in tool_calls:
index = tc.get("index", 0)
if index not in accumulated_tool_calls:
accumulated_tool_calls[index] = {
"id": "",
"type": "function",
"function": {"name": "", "arguments": ""}
}
if tc.get("id"):
accumulated_tool_calls[index]["id"] = tc["id"]
if tc.get("function", {}).get("name"):
accumulated_tool_calls[index]["function"]["name"] = tc["function"]["name"]
if tc.get("function", {}).get("arguments"):
accumulated_tool_calls[index]["function"]["arguments"] += tc["function"]["arguments"]
# 完成信号
finish_reason = chunk.get("choices", [{}])[0].get("finish_reason")
if finish_reason:
yield {
"type": "done",
"finish_reason": finish_reason,
"content": accumulated_content,
"tool_calls": list(accumulated_tool_calls.values()) if accumulated_tool_calls else None,
"usage": chunk.get("usage")
}
except httpx.HTTPStatusError as e:
yield {
"type": "error",
"error": f"HTTP error: {e.response.status_code}"
}
except Exception as e:
yield {
"type": "error",
"error": str(e)
}
# 全局LLM客户端
llm_client = LLMClient()