chore: 增加并发测试
This commit is contained in:
parent
29beb174a5
commit
a2ae742988
|
|
@ -0,0 +1,320 @@
|
|||
"""Tests for scheduler concurrency."""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from astrai.inference.scheduler import (
|
||||
InferenceScheduler,
|
||||
PrefixCacheManager,
|
||||
)
|
||||
|
||||
|
||||
def test_prefix_cache_concurrent_insert_find():
|
||||
"""Test concurrent insert and find operations."""
|
||||
cache = PrefixCacheManager(max_capacity=100)
|
||||
|
||||
results = {"errors": [], "inserts": 0, "finds": 0}
|
||||
|
||||
def insert_worker():
|
||||
try:
|
||||
for i in range(50):
|
||||
cache.insert((i,), slot=i % 10)
|
||||
results["inserts"] += 1
|
||||
except Exception as e:
|
||||
results["errors"].append(str(e))
|
||||
|
||||
def find_worker():
|
||||
try:
|
||||
for i in range(50):
|
||||
cache.find_longest_prefix([i])
|
||||
results["finds"] += 1
|
||||
except Exception as e:
|
||||
results["errors"].append(str(e))
|
||||
|
||||
threads = [threading.Thread(target=insert_worker) for _ in range(3)]
|
||||
threads += [threading.Thread(target=find_worker) for _ in range(3)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
||||
assert results["inserts"] == 150
|
||||
assert results["finds"] == 150
|
||||
|
||||
|
||||
def test_prefix_cache_concurrent_release():
|
||||
"""Test concurrent release operations."""
|
||||
cache = PrefixCacheManager(max_capacity=100)
|
||||
|
||||
# Insert some prefixes
|
||||
for i in range(10):
|
||||
cache.insert((i,), slot=i)
|
||||
|
||||
results = {"errors": []}
|
||||
|
||||
def release_worker():
|
||||
try:
|
||||
for i in range(10):
|
||||
cache.release((i,))
|
||||
except Exception as e:
|
||||
results["errors"].append(str(e))
|
||||
|
||||
threads = [threading.Thread(target=release_worker) for _ in range(3)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
||||
|
||||
|
||||
def test_prefix_cache_concurrent_insert_release_find():
|
||||
"""Test mixed concurrent operations."""
|
||||
cache = PrefixCacheManager(max_capacity=50)
|
||||
|
||||
results = {"errors": []}
|
||||
|
||||
def worker(worker_id):
|
||||
try:
|
||||
for i in range(20):
|
||||
token_ids = (worker_id * 100 + i,)
|
||||
cache.insert(token_ids, slot=worker_id)
|
||||
|
||||
# Find after insert
|
||||
cache.find_longest_prefix(list(token_ids))
|
||||
|
||||
# Release
|
||||
cache.release(token_ids)
|
||||
except Exception as e:
|
||||
results["errors"].append(f"Worker {worker_id}: {str(e)}")
|
||||
|
||||
threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_and_tokenizer():
|
||||
"""Create mock model and tokenizer."""
|
||||
mock_model = MagicMock()
|
||||
mock_model.config = MagicMock()
|
||||
mock_model.config.n_kv_heads = 8
|
||||
mock_model.config.n_heads = 8
|
||||
mock_model.config.dim = 128
|
||||
mock_model.config.n_layers = 2
|
||||
mock_model.config.max_len = 100
|
||||
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_tokenizer.encode.return_value = [1, 2, 3, 4, 5]
|
||||
mock_tokenizer.decode.return_value = "token"
|
||||
mock_tokenizer.stop_ids = [0]
|
||||
mock_tokenizer.pad_id = None
|
||||
|
||||
return mock_model, mock_tokenizer
|
||||
|
||||
|
||||
def test_scheduler_concurrent_add_task(mock_model_and_tokenizer):
|
||||
"""Test concurrent add_task operations."""
|
||||
mock_model, mock_tokenizer = mock_model_and_tokenizer
|
||||
|
||||
with patch("astrai.inference.scheduler.AutoModel"):
|
||||
with patch("astrai.inference.scheduler.AutoTokenizer"):
|
||||
scheduler = InferenceScheduler(
|
||||
model=mock_model,
|
||||
tokenizer=mock_tokenizer,
|
||||
max_batch_size=4,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
results = {"task_ids": [], "errors": []}
|
||||
lock = threading.Lock()
|
||||
|
||||
def add_task_worker(worker_id):
|
||||
try:
|
||||
for i in range(10):
|
||||
task_id = scheduler.add_task(f"prompt from worker {worker_id}-{i}")
|
||||
with lock:
|
||||
results["task_ids"].append(task_id)
|
||||
except Exception as e:
|
||||
results["errors"].append(str(e))
|
||||
|
||||
threads = [threading.Thread(target=add_task_worker, args=(i,)) for i in range(5)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
# Let some tasks be processed
|
||||
time.sleep(0.1)
|
||||
|
||||
scheduler.stop()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
||||
assert len(results["task_ids"]) == 50
|
||||
|
||||
|
||||
def test_scheduler_concurrent_add_remove_task(mock_model_and_tokenizer):
|
||||
"""Test concurrent add and remove task operations."""
|
||||
mock_model, mock_tokenizer = mock_model_and_tokenizer
|
||||
|
||||
with patch("astrai.inference.scheduler.AutoModel"):
|
||||
with patch("astrai.inference.scheduler.AutoTokenizer"):
|
||||
scheduler = InferenceScheduler(
|
||||
model=mock_model,
|
||||
tokenizer=mock_tokenizer,
|
||||
max_batch_size=4,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
results = {"added": [], "removed": [], "errors": []}
|
||||
|
||||
def add_worker():
|
||||
try:
|
||||
for i in range(20):
|
||||
task_id = scheduler.add_task(f"prompt {i}")
|
||||
results["added"].append(task_id)
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
results["errors"].append(f"Add: {str(e)}")
|
||||
|
||||
def remove_worker():
|
||||
try:
|
||||
time.sleep(0.05) # Wait for some tasks to be added
|
||||
for task_id in results["added"][:10]:
|
||||
scheduler.remove_task(task_id)
|
||||
results["removed"].append(task_id)
|
||||
except Exception as e:
|
||||
results["errors"].append(f"Remove: {str(e)}")
|
||||
|
||||
add_thread = threading.Thread(target=add_worker)
|
||||
remove_thread = threading.Thread(target=remove_worker)
|
||||
|
||||
add_thread.start()
|
||||
remove_thread.start()
|
||||
|
||||
time.sleep(0.2)
|
||||
scheduler.stop()
|
||||
|
||||
add_thread.join()
|
||||
remove_thread.join()
|
||||
|
||||
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
||||
assert len(results["added"]) == 20
|
||||
|
||||
|
||||
def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
|
||||
"""Test concurrent get_stats operations."""
|
||||
mock_model, mock_tokenizer = mock_model_and_tokenizer
|
||||
|
||||
with patch("astrai.inference.scheduler.AutoModel"):
|
||||
with patch("astrai.inference.scheduler.AutoTokenizer"):
|
||||
scheduler = InferenceScheduler(
|
||||
model=mock_model,
|
||||
tokenizer=mock_tokenizer,
|
||||
max_batch_size=4,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
results = {"stats": [], "errors": []}
|
||||
|
||||
def add_tasks():
|
||||
try:
|
||||
for i in range(20):
|
||||
scheduler.add_task(f"prompt {i}")
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
results["errors"].append(f"Add: {str(e)}")
|
||||
|
||||
def get_stats():
|
||||
try:
|
||||
for _ in range(50):
|
||||
stats = scheduler.get_stats()
|
||||
results["stats"].append(stats)
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
results["errors"].append(f"Get stats: {str(e)}")
|
||||
|
||||
add_thread = threading.Thread(target=add_tasks)
|
||||
stats_thread = threading.Thread(target=get_stats)
|
||||
|
||||
add_thread.start()
|
||||
stats_thread.start()
|
||||
|
||||
time.sleep(0.3)
|
||||
scheduler.stop()
|
||||
|
||||
add_thread.join()
|
||||
stats_thread.join()
|
||||
|
||||
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
||||
assert len(results["stats"]) == 50
|
||||
|
||||
# Verify stats are consistent
|
||||
for stats in results["stats"]:
|
||||
assert "total_tasks" in stats
|
||||
assert stats["total_tasks"] >= 0
|
||||
|
||||
|
||||
def test_prefix_cache_insert_same_prefix_concurrently():
|
||||
"""Test inserting the same prefix concurrently."""
|
||||
cache = PrefixCacheManager(max_capacity=100)
|
||||
|
||||
results = {"slot_values": [], "errors": []}
|
||||
|
||||
def insert_worker():
|
||||
try:
|
||||
# All workers try to insert the same prefix
|
||||
cache.insert((1, 2, 3), slot=threading.current_thread().name)
|
||||
node = cache.root.children.get(1)
|
||||
if node:
|
||||
node = node.children.get(2)
|
||||
if node:
|
||||
node = node.children.get(3)
|
||||
if node:
|
||||
results["slot_values"].append(node.slot)
|
||||
except Exception as e:
|
||||
results["errors"].append(str(e))
|
||||
|
||||
threads = [threading.Thread(target=insert_worker) for _ in range(10)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# All inserts should succeed, final slot should be one of the values
|
||||
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
||||
# Check ref_count is correct (should be 10)
|
||||
node = cache.root.children.get(1).children.get(2).children.get(3)
|
||||
assert node.ref_count == 10, f"Expected ref_count=10, got {node.ref_count}"
|
||||
|
||||
|
||||
def test_prefix_cache_ref_count_underflow_prevention():
|
||||
"""Test that ref_count doesn't go negative."""
|
||||
cache = PrefixCacheManager(max_capacity=100)
|
||||
|
||||
# Insert a prefix
|
||||
cache.insert((1, 2, 3), slot=0)
|
||||
|
||||
# Release multiple times
|
||||
for _ in range(5):
|
||||
cache.release((1, 2, 3))
|
||||
|
||||
# Try to find it - should return None since ref_count would be negative
|
||||
# or handle it gracefully
|
||||
node = cache.root.children.get(1).children.get(2).children.get(3)
|
||||
# The ref_count should be 0, not negative
|
||||
assert node.ref_count >= 0, f"ref_count went negative: {node.ref_count}"
|
||||
Loading…
Reference in New Issue