179 lines
5.0 KiB
Python
179 lines
5.0 KiB
Python
"""Message routes"""
|
|
import json
|
|
from fastapi import APIRouter, Depends, Response
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel
|
|
from sqlalchemy.orm import Session
|
|
from datetime import datetime
|
|
|
|
from luxx.database import get_db
|
|
from luxx.models import Conversation, Message, User
|
|
from luxx.routes.auth import get_current_user
|
|
from luxx.services.chat import chat_service
|
|
from luxx.utils.helpers import generate_id, success_response, error_response
|
|
|
|
|
|
router = APIRouter(prefix="/messages", tags=["Messages"])
|
|
|
|
|
|
class MessageCreate(BaseModel):
|
|
"""Create message model"""
|
|
conversation_id: str
|
|
content: str
|
|
|
|
|
|
class MessageResponse(BaseModel):
|
|
"""Message response model"""
|
|
id: str
|
|
role: str
|
|
content: str
|
|
token_count: int
|
|
|
|
|
|
@router.get("/", response_model=dict)
|
|
def list_messages(
|
|
conversation_id: str,
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Get message list"""
|
|
conversation = db.query(Conversation).filter(
|
|
Conversation.id == conversation_id,
|
|
Conversation.user_id == current_user.id
|
|
).first()
|
|
|
|
if not conversation:
|
|
return error_response("Conversation not found", 404)
|
|
|
|
messages = db.query(Message).filter(
|
|
Message.conversation_id == conversation_id
|
|
).order_by(Message.created_at).all()
|
|
|
|
return success_response(data={
|
|
"messages": [m.to_dict() for m in messages],
|
|
"title": conversation.title,
|
|
"first_message": next((m.content[:50] + ('...' if len(m.content) > 50 else '') for m in messages if m.role == 'user'), None)
|
|
})
|
|
|
|
|
|
@router.post("/", response_model=dict)
|
|
def send_message(
|
|
data: MessageCreate,
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Send message (non-streaming)"""
|
|
conversation = db.query(Conversation).filter(
|
|
Conversation.id == data.conversation_id,
|
|
Conversation.user_id == current_user.id
|
|
).first()
|
|
|
|
if not conversation:
|
|
return error_response("Conversation not found", 404)
|
|
|
|
user_message = Message(
|
|
id=generate_id("msg"),
|
|
conversation_id=data.conversation_id,
|
|
role="user",
|
|
content=data.content,
|
|
token_count=len(data.content) // 4
|
|
)
|
|
db.add(user_message)
|
|
|
|
from datetime import datetime, timezone, timedelta
|
|
conversation.updated_at = datetime.now(timezone(timedelta(hours=8)))
|
|
|
|
response = chat_service.non_stream_response(
|
|
conversation=conversation,
|
|
user_message=data.content,
|
|
tools_enabled=False
|
|
)
|
|
|
|
if not response.get("success"):
|
|
return error_response(response.get("error", "Failed to generate response"), 500)
|
|
|
|
ai_content = response.get("content", "")
|
|
|
|
ai_message = Message(
|
|
id=generate_id("msg"),
|
|
conversation_id=data.conversation_id,
|
|
role="assistant",
|
|
content=ai_content,
|
|
token_count=len(ai_content) // 4
|
|
)
|
|
db.add(ai_message)
|
|
db.commit()
|
|
|
|
return success_response(data={
|
|
"user_message": user_message.to_dict(),
|
|
"assistant_message": ai_message.to_dict()
|
|
})
|
|
|
|
|
|
@router.post("/stream")
|
|
async def stream_message(
|
|
data: MessageCreate,
|
|
tools_enabled: bool = True,
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Send message (streaming response - SSE)"""
|
|
conversation = db.query(Conversation).filter(
|
|
Conversation.id == data.conversation_id,
|
|
Conversation.user_id == current_user.id
|
|
).first()
|
|
|
|
if not conversation:
|
|
return error_response("Conversation not found", 404)
|
|
|
|
user_message = Message(
|
|
id=generate_id("msg"),
|
|
conversation_id=data.conversation_id,
|
|
role="user",
|
|
content=data.content,
|
|
token_count=len(data.content) // 4
|
|
)
|
|
db.add(user_message)
|
|
conversation.updated_at = datetime.now()
|
|
db.commit()
|
|
|
|
async def event_generator():
|
|
async for sse_str in chat_service.stream_response(
|
|
conversation=conversation,
|
|
user_message=data.content,
|
|
tools_enabled=tools_enabled
|
|
):
|
|
# Chat service returns raw SSE strings (including done event)
|
|
yield sse_str
|
|
|
|
return StreamingResponse(
|
|
event_generator(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no"
|
|
}
|
|
)
|
|
|
|
|
|
@router.delete("/{message_id}", response_model=dict)
|
|
def delete_message(
|
|
message_id: str,
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Delete message"""
|
|
message = db.query(Message).join(Conversation).filter(
|
|
Message.id == message_id,
|
|
Conversation.user_id == current_user.id
|
|
).first()
|
|
|
|
if not message:
|
|
return error_response("Message not found", 404)
|
|
|
|
db.delete(message)
|
|
db.commit()
|
|
|
|
return success_response(message="Message deleted successfully")
|