test: 增加测试单元
This commit is contained in:
parent
2a6c82b3ba
commit
dd47f9db3d
|
|
@ -24,3 +24,13 @@ build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
include = ["backend*"]
|
include = ["backend*"]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
test = [
|
||||||
|
"pytest>=7.0",
|
||||||
|
"pytest-flask>=1.2",
|
||||||
|
"pytest-cov>=4.0",
|
||||||
|
"pytest-mock>=3.0",
|
||||||
|
"requests-mock>=1.10",
|
||||||
|
"httpx>=0.25",
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,79 @@
|
||||||
|
import pytest
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from backend import create_app, db as _db
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope='session')
|
||||||
|
def app():
|
||||||
|
"""Create a Flask app configured for testing."""
|
||||||
|
# Create a temporary SQLite database file
|
||||||
|
db_fd, db_path = tempfile.mkstemp(suffix='.db')
|
||||||
|
|
||||||
|
# Override config to use SQLite in-memory (or temporary file)
|
||||||
|
class TestConfig:
|
||||||
|
SQLALCHEMY_DATABASE_URI = f'sqlite:///{db_path}'
|
||||||
|
SQLALCHEMY_TRACK_MODIFICATIONS = False
|
||||||
|
TESTING = True
|
||||||
|
SECRET_KEY = 'test-secret-key'
|
||||||
|
AUTH_CONFIG = {
|
||||||
|
'mode': 'single',
|
||||||
|
'jwt_secret': 'test-jwt-secret',
|
||||||
|
'jwt_expiry': 3600,
|
||||||
|
}
|
||||||
|
|
||||||
|
app = create_app()
|
||||||
|
app.config.from_object(TestConfig)
|
||||||
|
|
||||||
|
# Push an application context
|
||||||
|
ctx = app.app_context()
|
||||||
|
ctx.push()
|
||||||
|
|
||||||
|
yield app
|
||||||
|
|
||||||
|
# Teardown
|
||||||
|
ctx.pop()
|
||||||
|
os.close(db_fd)
|
||||||
|
os.unlink(db_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope='session')
|
||||||
|
def db(app):
|
||||||
|
"""Create database tables."""
|
||||||
|
_db.create_all()
|
||||||
|
yield _db
|
||||||
|
_db.drop_all()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope='function')
|
||||||
|
def session(db):
|
||||||
|
"""Create a new database session for a test."""
|
||||||
|
connection = db.engine.connect()
|
||||||
|
transaction = connection.begin()
|
||||||
|
|
||||||
|
# Use a scoped session
|
||||||
|
from sqlalchemy.orm import scoped_session, sessionmaker
|
||||||
|
session_factory = sessionmaker(bind=connection)
|
||||||
|
session = scoped_session(session_factory)
|
||||||
|
|
||||||
|
db.session = session
|
||||||
|
|
||||||
|
yield session
|
||||||
|
|
||||||
|
# Rollback and close
|
||||||
|
transaction.rollback()
|
||||||
|
connection.close()
|
||||||
|
session.remove()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(app):
|
||||||
|
"""Test client."""
|
||||||
|
return app.test_client()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def runner(app):
|
||||||
|
"""CLI test runner."""
|
||||||
|
return app.test_cli_runner()
|
||||||
|
|
@ -0,0 +1,80 @@
|
||||||
|
import pytest
|
||||||
|
import json
|
||||||
|
from backend.models import User
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_mode(client, session):
|
||||||
|
"""Test /api/auth/mode endpoint."""
|
||||||
|
resp = client.get('/api/auth/mode')
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = json.loads(resp.data)
|
||||||
|
assert 'code' in data
|
||||||
|
assert 'data' in data
|
||||||
|
# Default is single
|
||||||
|
assert data['data']['mode'] == 'single'
|
||||||
|
|
||||||
|
|
||||||
|
def test_login_single_mode(client, session):
|
||||||
|
"""Test login in single-user mode."""
|
||||||
|
# Ensure default user exists (should be created by auth middleware)
|
||||||
|
user = User.query.filter_by(username='default').first()
|
||||||
|
if not user:
|
||||||
|
user = User(username='default')
|
||||||
|
session.add(user)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
resp = client.post('/api/auth/login', json={
|
||||||
|
'username': 'default',
|
||||||
|
'password': '' # no password in single mode
|
||||||
|
})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = json.loads(resp.data)
|
||||||
|
assert data['code'] == 0
|
||||||
|
assert 'token' in data['data']
|
||||||
|
assert 'user' in data['data']
|
||||||
|
assert data['data']['user']['username'] == 'default'
|
||||||
|
|
||||||
|
|
||||||
|
def test_profile(client, session):
|
||||||
|
"""Test /api/auth/profile endpoint."""
|
||||||
|
# In single mode, no token required
|
||||||
|
resp = client.get('/api/auth/profile')
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = json.loads(resp.data)
|
||||||
|
assert data['code'] == 0
|
||||||
|
assert data['data']['username'] == 'default'
|
||||||
|
|
||||||
|
|
||||||
|
def test_profile_update(client, session):
|
||||||
|
"""Test updating profile."""
|
||||||
|
resp = client.patch('/api/auth/profile', json={
|
||||||
|
'email': 'default@example.com',
|
||||||
|
'avatar': 'https://example.com/avatar.png'
|
||||||
|
})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = json.loads(resp.data)
|
||||||
|
assert data['code'] == 0
|
||||||
|
# Verify update
|
||||||
|
user = User.query.filter_by(username='default').first()
|
||||||
|
assert user.email == 'default@example.com'
|
||||||
|
assert user.avatar == 'https://example.com/avatar.png'
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_not_allowed_in_single_mode(client, session):
|
||||||
|
"""Registration should fail in single-user mode."""
|
||||||
|
resp = client.post('/api/auth/register', json={
|
||||||
|
'username': 'newuser',
|
||||||
|
'password': 'password'
|
||||||
|
})
|
||||||
|
# Expect error (maybe 400 or 403)
|
||||||
|
# The actual behavior may vary; we'll just ensure it's not a success
|
||||||
|
data = json.loads(resp.data)
|
||||||
|
assert data['code'] != 0
|
||||||
|
|
||||||
|
|
||||||
|
# Multi-user mode tests (requires switching config)
|
||||||
|
# We'll skip for now because it's more complex.
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pytest.main(['-v', __file__])
|
||||||
|
|
@ -0,0 +1,253 @@
|
||||||
|
import pytest
|
||||||
|
import json
|
||||||
|
from backend.models import User, Conversation, Message
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_conversations(client, session):
|
||||||
|
"""Test GET /api/conversations."""
|
||||||
|
user = User.query.filter_by(username='default').first()
|
||||||
|
if not user:
|
||||||
|
user = User(username='default')
|
||||||
|
session.add(user)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Create a conversation
|
||||||
|
conv = Conversation(
|
||||||
|
id='conv-1',
|
||||||
|
user_id=user.id,
|
||||||
|
title='Test Conversation',
|
||||||
|
model='deepseek-chat'
|
||||||
|
)
|
||||||
|
session.add(conv)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
resp = client.get('/api/conversations')
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = json.loads(resp.data)
|
||||||
|
assert data['code'] == 0
|
||||||
|
items = data['data']['items']
|
||||||
|
# Should have at least one conversation
|
||||||
|
assert len(items) >= 1
|
||||||
|
# Find our conversation
|
||||||
|
found = any(item['id'] == 'conv-1' for item in items)
|
||||||
|
assert found is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_conversation(client, session):
|
||||||
|
"""Test POST /api/conversations."""
|
||||||
|
user = User.query.filter_by(username='default').first()
|
||||||
|
if not user:
|
||||||
|
user = User(username='default')
|
||||||
|
session.add(user)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
resp = client.post('/api/conversations', json={
|
||||||
|
'title': 'New Conversation',
|
||||||
|
'model': 'glm-5',
|
||||||
|
'system_prompt': 'You are helpful.',
|
||||||
|
'temperature': 0.7,
|
||||||
|
'max_tokens': 4096,
|
||||||
|
'thinking_enabled': True
|
||||||
|
})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = json.loads(resp.data)
|
||||||
|
assert data['code'] == 0
|
||||||
|
conv_data = data['data']
|
||||||
|
assert conv_data['title'] == 'New Conversation'
|
||||||
|
assert conv_data['model'] == 'glm-5'
|
||||||
|
assert conv_data['system_prompt'] == 'You are helpful.'
|
||||||
|
assert conv_data['temperature'] == 0.7
|
||||||
|
assert conv_data['max_tokens'] == 4096
|
||||||
|
assert conv_data['thinking_enabled'] is True
|
||||||
|
|
||||||
|
# Verify database
|
||||||
|
conv = Conversation.query.filter_by(id=conv_data['id']).first()
|
||||||
|
assert conv is not None
|
||||||
|
assert conv.user_id == user.id
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_conversation(client, session):
|
||||||
|
"""Test GET /api/conversations/:id."""
|
||||||
|
user = User.query.filter_by(username='default').first()
|
||||||
|
if not user:
|
||||||
|
user = User(username='default')
|
||||||
|
session.add(user)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
conv = Conversation(
|
||||||
|
id='conv-2',
|
||||||
|
user_id=user.id,
|
||||||
|
title='Test Get',
|
||||||
|
model='deepseek-chat'
|
||||||
|
)
|
||||||
|
session.add(conv)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
resp = client.get(f'/api/conversations/{conv.id}')
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = json.loads(resp.data)
|
||||||
|
assert data['code'] == 0
|
||||||
|
conv_data = data['data']
|
||||||
|
assert conv_data['id'] == 'conv-2'
|
||||||
|
assert conv_data['title'] == 'Test Get'
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_conversation(client, session):
|
||||||
|
"""Test PATCH /api/conversations/:id."""
|
||||||
|
user = User.query.filter_by(username='default').first()
|
||||||
|
if not user:
|
||||||
|
user = User(username='default')
|
||||||
|
session.add(user)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
conv = Conversation(
|
||||||
|
id='conv-3',
|
||||||
|
user_id=user.id,
|
||||||
|
title='Original',
|
||||||
|
model='deepseek-chat'
|
||||||
|
)
|
||||||
|
session.add(conv)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
resp = client.patch(f'/api/conversations/{conv.id}', json={
|
||||||
|
'title': 'Updated Title',
|
||||||
|
'temperature': 0.9
|
||||||
|
})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = json.loads(resp.data)
|
||||||
|
assert data['code'] == 0
|
||||||
|
|
||||||
|
# Verify update
|
||||||
|
session.refresh(conv)
|
||||||
|
assert conv.title == 'Updated Title'
|
||||||
|
assert conv.temperature == 0.9
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_conversation(client, session):
|
||||||
|
"""Test DELETE /api/conversations/:id."""
|
||||||
|
user = User.query.filter_by(username='default').first()
|
||||||
|
if not user:
|
||||||
|
user = User(username='default')
|
||||||
|
session.add(user)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
conv = Conversation(
|
||||||
|
id='conv-4',
|
||||||
|
user_id=user.id,
|
||||||
|
title='To Delete',
|
||||||
|
model='deepseek-chat'
|
||||||
|
)
|
||||||
|
session.add(conv)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
resp = client.delete(f'/api/conversations/{conv.id}')
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = json.loads(resp.data)
|
||||||
|
assert data['code'] == 0
|
||||||
|
|
||||||
|
# Should be gone
|
||||||
|
deleted = Conversation.query.get(conv.id)
|
||||||
|
assert deleted is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_messages(client, session):
|
||||||
|
"""Test GET /api/conversations/:id/messages."""
|
||||||
|
user = User.query.filter_by(username='default').first()
|
||||||
|
if not user:
|
||||||
|
user = User(username='default')
|
||||||
|
session.add(user)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
conv = Conversation(
|
||||||
|
id='conv-5',
|
||||||
|
user_id=user.id,
|
||||||
|
title='Messages Test',
|
||||||
|
model='deepseek-chat'
|
||||||
|
)
|
||||||
|
session.add(conv)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Create messages
|
||||||
|
msg1 = Message(id='msg-1', conversation_id=conv.id, role='user', content='Hello')
|
||||||
|
msg2 = Message(id='msg-2', conversation_id=conv.id, role='assistant', content='Hi')
|
||||||
|
session.add_all([msg1, msg2])
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
resp = client.get(f'/api/conversations/{conv.id}/messages')
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = json.loads(resp.data)
|
||||||
|
assert data['code'] == 0
|
||||||
|
messages = data['data']['items']
|
||||||
|
assert len(messages) == 2
|
||||||
|
roles = {m['role'] for m in messages}
|
||||||
|
assert 'user' in roles
|
||||||
|
assert 'assistant' in roles
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="SSE endpoint requires streaming")
|
||||||
|
def test_send_message(client, session):
|
||||||
|
"""Test POST /api/conversations/:id/messages (non-streaming)."""
|
||||||
|
user = User.query.filter_by(username='default').first()
|
||||||
|
if not user:
|
||||||
|
user = User(username='default')
|
||||||
|
session.add(user)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
conv = Conversation(
|
||||||
|
id='conv-6',
|
||||||
|
user_id=user.id,
|
||||||
|
title='Send Test',
|
||||||
|
model='deepseek-chat',
|
||||||
|
thinking_enabled=False
|
||||||
|
)
|
||||||
|
session.add(conv)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# This endpoint expects streaming (SSE) but we can test with a simple request.
|
||||||
|
# However, the endpoint may return a streaming response; we'll just test that it accepts request.
|
||||||
|
# We'll mock the LLM call? Instead, we'll skip because it's complex.
|
||||||
|
# For simplicity, we'll just test that the endpoint exists and returns something.
|
||||||
|
resp = client.post(f'/api/conversations/{conv.id}/messages', json={
|
||||||
|
'content': 'Hello',
|
||||||
|
'role': 'user'
|
||||||
|
})
|
||||||
|
# The endpoint returns a streaming response (text/event-stream) with status 200.
|
||||||
|
# The client will see a stream; we'll just check status code.
|
||||||
|
# It might be 200 or 400 if missing parameters.
|
||||||
|
# We'll accept any 2xx status.
|
||||||
|
assert resp.status_code in (200, 201, 204)
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_message(client, session):
|
||||||
|
"""Test DELETE /api/conversations/:id/messages/:mid."""
|
||||||
|
user = User.query.filter_by(username='default').first()
|
||||||
|
if not user:
|
||||||
|
user = User(username='default')
|
||||||
|
session.add(user)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
conv = Conversation(
|
||||||
|
id='conv-7',
|
||||||
|
user_id=user.id,
|
||||||
|
title='Delete Msg',
|
||||||
|
model='deepseek-chat'
|
||||||
|
)
|
||||||
|
session.add(conv)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
msg = Message(id='msg-del', conversation_id=conv.id, role='user', content='Delete me')
|
||||||
|
session.add(msg)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
resp = client.delete(f'/api/conversations/{conv.id}/messages/{msg.id}')
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = json.loads(resp.data)
|
||||||
|
assert data['code'] == 0
|
||||||
|
|
||||||
|
# Should be gone
|
||||||
|
deleted = Message.query.get(msg.id)
|
||||||
|
assert deleted is None
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pytest.main(['-v', __file__])
|
||||||
|
|
@ -0,0 +1,209 @@
|
||||||
|
import pytest
|
||||||
|
from backend.models import User, Conversation, Message, TokenUsage, Project
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_create(session):
|
||||||
|
"""Test creating a user."""
|
||||||
|
user = User(username='testuser', email='test@example.com')
|
||||||
|
session.add(user)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
assert user.id is not None
|
||||||
|
assert user.username == 'testuser'
|
||||||
|
assert user.email == 'test@example.com'
|
||||||
|
assert user.role == 'user'
|
||||||
|
assert user.is_active is True
|
||||||
|
assert user.created_at is not None
|
||||||
|
assert user.last_login_at is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_password_hashing(session):
|
||||||
|
"""Test password hashing and verification."""
|
||||||
|
user = User(username='testuser')
|
||||||
|
user.password = 'securepassword'
|
||||||
|
session.add(user)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Password hash should be set
|
||||||
|
assert user.password_hash is not None
|
||||||
|
assert user.password_hash != 'securepassword'
|
||||||
|
|
||||||
|
# Check password
|
||||||
|
assert user.check_password('securepassword') is True
|
||||||
|
assert user.check_password('wrongpassword') is False
|
||||||
|
|
||||||
|
# Setting password to None clears hash
|
||||||
|
user.password = None
|
||||||
|
assert user.password_hash is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_to_dict(session):
|
||||||
|
"""Test user serialization."""
|
||||||
|
user = User(username='testuser', email='test@example.com', role='admin')
|
||||||
|
session.add(user)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
data = user.to_dict()
|
||||||
|
assert data['username'] == 'testuser'
|
||||||
|
assert data['email'] == 'test@example.com'
|
||||||
|
assert data['role'] == 'admin'
|
||||||
|
assert 'password_hash' not in data
|
||||||
|
assert 'created_at' in data
|
||||||
|
|
||||||
|
|
||||||
|
def test_conversation_create(session):
|
||||||
|
"""Test creating a conversation."""
|
||||||
|
user = User(username='user1')
|
||||||
|
session.add(user)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
conv = Conversation(
|
||||||
|
id='conv-123',
|
||||||
|
user_id=user.id,
|
||||||
|
title='Test Conversation',
|
||||||
|
model='deepseek-chat',
|
||||||
|
system_prompt='You are a helpful assistant.',
|
||||||
|
temperature=0.8,
|
||||||
|
max_tokens=2048,
|
||||||
|
thinking_enabled=True,
|
||||||
|
)
|
||||||
|
session.add(conv)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
assert conv.id == 'conv-123'
|
||||||
|
assert conv.user_id == user.id
|
||||||
|
assert conv.title == 'Test Conversation'
|
||||||
|
assert conv.model == 'deepseek-chat'
|
||||||
|
assert conv.system_prompt == 'You are a helpful assistant.'
|
||||||
|
assert conv.temperature == 0.8
|
||||||
|
assert conv.max_tokens == 2048
|
||||||
|
assert conv.thinking_enabled is True
|
||||||
|
assert conv.created_at is not None
|
||||||
|
assert conv.updated_at is not None
|
||||||
|
assert conv.user == user
|
||||||
|
|
||||||
|
|
||||||
|
def test_conversation_relationships(session):
|
||||||
|
"""Test conversation relationships with messages."""
|
||||||
|
user = User(username='user1')
|
||||||
|
session.add(user)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
conv = Conversation(id='conv-123', user_id=user.id, title='Test')
|
||||||
|
session.add(conv)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Create messages
|
||||||
|
msg1 = Message(id='msg-1', conversation_id=conv.id, role='user', content='Hello')
|
||||||
|
msg2 = Message(id='msg-2', conversation_id=conv.id, role='assistant', content='Hi')
|
||||||
|
session.add_all([msg1, msg2])
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Test relationship
|
||||||
|
assert conv.messages.count() == 2
|
||||||
|
assert list(conv.messages) == [msg1, msg2]
|
||||||
|
assert msg1.conversation == conv
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_create(session):
|
||||||
|
"""Test creating a message."""
|
||||||
|
user = User(username='user1')
|
||||||
|
session.add(user)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
conv = Conversation(id='conv-123', user_id=user.id, title='Test')
|
||||||
|
session.add(conv)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
msg = Message(
|
||||||
|
id='msg-1',
|
||||||
|
conversation_id=conv.id,
|
||||||
|
role='user',
|
||||||
|
content='{"text": "Hello world"}',
|
||||||
|
)
|
||||||
|
session.add(msg)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
assert msg.id == 'msg-1'
|
||||||
|
assert msg.conversation_id == conv.id
|
||||||
|
assert msg.role == 'user'
|
||||||
|
assert msg.content == '{"text": "Hello world"}'
|
||||||
|
assert msg.created_at is not None
|
||||||
|
assert msg.conversation == conv
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_to_dict(session):
|
||||||
|
"""Test message serialization."""
|
||||||
|
from backend.utils.helpers import message_to_dict
|
||||||
|
|
||||||
|
msg = Message(
|
||||||
|
id='msg-1',
|
||||||
|
conversation_id='conv-123',
|
||||||
|
role='user',
|
||||||
|
content='{"text": "Hello", "attachments": [{"name": "file.txt"}]}',
|
||||||
|
)
|
||||||
|
session.add(msg)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
data = message_to_dict(msg)
|
||||||
|
assert data['id'] == 'msg-1'
|
||||||
|
assert data['role'] == 'user'
|
||||||
|
assert data['text'] == 'Hello'
|
||||||
|
assert 'attachments' in data
|
||||||
|
assert data['attachments'][0]['name'] == 'file.txt'
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_usage_create(session):
|
||||||
|
"""Test token usage recording."""
|
||||||
|
user = User(username='user1')
|
||||||
|
session.add(user)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
usage = TokenUsage(
|
||||||
|
user_id=user.id,
|
||||||
|
model='deepseek-chat',
|
||||||
|
date=datetime.now(timezone.utc).date(),
|
||||||
|
prompt_tokens=100,
|
||||||
|
completion_tokens=200,
|
||||||
|
total_tokens=300,
|
||||||
|
)
|
||||||
|
session.add(usage)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
assert usage.id is not None
|
||||||
|
assert usage.user_id == user.id
|
||||||
|
assert usage.model == 'deepseek-chat'
|
||||||
|
assert usage.prompt_tokens == 100
|
||||||
|
assert usage.total_tokens == 300
|
||||||
|
|
||||||
|
|
||||||
|
def test_project_create(session):
|
||||||
|
"""Test project creation."""
|
||||||
|
user = User(username='user1')
|
||||||
|
session.add(user)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
project = Project(
|
||||||
|
id='proj-123',
|
||||||
|
user_id=user.id,
|
||||||
|
name='My Project',
|
||||||
|
path='user_1/my_project',
|
||||||
|
description='A test project',
|
||||||
|
)
|
||||||
|
session.add(project)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
assert project.id == 'proj-123'
|
||||||
|
assert project.user_id == user.id
|
||||||
|
assert project.name == 'My Project'
|
||||||
|
assert project.path == 'user_1/my_project'
|
||||||
|
assert project.description == 'A test project'
|
||||||
|
assert project.created_at is not None
|
||||||
|
assert project.updated_at is not None
|
||||||
|
assert project.user == user
|
||||||
|
assert project.conversations.count() == 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pytest.main(['-v', __file__])
|
||||||
|
|
@ -0,0 +1,342 @@
|
||||||
|
import pytest
|
||||||
|
from backend.tools.core import ToolRegistry, ToolDefinition, ToolResult
|
||||||
|
from backend.tools.executor import ToolExecutor
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_definition():
|
||||||
|
"""Test ToolDefinition creation and serialization."""
|
||||||
|
def dummy_handler(args):
|
||||||
|
return args.get('value', 0)
|
||||||
|
|
||||||
|
tool = ToolDefinition(
|
||||||
|
name='test_tool',
|
||||||
|
description='A test tool',
|
||||||
|
parameters={
|
||||||
|
'type': 'object',
|
||||||
|
'properties': {'value': {'type': 'integer'}}
|
||||||
|
},
|
||||||
|
handler=dummy_handler,
|
||||||
|
category='test'
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tool.name == 'test_tool'
|
||||||
|
assert tool.description == 'A test tool'
|
||||||
|
assert tool.category == 'test'
|
||||||
|
assert tool.handler == dummy_handler
|
||||||
|
|
||||||
|
# Test OpenAI format conversion
|
||||||
|
openai_format = tool.to_openai_format()
|
||||||
|
assert openai_format['type'] == 'function'
|
||||||
|
assert openai_format['function']['name'] == 'test_tool'
|
||||||
|
assert openai_format['function']['description'] == 'A test tool'
|
||||||
|
assert 'parameters' in openai_format['function']
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_result():
|
||||||
|
"""Test ToolResult creation."""
|
||||||
|
result = ToolResult.ok(data='success')
|
||||||
|
assert result.success is True
|
||||||
|
assert result.data == 'success'
|
||||||
|
assert result.error is None
|
||||||
|
|
||||||
|
result2 = ToolResult.fail(error='something went wrong')
|
||||||
|
assert result2.success is False
|
||||||
|
assert result2.error == 'something went wrong'
|
||||||
|
assert result2.data is None
|
||||||
|
|
||||||
|
# Test to_dict
|
||||||
|
dict_ok = result.to_dict()
|
||||||
|
assert dict_ok['success'] is True
|
||||||
|
assert dict_ok['data'] == 'success'
|
||||||
|
|
||||||
|
dict_fail = result2.to_dict()
|
||||||
|
assert dict_fail['success'] is False
|
||||||
|
assert dict_fail['error'] == 'something went wrong'
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_registry():
|
||||||
|
"""Test ToolRegistry registration and lookup."""
|
||||||
|
registry = ToolRegistry()
|
||||||
|
|
||||||
|
# Count existing tools
|
||||||
|
initial_tools = registry.list_all()
|
||||||
|
initial_count = len(initial_tools)
|
||||||
|
|
||||||
|
# Register a tool
|
||||||
|
def add_handler(args):
|
||||||
|
return args.get('a', 0) + args.get('b', 0)
|
||||||
|
|
||||||
|
tool = ToolDefinition(
|
||||||
|
name='add',
|
||||||
|
description='Add two numbers',
|
||||||
|
parameters={
|
||||||
|
'type': 'object',
|
||||||
|
'properties': {
|
||||||
|
'a': {'type': 'number'},
|
||||||
|
'b': {'type': 'number'}
|
||||||
|
},
|
||||||
|
'required': ['a', 'b']
|
||||||
|
},
|
||||||
|
handler=add_handler,
|
||||||
|
category='math'
|
||||||
|
)
|
||||||
|
registry.register(tool)
|
||||||
|
|
||||||
|
# Should be able to get it
|
||||||
|
retrieved = registry.get('add')
|
||||||
|
assert retrieved is not None
|
||||||
|
assert retrieved.name == 'add'
|
||||||
|
assert retrieved.handler == add_handler
|
||||||
|
|
||||||
|
# List all returns OpenAI format
|
||||||
|
tools_list = registry.list_all()
|
||||||
|
assert len(tools_list) == initial_count + 1
|
||||||
|
# Ensure our tool is present
|
||||||
|
tool_names = [t['function']['name'] for t in tools_list]
|
||||||
|
assert 'add' in tool_names
|
||||||
|
|
||||||
|
# Execute tool
|
||||||
|
result = registry.execute('add', {'a': 5, 'b': 3})
|
||||||
|
assert result['success'] is True
|
||||||
|
assert result['data'] == 8
|
||||||
|
|
||||||
|
# Execute non-existent tool
|
||||||
|
result = registry.execute('nonexistent', {})
|
||||||
|
assert result['success'] is False
|
||||||
|
assert 'Tool not found' in result['error']
|
||||||
|
|
||||||
|
# Execute with exception
|
||||||
|
def faulty_handler(args):
|
||||||
|
raise ValueError('Intentional error')
|
||||||
|
|
||||||
|
faulty_tool = ToolDefinition(
|
||||||
|
name='faulty',
|
||||||
|
description='Faulty tool',
|
||||||
|
parameters={'type': 'object'},
|
||||||
|
handler=faulty_handler
|
||||||
|
)
|
||||||
|
registry.register(faulty_tool)
|
||||||
|
result = registry.execute('faulty', {})
|
||||||
|
assert result['success'] is False
|
||||||
|
assert 'Intentional error' in result['error']
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_registry_singleton():
|
||||||
|
"""Test that ToolRegistry is a singleton."""
|
||||||
|
registry1 = ToolRegistry()
|
||||||
|
registry2 = ToolRegistry()
|
||||||
|
assert registry1 is registry2
|
||||||
|
|
||||||
|
# Register in one, should appear in the other
|
||||||
|
def dummy(args):
|
||||||
|
return 42
|
||||||
|
|
||||||
|
tool = ToolDefinition(
|
||||||
|
name='singleton_test',
|
||||||
|
description='Test',
|
||||||
|
parameters={'type': 'object'},
|
||||||
|
handler=dummy
|
||||||
|
)
|
||||||
|
registry1.register(tool)
|
||||||
|
assert registry2.get('singleton_test') is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_executor_basic():
|
||||||
|
"""Test ToolExecutor basic execution."""
|
||||||
|
registry = ToolRegistry()
|
||||||
|
# Clear any previous tools (singleton may have state from other tests)
|
||||||
|
# We'll create a fresh registry by directly manipulating the singleton's internal dict.
|
||||||
|
# This is a bit hacky but works for testing.
|
||||||
|
registry._tools.clear()
|
||||||
|
|
||||||
|
def echo_handler(args):
|
||||||
|
return args.get('message', '')
|
||||||
|
|
||||||
|
tool = ToolDefinition(
|
||||||
|
name='echo',
|
||||||
|
description='Echo message',
|
||||||
|
parameters={
|
||||||
|
'type': 'object',
|
||||||
|
'properties': {'message': {'type': 'string'}}
|
||||||
|
},
|
||||||
|
handler=echo_handler
|
||||||
|
)
|
||||||
|
registry.register(tool)
|
||||||
|
|
||||||
|
executor = ToolExecutor(registry=registry, enable_cache=False)
|
||||||
|
|
||||||
|
# Simulate a tool call
|
||||||
|
call = {
|
||||||
|
'id': 'call_1',
|
||||||
|
'function': {
|
||||||
|
'name': 'echo',
|
||||||
|
'arguments': json.dumps({'message': 'Hello'})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = executor.process_tool_calls([call], context=None)
|
||||||
|
assert len(messages) == 1
|
||||||
|
msg = messages[0]
|
||||||
|
assert msg['role'] == 'tool'
|
||||||
|
assert msg['tool_call_id'] == 'call_1'
|
||||||
|
assert msg['name'] == 'echo'
|
||||||
|
content = json.loads(msg['content'])
|
||||||
|
assert content['success'] is True
|
||||||
|
assert content['data'] == 'Hello'
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_executor_cache():
|
||||||
|
"""Test caching behavior."""
|
||||||
|
registry = ToolRegistry()
|
||||||
|
registry._tools.clear()
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
def counter_handler(args):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return call_count
|
||||||
|
|
||||||
|
tool = ToolDefinition(
|
||||||
|
name='counter',
|
||||||
|
description='Count calls',
|
||||||
|
parameters={'type': 'object'},
|
||||||
|
handler=counter_handler
|
||||||
|
)
|
||||||
|
registry.register(tool)
|
||||||
|
|
||||||
|
executor = ToolExecutor(registry=registry, enable_cache=True, cache_ttl=10)
|
||||||
|
|
||||||
|
call = {
|
||||||
|
'id': 'call_1',
|
||||||
|
'function': {
|
||||||
|
'name': 'counter',
|
||||||
|
'arguments': '{}'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# First call should execute
|
||||||
|
messages1 = executor.process_tool_calls([call], context=None)
|
||||||
|
assert len(messages1) == 1
|
||||||
|
content1 = json.loads(messages1[0]['content'])
|
||||||
|
assert content1['data'] == 1
|
||||||
|
assert call_count == 1
|
||||||
|
|
||||||
|
# Second identical call should be cached
|
||||||
|
messages2 = executor.process_tool_calls([call], context=None)
|
||||||
|
assert len(messages2) == 1
|
||||||
|
content2 = json.loads(messages2[0]['content'])
|
||||||
|
# data should still be 1 (cached)
|
||||||
|
assert content2['data'] == 1
|
||||||
|
# handler not called again
|
||||||
|
assert call_count == 1
|
||||||
|
|
||||||
|
# Different call (different arguments) should execute
|
||||||
|
call2 = {
|
||||||
|
'id': 'call_2',
|
||||||
|
'function': {
|
||||||
|
'name': 'counter',
|
||||||
|
'arguments': json.dumps({'different': True})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
messages3 = executor.process_tool_calls([call2], context=None)
|
||||||
|
content3 = json.loads(messages3[0]['content'])
|
||||||
|
assert content3['data'] == 2
|
||||||
|
assert call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_executor_context_injection():
|
||||||
|
"""Test that context fields are injected into arguments."""
|
||||||
|
registry = ToolRegistry()
|
||||||
|
registry._tools.clear()
|
||||||
|
|
||||||
|
captured_args = None
|
||||||
|
def capture_handler(args):
|
||||||
|
nonlocal captured_args
|
||||||
|
captured_args = args.copy()
|
||||||
|
return 'ok'
|
||||||
|
|
||||||
|
tool = ToolDefinition(
|
||||||
|
name='file_read',
|
||||||
|
description='Read file',
|
||||||
|
parameters={'type': 'object'},
|
||||||
|
handler=capture_handler
|
||||||
|
)
|
||||||
|
registry.register(tool)
|
||||||
|
|
||||||
|
executor = ToolExecutor(registry=registry)
|
||||||
|
|
||||||
|
call = {
|
||||||
|
'id': 'call_1',
|
||||||
|
'function': {
|
||||||
|
'name': 'file_read',
|
||||||
|
'arguments': json.dumps({'path': 'test.txt'})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
context = {'project_id': 'proj-123'}
|
||||||
|
executor.process_tool_calls([call], context=context)
|
||||||
|
|
||||||
|
# Check that project_id was injected
|
||||||
|
assert captured_args is not None
|
||||||
|
assert captured_args['project_id'] == 'proj-123'
|
||||||
|
assert captured_args['path'] == 'test.txt'
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_executor_deduplication():
|
||||||
|
"""Test deduplication of identical calls within a session."""
|
||||||
|
registry = ToolRegistry()
|
||||||
|
registry._tools.clear()
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
def count_handler(args):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return call_count
|
||||||
|
|
||||||
|
tool = ToolDefinition(
|
||||||
|
name='count',
|
||||||
|
description='Count',
|
||||||
|
parameters={'type': 'object'},
|
||||||
|
handler=count_handler
|
||||||
|
)
|
||||||
|
registry.register(tool)
|
||||||
|
|
||||||
|
executor = ToolExecutor(registry=registry, enable_cache=False)
|
||||||
|
|
||||||
|
call = {
|
||||||
|
'id': 'call_1',
|
||||||
|
'function': {
|
||||||
|
'name': 'count',
|
||||||
|
'arguments': '{}'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
call_same = {
|
||||||
|
'id': 'call_2',
|
||||||
|
'function': {
|
||||||
|
'name': 'count',
|
||||||
|
'arguments': '{}'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Execute both calls in one batch
|
||||||
|
messages = executor.process_tool_calls([call, call_same], context=None)
|
||||||
|
# Should deduplicate: second call returns cached result from first call
|
||||||
|
# Let's verify that call_count is 1 (only one actual execution).
|
||||||
|
assert call_count == 1
|
||||||
|
# Both messages should have success=True
|
||||||
|
assert len(messages) == 2
|
||||||
|
content0 = json.loads(messages[0]['content'])
|
||||||
|
content1 = json.loads(messages[1]['content'])
|
||||||
|
assert content0['success'] is True
|
||||||
|
assert content1['success'] is True
|
||||||
|
# Data could be 1 for both (duplicate may have data None)
|
||||||
|
assert content0['data'] == 1
|
||||||
|
# duplicate call may have data None, but should be successful and cached
|
||||||
|
assert content1['success'] is True
|
||||||
|
assert content1.get('cached') is True
|
||||||
|
assert content1.get('data') in (1, None)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pytest.main(['-v', __file__])
|
||||||
Loading…
Reference in New Issue