Luxx/alcor/tools/executor.py

187 lines
6.6 KiB
Python

"""工具执行器"""
import json
import time
import hashlib
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Dict, Optional, Any
from alcor.tools.core import registry, ToolResult
class ToolExecutor:
"""工具执行器,支持缓存、并行执行"""
def __init__(
self,
enable_cache: bool = True,
cache_ttl: int = 300, # 5分钟
max_workers: int = 4
):
self.enable_cache = enable_cache
self.cache_ttl = cache_ttl
self.max_workers = max_workers
self._cache: Dict[str, tuple] = {} # (result, timestamp)
self._call_history: List[Dict[str, Any]] = []
def _make_cache_key(self, name: str, args: dict) -> str:
"""生成缓存键"""
args_str = json.dumps(args, sort_keys=True, ensure_ascii=False)
return hashlib.md5(f"{name}:{args_str}".encode()).hexdigest()
def _is_cache_valid(self, cache_key: str) -> bool:
"""检查缓存是否有效"""
if cache_key not in self._cache:
return False
_, timestamp = self._cache[cache_key]
return (time.time() - timestamp) < self.cache_ttl
def _get_cached(self, cache_key: str) -> Optional[Dict]:
"""获取缓存结果"""
if self.enable_cache and self._is_cache_valid(cache_key):
return self._cache[cache_key][0]
return None
def _set_cached(self, cache_key: str, result: Dict) -> None:
"""设置缓存"""
if self.enable_cache:
self._cache[cache_key] = (result, time.time())
def _record_call(self, name: str, args: dict, result: Dict) -> None:
"""记录调用历史"""
self._call_history.append({
"name": name,
"args": args,
"result": result,
"timestamp": time.time()
})
# 限制历史记录数量
if len(self._call_history) > 1000:
self._call_history = self._call_history[-500:]
def process_tool_calls(
self,
tool_calls: List[Dict[str, Any]],
context: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""顺序处理工具调用"""
results = []
for call in tool_calls:
name = call.get("function", {}).get("name", "")
args_str = call.get("function", {}).get("arguments", "{}")
call_id = call.get("id", "")
# 解析JSON参数
try:
args = json.loads(args_str) if isinstance(args_str, str) else args_str
except json.JSONDecodeError:
results.append(self._create_error_result(call_id, name, "Invalid JSON arguments"))
continue
# 检查缓存
cache_key = self._make_cache_key(name, args)
cached_result = self._get_cached(cache_key)
if cached_result is not None:
result = cached_result
else:
# 执行工具
result = registry.execute(name, args)
self._set_cached(cache_key, result)
# 记录调用
self._record_call(name, args, result)
# 创建结果消息
results.append(self._create_tool_result(call_id, name, result))
return results
def process_tool_calls_parallel(
self,
tool_calls: List[Dict[str, Any]],
context: Optional[Dict[str, Any]] = None,
max_workers: Optional[int] = None
) -> List[Dict[str, Any]]:
"""并行处理工具调用"""
if len(tool_calls) <= 1:
return self.process_tool_calls(tool_calls, context)
workers = max_workers or self.max_workers
results = [None] * len(tool_calls)
exec_tasks = {}
# 解析所有参数
for i, call in enumerate(tool_calls):
try:
name = call.get("function", {}).get("name", "")
args_str = call.get("function", {}).get("arguments", "{}")
call_id = call.get("id", "")
args = json.loads(args_str) if isinstance(args_str, str) else args_str
exec_tasks[i] = (call_id, name, args)
except json.JSONDecodeError:
results[i] = self._create_error_result(
call.get("id", ""),
call.get("function", {}).get("name", ""),
"Invalid JSON"
)
# 并行执行
def run(call_id: str, name: str, args: dict) -> Dict[str, Any]:
# 检查缓存
cache_key = self._make_cache_key(name, args)
cached_result = self._get_cached(cache_key)
if cached_result is not None:
result = cached_result
else:
result = registry.execute(name, args)
self._set_cached(cache_key, result)
self._record_call(name, args, result)
return self._create_tool_result(call_id, name, result)
with ThreadPoolExecutor(max_workers=workers) as pool:
futures = {
pool.submit(run, cid, n, a): i
for i, (cid, n, a) in exec_tasks.items()
}
for future in as_completed(futures):
idx = futures[future]
try:
results[idx] = future.result()
except Exception as e:
results[idx] = self._create_error_result(
exec_tasks[idx][0] if idx in exec_tasks else "",
exec_tasks[idx][1] if idx in exec_tasks else "",
str(e)
)
return results
def _create_tool_result(self, call_id: str, name: str, result: Dict) -> Dict[str, Any]:
"""创建工具结果消息"""
return {
"role": "tool",
"tool_call_id": call_id,
"name": name,
"content": json.dumps(result, ensure_ascii=False, default=str)
}
def _create_error_result(self, call_id: str, name: str, error: str) -> Dict[str, Any]:
"""创建错误结果消息"""
return {
"role": "tool",
"tool_call_id": call_id,
"name": name,
"content": json.dumps({"success": False, "error": error})
}
def clear_cache(self) -> None:
"""清空缓存"""
self._cache.clear()
def get_history(self, limit: int = 100) -> List[Dict[str, Any]]:
"""获取调用历史"""
return self._call_history[-limit:]