chore: 修改项目名
This commit is contained in:
parent
6749213f62
commit
72a3738388
|
|
@ -9,5 +9,5 @@
|
||||||
!README.md
|
!README.md
|
||||||
!.gitignore
|
!.gitignore
|
||||||
|
|
||||||
!alcor/**/*.py
|
!luxx/**/*.py
|
||||||
!docs/**/*.md
|
!docs/**/*.md
|
||||||
|
|
|
||||||
|
|
@ -1,69 +0,0 @@
|
||||||
"""FastAPI应用工厂"""
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
|
|
||||||
from alcor.config import config
|
|
||||||
from alcor.database import init_db
|
|
||||||
from alcor.routes import api_router
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def lifespan(app: FastAPI):
|
|
||||||
"""应用生命周期管理"""
|
|
||||||
# 启动时
|
|
||||||
print("🚀 Starting up ChatBackend API...")
|
|
||||||
|
|
||||||
# 初始化数据库
|
|
||||||
init_db()
|
|
||||||
print("✅ Database initialized")
|
|
||||||
|
|
||||||
# 加载内置工具
|
|
||||||
from alcor.tools.builtin import crawler, code, data
|
|
||||||
print(f"✅ Loaded {len(api_router.routes)} API routes")
|
|
||||||
|
|
||||||
yield
|
|
||||||
|
|
||||||
# 关闭时
|
|
||||||
print("👋 Shutting down ChatBackend API...")
|
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> FastAPI:
|
|
||||||
"""创建FastAPI应用"""
|
|
||||||
app = FastAPI(
|
|
||||||
title="ChatBackend API",
|
|
||||||
description="智能聊天后端API,支持多模型、流式响应、工具调用",
|
|
||||||
version="1.0.0",
|
|
||||||
lifespan=lifespan
|
|
||||||
)
|
|
||||||
|
|
||||||
# 配置CORS
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=["*"], # 生产环境应限制
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# 注册路由
|
|
||||||
app.include_router(api_router, prefix="/api")
|
|
||||||
|
|
||||||
# 健康检查
|
|
||||||
@app.get("/health")
|
|
||||||
async def health_check():
|
|
||||||
return {"status": "healthy", "service": "chat-backend"}
|
|
||||||
|
|
||||||
@app.get("/")
|
|
||||||
async def root():
|
|
||||||
return {
|
|
||||||
"service": "ChatBackend API",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"docs": "/docs"
|
|
||||||
}
|
|
||||||
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
# 创建应用实例
|
|
||||||
app = create_app()
|
|
||||||
|
|
@ -1,73 +0,0 @@
|
||||||
"""工具路由"""
|
|
||||||
from typing import Optional, List, Dict, Any
|
|
||||||
from fastapi import APIRouter, Depends
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from alcor.database import get_db
|
|
||||||
from alcor.models import User
|
|
||||||
from alcor.routes.auth import get_current_user
|
|
||||||
from alcor.tools.core import registry
|
|
||||||
from alcor.utils.helpers import success_response
|
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/tools", tags=["工具"])
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=dict)
|
|
||||||
def list_tools(
|
|
||||||
category: Optional[str] = None,
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""获取可用工具列表"""
|
|
||||||
if category:
|
|
||||||
tools = registry.list_by_category(category)
|
|
||||||
else:
|
|
||||||
tools = registry.list_all()
|
|
||||||
|
|
||||||
# 按分类分组
|
|
||||||
categorized = {}
|
|
||||||
for tool in tools:
|
|
||||||
cat = tool.get("function", {}).get("category", "general")
|
|
||||||
if cat not in categorized:
|
|
||||||
categorized[cat] = []
|
|
||||||
categorized[cat].append(tool)
|
|
||||||
|
|
||||||
return success_response(data={
|
|
||||||
"tools": tools,
|
|
||||||
"categorized": categorized,
|
|
||||||
"total": registry.tool_count
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{name}", response_model=dict)
|
|
||||||
def get_tool(
|
|
||||||
name: str,
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""获取工具详情"""
|
|
||||||
tool = registry.get(name)
|
|
||||||
|
|
||||||
if not tool:
|
|
||||||
return {"success": False, "message": "工具不存在", "code": 404}
|
|
||||||
|
|
||||||
return success_response(data={
|
|
||||||
"name": tool.name,
|
|
||||||
"description": tool.description,
|
|
||||||
"parameters": tool.parameters,
|
|
||||||
"category": tool.category
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{name}/execute", response_model=dict)
|
|
||||||
def execute_tool(
|
|
||||||
name: str,
|
|
||||||
arguments: Dict[str, Any],
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""手动执行工具"""
|
|
||||||
result = registry.execute(name, arguments)
|
|
||||||
|
|
||||||
if not result.get("success"):
|
|
||||||
return {"success": False, "message": result.get("error"), "code": 400}
|
|
||||||
|
|
||||||
return success_response(data=result)
|
|
||||||
|
|
@ -1,11 +0,0 @@
|
||||||
"""服务层模块"""
|
|
||||||
from alcor.services.llm_client import LLMClient, llm_client, LLMResponse
|
|
||||||
from alcor.services.chat import ChatService, chat_service
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"LLMClient",
|
|
||||||
"llm_client",
|
|
||||||
"LLMResponse",
|
|
||||||
"ChatService",
|
|
||||||
"chat_service"
|
|
||||||
]
|
|
||||||
|
|
@ -1,262 +0,0 @@
|
||||||
"""聊天服务模块"""
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
from typing import Dict, List, Optional, Any, Generator
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from alcor.models import Conversation, Message
|
|
||||||
from alcor.tools.executor import ToolExecutor
|
|
||||||
from alcor.tools.core import registry
|
|
||||||
from alcor.services.llm_client import llm_client, LLMClient
|
|
||||||
|
|
||||||
|
|
||||||
# 最大迭代次数,防止无限循环
|
|
||||||
MAX_ITERATIONS = 10
|
|
||||||
|
|
||||||
|
|
||||||
class ChatService:
|
|
||||||
"""聊天服务"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
llm_client: Optional[LLMClient] = None,
|
|
||||||
max_iterations: int = MAX_ITERATIONS
|
|
||||||
):
|
|
||||||
self.llm_client = llm_client or llm_client
|
|
||||||
self.tool_executor = ToolExecutor(enable_cache=True, cache_ttl=300)
|
|
||||||
self.max_iterations = max_iterations
|
|
||||||
|
|
||||||
def build_messages(
|
|
||||||
self,
|
|
||||||
conversation: Conversation,
|
|
||||||
include_system: bool = True
|
|
||||||
) -> List[Dict[str, str]]:
|
|
||||||
"""构建消息列表"""
|
|
||||||
messages = []
|
|
||||||
|
|
||||||
# 添加系统提示
|
|
||||||
if include_system and conversation.system_prompt:
|
|
||||||
messages.append({
|
|
||||||
"role": "system",
|
|
||||||
"content": conversation.system_prompt
|
|
||||||
})
|
|
||||||
|
|
||||||
# 添加历史消息
|
|
||||||
for msg in conversation.messages.order_by(Message.created_at).all():
|
|
||||||
try:
|
|
||||||
content_data = json.loads(msg.content) if msg.content else {}
|
|
||||||
if isinstance(content_data, dict):
|
|
||||||
text = content_data.get("text", "")
|
|
||||||
else:
|
|
||||||
text = str(msg.content)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
text = msg.content
|
|
||||||
|
|
||||||
messages.append({
|
|
||||||
"role": msg.role,
|
|
||||||
"content": text
|
|
||||||
})
|
|
||||||
|
|
||||||
return messages
|
|
||||||
|
|
||||||
def stream_response(
|
|
||||||
self,
|
|
||||||
conversation: Conversation,
|
|
||||||
user_message: str,
|
|
||||||
tools_enabled: bool = True,
|
|
||||||
context: Optional[Dict[str, Any]] = None
|
|
||||||
) -> Generator[Dict[str, Any], None, None]:
|
|
||||||
"""
|
|
||||||
流式响应生成器
|
|
||||||
|
|
||||||
生成事件类型:
|
|
||||||
- process_step: thinking/text/tool_call/tool_result 步骤
|
|
||||||
- done: 最终响应完成
|
|
||||||
- error: 出错时
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 构建消息列表
|
|
||||||
messages = self.build_messages(conversation)
|
|
||||||
|
|
||||||
# 添加用户消息
|
|
||||||
messages.append({
|
|
||||||
"role": "user",
|
|
||||||
"content": user_message
|
|
||||||
})
|
|
||||||
|
|
||||||
# 获取工具列表
|
|
||||||
tools = registry.list_all() if tools_enabled else None
|
|
||||||
|
|
||||||
# 迭代处理
|
|
||||||
iteration = 0
|
|
||||||
full_response = ""
|
|
||||||
tool_calls_buffer: List[Dict] = []
|
|
||||||
|
|
||||||
while iteration < self.max_iterations:
|
|
||||||
iteration += 1
|
|
||||||
|
|
||||||
# 调用LLM
|
|
||||||
tool_calls_this_round = None
|
|
||||||
|
|
||||||
for event in self.llm_client.stream(
|
|
||||||
model=conversation.model,
|
|
||||||
messages=messages,
|
|
||||||
tools=tools,
|
|
||||||
temperature=conversation.temperature,
|
|
||||||
max_tokens=conversation.max_tokens,
|
|
||||||
thinking_enabled=conversation.thinking_enabled
|
|
||||||
):
|
|
||||||
event_type = event.get("type")
|
|
||||||
|
|
||||||
if event_type == "content_delta":
|
|
||||||
# 内容增量
|
|
||||||
content = event.get("content", "")
|
|
||||||
if content:
|
|
||||||
full_response += content
|
|
||||||
yield {
|
|
||||||
"type": "process_step",
|
|
||||||
"step_type": "text",
|
|
||||||
"content": content
|
|
||||||
}
|
|
||||||
|
|
||||||
elif event_type == "done":
|
|
||||||
# 完成
|
|
||||||
tool_calls_this_round = event.get("tool_calls")
|
|
||||||
|
|
||||||
# 处理工具调用
|
|
||||||
if tool_calls_this_round and tools_enabled:
|
|
||||||
yield {
|
|
||||||
"type": "process_step",
|
|
||||||
"step_type": "tool_call",
|
|
||||||
"tool_calls": tool_calls_this_round
|
|
||||||
}
|
|
||||||
|
|
||||||
# 执行工具
|
|
||||||
tool_results = self.tool_executor.process_tool_calls_parallel(
|
|
||||||
tool_calls_this_round
|
|
||||||
)
|
|
||||||
|
|
||||||
for result in tool_results:
|
|
||||||
yield {
|
|
||||||
"type": "process_step",
|
|
||||||
"step_type": "tool_result",
|
|
||||||
"result": result
|
|
||||||
}
|
|
||||||
|
|
||||||
# 添加到消息历史
|
|
||||||
messages.append({
|
|
||||||
"role": "assistant",
|
|
||||||
"content": full_response,
|
|
||||||
"tool_calls": tool_calls_this_round
|
|
||||||
})
|
|
||||||
|
|
||||||
# 添加工具结果
|
|
||||||
for tr in tool_results:
|
|
||||||
messages.append({
|
|
||||||
"role": "tool",
|
|
||||||
"tool_call_id": tr.get("tool_call_id"),
|
|
||||||
"content": tr.get("content", ""),
|
|
||||||
"name": tr.get("name")
|
|
||||||
})
|
|
||||||
|
|
||||||
tool_calls_buffer.extend(tool_calls_this_round)
|
|
||||||
else:
|
|
||||||
# 没有工具调用,退出循环
|
|
||||||
break
|
|
||||||
|
|
||||||
# 如果没有更多工具调用,结束
|
|
||||||
if not tool_calls_this_round or not tools_enabled:
|
|
||||||
break
|
|
||||||
|
|
||||||
# 最终完成
|
|
||||||
yield {
|
|
||||||
"type": "done",
|
|
||||||
"content": full_response,
|
|
||||||
"tool_calls": tool_calls_buffer if tool_calls_buffer else None,
|
|
||||||
"iterations": iteration
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
yield {
|
|
||||||
"type": "error",
|
|
||||||
"error": str(e)
|
|
||||||
}
|
|
||||||
|
|
||||||
def non_stream_response(
|
|
||||||
self,
|
|
||||||
conversation: Conversation,
|
|
||||||
user_message: str,
|
|
||||||
tools_enabled: bool = True,
|
|
||||||
context: Optional[Dict[str, Any]] = None
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""非流式响应"""
|
|
||||||
try:
|
|
||||||
messages = self.build_messages(conversation)
|
|
||||||
messages.append({
|
|
||||||
"role": "user",
|
|
||||||
"content": user_message
|
|
||||||
})
|
|
||||||
|
|
||||||
tools = registry.list_all() if tools_enabled else None
|
|
||||||
|
|
||||||
# 迭代处理
|
|
||||||
iteration = 0
|
|
||||||
full_response = ""
|
|
||||||
all_tool_calls = []
|
|
||||||
|
|
||||||
while iteration < self.max_iterations:
|
|
||||||
iteration += 1
|
|
||||||
|
|
||||||
response = self.llm_client.call(
|
|
||||||
model=conversation.model,
|
|
||||||
messages=messages,
|
|
||||||
tools=tools,
|
|
||||||
stream=False,
|
|
||||||
temperature=conversation.temperature,
|
|
||||||
max_tokens=conversation.max_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
full_response = response.content
|
|
||||||
tool_calls = response.tool_calls
|
|
||||||
|
|
||||||
if tool_calls and tools_enabled:
|
|
||||||
# 执行工具
|
|
||||||
tool_results = self.tool_executor.process_tool_calls_parallel(tool_calls)
|
|
||||||
all_tool_calls.extend(tool_calls)
|
|
||||||
|
|
||||||
messages.append({
|
|
||||||
"role": "assistant",
|
|
||||||
"content": full_response,
|
|
||||||
"tool_calls": tool_calls
|
|
||||||
})
|
|
||||||
|
|
||||||
for tr in tool_results:
|
|
||||||
messages.append({
|
|
||||||
"role": "tool",
|
|
||||||
"tool_call_id": tr.get("tool_call_id"),
|
|
||||||
"content": tr.get("content", ""),
|
|
||||||
"name": tr.get("name")
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
messages.append({
|
|
||||||
"role": "assistant",
|
|
||||||
"content": full_response
|
|
||||||
})
|
|
||||||
break
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"content": full_response,
|
|
||||||
"tool_calls": all_tool_calls,
|
|
||||||
"iterations": iteration
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": str(e)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# 全局聊天服务
|
|
||||||
chat_service = ChatService()
|
|
||||||
|
|
@ -1,256 +0,0 @@
|
||||||
"""LLM API客户端"""
|
|
||||||
import json
|
|
||||||
from typing import Dict, List, Optional, Generator, Any, Callable, AsyncGenerator
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from alcor.config import config
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LLMResponse:
|
|
||||||
"""LLM响应"""
|
|
||||||
content: str
|
|
||||||
tool_calls: Optional[List[Dict[str, Any]]] = None
|
|
||||||
usage: Optional[Dict[str, int]] = None
|
|
||||||
finish_reason: Optional[str] = None
|
|
||||||
raw: Optional[Dict] = None
|
|
||||||
|
|
||||||
|
|
||||||
class LLMClient:
|
|
||||||
"""LLM API客户端,支持多种提供商"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
api_url: Optional[str] = None,
|
|
||||||
provider: Optional[str] = None
|
|
||||||
):
|
|
||||||
self.api_key = api_key or config.llm_api_key
|
|
||||||
self.api_url = api_url or config.llm_api_url
|
|
||||||
self.provider = provider or config.llm_provider or self._detect_provider()
|
|
||||||
self._client: Optional[httpx.AsyncClient] = None
|
|
||||||
|
|
||||||
def _detect_provider(self) -> str:
|
|
||||||
"""检测提供商"""
|
|
||||||
url = self.api_url.lower()
|
|
||||||
if "deepseek" in url:
|
|
||||||
return "deepseek"
|
|
||||||
elif "bigmodel" in url or "glm" in url:
|
|
||||||
return "glm"
|
|
||||||
elif "zhipu" in url:
|
|
||||||
return "glm"
|
|
||||||
elif "qwen" in url or "dashscope" in url:
|
|
||||||
return "qwen"
|
|
||||||
elif "moonshot" in url or "moonshot" in url:
|
|
||||||
return "moonshot"
|
|
||||||
return "openai"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def client(self) -> httpx.AsyncClient:
|
|
||||||
"""获取HTTP客户端"""
|
|
||||||
if self._client is None:
|
|
||||||
self._client = httpx.AsyncClient(
|
|
||||||
timeout=httpx.Timeout(120.0, connect=30.0),
|
|
||||||
headers={
|
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return self._client
|
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
"""关闭客户端"""
|
|
||||||
if self._client:
|
|
||||||
await self._client.aclose()
|
|
||||||
self._client = None
|
|
||||||
|
|
||||||
def _build_headers(self) -> Dict[str, str]:
|
|
||||||
"""构建请求头"""
|
|
||||||
return {
|
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
def _build_body(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
messages: List[Dict[str, str]],
|
|
||||||
tools: Optional[List[Dict]] = None,
|
|
||||||
stream: bool = True,
|
|
||||||
**kwargs
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""构建请求体"""
|
|
||||||
body = {
|
|
||||||
"model": model,
|
|
||||||
"messages": messages,
|
|
||||||
"stream": stream
|
|
||||||
}
|
|
||||||
|
|
||||||
# 添加可选参数
|
|
||||||
if "temperature" in kwargs:
|
|
||||||
body["temperature"] = kwargs["temperature"]
|
|
||||||
if "max_tokens" in kwargs:
|
|
||||||
body["max_tokens"] = kwargs["max_tokens"]
|
|
||||||
if "top_p" in kwargs:
|
|
||||||
body["top_p"] = kwargs["top_p"]
|
|
||||||
if "thinking_enabled" in kwargs:
|
|
||||||
body["thinking_enabled"] = kwargs["thinking_enabled"]
|
|
||||||
|
|
||||||
# 添加工具
|
|
||||||
if tools:
|
|
||||||
body["tools"] = tools
|
|
||||||
|
|
||||||
return body
|
|
||||||
|
|
||||||
def _parse_response(self, data: Dict) -> LLMResponse:
|
|
||||||
"""解析响应"""
|
|
||||||
# 通用字段
|
|
||||||
content = ""
|
|
||||||
tool_calls = None
|
|
||||||
usage = None
|
|
||||||
finish_reason = None
|
|
||||||
|
|
||||||
# OpenAI格式
|
|
||||||
if "choices" in data:
|
|
||||||
choice = data["choices"][0]
|
|
||||||
message = choice.get("message", {})
|
|
||||||
content = message.get("content", "")
|
|
||||||
tool_calls = message.get("tool_calls")
|
|
||||||
finish_reason = choice.get("finish_reason")
|
|
||||||
|
|
||||||
# 使用量统计
|
|
||||||
if "usage" in data:
|
|
||||||
usage = {
|
|
||||||
"prompt_tokens": data["usage"].get("prompt_tokens", 0),
|
|
||||||
"completion_tokens": data["usage"].get("completion_tokens", 0),
|
|
||||||
"total_tokens": data["usage"].get("total_tokens", 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
return LLMResponse(
|
|
||||||
content=content,
|
|
||||||
tool_calls=tool_calls,
|
|
||||||
usage=usage,
|
|
||||||
finish_reason=finish_reason,
|
|
||||||
raw=data
|
|
||||||
)
|
|
||||||
|
|
||||||
async def call(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
messages: List[Dict[str, str]],
|
|
||||||
tools: Optional[List[Dict]] = None,
|
|
||||||
**kwargs
|
|
||||||
) -> LLMResponse:
|
|
||||||
"""调用LLM API(非流式)"""
|
|
||||||
body = self._build_body(model, messages, tools, stream=False, **kwargs)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await self.client.post(
|
|
||||||
self.api_url,
|
|
||||||
json=body,
|
|
||||||
headers=self._build_headers()
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
|
||||||
return self._parse_response(data)
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
raise Exception(f"HTTP error: {e.response.status_code} - {e.response.text}")
|
|
||||||
except Exception as e:
|
|
||||||
raise Exception(f"LLM API error: {str(e)}")
|
|
||||||
|
|
||||||
async def stream(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
messages: List[Dict[str, str]],
|
|
||||||
tools: Optional[List[Dict]] = None,
|
|
||||||
**kwargs
|
|
||||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
|
||||||
"""流式调用LLM API"""
|
|
||||||
body = self._build_body(model, messages, tools, stream=True, **kwargs)
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with self.client.stream(
|
|
||||||
"POST",
|
|
||||||
self.api_url,
|
|
||||||
json=body,
|
|
||||||
headers=self._build_headers()
|
|
||||||
) as response:
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
accumulated_content = ""
|
|
||||||
accumulated_tool_calls: Dict[int, Dict] = {}
|
|
||||||
|
|
||||||
async for line in response.aiter_lines():
|
|
||||||
if not line.strip():
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 跳过SSE前缀
|
|
||||||
if line.startswith("data: "):
|
|
||||||
line = line[6:]
|
|
||||||
|
|
||||||
if line == "[DONE]":
|
|
||||||
break
|
|
||||||
|
|
||||||
try:
|
|
||||||
chunk = json.loads(line)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 解析SSE数据
|
|
||||||
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
|
||||||
|
|
||||||
# 内容增量
|
|
||||||
content_delta = delta.get("content", "")
|
|
||||||
if content_delta:
|
|
||||||
accumulated_content += content_delta
|
|
||||||
yield {
|
|
||||||
"type": "content_delta",
|
|
||||||
"content": content_delta,
|
|
||||||
"full_content": accumulated_content
|
|
||||||
}
|
|
||||||
|
|
||||||
# 工具调用增量
|
|
||||||
tool_calls = delta.get("tool_calls", [])
|
|
||||||
for tc in tool_calls:
|
|
||||||
index = tc.get("index", 0)
|
|
||||||
if index not in accumulated_tool_calls:
|
|
||||||
accumulated_tool_calls[index] = {
|
|
||||||
"id": "",
|
|
||||||
"type": "function",
|
|
||||||
"function": {"name": "", "arguments": ""}
|
|
||||||
}
|
|
||||||
|
|
||||||
if tc.get("id"):
|
|
||||||
accumulated_tool_calls[index]["id"] = tc["id"]
|
|
||||||
if tc.get("function", {}).get("name"):
|
|
||||||
accumulated_tool_calls[index]["function"]["name"] = tc["function"]["name"]
|
|
||||||
if tc.get("function", {}).get("arguments"):
|
|
||||||
accumulated_tool_calls[index]["function"]["arguments"] += tc["function"]["arguments"]
|
|
||||||
|
|
||||||
# 完成信号
|
|
||||||
finish_reason = chunk.get("choices", [{}])[0].get("finish_reason")
|
|
||||||
if finish_reason:
|
|
||||||
yield {
|
|
||||||
"type": "done",
|
|
||||||
"finish_reason": finish_reason,
|
|
||||||
"content": accumulated_content,
|
|
||||||
"tool_calls": list(accumulated_tool_calls.values()) if accumulated_tool_calls else None,
|
|
||||||
"usage": chunk.get("usage")
|
|
||||||
}
|
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
yield {
|
|
||||||
"type": "error",
|
|
||||||
"error": f"HTTP error: {e.response.status_code}"
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
yield {
|
|
||||||
"type": "error",
|
|
||||||
"error": str(e)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# 全局LLM客户端
|
|
||||||
llm_client = LLMClient()
|
|
||||||
|
|
@ -1,19 +0,0 @@
|
||||||
"""工具系统模块"""
|
|
||||||
from alcor.tools.core import (
|
|
||||||
ToolDefinition,
|
|
||||||
ToolResult,
|
|
||||||
ToolRegistry,
|
|
||||||
registry
|
|
||||||
)
|
|
||||||
from alcor.tools.factory import tool, tool_function
|
|
||||||
from alcor.tools.executor import ToolExecutor
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"ToolDefinition",
|
|
||||||
"ToolResult",
|
|
||||||
"ToolRegistry",
|
|
||||||
"registry",
|
|
||||||
"tool",
|
|
||||||
"tool_function",
|
|
||||||
"ToolExecutor"
|
|
||||||
]
|
|
||||||
|
|
@ -1,7 +0,0 @@
|
||||||
"""内置工具模块"""
|
|
||||||
# 导入所有内置工具以注册它们
|
|
||||||
from alcor.tools.builtin import crawler
|
|
||||||
from alcor.tools.builtin import code
|
|
||||||
from alcor.tools.builtin import data
|
|
||||||
|
|
||||||
__all__ = ["crawler", "code", "data"]
|
|
||||||
|
|
@ -1,270 +0,0 @@
|
||||||
"""数据处理工具"""
|
|
||||||
import re
|
|
||||||
import json
|
|
||||||
import hashlib
|
|
||||||
import base64
|
|
||||||
import urllib.parse
|
|
||||||
from typing import Dict, Any, List
|
|
||||||
|
|
||||||
from alcor.tools.factory import tool
|
|
||||||
|
|
||||||
|
|
||||||
@tool(
|
|
||||||
name="calculate",
|
|
||||||
description="Perform mathematical calculations",
|
|
||||||
parameters={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"expression": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Mathematical expression to evaluate (e.g., '2 + 2', 'sqrt(16)', 'sin(pi/2)')"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["expression"]
|
|
||||||
},
|
|
||||||
category="data"
|
|
||||||
)
|
|
||||||
def calculate(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""执行数学计算"""
|
|
||||||
expression = arguments.get("expression", "")
|
|
||||||
|
|
||||||
if not expression:
|
|
||||||
return {"success": False, "error": "Expression is required"}
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 安全替换数学函数
|
|
||||||
safe_dict = {
|
|
||||||
"abs": abs,
|
|
||||||
"round": round,
|
|
||||||
"min": min,
|
|
||||||
"max": max,
|
|
||||||
"sum": sum,
|
|
||||||
"pow": pow,
|
|
||||||
"sqrt": lambda x: x ** 0.5,
|
|
||||||
"sin": lambda x: __import__("math").sin(x),
|
|
||||||
"cos": lambda x: __import__("math").cos(x),
|
|
||||||
"tan": lambda x: __import__("math").tan(x),
|
|
||||||
"log": lambda x: __import__("math").log(x),
|
|
||||||
"pi": __import__("math").pi,
|
|
||||||
"e": __import__("math").e,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 移除危险字符,只保留数字和运算符
|
|
||||||
safe_expr = re.sub(r"[^0-9+\-*/().%sqrtinsclogmaxminpowabsroundte, ]", "", expression)
|
|
||||||
result = eval(safe_expr, {"__builtins__": {}, **safe_dict})
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"data": {
|
|
||||||
"expression": expression,
|
|
||||||
"result": float(result) if isinstance(result, (int, float)) else result
|
|
||||||
}
|
|
||||||
}
|
|
||||||
except ZeroDivisionError:
|
|
||||||
return {"success": False, "error": "Division by zero"}
|
|
||||||
except Exception as e:
|
|
||||||
return {"success": False, "error": f"Calculation error: {str(e)}"}
|
|
||||||
|
|
||||||
|
|
||||||
@tool(
|
|
||||||
name="text_process",
|
|
||||||
description="Process and transform text",
|
|
||||||
parameters={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"text": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Input text"
|
|
||||||
},
|
|
||||||
"operation": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Operation to perform: upper, lower, title, reverse, word_count, char_count, reverse_words",
|
|
||||||
"enum": ["upper", "lower", "title", "reverse", "word_count", "char_count", "reverse_words"]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["text", "operation"]
|
|
||||||
},
|
|
||||||
category="data"
|
|
||||||
)
|
|
||||||
def text_process(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""文本处理"""
|
|
||||||
text = arguments.get("text", "")
|
|
||||||
operation = arguments.get("operation", "")
|
|
||||||
|
|
||||||
if not text:
|
|
||||||
return {"success": False, "error": "Text is required"}
|
|
||||||
|
|
||||||
operations = {
|
|
||||||
"upper": lambda t: t.upper(),
|
|
||||||
"lower": lambda t: t.lower(),
|
|
||||||
"title": lambda t: t.title(),
|
|
||||||
"reverse": lambda t: t[::-1],
|
|
||||||
"word_count": lambda t: len(t.split()),
|
|
||||||
"char_count": lambda t: len(t),
|
|
||||||
"reverse_words": lambda t: " ".join(t.split()[::-1])
|
|
||||||
}
|
|
||||||
|
|
||||||
if operation not in operations:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": f"Unknown operation: {operation}"
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = operations[operation](text)
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"data": {
|
|
||||||
"operation": operation,
|
|
||||||
"input": text,
|
|
||||||
"result": result
|
|
||||||
}
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
return {"success": False, "error": str(e)}
|
|
||||||
|
|
||||||
|
|
||||||
@tool(
|
|
||||||
name="hash_text",
|
|
||||||
description="Generate hash of text using various algorithms",
|
|
||||||
parameters={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"text": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Text to hash"
|
|
||||||
},
|
|
||||||
"algorithm": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Hash algorithm: md5, sha1, sha256, sha512",
|
|
||||||
"default": "sha256"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["text"]
|
|
||||||
},
|
|
||||||
category="data"
|
|
||||||
)
|
|
||||||
def hash_text(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""生成文本哈希"""
|
|
||||||
text = arguments.get("text", "")
|
|
||||||
algorithm = arguments.get("algorithm", "sha256")
|
|
||||||
|
|
||||||
if not text:
|
|
||||||
return {"success": False, "error": "Text is required"}
|
|
||||||
|
|
||||||
hash_funcs = {
|
|
||||||
"md5": hashlib.md5,
|
|
||||||
"sha1": hashlib.sha1,
|
|
||||||
"sha256": hashlib.sha256,
|
|
||||||
"sha512": hashlib.sha512
|
|
||||||
}
|
|
||||||
|
|
||||||
if algorithm not in hash_funcs:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": f"Unsupported algorithm: {algorithm}"
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
hash_obj = hash_funcs[algorithm](text.encode("utf-8"))
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"data": {
|
|
||||||
"algorithm": algorithm,
|
|
||||||
"hash": hash_obj.hexdigest()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
return {"success": False, "error": str(e)}
|
|
||||||
|
|
||||||
|
|
||||||
@tool(
|
|
||||||
name="url_encode_decode",
|
|
||||||
description="URL encode or decode text",
|
|
||||||
parameters={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"text": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Text to encode or decode"
|
|
||||||
},
|
|
||||||
"operation": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Operation: encode or decode",
|
|
||||||
"enum": ["encode", "decode"]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["text", "operation"]
|
|
||||||
},
|
|
||||||
category="data"
|
|
||||||
)
|
|
||||||
def url_encode_decode(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""URL编码/解码"""
|
|
||||||
text = arguments.get("text", "")
|
|
||||||
operation = arguments.get("operation", "encode")
|
|
||||||
|
|
||||||
if not text:
|
|
||||||
return {"success": False, "error": "Text is required"}
|
|
||||||
|
|
||||||
try:
|
|
||||||
if operation == "encode":
|
|
||||||
result = urllib.parse.quote(text)
|
|
||||||
else:
|
|
||||||
result = urllib.parse.unquote(text)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"data": {
|
|
||||||
"operation": operation,
|
|
||||||
"input": text,
|
|
||||||
"result": result
|
|
||||||
}
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
return {"success": False, "error": str(e)}
|
|
||||||
|
|
||||||
|
|
||||||
@tool(
|
|
||||||
name="base64_encode_decode",
|
|
||||||
description="Base64 encode or decode text",
|
|
||||||
parameters={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"text": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Text to encode or decode"
|
|
||||||
},
|
|
||||||
"operation": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Operation: encode or decode",
|
|
||||||
"enum": ["encode", "decode"]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["text", "operation"]
|
|
||||||
},
|
|
||||||
category="data"
|
|
||||||
)
|
|
||||||
def base64_encode_decode(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""Base64编码/解码"""
|
|
||||||
text = arguments.get("text", "")
|
|
||||||
operation = arguments.get("operation", "encode")
|
|
||||||
|
|
||||||
if not text:
|
|
||||||
return {"success": False, "error": "Text is required"}
|
|
||||||
|
|
||||||
try:
|
|
||||||
if operation == "encode":
|
|
||||||
result = base64.b64encode(text.encode("utf-8")).decode("utf-8")
|
|
||||||
else:
|
|
||||||
result = base64.b64decode(text.encode("utf-8")).decode("utf-8")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"data": {
|
|
||||||
"operation": operation,
|
|
||||||
"input": text,
|
|
||||||
"result": result
|
|
||||||
}
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
return {"success": False, "error": str(e)}
|
|
||||||
|
|
@ -1,186 +0,0 @@
|
||||||
"""工具执行器"""
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import hashlib
|
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
||||||
from typing import List, Dict, Optional, Any
|
|
||||||
|
|
||||||
from alcor.tools.core import registry, ToolResult
|
|
||||||
|
|
||||||
|
|
||||||
class ToolExecutor:
|
|
||||||
"""工具执行器,支持缓存、并行执行"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
enable_cache: bool = True,
|
|
||||||
cache_ttl: int = 300, # 5分钟
|
|
||||||
max_workers: int = 4
|
|
||||||
):
|
|
||||||
self.enable_cache = enable_cache
|
|
||||||
self.cache_ttl = cache_ttl
|
|
||||||
self.max_workers = max_workers
|
|
||||||
self._cache: Dict[str, tuple] = {} # (result, timestamp)
|
|
||||||
self._call_history: List[Dict[str, Any]] = []
|
|
||||||
|
|
||||||
def _make_cache_key(self, name: str, args: dict) -> str:
|
|
||||||
"""生成缓存键"""
|
|
||||||
args_str = json.dumps(args, sort_keys=True, ensure_ascii=False)
|
|
||||||
return hashlib.md5(f"{name}:{args_str}".encode()).hexdigest()
|
|
||||||
|
|
||||||
def _is_cache_valid(self, cache_key: str) -> bool:
|
|
||||||
"""检查缓存是否有效"""
|
|
||||||
if cache_key not in self._cache:
|
|
||||||
return False
|
|
||||||
_, timestamp = self._cache[cache_key]
|
|
||||||
return (time.time() - timestamp) < self.cache_ttl
|
|
||||||
|
|
||||||
def _get_cached(self, cache_key: str) -> Optional[Dict]:
|
|
||||||
"""获取缓存结果"""
|
|
||||||
if self.enable_cache and self._is_cache_valid(cache_key):
|
|
||||||
return self._cache[cache_key][0]
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _set_cached(self, cache_key: str, result: Dict) -> None:
|
|
||||||
"""设置缓存"""
|
|
||||||
if self.enable_cache:
|
|
||||||
self._cache[cache_key] = (result, time.time())
|
|
||||||
|
|
||||||
def _record_call(self, name: str, args: dict, result: Dict) -> None:
|
|
||||||
"""记录调用历史"""
|
|
||||||
self._call_history.append({
|
|
||||||
"name": name,
|
|
||||||
"args": args,
|
|
||||||
"result": result,
|
|
||||||
"timestamp": time.time()
|
|
||||||
})
|
|
||||||
# 限制历史记录数量
|
|
||||||
if len(self._call_history) > 1000:
|
|
||||||
self._call_history = self._call_history[-500:]
|
|
||||||
|
|
||||||
def process_tool_calls(
|
|
||||||
self,
|
|
||||||
tool_calls: List[Dict[str, Any]],
|
|
||||||
context: Optional[Dict[str, Any]] = None
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""顺序处理工具调用"""
|
|
||||||
results = []
|
|
||||||
|
|
||||||
for call in tool_calls:
|
|
||||||
name = call.get("function", {}).get("name", "")
|
|
||||||
args_str = call.get("function", {}).get("arguments", "{}")
|
|
||||||
call_id = call.get("id", "")
|
|
||||||
|
|
||||||
# 解析JSON参数
|
|
||||||
try:
|
|
||||||
args = json.loads(args_str) if isinstance(args_str, str) else args_str
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
results.append(self._create_error_result(call_id, name, "Invalid JSON arguments"))
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 检查缓存
|
|
||||||
cache_key = self._make_cache_key(name, args)
|
|
||||||
cached_result = self._get_cached(cache_key)
|
|
||||||
|
|
||||||
if cached_result is not None:
|
|
||||||
result = cached_result
|
|
||||||
else:
|
|
||||||
# 执行工具
|
|
||||||
result = registry.execute(name, args)
|
|
||||||
self._set_cached(cache_key, result)
|
|
||||||
|
|
||||||
# 记录调用
|
|
||||||
self._record_call(name, args, result)
|
|
||||||
|
|
||||||
# 创建结果消息
|
|
||||||
results.append(self._create_tool_result(call_id, name, result))
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
def process_tool_calls_parallel(
|
|
||||||
self,
|
|
||||||
tool_calls: List[Dict[str, Any]],
|
|
||||||
context: Optional[Dict[str, Any]] = None,
|
|
||||||
max_workers: Optional[int] = None
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""并行处理工具调用"""
|
|
||||||
if len(tool_calls) <= 1:
|
|
||||||
return self.process_tool_calls(tool_calls, context)
|
|
||||||
|
|
||||||
workers = max_workers or self.max_workers
|
|
||||||
results = [None] * len(tool_calls)
|
|
||||||
exec_tasks = {}
|
|
||||||
|
|
||||||
# 解析所有参数
|
|
||||||
for i, call in enumerate(tool_calls):
|
|
||||||
try:
|
|
||||||
name = call.get("function", {}).get("name", "")
|
|
||||||
args_str = call.get("function", {}).get("arguments", "{}")
|
|
||||||
call_id = call.get("id", "")
|
|
||||||
args = json.loads(args_str) if isinstance(args_str, str) else args_str
|
|
||||||
exec_tasks[i] = (call_id, name, args)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
results[i] = self._create_error_result(
|
|
||||||
call.get("id", ""),
|
|
||||||
call.get("function", {}).get("name", ""),
|
|
||||||
"Invalid JSON"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 并行执行
|
|
||||||
def run(call_id: str, name: str, args: dict) -> Dict[str, Any]:
|
|
||||||
# 检查缓存
|
|
||||||
cache_key = self._make_cache_key(name, args)
|
|
||||||
cached_result = self._get_cached(cache_key)
|
|
||||||
|
|
||||||
if cached_result is not None:
|
|
||||||
result = cached_result
|
|
||||||
else:
|
|
||||||
result = registry.execute(name, args)
|
|
||||||
self._set_cached(cache_key, result)
|
|
||||||
|
|
||||||
self._record_call(name, args, result)
|
|
||||||
return self._create_tool_result(call_id, name, result)
|
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
|
||||||
futures = {
|
|
||||||
pool.submit(run, cid, n, a): i
|
|
||||||
for i, (cid, n, a) in exec_tasks.items()
|
|
||||||
}
|
|
||||||
for future in as_completed(futures):
|
|
||||||
idx = futures[future]
|
|
||||||
try:
|
|
||||||
results[idx] = future.result()
|
|
||||||
except Exception as e:
|
|
||||||
results[idx] = self._create_error_result(
|
|
||||||
exec_tasks[idx][0] if idx in exec_tasks else "",
|
|
||||||
exec_tasks[idx][1] if idx in exec_tasks else "",
|
|
||||||
str(e)
|
|
||||||
)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
def _create_tool_result(self, call_id: str, name: str, result: Dict) -> Dict[str, Any]:
|
|
||||||
"""创建工具结果消息"""
|
|
||||||
return {
|
|
||||||
"role": "tool",
|
|
||||||
"tool_call_id": call_id,
|
|
||||||
"name": name,
|
|
||||||
"content": json.dumps(result, ensure_ascii=False, default=str)
|
|
||||||
}
|
|
||||||
|
|
||||||
def _create_error_result(self, call_id: str, name: str, error: str) -> Dict[str, Any]:
|
|
||||||
"""创建错误结果消息"""
|
|
||||||
return {
|
|
||||||
"role": "tool",
|
|
||||||
"tool_call_id": call_id,
|
|
||||||
"name": name,
|
|
||||||
"content": json.dumps({"success": False, "error": error})
|
|
||||||
}
|
|
||||||
|
|
||||||
def clear_cache(self) -> None:
|
|
||||||
"""清空缓存"""
|
|
||||||
self._cache.clear()
|
|
||||||
|
|
||||||
def get_history(self, limit: int = 100) -> List[Dict[str, Any]]:
|
|
||||||
"""获取调用历史"""
|
|
||||||
return self._call_history[-limit:]
|
|
||||||
|
|
@ -1,22 +0,0 @@
|
||||||
"""工具函数模块"""
|
|
||||||
from alcor.utils.helpers import (
|
|
||||||
generate_id,
|
|
||||||
hash_password,
|
|
||||||
verify_password,
|
|
||||||
create_access_token,
|
|
||||||
decode_access_token,
|
|
||||||
success_response,
|
|
||||||
error_response,
|
|
||||||
paginate
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"generate_id",
|
|
||||||
"hash_password",
|
|
||||||
"verify_password",
|
|
||||||
"create_access_token",
|
|
||||||
"decode_access_token",
|
|
||||||
"success_response",
|
|
||||||
"error_response",
|
|
||||||
"paginate"
|
|
||||||
]
|
|
||||||
|
|
@ -12,7 +12,7 @@
|
||||||
## 目录结构
|
## 目录结构
|
||||||
|
|
||||||
```
|
```
|
||||||
alcor/
|
luxx/
|
||||||
├── __init__.py # FastAPI 应用工厂
|
├── __init__.py # FastAPI 应用工厂
|
||||||
├── run.py # 入口文件
|
├── run.py # 入口文件
|
||||||
├── config.py # 配置管理(YAML)
|
├── config.py # 配置管理(YAML)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,57 @@
|
||||||
|
"""FastAPI application factory"""
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from luxx.config import config
|
||||||
|
from luxx.database import init_db
|
||||||
|
from luxx.routes import api_router
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""Application lifespan manager"""
|
||||||
|
init_db()
|
||||||
|
from luxx.tools.builtin import crawler, code, data
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def create_app() -> FastAPI:
|
||||||
|
"""Create FastAPI application"""
|
||||||
|
app = FastAPI(
|
||||||
|
title="luxx API",
|
||||||
|
description="Intelligent chat backend API with multi-model support, streaming responses, and tool calling",
|
||||||
|
version="1.0.0",
|
||||||
|
lifespan=lifespan
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure CORS
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"], # Should be restricted in production
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register routes
|
||||||
|
app.include_router(api_router, prefix="/api")
|
||||||
|
|
||||||
|
# Health check
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
return {"status": "healthy", "service": "luxx"}
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
return {
|
||||||
|
"service": "luxx API",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"docs": "/docs"
|
||||||
|
}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
# Create application instance
|
||||||
|
app = create_app()
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
"""配置管理模块"""
|
"""Configuration management module"""
|
||||||
import os
|
import os
|
||||||
import yaml
|
import yaml
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
@ -6,7 +6,7 @@ from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""配置类(单例模式)"""
|
"""Configuration class (singleton pattern)"""
|
||||||
|
|
||||||
_instance: Optional["Config"] = None
|
_instance: Optional["Config"] = None
|
||||||
_config: Dict[str, Any] = {}
|
_config: Dict[str, Any] = {}
|
||||||
|
|
@ -18,7 +18,7 @@ class Config:
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def _load_config(self) -> None:
|
def _load_config(self) -> None:
|
||||||
"""加载配置文件"""
|
"""Load configuration from YAML file"""
|
||||||
yaml_paths = [
|
yaml_paths = [
|
||||||
Path("config.yaml"),
|
Path("config.yaml"),
|
||||||
Path(__file__).parent.parent / "config.yaml",
|
Path(__file__).parent.parent / "config.yaml",
|
||||||
|
|
@ -35,7 +35,7 @@ class Config:
|
||||||
self._config = {}
|
self._config = {}
|
||||||
|
|
||||||
def _resolve_env_vars(self) -> None:
|
def _resolve_env_vars(self) -> None:
|
||||||
"""解析环境变量引用"""
|
"""Resolve environment variable references"""
|
||||||
def resolve(value: Any) -> Any:
|
def resolve(value: Any) -> Any:
|
||||||
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
|
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
|
||||||
return os.environ.get(value[2:-1], "")
|
return os.environ.get(value[2:-1], "")
|
||||||
|
|
@ -48,7 +48,7 @@ class Config:
|
||||||
self._config = resolve(self._config)
|
self._config = resolve(self._config)
|
||||||
|
|
||||||
def get(self, key: str, default: Any = None) -> Any:
|
def get(self, key: str, default: Any = None) -> Any:
|
||||||
"""获取配置值,支持点号分隔的键"""
|
"""Get configuration value, supports dot-separated keys"""
|
||||||
keys = key.split(".")
|
keys = key.split(".")
|
||||||
value = self._config
|
value = self._config
|
||||||
for k in keys:
|
for k in keys:
|
||||||
|
|
@ -60,7 +60,7 @@ class Config:
|
||||||
return default
|
return default
|
||||||
return value
|
return value
|
||||||
|
|
||||||
# App配置
|
# App configuration
|
||||||
@property
|
@property
|
||||||
def secret_key(self) -> str:
|
def secret_key(self) -> str:
|
||||||
return self.get("app.secret_key", "change-me-in-production")
|
return self.get("app.secret_key", "change-me-in-production")
|
||||||
|
|
@ -77,12 +77,12 @@ class Config:
|
||||||
def app_port(self) -> int:
|
def app_port(self) -> int:
|
||||||
return self.get("app.port", 8000)
|
return self.get("app.port", 8000)
|
||||||
|
|
||||||
# 数据库配置
|
# Database configuration
|
||||||
@property
|
@property
|
||||||
def database_url(self) -> str:
|
def database_url(self) -> str:
|
||||||
return self.get("database.url", "sqlite:///./chat.db")
|
return self.get("database.url", "sqlite:///./chat.db")
|
||||||
|
|
||||||
# LLM配置
|
# LLM configuration
|
||||||
@property
|
@property
|
||||||
def llm_api_key(self) -> str:
|
def llm_api_key(self) -> str:
|
||||||
return self.get("llm.api_key", "") or os.environ.get("DEEPSEEK_API_KEY", "")
|
return self.get("llm.api_key", "") or os.environ.get("DEEPSEEK_API_KEY", "")
|
||||||
|
|
@ -95,7 +95,7 @@ class Config:
|
||||||
def llm_provider(self) -> str:
|
def llm_provider(self) -> str:
|
||||||
return self.get("llm.provider", "deepseek")
|
return self.get("llm.provider", "deepseek")
|
||||||
|
|
||||||
# 工具配置
|
# Tools configuration
|
||||||
@property
|
@property
|
||||||
def tools_enable_cache(self) -> bool:
|
def tools_enable_cache(self) -> bool:
|
||||||
return self.get("tools.enable_cache", True)
|
return self.get("tools.enable_cache", True)
|
||||||
|
|
@ -113,5 +113,5 @@ class Config:
|
||||||
return self.get("tools.max_iterations", 10)
|
return self.get("tools.max_iterations", 10)
|
||||||
|
|
||||||
|
|
||||||
# 全局配置实例
|
# Global configuration instance
|
||||||
config = Config()
|
config = Config()
|
||||||
|
|
@ -1,29 +1,27 @@
|
||||||
"""数据库连接模块"""
|
"""Database connection module"""
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.orm import sessionmaker, declarative_base, Mapped
|
||||||
from sqlalchemy.orm import sessionmaker, Session
|
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
|
|
||||||
from alcor.config import config
|
from luxx.config import config
|
||||||
|
|
||||||
|
|
||||||
# 创建数据库引擎
|
# Create database engine
|
||||||
engine = create_engine(
|
engine = create_engine(
|
||||||
config.database_url,
|
config.database_url,
|
||||||
connect_args={"check_same_thread": False} if "sqlite" in config.database_url else {},
|
connect_args={"check_same_thread": False} if "sqlite" in config.database_url else {},
|
||||||
pool_pre_ping=True,
|
|
||||||
echo=config.debug
|
echo=config.debug
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建会话工厂
|
# Create session factory
|
||||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
|
||||||
# 创建基类
|
# Create base class
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
def get_db() -> Generator[Session, None, None]:
|
def get_db() -> Generator:
|
||||||
"""获取数据库会话的依赖项"""
|
"""Dependency to get database session"""
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
yield db
|
yield db
|
||||||
|
|
@ -32,5 +30,5 @@ def get_db() -> Generator[Session, None, None]:
|
||||||
|
|
||||||
|
|
||||||
def init_db() -> None:
|
def init_db() -> None:
|
||||||
"""初始化数据库,创建所有表"""
|
"""Initialize database, create all tables"""
|
||||||
Base.metadata.create_all(bind=engine)
|
Base.metadata.create_all(bind=engine)
|
||||||
|
|
@ -1,47 +1,34 @@
|
||||||
"""ORM模型定义"""
|
"""ORM model definitions"""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from sqlalchemy import String, Text, Integer, Float, Boolean, DateTime, ForeignKey
|
from sqlalchemy import String, Integer, Boolean, Float, Text, DateTime, ForeignKey
|
||||||
from sqlalchemy.orm import relationship, Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column, relationship, DeclarativeBase
|
||||||
|
|
||||||
from alcor.database import Base
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Project(Base):
|
class Project(Base):
|
||||||
"""项目模型"""
|
"""Project model"""
|
||||||
__tablename__ = "projects"
|
__tablename__ = "projects"
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||||
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, index=True)
|
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False)
|
||||||
name: Mapped[str] = mapped_column(String(255), default="")
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||||
updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||||
|
|
||||||
# 关系
|
# Relationships
|
||||||
user: Mapped["User"] = relationship("User", backref="projects")
|
user: Mapped["User"] = relationship("User", backref="projects")
|
||||||
conversations: Mapped[List["Conversation"]] = relationship(
|
|
||||||
"Conversation",
|
|
||||||
back_populates="project",
|
|
||||||
lazy="dynamic"
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
return {
|
|
||||||
"id": self.id,
|
|
||||||
"user_id": self.user_id,
|
|
||||||
"name": self.name,
|
|
||||||
"description": self.description,
|
|
||||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
|
||||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class User(Base):
|
class User(Base):
|
||||||
"""用户模型"""
|
"""User model"""
|
||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||||
username: Mapped[str] = mapped_column(String(50), unique=True, nullable=False)
|
username: Mapped[str] = mapped_column(String(50), unique=True, nullable=False)
|
||||||
email: Mapped[Optional[str]] = mapped_column(String(120), unique=True, nullable=True)
|
email: Mapped[Optional[str]] = mapped_column(String(120), unique=True, nullable=True)
|
||||||
password_hash: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
password_hash: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||||
|
|
@ -49,14 +36,12 @@ class User(Base):
|
||||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||||
|
|
||||||
# 关系
|
# Relationships
|
||||||
conversations: Mapped[List["Conversation"]] = relationship(
|
conversations: Mapped[List["Conversation"]] = relationship(
|
||||||
"Conversation",
|
"Conversation", back_populates="user", cascade="all, delete-orphan"
|
||||||
back_populates="user",
|
|
||||||
lazy="dynamic"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self):
|
||||||
return {
|
return {
|
||||||
"id": self.id,
|
"id": self.id,
|
||||||
"username": self.username,
|
"username": self.username,
|
||||||
|
|
@ -68,33 +53,28 @@ class User(Base):
|
||||||
|
|
||||||
|
|
||||||
class Conversation(Base):
|
class Conversation(Base):
|
||||||
"""会话模型"""
|
"""Conversation model"""
|
||||||
__tablename__ = "conversations"
|
__tablename__ = "conversations"
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||||
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, index=True)
|
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False)
|
||||||
project_id: Mapped[Optional[str]] = mapped_column(String(64), ForeignKey("projects.id"), nullable=True)
|
project_id: Mapped[Optional[str]] = mapped_column(String(64), ForeignKey("projects.id"), nullable=True)
|
||||||
title: Mapped[str] = mapped_column(String(255), default="")
|
title: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
model: Mapped[str] = mapped_column(String(64), default="glm-5")
|
model: Mapped[str] = mapped_column(String(64), nullable=False, default="deepseek-chat")
|
||||||
system_prompt: Mapped[str] = mapped_column(Text, default="")
|
system_prompt: Mapped[str] = mapped_column(Text, nullable=False, default="You are a helpful assistant.")
|
||||||
temperature: Mapped[float] = mapped_column(Float, default=1.0)
|
temperature: Mapped[float] = mapped_column(Float, default=0.7)
|
||||||
max_tokens: Mapped[int] = mapped_column(Integer, default=65536)
|
max_tokens: Mapped[int] = mapped_column(Integer, default=2000)
|
||||||
thinking_enabled: Mapped[bool] = mapped_column(Boolean, default=False)
|
thinking_enabled: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||||
updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||||
|
|
||||||
# 关系
|
# Relationships
|
||||||
user: Mapped["User"] = relationship("User", back_populates="conversations")
|
user: Mapped["User"] = relationship("User", back_populates="conversations")
|
||||||
project: Mapped[Optional["Project"]] = relationship("Project", back_populates="conversations")
|
|
||||||
messages: Mapped[List["Message"]] = relationship(
|
messages: Mapped[List["Message"]] = relationship(
|
||||||
"Message",
|
"Message", back_populates="conversation", cascade="all, delete-orphan"
|
||||||
back_populates="conversation",
|
|
||||||
lazy="dynamic",
|
|
||||||
cascade="all, delete-orphan",
|
|
||||||
order_by="Message.created_at.asc()"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self):
|
||||||
return {
|
return {
|
||||||
"id": self.id,
|
"id": self.id,
|
||||||
"user_id": self.user_id,
|
"user_id": self.user_id,
|
||||||
|
|
@ -111,25 +91,20 @@ class Conversation(Base):
|
||||||
|
|
||||||
|
|
||||||
class Message(Base):
|
class Message(Base):
|
||||||
"""消息模型"""
|
"""Message model"""
|
||||||
__tablename__ = "messages"
|
__tablename__ = "messages"
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||||
conversation_id: Mapped[str] = mapped_column(
|
conversation_id: Mapped[str] = mapped_column(String(64), ForeignKey("conversations.id"), nullable=False)
|
||||||
String(64),
|
role: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||||
ForeignKey("conversations.id"),
|
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
nullable=False,
|
|
||||||
index=True
|
|
||||||
)
|
|
||||||
role: Mapped[str] = mapped_column(String(16), nullable=False) # user, assistant, system, tool
|
|
||||||
content: Mapped[str] = mapped_column(Text, default="") # JSON: {text, steps, tool_calls}
|
|
||||||
token_count: Mapped[int] = mapped_column(Integer, default=0)
|
token_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, index=True)
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||||
|
|
||||||
# 关系
|
# Relationships
|
||||||
conversation: Mapped["Conversation"] = relationship("Conversation", back_populates="messages")
|
conversation: Mapped["Conversation"] = relationship("Conversation", back_populates="messages")
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self):
|
||||||
return {
|
return {
|
||||||
"id": self.id,
|
"id": self.id,
|
||||||
"conversation_id": self.conversation_id,
|
"conversation_id": self.conversation_id,
|
||||||
|
|
@ -1,14 +1,13 @@
|
||||||
"""API路由模块"""
|
"""API routes module"""
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
from alcor.routes import auth, conversations, messages, tools
|
from luxx.routes import auth, conversations, messages, tools
|
||||||
|
|
||||||
|
|
||||||
api_router = APIRouter()
|
api_router = APIRouter()
|
||||||
|
|
||||||
# 注册子路由
|
# Register sub-routes
|
||||||
api_router.include_router(auth.router)
|
api_router.include_router(auth.router)
|
||||||
api_router.include_router(conversations.router)
|
api_router.include_router(conversations.router)
|
||||||
api_router.include_router(messages.router)
|
api_router.include_router(messages.router)
|
||||||
api_router.include_router(tools.router)
|
api_router.include_router(tools.router)
|
||||||
|
|
||||||
__all__ = ["api_router"]
|
|
||||||
|
|
@ -1,134 +1,112 @@
|
||||||
"""认证路由"""
|
"""Authentication routes"""
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Optional
|
from fastapi import APIRouter, Depends, status
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
|
||||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||||
from pydantic import BaseModel, EmailStr
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from alcor.database import get_db
|
from luxx.database import get_db
|
||||||
from alcor.models import User
|
from luxx.models import User
|
||||||
from alcor.utils.helpers import (
|
from luxx.utils.helpers import (
|
||||||
hash_password,
|
hash_password,
|
||||||
verify_password,
|
verify_password,
|
||||||
create_access_token,
|
create_access_token,
|
||||||
|
decode_access_token,
|
||||||
success_response,
|
success_response,
|
||||||
error_response
|
error_response
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/auth", tags=["认证"])
|
router = APIRouter(prefix="/auth", tags=["Authentication"])
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
|
||||||
|
|
||||||
|
|
||||||
class UserRegister(BaseModel):
|
class UserRegister(BaseModel):
|
||||||
"""用户注册模型"""
|
"""User registration model"""
|
||||||
username: str
|
username: str
|
||||||
email: Optional[EmailStr] = None
|
email: str | None = None
|
||||||
password: str
|
password: str
|
||||||
|
|
||||||
|
|
||||||
class UserLogin(BaseModel):
|
class UserLogin(BaseModel):
|
||||||
"""用户登录模型"""
|
"""User login model"""
|
||||||
username: str
|
username: str
|
||||||
password: str
|
password: str
|
||||||
|
|
||||||
|
|
||||||
class UserResponse(BaseModel):
|
class UserResponse(BaseModel):
|
||||||
"""用户响应模型"""
|
"""User response model"""
|
||||||
id: int
|
id: int
|
||||||
username: str
|
username: str
|
||||||
email: Optional[str] = None
|
email: str | None
|
||||||
role: str
|
role: str
|
||||||
is_active: bool
|
|
||||||
|
|
||||||
|
|
||||||
class TokenResponse(BaseModel):
|
class TokenResponse(BaseModel):
|
||||||
"""令牌响应模型"""
|
"""Token response model"""
|
||||||
access_token: str
|
access_token: str
|
||||||
token_type: str = "bearer"
|
token_type: str
|
||||||
|
|
||||||
|
|
||||||
def get_current_user(
|
def get_current_user(
|
||||||
token: str = Depends(oauth2_scheme),
|
token: str = Depends(oauth2_scheme),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
) -> User:
|
) -> User:
|
||||||
"""获取当前用户"""
|
"""Get current user"""
|
||||||
from alcor.utils.helpers import decode_access_token
|
|
||||||
|
|
||||||
payload = decode_access_token(token)
|
payload = decode_access_token(token)
|
||||||
if payload is None:
|
if not payload:
|
||||||
raise HTTPException(
|
raise status.HTTP_401_UNAUTHORIZED
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="无效的认证凭证",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
|
|
||||||
user_id = payload.get("sub")
|
user_id = payload.get("sub")
|
||||||
if user_id is None:
|
if not user_id:
|
||||||
raise HTTPException(
|
raise status.HTTP_401_UNAUTHORIZED
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
user = db.query(User).filter(User.id == int(user_id)).first()
|
||||||
detail="无效的认证凭证"
|
if not user:
|
||||||
)
|
raise status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
user = db.query(User).filter(User.id == user_id).first()
|
|
||||||
if user is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="用户不存在"
|
|
||||||
)
|
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
@router.post("/register", response_model=dict)
|
@router.post("/register", response_model=dict)
|
||||||
def register(user_data: UserRegister, db: Session = Depends(get_db)):
|
def register(user_data: UserRegister, db: Session = Depends(get_db)):
|
||||||
"""用户注册"""
|
"""User registration"""
|
||||||
# 检查用户名是否存在
|
|
||||||
existing_user = db.query(User).filter(User.username == user_data.username).first()
|
existing_user = db.query(User).filter(User.username == user_data.username).first()
|
||||||
if existing_user:
|
if existing_user:
|
||||||
return error_response("用户名已存在", 400)
|
return error_response("Username already exists", 400)
|
||||||
|
|
||||||
# 检查邮箱是否存在
|
|
||||||
if user_data.email:
|
if user_data.email:
|
||||||
existing_email = db.query(User).filter(User.email == user_data.email).first()
|
existing_email = db.query(User).filter(User.email == user_data.email).first()
|
||||||
if existing_email:
|
if existing_email:
|
||||||
return error_response("邮箱已被注册", 400)
|
return error_response("Email already registered", 400)
|
||||||
|
|
||||||
# 创建用户
|
|
||||||
password_hash = hash_password(user_data.password)
|
password_hash = hash_password(user_data.password)
|
||||||
user = User(
|
user = User(
|
||||||
username=user_data.username,
|
username=user_data.username,
|
||||||
email=user_data.email,
|
email=user_data.email,
|
||||||
password_hash=password_hash,
|
password_hash=password_hash
|
||||||
role="user"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
db.add(user)
|
db.add(user)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(user)
|
db.refresh(user)
|
||||||
|
|
||||||
return success_response(
|
return success_response(
|
||||||
data={"id": user.id, "username": user.username},
|
data={"id": user.id, "username": user.username},
|
||||||
message="注册成功"
|
message="Registration successful"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/login", response_model=dict)
|
@router.post("/login", response_model=dict)
|
||||||
def login(user_data: UserLogin, db: Session = Depends(get_db)):
|
def login(user_data: UserLogin, db: Session = Depends(get_db)):
|
||||||
"""用户登录"""
|
"""User login"""
|
||||||
user = db.query(User).filter(User.username == user_data.username).first()
|
user = db.query(User).filter(User.username == user_data.username).first()
|
||||||
|
|
||||||
if not user or not verify_password(user_data.password, user.password_hash or ""):
|
if not user or not verify_password(user_data.password, user.password_hash or ""):
|
||||||
return error_response("用户名或密码错误", 401)
|
return error_response("Invalid username or password", 401)
|
||||||
|
|
||||||
if not user.is_active:
|
if not user.is_active:
|
||||||
return error_response("用户已被禁用", 403)
|
return error_response("User account is disabled", 403)
|
||||||
|
|
||||||
# 创建访问令牌
|
|
||||||
access_token = create_access_token(
|
access_token = create_access_token(
|
||||||
data={"sub": user.id, "username": user.username},
|
data={"sub": str(user.id)},
|
||||||
expires_delta=timedelta(days=7)
|
expires_delta=timedelta(days=7)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -138,17 +116,17 @@ def login(user_data: UserLogin, db: Session = Depends(get_db)):
|
||||||
"token_type": "bearer",
|
"token_type": "bearer",
|
||||||
"user": user.to_dict()
|
"user": user.to_dict()
|
||||||
},
|
},
|
||||||
message="登录成功"
|
message="Login successful"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/logout")
|
@router.post("/logout")
|
||||||
def logout(current_user: User = Depends(get_current_user)):
|
def logout(current_user: User = Depends(get_current_user)):
|
||||||
"""用户登出(前端清除令牌即可)"""
|
"""User logout (client should delete token)"""
|
||||||
return success_response(message="登出成功")
|
return success_response(message="Logout successful")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me", response_model=dict)
|
@router.get("/me", response_model=dict)
|
||||||
def get_me(current_user: User = Depends(get_current_user)):
|
def get_me(current_user: User = Depends(get_current_user)):
|
||||||
"""获取当前用户信息"""
|
"""Get current user info"""
|
||||||
return success_response(data=current_user.to_dict())
|
return success_response(data=current_user.to_dict())
|
||||||
|
|
@ -1,31 +1,31 @@
|
||||||
"""会话路由"""
|
"""Conversation routes"""
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from alcor.database import get_db
|
from luxx.database import get_db
|
||||||
from alcor.models import Conversation, User
|
from luxx.models import Conversation, User
|
||||||
from alcor.routes.auth import get_current_user
|
from luxx.routes.auth import get_current_user
|
||||||
from alcor.utils.helpers import generate_id, success_response, error_response, paginate
|
from luxx.utils.helpers import generate_id, success_response, error_response, paginate
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/conversations", tags=["会话"])
|
router = APIRouter(prefix="/conversations", tags=["Conversations"])
|
||||||
|
|
||||||
|
|
||||||
class ConversationCreate(BaseModel):
|
class ConversationCreate(BaseModel):
|
||||||
"""创建会话模型"""
|
"""Create conversation model"""
|
||||||
project_id: Optional[str] = None
|
project_id: Optional[str] = None
|
||||||
title: str = ""
|
title: Optional[str] = None
|
||||||
model: str = "glm-5"
|
model: str = "deepseek-chat"
|
||||||
system_prompt: str = ""
|
system_prompt: str = "You are a helpful assistant."
|
||||||
temperature: float = 1.0
|
temperature: float = 0.7
|
||||||
max_tokens: int = 65536
|
max_tokens: int = 2000
|
||||||
thinking_enabled: bool = False
|
thinking_enabled: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ConversationUpdate(BaseModel):
|
class ConversationUpdate(BaseModel):
|
||||||
"""更新会话模型"""
|
"""Update conversation model"""
|
||||||
title: Optional[str] = None
|
title: Optional[str] = None
|
||||||
model: Optional[str] = None
|
model: Optional[str] = None
|
||||||
system_prompt: Optional[str] = None
|
system_prompt: Optional[str] = None
|
||||||
|
|
@ -36,25 +36,16 @@ class ConversationUpdate(BaseModel):
|
||||||
|
|
||||||
@router.get("/", response_model=dict)
|
@router.get("/", response_model=dict)
|
||||||
def list_conversations(
|
def list_conversations(
|
||||||
project_id: Optional[str] = None,
|
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 20,
|
page_size: int = 20,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""获取会话列表"""
|
"""Get conversation list"""
|
||||||
query = db.query(Conversation).filter(Conversation.user_id == current_user.id)
|
query = db.query(Conversation).filter(Conversation.user_id == current_user.id)
|
||||||
|
result = paginate(query.order_by(Conversation.updated_at.desc()), page, page_size)
|
||||||
if project_id:
|
|
||||||
query = query.filter(Conversation.project_id == project_id)
|
|
||||||
|
|
||||||
query = query.order_by(Conversation.updated_at.desc())
|
|
||||||
|
|
||||||
result = paginate(query, page, page_size)
|
|
||||||
items = [conv.to_dict() for conv in result["items"]]
|
|
||||||
|
|
||||||
return success_response(data={
|
return success_response(data={
|
||||||
"items": items,
|
"conversations": [c.to_dict() for c in result["items"]],
|
||||||
"total": result["total"],
|
"total": result["total"],
|
||||||
"page": result["page"],
|
"page": result["page"],
|
||||||
"page_size": result["page_size"]
|
"page_size": result["page_size"]
|
||||||
|
|
@ -67,12 +58,12 @@ def create_conversation(
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""创建会话"""
|
"""Create conversation"""
|
||||||
conversation = Conversation(
|
conversation = Conversation(
|
||||||
id=generate_id("conv"),
|
id=generate_id("conv"),
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
project_id=data.project_id,
|
project_id=data.project_id,
|
||||||
title=data.title or "新会话",
|
title=data.title or "New Conversation",
|
||||||
model=data.model,
|
model=data.model,
|
||||||
system_prompt=data.system_prompt,
|
system_prompt=data.system_prompt,
|
||||||
temperature=data.temperature,
|
temperature=data.temperature,
|
||||||
|
|
@ -84,7 +75,7 @@ def create_conversation(
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(conversation)
|
db.refresh(conversation)
|
||||||
|
|
||||||
return success_response(data=conversation.to_dict(), message="会话创建成功")
|
return success_response(data=conversation.to_dict(), message="Conversation created successfully")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{conversation_id}", response_model=dict)
|
@router.get("/{conversation_id}", response_model=dict)
|
||||||
|
|
@ -93,14 +84,14 @@ def get_conversation(
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""获取会话详情"""
|
"""Get conversation details"""
|
||||||
conversation = db.query(Conversation).filter(
|
conversation = db.query(Conversation).filter(
|
||||||
Conversation.id == conversation_id,
|
Conversation.id == conversation_id,
|
||||||
Conversation.user_id == current_user.id
|
Conversation.user_id == current_user.id
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not conversation:
|
if not conversation:
|
||||||
return error_response("会话不存在", 404)
|
return error_response("Conversation not found", 404)
|
||||||
|
|
||||||
return success_response(data=conversation.to_dict())
|
return success_response(data=conversation.to_dict())
|
||||||
|
|
||||||
|
|
@ -112,16 +103,15 @@ def update_conversation(
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""更新会话"""
|
"""Update conversation"""
|
||||||
conversation = db.query(Conversation).filter(
|
conversation = db.query(Conversation).filter(
|
||||||
Conversation.id == conversation_id,
|
Conversation.id == conversation_id,
|
||||||
Conversation.user_id == current_user.id
|
Conversation.user_id == current_user.id
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not conversation:
|
if not conversation:
|
||||||
return error_response("会话不存在", 404)
|
return error_response("Conversation not found", 404)
|
||||||
|
|
||||||
# 更新字段
|
|
||||||
update_data = data.dict(exclude_unset=True)
|
update_data = data.dict(exclude_unset=True)
|
||||||
for key, value in update_data.items():
|
for key, value in update_data.items():
|
||||||
setattr(conversation, key, value)
|
setattr(conversation, key, value)
|
||||||
|
|
@ -129,7 +119,7 @@ def update_conversation(
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(conversation)
|
db.refresh(conversation)
|
||||||
|
|
||||||
return success_response(data=conversation.to_dict(), message="会话更新成功")
|
return success_response(data=conversation.to_dict(), message="Conversation updated successfully")
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{conversation_id}", response_model=dict)
|
@router.delete("/{conversation_id}", response_model=dict)
|
||||||
|
|
@ -138,16 +128,16 @@ def delete_conversation(
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""删除会话"""
|
"""Delete conversation"""
|
||||||
conversation = db.query(Conversation).filter(
|
conversation = db.query(Conversation).filter(
|
||||||
Conversation.id == conversation_id,
|
Conversation.id == conversation_id,
|
||||||
Conversation.user_id == current_user.id
|
Conversation.user_id == current_user.id
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not conversation:
|
if not conversation:
|
||||||
return error_response("会话不存在", 404)
|
return error_response("Conversation not found", 404)
|
||||||
|
|
||||||
db.delete(conversation)
|
db.delete(conversation)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
return success_response(message="会话删除成功")
|
return success_response(message="Conversation deleted successfully")
|
||||||
|
|
@ -1,120 +1,103 @@
|
||||||
"""消息路由"""
|
"""Message routes"""
|
||||||
import json
|
import json
|
||||||
from typing import Optional, List
|
from fastapi import APIRouter, Depends, Response
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from alcor.database import get_db
|
from luxx.database import get_db
|
||||||
from alcor.models import Conversation, Message, User
|
from luxx.models import Conversation, Message, User
|
||||||
from alcor.routes.auth import get_current_user
|
from luxx.routes.auth import get_current_user
|
||||||
from alcor.services.chat import chat_service
|
from luxx.services.chat import chat_service
|
||||||
from alcor.utils.helpers import generate_id, success_response, error_response
|
from luxx.utils.helpers import generate_id, success_response, error_response
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/messages", tags=["消息"])
|
router = APIRouter(prefix="/messages", tags=["Messages"])
|
||||||
|
|
||||||
|
|
||||||
class MessageCreate(BaseModel):
|
class MessageCreate(BaseModel):
|
||||||
"""创建消息模型"""
|
"""Create message model"""
|
||||||
conversation_id: str
|
conversation_id: str
|
||||||
content: str
|
content: str
|
||||||
tools_enabled: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
class MessageResponse(BaseModel):
|
class MessageResponse(BaseModel):
|
||||||
"""消息响应模型"""
|
"""Message response model"""
|
||||||
id: str
|
id: str
|
||||||
conversation_id: str
|
|
||||||
role: str
|
role: str
|
||||||
content: str
|
content: str
|
||||||
token_count: int
|
token_count: int
|
||||||
created_at: str
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{conversation_id}", response_model=dict)
|
@router.get("/", response_model=dict)
|
||||||
def list_messages(
|
def list_messages(
|
||||||
conversation_id: str,
|
conversation_id: str,
|
||||||
limit: int = 100,
|
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""获取消息列表"""
|
"""Get message list"""
|
||||||
# 验证会话归属
|
|
||||||
conversation = db.query(Conversation).filter(
|
conversation = db.query(Conversation).filter(
|
||||||
Conversation.id == conversation_id,
|
Conversation.id == conversation_id,
|
||||||
Conversation.user_id == current_user.id
|
Conversation.user_id == current_user.id
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not conversation:
|
if not conversation:
|
||||||
return error_response("会话不存在", 404)
|
return error_response("Conversation not found", 404)
|
||||||
|
|
||||||
messages = db.query(Message).filter(
|
messages = db.query(Message).filter(
|
||||||
Message.conversation_id == conversation_id
|
Message.conversation_id == conversation_id
|
||||||
).order_by(Message.created_at.desc()).limit(limit).all()
|
).order_by(Message.created_at).all()
|
||||||
|
|
||||||
items = [msg.to_dict() for msg in reversed(messages)]
|
|
||||||
|
|
||||||
return success_response(data={
|
return success_response(data={
|
||||||
"items": items,
|
"messages": [m.to_dict() for m in messages]
|
||||||
"total": len(items)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
@router.post("/", response_model=dict)
|
@router.post("/", response_model=dict)
|
||||||
async def create_message(
|
def send_message(
|
||||||
data: MessageCreate,
|
data: MessageCreate,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""发送消息(非流式)"""
|
"""Send message (non-streaming)"""
|
||||||
# 验证会话
|
|
||||||
conversation = db.query(Conversation).filter(
|
conversation = db.query(Conversation).filter(
|
||||||
Conversation.id == data.conversation_id,
|
Conversation.id == data.conversation_id,
|
||||||
Conversation.user_id == current_user.id
|
Conversation.user_id == current_user.id
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not conversation:
|
if not conversation:
|
||||||
return error_response("会话不存在", 404)
|
return error_response("Conversation not found", 404)
|
||||||
|
|
||||||
# 保存用户消息
|
|
||||||
user_message = Message(
|
user_message = Message(
|
||||||
id=generate_id("msg"),
|
id=generate_id("msg"),
|
||||||
conversation_id=data.conversation_id,
|
conversation_id=data.conversation_id,
|
||||||
role="user",
|
role="user",
|
||||||
content=json.dumps({"text": data.content})
|
content=data.content,
|
||||||
|
token_count=len(data.content) // 4
|
||||||
)
|
)
|
||||||
db.add(user_message)
|
db.add(user_message)
|
||||||
|
|
||||||
# 更新会话时间
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
conversation.updated_at = datetime.utcnow()
|
conversation.updated_at = datetime.utcnow()
|
||||||
|
|
||||||
db.commit()
|
|
||||||
db.refresh(user_message)
|
|
||||||
|
|
||||||
# 获取AI响应(非流式)
|
|
||||||
response = chat_service.non_stream_response(
|
response = chat_service.non_stream_response(
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
user_message=data.content,
|
user_message=data.content,
|
||||||
tools_enabled=data.tools_enabled
|
tools_enabled=False
|
||||||
)
|
)
|
||||||
|
|
||||||
if not response.get("success"):
|
if not response.get("success"):
|
||||||
return error_response(response.get("error", "生成响应失败"), 500)
|
return error_response(response.get("error", "Failed to generate response"), 500)
|
||||||
|
|
||||||
# 保存AI响应
|
|
||||||
ai_content = response.get("content", "")
|
ai_content = response.get("content", "")
|
||||||
|
|
||||||
ai_message = Message(
|
ai_message = Message(
|
||||||
id=generate_id("msg"),
|
id=generate_id("msg"),
|
||||||
conversation_id=data.conversation_id,
|
conversation_id=data.conversation_id,
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=json.dumps({
|
content=ai_content,
|
||||||
"text": ai_content,
|
token_count=len(ai_content) // 4
|
||||||
"tool_calls": response.get("tool_calls")
|
|
||||||
}),
|
|
||||||
token_count=len(ai_content) // 4 # 粗略估算
|
|
||||||
)
|
)
|
||||||
db.add(ai_message)
|
db.add(ai_message)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
@ -128,77 +111,66 @@ async def create_message(
|
||||||
@router.post("/stream")
|
@router.post("/stream")
|
||||||
async def stream_message(
|
async def stream_message(
|
||||||
data: MessageCreate,
|
data: MessageCreate,
|
||||||
|
tools_enabled: bool = True,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""发送消息(流式响应 - SSE)"""
|
"""Send message (streaming response - SSE)"""
|
||||||
# 验证会话
|
|
||||||
conversation = db.query(Conversation).filter(
|
conversation = db.query(Conversation).filter(
|
||||||
Conversation.id == data.conversation_id,
|
Conversation.id == data.conversation_id,
|
||||||
Conversation.user_id == current_user.id
|
Conversation.user_id == current_user.id
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not conversation:
|
if not conversation:
|
||||||
return error_response("会话不存在", 404)
|
return error_response("Conversation not found", 404)
|
||||||
|
|
||||||
# 保存用户消息
|
|
||||||
user_message = Message(
|
user_message = Message(
|
||||||
id=generate_id("msg"),
|
id=generate_id("msg"),
|
||||||
conversation_id=data.conversation_id,
|
conversation_id=data.conversation_id,
|
||||||
role="user",
|
role="user",
|
||||||
content=json.dumps({"text": data.content})
|
content=data.content,
|
||||||
|
token_count=len(data.content) // 4
|
||||||
)
|
)
|
||||||
db.add(user_message)
|
db.add(user_message)
|
||||||
|
|
||||||
# 更新会话时间
|
|
||||||
from datetime import datetime
|
|
||||||
conversation.updated_at = datetime.utcnow()
|
conversation.updated_at = datetime.utcnow()
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(user_message)
|
|
||||||
|
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
"""SSE事件生成器"""
|
|
||||||
full_response = ""
|
full_response = ""
|
||||||
message_id = generate_id("msg")
|
|
||||||
|
|
||||||
async for event in chat_service.stream_response(
|
async for event in chat_service.stream_response(
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
user_message=data.content,
|
user_message=data.content,
|
||||||
tools_enabled=data.tools_enabled
|
tools_enabled=tools_enabled
|
||||||
):
|
):
|
||||||
event_type = event.get("type")
|
event_type = event.get("type")
|
||||||
|
|
||||||
if event_type == "process_step":
|
if event_type == "text":
|
||||||
step_type = event.get("step_type")
|
content = event.get("content", "")
|
||||||
|
full_response += content
|
||||||
|
yield f"data: {json.dumps({'type': 'text', 'content': content})}\n\n"
|
||||||
|
|
||||||
if step_type == "text":
|
elif event_type == "tool_call":
|
||||||
content = event.get("content", "")
|
yield f"data: {json.dumps({'type': 'tool_call', 'data': event.get('data')})}\n\n"
|
||||||
full_response += content
|
|
||||||
yield f"data: {json.dumps({'type': 'text', 'content': content})}\n\n"
|
|
||||||
|
|
||||||
elif step_type == "tool_call":
|
elif event_type == "tool_result":
|
||||||
yield f"data: {json.dumps({'type': 'tool_call', 'tool_calls': event.get('tool_calls')})}\n\n"
|
yield f"data: {json.dumps({'type': 'tool_result', 'data': event.get('data')})}\n\n"
|
||||||
|
|
||||||
elif step_type == "tool_result":
|
|
||||||
yield f"data: {json.dumps({'type': 'tool_result', 'result': event.get('result')})}\n\n"
|
|
||||||
|
|
||||||
elif event_type == "done":
|
elif event_type == "done":
|
||||||
# 保存AI消息
|
|
||||||
try:
|
try:
|
||||||
ai_message = Message(
|
ai_message = Message(
|
||||||
id=message_id,
|
id=generate_id("msg"),
|
||||||
conversation_id=data.conversation_id,
|
conversation_id=data.conversation_id,
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=json.dumps({"text": full_response}),
|
content=full_response,
|
||||||
token_count=len(full_response) // 4
|
token_count=len(full_response) // 4
|
||||||
)
|
)
|
||||||
db.add(ai_message)
|
db.add(ai_message)
|
||||||
db.commit()
|
db.commit()
|
||||||
except Exception as e:
|
except Exception:
|
||||||
db.rollback()
|
pass
|
||||||
|
|
||||||
yield f"data: {json.dumps({'type': 'done', 'message_id': message_id})}\n\n"
|
yield f"data: {json.dumps({'type': 'done', 'message_id': ai_message.id if 'ai_message' in dir() else None})}\n\n"
|
||||||
|
|
||||||
elif event_type == "error":
|
elif event_type == "error":
|
||||||
yield f"data: {json.dumps({'type': 'error', 'error': event.get('error')})}\n\n"
|
yield f"data: {json.dumps({'type': 'error', 'error': event.get('error')})}\n\n"
|
||||||
|
|
@ -222,17 +194,16 @@ def delete_message(
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""删除消息"""
|
"""Delete message"""
|
||||||
# 获取消息及其会话
|
|
||||||
message = db.query(Message).join(Conversation).filter(
|
message = db.query(Message).join(Conversation).filter(
|
||||||
Message.id == message_id,
|
Message.id == message_id,
|
||||||
Conversation.user_id == current_user.id
|
Conversation.user_id == current_user.id
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not message:
|
if not message:
|
||||||
return error_response("消息不存在", 404)
|
return error_response("Message not found", 404)
|
||||||
|
|
||||||
db.delete(message)
|
db.delete(message)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
return success_response(message="消息删除成功")
|
return success_response(message="Message deleted successfully")
|
||||||
|
|
@ -0,0 +1,63 @@
|
||||||
|
"""Tool routes"""
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from fastapi import APIRouter, Depends, Body
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from luxx.database import get_db
|
||||||
|
from luxx.models import User
|
||||||
|
from luxx.routes.auth import get_current_user
|
||||||
|
from luxx.tools.core import registry
|
||||||
|
from luxx.utils.helpers import success_response
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/tools", tags=["Tools"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/", response_model=dict)
|
||||||
|
def list_tools(
|
||||||
|
category: Optional[str] = None,
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""Get available tools list"""
|
||||||
|
if category:
|
||||||
|
tools = registry.list_by_category(category)
|
||||||
|
else:
|
||||||
|
tools = registry.list_all()
|
||||||
|
|
||||||
|
categorized = {}
|
||||||
|
for tool in tools:
|
||||||
|
cat = tool.get("category", "other")
|
||||||
|
if cat not in categorized:
|
||||||
|
categorized[cat] = []
|
||||||
|
categorized[cat].append(tool)
|
||||||
|
|
||||||
|
return success_response(data={
|
||||||
|
"tools": tools,
|
||||||
|
"categorized": categorized,
|
||||||
|
"total": len(tools)
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{name}", response_model=dict)
|
||||||
|
def get_tool(
|
||||||
|
name: str,
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""Get tool details"""
|
||||||
|
tool = registry.get(name)
|
||||||
|
|
||||||
|
if not tool:
|
||||||
|
return {"success": False, "message": "Tool not found", "code": 404}
|
||||||
|
|
||||||
|
return success_response(data=tool.to_openai_format())
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{name}/execute", response_model=dict)
|
||||||
|
def execute_tool(
|
||||||
|
name: str,
|
||||||
|
arguments: Dict[str, Any] = Body(...),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""Execute tool manually"""
|
||||||
|
result = registry.execute(name, arguments)
|
||||||
|
return result
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""应用入口"""
|
"""Application entry point"""
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from alcor.config import config
|
from luxx.config import config
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""启动应用"""
|
"""Start the application"""
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
"alcor:app",
|
"luxx:app",
|
||||||
host=config.app_host,
|
host=config.app_host,
|
||||||
port=config.app_port,
|
port=config.app_port,
|
||||||
reload=config.debug,
|
reload=config.debug,
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
"""Services module"""
|
||||||
|
from luxx.services.llm_client import LLMClient, llm_client, LLMResponse
|
||||||
|
from luxx.services.chat import ChatService, chat_service
|
||||||
|
|
@ -0,0 +1,194 @@
|
||||||
|
"""Chat service module"""
|
||||||
|
import json
|
||||||
|
from typing import List, Dict, Any, AsyncGenerator
|
||||||
|
|
||||||
|
from luxx.models import Conversation, Message
|
||||||
|
from luxx.tools.executor import ToolExecutor
|
||||||
|
from luxx.tools.core import registry
|
||||||
|
from luxx.services.llm_client import llm_client
|
||||||
|
|
||||||
|
|
||||||
|
# Maximum iterations to prevent infinite loops
|
||||||
|
MAX_ITERATIONS = 10
|
||||||
|
|
||||||
|
|
||||||
|
class ChatService:
|
||||||
|
"""Chat service"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.tool_executor = ToolExecutor()
|
||||||
|
|
||||||
|
def build_messages(
|
||||||
|
self,
|
||||||
|
conversation: Conversation,
|
||||||
|
include_system: bool = True
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
"""Build message list"""
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
if include_system and conversation.system_prompt:
|
||||||
|
messages.append({
|
||||||
|
"role": "system",
|
||||||
|
"content": conversation.system_prompt
|
||||||
|
})
|
||||||
|
|
||||||
|
for msg in conversation.messages.order_by(Message.created_at).all():
|
||||||
|
messages.append({
|
||||||
|
"role": msg.role,
|
||||||
|
"content": msg.content
|
||||||
|
})
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
async def stream_response(
|
||||||
|
self,
|
||||||
|
conversation: Conversation,
|
||||||
|
user_message: str,
|
||||||
|
tools_enabled: bool = True
|
||||||
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||||
|
"""
|
||||||
|
Streaming response generator
|
||||||
|
|
||||||
|
Event types:
|
||||||
|
- process_step: thinking/text/tool_call/tool_result step
|
||||||
|
- done: final response complete
|
||||||
|
- error: on error
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
messages = self.build_messages(conversation)
|
||||||
|
|
||||||
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": user_message
|
||||||
|
})
|
||||||
|
|
||||||
|
tools = registry.list_all() if tools_enabled else None
|
||||||
|
|
||||||
|
iteration = 0
|
||||||
|
|
||||||
|
while iteration < MAX_ITERATIONS:
|
||||||
|
iteration += 1
|
||||||
|
|
||||||
|
tool_calls_this_round = None
|
||||||
|
|
||||||
|
async for event in llm_client.stream_call(
|
||||||
|
model=conversation.model,
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
temperature=conversation.temperature,
|
||||||
|
max_tokens=conversation.max_tokens
|
||||||
|
):
|
||||||
|
event_type = event.get("type")
|
||||||
|
|
||||||
|
if event_type == "content_delta":
|
||||||
|
content = event.get("content", "")
|
||||||
|
if content:
|
||||||
|
yield {"type": "text", "content": content}
|
||||||
|
|
||||||
|
elif event_type == "tool_call_delta":
|
||||||
|
tool_call = event.get("tool_call", {})
|
||||||
|
yield {"type": "tool_call", "data": tool_call}
|
||||||
|
|
||||||
|
elif event_type == "done":
|
||||||
|
tool_calls_this_round = event.get("tool_calls")
|
||||||
|
|
||||||
|
if tool_calls_this_round and tools_enabled:
|
||||||
|
yield {"type": "tool_call", "data": tool_calls_this_round}
|
||||||
|
|
||||||
|
tool_results = self.tool_executor.process_tool_calls_parallel(
|
||||||
|
tool_calls_this_round,
|
||||||
|
{}
|
||||||
|
)
|
||||||
|
|
||||||
|
messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": tool_calls_this_round
|
||||||
|
})
|
||||||
|
|
||||||
|
for tr in tool_results:
|
||||||
|
messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tr.get("tool_call_id"),
|
||||||
|
"content": str(tr.get("result", ""))
|
||||||
|
})
|
||||||
|
|
||||||
|
yield {"type": "tool_result", "data": tool_results}
|
||||||
|
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not tool_calls_this_round or not tools_enabled:
|
||||||
|
break
|
||||||
|
|
||||||
|
yield {"type": "done"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
yield {"type": "error", "error": str(e)}
|
||||||
|
|
||||||
|
def non_stream_response(
|
||||||
|
self,
|
||||||
|
conversation: Conversation,
|
||||||
|
user_message: str,
|
||||||
|
tools_enabled: bool = False
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Non-streaming response"""
|
||||||
|
try:
|
||||||
|
messages = self.build_messages(conversation)
|
||||||
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": user_message
|
||||||
|
})
|
||||||
|
|
||||||
|
tools = registry.list_all() if tools_enabled else None
|
||||||
|
|
||||||
|
iteration = 0
|
||||||
|
|
||||||
|
while iteration < MAX_ITERATIONS:
|
||||||
|
iteration += 1
|
||||||
|
|
||||||
|
response = llm_client.sync_call(
|
||||||
|
model=conversation.model,
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
temperature=conversation.temperature,
|
||||||
|
max_tokens=conversation.max_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_calls = response.tool_calls
|
||||||
|
|
||||||
|
if tool_calls and tools_enabled:
|
||||||
|
messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": response.content,
|
||||||
|
"tool_calls": tool_calls
|
||||||
|
})
|
||||||
|
|
||||||
|
tool_results = self.tool_executor.process_tool_calls_parallel(tool_calls)
|
||||||
|
|
||||||
|
for tr in tool_results:
|
||||||
|
messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tr.get("tool_call_id"),
|
||||||
|
"content": str(tr.get("result", ""))
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"content": response.content
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"content": "Max iterations reached"
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Global chat service
|
||||||
|
chat_service = ChatService()
|
||||||
|
|
@ -0,0 +1,187 @@
|
||||||
|
"""LLM API client"""
|
||||||
|
import json
|
||||||
|
import httpx
|
||||||
|
from typing import Dict, Any, Optional, List, AsyncGenerator
|
||||||
|
|
||||||
|
from luxx.config import config
|
||||||
|
|
||||||
|
|
||||||
|
class LLMResponse:
|
||||||
|
"""LLM response"""
|
||||||
|
content: str
|
||||||
|
tool_calls: Optional[List[Dict]] = None
|
||||||
|
usage: Optional[Dict] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
content: str = "",
|
||||||
|
tool_calls: Optional[List[Dict]] = None,
|
||||||
|
usage: Optional[Dict] = None
|
||||||
|
):
|
||||||
|
self.content = content
|
||||||
|
self.tool_calls = tool_calls
|
||||||
|
self.usage = usage
|
||||||
|
|
||||||
|
|
||||||
|
class LLMClient:
|
||||||
|
"""LLM API client with multi-provider support"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.api_key = config.llm_api_key
|
||||||
|
self.api_url = config.llm_api_url
|
||||||
|
self.provider = self._detect_provider()
|
||||||
|
self._client: Optional[httpx.AsyncClient] = None
|
||||||
|
|
||||||
|
def _detect_provider(self) -> str:
|
||||||
|
"""Detect provider from URL"""
|
||||||
|
url = self.api_url.lower()
|
||||||
|
if "deepseek" in url:
|
||||||
|
return "deepseek"
|
||||||
|
elif "glm" in url or "zhipu" in url:
|
||||||
|
return "glm"
|
||||||
|
elif "openai" in url:
|
||||||
|
return "openai"
|
||||||
|
return "openai"
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""Close client"""
|
||||||
|
if self._client:
|
||||||
|
await self._client.aclose()
|
||||||
|
self._client = None
|
||||||
|
|
||||||
|
def _build_headers(self) -> Dict[str, str]:
|
||||||
|
"""Build request headers"""
|
||||||
|
return {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.api_key}"
|
||||||
|
}
|
||||||
|
|
||||||
|
def _build_body(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[Dict],
|
||||||
|
tools: Optional[List[Dict]] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Build request body"""
|
||||||
|
body = {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"stream": stream
|
||||||
|
}
|
||||||
|
|
||||||
|
if "temperature" in kwargs:
|
||||||
|
body["temperature"] = kwargs["temperature"]
|
||||||
|
|
||||||
|
if "max_tokens" in kwargs:
|
||||||
|
body["max_tokens"] = kwargs["max_tokens"]
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
body["tools"] = tools
|
||||||
|
|
||||||
|
return body
|
||||||
|
|
||||||
|
def _parse_response(self, data: Dict) -> LLMResponse:
|
||||||
|
"""Parse response"""
|
||||||
|
content = ""
|
||||||
|
tool_calls = None
|
||||||
|
usage = None
|
||||||
|
|
||||||
|
if "choices" in data:
|
||||||
|
choice = data["choices"][0]
|
||||||
|
content = choice.get("message", {}).get("content", "")
|
||||||
|
tool_calls = choice.get("message", {}).get("tool_calls")
|
||||||
|
|
||||||
|
if "usage" in data:
|
||||||
|
usage = data["usage"]
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content=content,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
usage=usage
|
||||||
|
)
|
||||||
|
|
||||||
|
async def client(self) -> httpx.AsyncClient:
|
||||||
|
"""Get HTTP client"""
|
||||||
|
if self._client is None:
|
||||||
|
self._client = httpx.AsyncClient(timeout=120.0)
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
async def sync_call(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[Dict],
|
||||||
|
tools: Optional[List[Dict]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""Call LLM API (non-streaming)"""
|
||||||
|
body = self._build_body(model, messages, tools, stream=False, **kwargs)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||||
|
response = await client.post(
|
||||||
|
self.api_url,
|
||||||
|
headers=self._build_headers(),
|
||||||
|
json=body
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
return self._parse_response(data)
|
||||||
|
|
||||||
|
async def stream_call(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[Dict],
|
||||||
|
tools: Optional[List[Dict]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||||
|
"""Stream call LLM API"""
|
||||||
|
body = self._build_body(model, messages, tools, stream=True, **kwargs)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||||
|
async with client.stream(
|
||||||
|
"POST",
|
||||||
|
self.api_url,
|
||||||
|
headers=self._build_headers(),
|
||||||
|
json=body
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
if line.startswith("data: "):
|
||||||
|
data_str = line[6:]
|
||||||
|
|
||||||
|
if data_str == "[DONE]":
|
||||||
|
yield {"type": "done"}
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
chunk = json.loads(data_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "choices" not in chunk:
|
||||||
|
continue
|
||||||
|
|
||||||
|
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
||||||
|
|
||||||
|
content_delta = delta.get("content", "")
|
||||||
|
if content_delta:
|
||||||
|
yield {"type": "content_delta", "content": content_delta}
|
||||||
|
|
||||||
|
tool_calls = delta.get("tool_calls", [])
|
||||||
|
if tool_calls:
|
||||||
|
yield {"type": "tool_call_delta", "tool_call": tool_calls}
|
||||||
|
|
||||||
|
finish_reason = chunk.get("choices", [{}])[0].get("finish_reason")
|
||||||
|
if finish_reason:
|
||||||
|
tool_calls_finish = chunk.get("choices", [{}])[0].get("message", {}).get("tool_calls")
|
||||||
|
yield {"type": "done", "tool_calls": tool_calls_finish}
|
||||||
|
|
||||||
|
|
||||||
|
# Global LLM client
|
||||||
|
llm_client = LLMClient()
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
"""Tool system module"""
|
||||||
|
from luxx.tools.core import (
|
||||||
|
ToolDefinition,
|
||||||
|
ToolResult,
|
||||||
|
ToolRegistry,
|
||||||
|
registry
|
||||||
|
)
|
||||||
|
from luxx.tools.factory import tool, tool_function
|
||||||
|
from luxx.tools.executor import ToolExecutor
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
"""Built-in tools module"""
|
||||||
|
# Import all built-in tools to register them
|
||||||
|
from luxx.tools.builtin import crawler
|
||||||
|
from luxx.tools.builtin import code
|
||||||
|
from luxx.tools.builtin import data
|
||||||
|
|
||||||
|
__all__ = ["crawler", "code", "data"]
|
||||||
|
|
@ -1,10 +1,9 @@
|
||||||
"""代码执行工具"""
|
"""Code execution tools"""
|
||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
import ast
|
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
|
||||||
from alcor.tools.factory import tool
|
from luxx.tools.factory import tool
|
||||||
|
|
||||||
|
|
||||||
@tool(
|
@tool(
|
||||||
|
|
@ -29,10 +28,10 @@ from alcor.tools.factory import tool
|
||||||
)
|
)
|
||||||
def python_execute(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
def python_execute(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
执行Python代码
|
Execute Python code
|
||||||
|
|
||||||
注意:这是一个简化的执行器,生产环境应使用更安全的隔离环境
|
Note: This is a simplified executor, production environments should use safer isolated environments
|
||||||
如:Docker容器、Pyodide等
|
such as: Docker containers, Pyodide, etc.
|
||||||
"""
|
"""
|
||||||
code = arguments.get("code", "")
|
code = arguments.get("code", "")
|
||||||
timeout = arguments.get("timeout", 30)
|
timeout = arguments.get("timeout", 30)
|
||||||
|
|
@ -40,16 +39,16 @@ def python_execute(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
if not code:
|
if not code:
|
||||||
return {"success": False, "error": "Code is required"}
|
return {"success": False, "error": "Code is required"}
|
||||||
|
|
||||||
# 创建执行环境(允许大多数操作)
|
# Create execution environment
|
||||||
namespace = {
|
namespace = {
|
||||||
"__builtins__": __builtins__
|
"__builtins__": __builtins__
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 编译并执行代码
|
# Compile and execute code
|
||||||
compiled = compile(code, "<string>", "exec")
|
compiled = compile(code, "<string>", "exec")
|
||||||
|
|
||||||
# 捕获输出
|
# Capture output
|
||||||
import io
|
import io
|
||||||
from contextlib import redirect_stdout
|
from contextlib import redirect_stdout
|
||||||
|
|
||||||
|
|
@ -60,7 +59,7 @@ def python_execute(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
|
||||||
result = output.getvalue()
|
result = output.getvalue()
|
||||||
|
|
||||||
# 尝试提取变量
|
# Try to extract variables
|
||||||
result_vars = {k: v for k, v in namespace.items()
|
result_vars = {k: v for k, v in namespace.items()
|
||||||
if not k.startswith("_") and k != "__builtins__"}
|
if not k.startswith("_") and k != "__builtins__"}
|
||||||
|
|
||||||
|
|
@ -100,7 +99,7 @@ def python_execute(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
category="code"
|
category="code"
|
||||||
)
|
)
|
||||||
def python_eval(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
def python_eval(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""评估Python表达式"""
|
"""Evaluate Python expression"""
|
||||||
expression = arguments.get("expression", "")
|
expression = arguments.get("expression", "")
|
||||||
|
|
||||||
if not expression:
|
if not expression:
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
"""网页爬虫工具"""
|
"""Web crawler tools"""
|
||||||
import requests
|
import requests
|
||||||
from typing import Dict, Any, List, Optional
|
from typing import Dict, Any, List, Optional
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
from alcor.tools.factory import tool
|
from luxx.tools.factory import tool
|
||||||
|
|
||||||
|
|
||||||
@tool(
|
@tool(
|
||||||
|
|
@ -28,10 +28,10 @@ from alcor.tools.factory import tool
|
||||||
)
|
)
|
||||||
def web_search(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
def web_search(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
执行网络搜索
|
Execute web search
|
||||||
|
|
||||||
注意:这是一个占位实现,实际使用时需要接入真实的搜索API
|
Note: This is a placeholder implementation, real usage requires integrating with actual search APIs
|
||||||
如:Google Custom Search, DuckDuckGo, SerpAPI等
|
such as: Google Custom Search, DuckDuckGo, SerpAPI, etc.
|
||||||
"""
|
"""
|
||||||
query = arguments.get("query", "")
|
query = arguments.get("query", "")
|
||||||
max_results = arguments.get("max_results", 5)
|
max_results = arguments.get("max_results", 5)
|
||||||
|
|
@ -39,8 +39,8 @@ def web_search(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
if not query:
|
if not query:
|
||||||
return {"success": False, "error": "Query is required"}
|
return {"success": False, "error": "Query is required"}
|
||||||
|
|
||||||
# 模拟搜索结果
|
# Simulated search results
|
||||||
# 实际实现应接入真实搜索API
|
# Real implementation should integrate with actual search API
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"data": {
|
"data": {
|
||||||
|
|
@ -78,14 +78,14 @@ def web_search(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
category="crawler"
|
category="crawler"
|
||||||
)
|
)
|
||||||
def web_fetch(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
def web_fetch(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""获取并解析网页内容"""
|
"""Fetch and parse web page content"""
|
||||||
url = arguments.get("url", "")
|
url = arguments.get("url", "")
|
||||||
extract_text = arguments.get("extract_text", True)
|
extract_text = arguments.get("extract_text", True)
|
||||||
|
|
||||||
if not url:
|
if not url:
|
||||||
return {"success": False, "error": "URL is required"}
|
return {"success": False, "error": "URL is required"}
|
||||||
|
|
||||||
# 简单的URL验证
|
# Simple URL validation
|
||||||
if not url.startswith(("http://", "https://")):
|
if not url.startswith(("http://", "https://")):
|
||||||
url = "https://" + url
|
url = "https://" + url
|
||||||
|
|
||||||
|
|
@ -98,11 +98,11 @@ def web_fetch(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
|
||||||
if extract_text:
|
if extract_text:
|
||||||
soup = BeautifulSoup(response.text, "html.parser")
|
soup = BeautifulSoup(response.text, "html.parser")
|
||||||
# 移除script和style标签
|
# Remove script and style tags
|
||||||
for tag in soup(["script", "style"]):
|
for tag in soup(["script", "style"]):
|
||||||
tag.decompose()
|
tag.decompose()
|
||||||
text = soup.get_text(separator="\n", strip=True)
|
text = soup.get_text(separator="\n", strip=True)
|
||||||
# 清理多余空行
|
# Clean up extra blank lines
|
||||||
lines = [line.strip() for line in text.split("\n") if line.strip()]
|
lines = [line.strip() for line in text.split("\n") if line.strip()]
|
||||||
text = "\n".join(lines)
|
text = "\n".join(lines)
|
||||||
|
|
||||||
|
|
@ -111,7 +111,7 @@ def web_fetch(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"data": {
|
"data": {
|
||||||
"url": url,
|
"url": url,
|
||||||
"title": soup.title.string if soup.title else "",
|
"title": soup.title.string if soup.title else "",
|
||||||
"content": text[:10000] # 限制内容长度
|
"content": text[:10000] # Limit content length
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
|
@ -119,7 +119,7 @@ def web_fetch(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"success": True,
|
"success": True,
|
||||||
"data": {
|
"data": {
|
||||||
"url": url,
|
"url": url,
|
||||||
"html": response.text[:50000] # 限制HTML长度
|
"html": response.text[:50000] # Limit HTML length
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
except requests.RequestException as e:
|
except requests.RequestException as e:
|
||||||
|
|
@ -147,7 +147,7 @@ def web_fetch(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
category="crawler"
|
category="crawler"
|
||||||
)
|
)
|
||||||
def extract_links(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
def extract_links(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""提取网页中的所有链接"""
|
"""Extract all links from a web page"""
|
||||||
url = arguments.get("url", "")
|
url = arguments.get("url", "")
|
||||||
max_links = arguments.get("max_links", 20)
|
max_links = arguments.get("max_links", 20)
|
||||||
|
|
||||||
|
|
@ -169,7 +169,7 @@ def extract_links(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
|
||||||
for a_tag in soup.find_all("a", href=True)[:max_links]:
|
for a_tag in soup.find_all("a", href=True)[:max_links]:
|
||||||
href = a_tag["href"]
|
href = a_tag["href"]
|
||||||
# 处理相对URL
|
# Handle relative URLs
|
||||||
if href.startswith("/"):
|
if href.startswith("/"):
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
href = urljoin(url, href)
|
href = urljoin(url, href)
|
||||||
|
|
@ -0,0 +1,314 @@
|
||||||
|
"""Data processing tools"""
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
import base64
|
||||||
|
from typing import Dict, Any
|
||||||
|
from urllib.parse import quote, unquote
|
||||||
|
|
||||||
|
from luxx.tools.factory import tool
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="calculate",
|
||||||
|
description="Execute mathematical calculations",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"expression": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Mathematical expression to evaluate"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["expression"]
|
||||||
|
},
|
||||||
|
category="data"
|
||||||
|
)
|
||||||
|
def calculate(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Execute mathematical calculation"""
|
||||||
|
expression = arguments.get("expression", "")
|
||||||
|
|
||||||
|
if not expression:
|
||||||
|
return {"success": False, "error": "Expression is required"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Safe replacement for math functions
|
||||||
|
safe_dict = {
|
||||||
|
"abs": abs,
|
||||||
|
"round": round,
|
||||||
|
"min": min,
|
||||||
|
"max": max,
|
||||||
|
"pow": pow,
|
||||||
|
"sqrt": lambda x: x ** 0.5,
|
||||||
|
"sin": lambda x: __import__('math').sin(x),
|
||||||
|
"cos": lambda x: __import__('math').cos(x),
|
||||||
|
"tan": lambda x: __import__('math').tan(x),
|
||||||
|
"log": lambda x: __import__('math').log(x),
|
||||||
|
"pi": __import__('math').pi,
|
||||||
|
"e": __import__('math').e
|
||||||
|
}
|
||||||
|
|
||||||
|
# Remove dangerous characters, only keep numbers and operators
|
||||||
|
safe_expr = re.sub(r"[^0-9+\-*/().%sqrtinsclogmaxminpowabsroundte, ]", "", expression)
|
||||||
|
|
||||||
|
result = eval(safe_expr, {"__builtins__": {}, **safe_dict})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"data": {
|
||||||
|
"expression": expression,
|
||||||
|
"result": result,
|
||||||
|
"formatted": f"{result}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"Calculation error: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="text_process",
|
||||||
|
description="Process and transform text",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"text": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Text to process"
|
||||||
|
},
|
||||||
|
"operation": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Operation to perform: uppercase, lowercase, title, strip, reverse, word_count, char_count",
|
||||||
|
"enum": ["uppercase", "lowercase", "title", "strip", "reverse", "word_count", "char_count"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["text", "operation"]
|
||||||
|
},
|
||||||
|
category="data"
|
||||||
|
)
|
||||||
|
def text_process(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Text processing"""
|
||||||
|
text = arguments.get("text", "")
|
||||||
|
operation = arguments.get("operation", "")
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
return {"success": False, "error": "Text is required"}
|
||||||
|
|
||||||
|
operations = {
|
||||||
|
"uppercase": text.upper(),
|
||||||
|
"lowercase": text.lower(),
|
||||||
|
"title": text.title(),
|
||||||
|
"strip": text.strip(),
|
||||||
|
"reverse": text[::-1],
|
||||||
|
"word_count": len(text.split()),
|
||||||
|
"char_count": len(text)
|
||||||
|
}
|
||||||
|
|
||||||
|
result = operations.get(operation)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
return {"success": False, "error": f"Unknown operation: {operation}"}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"data": {
|
||||||
|
"original": text,
|
||||||
|
"operation": operation,
|
||||||
|
"result": result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="json_process",
|
||||||
|
description="Process and transform JSON data",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"data": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "JSON string or text to process"
|
||||||
|
},
|
||||||
|
"operation": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Operation: format, minify, validate",
|
||||||
|
"enum": ["format", "minify", "validate"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["data", "operation"]
|
||||||
|
},
|
||||||
|
category="data"
|
||||||
|
)
|
||||||
|
def json_process(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""JSON data processing"""
|
||||||
|
data = arguments.get("data", "")
|
||||||
|
operation = arguments.get("operation", "")
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
return {"success": False, "error": "Data is required"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = json.loads(data)
|
||||||
|
|
||||||
|
if operation == "format":
|
||||||
|
result = json.dumps(parsed, indent=2, ensure_ascii=False)
|
||||||
|
elif operation == "minify":
|
||||||
|
result = json.dumps(parsed, ensure_ascii=False)
|
||||||
|
elif operation == "validate":
|
||||||
|
result = "Valid JSON"
|
||||||
|
else:
|
||||||
|
return {"success": False, "error": f"Unknown operation: {operation}"}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"data": {
|
||||||
|
"result": result,
|
||||||
|
"operation": operation
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
return {"success": False, "error": f"Invalid JSON: {str(e)}"}
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="hash_text",
|
||||||
|
description="Generate text hash",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"text": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Text to hash"
|
||||||
|
},
|
||||||
|
"algorithm": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Hash algorithm: md5, sha1, sha256, sha512",
|
||||||
|
"default": "sha256"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["text"]
|
||||||
|
},
|
||||||
|
category="data"
|
||||||
|
)
|
||||||
|
def hash_text(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Generate text hash"""
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
text = arguments.get("text", "")
|
||||||
|
algorithm = arguments.get("algorithm", "sha256")
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
return {"success": False, "error": "Text is required"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
hash_obj = hashlib.new(algorithm)
|
||||||
|
hash_obj.update(text.encode('utf-8'))
|
||||||
|
hash_value = hash_obj.hexdigest()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"data": {
|
||||||
|
"text": text,
|
||||||
|
"algorithm": algorithm,
|
||||||
|
"hash": hash_value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "error": f"Hash error: {str(e)}"}
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="url_encode_decode",
|
||||||
|
description="URL encoding/decoding",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"text": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Text to encode/decode"
|
||||||
|
},
|
||||||
|
"operation": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Operation: encode, decode",
|
||||||
|
"enum": ["encode", "decode"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["text", "operation"]
|
||||||
|
},
|
||||||
|
category="data"
|
||||||
|
)
|
||||||
|
def url_encode_decode(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""URL encoding/decoding"""
|
||||||
|
text = arguments.get("text", "")
|
||||||
|
operation = arguments.get("operation", "")
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
return {"success": False, "error": "Text is required"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if operation == "encode":
|
||||||
|
result = quote(text)
|
||||||
|
elif operation == "decode":
|
||||||
|
result = unquote(text)
|
||||||
|
else:
|
||||||
|
return {"success": False, "error": f"Unknown operation: {operation}"}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"data": {
|
||||||
|
"original": text,
|
||||||
|
"operation": operation,
|
||||||
|
"result": result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "error": f"URL error: {str(e)}"}
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="base64_encode_decode",
|
||||||
|
description="Base64 encoding/decoding",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"text": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Text to encode/decode"
|
||||||
|
},
|
||||||
|
"operation": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Operation: encode, decode",
|
||||||
|
"enum": ["encode", "decode"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["text", "operation"]
|
||||||
|
},
|
||||||
|
category="data"
|
||||||
|
)
|
||||||
|
def base64_encode_decode(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Base64 encoding/decoding"""
|
||||||
|
text = arguments.get("text", "")
|
||||||
|
operation = arguments.get("operation", "")
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
return {"success": False, "error": "Text is required"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if operation == "encode":
|
||||||
|
result = base64.b64encode(text.encode()).decode()
|
||||||
|
elif operation == "decode":
|
||||||
|
result = base64.b64decode(text.encode()).decode()
|
||||||
|
else:
|
||||||
|
return {"success": False, "error": f"Unknown operation: {operation}"}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"data": {
|
||||||
|
"original": text,
|
||||||
|
"operation": operation,
|
||||||
|
"result": result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "error": f"Base64 error: {str(e)}"}
|
||||||
|
|
@ -1,19 +1,19 @@
|
||||||
"""工具系统核心模块"""
|
"""Tool system core module"""
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, Any, Callable, List, Optional, TypeVar, Generic
|
from typing import Callable, Any, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolDefinition:
|
class ToolDefinition:
|
||||||
"""工具定义"""
|
"""Tool definition"""
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
parameters: Dict[str, Any] # JSON Schema
|
parameters: Dict[str, Any]
|
||||||
handler: Callable
|
handler: Callable
|
||||||
category: str = "general"
|
category: str = "general"
|
||||||
|
|
||||||
def to_openai_format(self) -> Dict[str, Any]:
|
def to_openai_format(self) -> Dict[str, Any]:
|
||||||
"""转换为OpenAI格式"""
|
"""Convert to OpenAI format"""
|
||||||
return {
|
return {
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
|
|
@ -26,50 +26,51 @@ class ToolDefinition:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolResult:
|
class ToolResult:
|
||||||
"""工具执行结果"""
|
"""Tool execution result"""
|
||||||
success: bool
|
success: bool
|
||||||
data: Any = None
|
data: Any = None
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
"""转换为字典"""
|
"""Convert to dictionary"""
|
||||||
return {"success": self.success, "data": self.data, "error": self.error}
|
return {"success": self.success, "data": self.data, "error": self.error}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def ok(cls, data: Any) -> "ToolResult":
|
def ok(cls, data: Any) -> "ToolResult":
|
||||||
"""创建成功结果"""
|
"""Create success result"""
|
||||||
return cls(success=True, data=data)
|
return cls(success=True, data=data)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def fail(cls, error: str) -> "ToolResult":
|
def fail(cls, error: str) -> "ToolResult":
|
||||||
"""创建失败结果"""
|
"""Create failure result"""
|
||||||
return cls(success=False, error=error)
|
return cls(success=False, error=error)
|
||||||
|
|
||||||
|
|
||||||
class ToolRegistry:
|
class ToolRegistry:
|
||||||
"""工具注册表(单例模式)"""
|
"""Tool registry (singleton pattern)"""
|
||||||
_instance: Optional["ToolRegistry"] = None
|
_instance: Optional["ToolRegistry"] = None
|
||||||
_tools: Dict[str, ToolDefinition] = {}
|
_tools: Dict[str, ToolDefinition] = {}
|
||||||
|
|
||||||
def __new__(cls):
|
def __new__(cls):
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = super().__new__(cls)
|
cls._instance = super().__new__(cls)
|
||||||
|
cls._instance._tools = {}
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def register(self, tool: ToolDefinition) -> None:
|
def register(self, tool: ToolDefinition) -> None:
|
||||||
"""注册工具"""
|
"""Register tool"""
|
||||||
self._tools[tool.name] = tool
|
self._tools[tool.name] = tool
|
||||||
|
|
||||||
def get(self, name: str) -> Optional[ToolDefinition]:
|
def get(self, name: str) -> Optional[ToolDefinition]:
|
||||||
"""获取工具定义"""
|
"""Get tool definition"""
|
||||||
return self._tools.get(name)
|
return self._tools.get(name)
|
||||||
|
|
||||||
def list_all(self) -> List[Dict[str, Any]]:
|
def list_all(self) -> List[Dict[str, Any]]:
|
||||||
"""列出所有工具"""
|
"""List all tools"""
|
||||||
return [t.to_openai_format() for t in self._tools.values()]
|
return [t.to_openai_format() for t in self._tools.values()]
|
||||||
|
|
||||||
def list_by_category(self, category: str) -> List[Dict[str, Any]]:
|
def list_by_category(self, category: str) -> List[Dict[str, Any]]:
|
||||||
"""按分类列出工具"""
|
"""List tools by category"""
|
||||||
return [
|
return [
|
||||||
t.to_openai_format()
|
t.to_openai_format()
|
||||||
for t in self._tools.values()
|
for t in self._tools.values()
|
||||||
|
|
@ -77,35 +78,34 @@ class ToolRegistry:
|
||||||
]
|
]
|
||||||
|
|
||||||
def execute(self, name: str, arguments: dict) -> Dict[str, Any]:
|
def execute(self, name: str, arguments: dict) -> Dict[str, Any]:
|
||||||
"""执行工具"""
|
"""Execute tool"""
|
||||||
tool = self.get(name)
|
tool = self.get(name)
|
||||||
if not tool:
|
if not tool:
|
||||||
return ToolResult.fail(f"Tool not found: {name}").to_dict()
|
return {"success": False, "error": f"Tool '{name}' not found"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = tool.handler(arguments)
|
result = tool.handler(arguments)
|
||||||
if isinstance(result, ToolResult):
|
if isinstance(result, ToolResult):
|
||||||
return result.to_dict()
|
return result.to_dict()
|
||||||
return ToolResult.ok(result).to_dict()
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return ToolResult.fail(str(e)).to_dict()
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""清空所有工具"""
|
"""Clear all tools"""
|
||||||
self._tools.clear()
|
self._tools.clear()
|
||||||
|
|
||||||
def remove(self, name: str) -> bool:
|
def remove(self, name: str) -> bool:
|
||||||
"""移除工具"""
|
"""Remove tool"""
|
||||||
if name in self._tools:
|
if name in self._tools:
|
||||||
del self._tools[name]
|
del self._tools[name]
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
|
||||||
def tool_count(self) -> int:
|
def tool_count(self) -> int:
|
||||||
"""工具数量"""
|
"""Tool count"""
|
||||||
return len(self._tools)
|
return len(self._tools)
|
||||||
|
|
||||||
|
|
||||||
# 全局注册表实例
|
# Global registry instance
|
||||||
registry = ToolRegistry()
|
registry = ToolRegistry()
|
||||||
|
|
@ -0,0 +1,177 @@
|
||||||
|
"""Tool executor"""
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
|
from luxx.tools.core import registry, ToolResult
|
||||||
|
|
||||||
|
|
||||||
|
class ToolExecutor:
|
||||||
|
"""Tool executor with caching and parallel execution support"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
enable_cache: bool = True,
|
||||||
|
cache_ttl: int = 300, # 5 minutes
|
||||||
|
max_workers: int = 4
|
||||||
|
):
|
||||||
|
self.enable_cache = enable_cache
|
||||||
|
self.cache_ttl = cache_ttl
|
||||||
|
self.max_workers = max_workers
|
||||||
|
self._cache: Dict[str, tuple] = {} # key: (result, timestamp)
|
||||||
|
self._call_history: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
def _make_cache_key(self, name: str, args: dict) -> str:
|
||||||
|
"""Generate cache key"""
|
||||||
|
args_str = json.dumps(args, sort_keys=True, ensure_ascii=False)
|
||||||
|
return f"{name}:{args_str}"
|
||||||
|
|
||||||
|
def _is_cache_valid(self, cache_key: str) -> bool:
|
||||||
|
"""Check if cache is valid"""
|
||||||
|
if cache_key not in self._cache:
|
||||||
|
return False
|
||||||
|
_, timestamp = self._cache[cache_key]
|
||||||
|
return time.time() - timestamp < self.cache_ttl
|
||||||
|
|
||||||
|
def _get_cached(self, cache_key: str) -> Optional[Dict]:
|
||||||
|
"""Get cached result"""
|
||||||
|
if self.enable_cache and self._is_cache_valid(cache_key):
|
||||||
|
return self._cache[cache_key][0]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _set_cached(self, cache_key: str, result: Dict) -> None:
|
||||||
|
"""Set cache"""
|
||||||
|
if self.enable_cache:
|
||||||
|
self._cache[cache_key] = (result, time.time())
|
||||||
|
|
||||||
|
def _record_call(self, name: str, args: dict, result: Dict) -> None:
|
||||||
|
"""Record call history"""
|
||||||
|
self._call_history.append({
|
||||||
|
"name": name,
|
||||||
|
"args": args,
|
||||||
|
"result": result,
|
||||||
|
"timestamp": time.time()
|
||||||
|
})
|
||||||
|
|
||||||
|
# Limit history size
|
||||||
|
if len(self._call_history) > 1000:
|
||||||
|
self._call_history = self._call_history[-1000:]
|
||||||
|
|
||||||
|
def process_tool_calls(
|
||||||
|
self,
|
||||||
|
tool_calls: List[Dict[str, Any]],
|
||||||
|
context: Dict[str, Any]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Process tool calls sequentially"""
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for call in tool_calls:
|
||||||
|
call_id = call.get("id", "")
|
||||||
|
name = call.get("function", {}).get("name", "")
|
||||||
|
|
||||||
|
# Parse JSON arguments
|
||||||
|
try:
|
||||||
|
args = json.loads(call.get("function", {}).get("arguments", "{}"))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
args = {}
|
||||||
|
|
||||||
|
# Check cache
|
||||||
|
cache_key = self._make_cache_key(name, args)
|
||||||
|
cached = self._get_cached(cache_key)
|
||||||
|
|
||||||
|
if cached is not None:
|
||||||
|
result = cached
|
||||||
|
else:
|
||||||
|
# Execute tool
|
||||||
|
result = registry.execute(name, args)
|
||||||
|
self._set_cached(cache_key, result)
|
||||||
|
|
||||||
|
# Record call
|
||||||
|
self._record_call(name, args, result)
|
||||||
|
|
||||||
|
# Create result message
|
||||||
|
results.append(self._create_tool_result(call_id, name, result))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def process_tool_calls_parallel(
|
||||||
|
self,
|
||||||
|
tool_calls: List[Dict[str, Any]],
|
||||||
|
context: Dict[str, Any]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Process tool calls in parallel"""
|
||||||
|
if len(tool_calls) <= 1:
|
||||||
|
return self.process_tool_calls(tool_calls, context)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
|
futures = {}
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||||
|
for call in tool_calls:
|
||||||
|
call_id = call.get("id", "")
|
||||||
|
name = call.get("function", {}).get("name", "")
|
||||||
|
|
||||||
|
# Parse all arguments
|
||||||
|
try:
|
||||||
|
args = json.loads(call.get("function", {}).get("arguments", "{}"))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
args = {}
|
||||||
|
|
||||||
|
# Check cache
|
||||||
|
cache_key = self._make_cache_key(name, args)
|
||||||
|
cached = self._get_cached(cache_key)
|
||||||
|
|
||||||
|
if cached is not None:
|
||||||
|
futures[call_id] = (name, args, cached)
|
||||||
|
else:
|
||||||
|
# Submit task
|
||||||
|
future = executor.submit(registry.execute, name, args)
|
||||||
|
futures[future] = (call_id, name, args)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for future in as_completed(futures.keys()):
|
||||||
|
if future in futures:
|
||||||
|
call_id, name, args = futures[future]
|
||||||
|
result = future.result()
|
||||||
|
self._set_cached(self._make_cache_key(name, args), result)
|
||||||
|
self._record_call(name, args, result)
|
||||||
|
results.append(self._create_tool_result(call_id, name, result))
|
||||||
|
else:
|
||||||
|
call_id, name, args = futures[future]
|
||||||
|
result = future.result()
|
||||||
|
self._set_cached(self._make_cache_key(name, args), result)
|
||||||
|
self._record_call(name, args, result)
|
||||||
|
results.append(self._create_tool_result(call_id, name, result))
|
||||||
|
|
||||||
|
return results
|
||||||
|
except ImportError:
|
||||||
|
return self.process_tool_calls(tool_calls, context)
|
||||||
|
|
||||||
|
def _create_tool_result(self, call_id: str, name: str, result: Dict) -> Dict[str, Any]:
|
||||||
|
"""Create tool result message"""
|
||||||
|
return {
|
||||||
|
"tool_call_id": call_id,
|
||||||
|
"role": "tool",
|
||||||
|
"name": name,
|
||||||
|
"content": json.dumps(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
def _create_error_result(self, call_id: str, name: str, error: str) -> Dict[str, Any]:
|
||||||
|
"""Create error result message"""
|
||||||
|
return {
|
||||||
|
"tool_call_id": call_id,
|
||||||
|
"role": "tool",
|
||||||
|
"name": name,
|
||||||
|
"content": json.dumps({"success": False, "error": error})
|
||||||
|
}
|
||||||
|
|
||||||
|
def clear_cache(self) -> None:
|
||||||
|
"""Clear all cache"""
|
||||||
|
self._cache.clear()
|
||||||
|
|
||||||
|
def get_history(self, limit: int = 100) -> List[Dict[str, Any]]:
|
||||||
|
"""Get call history"""
|
||||||
|
return self._call_history[-limit:]
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"""工具装饰器工厂"""
|
"""Tool decorator factory"""
|
||||||
from typing import Callable, Any, Dict
|
from typing import Callable, Any, Dict
|
||||||
from alcor.tools.core import ToolDefinition, registry
|
from luxx.tools.core import ToolDefinition, registry
|
||||||
|
|
||||||
|
|
||||||
def tool(
|
def tool(
|
||||||
|
|
@ -8,28 +8,28 @@ def tool(
|
||||||
description: str,
|
description: str,
|
||||||
parameters: Dict[str, Any],
|
parameters: Dict[str, Any],
|
||||||
category: str = "general"
|
category: str = "general"
|
||||||
) -> Callable:
|
):
|
||||||
"""
|
"""
|
||||||
工具注册装饰器
|
Tool registration decorator
|
||||||
|
|
||||||
用法示例:
|
Usage:
|
||||||
```python
|
```python
|
||||||
@tool(
|
@tool(
|
||||||
name="web_search",
|
name="my_tool",
|
||||||
description="Search the internet for information",
|
description="This is my tool",
|
||||||
parameters={
|
parameters={
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"query": {"type": "string", "description": "Search keywords"},
|
"arg1": {"type": "string"}
|
||||||
"max_results": {"type": "integer", "description": "Max results", "default": 5}
|
|
||||||
},
|
},
|
||||||
"required": ["query"]
|
"required": ["arg1"]
|
||||||
},
|
}
|
||||||
category="crawler"
|
|
||||||
)
|
)
|
||||||
def web_search(arguments: dict) -> dict:
|
def my_tool(arguments: dict) -> dict:
|
||||||
# 实现...
|
# Implementation...
|
||||||
return {"results": []}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
# The tool will be automatically registered
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
def decorator(func: Callable) -> Callable:
|
def decorator(func: Callable) -> Callable:
|
||||||
|
|
@ -46,12 +46,12 @@ def tool(
|
||||||
|
|
||||||
|
|
||||||
def tool_function(
|
def tool_function(
|
||||||
name: str,
|
name: str = None,
|
||||||
description: str,
|
description: str = None,
|
||||||
parameters: Dict[str, Any],
|
parameters: Dict[str, Any] = None,
|
||||||
category: str = "general"
|
category: str = "general"
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
工具装饰器的别名,提供更语义化的命名
|
Alias for tool decorator, providing a more semantic naming
|
||||||
"""
|
"""
|
||||||
return tool(name=name, description=description, parameters=parameters, category=category)
|
return tool(name=name, description=description, parameters=parameters, category=category)
|
||||||
|
|
@ -0,0 +1,11 @@
|
||||||
|
"""Utility functions module"""
|
||||||
|
from luxx.utils.helpers import (
|
||||||
|
generate_id,
|
||||||
|
hash_password,
|
||||||
|
verify_password,
|
||||||
|
create_access_token,
|
||||||
|
decode_access_token,
|
||||||
|
success_response,
|
||||||
|
error_response,
|
||||||
|
paginate
|
||||||
|
)
|
||||||
|
|
@ -1,14 +1,14 @@
|
||||||
"""辅助工具模块"""
|
"""Utility helpers module"""
|
||||||
import shortuuid
|
import shortuuid
|
||||||
import jwt
|
import hashlib
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Optional, Dict, Any
|
from typing import Dict, Any, Optional
|
||||||
|
|
||||||
from alcor.config import config
|
from luxx.config import config
|
||||||
|
|
||||||
|
|
||||||
def generate_id(prefix: str = "") -> str:
|
def generate_id(prefix: str = "") -> str:
|
||||||
"""生成唯一ID"""
|
"""Generate unique ID"""
|
||||||
unique_id = shortuuid.uuid()
|
unique_id = shortuuid.uuid()
|
||||||
if prefix:
|
if prefix:
|
||||||
return f"{prefix}_{unique_id}"
|
return f"{prefix}_{unique_id}"
|
||||||
|
|
@ -16,54 +16,42 @@ def generate_id(prefix: str = "") -> str:
|
||||||
|
|
||||||
|
|
||||||
def hash_password(password: str) -> str:
|
def hash_password(password: str) -> str:
|
||||||
"""密码哈希"""
|
"""Hash password"""
|
||||||
import bcrypt
|
import bcrypt
|
||||||
salt = bcrypt.gensalt()
|
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||||
return bcrypt.hashpw(password.encode(), salt).decode()
|
|
||||||
|
|
||||||
|
|
||||||
def verify_password(password: str, hashed: str) -> bool:
|
def verify_password(password: str, hashed: str) -> bool:
|
||||||
"""验证密码"""
|
"""Verify password"""
|
||||||
import bcrypt
|
import bcrypt
|
||||||
return bcrypt.checkpw(password.encode(), hashed.encode())
|
return bcrypt.checkpw(password.encode(), hashed.encode())
|
||||||
|
|
||||||
|
|
||||||
def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
||||||
"""创建JWT访问令牌"""
|
"""Create JWT access token"""
|
||||||
|
from jose import jwt
|
||||||
to_encode = data.copy()
|
to_encode = data.copy()
|
||||||
|
|
||||||
if expires_delta:
|
if expires_delta:
|
||||||
expire = datetime.utcnow() + expires_delta
|
expire = datetime.utcnow() + expires_delta
|
||||||
else:
|
else:
|
||||||
expire = datetime.utcnow() + timedelta(hours=24)
|
expire = datetime.utcnow() + timedelta(hours=24)
|
||||||
|
to_encode.update({"exp": expire})
|
||||||
to_encode.update({"exp": expire, "iat": datetime.utcnow()})
|
encoded_jwt = jwt.encode(to_encode, config.secret_key, algorithm="HS256")
|
||||||
|
|
||||||
encoded_jwt = jwt.encode(
|
|
||||||
to_encode,
|
|
||||||
config.secret_key,
|
|
||||||
algorithm="HS256"
|
|
||||||
)
|
|
||||||
return encoded_jwt
|
return encoded_jwt
|
||||||
|
|
||||||
|
|
||||||
def decode_access_token(token: str) -> Optional[Dict[str, Any]]:
|
def decode_access_token(token: str) -> Optional[Dict[str, Any]]:
|
||||||
"""解码JWT令牌"""
|
"""Decode JWT token"""
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(
|
from jose import jwt
|
||||||
token,
|
payload = jwt.decode(token, config.secret_key, algorithms=["HS256"])
|
||||||
config.secret_key,
|
|
||||||
algorithms=["HS256"]
|
|
||||||
)
|
|
||||||
return payload
|
return payload
|
||||||
except jwt.ExpiredSignatureError:
|
except Exception:
|
||||||
return None
|
|
||||||
except jwt.InvalidTokenError:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def success_response(data: Any = None, message: str = "Success") -> Dict[str, Any]:
|
def success_response(data: Any = None, message: str = "Success") -> Dict[str, Any]:
|
||||||
"""成功响应封装"""
|
"""Success response wrapper"""
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"message": message,
|
"message": message,
|
||||||
|
|
@ -72,7 +60,7 @@ def success_response(data: Any = None, message: str = "Success") -> Dict[str, An
|
||||||
|
|
||||||
|
|
||||||
def error_response(message: str, code: int = 400, errors: Any = None) -> Dict[str, Any]:
|
def error_response(message: str, code: int = 400, errors: Any = None) -> Dict[str, Any]:
|
||||||
"""错误响应封装"""
|
"""Error response wrapper"""
|
||||||
response = {
|
response = {
|
||||||
"success": False,
|
"success": False,
|
||||||
"message": message,
|
"message": message,
|
||||||
|
|
@ -84,14 +72,12 @@ def error_response(message: str, code: int = 400, errors: Any = None) -> Dict[st
|
||||||
|
|
||||||
|
|
||||||
def paginate(query, page: int = 1, page_size: int = 20):
|
def paginate(query, page: int = 1, page_size: int = 20):
|
||||||
"""分页辅助"""
|
"""Pagination helper"""
|
||||||
total = query.count()
|
total = query.count()
|
||||||
items = query.offset((page - 1) * page_size).limit(page_size).all()
|
items = query.offset((page - 1) * page_size).limit(page_size).all()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"items": items,
|
|
||||||
"total": total,
|
"total": total,
|
||||||
"page": page,
|
"page": page,
|
||||||
"page_size": page_size,
|
"page_size": page_size,
|
||||||
"total_pages": (total + page_size - 1) // page_size
|
"items": items
|
||||||
}
|
}
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
[project]
|
[project]
|
||||||
name = "alcor"
|
name = "luxx"
|
||||||
version = "1.0.0"
|
version = "1.0.0"
|
||||||
description = "Alcor - FastAPI + SQLAlchemy"
|
description = "luxx - FastAPI + SQLAlchemy"
|
||||||
readme = "docs/README.md"
|
readme = "docs/README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue