"""工具执行器""" import json import time import hashlib from concurrent.futures import ThreadPoolExecutor, as_completed from typing import List, Dict, Optional, Any from alcor.tools.core import registry, ToolResult class ToolExecutor: """工具执行器,支持缓存、并行执行""" def __init__( self, enable_cache: bool = True, cache_ttl: int = 300, # 5分钟 max_workers: int = 4 ): self.enable_cache = enable_cache self.cache_ttl = cache_ttl self.max_workers = max_workers self._cache: Dict[str, tuple] = {} # (result, timestamp) self._call_history: List[Dict[str, Any]] = [] def _make_cache_key(self, name: str, args: dict) -> str: """生成缓存键""" args_str = json.dumps(args, sort_keys=True, ensure_ascii=False) return hashlib.md5(f"{name}:{args_str}".encode()).hexdigest() def _is_cache_valid(self, cache_key: str) -> bool: """检查缓存是否有效""" if cache_key not in self._cache: return False _, timestamp = self._cache[cache_key] return (time.time() - timestamp) < self.cache_ttl def _get_cached(self, cache_key: str) -> Optional[Dict]: """获取缓存结果""" if self.enable_cache and self._is_cache_valid(cache_key): return self._cache[cache_key][0] return None def _set_cached(self, cache_key: str, result: Dict) -> None: """设置缓存""" if self.enable_cache: self._cache[cache_key] = (result, time.time()) def _record_call(self, name: str, args: dict, result: Dict) -> None: """记录调用历史""" self._call_history.append({ "name": name, "args": args, "result": result, "timestamp": time.time() }) # 限制历史记录数量 if len(self._call_history) > 1000: self._call_history = self._call_history[-500:] def process_tool_calls( self, tool_calls: List[Dict[str, Any]], context: Optional[Dict[str, Any]] = None ) -> List[Dict[str, Any]]: """顺序处理工具调用""" results = [] for call in tool_calls: name = call.get("function", {}).get("name", "") args_str = call.get("function", {}).get("arguments", "{}") call_id = call.get("id", "") # 解析JSON参数 try: args = json.loads(args_str) if isinstance(args_str, str) else args_str except json.JSONDecodeError: results.append(self._create_error_result(call_id, name, "Invalid JSON arguments")) continue # 检查缓存 cache_key = self._make_cache_key(name, args) cached_result = self._get_cached(cache_key) if cached_result is not None: result = cached_result else: # 执行工具 result = registry.execute(name, args) self._set_cached(cache_key, result) # 记录调用 self._record_call(name, args, result) # 创建结果消息 results.append(self._create_tool_result(call_id, name, result)) return results def process_tool_calls_parallel( self, tool_calls: List[Dict[str, Any]], context: Optional[Dict[str, Any]] = None, max_workers: Optional[int] = None ) -> List[Dict[str, Any]]: """并行处理工具调用""" if len(tool_calls) <= 1: return self.process_tool_calls(tool_calls, context) workers = max_workers or self.max_workers results = [None] * len(tool_calls) exec_tasks = {} # 解析所有参数 for i, call in enumerate(tool_calls): try: name = call.get("function", {}).get("name", "") args_str = call.get("function", {}).get("arguments", "{}") call_id = call.get("id", "") args = json.loads(args_str) if isinstance(args_str, str) else args_str exec_tasks[i] = (call_id, name, args) except json.JSONDecodeError: results[i] = self._create_error_result( call.get("id", ""), call.get("function", {}).get("name", ""), "Invalid JSON" ) # 并行执行 def run(call_id: str, name: str, args: dict) -> Dict[str, Any]: # 检查缓存 cache_key = self._make_cache_key(name, args) cached_result = self._get_cached(cache_key) if cached_result is not None: result = cached_result else: result = registry.execute(name, args) self._set_cached(cache_key, result) self._record_call(name, args, result) return self._create_tool_result(call_id, name, result) with ThreadPoolExecutor(max_workers=workers) as pool: futures = { pool.submit(run, cid, n, a): i for i, (cid, n, a) in exec_tasks.items() } for future in as_completed(futures): idx = futures[future] try: results[idx] = future.result() except Exception as e: results[idx] = self._create_error_result( exec_tasks[idx][0] if idx in exec_tasks else "", exec_tasks[idx][1] if idx in exec_tasks else "", str(e) ) return results def _create_tool_result(self, call_id: str, name: str, result: Dict) -> Dict[str, Any]: """创建工具结果消息""" return { "role": "tool", "tool_call_id": call_id, "name": name, "content": json.dumps(result, ensure_ascii=False, default=str) } def _create_error_result(self, call_id: str, name: str, error: str) -> Dict[str, Any]: """创建错误结果消息""" return { "role": "tool", "tool_call_id": call_id, "name": name, "content": json.dumps({"success": False, "error": error}) } def clear_cache(self) -> None: """清空缓存""" self._cache.clear() def get_history(self, limit: int = 100) -> List[Dict[str, Any]]: """获取调用历史""" return self._call_history[-limit:]