239 lines
7.0 KiB
Python
239 lines
7.0 KiB
Python
"""消息路由"""
|
||
import json
|
||
from typing import Optional, List
|
||
from fastapi import APIRouter, Depends, HTTPException
|
||
from fastapi.responses import StreamingResponse
|
||
from pydantic import BaseModel
|
||
from sqlalchemy.orm import Session
|
||
|
||
from alcor.database import get_db
|
||
from alcor.models import Conversation, Message, User
|
||
from alcor.routes.auth import get_current_user
|
||
from alcor.services.chat import chat_service
|
||
from alcor.utils.helpers import generate_id, success_response, error_response
|
||
|
||
|
||
router = APIRouter(prefix="/messages", tags=["消息"])
|
||
|
||
|
||
class MessageCreate(BaseModel):
|
||
"""创建消息模型"""
|
||
conversation_id: str
|
||
content: str
|
||
tools_enabled: bool = True
|
||
|
||
|
||
class MessageResponse(BaseModel):
|
||
"""消息响应模型"""
|
||
id: str
|
||
conversation_id: str
|
||
role: str
|
||
content: str
|
||
token_count: int
|
||
created_at: str
|
||
|
||
|
||
@router.get("/{conversation_id}", response_model=dict)
|
||
def list_messages(
|
||
conversation_id: str,
|
||
limit: int = 100,
|
||
current_user: User = Depends(get_current_user),
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""获取消息列表"""
|
||
# 验证会话归属
|
||
conversation = db.query(Conversation).filter(
|
||
Conversation.id == conversation_id,
|
||
Conversation.user_id == current_user.id
|
||
).first()
|
||
|
||
if not conversation:
|
||
return error_response("会话不存在", 404)
|
||
|
||
messages = db.query(Message).filter(
|
||
Message.conversation_id == conversation_id
|
||
).order_by(Message.created_at.desc()).limit(limit).all()
|
||
|
||
items = [msg.to_dict() for msg in reversed(messages)]
|
||
|
||
return success_response(data={
|
||
"items": items,
|
||
"total": len(items)
|
||
})
|
||
|
||
|
||
@router.post("/", response_model=dict)
|
||
async def create_message(
|
||
data: MessageCreate,
|
||
current_user: User = Depends(get_current_user),
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""发送消息(非流式)"""
|
||
# 验证会话
|
||
conversation = db.query(Conversation).filter(
|
||
Conversation.id == data.conversation_id,
|
||
Conversation.user_id == current_user.id
|
||
).first()
|
||
|
||
if not conversation:
|
||
return error_response("会话不存在", 404)
|
||
|
||
# 保存用户消息
|
||
user_message = Message(
|
||
id=generate_id("msg"),
|
||
conversation_id=data.conversation_id,
|
||
role="user",
|
||
content=json.dumps({"text": data.content})
|
||
)
|
||
db.add(user_message)
|
||
|
||
# 更新会话时间
|
||
from datetime import datetime
|
||
conversation.updated_at = datetime.utcnow()
|
||
|
||
db.commit()
|
||
db.refresh(user_message)
|
||
|
||
# 获取AI响应(非流式)
|
||
response = chat_service.non_stream_response(
|
||
conversation=conversation,
|
||
user_message=data.content,
|
||
tools_enabled=data.tools_enabled
|
||
)
|
||
|
||
if not response.get("success"):
|
||
return error_response(response.get("error", "生成响应失败"), 500)
|
||
|
||
# 保存AI响应
|
||
ai_content = response.get("content", "")
|
||
ai_message = Message(
|
||
id=generate_id("msg"),
|
||
conversation_id=data.conversation_id,
|
||
role="assistant",
|
||
content=json.dumps({
|
||
"text": ai_content,
|
||
"tool_calls": response.get("tool_calls")
|
||
}),
|
||
token_count=len(ai_content) // 4 # 粗略估算
|
||
)
|
||
db.add(ai_message)
|
||
db.commit()
|
||
|
||
return success_response(data={
|
||
"user_message": user_message.to_dict(),
|
||
"assistant_message": ai_message.to_dict()
|
||
})
|
||
|
||
|
||
@router.post("/stream")
|
||
async def stream_message(
|
||
data: MessageCreate,
|
||
current_user: User = Depends(get_current_user),
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""发送消息(流式响应 - SSE)"""
|
||
# 验证会话
|
||
conversation = db.query(Conversation).filter(
|
||
Conversation.id == data.conversation_id,
|
||
Conversation.user_id == current_user.id
|
||
).first()
|
||
|
||
if not conversation:
|
||
return error_response("会话不存在", 404)
|
||
|
||
# 保存用户消息
|
||
user_message = Message(
|
||
id=generate_id("msg"),
|
||
conversation_id=data.conversation_id,
|
||
role="user",
|
||
content=json.dumps({"text": data.content})
|
||
)
|
||
db.add(user_message)
|
||
|
||
# 更新会话时间
|
||
from datetime import datetime
|
||
conversation.updated_at = datetime.utcnow()
|
||
|
||
db.commit()
|
||
db.refresh(user_message)
|
||
|
||
async def event_generator():
|
||
"""SSE事件生成器"""
|
||
full_response = ""
|
||
message_id = generate_id("msg")
|
||
|
||
async for event in chat_service.stream_response(
|
||
conversation=conversation,
|
||
user_message=data.content,
|
||
tools_enabled=data.tools_enabled
|
||
):
|
||
event_type = event.get("type")
|
||
|
||
if event_type == "process_step":
|
||
step_type = event.get("step_type")
|
||
|
||
if step_type == "text":
|
||
content = event.get("content", "")
|
||
full_response += content
|
||
yield f"data: {json.dumps({'type': 'text', 'content': content})}\n\n"
|
||
|
||
elif step_type == "tool_call":
|
||
yield f"data: {json.dumps({'type': 'tool_call', 'tool_calls': event.get('tool_calls')})}\n\n"
|
||
|
||
elif step_type == "tool_result":
|
||
yield f"data: {json.dumps({'type': 'tool_result', 'result': event.get('result')})}\n\n"
|
||
|
||
elif event_type == "done":
|
||
# 保存AI消息
|
||
try:
|
||
ai_message = Message(
|
||
id=message_id,
|
||
conversation_id=data.conversation_id,
|
||
role="assistant",
|
||
content=json.dumps({"text": full_response}),
|
||
token_count=len(full_response) // 4
|
||
)
|
||
db.add(ai_message)
|
||
db.commit()
|
||
except Exception as e:
|
||
db.rollback()
|
||
|
||
yield f"data: {json.dumps({'type': 'done', 'message_id': message_id})}\n\n"
|
||
|
||
elif event_type == "error":
|
||
yield f"data: {json.dumps({'type': 'error', 'error': event.get('error')})}\n\n"
|
||
|
||
yield "data: [DONE]\n\n"
|
||
|
||
return StreamingResponse(
|
||
event_generator(),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"Connection": "keep-alive",
|
||
"X-Accel-Buffering": "no"
|
||
}
|
||
)
|
||
|
||
|
||
@router.delete("/{message_id}", response_model=dict)
|
||
def delete_message(
|
||
message_id: str,
|
||
current_user: User = Depends(get_current_user),
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""删除消息"""
|
||
# 获取消息及其会话
|
||
message = db.query(Message).join(Conversation).filter(
|
||
Message.id == message_id,
|
||
Conversation.user_id == current_user.id
|
||
).first()
|
||
|
||
if not message:
|
||
return error_response("消息不存在", 404)
|
||
|
||
db.delete(message)
|
||
db.commit()
|
||
|
||
return success_response(message="消息删除成功")
|