142 lines
4.9 KiB
Python
142 lines
4.9 KiB
Python
"""Unit tests for the inference HTTP server."""
|
||
|
||
import pytest
|
||
|
||
|
||
def test_health_no_model(client, monkeypatch):
|
||
"""GET /health should return 200 even when model not loaded."""
|
||
monkeypatch.setattr("astrai.inference.server._model_param", None)
|
||
response = client.get("/health")
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["status"] == "ok"
|
||
assert data["model_loaded"] == False
|
||
|
||
|
||
def test_health_with_model(client, loaded_model):
|
||
"""GET /health should return 200 when model is loaded."""
|
||
response = client.get("/health")
|
||
assert response.status_code == 200
|
||
assert response.json() == {"status": "ok", "model_loaded": True}
|
||
|
||
|
||
def test_generate_non_stream(client, loaded_model, mock_generator):
|
||
"""POST /generate with stream=false should return JSON response."""
|
||
MockFactory, mock_gen = mock_generator
|
||
mock_gen.generate.return_value = "Test response"
|
||
response = client.post(
|
||
"/generate",
|
||
params={
|
||
"query": "Hello",
|
||
"temperature": 0.8,
|
||
"top_p": 0.95,
|
||
"top_k": 50,
|
||
"max_len": 100,
|
||
"stream": False,
|
||
},
|
||
)
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["response"] == "Test response"
|
||
MockFactory.create.assert_called_once()
|
||
|
||
|
||
def test_generate_stream(client, loaded_model, mock_generator):
|
||
"""POST /generate with stream=true should return plain text stream."""
|
||
MockFactory, mock_gen = mock_generator
|
||
# Simulate a streaming generator that yields two chunks
|
||
mock_gen.generate.return_value = ["chunk1", "chunk2"]
|
||
response = client.post(
|
||
"/generate",
|
||
params={
|
||
"query": "Hello",
|
||
"temperature": 0.8,
|
||
"top_p": 0.95,
|
||
"top_k": 50,
|
||
"max_len": 100,
|
||
"stream": True,
|
||
},
|
||
headers={"Accept": "text/plain"},
|
||
)
|
||
assert response.status_code == 200
|
||
assert response.headers["content-type"] == "text/plain; charset=utf-8"
|
||
# The stream yields lines ending with newline
|
||
content = response.content.decode("utf-8")
|
||
assert "chunk1" in content
|
||
assert "chunk2" in content
|
||
|
||
|
||
def test_chat_completions_non_stream(client, loaded_model, mock_generator):
|
||
"""POST /v1/chat/completions with stream=false returns OpenAI‑style JSON."""
|
||
MockFactory, mock_gen = mock_generator
|
||
mock_gen.generate.return_value = "Assistant reply"
|
||
response = client.post(
|
||
"/v1/chat/completions",
|
||
json={
|
||
"messages": [{"role": "user", "content": "Hello"}],
|
||
"temperature": 0.8,
|
||
"top_p": 0.95,
|
||
"top_k": 50,
|
||
"max_tokens": 100,
|
||
"stream": False,
|
||
},
|
||
)
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["object"] == "chat.completion"
|
||
assert len(data["choices"]) == 1
|
||
assert data["choices"][0]["message"]["content"] == "Assistant reply"
|
||
|
||
|
||
def test_chat_completions_stream(client, loaded_model, mock_generator):
|
||
"""POST /v1/chat/completions with stream=true returns SSE stream."""
|
||
MockFactory, mock_gen = mock_generator
|
||
# Simulate a streaming generator that yields cumulative responses
|
||
mock_gen.generate.return_value = ["cumulative1", "cumulative2"]
|
||
response = client.post(
|
||
"/v1/chat/completions",
|
||
json={
|
||
"messages": [{"role": "user", "content": "Hello"}],
|
||
"temperature": 0.8,
|
||
"top_p": 0.95,
|
||
"top_k": 50,
|
||
"max_tokens": 100,
|
||
"stream": True,
|
||
},
|
||
headers={"Accept": "text/event-stream"},
|
||
)
|
||
assert response.status_code == 200
|
||
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
||
# Parse SSE lines
|
||
lines = [
|
||
line.strip() for line in response.content.decode("utf-8").split("\n") if line
|
||
]
|
||
# Should contain data lines and a final [DONE]
|
||
assert any("cumulative1" in line for line in lines)
|
||
assert any("cumulative2" in line for line in lines)
|
||
|
||
|
||
def test_generate_with_history(client, loaded_model, mock_generator):
|
||
"""POST /generate with history parameter."""
|
||
MockFactory, mock_gen = mock_generator
|
||
mock_gen.generate.return_value = "Response with history"
|
||
response = client.post(
|
||
"/generate",
|
||
params={
|
||
"query": "Hi",
|
||
"history": [["user1", "assistant1"], ["user2", "assistant2"]],
|
||
"stream": False,
|
||
},
|
||
)
|
||
assert response.status_code == 200
|
||
MockFactory.create.assert_called_once()
|
||
# Check that history was passed correctly (currently history is not parsed due to FastAPI limitation)
|
||
call_args = MockFactory.create.call_args
|
||
req = call_args[0][1] # second argument is GenerationRequest
|
||
# Because history cannot be passed via query params, it will be None
|
||
assert req.history is None
|
||
|
||
|
||
if __name__ == "__main__":
|
||
pytest.main([__file__, "-v"])
|