178 lines
6.4 KiB
Python
178 lines
6.4 KiB
Python
"""Tool executor"""
|
|
import json
|
|
import time
|
|
from typing import List, Dict, Any, Optional
|
|
|
|
from luxx.tools.core import registry, ToolResult
|
|
|
|
|
|
class ToolExecutor:
|
|
"""Tool executor with caching and parallel execution support"""
|
|
|
|
def __init__(
|
|
self,
|
|
enable_cache: bool = True,
|
|
cache_ttl: int = 300, # 5 minutes
|
|
max_workers: int = 4
|
|
):
|
|
self.enable_cache = enable_cache
|
|
self.cache_ttl = cache_ttl
|
|
self.max_workers = max_workers
|
|
self._cache: Dict[str, tuple] = {} # key: (result, timestamp)
|
|
self._call_history: List[Dict[str, Any]] = []
|
|
|
|
def _make_cache_key(self, name: str, args: dict) -> str:
|
|
"""Generate cache key"""
|
|
args_str = json.dumps(args, sort_keys=True, ensure_ascii=False)
|
|
return f"{name}:{args_str}"
|
|
|
|
def _is_cache_valid(self, cache_key: str) -> bool:
|
|
"""Check if cache is valid"""
|
|
if cache_key not in self._cache:
|
|
return False
|
|
_, timestamp = self._cache[cache_key]
|
|
return time.time() - timestamp < self.cache_ttl
|
|
|
|
def _get_cached(self, cache_key: str) -> Optional[Dict]:
|
|
"""Get cached result"""
|
|
if self.enable_cache and self._is_cache_valid(cache_key):
|
|
return self._cache[cache_key][0]
|
|
return None
|
|
|
|
def _set_cached(self, cache_key: str, result: Dict) -> None:
|
|
"""Set cache"""
|
|
if self.enable_cache:
|
|
self._cache[cache_key] = (result, time.time())
|
|
|
|
def _record_call(self, name: str, args: dict, result: Dict) -> None:
|
|
"""Record call history"""
|
|
self._call_history.append({
|
|
"name": name,
|
|
"args": args,
|
|
"result": result,
|
|
"timestamp": time.time()
|
|
})
|
|
|
|
# Limit history size
|
|
if len(self._call_history) > 1000:
|
|
self._call_history = self._call_history[-1000:]
|
|
|
|
def process_tool_calls(
|
|
self,
|
|
tool_calls: List[Dict[str, Any]],
|
|
context: Dict[str, Any]
|
|
) -> List[Dict[str, Any]]:
|
|
"""Process tool calls sequentially"""
|
|
results = []
|
|
|
|
for call in tool_calls:
|
|
call_id = call.get("id", "")
|
|
name = call.get("function", {}).get("name", "")
|
|
|
|
# Parse JSON arguments
|
|
try:
|
|
args = json.loads(call.get("function", {}).get("arguments", "{}"))
|
|
except json.JSONDecodeError:
|
|
args = {}
|
|
|
|
# Check cache
|
|
cache_key = self._make_cache_key(name, args)
|
|
cached = self._get_cached(cache_key)
|
|
|
|
if cached is not None:
|
|
result = cached
|
|
else:
|
|
# Execute tool
|
|
result = registry.execute(name, args)
|
|
self._set_cached(cache_key, result)
|
|
|
|
# Record call
|
|
self._record_call(name, args, result)
|
|
|
|
# Create result message
|
|
results.append(self._create_tool_result(call_id, name, result))
|
|
|
|
return results
|
|
|
|
def process_tool_calls_parallel(
|
|
self,
|
|
tool_calls: List[Dict[str, Any]],
|
|
context: Dict[str, Any]
|
|
) -> List[Dict[str, Any]]:
|
|
"""Process tool calls in parallel"""
|
|
if len(tool_calls) <= 1:
|
|
return self.process_tool_calls(tool_calls, context)
|
|
|
|
try:
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
futures = {}
|
|
|
|
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
|
for call in tool_calls:
|
|
call_id = call.get("id", "")
|
|
name = call.get("function", {}).get("name", "")
|
|
|
|
# Parse all arguments
|
|
try:
|
|
args = json.loads(call.get("function", {}).get("arguments", "{}"))
|
|
except json.JSONDecodeError:
|
|
args = {}
|
|
|
|
# Check cache
|
|
cache_key = self._make_cache_key(name, args)
|
|
cached = self._get_cached(cache_key)
|
|
|
|
if cached is not None:
|
|
futures[call_id] = (name, args, cached)
|
|
else:
|
|
# Submit task
|
|
future = executor.submit(registry.execute, name, args)
|
|
futures[future] = (call_id, name, args)
|
|
|
|
results = []
|
|
|
|
for future in as_completed(futures.keys()):
|
|
if future in futures:
|
|
call_id, name, args = futures[future]
|
|
result = future.result()
|
|
self._set_cached(self._make_cache_key(name, args), result)
|
|
self._record_call(name, args, result)
|
|
results.append(self._create_tool_result(call_id, name, result))
|
|
else:
|
|
call_id, name, args = futures[future]
|
|
result = future.result()
|
|
self._set_cached(self._make_cache_key(name, args), result)
|
|
self._record_call(name, args, result)
|
|
results.append(self._create_tool_result(call_id, name, result))
|
|
|
|
return results
|
|
except ImportError:
|
|
return self.process_tool_calls(tool_calls, context)
|
|
|
|
def _create_tool_result(self, call_id: str, name: str, result: Dict) -> Dict[str, Any]:
|
|
"""Create tool result message"""
|
|
return {
|
|
"tool_call_id": call_id,
|
|
"role": "tool",
|
|
"name": name,
|
|
"content": json.dumps(result, ensure_ascii=False)
|
|
}
|
|
|
|
def _create_error_result(self, call_id: str, name: str, error: str) -> Dict[str, Any]:
|
|
"""Create error result message"""
|
|
return {
|
|
"tool_call_id": call_id,
|
|
"role": "tool",
|
|
"name": name,
|
|
"content": json.dumps({"success": False, "error": error}, ensure_ascii=False)
|
|
}
|
|
|
|
def clear_cache(self) -> None:
|
|
"""Clear all cache"""
|
|
self._cache.clear()
|
|
|
|
def get_history(self, limit: int = 100) -> List[Dict[str, Any]]:
|
|
"""Get call history"""
|
|
return self._call_history[-limit:]
|