feat: 优化server 部分设置

This commit is contained in:
ViperEkura 2026-04-04 01:41:01 +08:00
parent 70d52935f0
commit 26989e54aa
1 changed files with 65 additions and 13 deletions

View File

@ -1,4 +1,5 @@
import logging import logging
from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
@ -16,7 +17,51 @@ logger = logging.getLogger(__name__)
# Global model parameter (loaded once) # Global model parameter (loaded once)
_model_param: Optional[ModelParameter] = None _model_param: Optional[ModelParameter] = None
_project_root = Path(__file__).parent.parent.parent _project_root = Path(__file__).parent.parent.parent
app = FastAPI(title="AstrAI Inference Server", version="0.1.0")
# Server configuration (set before running server)
_server_config: Dict[str, Any] = {
"device": "cuda",
"dtype": torch.bfloat16,
"param_path": None,
}
def configure_server(
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
param_path: Optional[Path] = None,
):
"""Configure server settings before starting.
Args:
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
param_path: Path to model parameters directory
"""
_server_config["device"] = device
_server_config["dtype"] = dtype
_server_config["param_path"] = param_path
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events."""
# Startup: Load model with configured settings
try:
load_model(
param_path=_server_config["param_path"],
device=_server_config["device"],
dtype=_server_config["dtype"],
)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
# Shutdown: Cleanup if needed
pass
app = FastAPI(title="AstrAI Inference Server", version="0.1.0", lifespan=lifespan)
def load_model( def load_model(
@ -94,16 +139,6 @@ def convert_messages_to_history(
return system_prompt, history if history else None return system_prompt, history if history else None
@app.on_event("startup")
async def startup_event():
"""Load model on server startup."""
try:
load_model()
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
@app.get("/health") @app.get("/health")
async def health(): async def health():
return {"status": "ok", "model_loaded": _model_param is not None} return {"status": "ok", "model_loaded": _model_param is not None}
@ -220,6 +255,23 @@ async def generate(
return {"response": result} return {"response": result}
def run_server(host: str = "0.0.0.0", port: int = 8000, reload: bool = False): def run_server(
"""Run the FastAPI server with uvicorn.""" host: str = "0.0.0.0",
port: int = 8000,
reload: bool = False,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
param_path: Optional[Path] = None,
):
"""Run the FastAPI server with uvicorn.
Args:
host: Server host address
port: Server port number
reload: Enable auto-reload for development
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
param_path: Path to model parameters directory
"""
configure_server(device=device, dtype=dtype, param_path=param_path)
uvicorn.run("astrai.inference.server:app", host=host, port=port, reload=reload) uvicorn.run("astrai.inference.server:app", host=host, port=port, reload=reload)