321 lines
9.4 KiB
Python
321 lines
9.4 KiB
Python
"""Tests for scheduler concurrency."""
|
|
|
|
import threading
|
|
import time
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from astrai.inference.scheduler import (
|
|
InferenceScheduler,
|
|
PrefixCacheManager,
|
|
)
|
|
|
|
|
|
def test_prefix_cache_concurrent_insert_find():
|
|
"""Test concurrent insert and find operations."""
|
|
cache = PrefixCacheManager(max_capacity=100)
|
|
|
|
results = {"errors": [], "inserts": 0, "finds": 0}
|
|
|
|
def insert_worker():
|
|
try:
|
|
for i in range(50):
|
|
cache.insert((i,), slot=i % 10)
|
|
results["inserts"] += 1
|
|
except Exception as e:
|
|
results["errors"].append(str(e))
|
|
|
|
def find_worker():
|
|
try:
|
|
for i in range(50):
|
|
cache.find_longest_prefix([i])
|
|
results["finds"] += 1
|
|
except Exception as e:
|
|
results["errors"].append(str(e))
|
|
|
|
threads = [threading.Thread(target=insert_worker) for _ in range(3)]
|
|
threads += [threading.Thread(target=find_worker) for _ in range(3)]
|
|
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join()
|
|
|
|
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
|
assert results["inserts"] == 150
|
|
assert results["finds"] == 150
|
|
|
|
|
|
def test_prefix_cache_concurrent_release():
|
|
"""Test concurrent release operations."""
|
|
cache = PrefixCacheManager(max_capacity=100)
|
|
|
|
# Insert some prefixes
|
|
for i in range(10):
|
|
cache.insert((i,), slot=i)
|
|
|
|
results = {"errors": []}
|
|
|
|
def release_worker():
|
|
try:
|
|
for i in range(10):
|
|
cache.release((i,))
|
|
except Exception as e:
|
|
results["errors"].append(str(e))
|
|
|
|
threads = [threading.Thread(target=release_worker) for _ in range(3)]
|
|
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join()
|
|
|
|
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
|
|
|
|
|
def test_prefix_cache_concurrent_insert_release_find():
|
|
"""Test mixed concurrent operations."""
|
|
cache = PrefixCacheManager(max_capacity=50)
|
|
|
|
results = {"errors": []}
|
|
|
|
def worker(worker_id):
|
|
try:
|
|
for i in range(20):
|
|
token_ids = (worker_id * 100 + i,)
|
|
cache.insert(token_ids, slot=worker_id)
|
|
|
|
# Find after insert
|
|
cache.find_longest_prefix(list(token_ids))
|
|
|
|
# Release
|
|
cache.release(token_ids)
|
|
except Exception as e:
|
|
results["errors"].append(f"Worker {worker_id}: {str(e)}")
|
|
|
|
threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)]
|
|
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join()
|
|
|
|
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_model_and_tokenizer():
|
|
"""Create mock model and tokenizer."""
|
|
mock_model = MagicMock()
|
|
mock_model.config = MagicMock()
|
|
mock_model.config.n_kv_heads = 8
|
|
mock_model.config.n_heads = 8
|
|
mock_model.config.dim = 128
|
|
mock_model.config.n_layers = 2
|
|
mock_model.config.max_len = 100
|
|
|
|
mock_tokenizer = MagicMock()
|
|
mock_tokenizer.encode.return_value = [1, 2, 3, 4, 5]
|
|
mock_tokenizer.decode.return_value = "token"
|
|
mock_tokenizer.stop_ids = [0]
|
|
mock_tokenizer.pad_id = None
|
|
|
|
return mock_model, mock_tokenizer
|
|
|
|
|
|
def test_scheduler_concurrent_add_task(mock_model_and_tokenizer):
|
|
"""Test concurrent add_task operations."""
|
|
mock_model, mock_tokenizer = mock_model_and_tokenizer
|
|
|
|
with patch("astrai.inference.scheduler.AutoModel"):
|
|
with patch("astrai.inference.scheduler.AutoTokenizer"):
|
|
scheduler = InferenceScheduler(
|
|
model=mock_model,
|
|
tokenizer=mock_tokenizer,
|
|
max_batch_size=4,
|
|
device="cpu",
|
|
)
|
|
|
|
results = {"task_ids": [], "errors": []}
|
|
lock = threading.Lock()
|
|
|
|
def add_task_worker(worker_id):
|
|
try:
|
|
for i in range(10):
|
|
task_id = scheduler.add_task(f"prompt from worker {worker_id}-{i}")
|
|
with lock:
|
|
results["task_ids"].append(task_id)
|
|
except Exception as e:
|
|
results["errors"].append(str(e))
|
|
|
|
threads = [threading.Thread(target=add_task_worker, args=(i,)) for i in range(5)]
|
|
|
|
for t in threads:
|
|
t.start()
|
|
|
|
# Let some tasks be processed
|
|
time.sleep(0.1)
|
|
|
|
scheduler.stop()
|
|
|
|
for t in threads:
|
|
t.join()
|
|
|
|
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
|
assert len(results["task_ids"]) == 50
|
|
|
|
|
|
def test_scheduler_concurrent_add_remove_task(mock_model_and_tokenizer):
|
|
"""Test concurrent add and remove task operations."""
|
|
mock_model, mock_tokenizer = mock_model_and_tokenizer
|
|
|
|
with patch("astrai.inference.scheduler.AutoModel"):
|
|
with patch("astrai.inference.scheduler.AutoTokenizer"):
|
|
scheduler = InferenceScheduler(
|
|
model=mock_model,
|
|
tokenizer=mock_tokenizer,
|
|
max_batch_size=4,
|
|
device="cpu",
|
|
)
|
|
|
|
results = {"added": [], "removed": [], "errors": []}
|
|
|
|
def add_worker():
|
|
try:
|
|
for i in range(20):
|
|
task_id = scheduler.add_task(f"prompt {i}")
|
|
results["added"].append(task_id)
|
|
time.sleep(0.001)
|
|
except Exception as e:
|
|
results["errors"].append(f"Add: {str(e)}")
|
|
|
|
def remove_worker():
|
|
try:
|
|
time.sleep(0.05) # Wait for some tasks to be added
|
|
for task_id in results["added"][:10]:
|
|
scheduler.remove_task(task_id)
|
|
results["removed"].append(task_id)
|
|
except Exception as e:
|
|
results["errors"].append(f"Remove: {str(e)}")
|
|
|
|
add_thread = threading.Thread(target=add_worker)
|
|
remove_thread = threading.Thread(target=remove_worker)
|
|
|
|
add_thread.start()
|
|
remove_thread.start()
|
|
|
|
time.sleep(0.2)
|
|
scheduler.stop()
|
|
|
|
add_thread.join()
|
|
remove_thread.join()
|
|
|
|
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
|
assert len(results["added"]) == 20
|
|
|
|
|
|
def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
|
|
"""Test concurrent get_stats operations."""
|
|
mock_model, mock_tokenizer = mock_model_and_tokenizer
|
|
|
|
with patch("astrai.inference.scheduler.AutoModel"):
|
|
with patch("astrai.inference.scheduler.AutoTokenizer"):
|
|
scheduler = InferenceScheduler(
|
|
model=mock_model,
|
|
tokenizer=mock_tokenizer,
|
|
max_batch_size=4,
|
|
device="cpu",
|
|
)
|
|
|
|
results = {"stats": [], "errors": []}
|
|
|
|
def add_tasks():
|
|
try:
|
|
for i in range(20):
|
|
scheduler.add_task(f"prompt {i}")
|
|
time.sleep(0.001)
|
|
except Exception as e:
|
|
results["errors"].append(f"Add: {str(e)}")
|
|
|
|
def get_stats():
|
|
try:
|
|
for _ in range(50):
|
|
stats = scheduler.get_stats()
|
|
results["stats"].append(stats)
|
|
time.sleep(0.001)
|
|
except Exception as e:
|
|
results["errors"].append(f"Get stats: {str(e)}")
|
|
|
|
add_thread = threading.Thread(target=add_tasks)
|
|
stats_thread = threading.Thread(target=get_stats)
|
|
|
|
add_thread.start()
|
|
stats_thread.start()
|
|
|
|
time.sleep(0.3)
|
|
scheduler.stop()
|
|
|
|
add_thread.join()
|
|
stats_thread.join()
|
|
|
|
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
|
assert len(results["stats"]) == 50
|
|
|
|
# Verify stats are consistent
|
|
for stats in results["stats"]:
|
|
assert "total_tasks" in stats
|
|
assert stats["total_tasks"] >= 0
|
|
|
|
|
|
def test_prefix_cache_insert_same_prefix_concurrently():
|
|
"""Test inserting the same prefix concurrently."""
|
|
cache = PrefixCacheManager(max_capacity=100)
|
|
|
|
results = {"slot_values": [], "errors": []}
|
|
|
|
def insert_worker():
|
|
try:
|
|
# All workers try to insert the same prefix
|
|
cache.insert((1, 2, 3), slot=threading.current_thread().name)
|
|
node = cache.root.children.get(1)
|
|
if node:
|
|
node = node.children.get(2)
|
|
if node:
|
|
node = node.children.get(3)
|
|
if node:
|
|
results["slot_values"].append(node.slot)
|
|
except Exception as e:
|
|
results["errors"].append(str(e))
|
|
|
|
threads = [threading.Thread(target=insert_worker) for _ in range(10)]
|
|
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join()
|
|
|
|
# All inserts should succeed, final slot should be one of the values
|
|
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
|
# Check ref_count is correct (should be 10)
|
|
node = cache.root.children.get(1).children.get(2).children.get(3)
|
|
assert node.ref_count == 10, f"Expected ref_count=10, got {node.ref_count}"
|
|
|
|
|
|
def test_prefix_cache_ref_count_underflow_prevention():
|
|
"""Test that ref_count doesn't go negative."""
|
|
cache = PrefixCacheManager(max_capacity=100)
|
|
|
|
# Insert a prefix
|
|
cache.insert((1, 2, 3), slot=0)
|
|
|
|
# Release multiple times
|
|
for _ in range(5):
|
|
cache.release((1, 2, 3))
|
|
|
|
# Try to find it - should return None since ref_count would be negative
|
|
# or handle it gracefully
|
|
node = cache.root.children.get(1).children.get(2).children.get(3)
|
|
# The ref_count should be 0, not negative
|
|
assert node.ref_count >= 0, f"ref_count went negative: {node.ref_count}"
|