feat: 优化server 部分设置
This commit is contained in:
parent
70d52935f0
commit
26989e54aa
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
|
@ -16,7 +17,51 @@ logger = logging.getLogger(__name__)
|
|||
# Global model parameter (loaded once)
|
||||
_model_param: Optional[ModelParameter] = None
|
||||
_project_root = Path(__file__).parent.parent.parent
|
||||
app = FastAPI(title="AstrAI Inference Server", version="0.1.0")
|
||||
|
||||
# Server configuration (set before running server)
|
||||
_server_config: Dict[str, Any] = {
|
||||
"device": "cuda",
|
||||
"dtype": torch.bfloat16,
|
||||
"param_path": None,
|
||||
}
|
||||
|
||||
|
||||
def configure_server(
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
param_path: Optional[Path] = None,
|
||||
):
|
||||
"""Configure server settings before starting.
|
||||
|
||||
Args:
|
||||
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
|
||||
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
|
||||
param_path: Path to model parameters directory
|
||||
"""
|
||||
_server_config["device"] = device
|
||||
_server_config["dtype"] = dtype
|
||||
_server_config["param_path"] = param_path
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Lifespan context manager for startup and shutdown events."""
|
||||
# Startup: Load model with configured settings
|
||||
try:
|
||||
load_model(
|
||||
param_path=_server_config["param_path"],
|
||||
device=_server_config["device"],
|
||||
dtype=_server_config["dtype"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
raise
|
||||
yield
|
||||
# Shutdown: Cleanup if needed
|
||||
pass
|
||||
|
||||
|
||||
app = FastAPI(title="AstrAI Inference Server", version="0.1.0", lifespan=lifespan)
|
||||
|
||||
|
||||
def load_model(
|
||||
|
|
@ -94,16 +139,6 @@ def convert_messages_to_history(
|
|||
return system_prompt, history if history else None
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Load model on server startup."""
|
||||
try:
|
||||
load_model()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok", "model_loaded": _model_param is not None}
|
||||
|
|
@ -220,6 +255,23 @@ async def generate(
|
|||
return {"response": result}
|
||||
|
||||
|
||||
def run_server(host: str = "0.0.0.0", port: int = 8000, reload: bool = False):
|
||||
"""Run the FastAPI server with uvicorn."""
|
||||
def run_server(
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8000,
|
||||
reload: bool = False,
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
param_path: Optional[Path] = None,
|
||||
):
|
||||
"""Run the FastAPI server with uvicorn.
|
||||
|
||||
Args:
|
||||
host: Server host address
|
||||
port: Server port number
|
||||
reload: Enable auto-reload for development
|
||||
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
|
||||
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
|
||||
param_path: Path to model parameters directory
|
||||
"""
|
||||
configure_server(device=device, dtype=dtype, param_path=param_path)
|
||||
uvicorn.run("astrai.inference.server:app", host=host, port=port, reload=reload)
|
||||
|
|
|
|||
Loading…
Reference in New Issue