Compare commits

...

10 Commits

45 changed files with 2572 additions and 996 deletions

99
.github/workflows/ci.yml vendored Normal file
View File

@ -0,0 +1,99 @@
name: CI
on:
push:
branches: [ main, master ]
pull_request:
branches: [ main, master ]
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.12']
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .[test]
# Install frontend dependencies for potential frontend tests (optional)
cd frontend && npm ci && cd ..
# 统一配置生成,用于后端测试(与前端构建共享相同配置)
- name: Create config.yml for CI
run: |
cat > config.yml << 'EOF'
backend_port: 3000
frontend_port: 4000
max_iterations: 5
sub_agent:
max_iterations: 3
max_tokens: 4096
max_agents: 5
max_concurrency: 3
models:
- id: dummy
name: Dummy
api_url: https://api.example.com
api_key: dummy-key
default_model: dummy
db_type: sqlite
db_sqlite_file: ":memory:"
workspace_root: ./workspaces
EOF
- name: Run tests with pytest
run: |
python -m pytest tests/ -v
- name: Upload coverage to Codecov (optional)
uses: codecov/codecov-action@v3
if: matrix.python-version == '3.10' && github.event_name == 'push'
with:
file: ./coverage.xml
fail_ci_if_error: false
build-frontend:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: '20'
- name: Install frontend dependencies
run: |
cd frontend && npm ci
# 统一配置生成,用于前端构建(与后端测试共享相同配置)
- name: Create config.yml for frontend build
run: |
cat > config.yml << 'EOF'
backend_port: 3000
frontend_port: 4000
max_iterations: 5
sub_agent:
max_iterations: 3
max_tokens: 4096
max_agents: 5
max_concurrency: 3
models:
- id: dummy
name: Dummy
api_url: https://api.example.com
api_key: dummy-key
default_model: dummy
db_type: sqlite
db_sqlite_file: ":memory:"
workspace_root: ./workspaces
EOF
- name: Build frontend
run: |
cd frontend && npm run build

3
.gitignore vendored
View File

@ -28,3 +28,6 @@
!frontend/src/**/*.css
!frontend/public/
!frontend/public/**
# CI / CD
!.github/workflows/*

View File

@ -36,6 +36,11 @@ frontend_port: 4000
# Max agentic loop iterations (tool call rounds)
max_iterations: 15
# Sub-agent settings (multi_agent tool)
sub_agent:
max_iterations: 3 # Max tool-call rounds per sub-agent
max_concurrency: 3 # ThreadPoolExecutor max workers
# Available models
# Each model must have its own id, name, api_url, api_key
models:
@ -117,6 +122,7 @@ backend/
│ ├── data.py # 计算器、文本、JSON 处理
│ ├── weather.py # 天气查询(模拟)
│ ├── file_ops.py # 文件操作6 个工具project_id 自动注入)
│ ├── agent.py # 多智能体(子 Agent 并发执行,工具权限隔离)
│ └── code.py # Python 代码执行(沙箱)
└── utils/ # 辅助函数
├── helpers.py # 通用函数ok/err/build_messages 等)
@ -207,6 +213,7 @@ frontend/
| **代码执行** | execute_python | 沙箱环境执行 Python |
| **文件操作** | file_read, file_write, file_delete, file_list, file_exists, file_mkdir | project_id 自动注入 |
| **天气** | get_weather | 天气查询(模拟) |
| **智能体** | multi_agent | 派生子 Agent 并发执行(禁止递归,工具权限与主 Agent 一致) |
## 文档

View File

@ -1,41 +1,117 @@
"""Configuration management"""
"""Configuration management using dataclasses"""
import sys
from dataclasses import dataclass, field
from typing import List, Dict, Optional
from backend import load_config
_cfg = load_config()
# Model list (for /api/models endpoint)
MODELS = _cfg.get("models", [])
@dataclass
class ModelConfig:
"""Individual model configuration."""
id: str
name: str
api_url: str
api_key: str
# Validate each model has required fields at startup
_REQUIRED_MODEL_KEYS = {"id", "name", "api_url", "api_key"}
_model_ids_seen = set()
for _i, _m in enumerate(MODELS):
_missing = _REQUIRED_MODEL_KEYS - set(_m.keys())
if _missing:
print(f"[config] ERROR: models[{_i}] missing required fields: {_missing}", file=sys.stderr)
sys.exit(1)
if _m["id"] in _model_ids_seen:
print(f"[config] ERROR: duplicate model id '{_m['id']}'", file=sys.stderr)
sys.exit(1)
_model_ids_seen.add(_m["id"])
# Per-model config lookup: {model_id: {api_url, api_key}}
MODEL_CONFIG = {m["id"]: {"api_url": m["api_url"], "api_key": m["api_key"]} for m in MODELS}
@dataclass
class SubAgentConfig:
"""Sub-agent (multi_agent tool) settings."""
max_iterations: int = 3
max_concurrency: int = 3
timeout: int = 60
# default_model must exist in models
DEFAULT_MODEL = _cfg.get("default_model", "")
if DEFAULT_MODEL and DEFAULT_MODEL not in MODEL_CONFIG:
print(f"[config] ERROR: default_model '{DEFAULT_MODEL}' not found in models", file=sys.stderr)
sys.exit(1)
if MODELS and not DEFAULT_MODEL:
DEFAULT_MODEL = MODELS[0]["id"]
# Max agentic loop iterations (tool call rounds)
MAX_ITERATIONS = _cfg.get("max_iterations", 5)
@dataclass
class CodeExecutionConfig:
"""Code execution settings."""
default_strictness: str = "standard"
extra_allowed_modules: Dict = field(default_factory=dict)
backend: str = "subprocess" # subprocess or docker
docker_image: str = "python:3.12-slim"
docker_network: str = "none"
docker_user: str = "nobody"
docker_memory_limit: Optional[str] = None
docker_cpu_shares: Optional[int] = None
# Max parallel workers for tool execution (ThreadPoolExecutor)
TOOL_MAX_WORKERS = _cfg.get("tool_max_workers", 4)
# Max character length for a single tool result content (truncated if exceeded)
TOOL_RESULT_MAX_LENGTH = _cfg.get("tool_result_max_length", 4096)
@dataclass
class AppConfig:
"""Main application configuration."""
models: List[ModelConfig] = field(default_factory=list)
default_model: str = ""
max_iterations: int = 5
tool_max_workers: int = 4
sub_agent: SubAgentConfig = field(default_factory=SubAgentConfig)
code_execution: CodeExecutionConfig = field(default_factory=CodeExecutionConfig)
# Per-model config lookup: {model_id: ModelConfig}
_model_config_map: Dict[str, ModelConfig] = field(default_factory=dict, repr=False)
def __post_init__(self):
"""Build lookup map after initialization."""
self._model_config_map = {m.id: m for m in self.models}
def get_model_config(self, model_id: str) -> Optional[ModelConfig]:
"""Get model config by ID."""
return self._model_config_map.get(model_id)
def get_model_credentials(self, model_id: str) -> tuple:
"""Get (api_url, api_key) for a model."""
cfg = self.get_model_config(model_id)
if not cfg:
raise ValueError(f"Unknown model: '{model_id}', not found in config")
if not cfg.api_url:
raise ValueError(f"Model '{model_id}' has no api_url configured")
if not cfg.api_key:
raise ValueError(f"Model '{model_id}' has no api_key configured")
return cfg.api_url, cfg.api_key
def _parse_config(raw: dict) -> AppConfig:
"""Parse raw YAML config into AppConfig dataclass."""
# Parse models
models = []
for m in raw.get("models", []):
models.append(ModelConfig(
id=m["id"],
name=m["name"],
api_url=m["api_url"],
api_key=m["api_key"],
))
# Parse sub_agent
sa_raw = raw.get("sub_agent", {})
sub_agent = SubAgentConfig(
max_iterations=sa_raw.get("max_iterations", 3),
max_concurrency=sa_raw.get("max_concurrency", 3),
timeout=sa_raw.get("timeout", 60),
)
# Parse code_execution
ce_raw = raw.get("code_execution", {})
code_execution = CodeExecutionConfig(
default_strictness=ce_raw.get("default_strictness", "standard"),
extra_allowed_modules=ce_raw.get("extra_allowed_modules", {}),
backend=ce_raw.get("backend", "subprocess"),
docker_image=ce_raw.get("docker_image", "python:3.12-slim"),
docker_network=ce_raw.get("docker_network", "none"),
docker_user=ce_raw.get("docker_user", "nobody"),
docker_memory_limit=ce_raw.get("docker_memory_limit"),
docker_cpu_shares=ce_raw.get("docker_cpu_shares"),
)
return AppConfig(
models=models,
default_model=raw.get("default_model", ""),
max_iterations=raw.get("max_iterations", 5),
tool_max_workers=raw.get("tool_max_workers", 4),
sub_agent=sub_agent,
code_execution=code_execution,
)
# Load and validate configuration at startup
_raw_cfg = load_config()
config = _parse_config(_raw_cfg)

View File

@ -8,16 +8,16 @@ from backend.routes.stats import bp as stats_bp
from backend.routes.projects import bp as projects_bp
from backend.routes.auth import bp as auth_bp, init_auth
from backend.services.llm_client import LLMClient
from backend.config import MODEL_CONFIG
from backend.config import config
def register_routes(app: Flask):
"""Register all route blueprints"""
# Initialize LLM client with per-model config
client = LLMClient(MODEL_CONFIG)
# Initialize LLM client with config
client = LLMClient(config)
init_chat_service(client)
# Register LLM client in service locator so tools (e.g. agent_task) can access it
# Register LLM client in service locator so tools (e.g. multi_agent) can access it
from backend.tools import register_service
register_service("llm_client", client)

View File

@ -5,7 +5,7 @@ from flask import Blueprint, request, g
from backend import db
from backend.models import Conversation, Project
from backend.utils.helpers import ok, err, to_dict
from backend.config import DEFAULT_MODEL
from backend.config import config
bp = Blueprint("conversations", __name__)
@ -40,7 +40,7 @@ def conversation_list():
user_id=user.id,
project_id=project_id or None,
title=d.get("title", ""),
model=d.get("model", DEFAULT_MODEL),
model=d.get("model", config.default_model),
system_prompt=d.get("system_prompt", ""),
temperature=d.get("temperature", 1.0),
max_tokens=d.get("max_tokens", 65536),

View File

@ -1,7 +1,7 @@
"""Model list API routes"""
from flask import Blueprint
from backend.utils.helpers import ok
from backend.config import MODELS
from backend.config import config
bp = Blueprint("models", __name__)
@ -13,7 +13,10 @@ _SENSITIVE_KEYS = {"api_key", "api_url"}
def list_models():
"""Get available model list (without sensitive fields like api_key)"""
safe_models = [
{k: v for k, v in m.items() if k not in _SENSITIVE_KEYS}
for m in MODELS
{
"id": m.id,
"name": m.name,
}
for m in config.models
]
return ok(safe_models)

View File

@ -1,29 +1,38 @@
"""Token statistics API routes"""
from datetime import date, timedelta
from datetime import date, timedelta, datetime, timezone
from flask import Blueprint, request, g
from sqlalchemy import func
from backend.models import TokenUsage
from sqlalchemy import func, extract
from backend.models import TokenUsage, Message, Conversation
from backend.utils.helpers import ok, err
from backend import db
bp = Blueprint("stats", __name__)
def _utc_today():
"""Get today's date in UTC to match stored timestamps."""
return datetime.now(timezone.utc).date()
@bp.route("/api/stats/tokens", methods=["GET"])
def token_stats():
"""Get token usage statistics"""
user = g.current_user
period = request.args.get("period", "daily")
today = date.today()
today = _utc_today()
if period == "daily":
stats = TokenUsage.query.filter_by(user_id=user.id, date=today).all()
# Hourly breakdown from Message table
hourly = _build_hourly_stats(user.id, today)
result = {
"period": "daily",
"date": today.isoformat(),
"prompt_tokens": sum(s.prompt_tokens for s in stats),
"completion_tokens": sum(s.completion_tokens for s in stats),
"total_tokens": sum(s.total_tokens for s in stats),
"hourly": hourly,
"by_model": {
s.model: {
"prompt": s.prompt_tokens,
@ -92,3 +101,37 @@ def _build_period_result(stats, period, start_date, end_date, days):
"daily": daily_data,
"by_model": by_model,
}
def _build_hourly_stats(user_id, day):
"""Build hourly token breakdown for a given day (UTC) from Message table."""
day_start = datetime.combine(day, datetime.min.time()).replace(tzinfo=timezone.utc)
day_end = datetime.combine(day, datetime.max.time()).replace(tzinfo=timezone.utc)
conv_ids = (
db.session.query(Conversation.id)
.filter(Conversation.user_id == user_id)
.subquery()
)
rows = (
db.session.query(
extract("hour", Message.created_at).label("hour"),
func.sum(Message.token_count).label("total"),
)
.filter(
Message.conversation_id.in_(conv_ids),
Message.role == "assistant",
Message.created_at >= day_start,
Message.created_at <= day_end,
)
.group_by(extract("hour", Message.created_at))
.all()
)
hourly = {}
for r in rows:
h = int(r.hour)
hourly[str(h)] = {"total": r.total or 0}
return hourly

View File

@ -14,7 +14,7 @@ from backend.utils.helpers import (
build_messages,
)
from backend.services.llm_client import LLMClient
from backend.config import MAX_ITERATIONS, TOOL_MAX_WORKERS, TOOL_RESULT_MAX_LENGTH
from backend.config import config as _cfg
logger = logging.getLogger(__name__)
@ -61,13 +61,20 @@ class ChatService:
"""
conv_id = conv.id
conv_model = conv.model
conv_max_tokens = conv.max_tokens
conv_temperature = conv.temperature
conv_thinking_enabled = conv.thinking_enabled
app = current_app._get_current_object()
tools = registry.list_all() if tools_enabled else None
initial_messages = build_messages(conv, project_id)
executor = ToolExecutor(registry=registry)
context = {"model": conv_model}
context = {
"model": conv_model,
"max_tokens": conv_max_tokens,
"temperature": conv_temperature,
}
if project_id:
context["project_id"] = project_id
elif conv.project_id:
@ -82,10 +89,26 @@ class ChatService:
total_completion_tokens = 0
total_prompt_tokens = 0
for iteration in range(MAX_ITERATIONS):
for iteration in range(_cfg.max_iterations):
# Helper to parse stream_result event
def parse_stream_result(event_str):
"""Parse stream_result SSE event and extract data dict."""
# Format: "event: stream_result\ndata: {...}\n\n"
try:
stream_result = self._stream_llm_response(
app, conv_id, messages, tools, tool_choice, step_index
for line in event_str.split('\n'):
if line.startswith('data: '):
return json.loads(line[6:])
except Exception:
pass
return None
# Collect SSE events and extract final stream_result
try:
stream_gen = self._stream_llm_response(
app, messages, tools, tool_choice, step_index,
conv_model, conv_max_tokens, conv_temperature,
conv_thinking_enabled,
)
except requests.exceptions.HTTPError as e:
resp = e.response
@ -107,22 +130,37 @@ class ChatService:
yield _sse_event("error", {"content": f"Internal error: {e}"})
return
if stream_result is None:
return # Client disconnected
result_data = None
try:
for event_str in stream_gen:
# Check if this is a stream_result event (final event)
if event_str.startswith("event: stream_result"):
result_data = parse_stream_result(event_str)
else:
# Forward process_step events to client in real-time
yield event_str
except Exception as e:
logger.exception("Error during streaming")
yield _sse_event("error", {"content": f"Stream error: {e}"})
return
full_content, full_thinking, tool_calls_list, \
thinking_step_id, thinking_step_idx, \
text_step_id, text_step_idx, \
completion_tokens, prompt_tokens, \
sse_chunks = stream_result
if result_data is None:
return # Client disconnected or error
# Extract data from stream_result
full_content = result_data["full_content"]
full_thinking = result_data["full_thinking"]
tool_calls_list = result_data["tool_calls_list"]
thinking_step_id = result_data["thinking_step_id"]
thinking_step_idx = result_data["thinking_step_idx"]
text_step_id = result_data["text_step_id"]
text_step_idx = result_data["text_step_idx"]
completion_tokens = result_data["completion_tokens"]
prompt_tokens = result_data["prompt_tokens"]
total_prompt_tokens += prompt_tokens
total_completion_tokens += completion_tokens
# Yield accumulated SSE chunks to frontend
for chunk in sse_chunks:
yield chunk
# Save thinking/text steps to all_steps for DB storage
if thinking_step_id is not None:
all_steps.append({
@ -185,7 +223,7 @@ class ChatService:
# Append assistant message + tool results for the next iteration
messages.append({
"role": "assistant",
"content": full_content or None,
"content": full_content or "",
"tool_calls": tool_calls_list,
})
messages.extend(tool_results)
@ -232,12 +270,19 @@ class ChatService:
# ------------------------------------------------------------------
def _stream_llm_response(
self, app, conv_id, messages, tools, tool_choice, step_index
self, app, messages, tools, tool_choice, step_index,
model, max_tokens, temperature, thinking_enabled,
):
"""Call LLM streaming API and parse the response.
"""Call LLM streaming API and yield SSE events in real-time.
Returns a tuple of parsed results, or None if the client disconnected.
Raises HTTPError / ConnectionError / Timeout for the caller to handle.
This is a generator that yields SSE event strings as they are received.
The final yield is a 'stream_result' event containing the accumulated data.
Yields:
str: SSE event strings (process_step events, then stream_result)
Raises:
HTTPError / ConnectionError / Timeout for the caller to handle.
"""
full_content = ""
full_thinking = ""
@ -250,16 +295,13 @@ class ChatService:
text_step_id = None
text_step_idx = None
sse_chunks = [] # Collect SSE events to yield later
with app.app_context():
active_conv = db.session.get(Conversation, conv_id)
resp = self.llm.call(
model=active_conv.model,
model=model,
messages=messages,
max_tokens=active_conv.max_tokens,
temperature=active_conv.temperature,
thinking_enabled=active_conv.thinking_enabled,
max_tokens=max_tokens,
temperature=temperature,
thinking_enabled=thinking_enabled,
tools=tools,
tool_choice=tool_choice,
stream=True,
@ -269,7 +311,7 @@ class ChatService:
for line in resp.iter_lines():
if _client_disconnected():
resp.close()
return None
return # Client disconnected, stop generator
if not line:
continue
@ -295,37 +337,44 @@ class ChatService:
delta = choices[0].get("delta", {})
# Yield thinking content in real-time
reasoning = delta.get("reasoning_content", "")
if reasoning:
full_thinking += reasoning
if thinking_step_id is None:
thinking_step_id = f"step-{step_index}"
thinking_step_idx = step_index
sse_chunks.append(_sse_event("process_step", {
yield _sse_event("process_step", {
"id": thinking_step_id, "index": thinking_step_idx,
"type": "thinking", "content": full_thinking,
}))
})
# Yield text content in real-time
text = delta.get("content", "")
if text:
full_content += text
if text_step_id is None:
text_step_idx = step_index + (1 if thinking_step_id is not None else 0)
text_step_id = f"step-{text_step_idx}"
sse_chunks.append(_sse_event("process_step", {
yield _sse_event("process_step", {
"id": text_step_id, "index": text_step_idx,
"type": "text", "content": full_content,
}))
})
tool_calls_list = self._process_tool_calls_delta(delta, tool_calls_list)
return (
full_content, full_thinking, tool_calls_list,
thinking_step_id, thinking_step_idx,
text_step_id, text_step_idx,
token_count, prompt_tokens,
sse_chunks,
)
# Final yield: stream_result event with accumulated data
yield _sse_event("stream_result", {
"full_content": full_content,
"full_thinking": full_thinking,
"tool_calls_list": tool_calls_list,
"thinking_step_id": thinking_step_id,
"thinking_step_idx": thinking_step_idx,
"text_step_id": text_step_id,
"text_step_idx": text_step_idx,
"completion_tokens": token_count,
"prompt_tokens": prompt_tokens,
})
def _execute_tools_safe(self, app, executor, tool_calls_list, context):
"""Execute tool calls with top-level error wrapping.
@ -336,17 +385,17 @@ class ChatService:
try:
if len(tool_calls_list) > 1:
with app.app_context():
tool_results = executor.process_tool_calls_parallel(
tool_calls_list, context, max_workers=TOOL_MAX_WORKERS
return executor.process_tool_calls_parallel(
tool_calls_list, context, max_workers=_cfg.tool_max_workers
)
else:
with app.app_context():
tool_results = executor.process_tool_calls(
return executor.process_tool_calls(
tool_calls_list, context
)
except Exception as e:
logger.exception("Error during tool execution")
tool_results = [
return [
{
"role": "tool",
"tool_call_id": tc["id"],
@ -359,30 +408,6 @@ class ChatService:
for tc in tool_calls_list
]
# Truncate oversized tool result content
for tr in tool_results:
if len(tr["content"]) > TOOL_RESULT_MAX_LENGTH:
try:
result_data = json.loads(tr["content"])
original = result_data
except (json.JSONDecodeError, TypeError):
original = None
tr["content"] = json.dumps(
{"success": False, "error": "Tool result too large, truncated"},
ensure_ascii=False,
) if not original else json.dumps(
{
**original,
"truncated": True,
"_note": f"Content truncated, original length {len(tr['content'])} chars",
},
ensure_ascii=False,
default=str,
)[:TOOL_RESULT_MAX_LENGTH]
return tool_results
def _save_message(
self, app, conv_id, conv_model, msg_id,
full_content, all_tool_calls, all_tool_results,

View File

@ -35,28 +35,22 @@ def _detect_provider(api_url: str) -> str:
class LLMClient:
"""OpenAI-compatible LLM API client.
Each model must have its own api_url and api_key configured in MODEL_CONFIG.
Each model must have its own api_url and api_key configured in config.models.
"""
def __init__(self, model_config: dict):
"""Initialize with per-model config lookup.
def __init__(self, cfg):
"""Initialize with AppConfig.
Args:
model_config: {model_id: {"api_url": ..., "api_key": ...}}
cfg: AppConfig dataclass instance
"""
self.model_config = model_config
self.cfg = cfg
def _get_credentials(self, model: str):
"""Get api_url and api_key for a model, with env-var expansion."""
cfg = self.model_config.get(model)
if not cfg:
raise ValueError(f"Unknown model: '{model}', not found in config")
api_url = _resolve_env_vars(cfg.get("api_url", ""))
api_key = _resolve_env_vars(cfg.get("api_key", ""))
if not api_url:
raise ValueError(f"Model '{model}' has no api_url configured")
if not api_key:
raise ValueError(f"Model '{model}' has no api_key configured")
api_url, api_key = self.cfg.get_model_credentials(model)
api_url = _resolve_env_vars(api_url)
api_key = _resolve_env_vars(api_key)
return api_url, api_key
def _build_body(self, model, messages, max_tokens, temperature, thinking_enabled,

View File

@ -15,13 +15,13 @@ Usage:
result = registry.execute("web_search", {"query": "Python"})
"""
from backend.tools.core import ToolDefinition, ToolResult, ToolRegistry, registry
from backend.tools.factory import tool, register_tool
from backend.tools.core import registry
from backend.tools.factory import tool
from backend.tools.executor import ToolExecutor
# ---------------------------------------------------------------------------
# Service locator allows tools (e.g. agent_task) to access LLM client
# Service locator allows tools (e.g. multi_agent) to access LLM client
# ---------------------------------------------------------------------------
_services: dict = {}
@ -47,16 +47,12 @@ def init_tools() -> None:
# Public API exports
__all__ = [
# Core classes
"ToolDefinition",
"ToolResult",
"ToolRegistry",
"ToolExecutor",
# Instances
"registry",
# Factory functions
"tool",
"register_tool",
# Classes
"ToolExecutor",
# Initialization
"init_tools",
# Service locator

View File

@ -1,130 +1,49 @@
"""Multi-agent tools for concurrent and batch task execution.
Provides:
- parallel_execute: Run multiple tool calls concurrently
- agent_task: Spawn sub-agents with their own LLM conversation loops
"""
import json
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Dict, Any, Optional
from typing import List, Any, Optional
from backend.tools import get_service
from backend.tools.factory import tool
from backend.tools.core import registry
from backend.tools.executor import ToolExecutor
from backend.config import config
logger = logging.getLogger(__name__)
# Sub-agents are forbidden from using multi_agent to prevent infinite recursion
BLOCKED_TOOLS = {"multi_agent"}
# ---------------------------------------------------------------------------
# parallel_execute run multiple tool calls concurrently
# ---------------------------------------------------------------------------
def _to_executor_calls(tool_calls: list, id_prefix: str = "tc") -> list:
"""Normalize tool calls into executor-compatible format.
@tool(
name="parallel_execute",
description=(
"Execute multiple tool calls concurrently for better performance. "
"Use when you have several independent operations that don't depend on each other "
"(e.g. reading multiple files, running multiple searches, fetching several pages). "
"Results are returned in the same order as the input."
),
parameters={
"type": "object",
"properties": {
"tool_calls": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "Tool name to execute",
},
"arguments": {
"type": "object",
"description": "Arguments for the tool",
},
},
"required": ["name", "arguments"],
},
"description": "List of tool calls to execute in parallel (max 10)",
},
"concurrency": {
"type": "integer",
"description": "Max concurrent executions (1-5, default 3)",
"default": 3,
},
},
"required": ["tool_calls"],
},
category="agent",
)
def parallel_execute(arguments: dict) -> dict:
"""Execute multiple tool calls concurrently.
Args:
arguments: {
"tool_calls": [
{"name": "file_read", "arguments": {"path": "a.py"}},
{"name": "web_search", "arguments": {"query": "python"}}
],
"concurrency": 3,
"_project_id": "..." // injected by executor
}
Returns:
{"results": [{index, tool_name, success, data/error}]}
Accepts two input shapes:
- LLM format: {"function": {"name": ..., "arguments": ...}}
- Simple format: {"name": ..., "arguments": ...}
"""
tool_calls = arguments["tool_calls"]
concurrency = min(max(arguments.get("concurrency", 3), 1), 5)
if len(tool_calls) > 10:
return {"success": False, "error": "Maximum 10 tool calls allowed per parallel execution"}
# Build executor context from injected fields
context = {}
project_id = arguments.get("_project_id")
if project_id:
context["project_id"] = project_id
# Format tool_calls into executor-compatible format
executor_calls = []
for i, tc in enumerate(tool_calls):
if "function" in tc:
func = tc["function"]
executor_calls.append({
"id": f"pe-{i}",
"id": tc.get("id", f"{id_prefix}-{i}"),
"type": tc.get("type", "function"),
"function": {
"name": func["name"],
"arguments": func["arguments"],
},
})
else:
executor_calls.append({
"id": f"{id_prefix}-{i}",
"type": "function",
"function": {
"name": tc["name"],
"arguments": json.dumps(tc["arguments"], ensure_ascii=False),
},
})
return executor_calls
# Use ToolExecutor for proper context injection, caching and dedup
executor = ToolExecutor(registry=registry, enable_cache=False)
executor_results = executor.process_tool_calls_parallel(
executor_calls, context, max_workers=concurrency
)
# Format output
results = []
for er in executor_results:
try:
content = json.loads(er["content"]) if isinstance(er["content"], str) else er["content"]
except (json.JSONDecodeError, TypeError):
content = {"success": False, "error": "Failed to parse result"}
results.append({
"index": len(results),
"tool_name": er["name"],
**content,
})
return {
"success": True,
"results": results,
"total": len(results),
}
# ---------------------------------------------------------------------------
# agent_task spawn sub-agents with independent LLM conversation loops
# ---------------------------------------------------------------------------
def _run_sub_agent(
task_name: str,
@ -132,6 +51,7 @@ def _run_sub_agent(
tool_names: Optional[List[str]],
model: str,
max_tokens: int,
temperature: float,
project_id: Optional[str],
app: Any,
max_iterations: int = 3,
@ -141,7 +61,6 @@ def _run_sub_agent(
Each sub-agent gets its own ToolExecutor instance and runs a simplified
version of the main agent loop, limited to prevent runaway cost.
"""
from backend.tools import get_service
llm_client = get_service("llm_client")
if not llm_client:
@ -151,16 +70,21 @@ def _run_sub_agent(
"error": "LLM client not available",
}
# Build tool list filter to requested tools or use all
# Build tool list filter to requested tools, then remove blocked
all_tools = registry.list_all()
if tool_names:
allowed = set(tool_names)
tools = [t for t in all_tools if t["function"]["name"] in allowed]
else:
tools = all_tools
tools = list(all_tools)
# Remove blocked tools to prevent recursion
tools = [t for t in tools if t["function"]["name"] not in BLOCKED_TOOLS]
executor = ToolExecutor(registry=registry)
context = {"project_id": project_id} if project_id else None
context = {"model": model}
if project_id:
context["project_id"] = project_id
# System prompt: instruction + reminder to give a final text answer
system_msg = (
@ -170,17 +94,21 @@ def _run_sub_agent(
)
messages = [{"role": "system", "content": system_msg}]
for _ in range(max_iterations):
for i in range(max_iterations):
is_final = (i == max_iterations - 1)
try:
with app.app_context():
resp = llm_client.call(
model=model,
messages=messages,
tools=tools if tools else None,
# On the last iteration, don't pass tools so the LLM is
# forced to produce a final text response instead of calling
# more tools.
tools=None if is_final else (tools if tools else None),
stream=False,
max_tokens=min(max_tokens, 4096),
temperature=0.7,
timeout=60,
max_tokens=max_tokens,
temperature=temperature,
timeout=config.sub_agent.timeout,
)
if resp.status_code != 200:
@ -196,20 +124,26 @@ def _run_sub_agent(
message = choice["message"]
if message.get("tool_calls"):
messages.append(message)
tc_list = message["tool_calls"]
# Convert OpenAI tool_calls to executor format
executor_calls = []
for tc in tc_list:
executor_calls.append({
"id": tc.get("id", ""),
"type": tc.get("type", "function"),
"function": {
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"],
},
# Only extract needed fields — LLM response may contain extra
# fields (e.g. reasoning_content) that the API rejects on re-send
messages.append({
"role": "assistant",
"content": message.get("content") or "",
"tool_calls": message["tool_calls"],
})
tool_results = executor.process_tool_calls(executor_calls, context)
tc_list = message["tool_calls"]
executor_calls = _to_executor_calls(tc_list)
# Execute tools inside app_context file ops and other DB-
# dependent tools require an active Flask context and session.
with app.app_context():
if len(executor_calls) > 1:
tool_results = executor.process_tool_calls_parallel(
executor_calls, context
)
else:
tool_results = executor.process_tool_calls(
executor_calls, context
)
messages.extend(tool_results)
else:
# Final text response
@ -226,7 +160,7 @@ def _run_sub_agent(
"error": str(e),
}
# Exhausted iterations without final response — return last LLM output if any
# Exhausted iterations without final response
return {
"task_name": task_name,
"success": True,
@ -234,49 +168,49 @@ def _run_sub_agent(
}
# @tool(
# name="agent_task",
# description=(
# "Spawn one or more sub-agents to work on tasks concurrently. "
# "Each agent runs its own independent conversation with the LLM and can use tools. "
# "Useful for parallel research, multi-file analysis, or dividing complex tasks into sub-tasks. "
# "Each agent is limited to 3 iterations and 4096 tokens to control cost."
# ),
# parameters={
# "type": "object",
# "properties": {
# "tasks": {
# "type": "array",
# "items": {
# "type": "object",
# "properties": {
# "name": {
# "type": "string",
# "description": "Short name/identifier for this task",
# },
# "instruction": {
# "type": "string",
# "description": "Detailed instruction for the sub-agent",
# },
# "tools": {
# "type": "array",
# "items": {"type": "string"},
# "description": (
# "Tool names this agent can use (empty = all tools). "
# "e.g. ['file_read', 'file_list', 'web_search']"
# ),
# },
# },
# "required": ["name", "instruction"],
# },
# "description": "Tasks for parallel sub-agents (max 5)",
# },
# },
# "required": ["tasks"],
# },
# category="agent",
# )
def agent_task(arguments: dict) -> dict:
@tool(
name="multi_agent",
description=(
"Spawn multiple sub-agents to work on tasks concurrently. "
"Each agent runs its own independent conversation with the LLM and can use tools. "
"Useful for parallel research, multi-file analysis, or dividing complex tasks into sub-tasks. "
"Resource limits (iterations, tokens, concurrency) are configured in config.yml -> sub_agent."
),
parameters={
"type": "object",
"properties": {
"tasks": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "Short name/identifier for this task",
},
"instruction": {
"type": "string",
"description": "Detailed instruction for the sub-agent",
},
"tools": {
"type": "array",
"items": {"type": "string"},
"description": (
"Tool names this agent can use (empty = all tools). "
"e.g. ['file_read', 'file_list', 'web_search']"
),
},
},
"required": ["name", "instruction"],
},
"description": "Tasks for parallel sub-agents (max 5)",
},
},
"required": ["tasks"],
},
category="agent",
)
def multi_agent(arguments: dict) -> dict:
"""Spawn sub-agents to work on tasks concurrently.
Args:
@ -296,7 +230,7 @@ def agent_task(arguments: dict) -> dict:
}
Returns:
{"success": true, "results": [{task_name, success, response/error}]}
{"success": true, "results": [{task_name, success, response/error}], "total": int}
"""
from flask import current_app
@ -309,11 +243,13 @@ def agent_task(arguments: dict) -> dict:
app = current_app._get_current_object()
# Use injected model/project_id from executor context, fall back to defaults
model = arguments.get("_model", "glm-5")
model = arguments.get("_model") or config.default_model
project_id = arguments.get("_project_id")
max_tokens = arguments.get("_max_tokens", 65536)
temperature = arguments.get("_temperature", 0.7)
# Execute agents concurrently (max 3 at a time)
concurrency = min(len(tasks), 3)
# Execute agents concurrently
concurrency = min(len(tasks), config.sub_agent.max_concurrency)
results = [None] * len(tasks)
with ThreadPoolExecutor(max_workers=concurrency) as pool:
@ -324,9 +260,11 @@ def agent_task(arguments: dict) -> dict:
task["instruction"],
task.get("tools"),
model,
4096,
max_tokens,
temperature,
project_id,
app,
config.sub_agent.max_iterations,
): i
for i, task in enumerate(tasks)
}

View File

@ -1,55 +1,131 @@
"""Safe code execution tool with sandboxing"""
"""Safe code execution tool with sandboxing and strictness levels"""
import ast
import subprocess
import sys
import tempfile
import textwrap
from pathlib import Path
from typing import Dict, List, Set
from backend.tools.factory import tool
from backend.config import config
from backend.tools.docker_executor import DockerExecutor
# Blacklist of dangerous modules - all other modules are allowed
BLOCKED_MODULES = {
# System-level access
"os", "sys", "subprocess", "shutil", "signal", "ctypes",
"multiprocessing", "threading", "_thread",
# Network access
"socket", "http", "urllib", "requests", "ftplib", "smtplib",
"telnetlib", "xmlrpc", "asyncio",
# File system / I/O
"pathlib", "io", "glob", "tempfile", "shutil", "fnmatch",
# Code execution / introspection
"importlib", "pkgutil", "code", "codeop", "compileall",
"runpy", "pdb", "profile", "cProfile",
# Dangerous stdlib
"webbrowser", "antigravity", "turtle",
# IPC / persistence
"pickle", "shelve", "marshal", "sqlite3", "dbm",
# Process / shell
"commands", "pipes", "pty", "posix", "posixpath",
}
# Strictness profiles configuration
# - lenient: no restrictions at all
# - standard: allowlist based, only safe modules permitted
# - strict: minimal allowlist, only pure computation modules
STRICTNESS_PROFILES: Dict[str, dict] = {
"lenient": {
"timeout": 30,
"description": "No restrictions, all modules and builtins allowed",
"allowlist_modules": None, # None means all allowed
"blocked_builtins": set(),
},
# Blacklist of dangerous builtins
BLOCKED_BUILTINS = {
"eval", "exec", "compile", "open", "input",
"__import__", "globals", "locals", "vars",
"standard": {
"timeout": 10,
"description": "Allowlist based, only safe modules and builtins permitted",
"allowlist_modules": {
# Data types & serialization
"json", "csv", "re", "typing",
# Data structures
"collections", "itertools", "functools", "operator", "heapq", "bisect",
"array", "copy", "pprint", "enum",
# Math & numbers
"math", "cmath", "statistics", "random", "fractions", "decimal", "numbers",
# Date & time
"datetime", "time", "calendar",
# Text processing
"string", "textwrap", "unicodedata", "difflib",
# Data formats
"base64", "binascii", "quopri", "uu", "html", "xml.etree.ElementTree",
# Functional & concurrency helpers
"dataclasses", "hashlib", "hmac",
# Common utilities
"abc", "contextlib", "warnings", "logging",
},
"blocked_builtins": {
"eval", "exec", "compile", "__import__",
"open", "input", "globals", "locals", "vars",
"breakpoint", "exit", "quit",
"memoryview", "bytearray",
"getattr", "setattr", "delattr",
},
},
"strict": {
"timeout": 5,
"description": "Minimal allowlist, only pure computation modules",
"allowlist_modules": {
# Pure data structures
"collections", "itertools", "functools", "operator",
"array", "copy", "enum",
# Pure math
"math", "cmath", "numbers", "fractions", "decimal",
"random", "statistics",
# Pure text
"string", "textwrap", "unicodedata",
# Type hints
"typing",
# Utilities (no I/O)
"dataclasses", "abc", "contextlib",
},
"blocked_builtins": {
"eval", "exec", "compile", "__import__",
"open", "input", "globals", "locals", "vars",
"breakpoint", "exit", "quit",
"memoryview", "bytearray",
"dir", "hasattr", "getattr", "setattr", "delattr",
"type", "isinstance", "issubclass",
},
},
}
def register_extra_modules(strictness: str, modules: Set[str] | List[str]) -> None:
"""Register additional modules to a strictness level's allowlist.
Args:
strictness: One of "lenient", "standard", "strict".
modules: Module names to add to the allowlist.
"""
if strictness not in STRICTNESS_PROFILES:
raise ValueError(f"Invalid strictness level: {strictness}. Must be one of: {', '.join(STRICTNESS_PROFILES.keys())}")
profile = STRICTNESS_PROFILES[strictness]
if profile.get("allowlist_modules") is None:
return # lenient mode allows everything, nothing to add
profile["allowlist_modules"].update(modules)
# Apply extra modules from config.yml on module load
for _level, _mods in config.code_execution.extra_allowed_modules.items():
if isinstance(_mods, list) and _mods:
register_extra_modules(_level, _mods)
@tool(
name="execute_python",
description="Execute Python code in a sandboxed environment. Most standard library modules are allowed, with dangerous modules (os, subprocess, socket, etc.) blocked. Max execution time: 10 seconds.",
description="Execute Python code in a sandboxed environment with configurable strictness levels (lenient/standard/strict). "
"Default: 'standard' mode - balances security and flexibility with 10s timeout. "
"Use 'lenient' for data processing tasks (30s timeout, more modules allowed). "
"Use 'strict' for basic calculations only (5s timeout, minimal module access).",
parameters={
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "Python code to execute. Dangerous modules (os, subprocess, socket, etc.) are blocked."
"description": "Python code to execute. Available modules depend on strictness level."
},
"strictness": {
"type": "string",
"enum": ["lenient", "standard", "strict"],
"description": "Optional. Security strictness level (default: standard). "
"lenient: 30s timeout, most modules allowed; "
"standard: 10s timeout, balanced security; "
"strict: 5s timeout, minimal permissions."
}
},
"required": ["code"]
@ -61,36 +137,76 @@ def execute_python(arguments: dict) -> dict:
Execute Python code safely with sandboxing.
Security measures:
1. Blocked dangerous imports (blacklist)
2. Blocked dangerous builtins
3. Timeout limit (10s)
4. No file system access
5. No network access
1. Lenient mode: no restrictions
2. Standard/strict mode: allowlist based module restrictions
3. Configurable blocked builtins based on strictness level
4. Timeout limit (5s/10s/30s based on strictness)
5. Subprocess isolation
"""
code = arguments["code"]
strictness = arguments.get("strictness", config.code_execution.default_strictness)
# Security check: detect dangerous imports
dangerous_imports = _check_dangerous_imports(code)
if dangerous_imports:
# Validate strictness level
if strictness not in STRICTNESS_PROFILES:
return {
"success": False,
"error": f"Blocked imports: {', '.join(dangerous_imports)}. These modules are not allowed for security reasons."
"error": f"Invalid strictness level: {strictness}. Must be one of: {', '.join(STRICTNESS_PROFILES.keys())}"
}
# Security check: detect dangerous function calls
dangerous_calls = _check_dangerous_calls(code)
# Get profile configuration
profile = STRICTNESS_PROFILES[strictness]
allowlist_modules = profile.get("allowlist_modules")
blocked_builtins = profile["blocked_builtins"]
timeout = profile["timeout"]
# Parse and validate code syntax first
try:
tree = ast.parse(code)
except SyntaxError as e:
return {"success": False, "error": f"Syntax error in code: {e}"}
# Security check: detect disallowed imports
disallowed_imports = _check_disallowed_imports(tree, allowlist_modules)
if disallowed_imports:
return {
"success": False,
"error": f"Blocked imports: {', '.join(disallowed_imports)}. These modules are not allowed in '{strictness}' mode."
}
# Security check: detect dangerous function calls (skip if no restrictions)
if blocked_builtins:
dangerous_calls = _check_dangerous_calls(tree, blocked_builtins)
if dangerous_calls:
return {
"success": False,
"error": f"Blocked functions: {', '.join(dangerous_calls)}"
"error": f"Blocked functions: {', '.join(dangerous_calls)}. These functions are not allowed in '{strictness}' mode."
}
# Execute in isolated subprocess
# Choose execution backend
backend = config.code_execution.backend
if backend == "docker":
# Use Docker executor
executor = DockerExecutor(
image=config.code_execution.docker_image,
network=config.code_execution.docker_network,
user=config.code_execution.docker_user,
memory_limit=config.code_execution.docker_memory_limit,
cpu_shares=config.code_execution.docker_cpu_shares,
)
result = executor.execute(
code=code,
timeout=timeout,
strictness=strictness,
)
# Docker executor already returns the same dict structure
return result
else:
# Default subprocess backend
try:
result = subprocess.run(
[sys.executable, "-c", _build_safe_code(code)],
[sys.executable, "-c", _build_safe_code(code, blocked_builtins, allowlist_modules)],
capture_output=True,
timeout=10,
timeout=timeout,
cwd=tempfile.gettempdir(),
encoding="utf-8",
env={ # Clear environment variables
@ -99,18 +215,25 @@ def execute_python(arguments: dict) -> dict:
)
if result.returncode == 0:
return {"success": True, "output": result.stdout}
return {
"success": True,
"output": result.stdout,
"strictness": strictness,
"timeout": timeout
}
else:
return {"success": False, "error": result.stderr or "Execution failed"}
except subprocess.TimeoutExpired:
return {"success": False, "error": "Execution timeout (10s limit)"}
return {"success": False, "error": f"Execution timeout ({timeout}s limit in '{strictness}' mode)"}
except Exception as e:
return {"success": False, "error": f"Execution error: {str(e)}"}
def _build_safe_code(code: str) -> str:
"""Build sandboxed code with restricted globals"""
def _build_safe_code(code: str, blocked_builtins: Set[str],
allowlist_modules: Set[str] | None = None) -> str:
"""Build sandboxed code with restricted globals and runtime import hook."""
allowlist_repr = "None" if allowlist_modules is None else repr(allowlist_modules)
template = textwrap.dedent('''
import builtins
@ -118,6 +241,20 @@ def _build_safe_code(code: str) -> str:
_BLOCKED = %r
_safe_builtins = {k: getattr(builtins, k) for k in dir(builtins) if k not in _BLOCKED}
# Runtime import hook for allowlist enforcement
_ALLOWLIST = %s
if _ALLOWLIST is not None:
_original_import = builtins.__import__
def _restricted_import(name, *args, **kwargs):
top_level = name.split(".")[0]
if top_level not in _ALLOWLIST:
raise ImportError(
f"'{top_level}' is not allowed in the current strictness mode"
)
return _original_import(name, *args, **kwargs)
builtins.__import__ = _restricted_import
_safe_builtins["__import__"] = _restricted_import
# Create safe namespace
_safe_globals = {
"__builtins__": _safe_builtins,
@ -128,44 +265,43 @@ def _build_safe_code(code: str) -> str:
exec(%r, _safe_globals)
''').strip()
return template % (BLOCKED_BUILTINS, code)
return template % (blocked_builtins, allowlist_repr, code)
def _check_dangerous_imports(code: str) -> list:
"""Check for blocked (blacklisted) imports"""
try:
tree = ast.parse(code)
except SyntaxError:
def _check_disallowed_imports(tree: ast.AST, allowlist_modules: Set[str] | None) -> List[str]:
"""Check for imports not in allowlist. None allowlist means everything is allowed."""
if allowlist_modules is None:
return []
dangerous = []
disallowed = []
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
module = alias.name.split(".")[0]
if module in BLOCKED_MODULES:
dangerous.append(module)
if module not in allowlist_modules:
disallowed.append(module)
elif isinstance(node, ast.ImportFrom):
if node.module:
module = node.module.split(".")[0]
if module in BLOCKED_MODULES:
dangerous.append(module)
if module not in allowlist_modules:
disallowed.append(module)
return dangerous
return list(dict.fromkeys(disallowed)) # deduplicate while preserving order
def _check_dangerous_calls(code: str) -> list:
"""Check for blocked function calls"""
try:
tree = ast.parse(code)
except SyntaxError:
return []
def _check_dangerous_calls(tree: ast.AST, blocked_builtins: Set[str]) -> List[str]:
"""Check for blocked function calls including attribute access patterns."""
dangerous = []
for node in ast.walk(tree):
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
if node.func.id in BLOCKED_BUILTINS:
# Direct call: eval("...")
if node.func.id in blocked_builtins:
dangerous.append(node.func.id)
elif isinstance(node.func, ast.Attribute):
# Attribute call: builtins.open(...) or os.system(...)
attr_name = node.func.attr
if attr_name in blocked_builtins:
dangerous.append(attr_name)
return dangerous
return list(dict.fromkeys(dangerous))

View File

@ -0,0 +1,156 @@
"""Docker-based code execution with isolation."""
import subprocess
import tempfile
from typing import Optional, Dict, Any
from pathlib import Path
from backend.config import config
class DockerExecutor:
"""Execute Python code in isolated Docker containers."""
def __init__(
self,
image: str = "python:3.12-slim",
network: str = "none",
user: str = "nobody",
workdir: str = "/workspace",
memory_limit: Optional[str] = None,
cpu_shares: Optional[int] = None,
):
self.image = image
self.network = network
self.user = user
self.workdir = workdir
self.memory_limit = memory_limit
self.cpu_shares = cpu_shares
def execute(
self,
code: str,
timeout: int,
strictness: str,
extra_env: Optional[Dict[str, str]] = None,
mount_src: Optional[str] = None,
mount_dst: Optional[str] = None,
) -> Dict[str, Any]:
"""
Execute Python code in a Docker container.
Args:
code: Python code to execute.
timeout: Maximum execution time in seconds.
strictness: Strictness level (lenient/standard/strict) for logging.
extra_env: Additional environment variables.
mount_src: Host path to mount into container (optional).
mount_dst: Container mount path (defaults to workdir).
Returns:
Dictionary with keys:
success: bool
output: str if success else empty
error: str if not success else empty
container_id: str for debugging
"""
# Create temporary file with code inside a temporary directory
# so we can mount it into container
with tempfile.TemporaryDirectory() as tmpdir:
code_path = Path(tmpdir) / "code.py"
code_path.write_text(code, encoding="utf-8")
# Build docker run command
cmd = [
"docker", "run",
"--rm",
f"--network={self.network}",
f"--user={self.user}",
f"--workdir={self.workdir}",
f"--env=PYTHONIOENCODING=utf-8",
]
# Add memory limit if specified
if self.memory_limit:
cmd.append(f"--memory={self.memory_limit}")
# Add CPU shares if specified
if self.cpu_shares:
cmd.append(f"--cpu-shares={self.cpu_shares}")
# Add timeout via --stop-timeout (seconds before SIGKILL)
# Docker's timeout is different; we'll use subprocess timeout instead.
# We'll rely on subprocess timeout, but also set --stop-timeout as backup.
stop_timeout = timeout + 2 # give 2 seconds grace
cmd.append(f"--stop-timeout={stop_timeout}")
# Mount the temporary directory as /workspace (read-only)
cmd.extend(["-v", f"{tmpdir}:{self.workdir}:ro"])
# Additional mount if provided
if mount_src and mount_dst:
cmd.extend(["-v", f"{mount_src}:{mount_dst}:ro"])
# Add environment variables
env = extra_env or {}
for k, v in env.items():
cmd.extend(["-e", f"{k}={v}"])
# Finally, image and command to run
cmd.append(self.image)
cmd.extend(["python", "-c", code])
# Execute docker run with timeout
try:
result = subprocess.run(
cmd,
capture_output=True,
timeout=timeout,
encoding="utf-8",
errors="ignore",
)
if result.returncode == 0:
return {
"success": True,
"output": result.stdout,
"error": "",
"container_id": "", # not available with --rm
"strictness": strictness,
"timeout": timeout,
}
else:
return {
"success": False,
"output": "",
"error": result.stderr or f"Container exited with code {result.returncode}",
"container_id": "",
"strictness": strictness,
"timeout": timeout,
}
except subprocess.TimeoutExpired:
return {
"success": False,
"output": "",
"error": f"Execution timeout ({timeout}s limit in '{strictness}' mode)",
"container_id": "",
"strictness": strictness,
"timeout": timeout,
}
except Exception as e:
return {
"success": False,
"output": "",
"error": f"Docker execution error: {str(e)}",
"container_id": "",
"strictness": strictness,
"timeout": timeout,
}
# Singleton instance
_default_executor = DockerExecutor()
def execute_in_docker(code: str, timeout: int, strictness: str, **kwargs) -> Dict[str, Any]:
"""Convenience function using default executor."""
return _default_executor.execute(code, timeout, strictness, **kwargs)

View File

@ -51,30 +51,89 @@ class ToolExecutor:
return record["result"]
return None
def clear_history(self) -> None:
"""Clear call history (call this at start of new conversation turn)"""
self._call_history.clear()
@staticmethod
def _inject_context(name: str, args: dict, context: Optional[dict]) -> None:
"""Inject context fields into tool arguments in-place.
- file_* tools: inject project_id
- agent_task: inject model and project_id (prefixed with _ to avoid collisions)
- parallel_execute: inject project_id (prefixed with _ to avoid collisions)
- agent tools (multi_agent): inject _model and _project_id
"""
if not context:
return
if name.startswith("file_") and "project_id" in context:
args["project_id"] = context["project_id"]
if name == "agent_task":
if name == "multi_agent":
if "model" in context:
args["_model"] = context["model"]
if "project_id" in context:
args["_project_id"] = context["project_id"]
if name == "parallel_execute":
if "project_id" in context:
args["_project_id"] = context["project_id"]
if "max_tokens" in context:
args["_max_tokens"] = context["max_tokens"]
if "temperature" in context:
args["_temperature"] = context["temperature"]
def _prepare_call(
self,
call: dict,
context: Optional[dict],
seen_calls: set,
) -> tuple:
"""Parse, inject context, check dedup/cache for a single tool call.
Returns a tagged tuple:
("error", call_id, name, error_msg)
("cached", call_id, name, result_dict) -- dedup or cache hit
("execute", call_id, name, args, cache_key)
"""
name = call["function"]["name"]
args_str = call["function"]["arguments"]
call_id = call["id"]
# Parse JSON arguments
try:
args = json.loads(args_str) if isinstance(args_str, str) else args_str
except json.JSONDecodeError:
return ("error", call_id, name, "Invalid JSON arguments")
# Inject context
self._inject_context(name, args, context)
# Dedup within same batch
call_key = f"{name}:{json.dumps(args, sort_keys=True)}"
if call_key in seen_calls:
return ("cached", call_id, name,
{"success": True, "data": None, "cached": True, "duplicate": True})
seen_calls.add(call_key)
# History dedup
history_result = self._check_duplicate_in_history(name, args)
if history_result is not None:
return ("cached", call_id, name, {**history_result, "cached": True})
# Cache check
cache_key = self._make_cache_key(name, args)
cached_result = self._get_cached(cache_key)
if cached_result is not None:
return ("cached", call_id, name, {**cached_result, "cached": True})
return ("execute", call_id, name, args, cache_key)
def _execute_and_record(
self,
name: str,
args: dict,
cache_key: str,
) -> dict:
"""Execute a tool, cache result, record history, and return raw result dict."""
result = self._execute_tool(name, args)
if result.get("success"):
self._set_cache(cache_key, result)
self._call_history.append({
"name": name,
"args_str": json.dumps(args, sort_keys=True, ensure_ascii=False),
"result": result,
})
return result
def process_tool_calls_parallel(
self,
@ -85,10 +144,6 @@ class ToolExecutor:
"""
Process tool calls concurrently and return message list (ordered by input).
Identical logic to process_tool_calls but uses ThreadPoolExecutor so that
independent tool calls (e.g. reading 3 files, running 2 searches) execute
in parallel instead of sequentially.
Args:
tool_calls: Tool call list returned by LLM
context: Optional context info (user_id, project_id, etc.)
@ -102,80 +157,31 @@ class ToolExecutor:
max_workers = min(max(max_workers, 1), 6)
# Phase 1: prepare each call (parse args, inject context, check dedup/cache)
# This phase is fast and sequential it must be done before parallelism
# to avoid race conditions on seen_calls / _call_history / _cache.
prepared: List[Optional[tuple]] = [None] * len(tool_calls)
seen_calls: set = set()
# Phase 1: prepare (sequential avoids race conditions on shared state)
prepared = [self._prepare_call(call, context, set()) for call in tool_calls]
for i, call in enumerate(tool_calls):
name = call["function"]["name"]
args_str = call["function"]["arguments"]
call_id = call["id"]
# Parse JSON arguments
try:
args = json.loads(args_str) if isinstance(args_str, str) else args_str
except json.JSONDecodeError:
prepared[i] = self._create_error_result(call_id, name, "Invalid JSON arguments")
continue
# Inject context into tool arguments
self._inject_context(name, args, context)
# Dedup within same batch
call_key = f"{name}:{json.dumps(args, sort_keys=True)}"
if call_key in seen_calls:
prepared[i] = self._create_tool_result(
call_id, name,
{"success": True, "data": None, "cached": True, "duplicate": True}
)
continue
seen_calls.add(call_key)
# History dedup
history_result = self._check_duplicate_in_history(name, args)
if history_result is not None:
prepared[i] = self._create_tool_result(call_id, name, {**history_result, "cached": True})
continue
# Cache check
cache_key = self._make_cache_key(name, args)
cached_result = self._get_cached(cache_key)
if cached_result is not None:
prepared[i] = self._create_tool_result(call_id, name, {**cached_result, "cached": True})
continue
# Mark as needing actual execution
prepared[i] = ("execute", call_id, name, args, cache_key)
# Separate pre-resolved results from tasks needing execution
# Phase 2: separate pre-resolved from tasks needing execution
results: List[dict] = [None] * len(tool_calls)
exec_tasks: Dict[int, tuple] = {} # index -> (call_id, name, args, cache_key)
exec_tasks: Dict[int, tuple] = {}
for i, item in enumerate(prepared):
if isinstance(item, dict):
results[i] = item
elif isinstance(item, tuple) and item[0] == "execute":
tag = item[0]
if tag == "error":
_, call_id, name, error_msg = item
results[i] = self._create_error_result(call_id, name, error_msg)
elif tag == "cached":
_, call_id, name, result_dict = item
results[i] = self._create_tool_result(call_id, name, result_dict)
else: # "execute"
_, call_id, name, args, cache_key = item
exec_tasks[i] = (call_id, name, args, cache_key)
# Phase 2: execute remaining calls in parallel
# Phase 3: execute remaining calls in parallel
if exec_tasks:
def _run(idx: int, call_id: str, name: str, args: dict, cache_key: str) -> tuple:
t0 = time.time()
result = self._execute_tool(name, args)
result = self._execute_and_record(name, args, cache_key)
elapsed = time.time() - t0
if result.get("success"):
self._set_cache(cache_key, result)
self._call_history.append({
"name": name,
"args_str": json.dumps(args, sort_keys=True, ensure_ascii=False),
"result": result,
})
return idx, self._create_tool_result(call_id, name, result, execution_time=elapsed)
with ThreadPoolExecutor(max_workers=max_workers) as pool:
@ -205,64 +211,21 @@ class ToolExecutor:
Tool response message list, can be appended to messages
"""
results = []
seen_calls = set() # Track calls within this batch
seen_calls: set = set()
for call in tool_calls:
name = call["function"]["name"]
args_str = call["function"]["arguments"]
call_id = call["id"]
try:
args = json.loads(args_str) if isinstance(args_str, str) else args_str
except json.JSONDecodeError:
results.append(self._create_error_result(
call_id, name, "Invalid JSON arguments"
))
continue
# Inject context into tool arguments
self._inject_context(name, args, context)
# Check for duplicate within same batch
call_key = f"{name}:{json.dumps(args, sort_keys=True)}"
if call_key in seen_calls:
# Skip duplicate, but still return a result
results.append(self._create_tool_result(
call_id, name,
{"success": True, "data": None, "cached": True, "duplicate": True}
))
continue
seen_calls.add(call_key)
# Check history for previous call in this session
history_result = self._check_duplicate_in_history(name, args)
if history_result is not None:
result = {**history_result, "cached": True}
results.append(self._create_tool_result(call_id, name, result))
continue
# Check cache
cache_key = self._make_cache_key(name, args)
cached_result = self._get_cached(cache_key)
if cached_result is not None:
result = {**cached_result, "cached": True}
results.append(self._create_tool_result(call_id, name, result))
continue
# Execute tool with retry
result = self._execute_tool(name, args)
# Cache the result (only cache successful results)
if result.get("success"):
self._set_cache(cache_key, result)
# Add to history
self._call_history.append({
"name": name,
"args_str": json.dumps(args, sort_keys=True, ensure_ascii=False),
"result": result
})
prepared = self._prepare_call(call, context, seen_calls)
tag = prepared[0]
if tag == "error":
_, call_id, name, error_msg = prepared
results.append(self._create_error_result(call_id, name, error_msg))
elif tag == "cached":
_, call_id, name, result_dict = prepared
results.append(self._create_tool_result(call_id, name, result_dict))
else: # "execute"
_, call_id, name, args, cache_key = prepared
result = self._execute_and_record(name, args, cache_key)
results.append(self._create_tool_result(call_id, name, result))
return results

View File

@ -34,30 +34,3 @@ def tool(
return func
return decorator
def register_tool(
name: str,
handler: Callable,
description: str,
parameters: dict,
category: str = "general"
) -> None:
"""
Register a tool directly (without decorator)
Usage:
register_tool(
name="my_tool",
handler=my_function,
description="Description",
parameters={...}
)
"""
tool_def = ToolDefinition(
name=name,
description=description,
parameters=parameters,
handler=handler,
category=category
)
registry.register(tool_def)

View File

@ -1,11 +1,11 @@
"""Backend utilities"""
from backend.utils.helpers import ok, err, to_dict, get_or_create_default_user, record_token_usage, build_messages
from backend.utils.helpers import ok, err, to_dict, message_to_dict, record_token_usage, build_messages
__all__ = [
"ok",
"err",
"to_dict",
"get_or_create_default_user",
"message_to_dict",
"record_token_usage",
"build_messages",
]

View File

@ -1,6 +1,6 @@
"""Common helper functions"""
import json
from datetime import date, datetime
from datetime import date, datetime, timezone
from typing import Any
from flask import jsonify
from backend import db
@ -97,7 +97,7 @@ def message_to_dict(msg: Message) -> dict:
def record_token_usage(user_id, model, prompt_tokens, completion_tokens):
"""Record token usage"""
today = date.today()
today = datetime.now(timezone.utc).date()
usage = TokenUsage.query.filter_by(
user_id=user_id, date=today, model=model
).first()
@ -133,6 +133,13 @@ def build_messages(conv, project_id=None):
# Query messages directly to avoid detached instance warning
messages = Message.query.filter_by(conversation_id=conv.id).order_by(Message.created_at.asc()).all()
for m in messages:
# Skip tool messages — they are ephemeral intermediate results, not
# meant to be replayed as conversation history (would violate the API
# protocol that requires tool messages to follow an assistant message
# with matching tool_calls).
if m.role == "tool":
continue
# Build full content from JSON structure
full_content = m.content
try:

View File

@ -66,6 +66,7 @@ backend/
│ ├── data.py # 计算器、文本、JSON
│ ├── weather.py # 天气查询
│ ├── file_ops.py # 文件操作project_id 自动注入)
│ ├── agent.py # 多智能体(子 Agent 并发执行,工具权限隔离)
│ └── code.py # 代码执行
├── utils/ # 辅助函数
@ -266,8 +267,8 @@ classDiagram
-ToolRegistry registry
-dict _cache
-list _call_history
+process_tool_calls(calls, context) list
+clear_history() void
+process_tool_calls(list, dict) list
+process_tool_calls_parallel(list, dict, int) list
}
ChatService --> LLMClient : 使用
@ -295,18 +296,17 @@ classDiagram
+register(ToolDefinition) void
+get(str name) ToolDefinition?
+list_all() list~dict~
+list_by_category(str) list~dict~
+execute(str name, dict args) dict
+remove(str name) bool
+has(str name) bool
}
class ToolExecutor {
-ToolRegistry registry
-bool enable_cache
-int cache_ttl
-dict _cache
-list _call_history
+process_tool_calls(list, dict) list
+clear_history() void
+process_tool_calls_parallel(list, dict, int) list
}
class ToolResult {
@ -394,18 +394,19 @@ def validate_path_in_project(path: str, project_dir: Path) -> Path:
工具执行器自动为文件工具注入 `project_id`
```python
# backend/tools/executor.py
# backend/tools/executor.py — _inject_context()
def process_tool_calls(self, tool_calls, context=None):
for call in tool_calls:
name = call["function"]["name"]
args = json.loads(call["function"]["arguments"])
# 自动注入 project_id
if context and name.startswith("file_") and "project_id" in context:
@staticmethod
def _inject_context(name: str, args: dict, context: Optional[dict]) -> None:
# file_* 工具: 注入 project_id
if name.startswith("file_") and "project_id" in context:
args["project_id"] = context["project_id"]
result = self.registry.execute(name, args)
# agent 工具: 注入 _model 和 _project_id
if name == "multi_agent":
if "model" in context:
args["_model"] = context["model"]
if "project_id" in context:
args["_project_id"] = context["project_id"]
```
---
@ -1020,6 +1021,12 @@ frontend_port: 4000
# 智能体循环最大迭代次数(工具调用轮次上限,默认 5
max_iterations: 15
# 子代理资源配置multi_agent 工具)
# max_tokens 和 temperature 与主 Agent 共用,无需单独配置
sub_agent:
max_iterations: 3 # 每个子代理的最大工具调用轮数
max_concurrency: 3 # 并发线程数
# 可用模型列表(每个模型必须指定 api_url 和 api_key
# 支持任何 OpenAI 兼容 APIDeepSeek、GLM、OpenAI、Moonshot、Qwen 等)
models:

View File

@ -27,19 +27,20 @@ classDiagram
+register(ToolDefinition tool) void
+get(str name) ToolDefinition?
+list_all() list~dict~
+list_by_category(str category) list~dict~
+execute(str name, dict args) dict
+remove(str name) bool
+has(str name) bool
}
class ToolExecutor {
-ToolRegistry registry
-bool enable_cache
-int cache_ttl
-dict _cache
-list _call_history
+process_tool_calls(list tool_calls, dict context) list~dict~
+build_request(list messages, str model, list tools, dict kwargs) dict
+clear_history() void
+process_tool_calls_parallel(list tool_calls, dict context, int max_workers) list~dict~
-_prepare_call(dict call, dict context, set seen_calls) tuple
-_execute_and_record(str name, dict args, str cache_key) dict
-_inject_context(str name, dict args, dict context) void
}
class ToolResult {
@ -88,32 +89,26 @@ classDiagram
### context 参数
`process_tool_calls()` 接受 `context` 参数,用于自动注入工具参数:
`process_tool_calls()` / `process_tool_calls_parallel()` 接受 `context` 参数,用于自动注入工具参数:
```python
# backend/tools/executor.py
# backend/tools/executor.py — _inject_context()
def process_tool_calls(
self,
tool_calls: List[dict],
context: Optional[dict] = None
) -> List[dict]:
@staticmethod
def _inject_context(name: str, args: dict, context: Optional[dict]) -> None:
"""
Args:
tool_calls: LLM 返回的工具调用列表
context: 上下文信息,支持:
- project_id: 自动注入到文件工具
- file_* 工具: 注入 project_id
- agent 工具 (multi_agent): 注入 _model 和 _project_id
"""
for call in tool_calls:
name = call["function"]["name"]
args = json.loads(call["function"]["arguments"])
# 自动注入 project_id 到文件工具
if context:
if not context:
return
if name.startswith("file_") and "project_id" in context:
args["project_id"] = context["project_id"]
result = self.registry.execute(name, args)
if name == "multi_agent":
if "model" in context:
args["_model"] = context["model"]
if "project_id" in context:
args["_project_id"] = context["project_id"]
```
### 使用示例
@ -122,12 +117,12 @@ def process_tool_calls(
# backend/services/chat.py
def stream_response(self, conv, tools_enabled=True, project_id=None):
# 构建上下文(优先使用请求传递的 project_id否则回退到对话绑定的
context = None
# 构建上下文(包含 model 和 project_id
context = {"model": conv.model}
if project_id:
context = {"project_id": project_id}
context["project_id"] = project_id
elif conv.project_id:
context = {"project_id": conv.project_id}
context["project_id"] = conv.project_id
# 处理工具调用时自动注入
tool_results = self.executor.process_tool_calls(tool_calls, context)
@ -222,14 +217,68 @@ file_read({"path": "src/main.py", "project_id": "xxx"})
| 工具名称 | 描述 | 参数 |
|---------|------|------|
| `execute_python` | 在沙箱环境中执行 Python 代码 | `code`: Python 代码 |
| `execute_python` | 在沙箱环境中执行 Python 代码 | `code`: Python 代码<br>`strictness`: 可选严格等级lenient/standard/strict |
安全措施:
- 白名单模块限制
- 危险内置函数禁止
- 10 秒超时限制
- 无文件系统访问
- 无网络访问
**严格等级配置:**
| 等级 | 超时 | 策略 | 适用场景 |
|------|------|------|---------|
| `lenient` | 30s | 无限制,所有模块和内置函数均可使用 | 数据处理、需要完整标准库 |
| `standard` | 10s | 白名单机制,仅允许安全模块(默认) | 通用场景 |
| `strict` | 5s | 精简白名单,仅允许纯计算模块 | 基础计算 |
**standard 白名单模块:** json, csv, re, typing, collections, itertools, functools, operator, heapq, bisect, array, copy, pprint, enum, math, cmath, statistics, random, fractions, decimal, numbers, datetime, time, calendar, string, textwrap, unicodedata, difflib, base64, binascii, quopri, uu, html, xml.etree.ElementTree, dataclasses, hashlib, hmac, abc, contextlib, warnings, logging
**strict 白名单模块:** collections, itertools, functools, operator, array, copy, enum, math, cmath, numbers, fractions, decimal, random, statistics, string, textwrap, unicodedata, typing, dataclasses, abc, contextlib
**内置函数限制:**
- standard 禁止eval, exec, compile, \_\_import\_\_, open, input, globals, locals, vars, breakpoint, exit, quit, memoryview, bytearray
- strict 额外禁止dir, hasattr, getattr, setattr, delattr, type, isinstance, issubclass
**白名单扩展方式:**
1. **config.yml 配置(持久化):**
```yaml
code_execution:
default_strictness: standard
extra_allowed_modules:
standard: [numpy, pandas]
strict: [numpy]
```
2. **代码 API插件/运行时):**
```python
from backend.tools.builtin.code import register_extra_modules
register_extra_modules("standard", {"numpy", "pandas"})
register_extra_modules("strict", {"numpy"})
```
**使用示例:**
```python
# 默认 standard 模式(白名单限制)
execute_python({"code": "import json; print(json.dumps({'key': 'value'}))"})
# lenient 模式 - 无限制
execute_python({
"code": "import os; print(os.getcwd())",
"strictness": "lenient"
})
# strict 模式 - 仅纯计算
execute_python({
"code": "result = sum([1, 2, 3, 4, 5]); print(result)",
"strictness": "strict"
})
```
**安全措施:**
- standard/strict: 白名单模块限制(默认拒绝,仅显式允许)
- lenient: 无限制
- 危险内置函数按等级禁止
- 可配置超时限制5s/10s/30s
- subprocess 隔离执行
### 5.4 文件操作工具 (file)
@ -250,6 +299,31 @@ file_read({"path": "src/main.py", "project_id": "xxx"})
|---------|------|------|
| `get_weather` | 查询天气信息(模拟) | `city`: 城市名称 |
### 5.6 多智能体工具 (agent)
| 工具名称 | 描述 | 参数 |
|---------|------|------|
| `multi_agent` | 派生子 Agent 并发执行任务 | `tasks`: 任务数组name, instruction, tools<br>`_model`: 模型名称(自动注入)<br>`_project_id`: 项目 ID自动注入 |
**`multi_agent` 工作原理:**
1. 接收任务数组,每个任务指定 name、instruction 和可选的 tools 列表
2. 子 Agent **禁止使用 `multi_agent` 工具**`BLOCKED_TOOLS`),防止无限递归
3. 子 Agent 工具权限与主 Agent 一致(除 multi_agent 外的所有已注册工具),支持并行工具执行
4. 为每个子 Agent 创建独立线程,各自拥有 LLM 对话循环
5. 子 Agent 在 `app.app_context()` 中运行 LLM 调用和工具执行,确保数据库等依赖正常工作
6. 通过 Service Locator 获取 `llm_client` 实例
7. 返回 `{success, results: [{task_name, success, response/error}], total}`
**资源配置**`config.yml` → `sub_agent`
| 配置项 | 默认值 | 说明 |
|--------|--------|------|
| `max_iterations` | 3 | 每个子代理的最大工具调用轮数 |
| `max_concurrency` | 3 | ThreadPoolExecutor 并发线程数 |
> - `max_tokens``temperature` 与主 Agent 共用,从对话配置中获取,无需单独配置。
> - 子代理禁止调用 `multi_agent` 工具,防止无限递归。
---
## 六、核心特性
@ -285,7 +359,6 @@ def my_tool(arguments: dict) -> dict:
- **批次内去重**:同一批次中相同工具+参数的调用会被跳过
- **历史去重**:同一会话内已调用过的工具会直接返回缓存结果
- **自动清理**:新会话开始时调用 `clear_history()` 清理历史
### 6.4 无自动重试
@ -308,13 +381,45 @@ def my_tool(arguments: dict) -> dict:
def init_tools() -> None:
"""初始化所有内置工具"""
from backend.tools.builtin import (
code, crawler, data, weather, file_ops
code, crawler, data, weather, file_ops, agent
)
```
---
## 八、扩展新工具
## 八、Service Locator
工具系统提供 Service Locator 模式,允许工具访问共享服务(如 LLM 客户端):
```python
# backend/tools/__init__.py
_services: dict = {}
def register_service(name: str, service) -> None:
"""注册共享服务"""
_services[name] = service
def get_service(name: str):
"""获取已注册的服务,不存在则返回 None"""
return _services.get(name)
```
### 使用方式
```python
# 在应用初始化时注册routes/__init__.py
from backend.tools import register_service
register_service("llm_client", llm_client)
# 在工具中使用agent.py
from backend.tools import get_service
llm_client = get_service("llm_client")
```
---
## 九、扩展新工具
### 添加新工具

View File

@ -22,6 +22,7 @@
"@codemirror/lang-xml": "^6.1.0",
"@codemirror/lang-yaml": "^6.1.2",
"@codemirror/theme-one-dark": "^6.1.2",
"chart.js": "^4.5.1",
"codemirror": "^6.0.1",
"highlight.js": "^11.11.1",
"katex": "^0.16.40",
@ -791,6 +792,12 @@
"integrity": "sha512-cYQ9310grqxueWbl+WuIUIaiUaDcj7WOq5fVhEljNVgRfOUhY9fy2zTvfoqWsnebh8Sl70VScFbICvJnLKB0Og==",
"license": "MIT"
},
"node_modules/@kurkle/color": {
"version": "0.3.4",
"resolved": "https://registry.npmmirror.com/@kurkle/color/-/color-0.3.4.tgz",
"integrity": "sha512-M5UknZPHRu3DEDWoipU6sE8PdkZ6Z/S+v4dD+Ke8IaNlpdSQah50lz1KtcFBa2vsdOnwbbnxJwVM4wty6udA5w==",
"license": "MIT"
},
"node_modules/@lezer/common": {
"version": "1.5.1",
"resolved": "https://registry.npmmirror.com/@lezer/common/-/common-1.5.1.tgz",
@ -1430,6 +1437,18 @@
"dev": true,
"license": "Python-2.0"
},
"node_modules/chart.js": {
"version": "4.5.1",
"resolved": "https://registry.npmmirror.com/chart.js/-/chart.js-4.5.1.tgz",
"integrity": "sha512-GIjfiT9dbmHRiYi6Nl2yFCq7kkwdkp1W/lp2J99rX0yo9tgJGn3lKQATztIjb5tVtevcBtIdICNWqlq5+E8/Pw==",
"license": "MIT",
"dependencies": {
"@kurkle/color": "^0.3.0"
},
"engines": {
"pnpm": ">=8"
}
},
"node_modules/codemirror": {
"version": "6.0.2",
"resolved": "https://registry.npmmirror.com/codemirror/-/codemirror-6.0.2.tgz",

View File

@ -9,21 +9,22 @@
"preview": "vite preview"
},
"dependencies": {
"codemirror": "^6.0.1",
"@codemirror/theme-one-dark": "^6.1.2",
"@codemirror/lang-markdown": "^6.3.2",
"@codemirror/lang-javascript": "^6.2.3",
"@codemirror/lang-python": "^6.1.7",
"@codemirror/lang-html": "^6.4.9",
"@codemirror/lang-css": "^6.3.1",
"@codemirror/lang-json": "^6.0.1",
"@codemirror/lang-yaml": "^6.1.2",
"@codemirror/lang-java": "^6.0.1",
"@codemirror/lang-cpp": "^6.0.2",
"@codemirror/lang-rust": "^6.0.1",
"@codemirror/lang-css": "^6.3.1",
"@codemirror/lang-go": "^6.0.1",
"@codemirror/lang-html": "^6.4.9",
"@codemirror/lang-java": "^6.0.1",
"@codemirror/lang-javascript": "^6.2.3",
"@codemirror/lang-json": "^6.0.1",
"@codemirror/lang-markdown": "^6.3.2",
"@codemirror/lang-python": "^6.1.7",
"@codemirror/lang-rust": "^6.0.1",
"@codemirror/lang-sql": "^6.8.0",
"@codemirror/lang-xml": "^6.1.0",
"@codemirror/lang-yaml": "^6.1.2",
"@codemirror/theme-one-dark": "^6.1.2",
"chart.js": "^4.5.1",
"codemirror": "^6.0.1",
"highlight.js": "^11.11.1",
"katex": "^0.16.40",
"marked": "^15.0.12",

View File

@ -42,6 +42,7 @@
:messages="messages"
:streaming="streaming"
:streaming-process-steps="streamProcessSteps"
:model-name-map="modelNameMap"
:has-more-messages="hasMoreMessages"
:loading-more="loadingMessages"
:tools-enabled="toolsEnabled"
@ -59,8 +60,11 @@
<div v-if="showSettings" class="modal-overlay" @click.self="showSettings = false">
<div class="modal-content">
<SettingsPanel
:key="currentConvId || '__none__'"
:visible="showSettings"
:conversation="currentConv"
:models="models"
:default-model="defaultModel"
@close="showSettings = false"
@save="saveSettings"
/>
@ -117,13 +121,37 @@ import ModalDialog from './components/ModalDialog.vue'
import ToastContainer from './components/ToastContainer.vue'
import { icons } from './utils/icons'
import { useModal } from './composables/useModal'
import {
DEFAULT_CONVERSATION_PAGE_SIZE,
DEFAULT_MESSAGE_PAGE_SIZE,
LS_KEY_TOOLS_ENABLED,
} from './constants'
const SettingsPanel = defineAsyncComponent(() => import('./components/SettingsPanel.vue'))
const StatsPanel = defineAsyncComponent(() => import('./components/StatsPanel.vue'))
import { conversationApi, messageApi, projectApi } from './api'
import { conversationApi, messageApi, projectApi, modelApi } from './api'
const modal = useModal()
// -- Models state (preloaded) --
const models = ref([])
const modelNameMap = ref({})
const defaultModel = computed(() => models.value.length > 0 ? models.value[0].id : '')
async function loadModels() {
try {
const res = await modelApi.getCached()
models.value = res.data || []
const map = {}
for (const m of models.value) {
if (m.id && m.name) map[m.id] = m.name
}
modelNameMap.value = map
} catch (e) {
console.error('Failed to load models:', e)
}
}
// -- Conversations state --
const conversations = shallowRef([])
const currentConvId = ref(null)
@ -203,7 +231,7 @@ function updateStreamField(convId, field, ref, valueOrUpdater) {
// -- UI state --
const showSettings = ref(false)
const showStats = ref(false)
const toolsEnabled = ref(localStorage.getItem('tools_enabled') !== 'false')
const toolsEnabled = ref(localStorage.getItem(LS_KEY_TOOLS_ENABLED) !== 'false')
const currentProject = ref(null)
const showFileExplorer = ref(false)
const showCreateModal = ref(false)
@ -227,7 +255,7 @@ async function loadConversations(reset = true) {
if (loadingConvs.value) return
loadingConvs.value = true
try {
const res = await conversationApi.list(reset ? null : nextConvCursor.value, 20)
const res = await conversationApi.list(reset ? null : nextConvCursor.value, DEFAULT_CONVERSATION_PAGE_SIZE)
if (reset) {
conversations.value = res.data.items
} else {
@ -258,6 +286,7 @@ async function createConversationInProject(project) {
const res = await conversationApi.create({
title: '新对话',
project_id: project.id || null,
model: defaultModel.value || undefined,
})
conversations.value = [res.data, ...conversations.value]
await selectConversation(res.data.id)
@ -412,7 +441,7 @@ function createStreamCallbacks(convId, { updateConvList = true } = {}) {
}
} else {
try {
const res = await messageApi.list(convId, null, 50)
const res = await messageApi.list(convId, null, DEFAULT_MESSAGE_PAGE_SIZE)
const idx = conversations.value.findIndex(c => c.id === convId)
if (idx >= 0) {
const conv = conversations.value[idx]
@ -533,7 +562,7 @@ async function saveSettings(data) {
// -- Update tools enabled --
function updateToolsEnabled(val) {
toolsEnabled.value = val
localStorage.setItem('tools_enabled', String(val))
localStorage.setItem(LS_KEY_TOOLS_ENABLED, String(val))
}
// -- Browse project files --
@ -602,7 +631,8 @@ async function deleteProject(project) {
}
// -- Init --
onMounted(() => {
onMounted(async () => {
await loadModels()
loadProjects()
loadConversations()
})

View File

@ -1,4 +1,11 @@
const BASE = '/api'
import {
API_BASE_URL,
CONTENT_TYPE_JSON,
LS_KEY_MODELS_CACHE,
DEFAULT_CONVERSATION_PAGE_SIZE,
DEFAULT_MESSAGE_PAGE_SIZE,
DEFAULT_PROJECT_PAGE_SIZE,
} from '../constants'
// Cache for models list
let modelsCache = null
@ -13,8 +20,8 @@ function buildQueryParams(params) {
}
async function request(url, options = {}) {
const res = await fetch(`${BASE}${url}`, {
headers: { 'Content-Type': 'application/json' },
const res = await fetch(`${API_BASE_URL}${url}`, {
headers: { 'Content-Type': CONTENT_TYPE_JSON },
...options,
body: options.body ? JSON.stringify(options.body) : undefined,
})
@ -37,9 +44,9 @@ function createSSEStream(url, body, { onProcessStep, onDone, onError }) {
const promise = (async () => {
try {
const res = await fetch(`${BASE}${url}`, {
const res = await fetch(`${API_BASE_URL}${url}`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
headers: { 'Content-Type': CONTENT_TYPE_JSON },
body: JSON.stringify(body),
signal: controller.signal,
})
@ -107,7 +114,7 @@ export const modelApi = {
}
// Try localStorage cache first
const cached = localStorage.getItem('models_cache')
const cached = localStorage.getItem(LS_KEY_MODELS_CACHE)
if (cached) {
try {
modelsCache = JSON.parse(cached)
@ -118,7 +125,7 @@ export const modelApi = {
// Fetch from server
const res = await this.list()
modelsCache = res.data
localStorage.setItem('models_cache', JSON.stringify(modelsCache))
localStorage.setItem(LS_KEY_MODELS_CACHE, JSON.stringify(modelsCache))
return res
},
@ -131,7 +138,7 @@ export const statsApi = {
}
export const conversationApi = {
list(cursor, limit = 20, projectId = null) {
list(cursor, limit = DEFAULT_CONVERSATION_PAGE_SIZE, projectId = null) {
return request(`/conversations${buildQueryParams({ cursor, limit, project_id: projectId })}`)
},
@ -159,7 +166,7 @@ export const conversationApi = {
}
export const messageApi = {
list(convId, cursor, limit = 50) {
list(convId, cursor, limit = DEFAULT_MESSAGE_PAGE_SIZE) {
return request(`/conversations/${convId}/messages${buildQueryParams({ cursor, limit })}`)
},
@ -186,7 +193,7 @@ export const messageApi = {
}
export const projectApi = {
list(cursor, limit = 20) {
list(cursor, limit = DEFAULT_PROJECT_PAGE_SIZE) {
return request(`/projects${buildQueryParams({ cursor, limit })}`)
},
@ -210,7 +217,7 @@ export const projectApi = {
},
readFileRaw(projectId, filepath) {
return fetch(`${BASE}/projects/${projectId}/files/${filepath}`).then(res => {
return fetch(`${API_BASE_URL}/projects/${projectId}/files/${filepath}`).then(res => {
if (!res.ok) throw new Error(`HTTP ${res.status}`)
return res
})

View File

@ -79,13 +79,13 @@ import MessageBubble from './MessageBubble.vue'
import MessageInput from './MessageInput.vue'
import MessageNav from './MessageNav.vue'
import ProcessBlock from './ProcessBlock.vue'
import { modelApi } from '../api'
const props = defineProps({
conversation: { type: Object, default: null },
messages: { type: Array, required: true },
streaming: { type: Boolean, default: false },
streamingProcessSteps: { type: Array, default: () => [] },
modelNameMap: { type: Object, default: () => ({}) },
hasMoreMessages: { type: Boolean, default: false },
loadingMore: { type: Boolean, default: false },
toolsEnabled: { type: Boolean, default: true },
@ -95,27 +95,15 @@ const emit = defineEmits(['sendMessage', 'stopStreaming', 'deleteMessage', 'rege
const scrollContainer = ref(null)
const inputRef = ref(null)
const modelNameMap = ref({})
const activeMessageId = ref(null)
let scrollObserver = null
const observedElements = new WeakSet()
function formatModelName(modelId) {
return modelNameMap.value[modelId] || modelId
return props.modelNameMap[modelId] || modelId
}
onMounted(async () => {
try {
const res = await modelApi.getCached()
const map = {}
for (const m of res.data) {
if (m.id && m.name) map[m.id] = m.name
}
modelNameMap.value = map
} catch (e) {
console.warn('Failed to load model names:', e)
}
onMounted(() => {
if (scrollContainer.value) {
scrollObserver = new IntersectionObserver(
(entries) => {
@ -257,16 +245,6 @@ watch(() => props.conversation?.id, () => {
line-height: 1;
}
.thinking-badge {
background: rgba(245, 158, 11, 0.12);
color: #d97706;
}
[data-theme="dark"] .thinking-badge {
background: rgba(245, 158, 11, 0.18);
color: #fbbf24;
}
.messages-container {
flex: 1 1 auto;
overflow-y: auto;
@ -313,4 +291,5 @@ watch(() => props.conversation?.id, () => {
</style>

View File

@ -24,7 +24,7 @@
<input
ref="fileInputRef"
type="file"
accept=".txt,.md,.json,.xml,.html,.css,.js,.ts,.jsx,.tsx,.py,.java,.c,.cpp,.h,.hpp,.yaml,.yml,.toml,.ini,.csv,.sql,.sh,.bat,.log,.vue,.svelte,.go,.rs,.rb,.php,.swift,.kt,.scala,.lua,.r,.dart"
accept=ALLOWED_UPLOAD_EXTENSIONS
@change="handleFileUpload"
style="display: none"
/>
@ -65,6 +65,7 @@
<script setup>
import { ref, computed, nextTick } from 'vue'
import { icons } from '../utils/icons'
import { TEXTAREA_MAX_HEIGHT_PX, ALLOWED_UPLOAD_EXTENSIONS } from '../constants'
const props = defineProps({
disabled: { type: Boolean, default: false },
@ -83,7 +84,7 @@ function autoResize() {
const el = textareaRef.value
if (!el) return
el.style.height = 'auto'
el.style.height = Math.min(el.scrollHeight, 200) + 'px'
el.style.height = Math.min(el.scrollHeight, TEXTAREA_MAX_HEIGHT_PX) + 'px'
}
function onKeydown(e) {

View File

@ -17,6 +17,7 @@
<script setup>
import { computed } from 'vue'
import { DEFAULT_TRUNCATE_LENGTH } from '../constants'
const props = defineProps({
messages: { type: Array, required: true },
@ -30,7 +31,7 @@ const userMessages = computed(() => props.messages.filter(m => m.role === 'user'
function preview(msg) {
if (!msg.text) return '...'
const clean = msg.text.replace(/[#*`~>\-\[\]()]/g, '').replace(/\s+/g, ' ').trim()
return clean.length > 60 ? clean.slice(0, 60) + '...' : clean
return clean.length > DEFAULT_TRUNCATE_LENGTH ? clean.slice(0, DEFAULT_TRUNCATE_LENGTH) + '...' : clean
}
</script>

View File

@ -41,7 +41,10 @@
</div>
<div v-if="item.result" class="tool-detail">
<span class="detail-label">返回结果:</span>
<pre>{{ item.result }}</pre>
<pre>{{ expandedResultKeys[item.key] ? item.result : item.resultPreview }}</pre>
<button v-if="item.resultTruncated" class="btn-expand-result" @click.stop="toggleResultExpand(item.key)">
{{ expandedResultKeys[item.key] ? '收起' : `展开全部 (${item.resultLength} 字符)` }}
</button>
</div>
</div>
</div>
@ -67,6 +70,19 @@ import { ref, computed, watch } from 'vue'
import { renderMarkdown } from '../utils/markdown'
import { formatJson, truncate } from '../utils/format'
import { useCodeEnhancement } from '../composables/useCodeEnhancement'
import { RESULT_PREVIEW_LIMIT } from '../constants'
function buildResultFields(rawContent) {
const formatted = formatJson(rawContent)
const len = formatted.length
const truncated = len > RESULT_PREVIEW_LIMIT
return {
result: formatted,
resultPreview: truncated ? formatted.slice(0, RESULT_PREVIEW_LIMIT) + '\n...' : formatted,
resultTruncated: truncated,
resultLength: len,
}
}
const props = defineProps({
toolCalls: { type: Array, default: () => [] },
@ -75,10 +91,12 @@ const props = defineProps({
})
const expandedKeys = ref({})
const expandedResultKeys = ref({})
// Auto-collapse all items when a new stream starts
watch(() => props.streaming, (v) => {
if (v) expandedKeys.value = {}
expandedResultKeys.value = {}
})
const processRef = ref(null)
@ -87,6 +105,10 @@ function toggleItem(key) {
expandedKeys.value[key] = !expandedKeys.value[key]
}
function toggleResultExpand(key) {
expandedResultKeys.value[key] = !expandedResultKeys.value[key]
}
function getResultSummary(result) {
try {
const parsed = typeof result === 'string' ? JSON.parse(result) : result
@ -138,7 +160,7 @@ const processItems = computed(() => {
const summary = getResultSummary(step.content)
const match = items.findLast(it => it.type === 'tool_call' && it.id === toolId)
if (match) {
match.result = formatJson(step.content)
Object.assign(match, buildResultFields(step.content))
match.resultSummary = summary.text
match.isSuccess = summary.success
match.loading = false
@ -165,7 +187,8 @@ const processItems = computed(() => {
if (props.toolCalls && props.toolCalls.length > 0) {
props.toolCalls.forEach((call, i) => {
const toolName = call.function?.name || '未知工具'
const result = call.result ? getResultSummary(call.result) : null
const resultSummary = call.result ? getResultSummary(call.result) : null
const resultFields = call.result ? buildResultFields(call.result) : { result: null, resultPreview: null, resultTruncated: false, resultLength: 0 }
items.push({
type: 'tool_call',
toolName,
@ -174,9 +197,9 @@ const processItems = computed(() => {
id: call.id,
key: `tool_call-${call.id || i}`,
loading: !call.result && props.streaming,
result: call.result ? formatJson(call.result) : null,
resultSummary: result ? result.text : null,
isSuccess: result ? result.success : undefined,
...resultFields,
resultSummary: resultSummary ? resultSummary.text : null,
isSuccess: resultSummary ? resultSummary.success : undefined,
})
})
}
@ -345,6 +368,23 @@ watch(() => props.processSteps?.length, () => {
word-break: break-word;
}
.btn-expand-result {
display: inline-block;
margin-top: 6px;
padding: 3px 10px;
font-size: 11px;
color: var(--tool-color);
background: var(--tool-bg);
border: 1px solid var(--tool-border);
border-radius: 4px;
cursor: pointer;
transition: background 0.15s;
}
.btn-expand-result:hover {
background: var(--tool-bg-hover);
}
/* Text content — rendered as markdown */
.text-content {
padding: 0;

View File

@ -123,19 +123,21 @@
<script setup>
import { reactive, ref, watch, onMounted } from 'vue'
import { modelApi, conversationApi } from '../api'
import { conversationApi } from '../api'
import { useTheme } from '../composables/useTheme'
import { icons } from '../utils/icons'
import { SETTINGS_AUTO_SAVE_DEBOUNCE_MS } from '../constants'
const props = defineProps({
visible: { type: Boolean, default: false },
conversation: { type: Object, default: null },
models: { type: Array, default: () => [] },
defaultModel: { type: String, default: '' },
})
const emit = defineEmits(['close', 'save'])
const { isDark, toggleTheme } = useTheme()
const models = ref([])
const tabs = [
{ value: 'basic', label: '基本' },
@ -154,15 +156,6 @@ const form = reactive({
thinking_enabled: false,
})
async function loadModels() {
try {
const res = await modelApi.getCached()
models.value = res.data || []
} catch (e) {
console.error('Failed to load models:', e)
}
}
function syncFormFromConversation() {
if (props.conversation) {
form.title = props.conversation.title || ''
@ -170,29 +163,63 @@ function syncFormFromConversation() {
form.temperature = props.conversation.temperature ?? 1.0
form.max_tokens = props.conversation.max_tokens ?? 65536
form.thinking_enabled = props.conversation.thinking_enabled ?? false
// model: 使 conversation models
// model: 使 conversation defaultModel models
if (props.conversation.model) {
form.model = props.conversation.model
} else if (models.value.length > 0) {
form.model = models.value[0].id
} else if (props.defaultModel) {
form.model = props.defaultModel
} else if (props.models.length > 0) {
form.model = props.models[0].id
}
}
}
// Sync form when panel opens or conversation changes
watch([() => props.visible, () => props.conversation, models], () => {
if (props.visible) {
activeTab.value = 'basic'
syncFormFromConversation()
}
}, { deep: true })
// Track which conversation the form is synced to, to avoid saving stale data
let syncedConvId = null
let isSyncing = false
// Auto-save with debounce when form changes
function doSync() {
if (!props.conversation) return
isSyncing = true
syncFormFromConversation()
syncedConvId = props.conversation.id
// Defer resetting flag to after all watchers flush
setTimeout(() => { isSyncing = false }, 0)
}
// Sync form when panel opens or conversation switches
watch([() => props.visible, () => props.conversation?.id, () => props.models, () => props.defaultModel], () => {
if (props.visible && props.conversation) {
activeTab.value = 'basic'
if (saveTimer) clearTimeout(saveTimer)
saveTimer = null
doSync()
} else if (!props.visible) {
syncedConvId = null
}
})
// Sync when conversation data updates (e.g. auto-generated title after stream)
watch(
() => props.conversation,
(conv) => {
if (!props.visible || !conv || syncedConvId !== conv.id) return
doSync()
},
{ deep: true },
)
// Initial sync on mount (component may be recreated via :key)
onMounted(() => {
if (props.visible && props.conversation) doSync()
})
// Auto-save with debounce when user edits form
let saveTimer = null
watch(form, () => {
if (props.visible && props.conversation) {
if (props.visible && props.conversation && syncedConvId === props.conversation.id && !isSyncing) {
if (saveTimer) clearTimeout(saveTimer)
saveTimer = setTimeout(saveChanges, 500)
saveTimer = setTimeout(saveChanges, SETTINGS_AUTO_SAVE_DEBOUNCE_MS)
}
}, { deep: true })
@ -205,8 +232,6 @@ async function saveChanges() {
console.error('Failed to save settings:', e)
}
}
onMounted(loadModels)
</script>
<style scoped>

View File

@ -113,6 +113,7 @@
import { computed, reactive } from 'vue'
import { formatTime } from '../utils/format'
import { icons } from '../utils/icons'
import { INFINITE_SCROLL_THRESHOLD_PX } from '../constants'
const props = defineProps({
conversations: { type: Array, required: true },
@ -171,7 +172,7 @@ function toggleGroup(id) {
function onScroll(e) {
const el = e.target
if (el.scrollTop + el.clientHeight >= el.scrollHeight - 50) {
if (el.scrollTop + el.clientHeight >= el.scrollHeight - INFINITE_SCROLL_THRESHOLD_PX) {
emit('loadMore')
}
}

View File

@ -60,107 +60,10 @@
</div>
<!-- 趋势图 -->
<div v-if="period !== 'daily' && stats.daily && chartData.length > 0" class="stats-chart">
<div class="chart-title">每日趋势</div>
<div v-if="chartData.length > 0" class="stats-chart">
<div class="chart-title">{{ period === 'daily' ? '今日趋势' : '每日趋势' }}</div>
<div class="chart-container">
<svg class="line-chart" :viewBox="`0 0 ${chartWidth} ${chartHeight}`">
<defs>
<linearGradient id="areaGradient" x1="0%" y1="0%" x2="0%" y2="100%">
<stop offset="0%" :stop-color="accentColor" stop-opacity="0.25"/>
<stop offset="100%" :stop-color="accentColor" stop-opacity="0.02"/>
</linearGradient>
</defs>
<!-- 网格线 -->
<line
v-for="i in 4"
:key="'grid-' + i"
:x1="padding"
:y1="padding + (chartHeight - 2 * padding) * (i - 1) / 3"
:x2="chartWidth - padding"
:y2="padding + (chartHeight - 2 * padding) * (i - 1) / 3"
stroke="var(--border-light)"
stroke-dasharray="3,3"
/>
<!-- Y轴标签 -->
<text
v-for="i in 4"
:key="'yl-' + i"
:x="padding - 4"
:y="padding + (chartHeight - 2 * padding) * (i - 1) / 3 + 3"
text-anchor="end"
class="y-label"
>{{ formatNumber(maxValue - (maxValue * (i - 1)) / 3) }}</text>
<!-- 填充区域 -->
<path :d="areaPath" fill="url(#areaGradient)"/>
<!-- 折线 -->
<path
:d="linePath"
fill="none"
:stroke="accentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
/>
<!-- 数据点 -->
<circle
v-for="(point, idx) in chartPoints"
:key="idx"
:cx="point.x"
:cy="point.y"
r="3"
:fill="accentColor"
stroke="var(--bg-primary)"
stroke-width="2"
class="data-point"
@mouseenter="hoveredPoint = idx"
@mouseleave="hoveredPoint = null"
/>
<!-- 竖线指示 -->
<line
v-if="hoveredPoint !== null && chartPoints[hoveredPoint]"
:x1="chartPoints[hoveredPoint].x"
:y1="padding"
:x2="chartPoints[hoveredPoint].x"
:y2="chartHeight - padding"
stroke="var(--border-medium)"
stroke-dasharray="3,3"
/>
</svg>
<!-- X轴标签 -->
<div class="x-labels">
<span
v-for="(point, idx) in chartPoints"
:key="idx"
class="x-label"
:class="{ active: hoveredPoint === idx }"
>
{{ formatDateLabel(point.date) }}
</span>
</div>
<!-- 悬浮提示 -->
<Transition name="fade">
<div
v-if="hoveredPoint !== null && chartPoints[hoveredPoint]"
class="tooltip"
:style="{
left: chartPoints[hoveredPoint].x + 'px',
top: (chartPoints[hoveredPoint].y - 52) + 'px'
}"
>
<div class="tooltip-date">{{ formatFullDate(chartPoints[hoveredPoint].date) }}</div>
<div class="tooltip-row">
<span class="tooltip-dot prompt"></span>
输入 {{ formatNumber(chartPoints[hoveredPoint].prompt) }}
</div>
<div class="tooltip-row">
<span class="tooltip-dot completion"></span>
输出 {{ formatNumber(chartPoints[hoveredPoint].completion) }}
</div>
<div class="tooltip-total">{{ formatNumber(chartPoints[hoveredPoint].value) }} tokens</div>
</div>
</Transition>
<canvas ref="chartCanvas"></canvas>
</div>
</div>
@ -197,11 +100,14 @@
</template>
<script setup>
import { ref, computed, onMounted } from 'vue'
import { ref, computed, watch, onMounted, onBeforeUnmount, nextTick } from 'vue'
import { Chart, registerables } from 'chart.js'
import { statsApi } from '../api'
import { formatNumber } from '../utils/format'
import { icons } from '../utils/icons'
Chart.register(...registerables)
defineEmits(['close'])
const periods = [
@ -213,15 +119,8 @@ const periods = [
const period = ref('daily')
const stats = ref(null)
const loading = ref(false)
const hoveredPoint = ref(null)
const accentColor = computed(() => {
return getComputedStyle(document.documentElement).getPropertyValue('--accent-primary').trim() || '#2563eb'
})
const chartWidth = 320
const chartHeight = 140
const padding = 32
const chartCanvas = ref(null)
let chartInstance = null
const sortedDaily = computed(() => {
if (!stats.value?.daily) return {}
@ -231,18 +130,46 @@ const sortedDaily = computed(() => {
})
const chartData = computed(() => {
if (period.value === 'daily' && stats.value?.hourly) {
const hourly = stats.value.hourly
// Backend returns UTC hours convert to local timezone for display.
const offset = -new Date().getTimezoneOffset() / 60 // e.g. +8 for UTC+8
const localHourly = {}
for (const [utcH, val] of Object.entries(hourly)) {
const localH = ((parseInt(utcH) + offset) % 24 + 24) % 24
localHourly[localH] = val
}
let minH = 24, maxH = -1
for (const h of Object.keys(localHourly)) {
const hour = parseInt(h)
if (hour < minH) minH = hour
if (hour > maxH) maxH = hour
}
if (minH > maxH) return []
const start = Math.max(0, minH)
const end = Math.min(23, maxH)
return Array.from({ length: end - start + 1 }, (_, i) => {
const h = start + i
return {
label: `${h}:00`,
value: localHourly[String(h)]?.total || 0,
}
})
}
const data = sortedDaily.value
return Object.entries(data).map(([date, val]) => ({
date,
return Object.entries(data).map(([date, val]) => {
// date is "YYYY-MM-DD" from backend parse directly to avoid
// new Date() timezone shift (parsed as UTC midnight then
// getMonth/getDate applies local offset, potentially off by one day).
const [year, month, day] = date.split('-')
return {
label: `${parseInt(month)}/${parseInt(day)}`,
value: val.total,
prompt: val.prompt || 0,
completion: val.completion || 0,
}))
})
const maxValue = computed(() => {
if (chartData.value.length === 0) return 100
return Math.max(100, ...chartData.value.map(d => d.value))
}
})
})
const maxModelTokens = computed(() => {
@ -250,54 +177,136 @@ const maxModelTokens = computed(() => {
return Math.max(1, ...Object.values(stats.value.by_model).map(d => d.total))
})
const chartPoints = computed(() => {
const data = chartData.value
if (data.length === 0) return []
const xRange = chartWidth - 2 * padding
const yRange = chartHeight - 2 * padding
return data.map((d, i) => ({
x: data.length === 1
? chartWidth / 2
: padding + (i / Math.max(1, data.length - 1)) * xRange,
y: chartHeight - padding - (d.value / maxValue.value) * yRange,
date: d.date,
value: d.value,
prompt: d.prompt,
completion: d.completion,
}))
})
const linePath = computed(() => {
const points = chartPoints.value
if (points.length === 0) return ''
return points.map((p, i) => `${i === 0 ? 'M' : 'L'} ${p.x} ${p.y}`).join(' ')
})
const areaPath = computed(() => {
const points = chartPoints.value
if (points.length === 0) return ''
const baseY = chartHeight - padding
let path = `M ${points[0].x} ${baseY} `
path += points.map(p => `L ${p.x} ${p.y}`).join(' ')
path += ` L ${points[points.length - 1].x} ${baseY} Z`
return path
})
function formatDateLabel(dateStr) {
const d = new Date(dateStr)
return `${d.getMonth() + 1}/${d.getDate()}`
function getAccentColor() {
return getComputedStyle(document.documentElement).getPropertyValue('--accent-primary').trim() || '#2563eb'
}
function formatFullDate(dateStr) {
const d = new Date(dateStr)
return `${d.getMonth() + 1}${d.getDate()}`
function getTextColor(alpha = 1) {
const c = getComputedStyle(document.documentElement).getPropertyValue('--text-tertiary').trim() || '#888'
if (alpha === 1) return c
// Convert hex to rgba
if (c.startsWith('#')) {
const r = parseInt(c.slice(1, 3), 16)
const g = parseInt(c.slice(3, 5), 16)
const b = parseInt(c.slice(5, 7), 16)
return `rgba(${r},${g},${b},${alpha})`
}
return c
}
function destroyChart() {
if (chartInstance) {
chartInstance.destroy()
chartInstance = null
}
}
function buildChart() {
if (!chartCanvas.value || chartData.value.length === 0) return
destroyChart()
const accent = getAccentColor()
const ctx = chartCanvas.value.getContext('2d')
// Gradient fill
const gradient = ctx.createLinearGradient(0, 0, 0, 200)
gradient.addColorStop(0, accent + '40')
gradient.addColorStop(1, accent + '05')
const labels = chartData.value.map(d => d.label)
const values = chartData.value.map(d => d.value)
// Determine max ticks for x-axis
const maxTicks = chartData.value.length <= 8 ? chartData.value.length : 6
chartInstance = new Chart(ctx, {
type: 'line',
data: {
labels,
datasets: [{
data: values,
borderColor: accent,
backgroundColor: gradient,
borderWidth: 2,
pointRadius: 0,
pointHoverRadius: 4,
pointHoverBackgroundColor: accent,
pointHoverBorderColor: '#fff',
pointHoverBorderWidth: 2,
fill: true,
tension: 0,
}],
},
options: {
responsive: true,
maintainAspectRatio: false,
animation: { duration: 300 },
layout: {
padding: { top: 4, right: 4, bottom: 0, left: 0 },
},
scales: {
x: {
grid: { display: false },
border: { display: false },
ticks: {
color: getTextColor(),
font: { size: 10 },
maxTicksLimit: maxTicks,
maxRotation: 0,
},
},
y: {
beginAtZero: true,
grid: {
color: getTextColor(0.15),
drawBorder: false,
},
border: { display: false },
ticks: {
color: getTextColor(),
font: { size: 9 },
maxTicksLimit: 4,
callback: (v) => formatNumber(v),
},
},
},
plugins: {
legend: { display: false },
tooltip: {
backgroundColor: 'rgba(0,0,0,0.8)',
titleColor: '#fff',
bodyColor: '#ccc',
titleFont: { size: 11, weight: '500' },
bodyFont: { size: 11 },
padding: 8,
cornerRadius: 6,
displayColors: false,
callbacks: {
title: (items) => {
const idx = items[0].dataIndex
const d = chartData.value[idx]
if (period.value === 'daily') {
return `${d.label} - ${parseInt(d.label) + 1}:00`
}
return d.label
},
label: (item) => `${formatNumber(item.raw)} tokens`,
},
},
},
interaction: {
mode: 'index',
intersect: false,
},
},
})
}
watch(chartData, () => {
nextTick(buildChart)
})
async function loadStats() {
loading.value = true
try {
@ -312,11 +321,11 @@ async function loadStats() {
function changePeriod(p) {
period.value = p
hoveredPoint.value = null
loadStats()
}
onMounted(loadStats)
onBeforeUnmount(destroyChart)
</script>
<style scoped>
@ -324,8 +333,6 @@ onMounted(loadStats)
padding: 0;
}
/* panel-header, panel-title, header-actions now in global.css */
.stats-loading {
display: flex;
align-items: center;
@ -430,99 +437,9 @@ onMounted(loadStats)
background: var(--bg-input);
border: 1px solid var(--border-light);
border-radius: 10px;
padding: 12px 8px 8px 8px;
padding: 10px;
position: relative;
overflow: hidden;
}
.line-chart {
width: 100%;
height: 140px;
}
.y-label {
fill: var(--text-tertiary);
font-size: 9px;
}
.data-point {
cursor: pointer;
transition: r 0.15s;
}
.data-point:hover {
r: 5;
}
.x-labels {
display: flex;
justify-content: space-between;
margin-top: 6px;
padding: 0 28px 0 32px;
}
.x-label {
font-size: 10px;
color: var(--text-tertiary);
transition: color 0.15s;
}
.x-label.active {
color: var(--text-primary);
font-weight: 500;
}
/* 提示框 */
.tooltip {
position: absolute;
background: var(--bg-primary);
border: 1px solid var(--border-medium);
padding: 8px 10px;
border-radius: 8px;
font-size: 11px;
pointer-events: none;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.12);
transform: translateX(-50%);
z-index: 10;
min-width: 120px;
}
.tooltip-date {
color: var(--text-tertiary);
font-size: 10px;
margin-bottom: 4px;
}
.tooltip-row {
display: flex;
align-items: center;
gap: 4px;
font-size: 11px;
color: var(--text-secondary);
}
.tooltip-dot {
width: 6px;
height: 6px;
border-radius: 50%;
flex-shrink: 0;
}
.tooltip-dot.prompt {
background: #3b82f6;
}
.tooltip-dot.completion {
background: #a855f7;
}
.tooltip-total {
margin-top: 4px;
padding-top: 4px;
border-top: 1px solid var(--border-light);
font-weight: 600;
color: var(--text-primary);
font-size: 12px;
height: 180px;
}
/* 模型分布 */

View File

@ -1,5 +1,6 @@
import { watch, onMounted, nextTick, onUnmounted } from 'vue'
import { enhanceCodeBlocks } from '../utils/markdown'
import { CODE_ENHANCE_DEBOUNCE_MS } from '../constants'
/**
* Composable for enhancing code blocks in a container element.
@ -18,7 +19,7 @@ export function useCodeEnhancement(templateRef, dep, watchOpts) {
function debouncedEnhance() {
if (debounceTimer) clearTimeout(debounceTimer)
debounceTimer = setTimeout(() => nextTick(enhance), 150)
debounceTimer = setTimeout(() => nextTick(enhance), CODE_ENHANCE_DEBOUNCE_MS)
}
onMounted(enhance)

View File

@ -1,4 +1,5 @@
import { reactive } from 'vue'
import { TOAST_DEFAULT_DURATION } from '../constants'
const state = reactive({
toasts: [],
@ -6,7 +7,7 @@ const state = reactive({
})
export function useToast() {
function add(type, message, duration = 1500) {
function add(type, message, duration = TOAST_DEFAULT_DURATION) {
const id = ++state._id
state.toasts.push({ id, type, message })
setTimeout(() => {

37
frontend/src/constants.js Normal file
View File

@ -0,0 +1,37 @@
/**
* Frontend constants
*/
// === Tool Result Display ===
/** Max characters shown in tool result preview before truncation */
export const RESULT_PREVIEW_LIMIT = 2048
// === API ===
export const API_BASE_URL = '/api'
export const CONTENT_TYPE_JSON = 'application/json'
// === Pagination ===
export const DEFAULT_CONVERSATION_PAGE_SIZE = 20
export const DEFAULT_MESSAGE_PAGE_SIZE = 50
export const DEFAULT_PROJECT_PAGE_SIZE = 20
// === Timers (ms) ===
export const TOAST_DEFAULT_DURATION = 1500
export const CODE_ENHANCE_DEBOUNCE_MS = 150
export const SETTINGS_AUTO_SAVE_DEBOUNCE_MS = 500
export const COPY_BUTTON_RESET_MS = 1500
// === Truncation ===
export const DEFAULT_TRUNCATE_LENGTH = 60
// === UI Limits ===
export const TEXTAREA_MAX_HEIGHT_PX = 200
export const INFINITE_SCROLL_THRESHOLD_PX = 50
// === LocalStorage Keys ===
export const LS_KEY_THEME = 'theme'
export const LS_KEY_TOOLS_ENABLED = 'tools_enabled'
export const LS_KEY_MODELS_CACHE = 'models_cache'
// === File Upload ===
export const ALLOWED_UPLOAD_EXTENSIONS = '.txt,.md,.json,.xml,.html,.css,.js,.ts,.jsx,.tsx,.py,.java,.c,.cpp,.h,.hpp,.yaml,.yml,.toml,.ini,.csv,.sql,.sh,.bat,.log,.vue,.svelte,.go,.rs,.rb,.php,.swift,.kt,.scala,.lua,.r,.dart'

View File

@ -3,9 +3,10 @@ import App from './App.vue'
import './styles/global.css'
import './styles/highlight.css'
import 'katex/dist/katex.min.css'
import { LS_KEY_THEME } from './constants'
// Initialize theme before app mounts to avoid flash when lazy-loading useTheme
const savedTheme = localStorage.getItem('theme')
const savedTheme = localStorage.getItem(LS_KEY_THEME)
if (savedTheme === 'dark' || savedTheme === 'light') {
document.documentElement.setAttribute('data-theme', savedTheme)
}

View File

@ -1,3 +1,5 @@
import { DEFAULT_TRUNCATE_LENGTH } from '../constants'
/**
* Format ISO date string to a short time string.
* - Today: "14:30"
@ -36,7 +38,7 @@ export function formatJson(value) {
/**
* Truncate text to max characters with ellipsis.
*/
export function truncate(text, max = 60) {
export function truncate(text, max = DEFAULT_TRUNCATE_LENGTH) {
if (!text) return ''
const str = text.replace(/\s+/g, ' ').trim()
return str.length > max ? str.slice(0, max) + '\u2026' : str

View File

@ -2,6 +2,7 @@ import { marked } from 'marked'
import { markedHighlight } from 'marked-highlight'
import katex from 'katex'
import { highlightCode } from './highlight'
import { COPY_BUTTON_RESET_MS } from '../constants'
function renderMath(text, displayMode) {
try {
@ -108,7 +109,7 @@ export function enhanceCodeBlocks(container) {
copyBtn.addEventListener('click', () => {
const raw = code?.textContent || ''
const copy = () => { copyBtn.innerHTML = CHECK_SVG; setTimeout(() => { copyBtn.innerHTML = COPY_SVG }, 1500) }
const copy = () => { copyBtn.innerHTML = CHECK_SVG; setTimeout(() => { copyBtn.innerHTML = COPY_SVG }, COPY_BUTTON_RESET_MS) }
if (navigator.clipboard) {
navigator.clipboard.writeText(raw).then(copy)
} else {

View File

@ -8,17 +8,25 @@ import { dirname, resolve } from 'path'
const __filename = fileURLToPath(import.meta.url)
const __dirname = dirname(__filename)
const config = yaml.load(
fs.readFileSync(resolve(__dirname, '..', 'config.yml'), 'utf-8')
)
const configPath = resolve(__dirname, '..', 'config.yml');
let config = {};
try {
config = yaml.load(fs.readFileSync(configPath, 'utf-8'));
} catch (e) {
console.warn(`Config file not found at ${configPath}, using defaults.`);
config = {};
}
const frontend_port = process.env.VITE_FRONTEND_PORT || config.frontend_port || 4000;
const backend_port = process.env.VITE_BACKEND_PORT || config.backend_port || 3000;
export default defineConfig({
plugins: [vue()],
server: {
port: config.frontend_port,
port: frontend_port,
proxy: {
'/api': {
target: `http://localhost:${config.backend_port}`,
target: `http://localhost:${backend_port}`,
changeOrigin: true,
},
},

View File

@ -24,3 +24,13 @@ build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
include = ["backend*"]
[project.optional-dependencies]
test = [
"pytest>=7.0",
"pytest-flask>=1.2",
"pytest-cov>=4.0",
"pytest-mock>=3.0",
"requests-mock>=1.10",
"httpx>=0.25",
]

0
tests/__init__.py Normal file
View File

79
tests/conftest.py Normal file
View File

@ -0,0 +1,79 @@
import pytest
import tempfile
import os
from pathlib import Path
from backend import create_app, db as _db
@pytest.fixture(scope='session')
def app():
"""Create a Flask app configured for testing."""
# Create a temporary SQLite database file
db_fd, db_path = tempfile.mkstemp(suffix='.db')
# Override config to use SQLite in-memory (or temporary file)
class TestConfig:
SQLALCHEMY_DATABASE_URI = f'sqlite:///{db_path}'
SQLALCHEMY_TRACK_MODIFICATIONS = False
TESTING = True
SECRET_KEY = 'test-secret-key'
AUTH_CONFIG = {
'mode': 'single',
'jwt_secret': 'test-jwt-secret',
'jwt_expiry': 3600,
}
app = create_app()
app.config.from_object(TestConfig)
# Push an application context
ctx = app.app_context()
ctx.push()
yield app
# Teardown
ctx.pop()
os.close(db_fd)
os.unlink(db_path)
@pytest.fixture(scope='session')
def db(app):
"""Create database tables."""
_db.create_all()
yield _db
_db.drop_all()
@pytest.fixture(scope='function')
def session(db):
"""Create a new database session for a test."""
connection = db.engine.connect()
transaction = connection.begin()
# Use a scoped session
from sqlalchemy.orm import scoped_session, sessionmaker
session_factory = sessionmaker(bind=connection)
session = scoped_session(session_factory)
db.session = session
yield session
# Rollback and close
transaction.rollback()
connection.close()
session.remove()
@pytest.fixture
def client(app):
"""Test client."""
return app.test_client()
@pytest.fixture
def runner(app):
"""CLI test runner."""
return app.test_cli_runner()

80
tests/test_auth.py Normal file
View File

@ -0,0 +1,80 @@
import pytest
import json
from backend.models import User
def test_auth_mode(client, session):
"""Test /api/auth/mode endpoint."""
resp = client.get('/api/auth/mode')
assert resp.status_code == 200
data = json.loads(resp.data)
assert 'code' in data
assert 'data' in data
# Default is single
assert data['data']['mode'] == 'single'
def test_login_single_mode(client, session):
"""Test login in single-user mode."""
# Ensure default user exists (should be created by auth middleware)
user = User.query.filter_by(username='default').first()
if not user:
user = User(username='default')
session.add(user)
session.commit()
resp = client.post('/api/auth/login', json={
'username': 'default',
'password': '' # no password in single mode
})
assert resp.status_code == 200
data = json.loads(resp.data)
assert data['code'] == 0
assert 'token' in data['data']
assert 'user' in data['data']
assert data['data']['user']['username'] == 'default'
def test_profile(client, session):
"""Test /api/auth/profile endpoint."""
# In single mode, no token required
resp = client.get('/api/auth/profile')
assert resp.status_code == 200
data = json.loads(resp.data)
assert data['code'] == 0
assert data['data']['username'] == 'default'
def test_profile_update(client, session):
"""Test updating profile."""
resp = client.patch('/api/auth/profile', json={
'email': 'default@example.com',
'avatar': 'https://example.com/avatar.png'
})
assert resp.status_code == 200
data = json.loads(resp.data)
assert data['code'] == 0
# Verify update
user = User.query.filter_by(username='default').first()
assert user.email == 'default@example.com'
assert user.avatar == 'https://example.com/avatar.png'
def test_register_not_allowed_in_single_mode(client, session):
"""Registration should fail in single-user mode."""
resp = client.post('/api/auth/register', json={
'username': 'newuser',
'password': 'password'
})
# Expect error (maybe 400 or 403)
# The actual behavior may vary; we'll just ensure it's not a success
data = json.loads(resp.data)
assert data['code'] != 0
# Multi-user mode tests (requires switching config)
# We'll skip for now because it's more complex.
if __name__ == '__main__':
pytest.main(['-v', __file__])

253
tests/test_conversations.py Normal file
View File

@ -0,0 +1,253 @@
import pytest
import json
from backend.models import User, Conversation, Message
def test_list_conversations(client, session):
"""Test GET /api/conversations."""
user = User.query.filter_by(username='default').first()
if not user:
user = User(username='default')
session.add(user)
session.commit()
# Create a conversation
conv = Conversation(
id='conv-1',
user_id=user.id,
title='Test Conversation',
model='deepseek-chat'
)
session.add(conv)
session.commit()
resp = client.get('/api/conversations')
assert resp.status_code == 200
data = json.loads(resp.data)
assert data['code'] == 0
items = data['data']['items']
# Should have at least one conversation
assert len(items) >= 1
# Find our conversation
found = any(item['id'] == 'conv-1' for item in items)
assert found is True
def test_create_conversation(client, session):
"""Test POST /api/conversations."""
user = User.query.filter_by(username='default').first()
if not user:
user = User(username='default')
session.add(user)
session.commit()
resp = client.post('/api/conversations', json={
'title': 'New Conversation',
'model': 'glm-5',
'system_prompt': 'You are helpful.',
'temperature': 0.7,
'max_tokens': 4096,
'thinking_enabled': True
})
assert resp.status_code == 200
data = json.loads(resp.data)
assert data['code'] == 0
conv_data = data['data']
assert conv_data['title'] == 'New Conversation'
assert conv_data['model'] == 'glm-5'
assert conv_data['system_prompt'] == 'You are helpful.'
assert conv_data['temperature'] == 0.7
assert conv_data['max_tokens'] == 4096
assert conv_data['thinking_enabled'] is True
# Verify database
conv = Conversation.query.filter_by(id=conv_data['id']).first()
assert conv is not None
assert conv.user_id == user.id
def test_get_conversation(client, session):
"""Test GET /api/conversations/:id."""
user = User.query.filter_by(username='default').first()
if not user:
user = User(username='default')
session.add(user)
session.commit()
conv = Conversation(
id='conv-2',
user_id=user.id,
title='Test Get',
model='deepseek-chat'
)
session.add(conv)
session.commit()
resp = client.get(f'/api/conversations/{conv.id}')
assert resp.status_code == 200
data = json.loads(resp.data)
assert data['code'] == 0
conv_data = data['data']
assert conv_data['id'] == 'conv-2'
assert conv_data['title'] == 'Test Get'
def test_update_conversation(client, session):
"""Test PATCH /api/conversations/:id."""
user = User.query.filter_by(username='default').first()
if not user:
user = User(username='default')
session.add(user)
session.commit()
conv = Conversation(
id='conv-3',
user_id=user.id,
title='Original',
model='deepseek-chat'
)
session.add(conv)
session.commit()
resp = client.patch(f'/api/conversations/{conv.id}', json={
'title': 'Updated Title',
'temperature': 0.9
})
assert resp.status_code == 200
data = json.loads(resp.data)
assert data['code'] == 0
# Verify update
session.refresh(conv)
assert conv.title == 'Updated Title'
assert conv.temperature == 0.9
def test_delete_conversation(client, session):
"""Test DELETE /api/conversations/:id."""
user = User.query.filter_by(username='default').first()
if not user:
user = User(username='default')
session.add(user)
session.commit()
conv = Conversation(
id='conv-4',
user_id=user.id,
title='To Delete',
model='deepseek-chat'
)
session.add(conv)
session.commit()
resp = client.delete(f'/api/conversations/{conv.id}')
assert resp.status_code == 200
data = json.loads(resp.data)
assert data['code'] == 0
# Should be gone
deleted = Conversation.query.get(conv.id)
assert deleted is None
def test_list_messages(client, session):
"""Test GET /api/conversations/:id/messages."""
user = User.query.filter_by(username='default').first()
if not user:
user = User(username='default')
session.add(user)
session.commit()
conv = Conversation(
id='conv-5',
user_id=user.id,
title='Messages Test',
model='deepseek-chat'
)
session.add(conv)
session.commit()
# Create messages
msg1 = Message(id='msg-1', conversation_id=conv.id, role='user', content='Hello')
msg2 = Message(id='msg-2', conversation_id=conv.id, role='assistant', content='Hi')
session.add_all([msg1, msg2])
session.commit()
resp = client.get(f'/api/conversations/{conv.id}/messages')
assert resp.status_code == 200
data = json.loads(resp.data)
assert data['code'] == 0
messages = data['data']['items']
assert len(messages) == 2
roles = {m['role'] for m in messages}
assert 'user' in roles
assert 'assistant' in roles
@pytest.mark.skip(reason="SSE endpoint requires streaming")
def test_send_message(client, session):
"""Test POST /api/conversations/:id/messages (non-streaming)."""
user = User.query.filter_by(username='default').first()
if not user:
user = User(username='default')
session.add(user)
session.commit()
conv = Conversation(
id='conv-6',
user_id=user.id,
title='Send Test',
model='deepseek-chat',
thinking_enabled=False
)
session.add(conv)
session.commit()
# This endpoint expects streaming (SSE) but we can test with a simple request.
# However, the endpoint may return a streaming response; we'll just test that it accepts request.
# We'll mock the LLM call? Instead, we'll skip because it's complex.
# For simplicity, we'll just test that the endpoint exists and returns something.
resp = client.post(f'/api/conversations/{conv.id}/messages', json={
'content': 'Hello',
'role': 'user'
})
# The endpoint returns a streaming response (text/event-stream) with status 200.
# The client will see a stream; we'll just check status code.
# It might be 200 or 400 if missing parameters.
# We'll accept any 2xx status.
assert resp.status_code in (200, 201, 204)
def test_delete_message(client, session):
"""Test DELETE /api/conversations/:id/messages/:mid."""
user = User.query.filter_by(username='default').first()
if not user:
user = User(username='default')
session.add(user)
session.commit()
conv = Conversation(
id='conv-7',
user_id=user.id,
title='Delete Msg',
model='deepseek-chat'
)
session.add(conv)
session.commit()
msg = Message(id='msg-del', conversation_id=conv.id, role='user', content='Delete me')
session.add(msg)
session.commit()
resp = client.delete(f'/api/conversations/{conv.id}/messages/{msg.id}')
assert resp.status_code == 200
data = json.loads(resp.data)
assert data['code'] == 0
# Should be gone
deleted = Message.query.get(msg.id)
assert deleted is None
if __name__ == '__main__':
pytest.main(['-v', __file__])

209
tests/test_models.py Normal file
View File

@ -0,0 +1,209 @@
import pytest
from backend.models import User, Conversation, Message, TokenUsage, Project
from datetime import datetime, timezone
def test_user_create(session):
"""Test creating a user."""
user = User(username='testuser', email='test@example.com')
session.add(user)
session.commit()
assert user.id is not None
assert user.username == 'testuser'
assert user.email == 'test@example.com'
assert user.role == 'user'
assert user.is_active is True
assert user.created_at is not None
assert user.last_login_at is None
def test_user_password_hashing(session):
"""Test password hashing and verification."""
user = User(username='testuser')
user.password = 'securepassword'
session.add(user)
session.commit()
# Password hash should be set
assert user.password_hash is not None
assert user.password_hash != 'securepassword'
# Check password
assert user.check_password('securepassword') is True
assert user.check_password('wrongpassword') is False
# Setting password to None clears hash
user.password = None
assert user.password_hash is None
def test_user_to_dict(session):
"""Test user serialization."""
user = User(username='testuser', email='test@example.com', role='admin')
session.add(user)
session.commit()
data = user.to_dict()
assert data['username'] == 'testuser'
assert data['email'] == 'test@example.com'
assert data['role'] == 'admin'
assert 'password_hash' not in data
assert 'created_at' in data
def test_conversation_create(session):
"""Test creating a conversation."""
user = User(username='user1')
session.add(user)
session.commit()
conv = Conversation(
id='conv-123',
user_id=user.id,
title='Test Conversation',
model='deepseek-chat',
system_prompt='You are a helpful assistant.',
temperature=0.8,
max_tokens=2048,
thinking_enabled=True,
)
session.add(conv)
session.commit()
assert conv.id == 'conv-123'
assert conv.user_id == user.id
assert conv.title == 'Test Conversation'
assert conv.model == 'deepseek-chat'
assert conv.system_prompt == 'You are a helpful assistant.'
assert conv.temperature == 0.8
assert conv.max_tokens == 2048
assert conv.thinking_enabled is True
assert conv.created_at is not None
assert conv.updated_at is not None
assert conv.user == user
def test_conversation_relationships(session):
"""Test conversation relationships with messages."""
user = User(username='user1')
session.add(user)
session.commit()
conv = Conversation(id='conv-123', user_id=user.id, title='Test')
session.add(conv)
session.commit()
# Create messages
msg1 = Message(id='msg-1', conversation_id=conv.id, role='user', content='Hello')
msg2 = Message(id='msg-2', conversation_id=conv.id, role='assistant', content='Hi')
session.add_all([msg1, msg2])
session.commit()
# Test relationship
assert conv.messages.count() == 2
assert list(conv.messages) == [msg1, msg2]
assert msg1.conversation == conv
def test_message_create(session):
"""Test creating a message."""
user = User(username='user1')
session.add(user)
session.commit()
conv = Conversation(id='conv-123', user_id=user.id, title='Test')
session.add(conv)
session.commit()
msg = Message(
id='msg-1',
conversation_id=conv.id,
role='user',
content='{"text": "Hello world"}',
)
session.add(msg)
session.commit()
assert msg.id == 'msg-1'
assert msg.conversation_id == conv.id
assert msg.role == 'user'
assert msg.content == '{"text": "Hello world"}'
assert msg.created_at is not None
assert msg.conversation == conv
def test_message_to_dict(session):
"""Test message serialization."""
from backend.utils.helpers import message_to_dict
msg = Message(
id='msg-1',
conversation_id='conv-123',
role='user',
content='{"text": "Hello", "attachments": [{"name": "file.txt"}]}',
)
session.add(msg)
session.commit()
data = message_to_dict(msg)
assert data['id'] == 'msg-1'
assert data['role'] == 'user'
assert data['text'] == 'Hello'
assert 'attachments' in data
assert data['attachments'][0]['name'] == 'file.txt'
def test_token_usage_create(session):
"""Test token usage recording."""
user = User(username='user1')
session.add(user)
session.commit()
usage = TokenUsage(
user_id=user.id,
model='deepseek-chat',
date=datetime.now(timezone.utc).date(),
prompt_tokens=100,
completion_tokens=200,
total_tokens=300,
)
session.add(usage)
session.commit()
assert usage.id is not None
assert usage.user_id == user.id
assert usage.model == 'deepseek-chat'
assert usage.prompt_tokens == 100
assert usage.total_tokens == 300
def test_project_create(session):
"""Test project creation."""
user = User(username='user1')
session.add(user)
session.commit()
project = Project(
id='proj-123',
user_id=user.id,
name='My Project',
path='user_1/my_project',
description='A test project',
)
session.add(project)
session.commit()
assert project.id == 'proj-123'
assert project.user_id == user.id
assert project.name == 'My Project'
assert project.path == 'user_1/my_project'
assert project.description == 'A test project'
assert project.created_at is not None
assert project.updated_at is not None
assert project.user == user
assert project.conversations.count() == 0
if __name__ == '__main__':
pytest.main(['-v', __file__])

342
tests/test_tools.py Normal file
View File

@ -0,0 +1,342 @@
import pytest
from backend.tools.core import ToolRegistry, ToolDefinition, ToolResult
from backend.tools.executor import ToolExecutor
import json
def test_tool_definition():
"""Test ToolDefinition creation and serialization."""
def dummy_handler(args):
return args.get('value', 0)
tool = ToolDefinition(
name='test_tool',
description='A test tool',
parameters={
'type': 'object',
'properties': {'value': {'type': 'integer'}}
},
handler=dummy_handler,
category='test'
)
assert tool.name == 'test_tool'
assert tool.description == 'A test tool'
assert tool.category == 'test'
assert tool.handler == dummy_handler
# Test OpenAI format conversion
openai_format = tool.to_openai_format()
assert openai_format['type'] == 'function'
assert openai_format['function']['name'] == 'test_tool'
assert openai_format['function']['description'] == 'A test tool'
assert 'parameters' in openai_format['function']
def test_tool_result():
"""Test ToolResult creation."""
result = ToolResult.ok(data='success')
assert result.success is True
assert result.data == 'success'
assert result.error is None
result2 = ToolResult.fail(error='something went wrong')
assert result2.success is False
assert result2.error == 'something went wrong'
assert result2.data is None
# Test to_dict
dict_ok = result.to_dict()
assert dict_ok['success'] is True
assert dict_ok['data'] == 'success'
dict_fail = result2.to_dict()
assert dict_fail['success'] is False
assert dict_fail['error'] == 'something went wrong'
def test_tool_registry():
"""Test ToolRegistry registration and lookup."""
registry = ToolRegistry()
# Count existing tools
initial_tools = registry.list_all()
initial_count = len(initial_tools)
# Register a tool
def add_handler(args):
return args.get('a', 0) + args.get('b', 0)
tool = ToolDefinition(
name='add',
description='Add two numbers',
parameters={
'type': 'object',
'properties': {
'a': {'type': 'number'},
'b': {'type': 'number'}
},
'required': ['a', 'b']
},
handler=add_handler,
category='math'
)
registry.register(tool)
# Should be able to get it
retrieved = registry.get('add')
assert retrieved is not None
assert retrieved.name == 'add'
assert retrieved.handler == add_handler
# List all returns OpenAI format
tools_list = registry.list_all()
assert len(tools_list) == initial_count + 1
# Ensure our tool is present
tool_names = [t['function']['name'] for t in tools_list]
assert 'add' in tool_names
# Execute tool
result = registry.execute('add', {'a': 5, 'b': 3})
assert result['success'] is True
assert result['data'] == 8
# Execute non-existent tool
result = registry.execute('nonexistent', {})
assert result['success'] is False
assert 'Tool not found' in result['error']
# Execute with exception
def faulty_handler(args):
raise ValueError('Intentional error')
faulty_tool = ToolDefinition(
name='faulty',
description='Faulty tool',
parameters={'type': 'object'},
handler=faulty_handler
)
registry.register(faulty_tool)
result = registry.execute('faulty', {})
assert result['success'] is False
assert 'Intentional error' in result['error']
def test_tool_registry_singleton():
"""Test that ToolRegistry is a singleton."""
registry1 = ToolRegistry()
registry2 = ToolRegistry()
assert registry1 is registry2
# Register in one, should appear in the other
def dummy(args):
return 42
tool = ToolDefinition(
name='singleton_test',
description='Test',
parameters={'type': 'object'},
handler=dummy
)
registry1.register(tool)
assert registry2.get('singleton_test') is not None
def test_tool_executor_basic():
"""Test ToolExecutor basic execution."""
registry = ToolRegistry()
# Clear any previous tools (singleton may have state from other tests)
# We'll create a fresh registry by directly manipulating the singleton's internal dict.
# This is a bit hacky but works for testing.
registry._tools.clear()
def echo_handler(args):
return args.get('message', '')
tool = ToolDefinition(
name='echo',
description='Echo message',
parameters={
'type': 'object',
'properties': {'message': {'type': 'string'}}
},
handler=echo_handler
)
registry.register(tool)
executor = ToolExecutor(registry=registry, enable_cache=False)
# Simulate a tool call
call = {
'id': 'call_1',
'function': {
'name': 'echo',
'arguments': json.dumps({'message': 'Hello'})
}
}
messages = executor.process_tool_calls([call], context=None)
assert len(messages) == 1
msg = messages[0]
assert msg['role'] == 'tool'
assert msg['tool_call_id'] == 'call_1'
assert msg['name'] == 'echo'
content = json.loads(msg['content'])
assert content['success'] is True
assert content['data'] == 'Hello'
def test_tool_executor_cache():
"""Test caching behavior."""
registry = ToolRegistry()
registry._tools.clear()
call_count = 0
def counter_handler(args):
nonlocal call_count
call_count += 1
return call_count
tool = ToolDefinition(
name='counter',
description='Count calls',
parameters={'type': 'object'},
handler=counter_handler
)
registry.register(tool)
executor = ToolExecutor(registry=registry, enable_cache=True, cache_ttl=10)
call = {
'id': 'call_1',
'function': {
'name': 'counter',
'arguments': '{}'
}
}
# First call should execute
messages1 = executor.process_tool_calls([call], context=None)
assert len(messages1) == 1
content1 = json.loads(messages1[0]['content'])
assert content1['data'] == 1
assert call_count == 1
# Second identical call should be cached
messages2 = executor.process_tool_calls([call], context=None)
assert len(messages2) == 1
content2 = json.loads(messages2[0]['content'])
# data should still be 1 (cached)
assert content2['data'] == 1
# handler not called again
assert call_count == 1
# Different call (different arguments) should execute
call2 = {
'id': 'call_2',
'function': {
'name': 'counter',
'arguments': json.dumps({'different': True})
}
}
messages3 = executor.process_tool_calls([call2], context=None)
content3 = json.loads(messages3[0]['content'])
assert content3['data'] == 2
assert call_count == 2
def test_tool_executor_context_injection():
"""Test that context fields are injected into arguments."""
registry = ToolRegistry()
registry._tools.clear()
captured_args = None
def capture_handler(args):
nonlocal captured_args
captured_args = args.copy()
return 'ok'
tool = ToolDefinition(
name='file_read',
description='Read file',
parameters={'type': 'object'},
handler=capture_handler
)
registry.register(tool)
executor = ToolExecutor(registry=registry)
call = {
'id': 'call_1',
'function': {
'name': 'file_read',
'arguments': json.dumps({'path': 'test.txt'})
}
}
context = {'project_id': 'proj-123'}
executor.process_tool_calls([call], context=context)
# Check that project_id was injected
assert captured_args is not None
assert captured_args['project_id'] == 'proj-123'
assert captured_args['path'] == 'test.txt'
def test_tool_executor_deduplication():
"""Test deduplication of identical calls within a session."""
registry = ToolRegistry()
registry._tools.clear()
call_count = 0
def count_handler(args):
nonlocal call_count
call_count += 1
return call_count
tool = ToolDefinition(
name='count',
description='Count',
parameters={'type': 'object'},
handler=count_handler
)
registry.register(tool)
executor = ToolExecutor(registry=registry, enable_cache=False)
call = {
'id': 'call_1',
'function': {
'name': 'count',
'arguments': '{}'
}
}
call_same = {
'id': 'call_2',
'function': {
'name': 'count',
'arguments': '{}'
}
}
# Execute both calls in one batch
messages = executor.process_tool_calls([call, call_same], context=None)
# Should deduplicate: second call returns cached result from first call
# Let's verify that call_count is 1 (only one actual execution).
assert call_count == 1
# Both messages should have success=True
assert len(messages) == 2
content0 = json.loads(messages[0]['content'])
content1 = json.loads(messages[1]['content'])
assert content0['success'] is True
assert content1['success'] is True
# Data could be 1 for both (duplicate may have data None)
assert content0['data'] == 1
# duplicate call may have data None, but should be successful and cached
assert content1['success'] is True
assert content1.get('cached') is True
assert content1.get('data') in (1, None)
if __name__ == '__main__':
pytest.main(['-v', __file__])