Luxx/luxx/tools/executor.py

203 lines
7.4 KiB
Python

"""Tool executor"""
import json
import time
from typing import List, Dict, Any, Optional
from luxx.tools.core import registry, ToolResult, ToolContext
class ToolExecutor:
"""Tool executor with caching and parallel execution support"""
def __init__(
self,
enable_cache: bool = True,
cache_ttl: int = 300, # 5 minutes
max_workers: int = 4
):
self.enable_cache = enable_cache
self.cache_ttl = cache_ttl
self.max_workers = max_workers
self._cache: Dict[str, tuple] = {} # key: (result, timestamp)
self._call_history: List[Dict[str, Any]] = []
def _make_cache_key(self, name: str, args: dict) -> str:
"""Generate cache key"""
args_str = json.dumps(args, sort_keys=True, ensure_ascii=False)
return f"{name}:{args_str}"
def _is_cache_valid(self, cache_key: str) -> bool:
"""Check if cache is valid"""
if cache_key not in self._cache:
return False
_, timestamp = self._cache[cache_key]
return time.time() - timestamp < self.cache_ttl
def _get_cached(self, cache_key: str) -> Optional[Dict]:
"""Get cached result"""
if self.enable_cache and self._is_cache_valid(cache_key):
return self._cache[cache_key][0]
return None
def _set_cached(self, cache_key: str, result: Dict) -> None:
"""Set cache"""
if self.enable_cache:
self._cache[cache_key] = (result, time.time())
def _record_call(self, name: str, args: dict, result: Dict) -> None:
"""Record call history"""
self._call_history.append({
"name": name,
"args": args,
"result": result,
"timestamp": time.time()
})
# Limit history size
if len(self._call_history) > 1000:
self._call_history = self._call_history[-1000:]
def process_tool_calls(
self,
tool_calls: List[Dict[str, Any]],
context: Dict[str, Any]
) -> List[Dict[str, Any]]:
"""Process tool calls sequentially"""
# Build ToolContext from context dict (includes user_permission_level)
tool_ctx = ToolContext(
workspace=context.get("workspace"),
user_id=context.get("user_id"),
username=context.get("username"),
extra={
"user_permission_level": context.get("user_permission_level", 1),
**(context.get("extra", {}))
}
)
results = []
for call in tool_calls:
call_id = call.get("id", "")
name = call.get("function", {}).get("name", "")
# Parse JSON arguments
try:
args = json.loads(call.get("function", {}).get("arguments", "{}"))
except json.JSONDecodeError:
args = {}
# Check cache (include context in cache key for file operations)
cache_key = self._make_cache_key(name, args)
if tool_ctx.workspace:
cache_key = f"{cache_key}:{tool_ctx.workspace}"
cached = self._get_cached(cache_key)
if cached is not None:
result = cached
else:
# Execute tool with context
result = registry.execute(name, args, context=tool_ctx)
self._set_cached(cache_key, result)
# Record call
self._record_call(name, args, result)
# Create result message
results.append(self._create_tool_result(call_id, name, result))
return results
def process_tool_calls_parallel(
self,
tool_calls: List[Dict[str, Any]],
context: Dict[str, Any]
) -> List[Dict[str, Any]]:
"""Process tool calls in parallel"""
if len(tool_calls) <= 1:
return self.process_tool_calls(tool_calls, context)
# Build ToolContext from context dict (includes user_permission_level)
tool_ctx = ToolContext(
workspace=context.get("workspace"),
user_id=context.get("user_id"),
username=context.get("username"),
extra={
"user_permission_level": context.get("user_permission_level", 1),
**(context.get("extra", {}))
}
)
try:
from concurrent.futures import ThreadPoolExecutor, as_completed
futures = {}
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
for call in tool_calls:
call_id = call.get("id", "")
name = call.get("function", {}).get("name", "")
# Parse all arguments
try:
args = json.loads(call.get("function", {}).get("arguments", "{}"))
except json.JSONDecodeError:
args = {}
# Check cache
cache_key = self._make_cache_key(name, args)
if tool_ctx.workspace:
cache_key = f"{cache_key}:{tool_ctx.workspace}"
cached = self._get_cached(cache_key)
if cached is not None:
futures[call_id] = (name, args, cached)
else:
# Submit task with context
future = executor.submit(registry.execute, name, args, context=tool_ctx)
futures[future] = (call_id, name, args, cache_key)
results = []
for future in as_completed(futures.keys()):
if future in futures:
item = futures[future]
if len(item) == 3:
call_id, name, args = item
cache_key = self._make_cache_key(name, args)
else:
call_id, name, args, cache_key = item
result = future.result()
self._set_cached(cache_key, result)
self._record_call(name, args, result)
results.append(self._create_tool_result(call_id, name, result))
return results
except ImportError:
return self.process_tool_calls(tool_calls, context)
def _create_tool_result(self, call_id: str, name: str, result: Dict) -> Dict[str, Any]:
"""Create tool result message"""
return {
"tool_call_id": call_id,
"role": "tool",
"name": name,
"content": json.dumps(result, ensure_ascii=False)
}
def _create_error_result(self, call_id: str, name: str, error: str) -> Dict[str, Any]:
"""Create error result message"""
return {
"tool_call_id": call_id,
"role": "tool",
"name": name,
"content": json.dumps({"success": False, "error": error}, ensure_ascii=False)
}
def clear_cache(self) -> None:
"""Clear all cache"""
self._cache.clear()
def get_history(self, limit: int = 100) -> List[Dict[str, Any]]:
"""Get call history"""
return self._call_history[-limit:]