Luxx/alcor/routes/messages.py

239 lines
7.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""消息路由"""
import json
from typing import Optional, List
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from alcor.database import get_db
from alcor.models import Conversation, Message, User
from alcor.routes.auth import get_current_user
from alcor.services.chat import chat_service
from alcor.utils.helpers import generate_id, success_response, error_response
router = APIRouter(prefix="/messages", tags=["消息"])
class MessageCreate(BaseModel):
"""创建消息模型"""
conversation_id: str
content: str
tools_enabled: bool = True
class MessageResponse(BaseModel):
"""消息响应模型"""
id: str
conversation_id: str
role: str
content: str
token_count: int
created_at: str
@router.get("/{conversation_id}", response_model=dict)
def list_messages(
conversation_id: str,
limit: int = 100,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""获取消息列表"""
# 验证会话归属
conversation = db.query(Conversation).filter(
Conversation.id == conversation_id,
Conversation.user_id == current_user.id
).first()
if not conversation:
return error_response("会话不存在", 404)
messages = db.query(Message).filter(
Message.conversation_id == conversation_id
).order_by(Message.created_at.desc()).limit(limit).all()
items = [msg.to_dict() for msg in reversed(messages)]
return success_response(data={
"items": items,
"total": len(items)
})
@router.post("/", response_model=dict)
async def create_message(
data: MessageCreate,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""发送消息(非流式)"""
# 验证会话
conversation = db.query(Conversation).filter(
Conversation.id == data.conversation_id,
Conversation.user_id == current_user.id
).first()
if not conversation:
return error_response("会话不存在", 404)
# 保存用户消息
user_message = Message(
id=generate_id("msg"),
conversation_id=data.conversation_id,
role="user",
content=json.dumps({"text": data.content})
)
db.add(user_message)
# 更新会话时间
from datetime import datetime
conversation.updated_at = datetime.utcnow()
db.commit()
db.refresh(user_message)
# 获取AI响应非流式
response = chat_service.non_stream_response(
conversation=conversation,
user_message=data.content,
tools_enabled=data.tools_enabled
)
if not response.get("success"):
return error_response(response.get("error", "生成响应失败"), 500)
# 保存AI响应
ai_content = response.get("content", "")
ai_message = Message(
id=generate_id("msg"),
conversation_id=data.conversation_id,
role="assistant",
content=json.dumps({
"text": ai_content,
"tool_calls": response.get("tool_calls")
}),
token_count=len(ai_content) // 4 # 粗略估算
)
db.add(ai_message)
db.commit()
return success_response(data={
"user_message": user_message.to_dict(),
"assistant_message": ai_message.to_dict()
})
@router.post("/stream")
async def stream_message(
data: MessageCreate,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""发送消息(流式响应 - SSE"""
# 验证会话
conversation = db.query(Conversation).filter(
Conversation.id == data.conversation_id,
Conversation.user_id == current_user.id
).first()
if not conversation:
return error_response("会话不存在", 404)
# 保存用户消息
user_message = Message(
id=generate_id("msg"),
conversation_id=data.conversation_id,
role="user",
content=json.dumps({"text": data.content})
)
db.add(user_message)
# 更新会话时间
from datetime import datetime
conversation.updated_at = datetime.utcnow()
db.commit()
db.refresh(user_message)
async def event_generator():
"""SSE事件生成器"""
full_response = ""
message_id = generate_id("msg")
async for event in chat_service.stream_response(
conversation=conversation,
user_message=data.content,
tools_enabled=data.tools_enabled
):
event_type = event.get("type")
if event_type == "process_step":
step_type = event.get("step_type")
if step_type == "text":
content = event.get("content", "")
full_response += content
yield f"data: {json.dumps({'type': 'text', 'content': content})}\n\n"
elif step_type == "tool_call":
yield f"data: {json.dumps({'type': 'tool_call', 'tool_calls': event.get('tool_calls')})}\n\n"
elif step_type == "tool_result":
yield f"data: {json.dumps({'type': 'tool_result', 'result': event.get('result')})}\n\n"
elif event_type == "done":
# 保存AI消息
try:
ai_message = Message(
id=message_id,
conversation_id=data.conversation_id,
role="assistant",
content=json.dumps({"text": full_response}),
token_count=len(full_response) // 4
)
db.add(ai_message)
db.commit()
except Exception as e:
db.rollback()
yield f"data: {json.dumps({'type': 'done', 'message_id': message_id})}\n\n"
elif event_type == "error":
yield f"data: {json.dumps({'type': 'error', 'error': event.get('error')})}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
@router.delete("/{message_id}", response_model=dict)
def delete_message(
message_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""删除消息"""
# 获取消息及其会话
message = db.query(Message).join(Conversation).filter(
Message.id == message_id,
Conversation.user_id == current_user.id
).first()
if not message:
return error_response("消息不存在", 404)
db.delete(message)
db.commit()
return success_response(message="消息删除成功")