feat: 增加缓存处理
This commit is contained in:
parent
b0eff02446
commit
ab5e207f42
|
|
@ -112,6 +112,8 @@ class InferenceEngine:
|
||||||
tokenizer: AutoTokenizer,
|
tokenizer: AutoTokenizer,
|
||||||
max_batch_size: int = 1,
|
max_batch_size: int = 1,
|
||||||
max_seq_len: Optional[int] = None,
|
max_seq_len: Optional[int] = None,
|
||||||
|
max_prefix_len: int = 512,
|
||||||
|
cache_capacity: int = 1000,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize inference engine with separate model and tokenizer.
|
Initialize inference engine with separate model and tokenizer.
|
||||||
|
|
@ -122,6 +124,8 @@ class InferenceEngine:
|
||||||
config: Model configuration
|
config: Model configuration
|
||||||
max_batch_size: Maximum batch size for continuous batching
|
max_batch_size: Maximum batch size for continuous batching
|
||||||
max_seq_len: Maximum sequence length (defaults to config.max_len)
|
max_seq_len: Maximum sequence length (defaults to config.max_len)
|
||||||
|
max_prefix_len: Maximum prefix length for cache (default: 512)
|
||||||
|
cache_capacity: Maximum number of cached prefixes (default: 1000)
|
||||||
"""
|
"""
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
@ -141,6 +145,8 @@ class InferenceEngine:
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
max_batch_size=max_batch_size,
|
max_batch_size=max_batch_size,
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
|
max_prefix_len=max_prefix_len,
|
||||||
|
cache_capacity=cache_capacity,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
@ -12,6 +12,120 @@ from astrai.model.automodel import AutoModel
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class RadixNode:
|
||||||
|
"""Radix tree node for prefix cache."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.children: Dict[int, "RadixNode"] = {} # token_id -> child node
|
||||||
|
self.hash: Optional[int] = None # 64-bit hash of the prefix
|
||||||
|
self.slot: int = -1 # KV Cache slot, valid only for leaf nodes
|
||||||
|
self.ref_count: int = 0 # number of tasks referencing this prefix
|
||||||
|
self.last_access: float = 0.0 # timestamp for LRU
|
||||||
|
self.token_sequence: list = [] # full token sequence from root to this node
|
||||||
|
|
||||||
|
|
||||||
|
class PrefixCacheManager:
|
||||||
|
"""Prefix cache manager using Radix tree with LRU eviction."""
|
||||||
|
|
||||||
|
def __init__(self, max_capacity: int = 1000, base: int = 131, mod: int = 10**9 + 7):
|
||||||
|
self.root = RadixNode()
|
||||||
|
self.base = base
|
||||||
|
self.mod = mod
|
||||||
|
self.max_capacity = max_capacity
|
||||||
|
self.lru: List[Tuple[float, RadixNode]] = [] # (timestamp, node) for LRU
|
||||||
|
|
||||||
|
def insert(self, token_ids: Tuple[int, ...], slot: int) -> None:
|
||||||
|
"""Insert a prefix, increase ref_count if already exists, otherwise create new node."""
|
||||||
|
node = self.root
|
||||||
|
path = []
|
||||||
|
h = 0
|
||||||
|
for i, token_id in enumerate(token_ids):
|
||||||
|
if token_id not in node.children:
|
||||||
|
node.children[token_id] = RadixNode()
|
||||||
|
node = node.children[token_id]
|
||||||
|
h = (h * self.base + token_id) % self.mod
|
||||||
|
node.hash = h
|
||||||
|
path.append(token_id)
|
||||||
|
node.token_sequence = list(
|
||||||
|
path
|
||||||
|
) # store full sequence for exact verification
|
||||||
|
|
||||||
|
# Leaf node: set slot and increase ref_count
|
||||||
|
if node.slot == -1:
|
||||||
|
node.slot = slot
|
||||||
|
node.ref_count += 1
|
||||||
|
node.last_access = time.time()
|
||||||
|
self._update_lru(node)
|
||||||
|
self._evict_if_needed()
|
||||||
|
|
||||||
|
def find_longest_prefix(self, token_ids: List[int]) -> Optional[Tuple[int, int]]:
|
||||||
|
"""Find longest matching prefix, return (prefix_len, slot).
|
||||||
|
|
||||||
|
During traversal, compute hash per token and compare with node hash.
|
||||||
|
If hash matches, perform full token sequence verification to avoid
|
||||||
|
hash collision errors.
|
||||||
|
"""
|
||||||
|
node = self.root
|
||||||
|
best_len = 0
|
||||||
|
best_slot = -1
|
||||||
|
h = 0
|
||||||
|
|
||||||
|
for i, token_id in enumerate(token_ids):
|
||||||
|
if token_id not in node.children:
|
||||||
|
break
|
||||||
|
node = node.children[token_id]
|
||||||
|
h = (h * self.base + token_id) % self.mod
|
||||||
|
if node.hash == h: # hash matches
|
||||||
|
# Exact verification: compare full token sequence
|
||||||
|
if node.token_sequence == token_ids[: i + 1]:
|
||||||
|
best_len = i + 1
|
||||||
|
best_slot = node.slot
|
||||||
|
node.last_access = time.time()
|
||||||
|
self._update_lru(node)
|
||||||
|
|
||||||
|
if best_len > 0:
|
||||||
|
return (best_len, best_slot)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def release(self, token_ids: Tuple[int, ...]) -> None:
|
||||||
|
"""Release reference to a prefix, decrease ref_count. If zero, mark as evictable."""
|
||||||
|
node = self.root
|
||||||
|
for token_id in token_ids:
|
||||||
|
if token_id not in node.children:
|
||||||
|
return
|
||||||
|
node = node.children[token_id]
|
||||||
|
if node.ref_count > 0:
|
||||||
|
node.ref_count -= 1
|
||||||
|
if node.ref_count == 0:
|
||||||
|
node.slot = -1 # slot can be reused
|
||||||
|
|
||||||
|
def _update_lru(self, node: RadixNode) -> None:
|
||||||
|
"""Update LRU list, move node to most recently used position."""
|
||||||
|
self.lru = [(ts, n) for (ts, n) in self.lru if n is not node]
|
||||||
|
self.lru.append((node.last_access, node))
|
||||||
|
|
||||||
|
def _evict_if_needed(self) -> None:
|
||||||
|
"""If cache entries exceed capacity, evict least recently used leaf nodes (ref_count must be 0)."""
|
||||||
|
if len(self.lru) <= self.max_capacity:
|
||||||
|
return
|
||||||
|
# Sort by timestamp
|
||||||
|
self.lru.sort(key=lambda x: x[0])
|
||||||
|
for ts, node in self.lru:
|
||||||
|
if node.ref_count == 0:
|
||||||
|
# Remove leaf node from tree (need to recursively delete empty branches)
|
||||||
|
self._remove_node(node)
|
||||||
|
self.lru.remove((ts, node))
|
||||||
|
if len(self.lru) <= self.max_capacity:
|
||||||
|
break
|
||||||
|
|
||||||
|
def _remove_node(self, node: RadixNode) -> None:
|
||||||
|
"""Remove node from tree (simplified implementation)."""
|
||||||
|
# Clear the node's leaf properties
|
||||||
|
node.slot = -1
|
||||||
|
node.hash = None
|
||||||
|
node.token_sequence = []
|
||||||
|
|
||||||
|
|
||||||
class TaskStatus:
|
class TaskStatus:
|
||||||
"""Task state for continuous batching."""
|
"""Task state for continuous batching."""
|
||||||
|
|
||||||
|
|
@ -46,6 +160,7 @@ class Task:
|
||||||
self.input_tokens: int = 0
|
self.input_tokens: int = 0
|
||||||
self.output_tokens: int = 0
|
self.output_tokens: int = 0
|
||||||
self.slot: int = -1
|
self.slot: int = -1
|
||||||
|
self.prefix_len: int = 0 # prefix cache matched length
|
||||||
self.arrival_time = time.time()
|
self.arrival_time = time.time()
|
||||||
self.finish_time: Optional[float] = None
|
self.finish_time: Optional[float] = None
|
||||||
|
|
||||||
|
|
@ -104,6 +219,8 @@ class InferenceScheduler:
|
||||||
tokenizer: AutoTokenizer,
|
tokenizer: AutoTokenizer,
|
||||||
max_batch_size: int = 16,
|
max_batch_size: int = 16,
|
||||||
max_seq_len: Optional[int] = None,
|
max_seq_len: Optional[int] = None,
|
||||||
|
max_prefix_len: int = 512,
|
||||||
|
cache_capacity: int = 1000,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
):
|
):
|
||||||
|
|
@ -113,9 +230,13 @@ class InferenceScheduler:
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.max_batch_size = max_batch_size
|
self.max_batch_size = max_batch_size
|
||||||
self.max_seq_len = max_seq_len or config.max_len
|
self.max_seq_len = max_seq_len or config.max_len
|
||||||
|
self.max_prefix_len = max_prefix_len
|
||||||
self.device = device or next(model.parameters()).device
|
self.device = device or next(model.parameters()).device
|
||||||
self.dtype = dtype or next(model.parameters()).dtype
|
self.dtype = dtype or next(model.parameters()).dtype
|
||||||
|
|
||||||
|
# Initialize prefix cache
|
||||||
|
self.prefix_cache = PrefixCacheManager(max_capacity=cache_capacity)
|
||||||
|
|
||||||
num_kv_heads = config.n_kv_heads
|
num_kv_heads = config.n_kv_heads
|
||||||
head_dim = config.dim // config.n_heads
|
head_dim = config.dim // config.n_heads
|
||||||
n_layers = config.n_layers
|
n_layers = config.n_layers
|
||||||
|
|
@ -170,6 +291,10 @@ class InferenceScheduler:
|
||||||
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||||
prompt_ids = self.tokenizer.encode(prompt)
|
prompt_ids = self.tokenizer.encode(prompt)
|
||||||
|
|
||||||
|
# Truncate if exceeds max_prefix_len
|
||||||
|
if len(prompt_ids) > self.max_prefix_len:
|
||||||
|
prompt_ids = prompt_ids[: self.max_prefix_len]
|
||||||
|
|
||||||
task = Task(
|
task = Task(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
prompt_ids=prompt_ids,
|
prompt_ids=prompt_ids,
|
||||||
|
|
@ -180,6 +305,16 @@ class InferenceScheduler:
|
||||||
stream_callback=stream_callback,
|
stream_callback=stream_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Find longest matching prefix from cache
|
||||||
|
match = self.prefix_cache.find_longest_prefix(prompt_ids)
|
||||||
|
if match:
|
||||||
|
prefix_len, slot = match
|
||||||
|
task.prefix_len = prefix_len
|
||||||
|
task.slot = slot
|
||||||
|
else:
|
||||||
|
task.prefix_len = 0
|
||||||
|
task.slot = -1
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self.waiting_queue.append(task)
|
self.waiting_queue.append(task)
|
||||||
self._total_tasks += 1
|
self._total_tasks += 1
|
||||||
|
|
@ -207,6 +342,11 @@ class InferenceScheduler:
|
||||||
slot = task.slot
|
slot = task.slot
|
||||||
if slot >= 0 and slot < len(self.active_tasks):
|
if slot >= 0 and slot < len(self.active_tasks):
|
||||||
self.seq_mask[slot, :] = False
|
self.seq_mask[slot, :] = False
|
||||||
|
|
||||||
|
# Release prefix cache reference
|
||||||
|
if task.prefix_len > 0:
|
||||||
|
self.prefix_cache.release(tuple(task.prompt_ids[: task.prefix_len]))
|
||||||
|
|
||||||
task.slot = -1
|
task.slot = -1
|
||||||
|
|
||||||
self.active_tasks = [
|
self.active_tasks = [
|
||||||
|
|
@ -235,7 +375,46 @@ class InferenceScheduler:
|
||||||
self.active_tasks.append(task)
|
self.active_tasks.append(task)
|
||||||
|
|
||||||
def _execute_prefill(self, tasks: List[Task]) -> None:
|
def _execute_prefill(self, tasks: List[Task]) -> None:
|
||||||
"""Execute Prefill phase."""
|
"""Execute Prefill phase with incremental prefill support."""
|
||||||
|
if not tasks:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Group tasks by their prefix_len to handle different prefill scenarios
|
||||||
|
fully_cached_tasks = [] # prefix_len == total_len, skip prefill
|
||||||
|
partial_prefill_tasks = [] # prefix_len > 0, need incremental prefill
|
||||||
|
full_prefill_tasks = [] # prefix_len == 0, full prefill
|
||||||
|
|
||||||
|
for task in tasks:
|
||||||
|
total_len = len(task.prompt_ids)
|
||||||
|
prefix_len = task.prefix_len
|
||||||
|
|
||||||
|
if prefix_len == total_len:
|
||||||
|
# Scenario 1: complete match, skip prefill
|
||||||
|
task.input_tokens = total_len
|
||||||
|
task.output_tokens = 0
|
||||||
|
fully_cached_tasks.append(task)
|
||||||
|
elif prefix_len > 0:
|
||||||
|
# Scenario 2: partial match, incremental prefill
|
||||||
|
partial_prefill_tasks.append(task)
|
||||||
|
else:
|
||||||
|
# Scenario 3: no match, full prefill
|
||||||
|
full_prefill_tasks.append(task)
|
||||||
|
|
||||||
|
# Handle fully cached tasks - update seq_mask
|
||||||
|
for task in fully_cached_tasks:
|
||||||
|
if task.slot >= 0:
|
||||||
|
self.seq_mask[task.slot, : task.input_tokens] = True
|
||||||
|
|
||||||
|
# Execute full prefill for new prefixes
|
||||||
|
if full_prefill_tasks:
|
||||||
|
self._execute_full_prefill(full_prefill_tasks)
|
||||||
|
|
||||||
|
# Execute incremental prefill for partial matches
|
||||||
|
if partial_prefill_tasks:
|
||||||
|
self._execute_partial_prefill(partial_prefill_tasks)
|
||||||
|
|
||||||
|
def _execute_full_prefill(self, tasks: List[Task]) -> None:
|
||||||
|
"""Execute full prefill for tasks without prefix cache."""
|
||||||
if not tasks:
|
if not tasks:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -271,11 +450,59 @@ class InferenceScheduler:
|
||||||
for i, task in enumerate(tasks):
|
for i, task in enumerate(tasks):
|
||||||
task.input_tokens = prompt_lens[i]
|
task.input_tokens = prompt_lens[i]
|
||||||
task.output_tokens = 0
|
task.output_tokens = 0
|
||||||
|
# Insert new prefix into cache
|
||||||
|
self.prefix_cache.insert(tuple(task.prompt_ids), task.slot)
|
||||||
|
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
if task.slot >= 0:
|
if task.slot >= 0:
|
||||||
self.seq_mask[task.slot, : task.input_tokens] = True
|
self.seq_mask[task.slot, : task.input_tokens] = True
|
||||||
|
|
||||||
|
def _execute_partial_prefill(self, tasks: List[Task]) -> None:
|
||||||
|
"""Execute incremental prefill for tasks with partial prefix cache match."""
|
||||||
|
for task in tasks:
|
||||||
|
total_len = len(task.prompt_ids)
|
||||||
|
prefix_len = task.prefix_len
|
||||||
|
|
||||||
|
if prefix_len >= total_len:
|
||||||
|
task.input_tokens = total_len
|
||||||
|
task.output_tokens = 0
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get new tokens that need prefill
|
||||||
|
new_ids = task.prompt_ids[prefix_len:]
|
||||||
|
new_len = len(new_ids)
|
||||||
|
|
||||||
|
if new_len == 0:
|
||||||
|
task.input_tokens = total_len
|
||||||
|
task.output_tokens = 0
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Build input for incremental prefill
|
||||||
|
input_ids = torch.tensor([new_ids], dtype=torch.long, device=self.device)
|
||||||
|
|
||||||
|
# Input mask should cover from position 0 to prefix_len + new_len
|
||||||
|
# The prefix part uses cached KV, new part needs computation
|
||||||
|
input_mask = torch.ones(
|
||||||
|
(1, prefix_len + new_len), dtype=torch.bool, device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
self.model(
|
||||||
|
input_ids,
|
||||||
|
input_mask=input_mask,
|
||||||
|
start_pos=prefix_len,
|
||||||
|
persistent_key_values=self.kv_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
task.input_tokens = total_len
|
||||||
|
task.output_tokens = 0
|
||||||
|
|
||||||
|
# Insert full prefix into cache (ref_count already increased in add_task)
|
||||||
|
self.prefix_cache.insert(tuple(task.prompt_ids), task.slot)
|
||||||
|
|
||||||
|
if task.slot >= 0:
|
||||||
|
self.seq_mask[task.slot, : task.input_tokens] = True
|
||||||
|
|
||||||
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
|
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
|
||||||
"""Execute Decode phase."""
|
"""Execute Decode phase."""
|
||||||
if not tasks:
|
if not tasks:
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,7 @@ import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
|
from tokenizers import Tokenizer
|
||||||
from tokenizers.models import BPE
|
|
||||||
|
|
||||||
from astrai.tokenize.chat_template import ChatTemplate
|
from astrai.tokenize.chat_template import ChatTemplate
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue