feat: 优化server 部分设置
This commit is contained in:
parent
70d52935f0
commit
26989e54aa
|
|
@ -1,4 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
|
@ -16,7 +17,51 @@ logger = logging.getLogger(__name__)
|
||||||
# Global model parameter (loaded once)
|
# Global model parameter (loaded once)
|
||||||
_model_param: Optional[ModelParameter] = None
|
_model_param: Optional[ModelParameter] = None
|
||||||
_project_root = Path(__file__).parent.parent.parent
|
_project_root = Path(__file__).parent.parent.parent
|
||||||
app = FastAPI(title="AstrAI Inference Server", version="0.1.0")
|
|
||||||
|
# Server configuration (set before running server)
|
||||||
|
_server_config: Dict[str, Any] = {
|
||||||
|
"device": "cuda",
|
||||||
|
"dtype": torch.bfloat16,
|
||||||
|
"param_path": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def configure_server(
|
||||||
|
device: str = "cuda",
|
||||||
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
|
param_path: Optional[Path] = None,
|
||||||
|
):
|
||||||
|
"""Configure server settings before starting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
|
||||||
|
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
|
||||||
|
param_path: Path to model parameters directory
|
||||||
|
"""
|
||||||
|
_server_config["device"] = device
|
||||||
|
_server_config["dtype"] = dtype
|
||||||
|
_server_config["param_path"] = param_path
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""Lifespan context manager for startup and shutdown events."""
|
||||||
|
# Startup: Load model with configured settings
|
||||||
|
try:
|
||||||
|
load_model(
|
||||||
|
param_path=_server_config["param_path"],
|
||||||
|
device=_server_config["device"],
|
||||||
|
dtype=_server_config["dtype"],
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load model: {e}")
|
||||||
|
raise
|
||||||
|
yield
|
||||||
|
# Shutdown: Cleanup if needed
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(title="AstrAI Inference Server", version="0.1.0", lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
|
|
@ -94,16 +139,6 @@ def convert_messages_to_history(
|
||||||
return system_prompt, history if history else None
|
return system_prompt, history if history else None
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
|
||||||
async def startup_event():
|
|
||||||
"""Load model on server startup."""
|
|
||||||
try:
|
|
||||||
load_model()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to load model: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health():
|
async def health():
|
||||||
return {"status": "ok", "model_loaded": _model_param is not None}
|
return {"status": "ok", "model_loaded": _model_param is not None}
|
||||||
|
|
@ -220,6 +255,23 @@ async def generate(
|
||||||
return {"response": result}
|
return {"response": result}
|
||||||
|
|
||||||
|
|
||||||
def run_server(host: str = "0.0.0.0", port: int = 8000, reload: bool = False):
|
def run_server(
|
||||||
"""Run the FastAPI server with uvicorn."""
|
host: str = "0.0.0.0",
|
||||||
|
port: int = 8000,
|
||||||
|
reload: bool = False,
|
||||||
|
device: str = "cuda",
|
||||||
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
|
param_path: Optional[Path] = None,
|
||||||
|
):
|
||||||
|
"""Run the FastAPI server with uvicorn.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
host: Server host address
|
||||||
|
port: Server port number
|
||||||
|
reload: Enable auto-reload for development
|
||||||
|
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
|
||||||
|
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
|
||||||
|
param_path: Path to model parameters directory
|
||||||
|
"""
|
||||||
|
configure_server(device=device, dtype=dtype, param_path=param_path)
|
||||||
uvicorn.run("astrai.inference.server:app", host=host, port=port, reload=reload)
|
uvicorn.run("astrai.inference.server:app", host=host, port=port, reload=reload)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue