"""LLM API客户端""" import json from typing import Dict, List, Optional, Generator, Any, Callable, AsyncGenerator from dataclasses import dataclass import httpx from alcor.config import config @dataclass class LLMResponse: """LLM响应""" content: str tool_calls: Optional[List[Dict[str, Any]]] = None usage: Optional[Dict[str, int]] = None finish_reason: Optional[str] = None raw: Optional[Dict] = None class LLMClient: """LLM API客户端,支持多种提供商""" def __init__( self, api_key: Optional[str] = None, api_url: Optional[str] = None, provider: Optional[str] = None ): self.api_key = api_key or config.llm_api_key self.api_url = api_url or config.llm_api_url self.provider = provider or config.llm_provider or self._detect_provider() self._client: Optional[httpx.AsyncClient] = None def _detect_provider(self) -> str: """检测提供商""" url = self.api_url.lower() if "deepseek" in url: return "deepseek" elif "bigmodel" in url or "glm" in url: return "glm" elif "zhipu" in url: return "glm" elif "qwen" in url or "dashscope" in url: return "qwen" elif "moonshot" in url or "moonshot" in url: return "moonshot" return "openai" @property def client(self) -> httpx.AsyncClient: """获取HTTP客户端""" if self._client is None: self._client = httpx.AsyncClient( timeout=httpx.Timeout(120.0, connect=30.0), headers={ "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } ) return self._client async def close(self): """关闭客户端""" if self._client: await self._client.aclose() self._client = None def _build_headers(self) -> Dict[str, str]: """构建请求头""" return { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } def _build_body( self, model: str, messages: List[Dict[str, str]], tools: Optional[List[Dict]] = None, stream: bool = True, **kwargs ) -> Dict[str, Any]: """构建请求体""" body = { "model": model, "messages": messages, "stream": stream } # 添加可选参数 if "temperature" in kwargs: body["temperature"] = kwargs["temperature"] if "max_tokens" in kwargs: body["max_tokens"] = kwargs["max_tokens"] if "top_p" in kwargs: body["top_p"] = kwargs["top_p"] if "thinking_enabled" in kwargs: body["thinking_enabled"] = kwargs["thinking_enabled"] # 添加工具 if tools: body["tools"] = tools return body def _parse_response(self, data: Dict) -> LLMResponse: """解析响应""" # 通用字段 content = "" tool_calls = None usage = None finish_reason = None # OpenAI格式 if "choices" in data: choice = data["choices"][0] message = choice.get("message", {}) content = message.get("content", "") tool_calls = message.get("tool_calls") finish_reason = choice.get("finish_reason") # 使用量统计 if "usage" in data: usage = { "prompt_tokens": data["usage"].get("prompt_tokens", 0), "completion_tokens": data["usage"].get("completion_tokens", 0), "total_tokens": data["usage"].get("total_tokens", 0) } return LLMResponse( content=content, tool_calls=tool_calls, usage=usage, finish_reason=finish_reason, raw=data ) async def call( self, model: str, messages: List[Dict[str, str]], tools: Optional[List[Dict]] = None, **kwargs ) -> LLMResponse: """调用LLM API(非流式)""" body = self._build_body(model, messages, tools, stream=False, **kwargs) try: response = await self.client.post( self.api_url, json=body, headers=self._build_headers() ) response.raise_for_status() data = response.json() return self._parse_response(data) except httpx.HTTPStatusError as e: raise Exception(f"HTTP error: {e.response.status_code} - {e.response.text}") except Exception as e: raise Exception(f"LLM API error: {str(e)}") async def stream( self, model: str, messages: List[Dict[str, str]], tools: Optional[List[Dict]] = None, **kwargs ) -> AsyncGenerator[Dict[str, Any], None]: """流式调用LLM API""" body = self._build_body(model, messages, tools, stream=True, **kwargs) try: async with self.client.stream( "POST", self.api_url, json=body, headers=self._build_headers() ) as response: response.raise_for_status() accumulated_content = "" accumulated_tool_calls: Dict[int, Dict] = {} async for line in response.aiter_lines(): if not line.strip(): continue # 跳过SSE前缀 if line.startswith("data: "): line = line[6:] if line == "[DONE]": break try: chunk = json.loads(line) except json.JSONDecodeError: continue # 解析SSE数据 delta = chunk.get("choices", [{}])[0].get("delta", {}) # 内容增量 content_delta = delta.get("content", "") if content_delta: accumulated_content += content_delta yield { "type": "content_delta", "content": content_delta, "full_content": accumulated_content } # 工具调用增量 tool_calls = delta.get("tool_calls", []) for tc in tool_calls: index = tc.get("index", 0) if index not in accumulated_tool_calls: accumulated_tool_calls[index] = { "id": "", "type": "function", "function": {"name": "", "arguments": ""} } if tc.get("id"): accumulated_tool_calls[index]["id"] = tc["id"] if tc.get("function", {}).get("name"): accumulated_tool_calls[index]["function"]["name"] = tc["function"]["name"] if tc.get("function", {}).get("arguments"): accumulated_tool_calls[index]["function"]["arguments"] += tc["function"]["arguments"] # 完成信号 finish_reason = chunk.get("choices", [{}])[0].get("finish_reason") if finish_reason: yield { "type": "done", "finish_reason": finish_reason, "content": accumulated_content, "tool_calls": list(accumulated_tool_calls.values()) if accumulated_tool_calls else None, "usage": chunk.get("usage") } except httpx.HTTPStatusError as e: yield { "type": "error", "error": f"HTTP error: {e.response.status_code}" } except Exception as e: yield { "type": "error", "error": str(e) } # 全局LLM客户端 llm_client = LLMClient()