feat: 增加缓存处理

This commit is contained in:
ViperEkura 2026-04-08 20:54:14 +08:00
parent b0eff02446
commit ab5e207f42
3 changed files with 236 additions and 4 deletions

View File

@ -112,6 +112,8 @@ class InferenceEngine:
tokenizer: AutoTokenizer, tokenizer: AutoTokenizer,
max_batch_size: int = 1, max_batch_size: int = 1,
max_seq_len: Optional[int] = None, max_seq_len: Optional[int] = None,
max_prefix_len: int = 512,
cache_capacity: int = 1000,
): ):
""" """
Initialize inference engine with separate model and tokenizer. Initialize inference engine with separate model and tokenizer.
@ -122,6 +124,8 @@ class InferenceEngine:
config: Model configuration config: Model configuration
max_batch_size: Maximum batch size for continuous batching max_batch_size: Maximum batch size for continuous batching
max_seq_len: Maximum sequence length (defaults to config.max_len) max_seq_len: Maximum sequence length (defaults to config.max_len)
max_prefix_len: Maximum prefix length for cache (default: 512)
cache_capacity: Maximum number of cached prefixes (default: 1000)
""" """
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
@ -141,6 +145,8 @@ class InferenceEngine:
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
max_batch_size=max_batch_size, max_batch_size=max_batch_size,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
max_prefix_len=max_prefix_len,
cache_capacity=cache_capacity,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )

View File

@ -3,7 +3,7 @@
import threading import threading
import time import time
import uuid import uuid
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional, Tuple
import torch import torch
from torch import Tensor from torch import Tensor
@ -12,6 +12,120 @@ from astrai.model.automodel import AutoModel
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer
class RadixNode:
"""Radix tree node for prefix cache."""
def __init__(self):
self.children: Dict[int, "RadixNode"] = {} # token_id -> child node
self.hash: Optional[int] = None # 64-bit hash of the prefix
self.slot: int = -1 # KV Cache slot, valid only for leaf nodes
self.ref_count: int = 0 # number of tasks referencing this prefix
self.last_access: float = 0.0 # timestamp for LRU
self.token_sequence: list = [] # full token sequence from root to this node
class PrefixCacheManager:
"""Prefix cache manager using Radix tree with LRU eviction."""
def __init__(self, max_capacity: int = 1000, base: int = 131, mod: int = 10**9 + 7):
self.root = RadixNode()
self.base = base
self.mod = mod
self.max_capacity = max_capacity
self.lru: List[Tuple[float, RadixNode]] = [] # (timestamp, node) for LRU
def insert(self, token_ids: Tuple[int, ...], slot: int) -> None:
"""Insert a prefix, increase ref_count if already exists, otherwise create new node."""
node = self.root
path = []
h = 0
for i, token_id in enumerate(token_ids):
if token_id not in node.children:
node.children[token_id] = RadixNode()
node = node.children[token_id]
h = (h * self.base + token_id) % self.mod
node.hash = h
path.append(token_id)
node.token_sequence = list(
path
) # store full sequence for exact verification
# Leaf node: set slot and increase ref_count
if node.slot == -1:
node.slot = slot
node.ref_count += 1
node.last_access = time.time()
self._update_lru(node)
self._evict_if_needed()
def find_longest_prefix(self, token_ids: List[int]) -> Optional[Tuple[int, int]]:
"""Find longest matching prefix, return (prefix_len, slot).
During traversal, compute hash per token and compare with node hash.
If hash matches, perform full token sequence verification to avoid
hash collision errors.
"""
node = self.root
best_len = 0
best_slot = -1
h = 0
for i, token_id in enumerate(token_ids):
if token_id not in node.children:
break
node = node.children[token_id]
h = (h * self.base + token_id) % self.mod
if node.hash == h: # hash matches
# Exact verification: compare full token sequence
if node.token_sequence == token_ids[: i + 1]:
best_len = i + 1
best_slot = node.slot
node.last_access = time.time()
self._update_lru(node)
if best_len > 0:
return (best_len, best_slot)
return None
def release(self, token_ids: Tuple[int, ...]) -> None:
"""Release reference to a prefix, decrease ref_count. If zero, mark as evictable."""
node = self.root
for token_id in token_ids:
if token_id not in node.children:
return
node = node.children[token_id]
if node.ref_count > 0:
node.ref_count -= 1
if node.ref_count == 0:
node.slot = -1 # slot can be reused
def _update_lru(self, node: RadixNode) -> None:
"""Update LRU list, move node to most recently used position."""
self.lru = [(ts, n) for (ts, n) in self.lru if n is not node]
self.lru.append((node.last_access, node))
def _evict_if_needed(self) -> None:
"""If cache entries exceed capacity, evict least recently used leaf nodes (ref_count must be 0)."""
if len(self.lru) <= self.max_capacity:
return
# Sort by timestamp
self.lru.sort(key=lambda x: x[0])
for ts, node in self.lru:
if node.ref_count == 0:
# Remove leaf node from tree (need to recursively delete empty branches)
self._remove_node(node)
self.lru.remove((ts, node))
if len(self.lru) <= self.max_capacity:
break
def _remove_node(self, node: RadixNode) -> None:
"""Remove node from tree (simplified implementation)."""
# Clear the node's leaf properties
node.slot = -1
node.hash = None
node.token_sequence = []
class TaskStatus: class TaskStatus:
"""Task state for continuous batching.""" """Task state for continuous batching."""
@ -46,6 +160,7 @@ class Task:
self.input_tokens: int = 0 self.input_tokens: int = 0
self.output_tokens: int = 0 self.output_tokens: int = 0
self.slot: int = -1 self.slot: int = -1
self.prefix_len: int = 0 # prefix cache matched length
self.arrival_time = time.time() self.arrival_time = time.time()
self.finish_time: Optional[float] = None self.finish_time: Optional[float] = None
@ -104,6 +219,8 @@ class InferenceScheduler:
tokenizer: AutoTokenizer, tokenizer: AutoTokenizer,
max_batch_size: int = 16, max_batch_size: int = 16,
max_seq_len: Optional[int] = None, max_seq_len: Optional[int] = None,
max_prefix_len: int = 512,
cache_capacity: int = 1000,
device: str = "cuda", device: str = "cuda",
dtype: torch.dtype = torch.bfloat16, dtype: torch.dtype = torch.bfloat16,
): ):
@ -113,9 +230,13 @@ class InferenceScheduler:
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len or config.max_len self.max_seq_len = max_seq_len or config.max_len
self.max_prefix_len = max_prefix_len
self.device = device or next(model.parameters()).device self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype self.dtype = dtype or next(model.parameters()).dtype
# Initialize prefix cache
self.prefix_cache = PrefixCacheManager(max_capacity=cache_capacity)
num_kv_heads = config.n_kv_heads num_kv_heads = config.n_kv_heads
head_dim = config.dim // config.n_heads head_dim = config.dim // config.n_heads
n_layers = config.n_layers n_layers = config.n_layers
@ -170,6 +291,10 @@ class InferenceScheduler:
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}" task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
prompt_ids = self.tokenizer.encode(prompt) prompt_ids = self.tokenizer.encode(prompt)
# Truncate if exceeds max_prefix_len
if len(prompt_ids) > self.max_prefix_len:
prompt_ids = prompt_ids[: self.max_prefix_len]
task = Task( task = Task(
task_id=task_id, task_id=task_id,
prompt_ids=prompt_ids, prompt_ids=prompt_ids,
@ -180,6 +305,16 @@ class InferenceScheduler:
stream_callback=stream_callback, stream_callback=stream_callback,
) )
# Find longest matching prefix from cache
match = self.prefix_cache.find_longest_prefix(prompt_ids)
if match:
prefix_len, slot = match
task.prefix_len = prefix_len
task.slot = slot
else:
task.prefix_len = 0
task.slot = -1
with self._lock: with self._lock:
self.waiting_queue.append(task) self.waiting_queue.append(task)
self._total_tasks += 1 self._total_tasks += 1
@ -207,6 +342,11 @@ class InferenceScheduler:
slot = task.slot slot = task.slot
if slot >= 0 and slot < len(self.active_tasks): if slot >= 0 and slot < len(self.active_tasks):
self.seq_mask[slot, :] = False self.seq_mask[slot, :] = False
# Release prefix cache reference
if task.prefix_len > 0:
self.prefix_cache.release(tuple(task.prompt_ids[: task.prefix_len]))
task.slot = -1 task.slot = -1
self.active_tasks = [ self.active_tasks = [
@ -235,7 +375,46 @@ class InferenceScheduler:
self.active_tasks.append(task) self.active_tasks.append(task)
def _execute_prefill(self, tasks: List[Task]) -> None: def _execute_prefill(self, tasks: List[Task]) -> None:
"""Execute Prefill phase.""" """Execute Prefill phase with incremental prefill support."""
if not tasks:
return
# Group tasks by their prefix_len to handle different prefill scenarios
fully_cached_tasks = [] # prefix_len == total_len, skip prefill
partial_prefill_tasks = [] # prefix_len > 0, need incremental prefill
full_prefill_tasks = [] # prefix_len == 0, full prefill
for task in tasks:
total_len = len(task.prompt_ids)
prefix_len = task.prefix_len
if prefix_len == total_len:
# Scenario 1: complete match, skip prefill
task.input_tokens = total_len
task.output_tokens = 0
fully_cached_tasks.append(task)
elif prefix_len > 0:
# Scenario 2: partial match, incremental prefill
partial_prefill_tasks.append(task)
else:
# Scenario 3: no match, full prefill
full_prefill_tasks.append(task)
# Handle fully cached tasks - update seq_mask
for task in fully_cached_tasks:
if task.slot >= 0:
self.seq_mask[task.slot, : task.input_tokens] = True
# Execute full prefill for new prefixes
if full_prefill_tasks:
self._execute_full_prefill(full_prefill_tasks)
# Execute incremental prefill for partial matches
if partial_prefill_tasks:
self._execute_partial_prefill(partial_prefill_tasks)
def _execute_full_prefill(self, tasks: List[Task]) -> None:
"""Execute full prefill for tasks without prefix cache."""
if not tasks: if not tasks:
return return
@ -271,11 +450,59 @@ class InferenceScheduler:
for i, task in enumerate(tasks): for i, task in enumerate(tasks):
task.input_tokens = prompt_lens[i] task.input_tokens = prompt_lens[i]
task.output_tokens = 0 task.output_tokens = 0
# Insert new prefix into cache
self.prefix_cache.insert(tuple(task.prompt_ids), task.slot)
for task in tasks: for task in tasks:
if task.slot >= 0: if task.slot >= 0:
self.seq_mask[task.slot, : task.input_tokens] = True self.seq_mask[task.slot, : task.input_tokens] = True
def _execute_partial_prefill(self, tasks: List[Task]) -> None:
"""Execute incremental prefill for tasks with partial prefix cache match."""
for task in tasks:
total_len = len(task.prompt_ids)
prefix_len = task.prefix_len
if prefix_len >= total_len:
task.input_tokens = total_len
task.output_tokens = 0
continue
# Get new tokens that need prefill
new_ids = task.prompt_ids[prefix_len:]
new_len = len(new_ids)
if new_len == 0:
task.input_tokens = total_len
task.output_tokens = 0
continue
# Build input for incremental prefill
input_ids = torch.tensor([new_ids], dtype=torch.long, device=self.device)
# Input mask should cover from position 0 to prefix_len + new_len
# The prefix part uses cached KV, new part needs computation
input_mask = torch.ones(
(1, prefix_len + new_len), dtype=torch.bool, device=self.device
)
with torch.inference_mode():
self.model(
input_ids,
input_mask=input_mask,
start_pos=prefix_len,
persistent_key_values=self.kv_cache,
)
task.input_tokens = total_len
task.output_tokens = 0
# Insert full prefix into cache (ref_count already increased in add_task)
self.prefix_cache.insert(tuple(task.prompt_ids), task.slot)
if task.slot >= 0:
self.seq_mask[task.slot, : task.input_tokens] = True
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None: def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
"""Execute Decode phase.""" """Execute Decode phase."""
if not tasks: if not tasks:

View File

@ -6,8 +6,7 @@ import json
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors from tokenizers import Tokenizer
from tokenizers.models import BPE
from astrai.tokenize.chat_template import ChatTemplate from astrai.tokenize.chat_template import ChatTemplate