AstrAI/astrai/inference/server.py

226 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import torch
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from astrai.config.param_config import ModelParameter
from astrai.inference.generator import GenerationRequest, GeneratorFactory
logger = logging.getLogger(__name__)
# Global model parameter (loaded once)
_model_param: Optional[ModelParameter] = None
_project_root = Path(__file__).parent.parent.parent
app = FastAPI(title="AstrAI Inference Server", version="0.1.0")
def load_model(
param_path: Optional[Path] = None,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
"""Load model parameters into global variable."""
global _model_param
if param_path is None:
param_path = _project_root / "params"
if not param_path.exists():
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
_model_param = ModelParameter.load(param_path, disable_init=True)
_model_param.to(device=device, dtype=dtype)
logger.info(f"Model loaded on {device} with dtype {dtype}")
# Pydantic models for API request/response
class ChatMessage(BaseModel):
role: str # "user", "assistant", "system"
content: str
class ChatCompletionRequest(BaseModel):
messages: List[ChatMessage]
temperature: float = Field(0.8, ge=0.0, le=2.0)
top_p: float = Field(0.95, ge=0.0, le=1.0)
top_k: int = Field(50, ge=0)
max_tokens: int = Field(2048, ge=1)
stream: bool = False
system_prompt: Optional[str] = None
class CompletionResponse(BaseModel):
id: str = "chatcmpl-default"
object: str = "chat.completion"
created: int = 0
model: str = "astrai"
choices: List[Dict[str, Any]]
class StreamCompletionResponse(BaseModel):
id: str = "chatcmpl-default"
object: str = "chat.completion.chunk"
created: int = 0
model: str = "astrai"
choices: List[Dict[str, Any]]
def convert_messages_to_history(
messages: List[ChatMessage],
) -> tuple[Optional[str], Optional[List[Tuple[str, str]]]]:
"""Convert OpenAI-style messages to system_prompt and history."""
system_prompt = None
history: List[Tuple[str, str]] = []
user_buffer = []
assistant_buffer = []
for msg in messages:
if msg.role == "system":
system_prompt = msg.content
elif msg.role == "user":
if assistant_buffer:
# Flush previous pair
history.append(("".join(user_buffer), "".join(assistant_buffer)))
user_buffer = []
assistant_buffer = []
user_buffer.append(msg.content)
elif msg.role == "assistant":
assistant_buffer.append(msg.content)
else:
logger.warning(f"Unknown role {msg.role}")
# If there is a pending user message without assistant, treat as current query
# We'll handle this later
return system_prompt, history if history else None
@app.on_event("startup")
async def startup_event():
"""Load model on server startup."""
try:
load_model()
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
@app.get("/health")
async def health():
return {"status": "ok", "model_loaded": _model_param is not None}
@app.post("/v1/chat/completions", response_model=CompletionResponse)
async def chat_completion(request: ChatCompletionRequest):
"""OpenAIcompatible chat completion endpoint.
Supports both streaming and nonstreaming modes.
"""
if _model_param is None:
raise HTTPException(status_code=503, detail="Model not loaded")
# Convert messages to query/history
# For simplicity, assume the last user message is the query, previous messages are history
system_prompt, history = convert_messages_to_history(request.messages)
# Extract last user message as query
user_messages = [m.content for m in request.messages if m.role == "user"]
if not user_messages:
raise HTTPException(status_code=400, detail="No user message found")
query = user_messages[-1]
# If there are multiple user messages, we could merge them, but for demo we keep simple
gen_request = GenerationRequest(
query=query,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
max_len=request.max_tokens,
history=history,
system_prompt=system_prompt,
stream=request.stream,
)
if request.stream:
# Return streaming response
def generate_stream():
generator = GeneratorFactory.create(_model_param, gen_request)
for chunk in generator.generate(gen_request):
# chunk is the cumulative response string
# For OpenAI compatibility, we send incremental delta
# For simplicity, we send the whole chunk each time
yield f"data: {chunk}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
generate_stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
else:
# Nonstreaming
generator = GeneratorFactory.create(_model_param, gen_request)
if gen_request.stream:
# Should not happen because we set stream=False
pass
response_text = generator.generate(gen_request)
# Build OpenAIstyle response
import time
resp = CompletionResponse(
id=f"chatcmpl-{int(time.time())}",
created=int(time.time()),
choices=[
{
"index": 0,
"message": {"role": "assistant", "content": response_text},
"finish_reason": "stop",
}
],
)
return resp
@app.post("/generate")
async def generate(
query: str,
history: Optional[List[List[str]]] = None,
temperature: float = 0.8,
top_p: float = 0.95,
top_k: int = 50,
max_len: int = 2048,
stream: bool = False,
):
"""Simple generation endpoint compatible with existing GenerationRequest."""
if _model_param is None:
raise HTTPException(status_code=503, detail="Model not loaded")
# Convert history format
hist: Optional[List[Tuple[str, str]]] = None
if history:
hist = [
(h[0], h[1]) for h in history
] # assuming each item is [user, assistant]
gen_request = GenerationRequest(
query=query,
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_len=max_len,
history=hist,
stream=stream,
)
if stream:
def stream_generator():
generator = GeneratorFactory.create(_model_param, gen_request)
for chunk in generator.generate(gen_request):
yield chunk + "\n"
return StreamingResponse(stream_generator(), media_type="text/plain")
else:
generator = GeneratorFactory.create(_model_param, gen_request)
result = generator.generate(gen_request)
return {"response": result}
def run_server(host: str = "0.0.0.0", port: int = 8000, reload: bool = False):
"""Run the FastAPI server with uvicorn."""
uvicorn.run("astrai.inference.server:app", host=host, port=port, reload=reload)