feat: 增加sqlite 数据库支持
This commit is contained in:
parent
8a23b1cd00
commit
4499c72ed8
|
|
@ -1,9 +1,10 @@
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from flask import Flask
|
from flask import Flask
|
||||||
from flask_sqlalchemy import SQLAlchemy
|
|
||||||
from flask_cors import CORS
|
from flask_cors import CORS
|
||||||
from pathlib import Path
|
from flask_sqlalchemy import SQLAlchemy
|
||||||
|
|
||||||
# Initialize db BEFORE importing models/routes that depend on it
|
# Initialize db BEFORE importing models/routes that depend on it
|
||||||
db = SQLAlchemy()
|
db = SQLAlchemy()
|
||||||
|
|
@ -15,15 +16,40 @@ def load_config():
|
||||||
return yaml.safe_load(f)
|
return yaml.safe_load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_database_uri(cfg: dict) -> str:
|
||||||
|
"""Build database URI based on database type."""
|
||||||
|
db_type = cfg.get("db_type", "mysql").lower()
|
||||||
|
|
||||||
|
if db_type == "sqlite":
|
||||||
|
# SQLite: sqlite:///path/to/database.db
|
||||||
|
db_file = cfg.get("db_sqlite_file", "app.db")
|
||||||
|
# Store in instance folder for better organization
|
||||||
|
instance_path = Path(__file__).parent.parent / "instance"
|
||||||
|
instance_path.mkdir(exist_ok=True)
|
||||||
|
db_path = instance_path / db_file
|
||||||
|
return f"sqlite:///{db_path}"
|
||||||
|
|
||||||
|
elif db_type == "postgresql":
|
||||||
|
# PostgreSQL: postgresql://user:password@host:port/database
|
||||||
|
return (
|
||||||
|
f"postgresql://{cfg['db_user']}:{cfg['db_password']}"
|
||||||
|
f"@{cfg.get('db_host', 'localhost')}:{cfg.get('db_port', 5432)}/{cfg['db_name']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
else: # mysql (default)
|
||||||
|
# MySQL: mysql+pymysql://user:password@host:port/database?charset=utf8mb4
|
||||||
|
return (
|
||||||
|
f"mysql+pymysql://{cfg['db_user']}:{cfg['db_password']}"
|
||||||
|
f"@{cfg.get('db_host', 'localhost')}:{cfg.get('db_port', 3306)}/{cfg['db_name']}"
|
||||||
|
f"?charset=utf8mb4"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_app():
|
def create_app():
|
||||||
cfg = load_config()
|
cfg = load_config()
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
app.config["SQLALCHEMY_DATABASE_URI"] = (
|
app.config["SQLALCHEMY_DATABASE_URI"] = _get_database_uri(cfg)
|
||||||
f"mysql+pymysql://{cfg['db_user']}:{cfg['db_password']}"
|
|
||||||
f"@{cfg.get('db_host', 'localhost')}:{cfg.get('db_port', 3306)}/{cfg['db_name']}"
|
|
||||||
f"?charset=utf8mb4"
|
|
||||||
)
|
|
||||||
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
|
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
|
||||||
|
|
||||||
# Enable CORS for all routes
|
# Enable CORS for all routes
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,35 @@
|
||||||
from datetime import datetime, timezone
|
|
||||||
from sqlalchemy.dialects.mysql import LONGTEXT
|
|
||||||
from backend import db
|
from backend import db
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from flask import current_app
|
||||||
|
from sqlalchemy import Text
|
||||||
|
from sqlalchemy.dialects.mysql import LONGTEXT as MYSQL_LONGTEXT
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_longtext_type():
|
||||||
|
"""Get appropriate text type for long content based on database dialect."""
|
||||||
|
db_uri = current_app.config.get("SQLALCHEMY_DATABASE_URI", "")
|
||||||
|
if db_uri.startswith("mysql"):
|
||||||
|
return MYSQL_LONGTEXT
|
||||||
|
return Text # SQLite and PostgreSQL use Text
|
||||||
|
|
||||||
|
|
||||||
|
# For model definitions, we'll use a callable that returns the right type
|
||||||
|
class LongText(db.TypeDecorator):
|
||||||
|
"""Cross-database LONGTEXT type that works with MySQL, SQLite, and PostgreSQL."""
|
||||||
|
impl = Text
|
||||||
|
cache_ok = True
|
||||||
|
|
||||||
|
def load_dialect_impl(self, dialect):
|
||||||
|
if dialect.name == "mysql":
|
||||||
|
return dialect.type_descriptor(MYSQL_LONGTEXT)
|
||||||
|
return dialect.type_descriptor(Text)
|
||||||
|
|
||||||
|
|
||||||
class User(db.Model):
|
class User(db.Model):
|
||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
|
|
||||||
id = db.Column(db.BigInteger, primary_key=True, autoincrement=True)
|
id = db.Column(db.Integer, primary_key=True, autoincrement=True)
|
||||||
username = db.Column(db.String(50), unique=True, nullable=False)
|
username = db.Column(db.String(50), unique=True, nullable=False)
|
||||||
password = db.Column(db.String(255), nullable=True) # Allow NULL for third-party login
|
password = db.Column(db.String(255), nullable=True) # Allow NULL for third-party login
|
||||||
phone = db.Column(db.String(20))
|
phone = db.Column(db.String(20))
|
||||||
|
|
@ -20,7 +43,7 @@ class Conversation(db.Model):
|
||||||
__tablename__ = "conversations"
|
__tablename__ = "conversations"
|
||||||
|
|
||||||
id = db.Column(db.String(64), primary_key=True)
|
id = db.Column(db.String(64), primary_key=True)
|
||||||
user_id = db.Column(db.BigInteger, db.ForeignKey("users.id"), nullable=False, index=True)
|
user_id = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=False, index=True)
|
||||||
title = db.Column(db.String(255), nullable=False, default="")
|
title = db.Column(db.String(255), nullable=False, default="")
|
||||||
model = db.Column(db.String(64), nullable=False, default="glm-5")
|
model = db.Column(db.String(64), nullable=False, default="glm-5")
|
||||||
system_prompt = db.Column(db.Text, default="")
|
system_prompt = db.Column(db.Text, default="")
|
||||||
|
|
@ -43,9 +66,9 @@ class Message(db.Model):
|
||||||
conversation_id = db.Column(db.String(64), db.ForeignKey("conversations.id"),
|
conversation_id = db.Column(db.String(64), db.ForeignKey("conversations.id"),
|
||||||
nullable=False, index=True)
|
nullable=False, index=True)
|
||||||
role = db.Column(db.String(16), nullable=False) # user, assistant, system, tool
|
role = db.Column(db.String(16), nullable=False) # user, assistant, system, tool
|
||||||
content = db.Column(LONGTEXT, default="") # LONGTEXT for long conversations
|
content = db.Column(LongText, default="") # LongText for long conversations
|
||||||
token_count = db.Column(db.Integer, default=0)
|
token_count = db.Column(db.Integer, default=0)
|
||||||
thinking_content = db.Column(LONGTEXT, default="")
|
thinking_content = db.Column(LongText, default="")
|
||||||
created_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc), index=True)
|
created_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc), index=True)
|
||||||
|
|
||||||
# Tool call support - relation to ToolCall table
|
# Tool call support - relation to ToolCall table
|
||||||
|
|
@ -58,14 +81,14 @@ class ToolCall(db.Model):
|
||||||
"""Tool call record - separate table, follows database normalization"""
|
"""Tool call record - separate table, follows database normalization"""
|
||||||
__tablename__ = "tool_calls"
|
__tablename__ = "tool_calls"
|
||||||
|
|
||||||
id = db.Column(db.BigInteger, primary_key=True, autoincrement=True)
|
id = db.Column(db.Integer, primary_key=True, autoincrement=True)
|
||||||
message_id = db.Column(db.String(64), db.ForeignKey("messages.id"),
|
message_id = db.Column(db.String(64), db.ForeignKey("messages.id"),
|
||||||
nullable=False, index=True)
|
nullable=False, index=True)
|
||||||
call_id = db.Column(db.String(64), nullable=False) # Tool call ID
|
call_id = db.Column(db.String(64), nullable=False) # Tool call ID
|
||||||
call_index = db.Column(db.Integer, nullable=False, default=0) # Call order
|
call_index = db.Column(db.Integer, nullable=False, default=0) # Call order
|
||||||
tool_name = db.Column(db.String(64), nullable=False) # Tool name
|
tool_name = db.Column(db.String(64), nullable=False) # Tool name
|
||||||
arguments = db.Column(LONGTEXT, nullable=False) # Call arguments JSON
|
arguments = db.Column(LongText, nullable=False) # Call arguments JSON
|
||||||
result = db.Column(LONGTEXT) # Execution result JSON
|
result = db.Column(LongText) # Execution result JSON
|
||||||
execution_time = db.Column(db.Float, default=0) # Execution time (seconds)
|
execution_time = db.Column(db.Float, default=0) # Execution time (seconds)
|
||||||
created_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc))
|
created_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc))
|
||||||
|
|
||||||
|
|
@ -77,8 +100,8 @@ class ToolCall(db.Model):
|
||||||
class TokenUsage(db.Model):
|
class TokenUsage(db.Model):
|
||||||
__tablename__ = "token_usage"
|
__tablename__ = "token_usage"
|
||||||
|
|
||||||
id = db.Column(db.BigInteger, primary_key=True, autoincrement=True)
|
id = db.Column(db.Integer, primary_key=True, autoincrement=True)
|
||||||
user_id = db.Column(db.BigInteger, db.ForeignKey("users.id"),
|
user_id = db.Column(db.Integer, db.ForeignKey("users.id"),
|
||||||
nullable=False, index=True)
|
nullable=False, index=True)
|
||||||
date = db.Column(db.Date, nullable=False, index=True)
|
date = db.Column(db.Date, nullable=False, index=True)
|
||||||
model = db.Column(db.String(64), nullable=False)
|
model = db.Column(db.String(64), nullable=False)
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,13 @@
|
||||||
"""Common helper functions"""
|
"""Common helper functions"""
|
||||||
import json
|
import json
|
||||||
from datetime import datetime, date
|
from datetime import date, datetime
|
||||||
|
from typing import Any
|
||||||
|
from flask import jsonify
|
||||||
from backend import db
|
from backend import db
|
||||||
from backend.models import Conversation, Message, User, TokenUsage
|
from backend.models import Conversation, Message, TokenUsage, User
|
||||||
|
|
||||||
|
|
||||||
def get_or_create_default_user():
|
def get_or_create_default_user() -> User:
|
||||||
"""Get or create default user"""
|
"""Get or create default user"""
|
||||||
user = User.query.filter_by(username="default").first()
|
user = User.query.filter_by(username="default").first()
|
||||||
if not user:
|
if not user:
|
||||||
|
|
@ -22,13 +24,11 @@ def ok(data=None, message=None):
|
||||||
body["data"] = data
|
body["data"] = data
|
||||||
if message is not None:
|
if message is not None:
|
||||||
body["message"] = message
|
body["message"] = message
|
||||||
from flask import jsonify
|
|
||||||
return jsonify(body)
|
return jsonify(body)
|
||||||
|
|
||||||
|
|
||||||
def err(code, message):
|
def err(code, message):
|
||||||
"""Error response helper"""
|
"""Error response helper"""
|
||||||
from flask import jsonify
|
|
||||||
return jsonify({"code": code, "message": message}), code
|
return jsonify({"code": code, "message": message}), code
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue