refactor: 清理死代码,支持按模型配置独立的 api_url 和 api_key
This commit is contained in:
parent
a847d91886
commit
a24eb8e24f
|
|
@ -1,4 +1,3 @@
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,21 @@ from backend import load_config
|
||||||
|
|
||||||
_cfg = load_config()
|
_cfg = load_config()
|
||||||
|
|
||||||
API_URL = _cfg.get("api_url")
|
# Global defaults
|
||||||
API_KEY = _cfg["api_key"]
|
DEFAULT_API_URL = _cfg.get("default_api_url", "")
|
||||||
|
DEFAULT_API_KEY = _cfg.get("default_api_key", "")
|
||||||
|
|
||||||
|
# Model list (for /api/models endpoint)
|
||||||
MODELS = _cfg.get("models", [])
|
MODELS = _cfg.get("models", [])
|
||||||
|
|
||||||
|
# Per-model config lookup: {model_id: {api_url, api_key}}
|
||||||
|
# Falls back to global defaults if not specified per model
|
||||||
|
MODEL_CONFIG = {}
|
||||||
|
for _m in MODELS:
|
||||||
|
_mid = _m["id"]
|
||||||
|
MODEL_CONFIG[_mid] = {
|
||||||
|
"api_url": _m.get("api_url", DEFAULT_API_URL),
|
||||||
|
"api_key": _m.get("api_key", DEFAULT_API_KEY),
|
||||||
|
}
|
||||||
|
|
||||||
DEFAULT_MODEL = _cfg.get("default_model", "glm-5")
|
DEFAULT_MODEL = _cfg.get("default_model", "glm-5")
|
||||||
|
|
|
||||||
|
|
@ -7,13 +7,13 @@ from backend.routes.tools import bp as tools_bp
|
||||||
from backend.routes.stats import bp as stats_bp
|
from backend.routes.stats import bp as stats_bp
|
||||||
from backend.routes.projects import bp as projects_bp
|
from backend.routes.projects import bp as projects_bp
|
||||||
from backend.services.glm_client import GLMClient
|
from backend.services.glm_client import GLMClient
|
||||||
from backend.config import API_URL, API_KEY
|
from backend.config import MODEL_CONFIG
|
||||||
|
|
||||||
|
|
||||||
def register_routes(app: Flask):
|
def register_routes(app: Flask):
|
||||||
"""Register all route blueprints"""
|
"""Register all route blueprints"""
|
||||||
# Initialize GLM client and chat service
|
# Initialize GLM client with per-model config
|
||||||
glm_client = GLMClient(API_URL, API_KEY)
|
glm_client = GLMClient(MODEL_CONFIG)
|
||||||
init_chat_service(glm_client)
|
init_chat_service(glm_client)
|
||||||
|
|
||||||
# Register blueprints
|
# Register blueprints
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"""Conversation API routes"""
|
"""Conversation API routes"""
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from flask import Blueprint, request
|
from flask import Blueprint, request
|
||||||
from backend import db
|
from backend import db
|
||||||
from backend.models import Conversation, Project
|
from backend.models import Conversation, Project
|
||||||
|
|
@ -62,7 +62,7 @@ def conversation_list():
|
||||||
|
|
||||||
if cursor:
|
if cursor:
|
||||||
q = q.filter(Conversation.updated_at < (
|
q = q.filter(Conversation.updated_at < (
|
||||||
db.session.query(Conversation.updated_at).filter_by(id=cursor).scalar() or datetime.utcnow))
|
db.session.query(Conversation.updated_at).filter_by(id=cursor).scalar() or datetime.now(timezone.utc)))
|
||||||
rows = q.order_by(Conversation.updated_at.desc()).limit(limit + 1).all()
|
rows = q.order_by(Conversation.updated_at.desc()).limit(limit + 1).all()
|
||||||
|
|
||||||
items = [_conv_to_dict(r, message_count=r.messages.count()) for r in rows[:limit]]
|
items = [_conv_to_dict(r, message_count=r.messages.count()) for r in rows[:limit]]
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
"""Message API routes"""
|
"""Message API routes"""
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from flask import Blueprint, request
|
from flask import Blueprint, request
|
||||||
from backend import db
|
from backend import db
|
||||||
from backend.models import Conversation, Message
|
from backend.models import Conversation, Message
|
||||||
from backend.utils.helpers import ok, err, to_dict, message_to_dict, get_or_create_default_user
|
from backend.utils.helpers import ok, err, message_to_dict
|
||||||
from backend.services.chat import ChatService
|
from backend.services.chat import ChatService
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -34,7 +34,7 @@ def message_list(conv_id):
|
||||||
q = Message.query.filter_by(conversation_id=conv_id)
|
q = Message.query.filter_by(conversation_id=conv_id)
|
||||||
if cursor:
|
if cursor:
|
||||||
q = q.filter(Message.created_at < (
|
q = q.filter(Message.created_at < (
|
||||||
db.session.query(Message.created_at).filter_by(id=cursor).scalar() or datetime.utcnow))
|
db.session.query(Message.created_at).filter_by(id=cursor).scalar() or datetime.now(timezone.utc)))
|
||||||
rows = q.order_by(Message.created_at.asc()).limit(limit + 1).all()
|
rows = q.order_by(Message.created_at.asc()).limit(limit + 1).all()
|
||||||
|
|
||||||
items = [message_to_dict(r) for r in rows[:limit]]
|
items = [message_to_dict(r) for r in rows[:limit]]
|
||||||
|
|
@ -48,7 +48,6 @@ def message_list(conv_id):
|
||||||
d = request.json or {}
|
d = request.json or {}
|
||||||
text = (d.get("text") or "").strip()
|
text = (d.get("text") or "").strip()
|
||||||
attachments = d.get("attachments") # [{"name": "a.py", "extension": "py", "content": "..."}]
|
attachments = d.get("attachments") # [{"name": "a.py", "extension": "py", "content": "..."}]
|
||||||
project_id = d.get("project_id") # Get project_id from request
|
|
||||||
|
|
||||||
if not text and not attachments:
|
if not text and not attachments:
|
||||||
return err(400, "text or attachments is required")
|
return err(400, "text or attachments is required")
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,7 @@
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
import shutil
|
import shutil
|
||||||
from flask import Blueprint, request, jsonify
|
from flask import Blueprint, request
|
||||||
from werkzeug.utils import secure_filename
|
|
||||||
|
|
||||||
from backend import db
|
from backend import db
|
||||||
from backend.models import Project, User
|
from backend.models import Project, User
|
||||||
|
|
|
||||||
|
|
@ -5,11 +5,20 @@ from typing import Optional, List
|
||||||
|
|
||||||
class GLMClient:
|
class GLMClient:
|
||||||
"""GLM API client for chat completions"""
|
"""GLM API client for chat completions"""
|
||||||
|
|
||||||
def __init__(self, api_url: str, api_key: str):
|
def __init__(self, model_config: dict):
|
||||||
self.api_url = api_url
|
"""Initialize with per-model config lookup.
|
||||||
self.api_key = api_key
|
|
||||||
|
Args:
|
||||||
|
model_config: {model_id: {"api_url": ..., "api_key": ...}}
|
||||||
|
"""
|
||||||
|
self.model_config = model_config
|
||||||
|
|
||||||
|
def _get_credentials(self, model: str):
|
||||||
|
"""Get api_url and api_key for a model, with fallback."""
|
||||||
|
cfg = self.model_config.get(model, {})
|
||||||
|
return cfg.get("api_url", ""), cfg.get("api_key", "")
|
||||||
|
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
@ -22,6 +31,7 @@ class GLMClient:
|
||||||
timeout: int = 120,
|
timeout: int = 120,
|
||||||
):
|
):
|
||||||
"""Call GLM API"""
|
"""Call GLM API"""
|
||||||
|
api_url, api_key = self._get_credentials(model)
|
||||||
body = {
|
body = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
|
|
@ -35,12 +45,12 @@ class GLMClient:
|
||||||
body["tool_choice"] = "auto"
|
body["tool_choice"] = "auto"
|
||||||
if stream:
|
if stream:
|
||||||
body["stream"] = True
|
body["stream"] = True
|
||||||
|
|
||||||
return requests.post(
|
return requests.post(
|
||||||
self.api_url,
|
api_url,
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {self.api_key}"
|
"Authorization": f"Bearer {api_key}"
|
||||||
},
|
},
|
||||||
json=body,
|
json=body,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
"""File operation tools"""
|
"""File operation tools"""
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Tuple
|
from typing import Tuple
|
||||||
from backend.tools.factory import tool
|
from backend.tools.factory import tool
|
||||||
from backend import db
|
from backend import db
|
||||||
from backend.models import Project
|
from backend.models import Project
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
"""Tool system core classes"""
|
"""Tool system core classes"""
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from typing import Callable, Any, Dict, List, Optional
|
from typing import Callable, Any, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -69,14 +69,6 @@ class ToolRegistry:
|
||||||
"""List all tools in OpenAI format"""
|
"""List all tools in OpenAI format"""
|
||||||
return [t.to_openai_format() for t in self._tools.values()]
|
return [t.to_openai_format() for t in self._tools.values()]
|
||||||
|
|
||||||
def list_by_category(self, category: str) -> List[dict]:
|
|
||||||
"""List tools by category"""
|
|
||||||
return [
|
|
||||||
t.to_openai_format()
|
|
||||||
for t in self._tools.values()
|
|
||||||
if t.category == category
|
|
||||||
]
|
|
||||||
|
|
||||||
def execute(self, name: str, arguments: dict) -> dict:
|
def execute(self, name: str, arguments: dict) -> dict:
|
||||||
"""Execute a tool"""
|
"""Execute a tool"""
|
||||||
tool = self.get(name)
|
tool = self.get(name)
|
||||||
|
|
@ -91,17 +83,6 @@ class ToolRegistry:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return ToolResult.fail(str(e)).to_dict()
|
return ToolResult.fail(str(e)).to_dict()
|
||||||
|
|
||||||
def remove(self, name: str) -> bool:
|
|
||||||
"""Remove a tool"""
|
|
||||||
if name in self._tools:
|
|
||||||
del self._tools[name]
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def has(self, name: str) -> bool:
|
|
||||||
"""Check if tool exists"""
|
|
||||||
return name in self._tools
|
|
||||||
|
|
||||||
|
|
||||||
# Global registry instance
|
# Global registry instance
|
||||||
registry = ToolRegistry()
|
registry = ToolRegistry()
|
||||||
|
|
|
||||||
|
|
@ -176,29 +176,5 @@ class ToolExecutor:
|
||||||
}, ensure_ascii=False)
|
}, ensure_ascii=False)
|
||||||
}
|
}
|
||||||
|
|
||||||
def build_request(
|
|
||||||
self,
|
|
||||||
messages: List[dict],
|
|
||||||
model: str = "glm-5",
|
|
||||||
tools: Optional[List[dict]] = None,
|
|
||||||
**kwargs
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Build API request body
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: Message list
|
|
||||||
model: Model name
|
|
||||||
tools: Tool list (default: all tools in registry)
|
|
||||||
**kwargs: Other parameters (temperature, max_tokens, etc.)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Request body dict
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
"model": model,
|
|
||||||
"messages": messages,
|
|
||||||
"tools": tools or self.registry.list_all(),
|
|
||||||
"tool_choice": kwargs.get("tool_choice", "auto"),
|
|
||||||
**{k: v for k, v in kwargs.items() if k not in ["tool_choice"]}
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
"""Tool helper services"""
|
"""Tool helper services"""
|
||||||
from typing import List, Dict, Optional, Any
|
from typing import List
|
||||||
from ddgs import DDGS
|
from ddgs import DDGS
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
@ -200,11 +200,6 @@ class ContentExtractor:
|
||||||
class CalculatorService:
|
class CalculatorService:
|
||||||
"""Safe calculation service"""
|
"""Safe calculation service"""
|
||||||
|
|
||||||
ALLOWED_OPS = {
|
|
||||||
"add", "sub", "mul", "truediv", "floordiv",
|
|
||||||
"mod", "pow", "neg", "abs"
|
|
||||||
}
|
|
||||||
|
|
||||||
def evaluate(self, expression: str) -> dict:
|
def evaluate(self, expression: str) -> dict:
|
||||||
"""
|
"""
|
||||||
Safely evaluate mathematical expression
|
Safely evaluate mathematical expression
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from datetime import date, datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from flask import jsonify
|
from flask import jsonify
|
||||||
from backend import db
|
from backend import db
|
||||||
from backend.models import Conversation, Message, TokenUsage, User
|
from backend.models import Message, TokenUsage, User
|
||||||
|
|
||||||
|
|
||||||
def get_or_create_default_user() -> User:
|
def get_or_create_default_user() -> User:
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,6 @@
|
||||||
"""Workspace path validation utilities"""
|
"""Workspace path validation utilities"""
|
||||||
import os
|
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
from flask import current_app
|
|
||||||
|
|
||||||
from backend import load_config
|
from backend import load_config
|
||||||
|
|
||||||
|
|
@ -137,8 +134,6 @@ def delete_project_directory(project_path: str) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def save_uploaded_files(files, project_dir: Path) -> dict:
|
def save_uploaded_files(files, project_dir: Path) -> dict:
|
||||||
"""
|
"""
|
||||||
Save uploaded files to project directory (for folder upload)
|
Save uploaded files to project directory (for folder upload)
|
||||||
|
|
@ -187,45 +182,3 @@ def save_uploaded_files(files, project_dir: Path) -> dict:
|
||||||
"size": total_size
|
"size": total_size
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def copy_folder_to_project(source_path: str, project_dir: Path, project_name: str) -> dict:
|
|
||||||
"""
|
|
||||||
Copy a folder to project directory (for folder upload)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_path: Source folder path
|
|
||||||
project_dir: Target project directory
|
|
||||||
project_name: Project name
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict with copy statistics
|
|
||||||
"""
|
|
||||||
source = Path(source_path)
|
|
||||||
|
|
||||||
if not source.exists():
|
|
||||||
raise ValueError(f"Source path does not exist: {source_path}")
|
|
||||||
|
|
||||||
if not source.is_dir():
|
|
||||||
raise ValueError(f"Source path is not a directory: {source_path}")
|
|
||||||
|
|
||||||
# Security check: don't copy from sensitive system directories
|
|
||||||
sensitive_dirs = ["/etc", "/usr", "/bin", "/sbin", "/root", "/home"]
|
|
||||||
for sensitive in sensitive_dirs:
|
|
||||||
if str(source.resolve()).startswith(sensitive):
|
|
||||||
raise ValueError(f"Cannot copy from system directory: {sensitive}")
|
|
||||||
|
|
||||||
# Copy directory
|
|
||||||
if project_dir.exists():
|
|
||||||
shutil.rmtree(project_dir)
|
|
||||||
|
|
||||||
shutil.copytree(source, project_dir)
|
|
||||||
|
|
||||||
# Count files
|
|
||||||
file_count = sum(1 for _ in project_dir.rglob("*") if _.is_file())
|
|
||||||
dir_count = sum(1 for _ in project_dir.rglob("*") if _.is_dir())
|
|
||||||
|
|
||||||
return {
|
|
||||||
"files": file_count,
|
|
||||||
"directories": dir_count,
|
|
||||||
"size": sum(f.stat().st_size for f in project_dir.rglob("*") if f.is_file())
|
|
||||||
}
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue