From 4499c72ed8f444da695bbab38b42ad31b4bafd98 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Wed, 25 Mar 2026 10:15:40 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0sqlite=20=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/__init__.py | 40 ++++++++++++++++++++++++++++------- backend/models.py | 45 ++++++++++++++++++++++++++++++---------- backend/utils/helpers.py | 10 ++++----- 3 files changed, 72 insertions(+), 23 deletions(-) diff --git a/backend/__init__.py b/backend/__init__.py index 487ef59..50a2170 100644 --- a/backend/__init__.py +++ b/backend/__init__.py @@ -1,9 +1,10 @@ import os +from pathlib import Path + import yaml from flask import Flask -from flask_sqlalchemy import SQLAlchemy from flask_cors import CORS -from pathlib import Path +from flask_sqlalchemy import SQLAlchemy # Initialize db BEFORE importing models/routes that depend on it db = SQLAlchemy() @@ -15,15 +16,40 @@ def load_config(): return yaml.safe_load(f) +def _get_database_uri(cfg: dict) -> str: + """Build database URI based on database type.""" + db_type = cfg.get("db_type", "mysql").lower() + + if db_type == "sqlite": + # SQLite: sqlite:///path/to/database.db + db_file = cfg.get("db_sqlite_file", "app.db") + # Store in instance folder for better organization + instance_path = Path(__file__).parent.parent / "instance" + instance_path.mkdir(exist_ok=True) + db_path = instance_path / db_file + return f"sqlite:///{db_path}" + + elif db_type == "postgresql": + # PostgreSQL: postgresql://user:password@host:port/database + return ( + f"postgresql://{cfg['db_user']}:{cfg['db_password']}" + f"@{cfg.get('db_host', 'localhost')}:{cfg.get('db_port', 5432)}/{cfg['db_name']}" + ) + + else: # mysql (default) + # MySQL: mysql+pymysql://user:password@host:port/database?charset=utf8mb4 + return ( + f"mysql+pymysql://{cfg['db_user']}:{cfg['db_password']}" + f"@{cfg.get('db_host', 'localhost')}:{cfg.get('db_port', 3306)}/{cfg['db_name']}" + f"?charset=utf8mb4" + ) + + def create_app(): cfg = load_config() app = Flask(__name__) - app.config["SQLALCHEMY_DATABASE_URI"] = ( - f"mysql+pymysql://{cfg['db_user']}:{cfg['db_password']}" - f"@{cfg.get('db_host', 'localhost')}:{cfg.get('db_port', 3306)}/{cfg['db_name']}" - f"?charset=utf8mb4" - ) + app.config["SQLALCHEMY_DATABASE_URI"] = _get_database_uri(cfg) app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False # Enable CORS for all routes diff --git a/backend/models.py b/backend/models.py index f5ddb9f..c909200 100644 --- a/backend/models.py +++ b/backend/models.py @@ -1,12 +1,35 @@ -from datetime import datetime, timezone -from sqlalchemy.dialects.mysql import LONGTEXT from backend import db +from datetime import datetime, timezone +from flask import current_app +from sqlalchemy import Text +from sqlalchemy.dialects.mysql import LONGTEXT as MYSQL_LONGTEXT + + + +def get_longtext_type(): + """Get appropriate text type for long content based on database dialect.""" + db_uri = current_app.config.get("SQLALCHEMY_DATABASE_URI", "") + if db_uri.startswith("mysql"): + return MYSQL_LONGTEXT + return Text # SQLite and PostgreSQL use Text + + +# For model definitions, we'll use a callable that returns the right type +class LongText(db.TypeDecorator): + """Cross-database LONGTEXT type that works with MySQL, SQLite, and PostgreSQL.""" + impl = Text + cache_ok = True + + def load_dialect_impl(self, dialect): + if dialect.name == "mysql": + return dialect.type_descriptor(MYSQL_LONGTEXT) + return dialect.type_descriptor(Text) class User(db.Model): __tablename__ = "users" - id = db.Column(db.BigInteger, primary_key=True, autoincrement=True) + id = db.Column(db.Integer, primary_key=True, autoincrement=True) username = db.Column(db.String(50), unique=True, nullable=False) password = db.Column(db.String(255), nullable=True) # Allow NULL for third-party login phone = db.Column(db.String(20)) @@ -20,7 +43,7 @@ class Conversation(db.Model): __tablename__ = "conversations" id = db.Column(db.String(64), primary_key=True) - user_id = db.Column(db.BigInteger, db.ForeignKey("users.id"), nullable=False, index=True) + user_id = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=False, index=True) title = db.Column(db.String(255), nullable=False, default="") model = db.Column(db.String(64), nullable=False, default="glm-5") system_prompt = db.Column(db.Text, default="") @@ -43,9 +66,9 @@ class Message(db.Model): conversation_id = db.Column(db.String(64), db.ForeignKey("conversations.id"), nullable=False, index=True) role = db.Column(db.String(16), nullable=False) # user, assistant, system, tool - content = db.Column(LONGTEXT, default="") # LONGTEXT for long conversations + content = db.Column(LongText, default="") # LongText for long conversations token_count = db.Column(db.Integer, default=0) - thinking_content = db.Column(LONGTEXT, default="") + thinking_content = db.Column(LongText, default="") created_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc), index=True) # Tool call support - relation to ToolCall table @@ -58,14 +81,14 @@ class ToolCall(db.Model): """Tool call record - separate table, follows database normalization""" __tablename__ = "tool_calls" - id = db.Column(db.BigInteger, primary_key=True, autoincrement=True) + id = db.Column(db.Integer, primary_key=True, autoincrement=True) message_id = db.Column(db.String(64), db.ForeignKey("messages.id"), nullable=False, index=True) call_id = db.Column(db.String(64), nullable=False) # Tool call ID call_index = db.Column(db.Integer, nullable=False, default=0) # Call order tool_name = db.Column(db.String(64), nullable=False) # Tool name - arguments = db.Column(LONGTEXT, nullable=False) # Call arguments JSON - result = db.Column(LONGTEXT) # Execution result JSON + arguments = db.Column(LongText, nullable=False) # Call arguments JSON + result = db.Column(LongText) # Execution result JSON execution_time = db.Column(db.Float, default=0) # Execution time (seconds) created_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc)) @@ -77,8 +100,8 @@ class ToolCall(db.Model): class TokenUsage(db.Model): __tablename__ = "token_usage" - id = db.Column(db.BigInteger, primary_key=True, autoincrement=True) - user_id = db.Column(db.BigInteger, db.ForeignKey("users.id"), + id = db.Column(db.Integer, primary_key=True, autoincrement=True) + user_id = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=False, index=True) date = db.Column(db.Date, nullable=False, index=True) model = db.Column(db.String(64), nullable=False) diff --git a/backend/utils/helpers.py b/backend/utils/helpers.py index 970de26..37e9073 100644 --- a/backend/utils/helpers.py +++ b/backend/utils/helpers.py @@ -1,11 +1,13 @@ """Common helper functions""" import json -from datetime import datetime, date +from datetime import date, datetime +from typing import Any +from flask import jsonify from backend import db -from backend.models import Conversation, Message, User, TokenUsage +from backend.models import Conversation, Message, TokenUsage, User -def get_or_create_default_user(): +def get_or_create_default_user() -> User: """Get or create default user""" user = User.query.filter_by(username="default").first() if not user: @@ -22,13 +24,11 @@ def ok(data=None, message=None): body["data"] = data if message is not None: body["message"] = message - from flask import jsonify return jsonify(body) def err(code, message): """Error response helper""" - from flask import jsonify return jsonify({"code": code, "message": message}), code