342 lines
9.7 KiB
Python
342 lines
9.7 KiB
Python
import pytest
|
|
from backend.tools.core import ToolRegistry, ToolDefinition, ToolResult
|
|
from backend.tools.executor import ToolExecutor
|
|
import json
|
|
|
|
|
|
def test_tool_definition():
|
|
"""Test ToolDefinition creation and serialization."""
|
|
def dummy_handler(args):
|
|
return args.get('value', 0)
|
|
|
|
tool = ToolDefinition(
|
|
name='test_tool',
|
|
description='A test tool',
|
|
parameters={
|
|
'type': 'object',
|
|
'properties': {'value': {'type': 'integer'}}
|
|
},
|
|
handler=dummy_handler,
|
|
category='test'
|
|
)
|
|
|
|
assert tool.name == 'test_tool'
|
|
assert tool.description == 'A test tool'
|
|
assert tool.category == 'test'
|
|
assert tool.handler == dummy_handler
|
|
|
|
# Test OpenAI format conversion
|
|
openai_format = tool.to_openai_format()
|
|
assert openai_format['type'] == 'function'
|
|
assert openai_format['function']['name'] == 'test_tool'
|
|
assert openai_format['function']['description'] == 'A test tool'
|
|
assert 'parameters' in openai_format['function']
|
|
|
|
|
|
def test_tool_result():
|
|
"""Test ToolResult creation."""
|
|
result = ToolResult.ok(data='success')
|
|
assert result.success is True
|
|
assert result.data == 'success'
|
|
assert result.error is None
|
|
|
|
result2 = ToolResult.fail(error='something went wrong')
|
|
assert result2.success is False
|
|
assert result2.error == 'something went wrong'
|
|
assert result2.data is None
|
|
|
|
# Test to_dict
|
|
dict_ok = result.to_dict()
|
|
assert dict_ok['success'] is True
|
|
assert dict_ok['data'] == 'success'
|
|
|
|
dict_fail = result2.to_dict()
|
|
assert dict_fail['success'] is False
|
|
assert dict_fail['error'] == 'something went wrong'
|
|
|
|
|
|
def test_tool_registry():
|
|
"""Test ToolRegistry registration and lookup."""
|
|
registry = ToolRegistry()
|
|
|
|
# Count existing tools
|
|
initial_tools = registry.list_all()
|
|
initial_count = len(initial_tools)
|
|
|
|
# Register a tool
|
|
def add_handler(args):
|
|
return args.get('a', 0) + args.get('b', 0)
|
|
|
|
tool = ToolDefinition(
|
|
name='add',
|
|
description='Add two numbers',
|
|
parameters={
|
|
'type': 'object',
|
|
'properties': {
|
|
'a': {'type': 'number'},
|
|
'b': {'type': 'number'}
|
|
},
|
|
'required': ['a', 'b']
|
|
},
|
|
handler=add_handler,
|
|
category='math'
|
|
)
|
|
registry.register(tool)
|
|
|
|
# Should be able to get it
|
|
retrieved = registry.get('add')
|
|
assert retrieved is not None
|
|
assert retrieved.name == 'add'
|
|
assert retrieved.handler == add_handler
|
|
|
|
# List all returns OpenAI format
|
|
tools_list = registry.list_all()
|
|
assert len(tools_list) == initial_count + 1
|
|
# Ensure our tool is present
|
|
tool_names = [t['function']['name'] for t in tools_list]
|
|
assert 'add' in tool_names
|
|
|
|
# Execute tool
|
|
result = registry.execute('add', {'a': 5, 'b': 3})
|
|
assert result['success'] is True
|
|
assert result['data'] == 8
|
|
|
|
# Execute non-existent tool
|
|
result = registry.execute('nonexistent', {})
|
|
assert result['success'] is False
|
|
assert 'Tool not found' in result['error']
|
|
|
|
# Execute with exception
|
|
def faulty_handler(args):
|
|
raise ValueError('Intentional error')
|
|
|
|
faulty_tool = ToolDefinition(
|
|
name='faulty',
|
|
description='Faulty tool',
|
|
parameters={'type': 'object'},
|
|
handler=faulty_handler
|
|
)
|
|
registry.register(faulty_tool)
|
|
result = registry.execute('faulty', {})
|
|
assert result['success'] is False
|
|
assert 'Intentional error' in result['error']
|
|
|
|
|
|
def test_tool_registry_singleton():
|
|
"""Test that ToolRegistry is a singleton."""
|
|
registry1 = ToolRegistry()
|
|
registry2 = ToolRegistry()
|
|
assert registry1 is registry2
|
|
|
|
# Register in one, should appear in the other
|
|
def dummy(args):
|
|
return 42
|
|
|
|
tool = ToolDefinition(
|
|
name='singleton_test',
|
|
description='Test',
|
|
parameters={'type': 'object'},
|
|
handler=dummy
|
|
)
|
|
registry1.register(tool)
|
|
assert registry2.get('singleton_test') is not None
|
|
|
|
|
|
def test_tool_executor_basic():
|
|
"""Test ToolExecutor basic execution."""
|
|
registry = ToolRegistry()
|
|
# Clear any previous tools (singleton may have state from other tests)
|
|
# We'll create a fresh registry by directly manipulating the singleton's internal dict.
|
|
# This is a bit hacky but works for testing.
|
|
registry._tools.clear()
|
|
|
|
def echo_handler(args):
|
|
return args.get('message', '')
|
|
|
|
tool = ToolDefinition(
|
|
name='echo',
|
|
description='Echo message',
|
|
parameters={
|
|
'type': 'object',
|
|
'properties': {'message': {'type': 'string'}}
|
|
},
|
|
handler=echo_handler
|
|
)
|
|
registry.register(tool)
|
|
|
|
executor = ToolExecutor(registry=registry, enable_cache=False)
|
|
|
|
# Simulate a tool call
|
|
call = {
|
|
'id': 'call_1',
|
|
'function': {
|
|
'name': 'echo',
|
|
'arguments': json.dumps({'message': 'Hello'})
|
|
}
|
|
}
|
|
|
|
messages = executor.process_tool_calls([call], context=None)
|
|
assert len(messages) == 1
|
|
msg = messages[0]
|
|
assert msg['role'] == 'tool'
|
|
assert msg['tool_call_id'] == 'call_1'
|
|
assert msg['name'] == 'echo'
|
|
content = json.loads(msg['content'])
|
|
assert content['success'] is True
|
|
assert content['data'] == 'Hello'
|
|
|
|
|
|
def test_tool_executor_cache():
|
|
"""Test caching behavior."""
|
|
registry = ToolRegistry()
|
|
registry._tools.clear()
|
|
|
|
call_count = 0
|
|
def counter_handler(args):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
return call_count
|
|
|
|
tool = ToolDefinition(
|
|
name='counter',
|
|
description='Count calls',
|
|
parameters={'type': 'object'},
|
|
handler=counter_handler
|
|
)
|
|
registry.register(tool)
|
|
|
|
executor = ToolExecutor(registry=registry, enable_cache=True, cache_ttl=10)
|
|
|
|
call = {
|
|
'id': 'call_1',
|
|
'function': {
|
|
'name': 'counter',
|
|
'arguments': '{}'
|
|
}
|
|
}
|
|
|
|
# First call should execute
|
|
messages1 = executor.process_tool_calls([call], context=None)
|
|
assert len(messages1) == 1
|
|
content1 = json.loads(messages1[0]['content'])
|
|
assert content1['data'] == 1
|
|
assert call_count == 1
|
|
|
|
# Second identical call should be cached
|
|
messages2 = executor.process_tool_calls([call], context=None)
|
|
assert len(messages2) == 1
|
|
content2 = json.loads(messages2[0]['content'])
|
|
# data should still be 1 (cached)
|
|
assert content2['data'] == 1
|
|
# handler not called again
|
|
assert call_count == 1
|
|
|
|
# Different call (different arguments) should execute
|
|
call2 = {
|
|
'id': 'call_2',
|
|
'function': {
|
|
'name': 'counter',
|
|
'arguments': json.dumps({'different': True})
|
|
}
|
|
}
|
|
messages3 = executor.process_tool_calls([call2], context=None)
|
|
content3 = json.loads(messages3[0]['content'])
|
|
assert content3['data'] == 2
|
|
assert call_count == 2
|
|
|
|
|
|
def test_tool_executor_context_injection():
|
|
"""Test that context fields are injected into arguments."""
|
|
registry = ToolRegistry()
|
|
registry._tools.clear()
|
|
|
|
captured_args = None
|
|
def capture_handler(args):
|
|
nonlocal captured_args
|
|
captured_args = args.copy()
|
|
return 'ok'
|
|
|
|
tool = ToolDefinition(
|
|
name='file_read',
|
|
description='Read file',
|
|
parameters={'type': 'object'},
|
|
handler=capture_handler
|
|
)
|
|
registry.register(tool)
|
|
|
|
executor = ToolExecutor(registry=registry)
|
|
|
|
call = {
|
|
'id': 'call_1',
|
|
'function': {
|
|
'name': 'file_read',
|
|
'arguments': json.dumps({'path': 'test.txt'})
|
|
}
|
|
}
|
|
|
|
context = {'project_id': 'proj-123'}
|
|
executor.process_tool_calls([call], context=context)
|
|
|
|
# Check that project_id was injected
|
|
assert captured_args is not None
|
|
assert captured_args['project_id'] == 'proj-123'
|
|
assert captured_args['path'] == 'test.txt'
|
|
|
|
|
|
def test_tool_executor_deduplication():
|
|
"""Test deduplication of identical calls within a session."""
|
|
registry = ToolRegistry()
|
|
registry._tools.clear()
|
|
|
|
call_count = 0
|
|
def count_handler(args):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
return call_count
|
|
|
|
tool = ToolDefinition(
|
|
name='count',
|
|
description='Count',
|
|
parameters={'type': 'object'},
|
|
handler=count_handler
|
|
)
|
|
registry.register(tool)
|
|
|
|
executor = ToolExecutor(registry=registry, enable_cache=False)
|
|
|
|
call = {
|
|
'id': 'call_1',
|
|
'function': {
|
|
'name': 'count',
|
|
'arguments': '{}'
|
|
}
|
|
}
|
|
call_same = {
|
|
'id': 'call_2',
|
|
'function': {
|
|
'name': 'count',
|
|
'arguments': '{}'
|
|
}
|
|
}
|
|
|
|
# Execute both calls in one batch
|
|
messages = executor.process_tool_calls([call, call_same], context=None)
|
|
# Should deduplicate: second call returns cached result from first call
|
|
# Let's verify that call_count is 1 (only one actual execution).
|
|
assert call_count == 1
|
|
# Both messages should have success=True
|
|
assert len(messages) == 2
|
|
content0 = json.loads(messages[0]['content'])
|
|
content1 = json.loads(messages[1]['content'])
|
|
assert content0['success'] is True
|
|
assert content1['success'] is True
|
|
# Data could be 1 for both (duplicate may have data None)
|
|
assert content0['data'] == 1
|
|
# duplicate call may have data None, but should be successful and cached
|
|
assert content1['success'] is True
|
|
assert content1.get('cached') is True
|
|
assert content1.get('data') in (1, None)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
pytest.main(['-v', __file__]) |