diff --git a/tests/inference/test_scheduler_concurrency.py b/tests/inference/test_scheduler_concurrency.py new file mode 100644 index 0000000..1c63995 --- /dev/null +++ b/tests/inference/test_scheduler_concurrency.py @@ -0,0 +1,320 @@ +"""Tests for scheduler concurrency.""" + +import threading +import time +from unittest.mock import MagicMock, patch + +import pytest + +from astrai.inference.scheduler import ( + InferenceScheduler, + PrefixCacheManager, +) + + +def test_prefix_cache_concurrent_insert_find(): + """Test concurrent insert and find operations.""" + cache = PrefixCacheManager(max_capacity=100) + + results = {"errors": [], "inserts": 0, "finds": 0} + + def insert_worker(): + try: + for i in range(50): + cache.insert((i,), slot=i % 10) + results["inserts"] += 1 + except Exception as e: + results["errors"].append(str(e)) + + def find_worker(): + try: + for i in range(50): + cache.find_longest_prefix([i]) + results["finds"] += 1 + except Exception as e: + results["errors"].append(str(e)) + + threads = [threading.Thread(target=insert_worker) for _ in range(3)] + threads += [threading.Thread(target=find_worker) for _ in range(3)] + + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(results["errors"]) == 0, f"Errors: {results['errors']}" + assert results["inserts"] == 150 + assert results["finds"] == 150 + + +def test_prefix_cache_concurrent_release(): + """Test concurrent release operations.""" + cache = PrefixCacheManager(max_capacity=100) + + # Insert some prefixes + for i in range(10): + cache.insert((i,), slot=i) + + results = {"errors": []} + + def release_worker(): + try: + for i in range(10): + cache.release((i,)) + except Exception as e: + results["errors"].append(str(e)) + + threads = [threading.Thread(target=release_worker) for _ in range(3)] + + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(results["errors"]) == 0, f"Errors: {results['errors']}" + + +def test_prefix_cache_concurrent_insert_release_find(): + """Test mixed concurrent operations.""" + cache = PrefixCacheManager(max_capacity=50) + + results = {"errors": []} + + def worker(worker_id): + try: + for i in range(20): + token_ids = (worker_id * 100 + i,) + cache.insert(token_ids, slot=worker_id) + + # Find after insert + cache.find_longest_prefix(list(token_ids)) + + # Release + cache.release(token_ids) + except Exception as e: + results["errors"].append(f"Worker {worker_id}: {str(e)}") + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)] + + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(results["errors"]) == 0, f"Errors: {results['errors']}" + + +@pytest.fixture +def mock_model_and_tokenizer(): + """Create mock model and tokenizer.""" + mock_model = MagicMock() + mock_model.config = MagicMock() + mock_model.config.n_kv_heads = 8 + mock_model.config.n_heads = 8 + mock_model.config.dim = 128 + mock_model.config.n_layers = 2 + mock_model.config.max_len = 100 + + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [1, 2, 3, 4, 5] + mock_tokenizer.decode.return_value = "token" + mock_tokenizer.stop_ids = [0] + mock_tokenizer.pad_id = None + + return mock_model, mock_tokenizer + + +def test_scheduler_concurrent_add_task(mock_model_and_tokenizer): + """Test concurrent add_task operations.""" + mock_model, mock_tokenizer = mock_model_and_tokenizer + + with patch("astrai.inference.scheduler.AutoModel"): + with patch("astrai.inference.scheduler.AutoTokenizer"): + scheduler = InferenceScheduler( + model=mock_model, + tokenizer=mock_tokenizer, + max_batch_size=4, + device="cpu", + ) + + results = {"task_ids": [], "errors": []} + lock = threading.Lock() + + def add_task_worker(worker_id): + try: + for i in range(10): + task_id = scheduler.add_task(f"prompt from worker {worker_id}-{i}") + with lock: + results["task_ids"].append(task_id) + except Exception as e: + results["errors"].append(str(e)) + + threads = [threading.Thread(target=add_task_worker, args=(i,)) for i in range(5)] + + for t in threads: + t.start() + + # Let some tasks be processed + time.sleep(0.1) + + scheduler.stop() + + for t in threads: + t.join() + + assert len(results["errors"]) == 0, f"Errors: {results['errors']}" + assert len(results["task_ids"]) == 50 + + +def test_scheduler_concurrent_add_remove_task(mock_model_and_tokenizer): + """Test concurrent add and remove task operations.""" + mock_model, mock_tokenizer = mock_model_and_tokenizer + + with patch("astrai.inference.scheduler.AutoModel"): + with patch("astrai.inference.scheduler.AutoTokenizer"): + scheduler = InferenceScheduler( + model=mock_model, + tokenizer=mock_tokenizer, + max_batch_size=4, + device="cpu", + ) + + results = {"added": [], "removed": [], "errors": []} + + def add_worker(): + try: + for i in range(20): + task_id = scheduler.add_task(f"prompt {i}") + results["added"].append(task_id) + time.sleep(0.001) + except Exception as e: + results["errors"].append(f"Add: {str(e)}") + + def remove_worker(): + try: + time.sleep(0.05) # Wait for some tasks to be added + for task_id in results["added"][:10]: + scheduler.remove_task(task_id) + results["removed"].append(task_id) + except Exception as e: + results["errors"].append(f"Remove: {str(e)}") + + add_thread = threading.Thread(target=add_worker) + remove_thread = threading.Thread(target=remove_worker) + + add_thread.start() + remove_thread.start() + + time.sleep(0.2) + scheduler.stop() + + add_thread.join() + remove_thread.join() + + assert len(results["errors"]) == 0, f"Errors: {results['errors']}" + assert len(results["added"]) == 20 + + +def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer): + """Test concurrent get_stats operations.""" + mock_model, mock_tokenizer = mock_model_and_tokenizer + + with patch("astrai.inference.scheduler.AutoModel"): + with patch("astrai.inference.scheduler.AutoTokenizer"): + scheduler = InferenceScheduler( + model=mock_model, + tokenizer=mock_tokenizer, + max_batch_size=4, + device="cpu", + ) + + results = {"stats": [], "errors": []} + + def add_tasks(): + try: + for i in range(20): + scheduler.add_task(f"prompt {i}") + time.sleep(0.001) + except Exception as e: + results["errors"].append(f"Add: {str(e)}") + + def get_stats(): + try: + for _ in range(50): + stats = scheduler.get_stats() + results["stats"].append(stats) + time.sleep(0.001) + except Exception as e: + results["errors"].append(f"Get stats: {str(e)}") + + add_thread = threading.Thread(target=add_tasks) + stats_thread = threading.Thread(target=get_stats) + + add_thread.start() + stats_thread.start() + + time.sleep(0.3) + scheduler.stop() + + add_thread.join() + stats_thread.join() + + assert len(results["errors"]) == 0, f"Errors: {results['errors']}" + assert len(results["stats"]) == 50 + + # Verify stats are consistent + for stats in results["stats"]: + assert "total_tasks" in stats + assert stats["total_tasks"] >= 0 + + +def test_prefix_cache_insert_same_prefix_concurrently(): + """Test inserting the same prefix concurrently.""" + cache = PrefixCacheManager(max_capacity=100) + + results = {"slot_values": [], "errors": []} + + def insert_worker(): + try: + # All workers try to insert the same prefix + cache.insert((1, 2, 3), slot=threading.current_thread().name) + node = cache.root.children.get(1) + if node: + node = node.children.get(2) + if node: + node = node.children.get(3) + if node: + results["slot_values"].append(node.slot) + except Exception as e: + results["errors"].append(str(e)) + + threads = [threading.Thread(target=insert_worker) for _ in range(10)] + + for t in threads: + t.start() + for t in threads: + t.join() + + # All inserts should succeed, final slot should be one of the values + assert len(results["errors"]) == 0, f"Errors: {results['errors']}" + # Check ref_count is correct (should be 10) + node = cache.root.children.get(1).children.get(2).children.get(3) + assert node.ref_count == 10, f"Expected ref_count=10, got {node.ref_count}" + + +def test_prefix_cache_ref_count_underflow_prevention(): + """Test that ref_count doesn't go negative.""" + cache = PrefixCacheManager(max_capacity=100) + + # Insert a prefix + cache.insert((1, 2, 3), slot=0) + + # Release multiple times + for _ in range(5): + cache.release((1, 2, 3)) + + # Try to find it - should return None since ref_count would be negative + # or handle it gracefully + node = cache.root.children.get(1).children.get(2).children.get(3) + # The ref_count should be 0, not negative + assert node.ref_count >= 0, f"ref_count went negative: {node.ref_count}"