refactor: 清理死代码,支持按模型配置独立的 api_url 和 api_key
This commit is contained in:
parent
a847d91886
commit
a24eb8e24f
|
|
@ -1,4 +1,3 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
|
|
|||
|
|
@ -3,7 +3,21 @@ from backend import load_config
|
|||
|
||||
_cfg = load_config()
|
||||
|
||||
API_URL = _cfg.get("api_url")
|
||||
API_KEY = _cfg["api_key"]
|
||||
# Global defaults
|
||||
DEFAULT_API_URL = _cfg.get("default_api_url", "")
|
||||
DEFAULT_API_KEY = _cfg.get("default_api_key", "")
|
||||
|
||||
# Model list (for /api/models endpoint)
|
||||
MODELS = _cfg.get("models", [])
|
||||
|
||||
# Per-model config lookup: {model_id: {api_url, api_key}}
|
||||
# Falls back to global defaults if not specified per model
|
||||
MODEL_CONFIG = {}
|
||||
for _m in MODELS:
|
||||
_mid = _m["id"]
|
||||
MODEL_CONFIG[_mid] = {
|
||||
"api_url": _m.get("api_url", DEFAULT_API_URL),
|
||||
"api_key": _m.get("api_key", DEFAULT_API_KEY),
|
||||
}
|
||||
|
||||
DEFAULT_MODEL = _cfg.get("default_model", "glm-5")
|
||||
|
|
|
|||
|
|
@ -7,13 +7,13 @@ from backend.routes.tools import bp as tools_bp
|
|||
from backend.routes.stats import bp as stats_bp
|
||||
from backend.routes.projects import bp as projects_bp
|
||||
from backend.services.glm_client import GLMClient
|
||||
from backend.config import API_URL, API_KEY
|
||||
from backend.config import MODEL_CONFIG
|
||||
|
||||
|
||||
def register_routes(app: Flask):
|
||||
"""Register all route blueprints"""
|
||||
# Initialize GLM client and chat service
|
||||
glm_client = GLMClient(API_URL, API_KEY)
|
||||
# Initialize GLM client with per-model config
|
||||
glm_client = GLMClient(MODEL_CONFIG)
|
||||
init_chat_service(glm_client)
|
||||
|
||||
# Register blueprints
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""Conversation API routes"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from flask import Blueprint, request
|
||||
from backend import db
|
||||
from backend.models import Conversation, Project
|
||||
|
|
@ -62,7 +62,7 @@ def conversation_list():
|
|||
|
||||
if cursor:
|
||||
q = q.filter(Conversation.updated_at < (
|
||||
db.session.query(Conversation.updated_at).filter_by(id=cursor).scalar() or datetime.utcnow))
|
||||
db.session.query(Conversation.updated_at).filter_by(id=cursor).scalar() or datetime.now(timezone.utc)))
|
||||
rows = q.order_by(Conversation.updated_at.desc()).limit(limit + 1).all()
|
||||
|
||||
items = [_conv_to_dict(r, message_count=r.messages.count()) for r in rows[:limit]]
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
"""Message API routes"""
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from flask import Blueprint, request
|
||||
from backend import db
|
||||
from backend.models import Conversation, Message
|
||||
from backend.utils.helpers import ok, err, to_dict, message_to_dict, get_or_create_default_user
|
||||
from backend.utils.helpers import ok, err, message_to_dict
|
||||
from backend.services.chat import ChatService
|
||||
|
||||
|
||||
|
|
@ -34,7 +34,7 @@ def message_list(conv_id):
|
|||
q = Message.query.filter_by(conversation_id=conv_id)
|
||||
if cursor:
|
||||
q = q.filter(Message.created_at < (
|
||||
db.session.query(Message.created_at).filter_by(id=cursor).scalar() or datetime.utcnow))
|
||||
db.session.query(Message.created_at).filter_by(id=cursor).scalar() or datetime.now(timezone.utc)))
|
||||
rows = q.order_by(Message.created_at.asc()).limit(limit + 1).all()
|
||||
|
||||
items = [message_to_dict(r) for r in rows[:limit]]
|
||||
|
|
@ -48,7 +48,6 @@ def message_list(conv_id):
|
|||
d = request.json or {}
|
||||
text = (d.get("text") or "").strip()
|
||||
attachments = d.get("attachments") # [{"name": "a.py", "extension": "py", "content": "..."}]
|
||||
project_id = d.get("project_id") # Get project_id from request
|
||||
|
||||
if not text and not attachments:
|
||||
return err(400, "text or attachments is required")
|
||||
|
|
|
|||
|
|
@ -2,8 +2,7 @@
|
|||
import os
|
||||
import uuid
|
||||
import shutil
|
||||
from flask import Blueprint, request, jsonify
|
||||
from werkzeug.utils import secure_filename
|
||||
from flask import Blueprint, request
|
||||
|
||||
from backend import db
|
||||
from backend.models import Project, User
|
||||
|
|
|
|||
|
|
@ -5,11 +5,20 @@ from typing import Optional, List
|
|||
|
||||
class GLMClient:
|
||||
"""GLM API client for chat completions"""
|
||||
|
||||
def __init__(self, api_url: str, api_key: str):
|
||||
self.api_url = api_url
|
||||
self.api_key = api_key
|
||||
|
||||
|
||||
def __init__(self, model_config: dict):
|
||||
"""Initialize with per-model config lookup.
|
||||
|
||||
Args:
|
||||
model_config: {model_id: {"api_url": ..., "api_key": ...}}
|
||||
"""
|
||||
self.model_config = model_config
|
||||
|
||||
def _get_credentials(self, model: str):
|
||||
"""Get api_url and api_key for a model, with fallback."""
|
||||
cfg = self.model_config.get(model, {})
|
||||
return cfg.get("api_url", ""), cfg.get("api_key", "")
|
||||
|
||||
def call(
|
||||
self,
|
||||
model: str,
|
||||
|
|
@ -22,6 +31,7 @@ class GLMClient:
|
|||
timeout: int = 120,
|
||||
):
|
||||
"""Call GLM API"""
|
||||
api_url, api_key = self._get_credentials(model)
|
||||
body = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
|
|
@ -35,12 +45,12 @@ class GLMClient:
|
|||
body["tool_choice"] = "auto"
|
||||
if stream:
|
||||
body["stream"] = True
|
||||
|
||||
|
||||
return requests.post(
|
||||
self.api_url,
|
||||
api_url,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
},
|
||||
json=body,
|
||||
stream=stream,
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
"""File operation tools"""
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from typing import Tuple
|
||||
from backend.tools.factory import tool
|
||||
from backend import db
|
||||
from backend.models import Project
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
"""Tool system core classes"""
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Any, Dict, List, Optional
|
||||
|
||||
|
||||
|
|
@ -69,14 +69,6 @@ class ToolRegistry:
|
|||
"""List all tools in OpenAI format"""
|
||||
return [t.to_openai_format() for t in self._tools.values()]
|
||||
|
||||
def list_by_category(self, category: str) -> List[dict]:
|
||||
"""List tools by category"""
|
||||
return [
|
||||
t.to_openai_format()
|
||||
for t in self._tools.values()
|
||||
if t.category == category
|
||||
]
|
||||
|
||||
def execute(self, name: str, arguments: dict) -> dict:
|
||||
"""Execute a tool"""
|
||||
tool = self.get(name)
|
||||
|
|
@ -91,17 +83,6 @@ class ToolRegistry:
|
|||
except Exception as e:
|
||||
return ToolResult.fail(str(e)).to_dict()
|
||||
|
||||
def remove(self, name: str) -> bool:
|
||||
"""Remove a tool"""
|
||||
if name in self._tools:
|
||||
del self._tools[name]
|
||||
return True
|
||||
return False
|
||||
|
||||
def has(self, name: str) -> bool:
|
||||
"""Check if tool exists"""
|
||||
return name in self._tools
|
||||
|
||||
|
||||
# Global registry instance
|
||||
registry = ToolRegistry()
|
||||
|
|
|
|||
|
|
@ -176,29 +176,5 @@ class ToolExecutor:
|
|||
}, ensure_ascii=False)
|
||||
}
|
||||
|
||||
def build_request(
|
||||
self,
|
||||
messages: List[dict],
|
||||
model: str = "glm-5",
|
||||
tools: Optional[List[dict]] = None,
|
||||
**kwargs
|
||||
) -> dict:
|
||||
"""
|
||||
Build API request body
|
||||
|
||||
Args:
|
||||
messages: Message list
|
||||
model: Model name
|
||||
tools: Tool list (default: all tools in registry)
|
||||
**kwargs: Other parameters (temperature, max_tokens, etc.)
|
||||
|
||||
Returns:
|
||||
Request body dict
|
||||
"""
|
||||
return {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"tools": tools or self.registry.list_all(),
|
||||
"tool_choice": kwargs.get("tool_choice", "auto"),
|
||||
**{k: v for k, v in kwargs.items() if k not in ["tool_choice"]}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
"""Tool helper services"""
|
||||
from typing import List, Dict, Optional, Any
|
||||
from typing import List
|
||||
from ddgs import DDGS
|
||||
import re
|
||||
|
||||
|
|
@ -200,11 +200,6 @@ class ContentExtractor:
|
|||
class CalculatorService:
|
||||
"""Safe calculation service"""
|
||||
|
||||
ALLOWED_OPS = {
|
||||
"add", "sub", "mul", "truediv", "floordiv",
|
||||
"mod", "pow", "neg", "abs"
|
||||
}
|
||||
|
||||
def evaluate(self, expression: str) -> dict:
|
||||
"""
|
||||
Safely evaluate mathematical expression
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from datetime import date, datetime
|
|||
from typing import Any
|
||||
from flask import jsonify
|
||||
from backend import db
|
||||
from backend.models import Conversation, Message, TokenUsage, User
|
||||
from backend.models import Message, TokenUsage, User
|
||||
|
||||
|
||||
def get_or_create_default_user() -> User:
|
||||
|
|
|
|||
|
|
@ -1,9 +1,6 @@
|
|||
"""Workspace path validation utilities"""
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from flask import current_app
|
||||
|
||||
from backend import load_config
|
||||
|
||||
|
|
@ -137,8 +134,6 @@ def delete_project_directory(project_path: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
|
||||
|
||||
def save_uploaded_files(files, project_dir: Path) -> dict:
|
||||
"""
|
||||
Save uploaded files to project directory (for folder upload)
|
||||
|
|
@ -187,45 +182,3 @@ def save_uploaded_files(files, project_dir: Path) -> dict:
|
|||
"size": total_size
|
||||
}
|
||||
|
||||
|
||||
def copy_folder_to_project(source_path: str, project_dir: Path, project_name: str) -> dict:
|
||||
"""
|
||||
Copy a folder to project directory (for folder upload)
|
||||
|
||||
Args:
|
||||
source_path: Source folder path
|
||||
project_dir: Target project directory
|
||||
project_name: Project name
|
||||
|
||||
Returns:
|
||||
Dict with copy statistics
|
||||
"""
|
||||
source = Path(source_path)
|
||||
|
||||
if not source.exists():
|
||||
raise ValueError(f"Source path does not exist: {source_path}")
|
||||
|
||||
if not source.is_dir():
|
||||
raise ValueError(f"Source path is not a directory: {source_path}")
|
||||
|
||||
# Security check: don't copy from sensitive system directories
|
||||
sensitive_dirs = ["/etc", "/usr", "/bin", "/sbin", "/root", "/home"]
|
||||
for sensitive in sensitive_dirs:
|
||||
if str(source.resolve()).startswith(sensitive):
|
||||
raise ValueError(f"Cannot copy from system directory: {sensitive}")
|
||||
|
||||
# Copy directory
|
||||
if project_dir.exists():
|
||||
shutil.rmtree(project_dir)
|
||||
|
||||
shutil.copytree(source, project_dir)
|
||||
|
||||
# Count files
|
||||
file_count = sum(1 for _ in project_dir.rglob("*") if _.is_file())
|
||||
dir_count = sum(1 for _ in project_dir.rglob("*") if _.is_dir())
|
||||
|
||||
return {
|
||||
"files": file_count,
|
||||
"directories": dir_count,
|
||||
"size": sum(f.stat().st_size for f in project_dir.rglob("*") if f.is_file())
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue