refactor: 清理死代码,支持按模型配置独立的 api_url 和 api_key

This commit is contained in:
ViperEkura 2026-03-26 17:55:38 +08:00
parent a847d91886
commit a24eb8e24f
13 changed files with 47 additions and 122 deletions

View File

@ -1,4 +1,3 @@
import os
from pathlib import Path from pathlib import Path
import yaml import yaml

View File

@ -3,7 +3,21 @@ from backend import load_config
_cfg = load_config() _cfg = load_config()
API_URL = _cfg.get("api_url") # Global defaults
API_KEY = _cfg["api_key"] DEFAULT_API_URL = _cfg.get("default_api_url", "")
DEFAULT_API_KEY = _cfg.get("default_api_key", "")
# Model list (for /api/models endpoint)
MODELS = _cfg.get("models", []) MODELS = _cfg.get("models", [])
# Per-model config lookup: {model_id: {api_url, api_key}}
# Falls back to global defaults if not specified per model
MODEL_CONFIG = {}
for _m in MODELS:
_mid = _m["id"]
MODEL_CONFIG[_mid] = {
"api_url": _m.get("api_url", DEFAULT_API_URL),
"api_key": _m.get("api_key", DEFAULT_API_KEY),
}
DEFAULT_MODEL = _cfg.get("default_model", "glm-5") DEFAULT_MODEL = _cfg.get("default_model", "glm-5")

View File

@ -7,13 +7,13 @@ from backend.routes.tools import bp as tools_bp
from backend.routes.stats import bp as stats_bp from backend.routes.stats import bp as stats_bp
from backend.routes.projects import bp as projects_bp from backend.routes.projects import bp as projects_bp
from backend.services.glm_client import GLMClient from backend.services.glm_client import GLMClient
from backend.config import API_URL, API_KEY from backend.config import MODEL_CONFIG
def register_routes(app: Flask): def register_routes(app: Flask):
"""Register all route blueprints""" """Register all route blueprints"""
# Initialize GLM client and chat service # Initialize GLM client with per-model config
glm_client = GLMClient(API_URL, API_KEY) glm_client = GLMClient(MODEL_CONFIG)
init_chat_service(glm_client) init_chat_service(glm_client)
# Register blueprints # Register blueprints

View File

@ -1,6 +1,6 @@
"""Conversation API routes""" """Conversation API routes"""
import uuid import uuid
from datetime import datetime from datetime import datetime, timezone
from flask import Blueprint, request from flask import Blueprint, request
from backend import db from backend import db
from backend.models import Conversation, Project from backend.models import Conversation, Project
@ -62,7 +62,7 @@ def conversation_list():
if cursor: if cursor:
q = q.filter(Conversation.updated_at < ( q = q.filter(Conversation.updated_at < (
db.session.query(Conversation.updated_at).filter_by(id=cursor).scalar() or datetime.utcnow)) db.session.query(Conversation.updated_at).filter_by(id=cursor).scalar() or datetime.now(timezone.utc)))
rows = q.order_by(Conversation.updated_at.desc()).limit(limit + 1).all() rows = q.order_by(Conversation.updated_at.desc()).limit(limit + 1).all()
items = [_conv_to_dict(r, message_count=r.messages.count()) for r in rows[:limit]] items = [_conv_to_dict(r, message_count=r.messages.count()) for r in rows[:limit]]

View File

@ -1,11 +1,11 @@
"""Message API routes""" """Message API routes"""
import json import json
import uuid import uuid
from datetime import datetime from datetime import datetime, timezone
from flask import Blueprint, request from flask import Blueprint, request
from backend import db from backend import db
from backend.models import Conversation, Message from backend.models import Conversation, Message
from backend.utils.helpers import ok, err, to_dict, message_to_dict, get_or_create_default_user from backend.utils.helpers import ok, err, message_to_dict
from backend.services.chat import ChatService from backend.services.chat import ChatService
@ -34,7 +34,7 @@ def message_list(conv_id):
q = Message.query.filter_by(conversation_id=conv_id) q = Message.query.filter_by(conversation_id=conv_id)
if cursor: if cursor:
q = q.filter(Message.created_at < ( q = q.filter(Message.created_at < (
db.session.query(Message.created_at).filter_by(id=cursor).scalar() or datetime.utcnow)) db.session.query(Message.created_at).filter_by(id=cursor).scalar() or datetime.now(timezone.utc)))
rows = q.order_by(Message.created_at.asc()).limit(limit + 1).all() rows = q.order_by(Message.created_at.asc()).limit(limit + 1).all()
items = [message_to_dict(r) for r in rows[:limit]] items = [message_to_dict(r) for r in rows[:limit]]
@ -48,7 +48,6 @@ def message_list(conv_id):
d = request.json or {} d = request.json or {}
text = (d.get("text") or "").strip() text = (d.get("text") or "").strip()
attachments = d.get("attachments") # [{"name": "a.py", "extension": "py", "content": "..."}] attachments = d.get("attachments") # [{"name": "a.py", "extension": "py", "content": "..."}]
project_id = d.get("project_id") # Get project_id from request
if not text and not attachments: if not text and not attachments:
return err(400, "text or attachments is required") return err(400, "text or attachments is required")

View File

@ -2,8 +2,7 @@
import os import os
import uuid import uuid
import shutil import shutil
from flask import Blueprint, request, jsonify from flask import Blueprint, request
from werkzeug.utils import secure_filename
from backend import db from backend import db
from backend.models import Project, User from backend.models import Project, User

View File

@ -6,9 +6,18 @@ from typing import Optional, List
class GLMClient: class GLMClient:
"""GLM API client for chat completions""" """GLM API client for chat completions"""
def __init__(self, api_url: str, api_key: str): def __init__(self, model_config: dict):
self.api_url = api_url """Initialize with per-model config lookup.
self.api_key = api_key
Args:
model_config: {model_id: {"api_url": ..., "api_key": ...}}
"""
self.model_config = model_config
def _get_credentials(self, model: str):
"""Get api_url and api_key for a model, with fallback."""
cfg = self.model_config.get(model, {})
return cfg.get("api_url", ""), cfg.get("api_key", "")
def call( def call(
self, self,
@ -22,6 +31,7 @@ class GLMClient:
timeout: int = 120, timeout: int = 120,
): ):
"""Call GLM API""" """Call GLM API"""
api_url, api_key = self._get_credentials(model)
body = { body = {
"model": model, "model": model,
"messages": messages, "messages": messages,
@ -37,10 +47,10 @@ class GLMClient:
body["stream"] = True body["stream"] = True
return requests.post( return requests.post(
self.api_url, api_url,
headers={ headers={
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}" "Authorization": f"Bearer {api_key}"
}, },
json=body, json=body,
stream=stream, stream=stream,

View File

@ -1,8 +1,7 @@
"""File operation tools""" """File operation tools"""
import os
import json import json
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple from typing import Tuple
from backend.tools.factory import tool from backend.tools.factory import tool
from backend import db from backend import db
from backend.models import Project from backend.models import Project

View File

@ -1,5 +1,5 @@
"""Tool system core classes""" """Tool system core classes"""
from dataclasses import dataclass, field from dataclasses import dataclass
from typing import Callable, Any, Dict, List, Optional from typing import Callable, Any, Dict, List, Optional
@ -69,14 +69,6 @@ class ToolRegistry:
"""List all tools in OpenAI format""" """List all tools in OpenAI format"""
return [t.to_openai_format() for t in self._tools.values()] return [t.to_openai_format() for t in self._tools.values()]
def list_by_category(self, category: str) -> List[dict]:
"""List tools by category"""
return [
t.to_openai_format()
for t in self._tools.values()
if t.category == category
]
def execute(self, name: str, arguments: dict) -> dict: def execute(self, name: str, arguments: dict) -> dict:
"""Execute a tool""" """Execute a tool"""
tool = self.get(name) tool = self.get(name)
@ -91,17 +83,6 @@ class ToolRegistry:
except Exception as e: except Exception as e:
return ToolResult.fail(str(e)).to_dict() return ToolResult.fail(str(e)).to_dict()
def remove(self, name: str) -> bool:
"""Remove a tool"""
if name in self._tools:
del self._tools[name]
return True
return False
def has(self, name: str) -> bool:
"""Check if tool exists"""
return name in self._tools
# Global registry instance # Global registry instance
registry = ToolRegistry() registry = ToolRegistry()

View File

@ -176,29 +176,5 @@ class ToolExecutor:
}, ensure_ascii=False) }, ensure_ascii=False)
} }
def build_request(
self,
messages: List[dict],
model: str = "glm-5",
tools: Optional[List[dict]] = None,
**kwargs
) -> dict:
"""
Build API request body
Args:
messages: Message list
model: Model name
tools: Tool list (default: all tools in registry)
**kwargs: Other parameters (temperature, max_tokens, etc.)
Returns:
Request body dict
"""
return {
"model": model,
"messages": messages,
"tools": tools or self.registry.list_all(),
"tool_choice": kwargs.get("tool_choice", "auto"),
**{k: v for k, v in kwargs.items() if k not in ["tool_choice"]}
}

View File

@ -1,5 +1,5 @@
"""Tool helper services""" """Tool helper services"""
from typing import List, Dict, Optional, Any from typing import List
from ddgs import DDGS from ddgs import DDGS
import re import re
@ -200,11 +200,6 @@ class ContentExtractor:
class CalculatorService: class CalculatorService:
"""Safe calculation service""" """Safe calculation service"""
ALLOWED_OPS = {
"add", "sub", "mul", "truediv", "floordiv",
"mod", "pow", "neg", "abs"
}
def evaluate(self, expression: str) -> dict: def evaluate(self, expression: str) -> dict:
""" """
Safely evaluate mathematical expression Safely evaluate mathematical expression

View File

@ -4,7 +4,7 @@ from datetime import date, datetime
from typing import Any from typing import Any
from flask import jsonify from flask import jsonify
from backend import db from backend import db
from backend.models import Conversation, Message, TokenUsage, User from backend.models import Message, TokenUsage, User
def get_or_create_default_user() -> User: def get_or_create_default_user() -> User:

View File

@ -1,9 +1,6 @@
"""Workspace path validation utilities""" """Workspace path validation utilities"""
import os
import shutil import shutil
from pathlib import Path from pathlib import Path
from typing import Optional
from flask import current_app
from backend import load_config from backend import load_config
@ -137,8 +134,6 @@ def delete_project_directory(project_path: str) -> bool:
return False return False
def save_uploaded_files(files, project_dir: Path) -> dict: def save_uploaded_files(files, project_dir: Path) -> dict:
""" """
Save uploaded files to project directory (for folder upload) Save uploaded files to project directory (for folder upload)
@ -187,45 +182,3 @@ def save_uploaded_files(files, project_dir: Path) -> dict:
"size": total_size "size": total_size
} }
def copy_folder_to_project(source_path: str, project_dir: Path, project_name: str) -> dict:
"""
Copy a folder to project directory (for folder upload)
Args:
source_path: Source folder path
project_dir: Target project directory
project_name: Project name
Returns:
Dict with copy statistics
"""
source = Path(source_path)
if not source.exists():
raise ValueError(f"Source path does not exist: {source_path}")
if not source.is_dir():
raise ValueError(f"Source path is not a directory: {source_path}")
# Security check: don't copy from sensitive system directories
sensitive_dirs = ["/etc", "/usr", "/bin", "/sbin", "/root", "/home"]
for sensitive in sensitive_dirs:
if str(source.resolve()).startswith(sensitive):
raise ValueError(f"Cannot copy from system directory: {sensitive}")
# Copy directory
if project_dir.exists():
shutil.rmtree(project_dir)
shutil.copytree(source, project_dir)
# Count files
file_count = sum(1 for _ in project_dir.rglob("*") if _.is_file())
dir_count = sum(1 for _ in project_dir.rglob("*") if _.is_dir())
return {
"files": file_count,
"directories": dir_count,
"size": sum(f.stat().st_size for f in project_dir.rglob("*") if f.is_file())
}