Compare commits
10 Commits
6aea98554f
...
9b7468ea4e
| Author | SHA1 | Date |
|---|---|---|
|
|
9b7468ea4e | |
|
|
7da142fccb | |
|
|
dd47f9db3d | |
|
|
2a6c82b3ba | |
|
|
ae73559fd2 | |
|
|
cc639a979a | |
|
|
3970c0b9a0 | |
|
|
24e8497230 | |
|
|
57e998f896 | |
|
|
836ee8ac9d |
|
|
@ -0,0 +1,99 @@
|
|||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.12']
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -e .[test]
|
||||
# Install frontend dependencies for potential frontend tests (optional)
|
||||
cd frontend && npm ci && cd ..
|
||||
|
||||
# 统一配置生成,用于后端测试(与前端构建共享相同配置)
|
||||
- name: Create config.yml for CI
|
||||
run: |
|
||||
cat > config.yml << 'EOF'
|
||||
backend_port: 3000
|
||||
frontend_port: 4000
|
||||
max_iterations: 5
|
||||
sub_agent:
|
||||
max_iterations: 3
|
||||
max_tokens: 4096
|
||||
max_agents: 5
|
||||
max_concurrency: 3
|
||||
models:
|
||||
- id: dummy
|
||||
name: Dummy
|
||||
api_url: https://api.example.com
|
||||
api_key: dummy-key
|
||||
default_model: dummy
|
||||
db_type: sqlite
|
||||
db_sqlite_file: ":memory:"
|
||||
workspace_root: ./workspaces
|
||||
EOF
|
||||
|
||||
- name: Run tests with pytest
|
||||
run: |
|
||||
python -m pytest tests/ -v
|
||||
|
||||
- name: Upload coverage to Codecov (optional)
|
||||
uses: codecov/codecov-action@v3
|
||||
if: matrix.python-version == '3.10' && github.event_name == 'push'
|
||||
with:
|
||||
file: ./coverage.xml
|
||||
fail_ci_if_error: false
|
||||
build-frontend:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
- name: Install frontend dependencies
|
||||
run: |
|
||||
cd frontend && npm ci
|
||||
# 统一配置生成,用于前端构建(与后端测试共享相同配置)
|
||||
- name: Create config.yml for frontend build
|
||||
run: |
|
||||
cat > config.yml << 'EOF'
|
||||
backend_port: 3000
|
||||
frontend_port: 4000
|
||||
max_iterations: 5
|
||||
sub_agent:
|
||||
max_iterations: 3
|
||||
max_tokens: 4096
|
||||
max_agents: 5
|
||||
max_concurrency: 3
|
||||
models:
|
||||
- id: dummy
|
||||
name: Dummy
|
||||
api_url: https://api.example.com
|
||||
api_key: dummy-key
|
||||
default_model: dummy
|
||||
db_type: sqlite
|
||||
db_sqlite_file: ":memory:"
|
||||
workspace_root: ./workspaces
|
||||
EOF
|
||||
- name: Build frontend
|
||||
run: |
|
||||
cd frontend && npm run build
|
||||
|
|
@ -28,3 +28,6 @@
|
|||
!frontend/src/**/*.css
|
||||
!frontend/public/
|
||||
!frontend/public/**
|
||||
|
||||
# CI / CD
|
||||
!.github/workflows/*
|
||||
|
|
|
|||
|
|
@ -36,6 +36,11 @@ frontend_port: 4000
|
|||
# Max agentic loop iterations (tool call rounds)
|
||||
max_iterations: 15
|
||||
|
||||
# Sub-agent settings (multi_agent tool)
|
||||
sub_agent:
|
||||
max_iterations: 3 # Max tool-call rounds per sub-agent
|
||||
max_concurrency: 3 # ThreadPoolExecutor max workers
|
||||
|
||||
# Available models
|
||||
# Each model must have its own id, name, api_url, api_key
|
||||
models:
|
||||
|
|
@ -117,6 +122,7 @@ backend/
|
|||
│ ├── data.py # 计算器、文本、JSON 处理
|
||||
│ ├── weather.py # 天气查询(模拟)
|
||||
│ ├── file_ops.py # 文件操作(6 个工具,project_id 自动注入)
|
||||
│ ├── agent.py # 多智能体(子 Agent 并发执行,工具权限隔离)
|
||||
│ └── code.py # Python 代码执行(沙箱)
|
||||
└── utils/ # 辅助函数
|
||||
├── helpers.py # 通用函数(ok/err/build_messages 等)
|
||||
|
|
@ -207,6 +213,7 @@ frontend/
|
|||
| **代码执行** | execute_python | 沙箱环境执行 Python |
|
||||
| **文件操作** | file_read, file_write, file_delete, file_list, file_exists, file_mkdir | project_id 自动注入 |
|
||||
| **天气** | get_weather | 天气查询(模拟) |
|
||||
| **智能体** | multi_agent | 派生子 Agent 并发执行(禁止递归,工具权限与主 Agent 一致) |
|
||||
|
||||
## 文档
|
||||
|
||||
|
|
|
|||
|
|
@ -1,41 +1,117 @@
|
|||
"""Configuration management"""
|
||||
"""Configuration management using dataclasses"""
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from backend import load_config
|
||||
|
||||
_cfg = load_config()
|
||||
|
||||
# Model list (for /api/models endpoint)
|
||||
MODELS = _cfg.get("models", [])
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""Individual model configuration."""
|
||||
id: str
|
||||
name: str
|
||||
api_url: str
|
||||
api_key: str
|
||||
|
||||
# Validate each model has required fields at startup
|
||||
_REQUIRED_MODEL_KEYS = {"id", "name", "api_url", "api_key"}
|
||||
_model_ids_seen = set()
|
||||
for _i, _m in enumerate(MODELS):
|
||||
_missing = _REQUIRED_MODEL_KEYS - set(_m.keys())
|
||||
if _missing:
|
||||
print(f"[config] ERROR: models[{_i}] missing required fields: {_missing}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
if _m["id"] in _model_ids_seen:
|
||||
print(f"[config] ERROR: duplicate model id '{_m['id']}'", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
_model_ids_seen.add(_m["id"])
|
||||
|
||||
# Per-model config lookup: {model_id: {api_url, api_key}}
|
||||
MODEL_CONFIG = {m["id"]: {"api_url": m["api_url"], "api_key": m["api_key"]} for m in MODELS}
|
||||
@dataclass
|
||||
class SubAgentConfig:
|
||||
"""Sub-agent (multi_agent tool) settings."""
|
||||
max_iterations: int = 3
|
||||
max_concurrency: int = 3
|
||||
timeout: int = 60
|
||||
|
||||
# default_model must exist in models
|
||||
DEFAULT_MODEL = _cfg.get("default_model", "")
|
||||
if DEFAULT_MODEL and DEFAULT_MODEL not in MODEL_CONFIG:
|
||||
print(f"[config] ERROR: default_model '{DEFAULT_MODEL}' not found in models", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
if MODELS and not DEFAULT_MODEL:
|
||||
DEFAULT_MODEL = MODELS[0]["id"]
|
||||
|
||||
# Max agentic loop iterations (tool call rounds)
|
||||
MAX_ITERATIONS = _cfg.get("max_iterations", 5)
|
||||
@dataclass
|
||||
class CodeExecutionConfig:
|
||||
"""Code execution settings."""
|
||||
default_strictness: str = "standard"
|
||||
extra_allowed_modules: Dict = field(default_factory=dict)
|
||||
backend: str = "subprocess" # subprocess or docker
|
||||
docker_image: str = "python:3.12-slim"
|
||||
docker_network: str = "none"
|
||||
docker_user: str = "nobody"
|
||||
docker_memory_limit: Optional[str] = None
|
||||
docker_cpu_shares: Optional[int] = None
|
||||
|
||||
# Max parallel workers for tool execution (ThreadPoolExecutor)
|
||||
TOOL_MAX_WORKERS = _cfg.get("tool_max_workers", 4)
|
||||
|
||||
# Max character length for a single tool result content (truncated if exceeded)
|
||||
TOOL_RESULT_MAX_LENGTH = _cfg.get("tool_result_max_length", 4096)
|
||||
@dataclass
|
||||
class AppConfig:
|
||||
"""Main application configuration."""
|
||||
models: List[ModelConfig] = field(default_factory=list)
|
||||
default_model: str = ""
|
||||
max_iterations: int = 5
|
||||
tool_max_workers: int = 4
|
||||
sub_agent: SubAgentConfig = field(default_factory=SubAgentConfig)
|
||||
code_execution: CodeExecutionConfig = field(default_factory=CodeExecutionConfig)
|
||||
|
||||
# Per-model config lookup: {model_id: ModelConfig}
|
||||
_model_config_map: Dict[str, ModelConfig] = field(default_factory=dict, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Build lookup map after initialization."""
|
||||
self._model_config_map = {m.id: m for m in self.models}
|
||||
|
||||
def get_model_config(self, model_id: str) -> Optional[ModelConfig]:
|
||||
"""Get model config by ID."""
|
||||
return self._model_config_map.get(model_id)
|
||||
|
||||
def get_model_credentials(self, model_id: str) -> tuple:
|
||||
"""Get (api_url, api_key) for a model."""
|
||||
cfg = self.get_model_config(model_id)
|
||||
if not cfg:
|
||||
raise ValueError(f"Unknown model: '{model_id}', not found in config")
|
||||
if not cfg.api_url:
|
||||
raise ValueError(f"Model '{model_id}' has no api_url configured")
|
||||
if not cfg.api_key:
|
||||
raise ValueError(f"Model '{model_id}' has no api_key configured")
|
||||
return cfg.api_url, cfg.api_key
|
||||
|
||||
|
||||
def _parse_config(raw: dict) -> AppConfig:
|
||||
"""Parse raw YAML config into AppConfig dataclass."""
|
||||
# Parse models
|
||||
models = []
|
||||
for m in raw.get("models", []):
|
||||
models.append(ModelConfig(
|
||||
id=m["id"],
|
||||
name=m["name"],
|
||||
api_url=m["api_url"],
|
||||
api_key=m["api_key"],
|
||||
))
|
||||
|
||||
# Parse sub_agent
|
||||
sa_raw = raw.get("sub_agent", {})
|
||||
sub_agent = SubAgentConfig(
|
||||
max_iterations=sa_raw.get("max_iterations", 3),
|
||||
max_concurrency=sa_raw.get("max_concurrency", 3),
|
||||
timeout=sa_raw.get("timeout", 60),
|
||||
)
|
||||
|
||||
# Parse code_execution
|
||||
ce_raw = raw.get("code_execution", {})
|
||||
code_execution = CodeExecutionConfig(
|
||||
default_strictness=ce_raw.get("default_strictness", "standard"),
|
||||
extra_allowed_modules=ce_raw.get("extra_allowed_modules", {}),
|
||||
backend=ce_raw.get("backend", "subprocess"),
|
||||
docker_image=ce_raw.get("docker_image", "python:3.12-slim"),
|
||||
docker_network=ce_raw.get("docker_network", "none"),
|
||||
docker_user=ce_raw.get("docker_user", "nobody"),
|
||||
docker_memory_limit=ce_raw.get("docker_memory_limit"),
|
||||
docker_cpu_shares=ce_raw.get("docker_cpu_shares"),
|
||||
)
|
||||
|
||||
return AppConfig(
|
||||
models=models,
|
||||
default_model=raw.get("default_model", ""),
|
||||
max_iterations=raw.get("max_iterations", 5),
|
||||
tool_max_workers=raw.get("tool_max_workers", 4),
|
||||
sub_agent=sub_agent,
|
||||
code_execution=code_execution,
|
||||
)
|
||||
|
||||
|
||||
# Load and validate configuration at startup
|
||||
_raw_cfg = load_config()
|
||||
config = _parse_config(_raw_cfg)
|
||||
|
|
|
|||
|
|
@ -8,16 +8,16 @@ from backend.routes.stats import bp as stats_bp
|
|||
from backend.routes.projects import bp as projects_bp
|
||||
from backend.routes.auth import bp as auth_bp, init_auth
|
||||
from backend.services.llm_client import LLMClient
|
||||
from backend.config import MODEL_CONFIG
|
||||
from backend.config import config
|
||||
|
||||
|
||||
def register_routes(app: Flask):
|
||||
"""Register all route blueprints"""
|
||||
# Initialize LLM client with per-model config
|
||||
client = LLMClient(MODEL_CONFIG)
|
||||
# Initialize LLM client with config
|
||||
client = LLMClient(config)
|
||||
init_chat_service(client)
|
||||
|
||||
# Register LLM client in service locator so tools (e.g. agent_task) can access it
|
||||
# Register LLM client in service locator so tools (e.g. multi_agent) can access it
|
||||
from backend.tools import register_service
|
||||
register_service("llm_client", client)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from flask import Blueprint, request, g
|
|||
from backend import db
|
||||
from backend.models import Conversation, Project
|
||||
from backend.utils.helpers import ok, err, to_dict
|
||||
from backend.config import DEFAULT_MODEL
|
||||
from backend.config import config
|
||||
|
||||
bp = Blueprint("conversations", __name__)
|
||||
|
||||
|
|
@ -40,7 +40,7 @@ def conversation_list():
|
|||
user_id=user.id,
|
||||
project_id=project_id or None,
|
||||
title=d.get("title", ""),
|
||||
model=d.get("model", DEFAULT_MODEL),
|
||||
model=d.get("model", config.default_model),
|
||||
system_prompt=d.get("system_prompt", ""),
|
||||
temperature=d.get("temperature", 1.0),
|
||||
max_tokens=d.get("max_tokens", 65536),
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""Model list API routes"""
|
||||
from flask import Blueprint
|
||||
from backend.utils.helpers import ok
|
||||
from backend.config import MODELS
|
||||
from backend.config import config
|
||||
|
||||
bp = Blueprint("models", __name__)
|
||||
|
||||
|
|
@ -13,7 +13,10 @@ _SENSITIVE_KEYS = {"api_key", "api_url"}
|
|||
def list_models():
|
||||
"""Get available model list (without sensitive fields like api_key)"""
|
||||
safe_models = [
|
||||
{k: v for k, v in m.items() if k not in _SENSITIVE_KEYS}
|
||||
for m in MODELS
|
||||
{
|
||||
"id": m.id,
|
||||
"name": m.name,
|
||||
}
|
||||
for m in config.models
|
||||
]
|
||||
return ok(safe_models)
|
||||
|
|
@ -1,29 +1,38 @@
|
|||
"""Token statistics API routes"""
|
||||
from datetime import date, timedelta
|
||||
from datetime import date, timedelta, datetime, timezone
|
||||
from flask import Blueprint, request, g
|
||||
from sqlalchemy import func
|
||||
from backend.models import TokenUsage
|
||||
from sqlalchemy import func, extract
|
||||
from backend.models import TokenUsage, Message, Conversation
|
||||
from backend.utils.helpers import ok, err
|
||||
from backend import db
|
||||
|
||||
bp = Blueprint("stats", __name__)
|
||||
|
||||
|
||||
def _utc_today():
|
||||
"""Get today's date in UTC to match stored timestamps."""
|
||||
return datetime.now(timezone.utc).date()
|
||||
|
||||
|
||||
@bp.route("/api/stats/tokens", methods=["GET"])
|
||||
def token_stats():
|
||||
"""Get token usage statistics"""
|
||||
user = g.current_user
|
||||
period = request.args.get("period", "daily")
|
||||
|
||||
today = date.today()
|
||||
today = _utc_today()
|
||||
|
||||
if period == "daily":
|
||||
stats = TokenUsage.query.filter_by(user_id=user.id, date=today).all()
|
||||
# Hourly breakdown from Message table
|
||||
hourly = _build_hourly_stats(user.id, today)
|
||||
result = {
|
||||
"period": "daily",
|
||||
"date": today.isoformat(),
|
||||
"prompt_tokens": sum(s.prompt_tokens for s in stats),
|
||||
"completion_tokens": sum(s.completion_tokens for s in stats),
|
||||
"total_tokens": sum(s.total_tokens for s in stats),
|
||||
"hourly": hourly,
|
||||
"by_model": {
|
||||
s.model: {
|
||||
"prompt": s.prompt_tokens,
|
||||
|
|
@ -92,3 +101,37 @@ def _build_period_result(stats, period, start_date, end_date, days):
|
|||
"daily": daily_data,
|
||||
"by_model": by_model,
|
||||
}
|
||||
|
||||
|
||||
def _build_hourly_stats(user_id, day):
|
||||
"""Build hourly token breakdown for a given day (UTC) from Message table."""
|
||||
day_start = datetime.combine(day, datetime.min.time()).replace(tzinfo=timezone.utc)
|
||||
day_end = datetime.combine(day, datetime.max.time()).replace(tzinfo=timezone.utc)
|
||||
|
||||
conv_ids = (
|
||||
db.session.query(Conversation.id)
|
||||
.filter(Conversation.user_id == user_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
rows = (
|
||||
db.session.query(
|
||||
extract("hour", Message.created_at).label("hour"),
|
||||
func.sum(Message.token_count).label("total"),
|
||||
)
|
||||
.filter(
|
||||
Message.conversation_id.in_(conv_ids),
|
||||
Message.role == "assistant",
|
||||
Message.created_at >= day_start,
|
||||
Message.created_at <= day_end,
|
||||
)
|
||||
.group_by(extract("hour", Message.created_at))
|
||||
.all()
|
||||
)
|
||||
|
||||
hourly = {}
|
||||
for r in rows:
|
||||
h = int(r.hour)
|
||||
hourly[str(h)] = {"total": r.total or 0}
|
||||
|
||||
return hourly
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from backend.utils.helpers import (
|
|||
build_messages,
|
||||
)
|
||||
from backend.services.llm_client import LLMClient
|
||||
from backend.config import MAX_ITERATIONS, TOOL_MAX_WORKERS, TOOL_RESULT_MAX_LENGTH
|
||||
from backend.config import config as _cfg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -61,13 +61,20 @@ class ChatService:
|
|||
"""
|
||||
conv_id = conv.id
|
||||
conv_model = conv.model
|
||||
conv_max_tokens = conv.max_tokens
|
||||
conv_temperature = conv.temperature
|
||||
conv_thinking_enabled = conv.thinking_enabled
|
||||
app = current_app._get_current_object()
|
||||
tools = registry.list_all() if tools_enabled else None
|
||||
initial_messages = build_messages(conv, project_id)
|
||||
|
||||
executor = ToolExecutor(registry=registry)
|
||||
|
||||
context = {"model": conv_model}
|
||||
context = {
|
||||
"model": conv_model,
|
||||
"max_tokens": conv_max_tokens,
|
||||
"temperature": conv_temperature,
|
||||
}
|
||||
if project_id:
|
||||
context["project_id"] = project_id
|
||||
elif conv.project_id:
|
||||
|
|
@ -82,10 +89,26 @@ class ChatService:
|
|||
total_completion_tokens = 0
|
||||
total_prompt_tokens = 0
|
||||
|
||||
for iteration in range(MAX_ITERATIONS):
|
||||
|
||||
for iteration in range(_cfg.max_iterations):
|
||||
# Helper to parse stream_result event
|
||||
def parse_stream_result(event_str):
|
||||
"""Parse stream_result SSE event and extract data dict."""
|
||||
# Format: "event: stream_result\ndata: {...}\n\n"
|
||||
try:
|
||||
stream_result = self._stream_llm_response(
|
||||
app, conv_id, messages, tools, tool_choice, step_index
|
||||
for line in event_str.split('\n'):
|
||||
if line.startswith('data: '):
|
||||
return json.loads(line[6:])
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
# Collect SSE events and extract final stream_result
|
||||
try:
|
||||
stream_gen = self._stream_llm_response(
|
||||
app, messages, tools, tool_choice, step_index,
|
||||
conv_model, conv_max_tokens, conv_temperature,
|
||||
conv_thinking_enabled,
|
||||
)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
resp = e.response
|
||||
|
|
@ -107,22 +130,37 @@ class ChatService:
|
|||
yield _sse_event("error", {"content": f"Internal error: {e}"})
|
||||
return
|
||||
|
||||
if stream_result is None:
|
||||
return # Client disconnected
|
||||
result_data = None
|
||||
try:
|
||||
for event_str in stream_gen:
|
||||
# Check if this is a stream_result event (final event)
|
||||
if event_str.startswith("event: stream_result"):
|
||||
result_data = parse_stream_result(event_str)
|
||||
else:
|
||||
# Forward process_step events to client in real-time
|
||||
yield event_str
|
||||
except Exception as e:
|
||||
logger.exception("Error during streaming")
|
||||
yield _sse_event("error", {"content": f"Stream error: {e}"})
|
||||
return
|
||||
|
||||
full_content, full_thinking, tool_calls_list, \
|
||||
thinking_step_id, thinking_step_idx, \
|
||||
text_step_id, text_step_idx, \
|
||||
completion_tokens, prompt_tokens, \
|
||||
sse_chunks = stream_result
|
||||
if result_data is None:
|
||||
return # Client disconnected or error
|
||||
|
||||
# Extract data from stream_result
|
||||
full_content = result_data["full_content"]
|
||||
full_thinking = result_data["full_thinking"]
|
||||
tool_calls_list = result_data["tool_calls_list"]
|
||||
thinking_step_id = result_data["thinking_step_id"]
|
||||
thinking_step_idx = result_data["thinking_step_idx"]
|
||||
text_step_id = result_data["text_step_id"]
|
||||
text_step_idx = result_data["text_step_idx"]
|
||||
completion_tokens = result_data["completion_tokens"]
|
||||
prompt_tokens = result_data["prompt_tokens"]
|
||||
|
||||
total_prompt_tokens += prompt_tokens
|
||||
total_completion_tokens += completion_tokens
|
||||
|
||||
# Yield accumulated SSE chunks to frontend
|
||||
for chunk in sse_chunks:
|
||||
yield chunk
|
||||
|
||||
# Save thinking/text steps to all_steps for DB storage
|
||||
if thinking_step_id is not None:
|
||||
all_steps.append({
|
||||
|
|
@ -185,7 +223,7 @@ class ChatService:
|
|||
# Append assistant message + tool results for the next iteration
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": full_content or None,
|
||||
"content": full_content or "",
|
||||
"tool_calls": tool_calls_list,
|
||||
})
|
||||
messages.extend(tool_results)
|
||||
|
|
@ -232,12 +270,19 @@ class ChatService:
|
|||
# ------------------------------------------------------------------
|
||||
|
||||
def _stream_llm_response(
|
||||
self, app, conv_id, messages, tools, tool_choice, step_index
|
||||
self, app, messages, tools, tool_choice, step_index,
|
||||
model, max_tokens, temperature, thinking_enabled,
|
||||
):
|
||||
"""Call LLM streaming API and parse the response.
|
||||
"""Call LLM streaming API and yield SSE events in real-time.
|
||||
|
||||
Returns a tuple of parsed results, or None if the client disconnected.
|
||||
Raises HTTPError / ConnectionError / Timeout for the caller to handle.
|
||||
This is a generator that yields SSE event strings as they are received.
|
||||
The final yield is a 'stream_result' event containing the accumulated data.
|
||||
|
||||
Yields:
|
||||
str: SSE event strings (process_step events, then stream_result)
|
||||
|
||||
Raises:
|
||||
HTTPError / ConnectionError / Timeout for the caller to handle.
|
||||
"""
|
||||
full_content = ""
|
||||
full_thinking = ""
|
||||
|
|
@ -250,16 +295,13 @@ class ChatService:
|
|||
text_step_id = None
|
||||
text_step_idx = None
|
||||
|
||||
sse_chunks = [] # Collect SSE events to yield later
|
||||
|
||||
with app.app_context():
|
||||
active_conv = db.session.get(Conversation, conv_id)
|
||||
resp = self.llm.call(
|
||||
model=active_conv.model,
|
||||
model=model,
|
||||
messages=messages,
|
||||
max_tokens=active_conv.max_tokens,
|
||||
temperature=active_conv.temperature,
|
||||
thinking_enabled=active_conv.thinking_enabled,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
thinking_enabled=thinking_enabled,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=True,
|
||||
|
|
@ -269,7 +311,7 @@ class ChatService:
|
|||
for line in resp.iter_lines():
|
||||
if _client_disconnected():
|
||||
resp.close()
|
||||
return None
|
||||
return # Client disconnected, stop generator
|
||||
|
||||
if not line:
|
||||
continue
|
||||
|
|
@ -295,37 +337,44 @@ class ChatService:
|
|||
|
||||
delta = choices[0].get("delta", {})
|
||||
|
||||
# Yield thinking content in real-time
|
||||
reasoning = delta.get("reasoning_content", "")
|
||||
if reasoning:
|
||||
full_thinking += reasoning
|
||||
if thinking_step_id is None:
|
||||
thinking_step_id = f"step-{step_index}"
|
||||
thinking_step_idx = step_index
|
||||
sse_chunks.append(_sse_event("process_step", {
|
||||
yield _sse_event("process_step", {
|
||||
"id": thinking_step_id, "index": thinking_step_idx,
|
||||
"type": "thinking", "content": full_thinking,
|
||||
}))
|
||||
})
|
||||
|
||||
# Yield text content in real-time
|
||||
text = delta.get("content", "")
|
||||
if text:
|
||||
full_content += text
|
||||
if text_step_id is None:
|
||||
text_step_idx = step_index + (1 if thinking_step_id is not None else 0)
|
||||
text_step_id = f"step-{text_step_idx}"
|
||||
sse_chunks.append(_sse_event("process_step", {
|
||||
yield _sse_event("process_step", {
|
||||
"id": text_step_id, "index": text_step_idx,
|
||||
"type": "text", "content": full_content,
|
||||
}))
|
||||
})
|
||||
|
||||
tool_calls_list = self._process_tool_calls_delta(delta, tool_calls_list)
|
||||
|
||||
return (
|
||||
full_content, full_thinking, tool_calls_list,
|
||||
thinking_step_id, thinking_step_idx,
|
||||
text_step_id, text_step_idx,
|
||||
token_count, prompt_tokens,
|
||||
sse_chunks,
|
||||
)
|
||||
# Final yield: stream_result event with accumulated data
|
||||
yield _sse_event("stream_result", {
|
||||
"full_content": full_content,
|
||||
"full_thinking": full_thinking,
|
||||
"tool_calls_list": tool_calls_list,
|
||||
"thinking_step_id": thinking_step_id,
|
||||
"thinking_step_idx": thinking_step_idx,
|
||||
"text_step_id": text_step_id,
|
||||
"text_step_idx": text_step_idx,
|
||||
"completion_tokens": token_count,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
})
|
||||
|
||||
def _execute_tools_safe(self, app, executor, tool_calls_list, context):
|
||||
"""Execute tool calls with top-level error wrapping.
|
||||
|
|
@ -336,17 +385,17 @@ class ChatService:
|
|||
try:
|
||||
if len(tool_calls_list) > 1:
|
||||
with app.app_context():
|
||||
tool_results = executor.process_tool_calls_parallel(
|
||||
tool_calls_list, context, max_workers=TOOL_MAX_WORKERS
|
||||
return executor.process_tool_calls_parallel(
|
||||
tool_calls_list, context, max_workers=_cfg.tool_max_workers
|
||||
)
|
||||
else:
|
||||
with app.app_context():
|
||||
tool_results = executor.process_tool_calls(
|
||||
return executor.process_tool_calls(
|
||||
tool_calls_list, context
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error during tool execution")
|
||||
tool_results = [
|
||||
return [
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tc["id"],
|
||||
|
|
@ -359,30 +408,6 @@ class ChatService:
|
|||
for tc in tool_calls_list
|
||||
]
|
||||
|
||||
# Truncate oversized tool result content
|
||||
for tr in tool_results:
|
||||
if len(tr["content"]) > TOOL_RESULT_MAX_LENGTH:
|
||||
try:
|
||||
result_data = json.loads(tr["content"])
|
||||
original = result_data
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
original = None
|
||||
|
||||
tr["content"] = json.dumps(
|
||||
{"success": False, "error": "Tool result too large, truncated"},
|
||||
ensure_ascii=False,
|
||||
) if not original else json.dumps(
|
||||
{
|
||||
**original,
|
||||
"truncated": True,
|
||||
"_note": f"Content truncated, original length {len(tr['content'])} chars",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
)[:TOOL_RESULT_MAX_LENGTH]
|
||||
|
||||
return tool_results
|
||||
|
||||
def _save_message(
|
||||
self, app, conv_id, conv_model, msg_id,
|
||||
full_content, all_tool_calls, all_tool_results,
|
||||
|
|
|
|||
|
|
@ -35,28 +35,22 @@ def _detect_provider(api_url: str) -> str:
|
|||
class LLMClient:
|
||||
"""OpenAI-compatible LLM API client.
|
||||
|
||||
Each model must have its own api_url and api_key configured in MODEL_CONFIG.
|
||||
Each model must have its own api_url and api_key configured in config.models.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: dict):
|
||||
"""Initialize with per-model config lookup.
|
||||
def __init__(self, cfg):
|
||||
"""Initialize with AppConfig.
|
||||
|
||||
Args:
|
||||
model_config: {model_id: {"api_url": ..., "api_key": ...}}
|
||||
cfg: AppConfig dataclass instance
|
||||
"""
|
||||
self.model_config = model_config
|
||||
self.cfg = cfg
|
||||
|
||||
def _get_credentials(self, model: str):
|
||||
"""Get api_url and api_key for a model, with env-var expansion."""
|
||||
cfg = self.model_config.get(model)
|
||||
if not cfg:
|
||||
raise ValueError(f"Unknown model: '{model}', not found in config")
|
||||
api_url = _resolve_env_vars(cfg.get("api_url", ""))
|
||||
api_key = _resolve_env_vars(cfg.get("api_key", ""))
|
||||
if not api_url:
|
||||
raise ValueError(f"Model '{model}' has no api_url configured")
|
||||
if not api_key:
|
||||
raise ValueError(f"Model '{model}' has no api_key configured")
|
||||
api_url, api_key = self.cfg.get_model_credentials(model)
|
||||
api_url = _resolve_env_vars(api_url)
|
||||
api_key = _resolve_env_vars(api_key)
|
||||
return api_url, api_key
|
||||
|
||||
def _build_body(self, model, messages, max_tokens, temperature, thinking_enabled,
|
||||
|
|
|
|||
|
|
@ -15,13 +15,13 @@ Usage:
|
|||
result = registry.execute("web_search", {"query": "Python"})
|
||||
"""
|
||||
|
||||
from backend.tools.core import ToolDefinition, ToolResult, ToolRegistry, registry
|
||||
from backend.tools.factory import tool, register_tool
|
||||
from backend.tools.core import registry
|
||||
from backend.tools.factory import tool
|
||||
from backend.tools.executor import ToolExecutor
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Service locator – allows tools (e.g. agent_task) to access LLM client
|
||||
# Service locator – allows tools (e.g. multi_agent) to access LLM client
|
||||
# ---------------------------------------------------------------------------
|
||||
_services: dict = {}
|
||||
|
||||
|
|
@ -47,16 +47,12 @@ def init_tools() -> None:
|
|||
|
||||
# Public API exports
|
||||
__all__ = [
|
||||
# Core classes
|
||||
"ToolDefinition",
|
||||
"ToolResult",
|
||||
"ToolRegistry",
|
||||
"ToolExecutor",
|
||||
# Instances
|
||||
"registry",
|
||||
# Factory functions
|
||||
"tool",
|
||||
"register_tool",
|
||||
# Classes
|
||||
"ToolExecutor",
|
||||
# Initialization
|
||||
"init_tools",
|
||||
# Service locator
|
||||
|
|
|
|||
|
|
@ -1,130 +1,49 @@
|
|||
"""Multi-agent tools for concurrent and batch task execution.
|
||||
|
||||
Provides:
|
||||
- parallel_execute: Run multiple tool calls concurrently
|
||||
- agent_task: Spawn sub-agents with their own LLM conversation loops
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from typing import List, Any, Optional
|
||||
from backend.tools import get_service
|
||||
from backend.tools.factory import tool
|
||||
from backend.tools.core import registry
|
||||
from backend.tools.executor import ToolExecutor
|
||||
from backend.config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Sub-agents are forbidden from using multi_agent to prevent infinite recursion
|
||||
BLOCKED_TOOLS = {"multi_agent"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parallel_execute – run multiple tool calls concurrently
|
||||
# ---------------------------------------------------------------------------
|
||||
def _to_executor_calls(tool_calls: list, id_prefix: str = "tc") -> list:
|
||||
"""Normalize tool calls into executor-compatible format.
|
||||
|
||||
@tool(
|
||||
name="parallel_execute",
|
||||
description=(
|
||||
"Execute multiple tool calls concurrently for better performance. "
|
||||
"Use when you have several independent operations that don't depend on each other "
|
||||
"(e.g. reading multiple files, running multiple searches, fetching several pages). "
|
||||
"Results are returned in the same order as the input."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tool_calls": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Tool name to execute",
|
||||
},
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"description": "Arguments for the tool",
|
||||
},
|
||||
},
|
||||
"required": ["name", "arguments"],
|
||||
},
|
||||
"description": "List of tool calls to execute in parallel (max 10)",
|
||||
},
|
||||
"concurrency": {
|
||||
"type": "integer",
|
||||
"description": "Max concurrent executions (1-5, default 3)",
|
||||
"default": 3,
|
||||
},
|
||||
},
|
||||
"required": ["tool_calls"],
|
||||
},
|
||||
category="agent",
|
||||
)
|
||||
def parallel_execute(arguments: dict) -> dict:
|
||||
"""Execute multiple tool calls concurrently.
|
||||
|
||||
Args:
|
||||
arguments: {
|
||||
"tool_calls": [
|
||||
{"name": "file_read", "arguments": {"path": "a.py"}},
|
||||
{"name": "web_search", "arguments": {"query": "python"}}
|
||||
],
|
||||
"concurrency": 3,
|
||||
"_project_id": "..." // injected by executor
|
||||
}
|
||||
|
||||
Returns:
|
||||
{"results": [{index, tool_name, success, data/error}]}
|
||||
Accepts two input shapes:
|
||||
- LLM format: {"function": {"name": ..., "arguments": ...}}
|
||||
- Simple format: {"name": ..., "arguments": ...}
|
||||
"""
|
||||
tool_calls = arguments["tool_calls"]
|
||||
concurrency = min(max(arguments.get("concurrency", 3), 1), 5)
|
||||
|
||||
if len(tool_calls) > 10:
|
||||
return {"success": False, "error": "Maximum 10 tool calls allowed per parallel execution"}
|
||||
|
||||
# Build executor context from injected fields
|
||||
context = {}
|
||||
project_id = arguments.get("_project_id")
|
||||
if project_id:
|
||||
context["project_id"] = project_id
|
||||
|
||||
# Format tool_calls into executor-compatible format
|
||||
executor_calls = []
|
||||
for i, tc in enumerate(tool_calls):
|
||||
if "function" in tc:
|
||||
func = tc["function"]
|
||||
executor_calls.append({
|
||||
"id": f"pe-{i}",
|
||||
"id": tc.get("id", f"{id_prefix}-{i}"),
|
||||
"type": tc.get("type", "function"),
|
||||
"function": {
|
||||
"name": func["name"],
|
||||
"arguments": func["arguments"],
|
||||
},
|
||||
})
|
||||
else:
|
||||
executor_calls.append({
|
||||
"id": f"{id_prefix}-{i}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc["name"],
|
||||
"arguments": json.dumps(tc["arguments"], ensure_ascii=False),
|
||||
},
|
||||
})
|
||||
return executor_calls
|
||||
|
||||
# Use ToolExecutor for proper context injection, caching and dedup
|
||||
executor = ToolExecutor(registry=registry, enable_cache=False)
|
||||
executor_results = executor.process_tool_calls_parallel(
|
||||
executor_calls, context, max_workers=concurrency
|
||||
)
|
||||
|
||||
# Format output
|
||||
results = []
|
||||
for er in executor_results:
|
||||
try:
|
||||
content = json.loads(er["content"]) if isinstance(er["content"], str) else er["content"]
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
content = {"success": False, "error": "Failed to parse result"}
|
||||
results.append({
|
||||
"index": len(results),
|
||||
"tool_name": er["name"],
|
||||
**content,
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"results": results,
|
||||
"total": len(results),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# agent_task – spawn sub-agents with independent LLM conversation loops
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _run_sub_agent(
|
||||
task_name: str,
|
||||
|
|
@ -132,6 +51,7 @@ def _run_sub_agent(
|
|||
tool_names: Optional[List[str]],
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
project_id: Optional[str],
|
||||
app: Any,
|
||||
max_iterations: int = 3,
|
||||
|
|
@ -141,7 +61,6 @@ def _run_sub_agent(
|
|||
Each sub-agent gets its own ToolExecutor instance and runs a simplified
|
||||
version of the main agent loop, limited to prevent runaway cost.
|
||||
"""
|
||||
from backend.tools import get_service
|
||||
|
||||
llm_client = get_service("llm_client")
|
||||
if not llm_client:
|
||||
|
|
@ -151,16 +70,21 @@ def _run_sub_agent(
|
|||
"error": "LLM client not available",
|
||||
}
|
||||
|
||||
# Build tool list – filter to requested tools or use all
|
||||
# Build tool list – filter to requested tools, then remove blocked
|
||||
all_tools = registry.list_all()
|
||||
if tool_names:
|
||||
allowed = set(tool_names)
|
||||
tools = [t for t in all_tools if t["function"]["name"] in allowed]
|
||||
else:
|
||||
tools = all_tools
|
||||
tools = list(all_tools)
|
||||
|
||||
# Remove blocked tools to prevent recursion
|
||||
tools = [t for t in tools if t["function"]["name"] not in BLOCKED_TOOLS]
|
||||
|
||||
executor = ToolExecutor(registry=registry)
|
||||
context = {"project_id": project_id} if project_id else None
|
||||
context = {"model": model}
|
||||
if project_id:
|
||||
context["project_id"] = project_id
|
||||
|
||||
# System prompt: instruction + reminder to give a final text answer
|
||||
system_msg = (
|
||||
|
|
@ -170,17 +94,21 @@ def _run_sub_agent(
|
|||
)
|
||||
messages = [{"role": "system", "content": system_msg}]
|
||||
|
||||
for _ in range(max_iterations):
|
||||
for i in range(max_iterations):
|
||||
is_final = (i == max_iterations - 1)
|
||||
try:
|
||||
with app.app_context():
|
||||
resp = llm_client.call(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools if tools else None,
|
||||
# On the last iteration, don't pass tools so the LLM is
|
||||
# forced to produce a final text response instead of calling
|
||||
# more tools.
|
||||
tools=None if is_final else (tools if tools else None),
|
||||
stream=False,
|
||||
max_tokens=min(max_tokens, 4096),
|
||||
temperature=0.7,
|
||||
timeout=60,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
timeout=config.sub_agent.timeout,
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
|
|
@ -196,20 +124,26 @@ def _run_sub_agent(
|
|||
message = choice["message"]
|
||||
|
||||
if message.get("tool_calls"):
|
||||
messages.append(message)
|
||||
tc_list = message["tool_calls"]
|
||||
# Convert OpenAI tool_calls to executor format
|
||||
executor_calls = []
|
||||
for tc in tc_list:
|
||||
executor_calls.append({
|
||||
"id": tc.get("id", ""),
|
||||
"type": tc.get("type", "function"),
|
||||
"function": {
|
||||
"name": tc["function"]["name"],
|
||||
"arguments": tc["function"]["arguments"],
|
||||
},
|
||||
# Only extract needed fields — LLM response may contain extra
|
||||
# fields (e.g. reasoning_content) that the API rejects on re-send
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": message.get("content") or "",
|
||||
"tool_calls": message["tool_calls"],
|
||||
})
|
||||
tool_results = executor.process_tool_calls(executor_calls, context)
|
||||
tc_list = message["tool_calls"]
|
||||
executor_calls = _to_executor_calls(tc_list)
|
||||
# Execute tools inside app_context – file ops and other DB-
|
||||
# dependent tools require an active Flask context and session.
|
||||
with app.app_context():
|
||||
if len(executor_calls) > 1:
|
||||
tool_results = executor.process_tool_calls_parallel(
|
||||
executor_calls, context
|
||||
)
|
||||
else:
|
||||
tool_results = executor.process_tool_calls(
|
||||
executor_calls, context
|
||||
)
|
||||
messages.extend(tool_results)
|
||||
else:
|
||||
# Final text response
|
||||
|
|
@ -226,7 +160,7 @@ def _run_sub_agent(
|
|||
"error": str(e),
|
||||
}
|
||||
|
||||
# Exhausted iterations without final response — return last LLM output if any
|
||||
# Exhausted iterations without final response
|
||||
return {
|
||||
"task_name": task_name,
|
||||
"success": True,
|
||||
|
|
@ -234,49 +168,49 @@ def _run_sub_agent(
|
|||
}
|
||||
|
||||
|
||||
# @tool(
|
||||
# name="agent_task",
|
||||
# description=(
|
||||
# "Spawn one or more sub-agents to work on tasks concurrently. "
|
||||
# "Each agent runs its own independent conversation with the LLM and can use tools. "
|
||||
# "Useful for parallel research, multi-file analysis, or dividing complex tasks into sub-tasks. "
|
||||
# "Each agent is limited to 3 iterations and 4096 tokens to control cost."
|
||||
# ),
|
||||
# parameters={
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "tasks": {
|
||||
# "type": "array",
|
||||
# "items": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "name": {
|
||||
# "type": "string",
|
||||
# "description": "Short name/identifier for this task",
|
||||
# },
|
||||
# "instruction": {
|
||||
# "type": "string",
|
||||
# "description": "Detailed instruction for the sub-agent",
|
||||
# },
|
||||
# "tools": {
|
||||
# "type": "array",
|
||||
# "items": {"type": "string"},
|
||||
# "description": (
|
||||
# "Tool names this agent can use (empty = all tools). "
|
||||
# "e.g. ['file_read', 'file_list', 'web_search']"
|
||||
# ),
|
||||
# },
|
||||
# },
|
||||
# "required": ["name", "instruction"],
|
||||
# },
|
||||
# "description": "Tasks for parallel sub-agents (max 5)",
|
||||
# },
|
||||
# },
|
||||
# "required": ["tasks"],
|
||||
# },
|
||||
# category="agent",
|
||||
# )
|
||||
def agent_task(arguments: dict) -> dict:
|
||||
@tool(
|
||||
name="multi_agent",
|
||||
description=(
|
||||
"Spawn multiple sub-agents to work on tasks concurrently. "
|
||||
"Each agent runs its own independent conversation with the LLM and can use tools. "
|
||||
"Useful for parallel research, multi-file analysis, or dividing complex tasks into sub-tasks. "
|
||||
"Resource limits (iterations, tokens, concurrency) are configured in config.yml -> sub_agent."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tasks": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Short name/identifier for this task",
|
||||
},
|
||||
"instruction": {
|
||||
"type": "string",
|
||||
"description": "Detailed instruction for the sub-agent",
|
||||
},
|
||||
"tools": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"Tool names this agent can use (empty = all tools). "
|
||||
"e.g. ['file_read', 'file_list', 'web_search']"
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["name", "instruction"],
|
||||
},
|
||||
"description": "Tasks for parallel sub-agents (max 5)",
|
||||
},
|
||||
},
|
||||
"required": ["tasks"],
|
||||
},
|
||||
category="agent",
|
||||
)
|
||||
def multi_agent(arguments: dict) -> dict:
|
||||
"""Spawn sub-agents to work on tasks concurrently.
|
||||
|
||||
Args:
|
||||
|
|
@ -296,7 +230,7 @@ def agent_task(arguments: dict) -> dict:
|
|||
}
|
||||
|
||||
Returns:
|
||||
{"success": true, "results": [{task_name, success, response/error}]}
|
||||
{"success": true, "results": [{task_name, success, response/error}], "total": int}
|
||||
"""
|
||||
from flask import current_app
|
||||
|
||||
|
|
@ -309,11 +243,13 @@ def agent_task(arguments: dict) -> dict:
|
|||
app = current_app._get_current_object()
|
||||
|
||||
# Use injected model/project_id from executor context, fall back to defaults
|
||||
model = arguments.get("_model", "glm-5")
|
||||
model = arguments.get("_model") or config.default_model
|
||||
project_id = arguments.get("_project_id")
|
||||
max_tokens = arguments.get("_max_tokens", 65536)
|
||||
temperature = arguments.get("_temperature", 0.7)
|
||||
|
||||
# Execute agents concurrently (max 3 at a time)
|
||||
concurrency = min(len(tasks), 3)
|
||||
# Execute agents concurrently
|
||||
concurrency = min(len(tasks), config.sub_agent.max_concurrency)
|
||||
results = [None] * len(tasks)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=concurrency) as pool:
|
||||
|
|
@ -324,9 +260,11 @@ def agent_task(arguments: dict) -> dict:
|
|||
task["instruction"],
|
||||
task.get("tools"),
|
||||
model,
|
||||
4096,
|
||||
max_tokens,
|
||||
temperature,
|
||||
project_id,
|
||||
app,
|
||||
config.sub_agent.max_iterations,
|
||||
): i
|
||||
for i, task in enumerate(tasks)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,55 +1,131 @@
|
|||
"""Safe code execution tool with sandboxing"""
|
||||
"""Safe code execution tool with sandboxing and strictness levels"""
|
||||
import ast
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Set
|
||||
|
||||
from backend.tools.factory import tool
|
||||
from backend.config import config
|
||||
from backend.tools.docker_executor import DockerExecutor
|
||||
|
||||
|
||||
# Blacklist of dangerous modules - all other modules are allowed
|
||||
BLOCKED_MODULES = {
|
||||
# System-level access
|
||||
"os", "sys", "subprocess", "shutil", "signal", "ctypes",
|
||||
"multiprocessing", "threading", "_thread",
|
||||
# Network access
|
||||
"socket", "http", "urllib", "requests", "ftplib", "smtplib",
|
||||
"telnetlib", "xmlrpc", "asyncio",
|
||||
# File system / I/O
|
||||
"pathlib", "io", "glob", "tempfile", "shutil", "fnmatch",
|
||||
# Code execution / introspection
|
||||
"importlib", "pkgutil", "code", "codeop", "compileall",
|
||||
"runpy", "pdb", "profile", "cProfile",
|
||||
# Dangerous stdlib
|
||||
"webbrowser", "antigravity", "turtle",
|
||||
# IPC / persistence
|
||||
"pickle", "shelve", "marshal", "sqlite3", "dbm",
|
||||
# Process / shell
|
||||
"commands", "pipes", "pty", "posix", "posixpath",
|
||||
}
|
||||
# Strictness profiles configuration
|
||||
# - lenient: no restrictions at all
|
||||
# - standard: allowlist based, only safe modules permitted
|
||||
# - strict: minimal allowlist, only pure computation modules
|
||||
STRICTNESS_PROFILES: Dict[str, dict] = {
|
||||
"lenient": {
|
||||
"timeout": 30,
|
||||
"description": "No restrictions, all modules and builtins allowed",
|
||||
"allowlist_modules": None, # None means all allowed
|
||||
"blocked_builtins": set(),
|
||||
},
|
||||
|
||||
|
||||
# Blacklist of dangerous builtins
|
||||
BLOCKED_BUILTINS = {
|
||||
"eval", "exec", "compile", "open", "input",
|
||||
"__import__", "globals", "locals", "vars",
|
||||
"standard": {
|
||||
"timeout": 10,
|
||||
"description": "Allowlist based, only safe modules and builtins permitted",
|
||||
"allowlist_modules": {
|
||||
# Data types & serialization
|
||||
"json", "csv", "re", "typing",
|
||||
# Data structures
|
||||
"collections", "itertools", "functools", "operator", "heapq", "bisect",
|
||||
"array", "copy", "pprint", "enum",
|
||||
# Math & numbers
|
||||
"math", "cmath", "statistics", "random", "fractions", "decimal", "numbers",
|
||||
# Date & time
|
||||
"datetime", "time", "calendar",
|
||||
# Text processing
|
||||
"string", "textwrap", "unicodedata", "difflib",
|
||||
# Data formats
|
||||
"base64", "binascii", "quopri", "uu", "html", "xml.etree.ElementTree",
|
||||
# Functional & concurrency helpers
|
||||
"dataclasses", "hashlib", "hmac",
|
||||
# Common utilities
|
||||
"abc", "contextlib", "warnings", "logging",
|
||||
},
|
||||
"blocked_builtins": {
|
||||
"eval", "exec", "compile", "__import__",
|
||||
"open", "input", "globals", "locals", "vars",
|
||||
"breakpoint", "exit", "quit",
|
||||
"memoryview", "bytearray",
|
||||
"getattr", "setattr", "delattr",
|
||||
},
|
||||
},
|
||||
|
||||
"strict": {
|
||||
"timeout": 5,
|
||||
"description": "Minimal allowlist, only pure computation modules",
|
||||
"allowlist_modules": {
|
||||
# Pure data structures
|
||||
"collections", "itertools", "functools", "operator",
|
||||
"array", "copy", "enum",
|
||||
# Pure math
|
||||
"math", "cmath", "numbers", "fractions", "decimal",
|
||||
"random", "statistics",
|
||||
# Pure text
|
||||
"string", "textwrap", "unicodedata",
|
||||
# Type hints
|
||||
"typing",
|
||||
# Utilities (no I/O)
|
||||
"dataclasses", "abc", "contextlib",
|
||||
},
|
||||
"blocked_builtins": {
|
||||
"eval", "exec", "compile", "__import__",
|
||||
"open", "input", "globals", "locals", "vars",
|
||||
"breakpoint", "exit", "quit",
|
||||
"memoryview", "bytearray",
|
||||
"dir", "hasattr", "getattr", "setattr", "delattr",
|
||||
"type", "isinstance", "issubclass",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def register_extra_modules(strictness: str, modules: Set[str] | List[str]) -> None:
|
||||
"""Register additional modules to a strictness level's allowlist.
|
||||
|
||||
Args:
|
||||
strictness: One of "lenient", "standard", "strict".
|
||||
modules: Module names to add to the allowlist.
|
||||
"""
|
||||
if strictness not in STRICTNESS_PROFILES:
|
||||
raise ValueError(f"Invalid strictness level: {strictness}. Must be one of: {', '.join(STRICTNESS_PROFILES.keys())}")
|
||||
|
||||
profile = STRICTNESS_PROFILES[strictness]
|
||||
if profile.get("allowlist_modules") is None:
|
||||
return # lenient mode allows everything, nothing to add
|
||||
|
||||
profile["allowlist_modules"].update(modules)
|
||||
|
||||
|
||||
# Apply extra modules from config.yml on module load
|
||||
for _level, _mods in config.code_execution.extra_allowed_modules.items():
|
||||
if isinstance(_mods, list) and _mods:
|
||||
register_extra_modules(_level, _mods)
|
||||
|
||||
|
||||
@tool(
|
||||
name="execute_python",
|
||||
description="Execute Python code in a sandboxed environment. Most standard library modules are allowed, with dangerous modules (os, subprocess, socket, etc.) blocked. Max execution time: 10 seconds.",
|
||||
description="Execute Python code in a sandboxed environment with configurable strictness levels (lenient/standard/strict). "
|
||||
"Default: 'standard' mode - balances security and flexibility with 10s timeout. "
|
||||
"Use 'lenient' for data processing tasks (30s timeout, more modules allowed). "
|
||||
"Use 'strict' for basic calculations only (5s timeout, minimal module access).",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "Python code to execute. Dangerous modules (os, subprocess, socket, etc.) are blocked."
|
||||
|
||||
"description": "Python code to execute. Available modules depend on strictness level."
|
||||
},
|
||||
"strictness": {
|
||||
"type": "string",
|
||||
"enum": ["lenient", "standard", "strict"],
|
||||
"description": "Optional. Security strictness level (default: standard). "
|
||||
"lenient: 30s timeout, most modules allowed; "
|
||||
"standard: 10s timeout, balanced security; "
|
||||
"strict: 5s timeout, minimal permissions."
|
||||
}
|
||||
},
|
||||
"required": ["code"]
|
||||
|
|
@ -61,36 +137,76 @@ def execute_python(arguments: dict) -> dict:
|
|||
Execute Python code safely with sandboxing.
|
||||
|
||||
Security measures:
|
||||
1. Blocked dangerous imports (blacklist)
|
||||
2. Blocked dangerous builtins
|
||||
3. Timeout limit (10s)
|
||||
4. No file system access
|
||||
5. No network access
|
||||
1. Lenient mode: no restrictions
|
||||
2. Standard/strict mode: allowlist based module restrictions
|
||||
3. Configurable blocked builtins based on strictness level
|
||||
4. Timeout limit (5s/10s/30s based on strictness)
|
||||
5. Subprocess isolation
|
||||
"""
|
||||
code = arguments["code"]
|
||||
strictness = arguments.get("strictness", config.code_execution.default_strictness)
|
||||
|
||||
# Security check: detect dangerous imports
|
||||
dangerous_imports = _check_dangerous_imports(code)
|
||||
if dangerous_imports:
|
||||
# Validate strictness level
|
||||
if strictness not in STRICTNESS_PROFILES:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Blocked imports: {', '.join(dangerous_imports)}. These modules are not allowed for security reasons."
|
||||
"error": f"Invalid strictness level: {strictness}. Must be one of: {', '.join(STRICTNESS_PROFILES.keys())}"
|
||||
}
|
||||
|
||||
# Security check: detect dangerous function calls
|
||||
dangerous_calls = _check_dangerous_calls(code)
|
||||
# Get profile configuration
|
||||
profile = STRICTNESS_PROFILES[strictness]
|
||||
allowlist_modules = profile.get("allowlist_modules")
|
||||
blocked_builtins = profile["blocked_builtins"]
|
||||
timeout = profile["timeout"]
|
||||
|
||||
# Parse and validate code syntax first
|
||||
try:
|
||||
tree = ast.parse(code)
|
||||
except SyntaxError as e:
|
||||
return {"success": False, "error": f"Syntax error in code: {e}"}
|
||||
|
||||
# Security check: detect disallowed imports
|
||||
disallowed_imports = _check_disallowed_imports(tree, allowlist_modules)
|
||||
if disallowed_imports:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Blocked imports: {', '.join(disallowed_imports)}. These modules are not allowed in '{strictness}' mode."
|
||||
}
|
||||
|
||||
# Security check: detect dangerous function calls (skip if no restrictions)
|
||||
if blocked_builtins:
|
||||
dangerous_calls = _check_dangerous_calls(tree, blocked_builtins)
|
||||
if dangerous_calls:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Blocked functions: {', '.join(dangerous_calls)}"
|
||||
"error": f"Blocked functions: {', '.join(dangerous_calls)}. These functions are not allowed in '{strictness}' mode."
|
||||
}
|
||||
|
||||
# Execute in isolated subprocess
|
||||
# Choose execution backend
|
||||
backend = config.code_execution.backend
|
||||
if backend == "docker":
|
||||
# Use Docker executor
|
||||
executor = DockerExecutor(
|
||||
image=config.code_execution.docker_image,
|
||||
network=config.code_execution.docker_network,
|
||||
user=config.code_execution.docker_user,
|
||||
memory_limit=config.code_execution.docker_memory_limit,
|
||||
cpu_shares=config.code_execution.docker_cpu_shares,
|
||||
)
|
||||
result = executor.execute(
|
||||
code=code,
|
||||
timeout=timeout,
|
||||
strictness=strictness,
|
||||
)
|
||||
# Docker executor already returns the same dict structure
|
||||
return result
|
||||
else:
|
||||
# Default subprocess backend
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-c", _build_safe_code(code)],
|
||||
[sys.executable, "-c", _build_safe_code(code, blocked_builtins, allowlist_modules)],
|
||||
capture_output=True,
|
||||
timeout=10,
|
||||
timeout=timeout,
|
||||
cwd=tempfile.gettempdir(),
|
||||
encoding="utf-8",
|
||||
env={ # Clear environment variables
|
||||
|
|
@ -99,18 +215,25 @@ def execute_python(arguments: dict) -> dict:
|
|||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
return {"success": True, "output": result.stdout}
|
||||
return {
|
||||
"success": True,
|
||||
"output": result.stdout,
|
||||
"strictness": strictness,
|
||||
"timeout": timeout
|
||||
}
|
||||
else:
|
||||
return {"success": False, "error": result.stderr or "Execution failed"}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {"success": False, "error": "Execution timeout (10s limit)"}
|
||||
return {"success": False, "error": f"Execution timeout ({timeout}s limit in '{strictness}' mode)"}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": f"Execution error: {str(e)}"}
|
||||
|
||||
|
||||
def _build_safe_code(code: str) -> str:
|
||||
"""Build sandboxed code with restricted globals"""
|
||||
def _build_safe_code(code: str, blocked_builtins: Set[str],
|
||||
allowlist_modules: Set[str] | None = None) -> str:
|
||||
"""Build sandboxed code with restricted globals and runtime import hook."""
|
||||
allowlist_repr = "None" if allowlist_modules is None else repr(allowlist_modules)
|
||||
template = textwrap.dedent('''
|
||||
import builtins
|
||||
|
||||
|
|
@ -118,6 +241,20 @@ def _build_safe_code(code: str) -> str:
|
|||
_BLOCKED = %r
|
||||
_safe_builtins = {k: getattr(builtins, k) for k in dir(builtins) if k not in _BLOCKED}
|
||||
|
||||
# Runtime import hook for allowlist enforcement
|
||||
_ALLOWLIST = %s
|
||||
if _ALLOWLIST is not None:
|
||||
_original_import = builtins.__import__
|
||||
def _restricted_import(name, *args, **kwargs):
|
||||
top_level = name.split(".")[0]
|
||||
if top_level not in _ALLOWLIST:
|
||||
raise ImportError(
|
||||
f"'{top_level}' is not allowed in the current strictness mode"
|
||||
)
|
||||
return _original_import(name, *args, **kwargs)
|
||||
builtins.__import__ = _restricted_import
|
||||
_safe_builtins["__import__"] = _restricted_import
|
||||
|
||||
# Create safe namespace
|
||||
_safe_globals = {
|
||||
"__builtins__": _safe_builtins,
|
||||
|
|
@ -128,44 +265,43 @@ def _build_safe_code(code: str) -> str:
|
|||
exec(%r, _safe_globals)
|
||||
''').strip()
|
||||
|
||||
return template % (BLOCKED_BUILTINS, code)
|
||||
return template % (blocked_builtins, allowlist_repr, code)
|
||||
|
||||
|
||||
def _check_dangerous_imports(code: str) -> list:
|
||||
"""Check for blocked (blacklisted) imports"""
|
||||
try:
|
||||
tree = ast.parse(code)
|
||||
except SyntaxError:
|
||||
def _check_disallowed_imports(tree: ast.AST, allowlist_modules: Set[str] | None) -> List[str]:
|
||||
"""Check for imports not in allowlist. None allowlist means everything is allowed."""
|
||||
if allowlist_modules is None:
|
||||
return []
|
||||
|
||||
dangerous = []
|
||||
disallowed = []
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
module = alias.name.split(".")[0]
|
||||
if module in BLOCKED_MODULES:
|
||||
dangerous.append(module)
|
||||
if module not in allowlist_modules:
|
||||
disallowed.append(module)
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
if node.module:
|
||||
module = node.module.split(".")[0]
|
||||
if module in BLOCKED_MODULES:
|
||||
dangerous.append(module)
|
||||
if module not in allowlist_modules:
|
||||
disallowed.append(module)
|
||||
|
||||
return dangerous
|
||||
return list(dict.fromkeys(disallowed)) # deduplicate while preserving order
|
||||
|
||||
|
||||
def _check_dangerous_calls(code: str) -> list:
|
||||
"""Check for blocked function calls"""
|
||||
try:
|
||||
tree = ast.parse(code)
|
||||
except SyntaxError:
|
||||
return []
|
||||
|
||||
def _check_dangerous_calls(tree: ast.AST, blocked_builtins: Set[str]) -> List[str]:
|
||||
"""Check for blocked function calls including attribute access patterns."""
|
||||
dangerous = []
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Call):
|
||||
if isinstance(node.func, ast.Name):
|
||||
if node.func.id in BLOCKED_BUILTINS:
|
||||
# Direct call: eval("...")
|
||||
if node.func.id in blocked_builtins:
|
||||
dangerous.append(node.func.id)
|
||||
elif isinstance(node.func, ast.Attribute):
|
||||
# Attribute call: builtins.open(...) or os.system(...)
|
||||
attr_name = node.func.attr
|
||||
if attr_name in blocked_builtins:
|
||||
dangerous.append(attr_name)
|
||||
|
||||
return dangerous
|
||||
return list(dict.fromkeys(dangerous))
|
||||
|
|
|
|||
|
|
@ -0,0 +1,156 @@
|
|||
"""Docker-based code execution with isolation."""
|
||||
import subprocess
|
||||
import tempfile
|
||||
from typing import Optional, Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
from backend.config import config
|
||||
|
||||
|
||||
class DockerExecutor:
|
||||
"""Execute Python code in isolated Docker containers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image: str = "python:3.12-slim",
|
||||
network: str = "none",
|
||||
user: str = "nobody",
|
||||
workdir: str = "/workspace",
|
||||
memory_limit: Optional[str] = None,
|
||||
cpu_shares: Optional[int] = None,
|
||||
):
|
||||
self.image = image
|
||||
self.network = network
|
||||
self.user = user
|
||||
self.workdir = workdir
|
||||
self.memory_limit = memory_limit
|
||||
self.cpu_shares = cpu_shares
|
||||
|
||||
def execute(
|
||||
self,
|
||||
code: str,
|
||||
timeout: int,
|
||||
strictness: str,
|
||||
extra_env: Optional[Dict[str, str]] = None,
|
||||
mount_src: Optional[str] = None,
|
||||
mount_dst: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute Python code in a Docker container.
|
||||
|
||||
Args:
|
||||
code: Python code to execute.
|
||||
timeout: Maximum execution time in seconds.
|
||||
strictness: Strictness level (lenient/standard/strict) for logging.
|
||||
extra_env: Additional environment variables.
|
||||
mount_src: Host path to mount into container (optional).
|
||||
mount_dst: Container mount path (defaults to workdir).
|
||||
|
||||
Returns:
|
||||
Dictionary with keys:
|
||||
success: bool
|
||||
output: str if success else empty
|
||||
error: str if not success else empty
|
||||
container_id: str for debugging
|
||||
"""
|
||||
# Create temporary file with code inside a temporary directory
|
||||
# so we can mount it into container
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
code_path = Path(tmpdir) / "code.py"
|
||||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
# Build docker run command
|
||||
cmd = [
|
||||
"docker", "run",
|
||||
"--rm",
|
||||
f"--network={self.network}",
|
||||
f"--user={self.user}",
|
||||
f"--workdir={self.workdir}",
|
||||
f"--env=PYTHONIOENCODING=utf-8",
|
||||
]
|
||||
|
||||
# Add memory limit if specified
|
||||
if self.memory_limit:
|
||||
cmd.append(f"--memory={self.memory_limit}")
|
||||
|
||||
# Add CPU shares if specified
|
||||
if self.cpu_shares:
|
||||
cmd.append(f"--cpu-shares={self.cpu_shares}")
|
||||
|
||||
# Add timeout via --stop-timeout (seconds before SIGKILL)
|
||||
# Docker's timeout is different; we'll use subprocess timeout instead.
|
||||
# We'll rely on subprocess timeout, but also set --stop-timeout as backup.
|
||||
stop_timeout = timeout + 2 # give 2 seconds grace
|
||||
cmd.append(f"--stop-timeout={stop_timeout}")
|
||||
|
||||
# Mount the temporary directory as /workspace (read-only)
|
||||
cmd.extend(["-v", f"{tmpdir}:{self.workdir}:ro"])
|
||||
|
||||
# Additional mount if provided
|
||||
if mount_src and mount_dst:
|
||||
cmd.extend(["-v", f"{mount_src}:{mount_dst}:ro"])
|
||||
|
||||
# Add environment variables
|
||||
env = extra_env or {}
|
||||
for k, v in env.items():
|
||||
cmd.extend(["-e", f"{k}={v}"])
|
||||
|
||||
# Finally, image and command to run
|
||||
cmd.append(self.image)
|
||||
cmd.extend(["python", "-c", code])
|
||||
|
||||
# Execute docker run with timeout
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
timeout=timeout,
|
||||
encoding="utf-8",
|
||||
errors="ignore",
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
return {
|
||||
"success": True,
|
||||
"output": result.stdout,
|
||||
"error": "",
|
||||
"container_id": "", # not available with --rm
|
||||
"strictness": strictness,
|
||||
"timeout": timeout,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"output": "",
|
||||
"error": result.stderr or f"Container exited with code {result.returncode}",
|
||||
"container_id": "",
|
||||
"strictness": strictness,
|
||||
"timeout": timeout,
|
||||
}
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
"success": False,
|
||||
"output": "",
|
||||
"error": f"Execution timeout ({timeout}s limit in '{strictness}' mode)",
|
||||
"container_id": "",
|
||||
"strictness": strictness,
|
||||
"timeout": timeout,
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"output": "",
|
||||
"error": f"Docker execution error: {str(e)}",
|
||||
"container_id": "",
|
||||
"strictness": strictness,
|
||||
"timeout": timeout,
|
||||
}
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_default_executor = DockerExecutor()
|
||||
|
||||
|
||||
def execute_in_docker(code: str, timeout: int, strictness: str, **kwargs) -> Dict[str, Any]:
|
||||
"""Convenience function using default executor."""
|
||||
return _default_executor.execute(code, timeout, strictness, **kwargs)
|
||||
|
|
@ -51,30 +51,89 @@ class ToolExecutor:
|
|||
return record["result"]
|
||||
return None
|
||||
|
||||
def clear_history(self) -> None:
|
||||
"""Clear call history (call this at start of new conversation turn)"""
|
||||
self._call_history.clear()
|
||||
|
||||
@staticmethod
|
||||
def _inject_context(name: str, args: dict, context: Optional[dict]) -> None:
|
||||
"""Inject context fields into tool arguments in-place.
|
||||
|
||||
- file_* tools: inject project_id
|
||||
- agent_task: inject model and project_id (prefixed with _ to avoid collisions)
|
||||
- parallel_execute: inject project_id (prefixed with _ to avoid collisions)
|
||||
- agent tools (multi_agent): inject _model and _project_id
|
||||
"""
|
||||
if not context:
|
||||
return
|
||||
if name.startswith("file_") and "project_id" in context:
|
||||
args["project_id"] = context["project_id"]
|
||||
if name == "agent_task":
|
||||
if name == "multi_agent":
|
||||
if "model" in context:
|
||||
args["_model"] = context["model"]
|
||||
if "project_id" in context:
|
||||
args["_project_id"] = context["project_id"]
|
||||
if name == "parallel_execute":
|
||||
if "project_id" in context:
|
||||
args["_project_id"] = context["project_id"]
|
||||
if "max_tokens" in context:
|
||||
args["_max_tokens"] = context["max_tokens"]
|
||||
if "temperature" in context:
|
||||
args["_temperature"] = context["temperature"]
|
||||
|
||||
def _prepare_call(
|
||||
self,
|
||||
call: dict,
|
||||
context: Optional[dict],
|
||||
seen_calls: set,
|
||||
) -> tuple:
|
||||
"""Parse, inject context, check dedup/cache for a single tool call.
|
||||
|
||||
Returns a tagged tuple:
|
||||
("error", call_id, name, error_msg)
|
||||
("cached", call_id, name, result_dict) -- dedup or cache hit
|
||||
("execute", call_id, name, args, cache_key)
|
||||
"""
|
||||
name = call["function"]["name"]
|
||||
args_str = call["function"]["arguments"]
|
||||
call_id = call["id"]
|
||||
|
||||
# Parse JSON arguments
|
||||
try:
|
||||
args = json.loads(args_str) if isinstance(args_str, str) else args_str
|
||||
except json.JSONDecodeError:
|
||||
return ("error", call_id, name, "Invalid JSON arguments")
|
||||
|
||||
# Inject context
|
||||
self._inject_context(name, args, context)
|
||||
|
||||
# Dedup within same batch
|
||||
call_key = f"{name}:{json.dumps(args, sort_keys=True)}"
|
||||
if call_key in seen_calls:
|
||||
return ("cached", call_id, name,
|
||||
{"success": True, "data": None, "cached": True, "duplicate": True})
|
||||
seen_calls.add(call_key)
|
||||
|
||||
# History dedup
|
||||
history_result = self._check_duplicate_in_history(name, args)
|
||||
if history_result is not None:
|
||||
return ("cached", call_id, name, {**history_result, "cached": True})
|
||||
|
||||
# Cache check
|
||||
cache_key = self._make_cache_key(name, args)
|
||||
cached_result = self._get_cached(cache_key)
|
||||
if cached_result is not None:
|
||||
return ("cached", call_id, name, {**cached_result, "cached": True})
|
||||
|
||||
return ("execute", call_id, name, args, cache_key)
|
||||
|
||||
def _execute_and_record(
|
||||
self,
|
||||
name: str,
|
||||
args: dict,
|
||||
cache_key: str,
|
||||
) -> dict:
|
||||
"""Execute a tool, cache result, record history, and return raw result dict."""
|
||||
result = self._execute_tool(name, args)
|
||||
if result.get("success"):
|
||||
self._set_cache(cache_key, result)
|
||||
self._call_history.append({
|
||||
"name": name,
|
||||
"args_str": json.dumps(args, sort_keys=True, ensure_ascii=False),
|
||||
"result": result,
|
||||
})
|
||||
return result
|
||||
|
||||
def process_tool_calls_parallel(
|
||||
self,
|
||||
|
|
@ -85,10 +144,6 @@ class ToolExecutor:
|
|||
"""
|
||||
Process tool calls concurrently and return message list (ordered by input).
|
||||
|
||||
Identical logic to process_tool_calls but uses ThreadPoolExecutor so that
|
||||
independent tool calls (e.g. reading 3 files, running 2 searches) execute
|
||||
in parallel instead of sequentially.
|
||||
|
||||
Args:
|
||||
tool_calls: Tool call list returned by LLM
|
||||
context: Optional context info (user_id, project_id, etc.)
|
||||
|
|
@ -102,80 +157,31 @@ class ToolExecutor:
|
|||
|
||||
max_workers = min(max(max_workers, 1), 6)
|
||||
|
||||
# Phase 1: prepare each call (parse args, inject context, check dedup/cache)
|
||||
# This phase is fast and sequential – it must be done before parallelism
|
||||
# to avoid race conditions on seen_calls / _call_history / _cache.
|
||||
prepared: List[Optional[tuple]] = [None] * len(tool_calls)
|
||||
seen_calls: set = set()
|
||||
# Phase 1: prepare (sequential – avoids race conditions on shared state)
|
||||
prepared = [self._prepare_call(call, context, set()) for call in tool_calls]
|
||||
|
||||
for i, call in enumerate(tool_calls):
|
||||
name = call["function"]["name"]
|
||||
args_str = call["function"]["arguments"]
|
||||
call_id = call["id"]
|
||||
|
||||
# Parse JSON arguments
|
||||
try:
|
||||
args = json.loads(args_str) if isinstance(args_str, str) else args_str
|
||||
except json.JSONDecodeError:
|
||||
prepared[i] = self._create_error_result(call_id, name, "Invalid JSON arguments")
|
||||
continue
|
||||
|
||||
# Inject context into tool arguments
|
||||
self._inject_context(name, args, context)
|
||||
|
||||
# Dedup within same batch
|
||||
call_key = f"{name}:{json.dumps(args, sort_keys=True)}"
|
||||
if call_key in seen_calls:
|
||||
prepared[i] = self._create_tool_result(
|
||||
call_id, name,
|
||||
{"success": True, "data": None, "cached": True, "duplicate": True}
|
||||
)
|
||||
continue
|
||||
seen_calls.add(call_key)
|
||||
|
||||
# History dedup
|
||||
history_result = self._check_duplicate_in_history(name, args)
|
||||
if history_result is not None:
|
||||
prepared[i] = self._create_tool_result(call_id, name, {**history_result, "cached": True})
|
||||
continue
|
||||
|
||||
# Cache check
|
||||
cache_key = self._make_cache_key(name, args)
|
||||
cached_result = self._get_cached(cache_key)
|
||||
if cached_result is not None:
|
||||
prepared[i] = self._create_tool_result(call_id, name, {**cached_result, "cached": True})
|
||||
continue
|
||||
|
||||
# Mark as needing actual execution
|
||||
prepared[i] = ("execute", call_id, name, args, cache_key)
|
||||
|
||||
# Separate pre-resolved results from tasks needing execution
|
||||
# Phase 2: separate pre-resolved from tasks needing execution
|
||||
results: List[dict] = [None] * len(tool_calls)
|
||||
exec_tasks: Dict[int, tuple] = {} # index -> (call_id, name, args, cache_key)
|
||||
exec_tasks: Dict[int, tuple] = {}
|
||||
|
||||
for i, item in enumerate(prepared):
|
||||
if isinstance(item, dict):
|
||||
results[i] = item
|
||||
elif isinstance(item, tuple) and item[0] == "execute":
|
||||
tag = item[0]
|
||||
if tag == "error":
|
||||
_, call_id, name, error_msg = item
|
||||
results[i] = self._create_error_result(call_id, name, error_msg)
|
||||
elif tag == "cached":
|
||||
_, call_id, name, result_dict = item
|
||||
results[i] = self._create_tool_result(call_id, name, result_dict)
|
||||
else: # "execute"
|
||||
_, call_id, name, args, cache_key = item
|
||||
exec_tasks[i] = (call_id, name, args, cache_key)
|
||||
|
||||
# Phase 2: execute remaining calls in parallel
|
||||
# Phase 3: execute remaining calls in parallel
|
||||
if exec_tasks:
|
||||
def _run(idx: int, call_id: str, name: str, args: dict, cache_key: str) -> tuple:
|
||||
t0 = time.time()
|
||||
result = self._execute_tool(name, args)
|
||||
result = self._execute_and_record(name, args, cache_key)
|
||||
elapsed = time.time() - t0
|
||||
|
||||
if result.get("success"):
|
||||
self._set_cache(cache_key, result)
|
||||
|
||||
self._call_history.append({
|
||||
"name": name,
|
||||
"args_str": json.dumps(args, sort_keys=True, ensure_ascii=False),
|
||||
"result": result,
|
||||
})
|
||||
|
||||
return idx, self._create_tool_result(call_id, name, result, execution_time=elapsed)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
|
|
@ -205,64 +211,21 @@ class ToolExecutor:
|
|||
Tool response message list, can be appended to messages
|
||||
"""
|
||||
results = []
|
||||
seen_calls = set() # Track calls within this batch
|
||||
seen_calls: set = set()
|
||||
|
||||
for call in tool_calls:
|
||||
name = call["function"]["name"]
|
||||
args_str = call["function"]["arguments"]
|
||||
call_id = call["id"]
|
||||
|
||||
try:
|
||||
args = json.loads(args_str) if isinstance(args_str, str) else args_str
|
||||
except json.JSONDecodeError:
|
||||
results.append(self._create_error_result(
|
||||
call_id, name, "Invalid JSON arguments"
|
||||
))
|
||||
continue
|
||||
|
||||
# Inject context into tool arguments
|
||||
self._inject_context(name, args, context)
|
||||
|
||||
# Check for duplicate within same batch
|
||||
call_key = f"{name}:{json.dumps(args, sort_keys=True)}"
|
||||
if call_key in seen_calls:
|
||||
# Skip duplicate, but still return a result
|
||||
results.append(self._create_tool_result(
|
||||
call_id, name,
|
||||
{"success": True, "data": None, "cached": True, "duplicate": True}
|
||||
))
|
||||
continue
|
||||
seen_calls.add(call_key)
|
||||
|
||||
# Check history for previous call in this session
|
||||
history_result = self._check_duplicate_in_history(name, args)
|
||||
if history_result is not None:
|
||||
result = {**history_result, "cached": True}
|
||||
results.append(self._create_tool_result(call_id, name, result))
|
||||
continue
|
||||
|
||||
# Check cache
|
||||
cache_key = self._make_cache_key(name, args)
|
||||
cached_result = self._get_cached(cache_key)
|
||||
if cached_result is not None:
|
||||
result = {**cached_result, "cached": True}
|
||||
results.append(self._create_tool_result(call_id, name, result))
|
||||
continue
|
||||
|
||||
# Execute tool with retry
|
||||
result = self._execute_tool(name, args)
|
||||
|
||||
# Cache the result (only cache successful results)
|
||||
if result.get("success"):
|
||||
self._set_cache(cache_key, result)
|
||||
|
||||
# Add to history
|
||||
self._call_history.append({
|
||||
"name": name,
|
||||
"args_str": json.dumps(args, sort_keys=True, ensure_ascii=False),
|
||||
"result": result
|
||||
})
|
||||
prepared = self._prepare_call(call, context, seen_calls)
|
||||
tag = prepared[0]
|
||||
|
||||
if tag == "error":
|
||||
_, call_id, name, error_msg = prepared
|
||||
results.append(self._create_error_result(call_id, name, error_msg))
|
||||
elif tag == "cached":
|
||||
_, call_id, name, result_dict = prepared
|
||||
results.append(self._create_tool_result(call_id, name, result_dict))
|
||||
else: # "execute"
|
||||
_, call_id, name, args, cache_key = prepared
|
||||
result = self._execute_and_record(name, args, cache_key)
|
||||
results.append(self._create_tool_result(call_id, name, result))
|
||||
|
||||
return results
|
||||
|
|
|
|||
|
|
@ -34,30 +34,3 @@ def tool(
|
|||
return func
|
||||
return decorator
|
||||
|
||||
|
||||
def register_tool(
|
||||
name: str,
|
||||
handler: Callable,
|
||||
description: str,
|
||||
parameters: dict,
|
||||
category: str = "general"
|
||||
) -> None:
|
||||
"""
|
||||
Register a tool directly (without decorator)
|
||||
|
||||
Usage:
|
||||
register_tool(
|
||||
name="my_tool",
|
||||
handler=my_function,
|
||||
description="Description",
|
||||
parameters={...}
|
||||
)
|
||||
"""
|
||||
tool_def = ToolDefinition(
|
||||
name=name,
|
||||
description=description,
|
||||
parameters=parameters,
|
||||
handler=handler,
|
||||
category=category
|
||||
)
|
||||
registry.register(tool_def)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
"""Backend utilities"""
|
||||
from backend.utils.helpers import ok, err, to_dict, get_or_create_default_user, record_token_usage, build_messages
|
||||
from backend.utils.helpers import ok, err, to_dict, message_to_dict, record_token_usage, build_messages
|
||||
|
||||
__all__ = [
|
||||
"ok",
|
||||
"err",
|
||||
"to_dict",
|
||||
"get_or_create_default_user",
|
||||
"message_to_dict",
|
||||
"record_token_usage",
|
||||
"build_messages",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""Common helper functions"""
|
||||
import json
|
||||
from datetime import date, datetime
|
||||
from datetime import date, datetime, timezone
|
||||
from typing import Any
|
||||
from flask import jsonify
|
||||
from backend import db
|
||||
|
|
@ -97,7 +97,7 @@ def message_to_dict(msg: Message) -> dict:
|
|||
|
||||
def record_token_usage(user_id, model, prompt_tokens, completion_tokens):
|
||||
"""Record token usage"""
|
||||
today = date.today()
|
||||
today = datetime.now(timezone.utc).date()
|
||||
usage = TokenUsage.query.filter_by(
|
||||
user_id=user_id, date=today, model=model
|
||||
).first()
|
||||
|
|
@ -133,6 +133,13 @@ def build_messages(conv, project_id=None):
|
|||
# Query messages directly to avoid detached instance warning
|
||||
messages = Message.query.filter_by(conversation_id=conv.id).order_by(Message.created_at.asc()).all()
|
||||
for m in messages:
|
||||
# Skip tool messages — they are ephemeral intermediate results, not
|
||||
# meant to be replayed as conversation history (would violate the API
|
||||
# protocol that requires tool messages to follow an assistant message
|
||||
# with matching tool_calls).
|
||||
if m.role == "tool":
|
||||
continue
|
||||
|
||||
# Build full content from JSON structure
|
||||
full_content = m.content
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -66,6 +66,7 @@ backend/
|
|||
│ ├── data.py # 计算器、文本、JSON
|
||||
│ ├── weather.py # 天气查询
|
||||
│ ├── file_ops.py # 文件操作(project_id 自动注入)
|
||||
│ ├── agent.py # 多智能体(子 Agent 并发执行,工具权限隔离)
|
||||
│ └── code.py # 代码执行
|
||||
│
|
||||
├── utils/ # 辅助函数
|
||||
|
|
@ -266,8 +267,8 @@ classDiagram
|
|||
-ToolRegistry registry
|
||||
-dict _cache
|
||||
-list _call_history
|
||||
+process_tool_calls(calls, context) list
|
||||
+clear_history() void
|
||||
+process_tool_calls(list, dict) list
|
||||
+process_tool_calls_parallel(list, dict, int) list
|
||||
}
|
||||
|
||||
ChatService --> LLMClient : 使用
|
||||
|
|
@ -295,18 +296,17 @@ classDiagram
|
|||
+register(ToolDefinition) void
|
||||
+get(str name) ToolDefinition?
|
||||
+list_all() list~dict~
|
||||
+list_by_category(str) list~dict~
|
||||
+execute(str name, dict args) dict
|
||||
+remove(str name) bool
|
||||
+has(str name) bool
|
||||
}
|
||||
|
||||
class ToolExecutor {
|
||||
-ToolRegistry registry
|
||||
-bool enable_cache
|
||||
-int cache_ttl
|
||||
-dict _cache
|
||||
-list _call_history
|
||||
+process_tool_calls(list, dict) list
|
||||
+clear_history() void
|
||||
+process_tool_calls_parallel(list, dict, int) list
|
||||
}
|
||||
|
||||
class ToolResult {
|
||||
|
|
@ -394,18 +394,19 @@ def validate_path_in_project(path: str, project_dir: Path) -> Path:
|
|||
工具执行器自动为文件工具注入 `project_id`:
|
||||
|
||||
```python
|
||||
# backend/tools/executor.py
|
||||
# backend/tools/executor.py — _inject_context()
|
||||
|
||||
def process_tool_calls(self, tool_calls, context=None):
|
||||
for call in tool_calls:
|
||||
name = call["function"]["name"]
|
||||
args = json.loads(call["function"]["arguments"])
|
||||
|
||||
# 自动注入 project_id
|
||||
if context and name.startswith("file_") and "project_id" in context:
|
||||
@staticmethod
|
||||
def _inject_context(name: str, args: dict, context: Optional[dict]) -> None:
|
||||
# file_* 工具: 注入 project_id
|
||||
if name.startswith("file_") and "project_id" in context:
|
||||
args["project_id"] = context["project_id"]
|
||||
|
||||
result = self.registry.execute(name, args)
|
||||
# agent 工具: 注入 _model 和 _project_id
|
||||
if name == "multi_agent":
|
||||
if "model" in context:
|
||||
args["_model"] = context["model"]
|
||||
if "project_id" in context:
|
||||
args["_project_id"] = context["project_id"]
|
||||
```
|
||||
|
||||
---
|
||||
|
|
@ -1020,6 +1021,12 @@ frontend_port: 4000
|
|||
# 智能体循环最大迭代次数(工具调用轮次上限,默认 5)
|
||||
max_iterations: 15
|
||||
|
||||
# 子代理资源配置(multi_agent 工具)
|
||||
# max_tokens 和 temperature 与主 Agent 共用,无需单独配置
|
||||
sub_agent:
|
||||
max_iterations: 3 # 每个子代理的最大工具调用轮数
|
||||
max_concurrency: 3 # 并发线程数
|
||||
|
||||
# 可用模型列表(每个模型必须指定 api_url 和 api_key)
|
||||
# 支持任何 OpenAI 兼容 API(DeepSeek、GLM、OpenAI、Moonshot、Qwen 等)
|
||||
models:
|
||||
|
|
|
|||
|
|
@ -27,19 +27,20 @@ classDiagram
|
|||
+register(ToolDefinition tool) void
|
||||
+get(str name) ToolDefinition?
|
||||
+list_all() list~dict~
|
||||
+list_by_category(str category) list~dict~
|
||||
+execute(str name, dict args) dict
|
||||
+remove(str name) bool
|
||||
+has(str name) bool
|
||||
}
|
||||
|
||||
class ToolExecutor {
|
||||
-ToolRegistry registry
|
||||
-bool enable_cache
|
||||
-int cache_ttl
|
||||
-dict _cache
|
||||
-list _call_history
|
||||
+process_tool_calls(list tool_calls, dict context) list~dict~
|
||||
+build_request(list messages, str model, list tools, dict kwargs) dict
|
||||
+clear_history() void
|
||||
+process_tool_calls_parallel(list tool_calls, dict context, int max_workers) list~dict~
|
||||
-_prepare_call(dict call, dict context, set seen_calls) tuple
|
||||
-_execute_and_record(str name, dict args, str cache_key) dict
|
||||
-_inject_context(str name, dict args, dict context) void
|
||||
}
|
||||
|
||||
class ToolResult {
|
||||
|
|
@ -88,32 +89,26 @@ classDiagram
|
|||
|
||||
### context 参数
|
||||
|
||||
`process_tool_calls()` 接受 `context` 参数,用于自动注入工具参数:
|
||||
`process_tool_calls()` / `process_tool_calls_parallel()` 接受 `context` 参数,用于自动注入工具参数:
|
||||
|
||||
```python
|
||||
# backend/tools/executor.py
|
||||
# backend/tools/executor.py — _inject_context()
|
||||
|
||||
def process_tool_calls(
|
||||
self,
|
||||
tool_calls: List[dict],
|
||||
context: Optional[dict] = None
|
||||
) -> List[dict]:
|
||||
@staticmethod
|
||||
def _inject_context(name: str, args: dict, context: Optional[dict]) -> None:
|
||||
"""
|
||||
Args:
|
||||
tool_calls: LLM 返回的工具调用列表
|
||||
context: 上下文信息,支持:
|
||||
- project_id: 自动注入到文件工具
|
||||
- file_* 工具: 注入 project_id
|
||||
- agent 工具 (multi_agent): 注入 _model 和 _project_id
|
||||
"""
|
||||
for call in tool_calls:
|
||||
name = call["function"]["name"]
|
||||
args = json.loads(call["function"]["arguments"])
|
||||
|
||||
# 自动注入 project_id 到文件工具
|
||||
if context:
|
||||
if not context:
|
||||
return
|
||||
if name.startswith("file_") and "project_id" in context:
|
||||
args["project_id"] = context["project_id"]
|
||||
|
||||
result = self.registry.execute(name, args)
|
||||
if name == "multi_agent":
|
||||
if "model" in context:
|
||||
args["_model"] = context["model"]
|
||||
if "project_id" in context:
|
||||
args["_project_id"] = context["project_id"]
|
||||
```
|
||||
|
||||
### 使用示例
|
||||
|
|
@ -122,12 +117,12 @@ def process_tool_calls(
|
|||
# backend/services/chat.py
|
||||
|
||||
def stream_response(self, conv, tools_enabled=True, project_id=None):
|
||||
# 构建上下文(优先使用请求传递的 project_id,否则回退到对话绑定的)
|
||||
context = None
|
||||
# 构建上下文(包含 model 和 project_id)
|
||||
context = {"model": conv.model}
|
||||
if project_id:
|
||||
context = {"project_id": project_id}
|
||||
context["project_id"] = project_id
|
||||
elif conv.project_id:
|
||||
context = {"project_id": conv.project_id}
|
||||
context["project_id"] = conv.project_id
|
||||
|
||||
# 处理工具调用时自动注入
|
||||
tool_results = self.executor.process_tool_calls(tool_calls, context)
|
||||
|
|
@ -222,14 +217,68 @@ file_read({"path": "src/main.py", "project_id": "xxx"})
|
|||
|
||||
| 工具名称 | 描述 | 参数 |
|
||||
|---------|------|------|
|
||||
| `execute_python` | 在沙箱环境中执行 Python 代码 | `code`: Python 代码 |
|
||||
| `execute_python` | 在沙箱环境中执行 Python 代码 | `code`: Python 代码<br>`strictness`: 可选,严格等级(lenient/standard/strict) |
|
||||
|
||||
安全措施:
|
||||
- 白名单模块限制
|
||||
- 危险内置函数禁止
|
||||
- 10 秒超时限制
|
||||
- 无文件系统访问
|
||||
- 无网络访问
|
||||
**严格等级配置:**
|
||||
|
||||
| 等级 | 超时 | 策略 | 适用场景 |
|
||||
|------|------|------|---------|
|
||||
| `lenient` | 30s | 无限制,所有模块和内置函数均可使用 | 数据处理、需要完整标准库 |
|
||||
| `standard` | 10s | 白名单机制,仅允许安全模块(默认) | 通用场景 |
|
||||
| `strict` | 5s | 精简白名单,仅允许纯计算模块 | 基础计算 |
|
||||
|
||||
**standard 白名单模块:** json, csv, re, typing, collections, itertools, functools, operator, heapq, bisect, array, copy, pprint, enum, math, cmath, statistics, random, fractions, decimal, numbers, datetime, time, calendar, string, textwrap, unicodedata, difflib, base64, binascii, quopri, uu, html, xml.etree.ElementTree, dataclasses, hashlib, hmac, abc, contextlib, warnings, logging
|
||||
|
||||
**strict 白名单模块:** collections, itertools, functools, operator, array, copy, enum, math, cmath, numbers, fractions, decimal, random, statistics, string, textwrap, unicodedata, typing, dataclasses, abc, contextlib
|
||||
|
||||
**内置函数限制:**
|
||||
- standard 禁止:eval, exec, compile, \_\_import\_\_, open, input, globals, locals, vars, breakpoint, exit, quit, memoryview, bytearray
|
||||
- strict 额外禁止:dir, hasattr, getattr, setattr, delattr, type, isinstance, issubclass
|
||||
|
||||
**白名单扩展方式:**
|
||||
|
||||
1. **config.yml 配置(持久化):**
|
||||
```yaml
|
||||
code_execution:
|
||||
default_strictness: standard
|
||||
extra_allowed_modules:
|
||||
standard: [numpy, pandas]
|
||||
strict: [numpy]
|
||||
```
|
||||
|
||||
2. **代码 API(插件/运行时):**
|
||||
```python
|
||||
from backend.tools.builtin.code import register_extra_modules
|
||||
|
||||
register_extra_modules("standard", {"numpy", "pandas"})
|
||||
register_extra_modules("strict", {"numpy"})
|
||||
```
|
||||
|
||||
**使用示例:**
|
||||
|
||||
```python
|
||||
# 默认 standard 模式(白名单限制)
|
||||
execute_python({"code": "import json; print(json.dumps({'key': 'value'}))"})
|
||||
|
||||
# lenient 模式 - 无限制
|
||||
execute_python({
|
||||
"code": "import os; print(os.getcwd())",
|
||||
"strictness": "lenient"
|
||||
})
|
||||
|
||||
# strict 模式 - 仅纯计算
|
||||
execute_python({
|
||||
"code": "result = sum([1, 2, 3, 4, 5]); print(result)",
|
||||
"strictness": "strict"
|
||||
})
|
||||
```
|
||||
|
||||
**安全措施:**
|
||||
- standard/strict: 白名单模块限制(默认拒绝,仅显式允许)
|
||||
- lenient: 无限制
|
||||
- 危险内置函数按等级禁止
|
||||
- 可配置超时限制(5s/10s/30s)
|
||||
- subprocess 隔离执行
|
||||
|
||||
### 5.4 文件操作工具 (file)
|
||||
|
||||
|
|
@ -250,6 +299,31 @@ file_read({"path": "src/main.py", "project_id": "xxx"})
|
|||
|---------|------|------|
|
||||
| `get_weather` | 查询天气信息(模拟) | `city`: 城市名称 |
|
||||
|
||||
### 5.6 多智能体工具 (agent)
|
||||
|
||||
| 工具名称 | 描述 | 参数 |
|
||||
|---------|------|------|
|
||||
| `multi_agent` | 派生子 Agent 并发执行任务 | `tasks`: 任务数组(name, instruction, tools)<br>`_model`: 模型名称(自动注入)<br>`_project_id`: 项目 ID(自动注入) |
|
||||
|
||||
**`multi_agent` 工作原理:**
|
||||
1. 接收任务数组,每个任务指定 name、instruction 和可选的 tools 列表
|
||||
2. 子 Agent **禁止使用 `multi_agent` 工具**(`BLOCKED_TOOLS`),防止无限递归
|
||||
3. 子 Agent 工具权限与主 Agent 一致(除 multi_agent 外的所有已注册工具),支持并行工具执行
|
||||
4. 为每个子 Agent 创建独立线程,各自拥有 LLM 对话循环
|
||||
5. 子 Agent 在 `app.app_context()` 中运行 LLM 调用和工具执行,确保数据库等依赖正常工作
|
||||
6. 通过 Service Locator 获取 `llm_client` 实例
|
||||
7. 返回 `{success, results: [{task_name, success, response/error}], total}`
|
||||
|
||||
**资源配置**(`config.yml` → `sub_agent`):
|
||||
|
||||
| 配置项 | 默认值 | 说明 |
|
||||
|--------|--------|------|
|
||||
| `max_iterations` | 3 | 每个子代理的最大工具调用轮数 |
|
||||
| `max_concurrency` | 3 | ThreadPoolExecutor 并发线程数 |
|
||||
|
||||
> - `max_tokens` 和 `temperature` 与主 Agent 共用,从对话配置中获取,无需单独配置。
|
||||
> - 子代理禁止调用 `multi_agent` 工具,防止无限递归。
|
||||
|
||||
---
|
||||
|
||||
## 六、核心特性
|
||||
|
|
@ -285,7 +359,6 @@ def my_tool(arguments: dict) -> dict:
|
|||
|
||||
- **批次内去重**:同一批次中相同工具+参数的调用会被跳过
|
||||
- **历史去重**:同一会话内已调用过的工具会直接返回缓存结果
|
||||
- **自动清理**:新会话开始时调用 `clear_history()` 清理历史
|
||||
|
||||
### 6.4 无自动重试
|
||||
|
||||
|
|
@ -308,13 +381,45 @@ def my_tool(arguments: dict) -> dict:
|
|||
def init_tools() -> None:
|
||||
"""初始化所有内置工具"""
|
||||
from backend.tools.builtin import (
|
||||
code, crawler, data, weather, file_ops
|
||||
code, crawler, data, weather, file_ops, agent
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 八、扩展新工具
|
||||
## 八、Service Locator
|
||||
|
||||
工具系统提供 Service Locator 模式,允许工具访问共享服务(如 LLM 客户端):
|
||||
|
||||
```python
|
||||
# backend/tools/__init__.py
|
||||
|
||||
_services: dict = {}
|
||||
|
||||
def register_service(name: str, service) -> None:
|
||||
"""注册共享服务"""
|
||||
_services[name] = service
|
||||
|
||||
def get_service(name: str):
|
||||
"""获取已注册的服务,不存在则返回 None"""
|
||||
return _services.get(name)
|
||||
```
|
||||
|
||||
### 使用方式
|
||||
|
||||
```python
|
||||
# 在应用初始化时注册(routes/__init__.py)
|
||||
from backend.tools import register_service
|
||||
register_service("llm_client", llm_client)
|
||||
|
||||
# 在工具中使用(agent.py)
|
||||
from backend.tools import get_service
|
||||
llm_client = get_service("llm_client")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 九、扩展新工具
|
||||
|
||||
### 添加新工具
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@
|
|||
"@codemirror/lang-xml": "^6.1.0",
|
||||
"@codemirror/lang-yaml": "^6.1.2",
|
||||
"@codemirror/theme-one-dark": "^6.1.2",
|
||||
"chart.js": "^4.5.1",
|
||||
"codemirror": "^6.0.1",
|
||||
"highlight.js": "^11.11.1",
|
||||
"katex": "^0.16.40",
|
||||
|
|
@ -791,6 +792,12 @@
|
|||
"integrity": "sha512-cYQ9310grqxueWbl+WuIUIaiUaDcj7WOq5fVhEljNVgRfOUhY9fy2zTvfoqWsnebh8Sl70VScFbICvJnLKB0Og==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@kurkle/color": {
|
||||
"version": "0.3.4",
|
||||
"resolved": "https://registry.npmmirror.com/@kurkle/color/-/color-0.3.4.tgz",
|
||||
"integrity": "sha512-M5UknZPHRu3DEDWoipU6sE8PdkZ6Z/S+v4dD+Ke8IaNlpdSQah50lz1KtcFBa2vsdOnwbbnxJwVM4wty6udA5w==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@lezer/common": {
|
||||
"version": "1.5.1",
|
||||
"resolved": "https://registry.npmmirror.com/@lezer/common/-/common-1.5.1.tgz",
|
||||
|
|
@ -1430,6 +1437,18 @@
|
|||
"dev": true,
|
||||
"license": "Python-2.0"
|
||||
},
|
||||
"node_modules/chart.js": {
|
||||
"version": "4.5.1",
|
||||
"resolved": "https://registry.npmmirror.com/chart.js/-/chart.js-4.5.1.tgz",
|
||||
"integrity": "sha512-GIjfiT9dbmHRiYi6Nl2yFCq7kkwdkp1W/lp2J99rX0yo9tgJGn3lKQATztIjb5tVtevcBtIdICNWqlq5+E8/Pw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@kurkle/color": "^0.3.0"
|
||||
},
|
||||
"engines": {
|
||||
"pnpm": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/codemirror": {
|
||||
"version": "6.0.2",
|
||||
"resolved": "https://registry.npmmirror.com/codemirror/-/codemirror-6.0.2.tgz",
|
||||
|
|
|
|||
|
|
@ -9,21 +9,22 @@
|
|||
"preview": "vite preview"
|
||||
},
|
||||
"dependencies": {
|
||||
"codemirror": "^6.0.1",
|
||||
"@codemirror/theme-one-dark": "^6.1.2",
|
||||
"@codemirror/lang-markdown": "^6.3.2",
|
||||
"@codemirror/lang-javascript": "^6.2.3",
|
||||
"@codemirror/lang-python": "^6.1.7",
|
||||
"@codemirror/lang-html": "^6.4.9",
|
||||
"@codemirror/lang-css": "^6.3.1",
|
||||
"@codemirror/lang-json": "^6.0.1",
|
||||
"@codemirror/lang-yaml": "^6.1.2",
|
||||
"@codemirror/lang-java": "^6.0.1",
|
||||
"@codemirror/lang-cpp": "^6.0.2",
|
||||
"@codemirror/lang-rust": "^6.0.1",
|
||||
"@codemirror/lang-css": "^6.3.1",
|
||||
"@codemirror/lang-go": "^6.0.1",
|
||||
"@codemirror/lang-html": "^6.4.9",
|
||||
"@codemirror/lang-java": "^6.0.1",
|
||||
"@codemirror/lang-javascript": "^6.2.3",
|
||||
"@codemirror/lang-json": "^6.0.1",
|
||||
"@codemirror/lang-markdown": "^6.3.2",
|
||||
"@codemirror/lang-python": "^6.1.7",
|
||||
"@codemirror/lang-rust": "^6.0.1",
|
||||
"@codemirror/lang-sql": "^6.8.0",
|
||||
"@codemirror/lang-xml": "^6.1.0",
|
||||
"@codemirror/lang-yaml": "^6.1.2",
|
||||
"@codemirror/theme-one-dark": "^6.1.2",
|
||||
"chart.js": "^4.5.1",
|
||||
"codemirror": "^6.0.1",
|
||||
"highlight.js": "^11.11.1",
|
||||
"katex": "^0.16.40",
|
||||
"marked": "^15.0.12",
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@
|
|||
:messages="messages"
|
||||
:streaming="streaming"
|
||||
:streaming-process-steps="streamProcessSteps"
|
||||
:model-name-map="modelNameMap"
|
||||
:has-more-messages="hasMoreMessages"
|
||||
:loading-more="loadingMessages"
|
||||
:tools-enabled="toolsEnabled"
|
||||
|
|
@ -59,8 +60,11 @@
|
|||
<div v-if="showSettings" class="modal-overlay" @click.self="showSettings = false">
|
||||
<div class="modal-content">
|
||||
<SettingsPanel
|
||||
:key="currentConvId || '__none__'"
|
||||
:visible="showSettings"
|
||||
:conversation="currentConv"
|
||||
:models="models"
|
||||
:default-model="defaultModel"
|
||||
@close="showSettings = false"
|
||||
@save="saveSettings"
|
||||
/>
|
||||
|
|
@ -117,13 +121,37 @@ import ModalDialog from './components/ModalDialog.vue'
|
|||
import ToastContainer from './components/ToastContainer.vue'
|
||||
import { icons } from './utils/icons'
|
||||
import { useModal } from './composables/useModal'
|
||||
import {
|
||||
DEFAULT_CONVERSATION_PAGE_SIZE,
|
||||
DEFAULT_MESSAGE_PAGE_SIZE,
|
||||
LS_KEY_TOOLS_ENABLED,
|
||||
} from './constants'
|
||||
|
||||
const SettingsPanel = defineAsyncComponent(() => import('./components/SettingsPanel.vue'))
|
||||
const StatsPanel = defineAsyncComponent(() => import('./components/StatsPanel.vue'))
|
||||
import { conversationApi, messageApi, projectApi } from './api'
|
||||
import { conversationApi, messageApi, projectApi, modelApi } from './api'
|
||||
|
||||
const modal = useModal()
|
||||
|
||||
// -- Models state (preloaded) --
|
||||
const models = ref([])
|
||||
const modelNameMap = ref({})
|
||||
const defaultModel = computed(() => models.value.length > 0 ? models.value[0].id : '')
|
||||
|
||||
async function loadModels() {
|
||||
try {
|
||||
const res = await modelApi.getCached()
|
||||
models.value = res.data || []
|
||||
const map = {}
|
||||
for (const m of models.value) {
|
||||
if (m.id && m.name) map[m.id] = m.name
|
||||
}
|
||||
modelNameMap.value = map
|
||||
} catch (e) {
|
||||
console.error('Failed to load models:', e)
|
||||
}
|
||||
}
|
||||
|
||||
// -- Conversations state --
|
||||
const conversations = shallowRef([])
|
||||
const currentConvId = ref(null)
|
||||
|
|
@ -203,7 +231,7 @@ function updateStreamField(convId, field, ref, valueOrUpdater) {
|
|||
// -- UI state --
|
||||
const showSettings = ref(false)
|
||||
const showStats = ref(false)
|
||||
const toolsEnabled = ref(localStorage.getItem('tools_enabled') !== 'false')
|
||||
const toolsEnabled = ref(localStorage.getItem(LS_KEY_TOOLS_ENABLED) !== 'false')
|
||||
const currentProject = ref(null)
|
||||
const showFileExplorer = ref(false)
|
||||
const showCreateModal = ref(false)
|
||||
|
|
@ -227,7 +255,7 @@ async function loadConversations(reset = true) {
|
|||
if (loadingConvs.value) return
|
||||
loadingConvs.value = true
|
||||
try {
|
||||
const res = await conversationApi.list(reset ? null : nextConvCursor.value, 20)
|
||||
const res = await conversationApi.list(reset ? null : nextConvCursor.value, DEFAULT_CONVERSATION_PAGE_SIZE)
|
||||
if (reset) {
|
||||
conversations.value = res.data.items
|
||||
} else {
|
||||
|
|
@ -258,6 +286,7 @@ async function createConversationInProject(project) {
|
|||
const res = await conversationApi.create({
|
||||
title: '新对话',
|
||||
project_id: project.id || null,
|
||||
model: defaultModel.value || undefined,
|
||||
})
|
||||
conversations.value = [res.data, ...conversations.value]
|
||||
await selectConversation(res.data.id)
|
||||
|
|
@ -412,7 +441,7 @@ function createStreamCallbacks(convId, { updateConvList = true } = {}) {
|
|||
}
|
||||
} else {
|
||||
try {
|
||||
const res = await messageApi.list(convId, null, 50)
|
||||
const res = await messageApi.list(convId, null, DEFAULT_MESSAGE_PAGE_SIZE)
|
||||
const idx = conversations.value.findIndex(c => c.id === convId)
|
||||
if (idx >= 0) {
|
||||
const conv = conversations.value[idx]
|
||||
|
|
@ -533,7 +562,7 @@ async function saveSettings(data) {
|
|||
// -- Update tools enabled --
|
||||
function updateToolsEnabled(val) {
|
||||
toolsEnabled.value = val
|
||||
localStorage.setItem('tools_enabled', String(val))
|
||||
localStorage.setItem(LS_KEY_TOOLS_ENABLED, String(val))
|
||||
}
|
||||
|
||||
// -- Browse project files --
|
||||
|
|
@ -602,7 +631,8 @@ async function deleteProject(project) {
|
|||
}
|
||||
|
||||
// -- Init --
|
||||
onMounted(() => {
|
||||
onMounted(async () => {
|
||||
await loadModels()
|
||||
loadProjects()
|
||||
loadConversations()
|
||||
})
|
||||
|
|
|
|||
|
|
@ -1,4 +1,11 @@
|
|||
const BASE = '/api'
|
||||
import {
|
||||
API_BASE_URL,
|
||||
CONTENT_TYPE_JSON,
|
||||
LS_KEY_MODELS_CACHE,
|
||||
DEFAULT_CONVERSATION_PAGE_SIZE,
|
||||
DEFAULT_MESSAGE_PAGE_SIZE,
|
||||
DEFAULT_PROJECT_PAGE_SIZE,
|
||||
} from '../constants'
|
||||
|
||||
// Cache for models list
|
||||
let modelsCache = null
|
||||
|
|
@ -13,8 +20,8 @@ function buildQueryParams(params) {
|
|||
}
|
||||
|
||||
async function request(url, options = {}) {
|
||||
const res = await fetch(`${BASE}${url}`, {
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
const res = await fetch(`${API_BASE_URL}${url}`, {
|
||||
headers: { 'Content-Type': CONTENT_TYPE_JSON },
|
||||
...options,
|
||||
body: options.body ? JSON.stringify(options.body) : undefined,
|
||||
})
|
||||
|
|
@ -37,9 +44,9 @@ function createSSEStream(url, body, { onProcessStep, onDone, onError }) {
|
|||
|
||||
const promise = (async () => {
|
||||
try {
|
||||
const res = await fetch(`${BASE}${url}`, {
|
||||
const res = await fetch(`${API_BASE_URL}${url}`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
headers: { 'Content-Type': CONTENT_TYPE_JSON },
|
||||
body: JSON.stringify(body),
|
||||
signal: controller.signal,
|
||||
})
|
||||
|
|
@ -107,7 +114,7 @@ export const modelApi = {
|
|||
}
|
||||
|
||||
// Try localStorage cache first
|
||||
const cached = localStorage.getItem('models_cache')
|
||||
const cached = localStorage.getItem(LS_KEY_MODELS_CACHE)
|
||||
if (cached) {
|
||||
try {
|
||||
modelsCache = JSON.parse(cached)
|
||||
|
|
@ -118,7 +125,7 @@ export const modelApi = {
|
|||
// Fetch from server
|
||||
const res = await this.list()
|
||||
modelsCache = res.data
|
||||
localStorage.setItem('models_cache', JSON.stringify(modelsCache))
|
||||
localStorage.setItem(LS_KEY_MODELS_CACHE, JSON.stringify(modelsCache))
|
||||
return res
|
||||
},
|
||||
|
||||
|
|
@ -131,7 +138,7 @@ export const statsApi = {
|
|||
}
|
||||
|
||||
export const conversationApi = {
|
||||
list(cursor, limit = 20, projectId = null) {
|
||||
list(cursor, limit = DEFAULT_CONVERSATION_PAGE_SIZE, projectId = null) {
|
||||
return request(`/conversations${buildQueryParams({ cursor, limit, project_id: projectId })}`)
|
||||
},
|
||||
|
||||
|
|
@ -159,7 +166,7 @@ export const conversationApi = {
|
|||
}
|
||||
|
||||
export const messageApi = {
|
||||
list(convId, cursor, limit = 50) {
|
||||
list(convId, cursor, limit = DEFAULT_MESSAGE_PAGE_SIZE) {
|
||||
return request(`/conversations/${convId}/messages${buildQueryParams({ cursor, limit })}`)
|
||||
},
|
||||
|
||||
|
|
@ -186,7 +193,7 @@ export const messageApi = {
|
|||
}
|
||||
|
||||
export const projectApi = {
|
||||
list(cursor, limit = 20) {
|
||||
list(cursor, limit = DEFAULT_PROJECT_PAGE_SIZE) {
|
||||
return request(`/projects${buildQueryParams({ cursor, limit })}`)
|
||||
},
|
||||
|
||||
|
|
@ -210,7 +217,7 @@ export const projectApi = {
|
|||
},
|
||||
|
||||
readFileRaw(projectId, filepath) {
|
||||
return fetch(`${BASE}/projects/${projectId}/files/${filepath}`).then(res => {
|
||||
return fetch(`${API_BASE_URL}/projects/${projectId}/files/${filepath}`).then(res => {
|
||||
if (!res.ok) throw new Error(`HTTP ${res.status}`)
|
||||
return res
|
||||
})
|
||||
|
|
|
|||
|
|
@ -79,13 +79,13 @@ import MessageBubble from './MessageBubble.vue'
|
|||
import MessageInput from './MessageInput.vue'
|
||||
import MessageNav from './MessageNav.vue'
|
||||
import ProcessBlock from './ProcessBlock.vue'
|
||||
import { modelApi } from '../api'
|
||||
|
||||
const props = defineProps({
|
||||
conversation: { type: Object, default: null },
|
||||
messages: { type: Array, required: true },
|
||||
streaming: { type: Boolean, default: false },
|
||||
streamingProcessSteps: { type: Array, default: () => [] },
|
||||
modelNameMap: { type: Object, default: () => ({}) },
|
||||
hasMoreMessages: { type: Boolean, default: false },
|
||||
loadingMore: { type: Boolean, default: false },
|
||||
toolsEnabled: { type: Boolean, default: true },
|
||||
|
|
@ -95,27 +95,15 @@ const emit = defineEmits(['sendMessage', 'stopStreaming', 'deleteMessage', 'rege
|
|||
|
||||
const scrollContainer = ref(null)
|
||||
const inputRef = ref(null)
|
||||
const modelNameMap = ref({})
|
||||
const activeMessageId = ref(null)
|
||||
let scrollObserver = null
|
||||
const observedElements = new WeakSet()
|
||||
|
||||
function formatModelName(modelId) {
|
||||
return modelNameMap.value[modelId] || modelId
|
||||
}
|
||||
|
||||
onMounted(async () => {
|
||||
try {
|
||||
const res = await modelApi.getCached()
|
||||
const map = {}
|
||||
for (const m of res.data) {
|
||||
if (m.id && m.name) map[m.id] = m.name
|
||||
}
|
||||
modelNameMap.value = map
|
||||
} catch (e) {
|
||||
console.warn('Failed to load model names:', e)
|
||||
return props.modelNameMap[modelId] || modelId
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
if (scrollContainer.value) {
|
||||
scrollObserver = new IntersectionObserver(
|
||||
(entries) => {
|
||||
|
|
@ -257,16 +245,6 @@ watch(() => props.conversation?.id, () => {
|
|||
line-height: 1;
|
||||
}
|
||||
|
||||
.thinking-badge {
|
||||
background: rgba(245, 158, 11, 0.12);
|
||||
color: #d97706;
|
||||
}
|
||||
|
||||
[data-theme="dark"] .thinking-badge {
|
||||
background: rgba(245, 158, 11, 0.18);
|
||||
color: #fbbf24;
|
||||
}
|
||||
|
||||
.messages-container {
|
||||
flex: 1 1 auto;
|
||||
overflow-y: auto;
|
||||
|
|
@ -313,4 +291,5 @@ watch(() => props.conversation?.id, () => {
|
|||
|
||||
|
||||
|
||||
|
||||
</style>
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@
|
|||
<input
|
||||
ref="fileInputRef"
|
||||
type="file"
|
||||
accept=".txt,.md,.json,.xml,.html,.css,.js,.ts,.jsx,.tsx,.py,.java,.c,.cpp,.h,.hpp,.yaml,.yml,.toml,.ini,.csv,.sql,.sh,.bat,.log,.vue,.svelte,.go,.rs,.rb,.php,.swift,.kt,.scala,.lua,.r,.dart"
|
||||
accept=ALLOWED_UPLOAD_EXTENSIONS
|
||||
@change="handleFileUpload"
|
||||
style="display: none"
|
||||
/>
|
||||
|
|
@ -65,6 +65,7 @@
|
|||
<script setup>
|
||||
import { ref, computed, nextTick } from 'vue'
|
||||
import { icons } from '../utils/icons'
|
||||
import { TEXTAREA_MAX_HEIGHT_PX, ALLOWED_UPLOAD_EXTENSIONS } from '../constants'
|
||||
|
||||
const props = defineProps({
|
||||
disabled: { type: Boolean, default: false },
|
||||
|
|
@ -83,7 +84,7 @@ function autoResize() {
|
|||
const el = textareaRef.value
|
||||
if (!el) return
|
||||
el.style.height = 'auto'
|
||||
el.style.height = Math.min(el.scrollHeight, 200) + 'px'
|
||||
el.style.height = Math.min(el.scrollHeight, TEXTAREA_MAX_HEIGHT_PX) + 'px'
|
||||
}
|
||||
|
||||
function onKeydown(e) {
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
<script setup>
|
||||
import { computed } from 'vue'
|
||||
import { DEFAULT_TRUNCATE_LENGTH } from '../constants'
|
||||
|
||||
const props = defineProps({
|
||||
messages: { type: Array, required: true },
|
||||
|
|
@ -30,7 +31,7 @@ const userMessages = computed(() => props.messages.filter(m => m.role === 'user'
|
|||
function preview(msg) {
|
||||
if (!msg.text) return '...'
|
||||
const clean = msg.text.replace(/[#*`~>\-\[\]()]/g, '').replace(/\s+/g, ' ').trim()
|
||||
return clean.length > 60 ? clean.slice(0, 60) + '...' : clean
|
||||
return clean.length > DEFAULT_TRUNCATE_LENGTH ? clean.slice(0, DEFAULT_TRUNCATE_LENGTH) + '...' : clean
|
||||
}
|
||||
</script>
|
||||
|
||||
|
|
|
|||
|
|
@ -41,7 +41,10 @@
|
|||
</div>
|
||||
<div v-if="item.result" class="tool-detail">
|
||||
<span class="detail-label">返回结果:</span>
|
||||
<pre>{{ item.result }}</pre>
|
||||
<pre>{{ expandedResultKeys[item.key] ? item.result : item.resultPreview }}</pre>
|
||||
<button v-if="item.resultTruncated" class="btn-expand-result" @click.stop="toggleResultExpand(item.key)">
|
||||
{{ expandedResultKeys[item.key] ? '收起' : `展开全部 (${item.resultLength} 字符)` }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
|
@ -67,6 +70,19 @@ import { ref, computed, watch } from 'vue'
|
|||
import { renderMarkdown } from '../utils/markdown'
|
||||
import { formatJson, truncate } from '../utils/format'
|
||||
import { useCodeEnhancement } from '../composables/useCodeEnhancement'
|
||||
import { RESULT_PREVIEW_LIMIT } from '../constants'
|
||||
|
||||
function buildResultFields(rawContent) {
|
||||
const formatted = formatJson(rawContent)
|
||||
const len = formatted.length
|
||||
const truncated = len > RESULT_PREVIEW_LIMIT
|
||||
return {
|
||||
result: formatted,
|
||||
resultPreview: truncated ? formatted.slice(0, RESULT_PREVIEW_LIMIT) + '\n...' : formatted,
|
||||
resultTruncated: truncated,
|
||||
resultLength: len,
|
||||
}
|
||||
}
|
||||
|
||||
const props = defineProps({
|
||||
toolCalls: { type: Array, default: () => [] },
|
||||
|
|
@ -75,10 +91,12 @@ const props = defineProps({
|
|||
})
|
||||
|
||||
const expandedKeys = ref({})
|
||||
const expandedResultKeys = ref({})
|
||||
|
||||
// Auto-collapse all items when a new stream starts
|
||||
watch(() => props.streaming, (v) => {
|
||||
if (v) expandedKeys.value = {}
|
||||
expandedResultKeys.value = {}
|
||||
})
|
||||
|
||||
const processRef = ref(null)
|
||||
|
|
@ -87,6 +105,10 @@ function toggleItem(key) {
|
|||
expandedKeys.value[key] = !expandedKeys.value[key]
|
||||
}
|
||||
|
||||
function toggleResultExpand(key) {
|
||||
expandedResultKeys.value[key] = !expandedResultKeys.value[key]
|
||||
}
|
||||
|
||||
function getResultSummary(result) {
|
||||
try {
|
||||
const parsed = typeof result === 'string' ? JSON.parse(result) : result
|
||||
|
|
@ -138,7 +160,7 @@ const processItems = computed(() => {
|
|||
const summary = getResultSummary(step.content)
|
||||
const match = items.findLast(it => it.type === 'tool_call' && it.id === toolId)
|
||||
if (match) {
|
||||
match.result = formatJson(step.content)
|
||||
Object.assign(match, buildResultFields(step.content))
|
||||
match.resultSummary = summary.text
|
||||
match.isSuccess = summary.success
|
||||
match.loading = false
|
||||
|
|
@ -165,7 +187,8 @@ const processItems = computed(() => {
|
|||
if (props.toolCalls && props.toolCalls.length > 0) {
|
||||
props.toolCalls.forEach((call, i) => {
|
||||
const toolName = call.function?.name || '未知工具'
|
||||
const result = call.result ? getResultSummary(call.result) : null
|
||||
const resultSummary = call.result ? getResultSummary(call.result) : null
|
||||
const resultFields = call.result ? buildResultFields(call.result) : { result: null, resultPreview: null, resultTruncated: false, resultLength: 0 }
|
||||
items.push({
|
||||
type: 'tool_call',
|
||||
toolName,
|
||||
|
|
@ -174,9 +197,9 @@ const processItems = computed(() => {
|
|||
id: call.id,
|
||||
key: `tool_call-${call.id || i}`,
|
||||
loading: !call.result && props.streaming,
|
||||
result: call.result ? formatJson(call.result) : null,
|
||||
resultSummary: result ? result.text : null,
|
||||
isSuccess: result ? result.success : undefined,
|
||||
...resultFields,
|
||||
resultSummary: resultSummary ? resultSummary.text : null,
|
||||
isSuccess: resultSummary ? resultSummary.success : undefined,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
|
@ -345,6 +368,23 @@ watch(() => props.processSteps?.length, () => {
|
|||
word-break: break-word;
|
||||
}
|
||||
|
||||
.btn-expand-result {
|
||||
display: inline-block;
|
||||
margin-top: 6px;
|
||||
padding: 3px 10px;
|
||||
font-size: 11px;
|
||||
color: var(--tool-color);
|
||||
background: var(--tool-bg);
|
||||
border: 1px solid var(--tool-border);
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
|
||||
.btn-expand-result:hover {
|
||||
background: var(--tool-bg-hover);
|
||||
}
|
||||
|
||||
/* Text content — rendered as markdown */
|
||||
.text-content {
|
||||
padding: 0;
|
||||
|
|
|
|||
|
|
@ -123,19 +123,21 @@
|
|||
|
||||
<script setup>
|
||||
import { reactive, ref, watch, onMounted } from 'vue'
|
||||
import { modelApi, conversationApi } from '../api'
|
||||
import { conversationApi } from '../api'
|
||||
import { useTheme } from '../composables/useTheme'
|
||||
import { icons } from '../utils/icons'
|
||||
import { SETTINGS_AUTO_SAVE_DEBOUNCE_MS } from '../constants'
|
||||
|
||||
const props = defineProps({
|
||||
visible: { type: Boolean, default: false },
|
||||
conversation: { type: Object, default: null },
|
||||
models: { type: Array, default: () => [] },
|
||||
defaultModel: { type: String, default: '' },
|
||||
})
|
||||
|
||||
const emit = defineEmits(['close', 'save'])
|
||||
|
||||
const { isDark, toggleTheme } = useTheme()
|
||||
const models = ref([])
|
||||
|
||||
const tabs = [
|
||||
{ value: 'basic', label: '基本' },
|
||||
|
|
@ -154,15 +156,6 @@ const form = reactive({
|
|||
thinking_enabled: false,
|
||||
})
|
||||
|
||||
async function loadModels() {
|
||||
try {
|
||||
const res = await modelApi.getCached()
|
||||
models.value = res.data || []
|
||||
} catch (e) {
|
||||
console.error('Failed to load models:', e)
|
||||
}
|
||||
}
|
||||
|
||||
function syncFormFromConversation() {
|
||||
if (props.conversation) {
|
||||
form.title = props.conversation.title || ''
|
||||
|
|
@ -170,29 +163,63 @@ function syncFormFromConversation() {
|
|||
form.temperature = props.conversation.temperature ?? 1.0
|
||||
form.max_tokens = props.conversation.max_tokens ?? 65536
|
||||
form.thinking_enabled = props.conversation.thinking_enabled ?? false
|
||||
// model: 优先使用 conversation 的值,其次 models 列表第一个
|
||||
// model: 优先使用 conversation 的值,其次 defaultModel,最后 models 列表第一个
|
||||
if (props.conversation.model) {
|
||||
form.model = props.conversation.model
|
||||
} else if (models.value.length > 0) {
|
||||
form.model = models.value[0].id
|
||||
} else if (props.defaultModel) {
|
||||
form.model = props.defaultModel
|
||||
} else if (props.models.length > 0) {
|
||||
form.model = props.models[0].id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sync form when panel opens or conversation changes
|
||||
watch([() => props.visible, () => props.conversation, models], () => {
|
||||
if (props.visible) {
|
||||
activeTab.value = 'basic'
|
||||
// Track which conversation the form is synced to, to avoid saving stale data
|
||||
let syncedConvId = null
|
||||
let isSyncing = false
|
||||
|
||||
function doSync() {
|
||||
if (!props.conversation) return
|
||||
isSyncing = true
|
||||
syncFormFromConversation()
|
||||
syncedConvId = props.conversation.id
|
||||
// Defer resetting flag to after all watchers flush
|
||||
setTimeout(() => { isSyncing = false }, 0)
|
||||
}
|
||||
}, { deep: true })
|
||||
|
||||
// Auto-save with debounce when form changes
|
||||
// Sync form when panel opens or conversation switches
|
||||
watch([() => props.visible, () => props.conversation?.id, () => props.models, () => props.defaultModel], () => {
|
||||
if (props.visible && props.conversation) {
|
||||
activeTab.value = 'basic'
|
||||
if (saveTimer) clearTimeout(saveTimer)
|
||||
saveTimer = null
|
||||
doSync()
|
||||
} else if (!props.visible) {
|
||||
syncedConvId = null
|
||||
}
|
||||
})
|
||||
|
||||
// Sync when conversation data updates (e.g. auto-generated title after stream)
|
||||
watch(
|
||||
() => props.conversation,
|
||||
(conv) => {
|
||||
if (!props.visible || !conv || syncedConvId !== conv.id) return
|
||||
doSync()
|
||||
},
|
||||
{ deep: true },
|
||||
)
|
||||
|
||||
// Initial sync on mount (component may be recreated via :key)
|
||||
onMounted(() => {
|
||||
if (props.visible && props.conversation) doSync()
|
||||
})
|
||||
|
||||
// Auto-save with debounce when user edits form
|
||||
let saveTimer = null
|
||||
watch(form, () => {
|
||||
if (props.visible && props.conversation) {
|
||||
if (props.visible && props.conversation && syncedConvId === props.conversation.id && !isSyncing) {
|
||||
if (saveTimer) clearTimeout(saveTimer)
|
||||
saveTimer = setTimeout(saveChanges, 500)
|
||||
saveTimer = setTimeout(saveChanges, SETTINGS_AUTO_SAVE_DEBOUNCE_MS)
|
||||
}
|
||||
}, { deep: true })
|
||||
|
||||
|
|
@ -205,8 +232,6 @@ async function saveChanges() {
|
|||
console.error('Failed to save settings:', e)
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(loadModels)
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
|
|
|
|||
|
|
@ -113,6 +113,7 @@
|
|||
import { computed, reactive } from 'vue'
|
||||
import { formatTime } from '../utils/format'
|
||||
import { icons } from '../utils/icons'
|
||||
import { INFINITE_SCROLL_THRESHOLD_PX } from '../constants'
|
||||
|
||||
const props = defineProps({
|
||||
conversations: { type: Array, required: true },
|
||||
|
|
@ -171,7 +172,7 @@ function toggleGroup(id) {
|
|||
|
||||
function onScroll(e) {
|
||||
const el = e.target
|
||||
if (el.scrollTop + el.clientHeight >= el.scrollHeight - 50) {
|
||||
if (el.scrollTop + el.clientHeight >= el.scrollHeight - INFINITE_SCROLL_THRESHOLD_PX) {
|
||||
emit('loadMore')
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -60,107 +60,10 @@
|
|||
</div>
|
||||
|
||||
<!-- 趋势图 -->
|
||||
<div v-if="period !== 'daily' && stats.daily && chartData.length > 0" class="stats-chart">
|
||||
<div class="chart-title">每日趋势</div>
|
||||
<div v-if="chartData.length > 0" class="stats-chart">
|
||||
<div class="chart-title">{{ period === 'daily' ? '今日趋势' : '每日趋势' }}</div>
|
||||
<div class="chart-container">
|
||||
<svg class="line-chart" :viewBox="`0 0 ${chartWidth} ${chartHeight}`">
|
||||
<defs>
|
||||
<linearGradient id="areaGradient" x1="0%" y1="0%" x2="0%" y2="100%">
|
||||
<stop offset="0%" :stop-color="accentColor" stop-opacity="0.25"/>
|
||||
<stop offset="100%" :stop-color="accentColor" stop-opacity="0.02"/>
|
||||
</linearGradient>
|
||||
</defs>
|
||||
<!-- 网格线 -->
|
||||
<line
|
||||
v-for="i in 4"
|
||||
:key="'grid-' + i"
|
||||
:x1="padding"
|
||||
:y1="padding + (chartHeight - 2 * padding) * (i - 1) / 3"
|
||||
:x2="chartWidth - padding"
|
||||
:y2="padding + (chartHeight - 2 * padding) * (i - 1) / 3"
|
||||
stroke="var(--border-light)"
|
||||
stroke-dasharray="3,3"
|
||||
/>
|
||||
<!-- Y轴标签 -->
|
||||
<text
|
||||
v-for="i in 4"
|
||||
:key="'yl-' + i"
|
||||
:x="padding - 4"
|
||||
:y="padding + (chartHeight - 2 * padding) * (i - 1) / 3 + 3"
|
||||
text-anchor="end"
|
||||
class="y-label"
|
||||
>{{ formatNumber(maxValue - (maxValue * (i - 1)) / 3) }}</text>
|
||||
<!-- 填充区域 -->
|
||||
<path :d="areaPath" fill="url(#areaGradient)"/>
|
||||
<!-- 折线 -->
|
||||
<path
|
||||
:d="linePath"
|
||||
fill="none"
|
||||
:stroke="accentColor"
|
||||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
/>
|
||||
<!-- 数据点 -->
|
||||
<circle
|
||||
v-for="(point, idx) in chartPoints"
|
||||
:key="idx"
|
||||
:cx="point.x"
|
||||
:cy="point.y"
|
||||
r="3"
|
||||
:fill="accentColor"
|
||||
stroke="var(--bg-primary)"
|
||||
stroke-width="2"
|
||||
class="data-point"
|
||||
@mouseenter="hoveredPoint = idx"
|
||||
@mouseleave="hoveredPoint = null"
|
||||
/>
|
||||
<!-- 竖线指示 -->
|
||||
<line
|
||||
v-if="hoveredPoint !== null && chartPoints[hoveredPoint]"
|
||||
:x1="chartPoints[hoveredPoint].x"
|
||||
:y1="padding"
|
||||
:x2="chartPoints[hoveredPoint].x"
|
||||
:y2="chartHeight - padding"
|
||||
stroke="var(--border-medium)"
|
||||
stroke-dasharray="3,3"
|
||||
/>
|
||||
</svg>
|
||||
|
||||
<!-- X轴标签 -->
|
||||
<div class="x-labels">
|
||||
<span
|
||||
v-for="(point, idx) in chartPoints"
|
||||
:key="idx"
|
||||
class="x-label"
|
||||
:class="{ active: hoveredPoint === idx }"
|
||||
>
|
||||
{{ formatDateLabel(point.date) }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- 悬浮提示 -->
|
||||
<Transition name="fade">
|
||||
<div
|
||||
v-if="hoveredPoint !== null && chartPoints[hoveredPoint]"
|
||||
class="tooltip"
|
||||
:style="{
|
||||
left: chartPoints[hoveredPoint].x + 'px',
|
||||
top: (chartPoints[hoveredPoint].y - 52) + 'px'
|
||||
}"
|
||||
>
|
||||
<div class="tooltip-date">{{ formatFullDate(chartPoints[hoveredPoint].date) }}</div>
|
||||
<div class="tooltip-row">
|
||||
<span class="tooltip-dot prompt"></span>
|
||||
输入 {{ formatNumber(chartPoints[hoveredPoint].prompt) }}
|
||||
</div>
|
||||
<div class="tooltip-row">
|
||||
<span class="tooltip-dot completion"></span>
|
||||
输出 {{ formatNumber(chartPoints[hoveredPoint].completion) }}
|
||||
</div>
|
||||
<div class="tooltip-total">{{ formatNumber(chartPoints[hoveredPoint].value) }} tokens</div>
|
||||
</div>
|
||||
</Transition>
|
||||
<canvas ref="chartCanvas"></canvas>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
|
@ -197,11 +100,14 @@
|
|||
</template>
|
||||
|
||||
<script setup>
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { ref, computed, watch, onMounted, onBeforeUnmount, nextTick } from 'vue'
|
||||
import { Chart, registerables } from 'chart.js'
|
||||
import { statsApi } from '../api'
|
||||
import { formatNumber } from '../utils/format'
|
||||
import { icons } from '../utils/icons'
|
||||
|
||||
Chart.register(...registerables)
|
||||
|
||||
defineEmits(['close'])
|
||||
|
||||
const periods = [
|
||||
|
|
@ -213,15 +119,8 @@ const periods = [
|
|||
const period = ref('daily')
|
||||
const stats = ref(null)
|
||||
const loading = ref(false)
|
||||
const hoveredPoint = ref(null)
|
||||
|
||||
const accentColor = computed(() => {
|
||||
return getComputedStyle(document.documentElement).getPropertyValue('--accent-primary').trim() || '#2563eb'
|
||||
})
|
||||
|
||||
const chartWidth = 320
|
||||
const chartHeight = 140
|
||||
const padding = 32
|
||||
const chartCanvas = ref(null)
|
||||
let chartInstance = null
|
||||
|
||||
const sortedDaily = computed(() => {
|
||||
if (!stats.value?.daily) return {}
|
||||
|
|
@ -231,18 +130,46 @@ const sortedDaily = computed(() => {
|
|||
})
|
||||
|
||||
const chartData = computed(() => {
|
||||
if (period.value === 'daily' && stats.value?.hourly) {
|
||||
const hourly = stats.value.hourly
|
||||
// Backend returns UTC hours — convert to local timezone for display.
|
||||
const offset = -new Date().getTimezoneOffset() / 60 // e.g. +8 for UTC+8
|
||||
const localHourly = {}
|
||||
for (const [utcH, val] of Object.entries(hourly)) {
|
||||
const localH = ((parseInt(utcH) + offset) % 24 + 24) % 24
|
||||
localHourly[localH] = val
|
||||
}
|
||||
let minH = 24, maxH = -1
|
||||
for (const h of Object.keys(localHourly)) {
|
||||
const hour = parseInt(h)
|
||||
if (hour < minH) minH = hour
|
||||
if (hour > maxH) maxH = hour
|
||||
}
|
||||
if (minH > maxH) return []
|
||||
const start = Math.max(0, minH)
|
||||
const end = Math.min(23, maxH)
|
||||
return Array.from({ length: end - start + 1 }, (_, i) => {
|
||||
const h = start + i
|
||||
return {
|
||||
label: `${h}:00`,
|
||||
value: localHourly[String(h)]?.total || 0,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
const data = sortedDaily.value
|
||||
return Object.entries(data).map(([date, val]) => ({
|
||||
date,
|
||||
return Object.entries(data).map(([date, val]) => {
|
||||
// date is "YYYY-MM-DD" from backend — parse directly to avoid
|
||||
// new Date() timezone shift (parsed as UTC midnight then
|
||||
// getMonth/getDate applies local offset, potentially off by one day).
|
||||
const [year, month, day] = date.split('-')
|
||||
return {
|
||||
label: `${parseInt(month)}/${parseInt(day)}`,
|
||||
value: val.total,
|
||||
prompt: val.prompt || 0,
|
||||
completion: val.completion || 0,
|
||||
}))
|
||||
}
|
||||
})
|
||||
|
||||
const maxValue = computed(() => {
|
||||
if (chartData.value.length === 0) return 100
|
||||
return Math.max(100, ...chartData.value.map(d => d.value))
|
||||
})
|
||||
|
||||
const maxModelTokens = computed(() => {
|
||||
|
|
@ -250,53 +177,135 @@ const maxModelTokens = computed(() => {
|
|||
return Math.max(1, ...Object.values(stats.value.by_model).map(d => d.total))
|
||||
})
|
||||
|
||||
const chartPoints = computed(() => {
|
||||
const data = chartData.value
|
||||
if (data.length === 0) return []
|
||||
|
||||
const xRange = chartWidth - 2 * padding
|
||||
const yRange = chartHeight - 2 * padding
|
||||
|
||||
return data.map((d, i) => ({
|
||||
x: data.length === 1
|
||||
? chartWidth / 2
|
||||
: padding + (i / Math.max(1, data.length - 1)) * xRange,
|
||||
y: chartHeight - padding - (d.value / maxValue.value) * yRange,
|
||||
date: d.date,
|
||||
value: d.value,
|
||||
prompt: d.prompt,
|
||||
completion: d.completion,
|
||||
}))
|
||||
})
|
||||
|
||||
const linePath = computed(() => {
|
||||
const points = chartPoints.value
|
||||
if (points.length === 0) return ''
|
||||
return points.map((p, i) => `${i === 0 ? 'M' : 'L'} ${p.x} ${p.y}`).join(' ')
|
||||
})
|
||||
|
||||
const areaPath = computed(() => {
|
||||
const points = chartPoints.value
|
||||
if (points.length === 0) return ''
|
||||
|
||||
const baseY = chartHeight - padding
|
||||
|
||||
let path = `M ${points[0].x} ${baseY} `
|
||||
path += points.map(p => `L ${p.x} ${p.y}`).join(' ')
|
||||
path += ` L ${points[points.length - 1].x} ${baseY} Z`
|
||||
|
||||
return path
|
||||
})
|
||||
|
||||
function formatDateLabel(dateStr) {
|
||||
const d = new Date(dateStr)
|
||||
return `${d.getMonth() + 1}/${d.getDate()}`
|
||||
function getAccentColor() {
|
||||
return getComputedStyle(document.documentElement).getPropertyValue('--accent-primary').trim() || '#2563eb'
|
||||
}
|
||||
|
||||
function formatFullDate(dateStr) {
|
||||
const d = new Date(dateStr)
|
||||
return `${d.getMonth() + 1}月${d.getDate()}日`
|
||||
function getTextColor(alpha = 1) {
|
||||
const c = getComputedStyle(document.documentElement).getPropertyValue('--text-tertiary').trim() || '#888'
|
||||
if (alpha === 1) return c
|
||||
// Convert hex to rgba
|
||||
if (c.startsWith('#')) {
|
||||
const r = parseInt(c.slice(1, 3), 16)
|
||||
const g = parseInt(c.slice(3, 5), 16)
|
||||
const b = parseInt(c.slice(5, 7), 16)
|
||||
return `rgba(${r},${g},${b},${alpha})`
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
function destroyChart() {
|
||||
if (chartInstance) {
|
||||
chartInstance.destroy()
|
||||
chartInstance = null
|
||||
}
|
||||
}
|
||||
|
||||
function buildChart() {
|
||||
if (!chartCanvas.value || chartData.value.length === 0) return
|
||||
|
||||
destroyChart()
|
||||
|
||||
const accent = getAccentColor()
|
||||
const ctx = chartCanvas.value.getContext('2d')
|
||||
|
||||
// Gradient fill
|
||||
const gradient = ctx.createLinearGradient(0, 0, 0, 200)
|
||||
gradient.addColorStop(0, accent + '40')
|
||||
gradient.addColorStop(1, accent + '05')
|
||||
|
||||
const labels = chartData.value.map(d => d.label)
|
||||
const values = chartData.value.map(d => d.value)
|
||||
|
||||
// Determine max ticks for x-axis
|
||||
const maxTicks = chartData.value.length <= 8 ? chartData.value.length : 6
|
||||
|
||||
chartInstance = new Chart(ctx, {
|
||||
type: 'line',
|
||||
data: {
|
||||
labels,
|
||||
datasets: [{
|
||||
data: values,
|
||||
borderColor: accent,
|
||||
backgroundColor: gradient,
|
||||
borderWidth: 2,
|
||||
pointRadius: 0,
|
||||
pointHoverRadius: 4,
|
||||
pointHoverBackgroundColor: accent,
|
||||
pointHoverBorderColor: '#fff',
|
||||
pointHoverBorderWidth: 2,
|
||||
fill: true,
|
||||
tension: 0,
|
||||
}],
|
||||
},
|
||||
options: {
|
||||
responsive: true,
|
||||
maintainAspectRatio: false,
|
||||
animation: { duration: 300 },
|
||||
layout: {
|
||||
padding: { top: 4, right: 4, bottom: 0, left: 0 },
|
||||
},
|
||||
scales: {
|
||||
x: {
|
||||
grid: { display: false },
|
||||
border: { display: false },
|
||||
ticks: {
|
||||
color: getTextColor(),
|
||||
font: { size: 10 },
|
||||
maxTicksLimit: maxTicks,
|
||||
maxRotation: 0,
|
||||
},
|
||||
},
|
||||
y: {
|
||||
beginAtZero: true,
|
||||
grid: {
|
||||
color: getTextColor(0.15),
|
||||
drawBorder: false,
|
||||
},
|
||||
border: { display: false },
|
||||
ticks: {
|
||||
color: getTextColor(),
|
||||
font: { size: 9 },
|
||||
maxTicksLimit: 4,
|
||||
callback: (v) => formatNumber(v),
|
||||
},
|
||||
},
|
||||
},
|
||||
plugins: {
|
||||
legend: { display: false },
|
||||
tooltip: {
|
||||
backgroundColor: 'rgba(0,0,0,0.8)',
|
||||
titleColor: '#fff',
|
||||
bodyColor: '#ccc',
|
||||
titleFont: { size: 11, weight: '500' },
|
||||
bodyFont: { size: 11 },
|
||||
padding: 8,
|
||||
cornerRadius: 6,
|
||||
displayColors: false,
|
||||
callbacks: {
|
||||
title: (items) => {
|
||||
const idx = items[0].dataIndex
|
||||
const d = chartData.value[idx]
|
||||
if (period.value === 'daily') {
|
||||
return `${d.label} - ${parseInt(d.label) + 1}:00`
|
||||
}
|
||||
return d.label
|
||||
},
|
||||
label: (item) => `${formatNumber(item.raw)} tokens`,
|
||||
},
|
||||
},
|
||||
},
|
||||
interaction: {
|
||||
mode: 'index',
|
||||
intersect: false,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
watch(chartData, () => {
|
||||
nextTick(buildChart)
|
||||
})
|
||||
|
||||
async function loadStats() {
|
||||
loading.value = true
|
||||
|
|
@ -312,11 +321,11 @@ async function loadStats() {
|
|||
|
||||
function changePeriod(p) {
|
||||
period.value = p
|
||||
hoveredPoint.value = null
|
||||
loadStats()
|
||||
}
|
||||
|
||||
onMounted(loadStats)
|
||||
onBeforeUnmount(destroyChart)
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
|
|
@ -324,8 +333,6 @@ onMounted(loadStats)
|
|||
padding: 0;
|
||||
}
|
||||
|
||||
/* panel-header, panel-title, header-actions now in global.css */
|
||||
|
||||
.stats-loading {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
|
|
@ -430,99 +437,9 @@ onMounted(loadStats)
|
|||
background: var(--bg-input);
|
||||
border: 1px solid var(--border-light);
|
||||
border-radius: 10px;
|
||||
padding: 12px 8px 8px 8px;
|
||||
padding: 10px;
|
||||
position: relative;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.line-chart {
|
||||
width: 100%;
|
||||
height: 140px;
|
||||
}
|
||||
|
||||
.y-label {
|
||||
fill: var(--text-tertiary);
|
||||
font-size: 9px;
|
||||
}
|
||||
|
||||
.data-point {
|
||||
cursor: pointer;
|
||||
transition: r 0.15s;
|
||||
}
|
||||
|
||||
.data-point:hover {
|
||||
r: 5;
|
||||
}
|
||||
|
||||
.x-labels {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
margin-top: 6px;
|
||||
padding: 0 28px 0 32px;
|
||||
}
|
||||
|
||||
.x-label {
|
||||
font-size: 10px;
|
||||
color: var(--text-tertiary);
|
||||
transition: color 0.15s;
|
||||
}
|
||||
|
||||
.x-label.active {
|
||||
color: var(--text-primary);
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
/* 提示框 */
|
||||
.tooltip {
|
||||
position: absolute;
|
||||
background: var(--bg-primary);
|
||||
border: 1px solid var(--border-medium);
|
||||
padding: 8px 10px;
|
||||
border-radius: 8px;
|
||||
font-size: 11px;
|
||||
pointer-events: none;
|
||||
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.12);
|
||||
transform: translateX(-50%);
|
||||
z-index: 10;
|
||||
min-width: 120px;
|
||||
}
|
||||
|
||||
.tooltip-date {
|
||||
color: var(--text-tertiary);
|
||||
font-size: 10px;
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
|
||||
.tooltip-row {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 4px;
|
||||
font-size: 11px;
|
||||
color: var(--text-secondary);
|
||||
}
|
||||
|
||||
.tooltip-dot {
|
||||
width: 6px;
|
||||
height: 6px;
|
||||
border-radius: 50%;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.tooltip-dot.prompt {
|
||||
background: #3b82f6;
|
||||
}
|
||||
|
||||
.tooltip-dot.completion {
|
||||
background: #a855f7;
|
||||
}
|
||||
|
||||
.tooltip-total {
|
||||
margin-top: 4px;
|
||||
padding-top: 4px;
|
||||
border-top: 1px solid var(--border-light);
|
||||
font-weight: 600;
|
||||
color: var(--text-primary);
|
||||
font-size: 12px;
|
||||
height: 180px;
|
||||
}
|
||||
|
||||
/* 模型分布 */
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import { watch, onMounted, nextTick, onUnmounted } from 'vue'
|
||||
import { enhanceCodeBlocks } from '../utils/markdown'
|
||||
import { CODE_ENHANCE_DEBOUNCE_MS } from '../constants'
|
||||
|
||||
/**
|
||||
* Composable for enhancing code blocks in a container element.
|
||||
|
|
@ -18,7 +19,7 @@ export function useCodeEnhancement(templateRef, dep, watchOpts) {
|
|||
|
||||
function debouncedEnhance() {
|
||||
if (debounceTimer) clearTimeout(debounceTimer)
|
||||
debounceTimer = setTimeout(() => nextTick(enhance), 150)
|
||||
debounceTimer = setTimeout(() => nextTick(enhance), CODE_ENHANCE_DEBOUNCE_MS)
|
||||
}
|
||||
|
||||
onMounted(enhance)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import { reactive } from 'vue'
|
||||
import { TOAST_DEFAULT_DURATION } from '../constants'
|
||||
|
||||
const state = reactive({
|
||||
toasts: [],
|
||||
|
|
@ -6,7 +7,7 @@ const state = reactive({
|
|||
})
|
||||
|
||||
export function useToast() {
|
||||
function add(type, message, duration = 1500) {
|
||||
function add(type, message, duration = TOAST_DEFAULT_DURATION) {
|
||||
const id = ++state._id
|
||||
state.toasts.push({ id, type, message })
|
||||
setTimeout(() => {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Frontend constants
|
||||
*/
|
||||
|
||||
// === Tool Result Display ===
|
||||
/** Max characters shown in tool result preview before truncation */
|
||||
export const RESULT_PREVIEW_LIMIT = 2048
|
||||
|
||||
// === API ===
|
||||
export const API_BASE_URL = '/api'
|
||||
export const CONTENT_TYPE_JSON = 'application/json'
|
||||
|
||||
// === Pagination ===
|
||||
export const DEFAULT_CONVERSATION_PAGE_SIZE = 20
|
||||
export const DEFAULT_MESSAGE_PAGE_SIZE = 50
|
||||
export const DEFAULT_PROJECT_PAGE_SIZE = 20
|
||||
|
||||
// === Timers (ms) ===
|
||||
export const TOAST_DEFAULT_DURATION = 1500
|
||||
export const CODE_ENHANCE_DEBOUNCE_MS = 150
|
||||
export const SETTINGS_AUTO_SAVE_DEBOUNCE_MS = 500
|
||||
export const COPY_BUTTON_RESET_MS = 1500
|
||||
|
||||
// === Truncation ===
|
||||
export const DEFAULT_TRUNCATE_LENGTH = 60
|
||||
|
||||
// === UI Limits ===
|
||||
export const TEXTAREA_MAX_HEIGHT_PX = 200
|
||||
export const INFINITE_SCROLL_THRESHOLD_PX = 50
|
||||
|
||||
// === LocalStorage Keys ===
|
||||
export const LS_KEY_THEME = 'theme'
|
||||
export const LS_KEY_TOOLS_ENABLED = 'tools_enabled'
|
||||
export const LS_KEY_MODELS_CACHE = 'models_cache'
|
||||
|
||||
// === File Upload ===
|
||||
export const ALLOWED_UPLOAD_EXTENSIONS = '.txt,.md,.json,.xml,.html,.css,.js,.ts,.jsx,.tsx,.py,.java,.c,.cpp,.h,.hpp,.yaml,.yml,.toml,.ini,.csv,.sql,.sh,.bat,.log,.vue,.svelte,.go,.rs,.rb,.php,.swift,.kt,.scala,.lua,.r,.dart'
|
||||
|
|
@ -3,9 +3,10 @@ import App from './App.vue'
|
|||
import './styles/global.css'
|
||||
import './styles/highlight.css'
|
||||
import 'katex/dist/katex.min.css'
|
||||
import { LS_KEY_THEME } from './constants'
|
||||
|
||||
// Initialize theme before app mounts to avoid flash when lazy-loading useTheme
|
||||
const savedTheme = localStorage.getItem('theme')
|
||||
const savedTheme = localStorage.getItem(LS_KEY_THEME)
|
||||
if (savedTheme === 'dark' || savedTheme === 'light') {
|
||||
document.documentElement.setAttribute('data-theme', savedTheme)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import { DEFAULT_TRUNCATE_LENGTH } from '../constants'
|
||||
|
||||
/**
|
||||
* Format ISO date string to a short time string.
|
||||
* - Today: "14:30"
|
||||
|
|
@ -36,7 +38,7 @@ export function formatJson(value) {
|
|||
/**
|
||||
* Truncate text to max characters with ellipsis.
|
||||
*/
|
||||
export function truncate(text, max = 60) {
|
||||
export function truncate(text, max = DEFAULT_TRUNCATE_LENGTH) {
|
||||
if (!text) return ''
|
||||
const str = text.replace(/\s+/g, ' ').trim()
|
||||
return str.length > max ? str.slice(0, max) + '\u2026' : str
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import { marked } from 'marked'
|
|||
import { markedHighlight } from 'marked-highlight'
|
||||
import katex from 'katex'
|
||||
import { highlightCode } from './highlight'
|
||||
import { COPY_BUTTON_RESET_MS } from '../constants'
|
||||
|
||||
function renderMath(text, displayMode) {
|
||||
try {
|
||||
|
|
@ -108,7 +109,7 @@ export function enhanceCodeBlocks(container) {
|
|||
|
||||
copyBtn.addEventListener('click', () => {
|
||||
const raw = code?.textContent || ''
|
||||
const copy = () => { copyBtn.innerHTML = CHECK_SVG; setTimeout(() => { copyBtn.innerHTML = COPY_SVG }, 1500) }
|
||||
const copy = () => { copyBtn.innerHTML = CHECK_SVG; setTimeout(() => { copyBtn.innerHTML = COPY_SVG }, COPY_BUTTON_RESET_MS) }
|
||||
if (navigator.clipboard) {
|
||||
navigator.clipboard.writeText(raw).then(copy)
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -8,17 +8,25 @@ import { dirname, resolve } from 'path'
|
|||
const __filename = fileURLToPath(import.meta.url)
|
||||
const __dirname = dirname(__filename)
|
||||
|
||||
const config = yaml.load(
|
||||
fs.readFileSync(resolve(__dirname, '..', 'config.yml'), 'utf-8')
|
||||
)
|
||||
const configPath = resolve(__dirname, '..', 'config.yml');
|
||||
let config = {};
|
||||
try {
|
||||
config = yaml.load(fs.readFileSync(configPath, 'utf-8'));
|
||||
} catch (e) {
|
||||
console.warn(`Config file not found at ${configPath}, using defaults.`);
|
||||
config = {};
|
||||
}
|
||||
|
||||
const frontend_port = process.env.VITE_FRONTEND_PORT || config.frontend_port || 4000;
|
||||
const backend_port = process.env.VITE_BACKEND_PORT || config.backend_port || 3000;
|
||||
|
||||
export default defineConfig({
|
||||
plugins: [vue()],
|
||||
server: {
|
||||
port: config.frontend_port,
|
||||
port: frontend_port,
|
||||
proxy: {
|
||||
'/api': {
|
||||
target: `http://localhost:${config.backend_port}`,
|
||||
target: `http://localhost:${backend_port}`,
|
||||
changeOrigin: true,
|
||||
},
|
||||
},
|
||||
|
|
|
|||
|
|
@ -24,3 +24,13 @@ build-backend = "setuptools.build_meta"
|
|||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["backend*"]
|
||||
|
||||
[project.optional-dependencies]
|
||||
test = [
|
||||
"pytest>=7.0",
|
||||
"pytest-flask>=1.2",
|
||||
"pytest-cov>=4.0",
|
||||
"pytest-mock>=3.0",
|
||||
"requests-mock>=1.10",
|
||||
"httpx>=0.25",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,79 @@
|
|||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
from pathlib import Path
|
||||
from backend import create_app, db as _db
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def app():
|
||||
"""Create a Flask app configured for testing."""
|
||||
# Create a temporary SQLite database file
|
||||
db_fd, db_path = tempfile.mkstemp(suffix='.db')
|
||||
|
||||
# Override config to use SQLite in-memory (or temporary file)
|
||||
class TestConfig:
|
||||
SQLALCHEMY_DATABASE_URI = f'sqlite:///{db_path}'
|
||||
SQLALCHEMY_TRACK_MODIFICATIONS = False
|
||||
TESTING = True
|
||||
SECRET_KEY = 'test-secret-key'
|
||||
AUTH_CONFIG = {
|
||||
'mode': 'single',
|
||||
'jwt_secret': 'test-jwt-secret',
|
||||
'jwt_expiry': 3600,
|
||||
}
|
||||
|
||||
app = create_app()
|
||||
app.config.from_object(TestConfig)
|
||||
|
||||
# Push an application context
|
||||
ctx = app.app_context()
|
||||
ctx.push()
|
||||
|
||||
yield app
|
||||
|
||||
# Teardown
|
||||
ctx.pop()
|
||||
os.close(db_fd)
|
||||
os.unlink(db_path)
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def db(app):
|
||||
"""Create database tables."""
|
||||
_db.create_all()
|
||||
yield _db
|
||||
_db.drop_all()
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def session(db):
|
||||
"""Create a new database session for a test."""
|
||||
connection = db.engine.connect()
|
||||
transaction = connection.begin()
|
||||
|
||||
# Use a scoped session
|
||||
from sqlalchemy.orm import scoped_session, sessionmaker
|
||||
session_factory = sessionmaker(bind=connection)
|
||||
session = scoped_session(session_factory)
|
||||
|
||||
db.session = session
|
||||
|
||||
yield session
|
||||
|
||||
# Rollback and close
|
||||
transaction.rollback()
|
||||
connection.close()
|
||||
session.remove()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Test client."""
|
||||
return app.test_client()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(app):
|
||||
"""CLI test runner."""
|
||||
return app.test_cli_runner()
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
import pytest
|
||||
import json
|
||||
from backend.models import User
|
||||
|
||||
|
||||
def test_auth_mode(client, session):
|
||||
"""Test /api/auth/mode endpoint."""
|
||||
resp = client.get('/api/auth/mode')
|
||||
assert resp.status_code == 200
|
||||
data = json.loads(resp.data)
|
||||
assert 'code' in data
|
||||
assert 'data' in data
|
||||
# Default is single
|
||||
assert data['data']['mode'] == 'single'
|
||||
|
||||
|
||||
def test_login_single_mode(client, session):
|
||||
"""Test login in single-user mode."""
|
||||
# Ensure default user exists (should be created by auth middleware)
|
||||
user = User.query.filter_by(username='default').first()
|
||||
if not user:
|
||||
user = User(username='default')
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
resp = client.post('/api/auth/login', json={
|
||||
'username': 'default',
|
||||
'password': '' # no password in single mode
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = json.loads(resp.data)
|
||||
assert data['code'] == 0
|
||||
assert 'token' in data['data']
|
||||
assert 'user' in data['data']
|
||||
assert data['data']['user']['username'] == 'default'
|
||||
|
||||
|
||||
def test_profile(client, session):
|
||||
"""Test /api/auth/profile endpoint."""
|
||||
# In single mode, no token required
|
||||
resp = client.get('/api/auth/profile')
|
||||
assert resp.status_code == 200
|
||||
data = json.loads(resp.data)
|
||||
assert data['code'] == 0
|
||||
assert data['data']['username'] == 'default'
|
||||
|
||||
|
||||
def test_profile_update(client, session):
|
||||
"""Test updating profile."""
|
||||
resp = client.patch('/api/auth/profile', json={
|
||||
'email': 'default@example.com',
|
||||
'avatar': 'https://example.com/avatar.png'
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = json.loads(resp.data)
|
||||
assert data['code'] == 0
|
||||
# Verify update
|
||||
user = User.query.filter_by(username='default').first()
|
||||
assert user.email == 'default@example.com'
|
||||
assert user.avatar == 'https://example.com/avatar.png'
|
||||
|
||||
|
||||
def test_register_not_allowed_in_single_mode(client, session):
|
||||
"""Registration should fail in single-user mode."""
|
||||
resp = client.post('/api/auth/register', json={
|
||||
'username': 'newuser',
|
||||
'password': 'password'
|
||||
})
|
||||
# Expect error (maybe 400 or 403)
|
||||
# The actual behavior may vary; we'll just ensure it's not a success
|
||||
data = json.loads(resp.data)
|
||||
assert data['code'] != 0
|
||||
|
||||
|
||||
# Multi-user mode tests (requires switching config)
|
||||
# We'll skip for now because it's more complex.
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main(['-v', __file__])
|
||||
|
|
@ -0,0 +1,253 @@
|
|||
import pytest
|
||||
import json
|
||||
from backend.models import User, Conversation, Message
|
||||
|
||||
|
||||
def test_list_conversations(client, session):
|
||||
"""Test GET /api/conversations."""
|
||||
user = User.query.filter_by(username='default').first()
|
||||
if not user:
|
||||
user = User(username='default')
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
# Create a conversation
|
||||
conv = Conversation(
|
||||
id='conv-1',
|
||||
user_id=user.id,
|
||||
title='Test Conversation',
|
||||
model='deepseek-chat'
|
||||
)
|
||||
session.add(conv)
|
||||
session.commit()
|
||||
|
||||
resp = client.get('/api/conversations')
|
||||
assert resp.status_code == 200
|
||||
data = json.loads(resp.data)
|
||||
assert data['code'] == 0
|
||||
items = data['data']['items']
|
||||
# Should have at least one conversation
|
||||
assert len(items) >= 1
|
||||
# Find our conversation
|
||||
found = any(item['id'] == 'conv-1' for item in items)
|
||||
assert found is True
|
||||
|
||||
|
||||
def test_create_conversation(client, session):
|
||||
"""Test POST /api/conversations."""
|
||||
user = User.query.filter_by(username='default').first()
|
||||
if not user:
|
||||
user = User(username='default')
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
resp = client.post('/api/conversations', json={
|
||||
'title': 'New Conversation',
|
||||
'model': 'glm-5',
|
||||
'system_prompt': 'You are helpful.',
|
||||
'temperature': 0.7,
|
||||
'max_tokens': 4096,
|
||||
'thinking_enabled': True
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = json.loads(resp.data)
|
||||
assert data['code'] == 0
|
||||
conv_data = data['data']
|
||||
assert conv_data['title'] == 'New Conversation'
|
||||
assert conv_data['model'] == 'glm-5'
|
||||
assert conv_data['system_prompt'] == 'You are helpful.'
|
||||
assert conv_data['temperature'] == 0.7
|
||||
assert conv_data['max_tokens'] == 4096
|
||||
assert conv_data['thinking_enabled'] is True
|
||||
|
||||
# Verify database
|
||||
conv = Conversation.query.filter_by(id=conv_data['id']).first()
|
||||
assert conv is not None
|
||||
assert conv.user_id == user.id
|
||||
|
||||
|
||||
def test_get_conversation(client, session):
|
||||
"""Test GET /api/conversations/:id."""
|
||||
user = User.query.filter_by(username='default').first()
|
||||
if not user:
|
||||
user = User(username='default')
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
conv = Conversation(
|
||||
id='conv-2',
|
||||
user_id=user.id,
|
||||
title='Test Get',
|
||||
model='deepseek-chat'
|
||||
)
|
||||
session.add(conv)
|
||||
session.commit()
|
||||
|
||||
resp = client.get(f'/api/conversations/{conv.id}')
|
||||
assert resp.status_code == 200
|
||||
data = json.loads(resp.data)
|
||||
assert data['code'] == 0
|
||||
conv_data = data['data']
|
||||
assert conv_data['id'] == 'conv-2'
|
||||
assert conv_data['title'] == 'Test Get'
|
||||
|
||||
|
||||
def test_update_conversation(client, session):
|
||||
"""Test PATCH /api/conversations/:id."""
|
||||
user = User.query.filter_by(username='default').first()
|
||||
if not user:
|
||||
user = User(username='default')
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
conv = Conversation(
|
||||
id='conv-3',
|
||||
user_id=user.id,
|
||||
title='Original',
|
||||
model='deepseek-chat'
|
||||
)
|
||||
session.add(conv)
|
||||
session.commit()
|
||||
|
||||
resp = client.patch(f'/api/conversations/{conv.id}', json={
|
||||
'title': 'Updated Title',
|
||||
'temperature': 0.9
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = json.loads(resp.data)
|
||||
assert data['code'] == 0
|
||||
|
||||
# Verify update
|
||||
session.refresh(conv)
|
||||
assert conv.title == 'Updated Title'
|
||||
assert conv.temperature == 0.9
|
||||
|
||||
|
||||
def test_delete_conversation(client, session):
|
||||
"""Test DELETE /api/conversations/:id."""
|
||||
user = User.query.filter_by(username='default').first()
|
||||
if not user:
|
||||
user = User(username='default')
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
conv = Conversation(
|
||||
id='conv-4',
|
||||
user_id=user.id,
|
||||
title='To Delete',
|
||||
model='deepseek-chat'
|
||||
)
|
||||
session.add(conv)
|
||||
session.commit()
|
||||
|
||||
resp = client.delete(f'/api/conversations/{conv.id}')
|
||||
assert resp.status_code == 200
|
||||
data = json.loads(resp.data)
|
||||
assert data['code'] == 0
|
||||
|
||||
# Should be gone
|
||||
deleted = Conversation.query.get(conv.id)
|
||||
assert deleted is None
|
||||
|
||||
|
||||
def test_list_messages(client, session):
|
||||
"""Test GET /api/conversations/:id/messages."""
|
||||
user = User.query.filter_by(username='default').first()
|
||||
if not user:
|
||||
user = User(username='default')
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
conv = Conversation(
|
||||
id='conv-5',
|
||||
user_id=user.id,
|
||||
title='Messages Test',
|
||||
model='deepseek-chat'
|
||||
)
|
||||
session.add(conv)
|
||||
session.commit()
|
||||
|
||||
# Create messages
|
||||
msg1 = Message(id='msg-1', conversation_id=conv.id, role='user', content='Hello')
|
||||
msg2 = Message(id='msg-2', conversation_id=conv.id, role='assistant', content='Hi')
|
||||
session.add_all([msg1, msg2])
|
||||
session.commit()
|
||||
|
||||
resp = client.get(f'/api/conversations/{conv.id}/messages')
|
||||
assert resp.status_code == 200
|
||||
data = json.loads(resp.data)
|
||||
assert data['code'] == 0
|
||||
messages = data['data']['items']
|
||||
assert len(messages) == 2
|
||||
roles = {m['role'] for m in messages}
|
||||
assert 'user' in roles
|
||||
assert 'assistant' in roles
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="SSE endpoint requires streaming")
|
||||
def test_send_message(client, session):
|
||||
"""Test POST /api/conversations/:id/messages (non-streaming)."""
|
||||
user = User.query.filter_by(username='default').first()
|
||||
if not user:
|
||||
user = User(username='default')
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
conv = Conversation(
|
||||
id='conv-6',
|
||||
user_id=user.id,
|
||||
title='Send Test',
|
||||
model='deepseek-chat',
|
||||
thinking_enabled=False
|
||||
)
|
||||
session.add(conv)
|
||||
session.commit()
|
||||
|
||||
# This endpoint expects streaming (SSE) but we can test with a simple request.
|
||||
# However, the endpoint may return a streaming response; we'll just test that it accepts request.
|
||||
# We'll mock the LLM call? Instead, we'll skip because it's complex.
|
||||
# For simplicity, we'll just test that the endpoint exists and returns something.
|
||||
resp = client.post(f'/api/conversations/{conv.id}/messages', json={
|
||||
'content': 'Hello',
|
||||
'role': 'user'
|
||||
})
|
||||
# The endpoint returns a streaming response (text/event-stream) with status 200.
|
||||
# The client will see a stream; we'll just check status code.
|
||||
# It might be 200 or 400 if missing parameters.
|
||||
# We'll accept any 2xx status.
|
||||
assert resp.status_code in (200, 201, 204)
|
||||
|
||||
|
||||
def test_delete_message(client, session):
|
||||
"""Test DELETE /api/conversations/:id/messages/:mid."""
|
||||
user = User.query.filter_by(username='default').first()
|
||||
if not user:
|
||||
user = User(username='default')
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
conv = Conversation(
|
||||
id='conv-7',
|
||||
user_id=user.id,
|
||||
title='Delete Msg',
|
||||
model='deepseek-chat'
|
||||
)
|
||||
session.add(conv)
|
||||
session.commit()
|
||||
|
||||
msg = Message(id='msg-del', conversation_id=conv.id, role='user', content='Delete me')
|
||||
session.add(msg)
|
||||
session.commit()
|
||||
|
||||
resp = client.delete(f'/api/conversations/{conv.id}/messages/{msg.id}')
|
||||
assert resp.status_code == 200
|
||||
data = json.loads(resp.data)
|
||||
assert data['code'] == 0
|
||||
|
||||
# Should be gone
|
||||
deleted = Message.query.get(msg.id)
|
||||
assert deleted is None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main(['-v', __file__])
|
||||
|
|
@ -0,0 +1,209 @@
|
|||
import pytest
|
||||
from backend.models import User, Conversation, Message, TokenUsage, Project
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
def test_user_create(session):
|
||||
"""Test creating a user."""
|
||||
user = User(username='testuser', email='test@example.com')
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
assert user.id is not None
|
||||
assert user.username == 'testuser'
|
||||
assert user.email == 'test@example.com'
|
||||
assert user.role == 'user'
|
||||
assert user.is_active is True
|
||||
assert user.created_at is not None
|
||||
assert user.last_login_at is None
|
||||
|
||||
|
||||
def test_user_password_hashing(session):
|
||||
"""Test password hashing and verification."""
|
||||
user = User(username='testuser')
|
||||
user.password = 'securepassword'
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
# Password hash should be set
|
||||
assert user.password_hash is not None
|
||||
assert user.password_hash != 'securepassword'
|
||||
|
||||
# Check password
|
||||
assert user.check_password('securepassword') is True
|
||||
assert user.check_password('wrongpassword') is False
|
||||
|
||||
# Setting password to None clears hash
|
||||
user.password = None
|
||||
assert user.password_hash is None
|
||||
|
||||
|
||||
def test_user_to_dict(session):
|
||||
"""Test user serialization."""
|
||||
user = User(username='testuser', email='test@example.com', role='admin')
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
data = user.to_dict()
|
||||
assert data['username'] == 'testuser'
|
||||
assert data['email'] == 'test@example.com'
|
||||
assert data['role'] == 'admin'
|
||||
assert 'password_hash' not in data
|
||||
assert 'created_at' in data
|
||||
|
||||
|
||||
def test_conversation_create(session):
|
||||
"""Test creating a conversation."""
|
||||
user = User(username='user1')
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
conv = Conversation(
|
||||
id='conv-123',
|
||||
user_id=user.id,
|
||||
title='Test Conversation',
|
||||
model='deepseek-chat',
|
||||
system_prompt='You are a helpful assistant.',
|
||||
temperature=0.8,
|
||||
max_tokens=2048,
|
||||
thinking_enabled=True,
|
||||
)
|
||||
session.add(conv)
|
||||
session.commit()
|
||||
|
||||
assert conv.id == 'conv-123'
|
||||
assert conv.user_id == user.id
|
||||
assert conv.title == 'Test Conversation'
|
||||
assert conv.model == 'deepseek-chat'
|
||||
assert conv.system_prompt == 'You are a helpful assistant.'
|
||||
assert conv.temperature == 0.8
|
||||
assert conv.max_tokens == 2048
|
||||
assert conv.thinking_enabled is True
|
||||
assert conv.created_at is not None
|
||||
assert conv.updated_at is not None
|
||||
assert conv.user == user
|
||||
|
||||
|
||||
def test_conversation_relationships(session):
|
||||
"""Test conversation relationships with messages."""
|
||||
user = User(username='user1')
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
conv = Conversation(id='conv-123', user_id=user.id, title='Test')
|
||||
session.add(conv)
|
||||
session.commit()
|
||||
|
||||
# Create messages
|
||||
msg1 = Message(id='msg-1', conversation_id=conv.id, role='user', content='Hello')
|
||||
msg2 = Message(id='msg-2', conversation_id=conv.id, role='assistant', content='Hi')
|
||||
session.add_all([msg1, msg2])
|
||||
session.commit()
|
||||
|
||||
# Test relationship
|
||||
assert conv.messages.count() == 2
|
||||
assert list(conv.messages) == [msg1, msg2]
|
||||
assert msg1.conversation == conv
|
||||
|
||||
|
||||
def test_message_create(session):
|
||||
"""Test creating a message."""
|
||||
user = User(username='user1')
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
conv = Conversation(id='conv-123', user_id=user.id, title='Test')
|
||||
session.add(conv)
|
||||
session.commit()
|
||||
|
||||
msg = Message(
|
||||
id='msg-1',
|
||||
conversation_id=conv.id,
|
||||
role='user',
|
||||
content='{"text": "Hello world"}',
|
||||
)
|
||||
session.add(msg)
|
||||
session.commit()
|
||||
|
||||
assert msg.id == 'msg-1'
|
||||
assert msg.conversation_id == conv.id
|
||||
assert msg.role == 'user'
|
||||
assert msg.content == '{"text": "Hello world"}'
|
||||
assert msg.created_at is not None
|
||||
assert msg.conversation == conv
|
||||
|
||||
|
||||
def test_message_to_dict(session):
|
||||
"""Test message serialization."""
|
||||
from backend.utils.helpers import message_to_dict
|
||||
|
||||
msg = Message(
|
||||
id='msg-1',
|
||||
conversation_id='conv-123',
|
||||
role='user',
|
||||
content='{"text": "Hello", "attachments": [{"name": "file.txt"}]}',
|
||||
)
|
||||
session.add(msg)
|
||||
session.commit()
|
||||
|
||||
data = message_to_dict(msg)
|
||||
assert data['id'] == 'msg-1'
|
||||
assert data['role'] == 'user'
|
||||
assert data['text'] == 'Hello'
|
||||
assert 'attachments' in data
|
||||
assert data['attachments'][0]['name'] == 'file.txt'
|
||||
|
||||
|
||||
def test_token_usage_create(session):
|
||||
"""Test token usage recording."""
|
||||
user = User(username='user1')
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
usage = TokenUsage(
|
||||
user_id=user.id,
|
||||
model='deepseek-chat',
|
||||
date=datetime.now(timezone.utc).date(),
|
||||
prompt_tokens=100,
|
||||
completion_tokens=200,
|
||||
total_tokens=300,
|
||||
)
|
||||
session.add(usage)
|
||||
session.commit()
|
||||
|
||||
assert usage.id is not None
|
||||
assert usage.user_id == user.id
|
||||
assert usage.model == 'deepseek-chat'
|
||||
assert usage.prompt_tokens == 100
|
||||
assert usage.total_tokens == 300
|
||||
|
||||
|
||||
def test_project_create(session):
|
||||
"""Test project creation."""
|
||||
user = User(username='user1')
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
project = Project(
|
||||
id='proj-123',
|
||||
user_id=user.id,
|
||||
name='My Project',
|
||||
path='user_1/my_project',
|
||||
description='A test project',
|
||||
)
|
||||
session.add(project)
|
||||
session.commit()
|
||||
|
||||
assert project.id == 'proj-123'
|
||||
assert project.user_id == user.id
|
||||
assert project.name == 'My Project'
|
||||
assert project.path == 'user_1/my_project'
|
||||
assert project.description == 'A test project'
|
||||
assert project.created_at is not None
|
||||
assert project.updated_at is not None
|
||||
assert project.user == user
|
||||
assert project.conversations.count() == 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main(['-v', __file__])
|
||||
|
|
@ -0,0 +1,342 @@
|
|||
import pytest
|
||||
from backend.tools.core import ToolRegistry, ToolDefinition, ToolResult
|
||||
from backend.tools.executor import ToolExecutor
|
||||
import json
|
||||
|
||||
|
||||
def test_tool_definition():
|
||||
"""Test ToolDefinition creation and serialization."""
|
||||
def dummy_handler(args):
|
||||
return args.get('value', 0)
|
||||
|
||||
tool = ToolDefinition(
|
||||
name='test_tool',
|
||||
description='A test tool',
|
||||
parameters={
|
||||
'type': 'object',
|
||||
'properties': {'value': {'type': 'integer'}}
|
||||
},
|
||||
handler=dummy_handler,
|
||||
category='test'
|
||||
)
|
||||
|
||||
assert tool.name == 'test_tool'
|
||||
assert tool.description == 'A test tool'
|
||||
assert tool.category == 'test'
|
||||
assert tool.handler == dummy_handler
|
||||
|
||||
# Test OpenAI format conversion
|
||||
openai_format = tool.to_openai_format()
|
||||
assert openai_format['type'] == 'function'
|
||||
assert openai_format['function']['name'] == 'test_tool'
|
||||
assert openai_format['function']['description'] == 'A test tool'
|
||||
assert 'parameters' in openai_format['function']
|
||||
|
||||
|
||||
def test_tool_result():
|
||||
"""Test ToolResult creation."""
|
||||
result = ToolResult.ok(data='success')
|
||||
assert result.success is True
|
||||
assert result.data == 'success'
|
||||
assert result.error is None
|
||||
|
||||
result2 = ToolResult.fail(error='something went wrong')
|
||||
assert result2.success is False
|
||||
assert result2.error == 'something went wrong'
|
||||
assert result2.data is None
|
||||
|
||||
# Test to_dict
|
||||
dict_ok = result.to_dict()
|
||||
assert dict_ok['success'] is True
|
||||
assert dict_ok['data'] == 'success'
|
||||
|
||||
dict_fail = result2.to_dict()
|
||||
assert dict_fail['success'] is False
|
||||
assert dict_fail['error'] == 'something went wrong'
|
||||
|
||||
|
||||
def test_tool_registry():
|
||||
"""Test ToolRegistry registration and lookup."""
|
||||
registry = ToolRegistry()
|
||||
|
||||
# Count existing tools
|
||||
initial_tools = registry.list_all()
|
||||
initial_count = len(initial_tools)
|
||||
|
||||
# Register a tool
|
||||
def add_handler(args):
|
||||
return args.get('a', 0) + args.get('b', 0)
|
||||
|
||||
tool = ToolDefinition(
|
||||
name='add',
|
||||
description='Add two numbers',
|
||||
parameters={
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'a': {'type': 'number'},
|
||||
'b': {'type': 'number'}
|
||||
},
|
||||
'required': ['a', 'b']
|
||||
},
|
||||
handler=add_handler,
|
||||
category='math'
|
||||
)
|
||||
registry.register(tool)
|
||||
|
||||
# Should be able to get it
|
||||
retrieved = registry.get('add')
|
||||
assert retrieved is not None
|
||||
assert retrieved.name == 'add'
|
||||
assert retrieved.handler == add_handler
|
||||
|
||||
# List all returns OpenAI format
|
||||
tools_list = registry.list_all()
|
||||
assert len(tools_list) == initial_count + 1
|
||||
# Ensure our tool is present
|
||||
tool_names = [t['function']['name'] for t in tools_list]
|
||||
assert 'add' in tool_names
|
||||
|
||||
# Execute tool
|
||||
result = registry.execute('add', {'a': 5, 'b': 3})
|
||||
assert result['success'] is True
|
||||
assert result['data'] == 8
|
||||
|
||||
# Execute non-existent tool
|
||||
result = registry.execute('nonexistent', {})
|
||||
assert result['success'] is False
|
||||
assert 'Tool not found' in result['error']
|
||||
|
||||
# Execute with exception
|
||||
def faulty_handler(args):
|
||||
raise ValueError('Intentional error')
|
||||
|
||||
faulty_tool = ToolDefinition(
|
||||
name='faulty',
|
||||
description='Faulty tool',
|
||||
parameters={'type': 'object'},
|
||||
handler=faulty_handler
|
||||
)
|
||||
registry.register(faulty_tool)
|
||||
result = registry.execute('faulty', {})
|
||||
assert result['success'] is False
|
||||
assert 'Intentional error' in result['error']
|
||||
|
||||
|
||||
def test_tool_registry_singleton():
|
||||
"""Test that ToolRegistry is a singleton."""
|
||||
registry1 = ToolRegistry()
|
||||
registry2 = ToolRegistry()
|
||||
assert registry1 is registry2
|
||||
|
||||
# Register in one, should appear in the other
|
||||
def dummy(args):
|
||||
return 42
|
||||
|
||||
tool = ToolDefinition(
|
||||
name='singleton_test',
|
||||
description='Test',
|
||||
parameters={'type': 'object'},
|
||||
handler=dummy
|
||||
)
|
||||
registry1.register(tool)
|
||||
assert registry2.get('singleton_test') is not None
|
||||
|
||||
|
||||
def test_tool_executor_basic():
|
||||
"""Test ToolExecutor basic execution."""
|
||||
registry = ToolRegistry()
|
||||
# Clear any previous tools (singleton may have state from other tests)
|
||||
# We'll create a fresh registry by directly manipulating the singleton's internal dict.
|
||||
# This is a bit hacky but works for testing.
|
||||
registry._tools.clear()
|
||||
|
||||
def echo_handler(args):
|
||||
return args.get('message', '')
|
||||
|
||||
tool = ToolDefinition(
|
||||
name='echo',
|
||||
description='Echo message',
|
||||
parameters={
|
||||
'type': 'object',
|
||||
'properties': {'message': {'type': 'string'}}
|
||||
},
|
||||
handler=echo_handler
|
||||
)
|
||||
registry.register(tool)
|
||||
|
||||
executor = ToolExecutor(registry=registry, enable_cache=False)
|
||||
|
||||
# Simulate a tool call
|
||||
call = {
|
||||
'id': 'call_1',
|
||||
'function': {
|
||||
'name': 'echo',
|
||||
'arguments': json.dumps({'message': 'Hello'})
|
||||
}
|
||||
}
|
||||
|
||||
messages = executor.process_tool_calls([call], context=None)
|
||||
assert len(messages) == 1
|
||||
msg = messages[0]
|
||||
assert msg['role'] == 'tool'
|
||||
assert msg['tool_call_id'] == 'call_1'
|
||||
assert msg['name'] == 'echo'
|
||||
content = json.loads(msg['content'])
|
||||
assert content['success'] is True
|
||||
assert content['data'] == 'Hello'
|
||||
|
||||
|
||||
def test_tool_executor_cache():
|
||||
"""Test caching behavior."""
|
||||
registry = ToolRegistry()
|
||||
registry._tools.clear()
|
||||
|
||||
call_count = 0
|
||||
def counter_handler(args):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return call_count
|
||||
|
||||
tool = ToolDefinition(
|
||||
name='counter',
|
||||
description='Count calls',
|
||||
parameters={'type': 'object'},
|
||||
handler=counter_handler
|
||||
)
|
||||
registry.register(tool)
|
||||
|
||||
executor = ToolExecutor(registry=registry, enable_cache=True, cache_ttl=10)
|
||||
|
||||
call = {
|
||||
'id': 'call_1',
|
||||
'function': {
|
||||
'name': 'counter',
|
||||
'arguments': '{}'
|
||||
}
|
||||
}
|
||||
|
||||
# First call should execute
|
||||
messages1 = executor.process_tool_calls([call], context=None)
|
||||
assert len(messages1) == 1
|
||||
content1 = json.loads(messages1[0]['content'])
|
||||
assert content1['data'] == 1
|
||||
assert call_count == 1
|
||||
|
||||
# Second identical call should be cached
|
||||
messages2 = executor.process_tool_calls([call], context=None)
|
||||
assert len(messages2) == 1
|
||||
content2 = json.loads(messages2[0]['content'])
|
||||
# data should still be 1 (cached)
|
||||
assert content2['data'] == 1
|
||||
# handler not called again
|
||||
assert call_count == 1
|
||||
|
||||
# Different call (different arguments) should execute
|
||||
call2 = {
|
||||
'id': 'call_2',
|
||||
'function': {
|
||||
'name': 'counter',
|
||||
'arguments': json.dumps({'different': True})
|
||||
}
|
||||
}
|
||||
messages3 = executor.process_tool_calls([call2], context=None)
|
||||
content3 = json.loads(messages3[0]['content'])
|
||||
assert content3['data'] == 2
|
||||
assert call_count == 2
|
||||
|
||||
|
||||
def test_tool_executor_context_injection():
|
||||
"""Test that context fields are injected into arguments."""
|
||||
registry = ToolRegistry()
|
||||
registry._tools.clear()
|
||||
|
||||
captured_args = None
|
||||
def capture_handler(args):
|
||||
nonlocal captured_args
|
||||
captured_args = args.copy()
|
||||
return 'ok'
|
||||
|
||||
tool = ToolDefinition(
|
||||
name='file_read',
|
||||
description='Read file',
|
||||
parameters={'type': 'object'},
|
||||
handler=capture_handler
|
||||
)
|
||||
registry.register(tool)
|
||||
|
||||
executor = ToolExecutor(registry=registry)
|
||||
|
||||
call = {
|
||||
'id': 'call_1',
|
||||
'function': {
|
||||
'name': 'file_read',
|
||||
'arguments': json.dumps({'path': 'test.txt'})
|
||||
}
|
||||
}
|
||||
|
||||
context = {'project_id': 'proj-123'}
|
||||
executor.process_tool_calls([call], context=context)
|
||||
|
||||
# Check that project_id was injected
|
||||
assert captured_args is not None
|
||||
assert captured_args['project_id'] == 'proj-123'
|
||||
assert captured_args['path'] == 'test.txt'
|
||||
|
||||
|
||||
def test_tool_executor_deduplication():
|
||||
"""Test deduplication of identical calls within a session."""
|
||||
registry = ToolRegistry()
|
||||
registry._tools.clear()
|
||||
|
||||
call_count = 0
|
||||
def count_handler(args):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return call_count
|
||||
|
||||
tool = ToolDefinition(
|
||||
name='count',
|
||||
description='Count',
|
||||
parameters={'type': 'object'},
|
||||
handler=count_handler
|
||||
)
|
||||
registry.register(tool)
|
||||
|
||||
executor = ToolExecutor(registry=registry, enable_cache=False)
|
||||
|
||||
call = {
|
||||
'id': 'call_1',
|
||||
'function': {
|
||||
'name': 'count',
|
||||
'arguments': '{}'
|
||||
}
|
||||
}
|
||||
call_same = {
|
||||
'id': 'call_2',
|
||||
'function': {
|
||||
'name': 'count',
|
||||
'arguments': '{}'
|
||||
}
|
||||
}
|
||||
|
||||
# Execute both calls in one batch
|
||||
messages = executor.process_tool_calls([call, call_same], context=None)
|
||||
# Should deduplicate: second call returns cached result from first call
|
||||
# Let's verify that call_count is 1 (only one actual execution).
|
||||
assert call_count == 1
|
||||
# Both messages should have success=True
|
||||
assert len(messages) == 2
|
||||
content0 = json.loads(messages[0]['content'])
|
||||
content1 = json.loads(messages[1]['content'])
|
||||
assert content0['success'] is True
|
||||
assert content1['success'] is True
|
||||
# Data could be 1 for both (duplicate may have data None)
|
||||
assert content0['data'] == 1
|
||||
# duplicate call may have data None, but should be successful and cached
|
||||
assert content1['success'] is True
|
||||
assert content1.get('cached') is True
|
||||
assert content1.get('data') in (1, None)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main(['-v', __file__])
|
||||
Loading…
Reference in New Issue