feat: 优化server 部分设置

This commit is contained in:
ViperEkura 2026-04-04 01:41:01 +08:00
parent 70d52935f0
commit 26989e54aa
1 changed files with 65 additions and 13 deletions

View File

@ -1,4 +1,5 @@
import logging
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
@ -16,7 +17,51 @@ logger = logging.getLogger(__name__)
# Global model parameter (loaded once)
_model_param: Optional[ModelParameter] = None
_project_root = Path(__file__).parent.parent.parent
app = FastAPI(title="AstrAI Inference Server", version="0.1.0")
# Server configuration (set before running server)
_server_config: Dict[str, Any] = {
"device": "cuda",
"dtype": torch.bfloat16,
"param_path": None,
}
def configure_server(
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
param_path: Optional[Path] = None,
):
"""Configure server settings before starting.
Args:
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
param_path: Path to model parameters directory
"""
_server_config["device"] = device
_server_config["dtype"] = dtype
_server_config["param_path"] = param_path
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events."""
# Startup: Load model with configured settings
try:
load_model(
param_path=_server_config["param_path"],
device=_server_config["device"],
dtype=_server_config["dtype"],
)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
# Shutdown: Cleanup if needed
pass
app = FastAPI(title="AstrAI Inference Server", version="0.1.0", lifespan=lifespan)
def load_model(
@ -94,16 +139,6 @@ def convert_messages_to_history(
return system_prompt, history if history else None
@app.on_event("startup")
async def startup_event():
"""Load model on server startup."""
try:
load_model()
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
@app.get("/health")
async def health():
return {"status": "ok", "model_loaded": _model_param is not None}
@ -220,6 +255,23 @@ async def generate(
return {"response": result}
def run_server(host: str = "0.0.0.0", port: int = 8000, reload: bool = False):
"""Run the FastAPI server with uvicorn."""
def run_server(
host: str = "0.0.0.0",
port: int = 8000,
reload: bool = False,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
param_path: Optional[Path] = None,
):
"""Run the FastAPI server with uvicorn.
Args:
host: Server host address
port: Server port number
reload: Enable auto-reload for development
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
param_path: Path to model parameters directory
"""
configure_server(device=device, dtype=dtype, param_path=param_path)
uvicorn.run("astrai.inference.server:app", host=host, port=port, reload=reload)