187 lines
6.6 KiB
Python
187 lines
6.6 KiB
Python
"""工具执行器"""
|
|
import json
|
|
import time
|
|
import hashlib
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
from typing import List, Dict, Optional, Any
|
|
|
|
from alcor.tools.core import registry, ToolResult
|
|
|
|
|
|
class ToolExecutor:
|
|
"""工具执行器,支持缓存、并行执行"""
|
|
|
|
def __init__(
|
|
self,
|
|
enable_cache: bool = True,
|
|
cache_ttl: int = 300, # 5分钟
|
|
max_workers: int = 4
|
|
):
|
|
self.enable_cache = enable_cache
|
|
self.cache_ttl = cache_ttl
|
|
self.max_workers = max_workers
|
|
self._cache: Dict[str, tuple] = {} # (result, timestamp)
|
|
self._call_history: List[Dict[str, Any]] = []
|
|
|
|
def _make_cache_key(self, name: str, args: dict) -> str:
|
|
"""生成缓存键"""
|
|
args_str = json.dumps(args, sort_keys=True, ensure_ascii=False)
|
|
return hashlib.md5(f"{name}:{args_str}".encode()).hexdigest()
|
|
|
|
def _is_cache_valid(self, cache_key: str) -> bool:
|
|
"""检查缓存是否有效"""
|
|
if cache_key not in self._cache:
|
|
return False
|
|
_, timestamp = self._cache[cache_key]
|
|
return (time.time() - timestamp) < self.cache_ttl
|
|
|
|
def _get_cached(self, cache_key: str) -> Optional[Dict]:
|
|
"""获取缓存结果"""
|
|
if self.enable_cache and self._is_cache_valid(cache_key):
|
|
return self._cache[cache_key][0]
|
|
return None
|
|
|
|
def _set_cached(self, cache_key: str, result: Dict) -> None:
|
|
"""设置缓存"""
|
|
if self.enable_cache:
|
|
self._cache[cache_key] = (result, time.time())
|
|
|
|
def _record_call(self, name: str, args: dict, result: Dict) -> None:
|
|
"""记录调用历史"""
|
|
self._call_history.append({
|
|
"name": name,
|
|
"args": args,
|
|
"result": result,
|
|
"timestamp": time.time()
|
|
})
|
|
# 限制历史记录数量
|
|
if len(self._call_history) > 1000:
|
|
self._call_history = self._call_history[-500:]
|
|
|
|
def process_tool_calls(
|
|
self,
|
|
tool_calls: List[Dict[str, Any]],
|
|
context: Optional[Dict[str, Any]] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""顺序处理工具调用"""
|
|
results = []
|
|
|
|
for call in tool_calls:
|
|
name = call.get("function", {}).get("name", "")
|
|
args_str = call.get("function", {}).get("arguments", "{}")
|
|
call_id = call.get("id", "")
|
|
|
|
# 解析JSON参数
|
|
try:
|
|
args = json.loads(args_str) if isinstance(args_str, str) else args_str
|
|
except json.JSONDecodeError:
|
|
results.append(self._create_error_result(call_id, name, "Invalid JSON arguments"))
|
|
continue
|
|
|
|
# 检查缓存
|
|
cache_key = self._make_cache_key(name, args)
|
|
cached_result = self._get_cached(cache_key)
|
|
|
|
if cached_result is not None:
|
|
result = cached_result
|
|
else:
|
|
# 执行工具
|
|
result = registry.execute(name, args)
|
|
self._set_cached(cache_key, result)
|
|
|
|
# 记录调用
|
|
self._record_call(name, args, result)
|
|
|
|
# 创建结果消息
|
|
results.append(self._create_tool_result(call_id, name, result))
|
|
|
|
return results
|
|
|
|
def process_tool_calls_parallel(
|
|
self,
|
|
tool_calls: List[Dict[str, Any]],
|
|
context: Optional[Dict[str, Any]] = None,
|
|
max_workers: Optional[int] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""并行处理工具调用"""
|
|
if len(tool_calls) <= 1:
|
|
return self.process_tool_calls(tool_calls, context)
|
|
|
|
workers = max_workers or self.max_workers
|
|
results = [None] * len(tool_calls)
|
|
exec_tasks = {}
|
|
|
|
# 解析所有参数
|
|
for i, call in enumerate(tool_calls):
|
|
try:
|
|
name = call.get("function", {}).get("name", "")
|
|
args_str = call.get("function", {}).get("arguments", "{}")
|
|
call_id = call.get("id", "")
|
|
args = json.loads(args_str) if isinstance(args_str, str) else args_str
|
|
exec_tasks[i] = (call_id, name, args)
|
|
except json.JSONDecodeError:
|
|
results[i] = self._create_error_result(
|
|
call.get("id", ""),
|
|
call.get("function", {}).get("name", ""),
|
|
"Invalid JSON"
|
|
)
|
|
|
|
# 并行执行
|
|
def run(call_id: str, name: str, args: dict) -> Dict[str, Any]:
|
|
# 检查缓存
|
|
cache_key = self._make_cache_key(name, args)
|
|
cached_result = self._get_cached(cache_key)
|
|
|
|
if cached_result is not None:
|
|
result = cached_result
|
|
else:
|
|
result = registry.execute(name, args)
|
|
self._set_cached(cache_key, result)
|
|
|
|
self._record_call(name, args, result)
|
|
return self._create_tool_result(call_id, name, result)
|
|
|
|
with ThreadPoolExecutor(max_workers=workers) as pool:
|
|
futures = {
|
|
pool.submit(run, cid, n, a): i
|
|
for i, (cid, n, a) in exec_tasks.items()
|
|
}
|
|
for future in as_completed(futures):
|
|
idx = futures[future]
|
|
try:
|
|
results[idx] = future.result()
|
|
except Exception as e:
|
|
results[idx] = self._create_error_result(
|
|
exec_tasks[idx][0] if idx in exec_tasks else "",
|
|
exec_tasks[idx][1] if idx in exec_tasks else "",
|
|
str(e)
|
|
)
|
|
|
|
return results
|
|
|
|
def _create_tool_result(self, call_id: str, name: str, result: Dict) -> Dict[str, Any]:
|
|
"""创建工具结果消息"""
|
|
return {
|
|
"role": "tool",
|
|
"tool_call_id": call_id,
|
|
"name": name,
|
|
"content": json.dumps(result, ensure_ascii=False, default=str)
|
|
}
|
|
|
|
def _create_error_result(self, call_id: str, name: str, error: str) -> Dict[str, Any]:
|
|
"""创建错误结果消息"""
|
|
return {
|
|
"role": "tool",
|
|
"tool_call_id": call_id,
|
|
"name": name,
|
|
"content": json.dumps({"success": False, "error": error})
|
|
}
|
|
|
|
def clear_cache(self) -> None:
|
|
"""清空缓存"""
|
|
self._cache.clear()
|
|
|
|
def get_history(self, limit: int = 100) -> List[Dict[str, Any]]:
|
|
"""获取调用历史"""
|
|
return self._call_history[-limit:]
|