257 lines
8.6 KiB
Python
257 lines
8.6 KiB
Python
"""LLM API客户端"""
|
||
import json
|
||
from typing import Dict, List, Optional, Generator, Any, Callable, AsyncGenerator
|
||
from dataclasses import dataclass
|
||
|
||
import httpx
|
||
|
||
from alcor.config import config
|
||
|
||
|
||
@dataclass
|
||
class LLMResponse:
|
||
"""LLM响应"""
|
||
content: str
|
||
tool_calls: Optional[List[Dict[str, Any]]] = None
|
||
usage: Optional[Dict[str, int]] = None
|
||
finish_reason: Optional[str] = None
|
||
raw: Optional[Dict] = None
|
||
|
||
|
||
class LLMClient:
|
||
"""LLM API客户端,支持多种提供商"""
|
||
|
||
def __init__(
|
||
self,
|
||
api_key: Optional[str] = None,
|
||
api_url: Optional[str] = None,
|
||
provider: Optional[str] = None
|
||
):
|
||
self.api_key = api_key or config.llm_api_key
|
||
self.api_url = api_url or config.llm_api_url
|
||
self.provider = provider or config.llm_provider or self._detect_provider()
|
||
self._client: Optional[httpx.AsyncClient] = None
|
||
|
||
def _detect_provider(self) -> str:
|
||
"""检测提供商"""
|
||
url = self.api_url.lower()
|
||
if "deepseek" in url:
|
||
return "deepseek"
|
||
elif "bigmodel" in url or "glm" in url:
|
||
return "glm"
|
||
elif "zhipu" in url:
|
||
return "glm"
|
||
elif "qwen" in url or "dashscope" in url:
|
||
return "qwen"
|
||
elif "moonshot" in url or "moonshot" in url:
|
||
return "moonshot"
|
||
return "openai"
|
||
|
||
@property
|
||
def client(self) -> httpx.AsyncClient:
|
||
"""获取HTTP客户端"""
|
||
if self._client is None:
|
||
self._client = httpx.AsyncClient(
|
||
timeout=httpx.Timeout(120.0, connect=30.0),
|
||
headers={
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
)
|
||
return self._client
|
||
|
||
async def close(self):
|
||
"""关闭客户端"""
|
||
if self._client:
|
||
await self._client.aclose()
|
||
self._client = None
|
||
|
||
def _build_headers(self) -> Dict[str, str]:
|
||
"""构建请求头"""
|
||
return {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
def _build_body(
|
||
self,
|
||
model: str,
|
||
messages: List[Dict[str, str]],
|
||
tools: Optional[List[Dict]] = None,
|
||
stream: bool = True,
|
||
**kwargs
|
||
) -> Dict[str, Any]:
|
||
"""构建请求体"""
|
||
body = {
|
||
"model": model,
|
||
"messages": messages,
|
||
"stream": stream
|
||
}
|
||
|
||
# 添加可选参数
|
||
if "temperature" in kwargs:
|
||
body["temperature"] = kwargs["temperature"]
|
||
if "max_tokens" in kwargs:
|
||
body["max_tokens"] = kwargs["max_tokens"]
|
||
if "top_p" in kwargs:
|
||
body["top_p"] = kwargs["top_p"]
|
||
if "thinking_enabled" in kwargs:
|
||
body["thinking_enabled"] = kwargs["thinking_enabled"]
|
||
|
||
# 添加工具
|
||
if tools:
|
||
body["tools"] = tools
|
||
|
||
return body
|
||
|
||
def _parse_response(self, data: Dict) -> LLMResponse:
|
||
"""解析响应"""
|
||
# 通用字段
|
||
content = ""
|
||
tool_calls = None
|
||
usage = None
|
||
finish_reason = None
|
||
|
||
# OpenAI格式
|
||
if "choices" in data:
|
||
choice = data["choices"][0]
|
||
message = choice.get("message", {})
|
||
content = message.get("content", "")
|
||
tool_calls = message.get("tool_calls")
|
||
finish_reason = choice.get("finish_reason")
|
||
|
||
# 使用量统计
|
||
if "usage" in data:
|
||
usage = {
|
||
"prompt_tokens": data["usage"].get("prompt_tokens", 0),
|
||
"completion_tokens": data["usage"].get("completion_tokens", 0),
|
||
"total_tokens": data["usage"].get("total_tokens", 0)
|
||
}
|
||
|
||
return LLMResponse(
|
||
content=content,
|
||
tool_calls=tool_calls,
|
||
usage=usage,
|
||
finish_reason=finish_reason,
|
||
raw=data
|
||
)
|
||
|
||
async def call(
|
||
self,
|
||
model: str,
|
||
messages: List[Dict[str, str]],
|
||
tools: Optional[List[Dict]] = None,
|
||
**kwargs
|
||
) -> LLMResponse:
|
||
"""调用LLM API(非流式)"""
|
||
body = self._build_body(model, messages, tools, stream=False, **kwargs)
|
||
|
||
try:
|
||
response = await self.client.post(
|
||
self.api_url,
|
||
json=body,
|
||
headers=self._build_headers()
|
||
)
|
||
response.raise_for_status()
|
||
data = response.json()
|
||
return self._parse_response(data)
|
||
except httpx.HTTPStatusError as e:
|
||
raise Exception(f"HTTP error: {e.response.status_code} - {e.response.text}")
|
||
except Exception as e:
|
||
raise Exception(f"LLM API error: {str(e)}")
|
||
|
||
async def stream(
|
||
self,
|
||
model: str,
|
||
messages: List[Dict[str, str]],
|
||
tools: Optional[List[Dict]] = None,
|
||
**kwargs
|
||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||
"""流式调用LLM API"""
|
||
body = self._build_body(model, messages, tools, stream=True, **kwargs)
|
||
|
||
try:
|
||
async with self.client.stream(
|
||
"POST",
|
||
self.api_url,
|
||
json=body,
|
||
headers=self._build_headers()
|
||
) as response:
|
||
response.raise_for_status()
|
||
|
||
accumulated_content = ""
|
||
accumulated_tool_calls: Dict[int, Dict] = {}
|
||
|
||
async for line in response.aiter_lines():
|
||
if not line.strip():
|
||
continue
|
||
|
||
# 跳过SSE前缀
|
||
if line.startswith("data: "):
|
||
line = line[6:]
|
||
|
||
if line == "[DONE]":
|
||
break
|
||
|
||
try:
|
||
chunk = json.loads(line)
|
||
except json.JSONDecodeError:
|
||
continue
|
||
|
||
# 解析SSE数据
|
||
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
||
|
||
# 内容增量
|
||
content_delta = delta.get("content", "")
|
||
if content_delta:
|
||
accumulated_content += content_delta
|
||
yield {
|
||
"type": "content_delta",
|
||
"content": content_delta,
|
||
"full_content": accumulated_content
|
||
}
|
||
|
||
# 工具调用增量
|
||
tool_calls = delta.get("tool_calls", [])
|
||
for tc in tool_calls:
|
||
index = tc.get("index", 0)
|
||
if index not in accumulated_tool_calls:
|
||
accumulated_tool_calls[index] = {
|
||
"id": "",
|
||
"type": "function",
|
||
"function": {"name": "", "arguments": ""}
|
||
}
|
||
|
||
if tc.get("id"):
|
||
accumulated_tool_calls[index]["id"] = tc["id"]
|
||
if tc.get("function", {}).get("name"):
|
||
accumulated_tool_calls[index]["function"]["name"] = tc["function"]["name"]
|
||
if tc.get("function", {}).get("arguments"):
|
||
accumulated_tool_calls[index]["function"]["arguments"] += tc["function"]["arguments"]
|
||
|
||
# 完成信号
|
||
finish_reason = chunk.get("choices", [{}])[0].get("finish_reason")
|
||
if finish_reason:
|
||
yield {
|
||
"type": "done",
|
||
"finish_reason": finish_reason,
|
||
"content": accumulated_content,
|
||
"tool_calls": list(accumulated_tool_calls.values()) if accumulated_tool_calls else None,
|
||
"usage": chunk.get("usage")
|
||
}
|
||
|
||
except httpx.HTTPStatusError as e:
|
||
yield {
|
||
"type": "error",
|
||
"error": f"HTTP error: {e.response.status_code}"
|
||
}
|
||
except Exception as e:
|
||
yield {
|
||
"type": "error",
|
||
"error": str(e)
|
||
}
|
||
|
||
|
||
# 全局LLM客户端
|
||
llm_client = LLMClient()
|