From 26989e54aa502254bb3627abba7c31ff86ae0885 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 4 Apr 2026 01:41:01 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96server=20=E9=83=A8?= =?UTF-8?q?=E5=88=86=E8=AE=BE=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrai/inference/server.py | 78 +++++++++++++++++++++++++++++++------- 1 file changed, 65 insertions(+), 13 deletions(-) diff --git a/astrai/inference/server.py b/astrai/inference/server.py index e273257..129392a 100644 --- a/astrai/inference/server.py +++ b/astrai/inference/server.py @@ -1,4 +1,5 @@ import logging +from contextlib import asynccontextmanager from pathlib import Path from typing import Any, Dict, List, Optional, Tuple @@ -16,7 +17,51 @@ logger = logging.getLogger(__name__) # Global model parameter (loaded once) _model_param: Optional[ModelParameter] = None _project_root = Path(__file__).parent.parent.parent -app = FastAPI(title="AstrAI Inference Server", version="0.1.0") + +# Server configuration (set before running server) +_server_config: Dict[str, Any] = { + "device": "cuda", + "dtype": torch.bfloat16, + "param_path": None, +} + + +def configure_server( + device: str = "cuda", + dtype: torch.dtype = torch.bfloat16, + param_path: Optional[Path] = None, +): + """Configure server settings before starting. + + Args: + device: Device to load model on (e.g., "cuda", "cpu", "cuda:0") + dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16) + param_path: Path to model parameters directory + """ + _server_config["device"] = device + _server_config["dtype"] = dtype + _server_config["param_path"] = param_path + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Lifespan context manager for startup and shutdown events.""" + # Startup: Load model with configured settings + try: + load_model( + param_path=_server_config["param_path"], + device=_server_config["device"], + dtype=_server_config["dtype"], + ) + except Exception as e: + logger.error(f"Failed to load model: {e}") + raise + yield + # Shutdown: Cleanup if needed + pass + + +app = FastAPI(title="AstrAI Inference Server", version="0.1.0", lifespan=lifespan) def load_model( @@ -94,16 +139,6 @@ def convert_messages_to_history( return system_prompt, history if history else None -@app.on_event("startup") -async def startup_event(): - """Load model on server startup.""" - try: - load_model() - except Exception as e: - logger.error(f"Failed to load model: {e}") - raise - - @app.get("/health") async def health(): return {"status": "ok", "model_loaded": _model_param is not None} @@ -220,6 +255,23 @@ async def generate( return {"response": result} -def run_server(host: str = "0.0.0.0", port: int = 8000, reload: bool = False): - """Run the FastAPI server with uvicorn.""" +def run_server( + host: str = "0.0.0.0", + port: int = 8000, + reload: bool = False, + device: str = "cuda", + dtype: torch.dtype = torch.bfloat16, + param_path: Optional[Path] = None, +): + """Run the FastAPI server with uvicorn. + + Args: + host: Server host address + port: Server port number + reload: Enable auto-reload for development + device: Device to load model on (e.g., "cuda", "cpu", "cuda:0") + dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16) + param_path: Path to model parameters directory + """ + configure_server(device=device, dtype=dtype, param_path=param_path) uvicorn.run("astrai.inference.server:app", host=host, port=port, reload=reload)