From ab5e207f42ae54bec0bc8f370b20a352ec6fbbb6 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Wed, 8 Apr 2026 20:54:14 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=E7=BC=93=E5=AD=98?= =?UTF-8?q?=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrai/inference/engine.py | 6 + astrai/inference/scheduler.py | 231 +++++++++++++++++++++++++++++++++- astrai/tokenize/tokenizer.py | 3 +- 3 files changed, 236 insertions(+), 4 deletions(-) diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py index 8e0a2c1..6177d75 100644 --- a/astrai/inference/engine.py +++ b/astrai/inference/engine.py @@ -112,6 +112,8 @@ class InferenceEngine: tokenizer: AutoTokenizer, max_batch_size: int = 1, max_seq_len: Optional[int] = None, + max_prefix_len: int = 512, + cache_capacity: int = 1000, ): """ Initialize inference engine with separate model and tokenizer. @@ -122,6 +124,8 @@ class InferenceEngine: config: Model configuration max_batch_size: Maximum batch size for continuous batching max_seq_len: Maximum sequence length (defaults to config.max_len) + max_prefix_len: Maximum prefix length for cache (default: 512) + cache_capacity: Maximum number of cached prefixes (default: 1000) """ self.model = model self.tokenizer = tokenizer @@ -141,6 +145,8 @@ class InferenceEngine: tokenizer=self.tokenizer, max_batch_size=max_batch_size, max_seq_len=max_seq_len, + max_prefix_len=max_prefix_len, + cache_capacity=cache_capacity, device=device, dtype=dtype, ) diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index 6d15334..24f4804 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -3,7 +3,7 @@ import threading import time import uuid -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple import torch from torch import Tensor @@ -12,6 +12,120 @@ from astrai.model.automodel import AutoModel from astrai.tokenize import AutoTokenizer +class RadixNode: + """Radix tree node for prefix cache.""" + + def __init__(self): + self.children: Dict[int, "RadixNode"] = {} # token_id -> child node + self.hash: Optional[int] = None # 64-bit hash of the prefix + self.slot: int = -1 # KV Cache slot, valid only for leaf nodes + self.ref_count: int = 0 # number of tasks referencing this prefix + self.last_access: float = 0.0 # timestamp for LRU + self.token_sequence: list = [] # full token sequence from root to this node + + +class PrefixCacheManager: + """Prefix cache manager using Radix tree with LRU eviction.""" + + def __init__(self, max_capacity: int = 1000, base: int = 131, mod: int = 10**9 + 7): + self.root = RadixNode() + self.base = base + self.mod = mod + self.max_capacity = max_capacity + self.lru: List[Tuple[float, RadixNode]] = [] # (timestamp, node) for LRU + + def insert(self, token_ids: Tuple[int, ...], slot: int) -> None: + """Insert a prefix, increase ref_count if already exists, otherwise create new node.""" + node = self.root + path = [] + h = 0 + for i, token_id in enumerate(token_ids): + if token_id not in node.children: + node.children[token_id] = RadixNode() + node = node.children[token_id] + h = (h * self.base + token_id) % self.mod + node.hash = h + path.append(token_id) + node.token_sequence = list( + path + ) # store full sequence for exact verification + + # Leaf node: set slot and increase ref_count + if node.slot == -1: + node.slot = slot + node.ref_count += 1 + node.last_access = time.time() + self._update_lru(node) + self._evict_if_needed() + + def find_longest_prefix(self, token_ids: List[int]) -> Optional[Tuple[int, int]]: + """Find longest matching prefix, return (prefix_len, slot). + + During traversal, compute hash per token and compare with node hash. + If hash matches, perform full token sequence verification to avoid + hash collision errors. + """ + node = self.root + best_len = 0 + best_slot = -1 + h = 0 + + for i, token_id in enumerate(token_ids): + if token_id not in node.children: + break + node = node.children[token_id] + h = (h * self.base + token_id) % self.mod + if node.hash == h: # hash matches + # Exact verification: compare full token sequence + if node.token_sequence == token_ids[: i + 1]: + best_len = i + 1 + best_slot = node.slot + node.last_access = time.time() + self._update_lru(node) + + if best_len > 0: + return (best_len, best_slot) + return None + + def release(self, token_ids: Tuple[int, ...]) -> None: + """Release reference to a prefix, decrease ref_count. If zero, mark as evictable.""" + node = self.root + for token_id in token_ids: + if token_id not in node.children: + return + node = node.children[token_id] + if node.ref_count > 0: + node.ref_count -= 1 + if node.ref_count == 0: + node.slot = -1 # slot can be reused + + def _update_lru(self, node: RadixNode) -> None: + """Update LRU list, move node to most recently used position.""" + self.lru = [(ts, n) for (ts, n) in self.lru if n is not node] + self.lru.append((node.last_access, node)) + + def _evict_if_needed(self) -> None: + """If cache entries exceed capacity, evict least recently used leaf nodes (ref_count must be 0).""" + if len(self.lru) <= self.max_capacity: + return + # Sort by timestamp + self.lru.sort(key=lambda x: x[0]) + for ts, node in self.lru: + if node.ref_count == 0: + # Remove leaf node from tree (need to recursively delete empty branches) + self._remove_node(node) + self.lru.remove((ts, node)) + if len(self.lru) <= self.max_capacity: + break + + def _remove_node(self, node: RadixNode) -> None: + """Remove node from tree (simplified implementation).""" + # Clear the node's leaf properties + node.slot = -1 + node.hash = None + node.token_sequence = [] + + class TaskStatus: """Task state for continuous batching.""" @@ -46,6 +160,7 @@ class Task: self.input_tokens: int = 0 self.output_tokens: int = 0 self.slot: int = -1 + self.prefix_len: int = 0 # prefix cache matched length self.arrival_time = time.time() self.finish_time: Optional[float] = None @@ -104,6 +219,8 @@ class InferenceScheduler: tokenizer: AutoTokenizer, max_batch_size: int = 16, max_seq_len: Optional[int] = None, + max_prefix_len: int = 512, + cache_capacity: int = 1000, device: str = "cuda", dtype: torch.dtype = torch.bfloat16, ): @@ -113,9 +230,13 @@ class InferenceScheduler: self.tokenizer = tokenizer self.max_batch_size = max_batch_size self.max_seq_len = max_seq_len or config.max_len + self.max_prefix_len = max_prefix_len self.device = device or next(model.parameters()).device self.dtype = dtype or next(model.parameters()).dtype + # Initialize prefix cache + self.prefix_cache = PrefixCacheManager(max_capacity=cache_capacity) + num_kv_heads = config.n_kv_heads head_dim = config.dim // config.n_heads n_layers = config.n_layers @@ -170,6 +291,10 @@ class InferenceScheduler: task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}" prompt_ids = self.tokenizer.encode(prompt) + # Truncate if exceeds max_prefix_len + if len(prompt_ids) > self.max_prefix_len: + prompt_ids = prompt_ids[: self.max_prefix_len] + task = Task( task_id=task_id, prompt_ids=prompt_ids, @@ -180,6 +305,16 @@ class InferenceScheduler: stream_callback=stream_callback, ) + # Find longest matching prefix from cache + match = self.prefix_cache.find_longest_prefix(prompt_ids) + if match: + prefix_len, slot = match + task.prefix_len = prefix_len + task.slot = slot + else: + task.prefix_len = 0 + task.slot = -1 + with self._lock: self.waiting_queue.append(task) self._total_tasks += 1 @@ -207,6 +342,11 @@ class InferenceScheduler: slot = task.slot if slot >= 0 and slot < len(self.active_tasks): self.seq_mask[slot, :] = False + + # Release prefix cache reference + if task.prefix_len > 0: + self.prefix_cache.release(tuple(task.prompt_ids[: task.prefix_len])) + task.slot = -1 self.active_tasks = [ @@ -235,7 +375,46 @@ class InferenceScheduler: self.active_tasks.append(task) def _execute_prefill(self, tasks: List[Task]) -> None: - """Execute Prefill phase.""" + """Execute Prefill phase with incremental prefill support.""" + if not tasks: + return + + # Group tasks by their prefix_len to handle different prefill scenarios + fully_cached_tasks = [] # prefix_len == total_len, skip prefill + partial_prefill_tasks = [] # prefix_len > 0, need incremental prefill + full_prefill_tasks = [] # prefix_len == 0, full prefill + + for task in tasks: + total_len = len(task.prompt_ids) + prefix_len = task.prefix_len + + if prefix_len == total_len: + # Scenario 1: complete match, skip prefill + task.input_tokens = total_len + task.output_tokens = 0 + fully_cached_tasks.append(task) + elif prefix_len > 0: + # Scenario 2: partial match, incremental prefill + partial_prefill_tasks.append(task) + else: + # Scenario 3: no match, full prefill + full_prefill_tasks.append(task) + + # Handle fully cached tasks - update seq_mask + for task in fully_cached_tasks: + if task.slot >= 0: + self.seq_mask[task.slot, : task.input_tokens] = True + + # Execute full prefill for new prefixes + if full_prefill_tasks: + self._execute_full_prefill(full_prefill_tasks) + + # Execute incremental prefill for partial matches + if partial_prefill_tasks: + self._execute_partial_prefill(partial_prefill_tasks) + + def _execute_full_prefill(self, tasks: List[Task]) -> None: + """Execute full prefill for tasks without prefix cache.""" if not tasks: return @@ -271,11 +450,59 @@ class InferenceScheduler: for i, task in enumerate(tasks): task.input_tokens = prompt_lens[i] task.output_tokens = 0 + # Insert new prefix into cache + self.prefix_cache.insert(tuple(task.prompt_ids), task.slot) for task in tasks: if task.slot >= 0: self.seq_mask[task.slot, : task.input_tokens] = True + def _execute_partial_prefill(self, tasks: List[Task]) -> None: + """Execute incremental prefill for tasks with partial prefix cache match.""" + for task in tasks: + total_len = len(task.prompt_ids) + prefix_len = task.prefix_len + + if prefix_len >= total_len: + task.input_tokens = total_len + task.output_tokens = 0 + continue + + # Get new tokens that need prefill + new_ids = task.prompt_ids[prefix_len:] + new_len = len(new_ids) + + if new_len == 0: + task.input_tokens = total_len + task.output_tokens = 0 + continue + + # Build input for incremental prefill + input_ids = torch.tensor([new_ids], dtype=torch.long, device=self.device) + + # Input mask should cover from position 0 to prefix_len + new_len + # The prefix part uses cached KV, new part needs computation + input_mask = torch.ones( + (1, prefix_len + new_len), dtype=torch.bool, device=self.device + ) + + with torch.inference_mode(): + self.model( + input_ids, + input_mask=input_mask, + start_pos=prefix_len, + persistent_key_values=self.kv_cache, + ) + + task.input_tokens = total_len + task.output_tokens = 0 + + # Insert full prefix into cache (ref_count already increased in add_task) + self.prefix_cache.insert(tuple(task.prompt_ids), task.slot) + + if task.slot >= 0: + self.seq_mask[task.slot, : task.input_tokens] = True + def _execute_decode(self, tasks: List[Task], start_pos: int) -> None: """Execute Decode phase.""" if not tasks: diff --git a/astrai/tokenize/tokenizer.py b/astrai/tokenize/tokenizer.py index c3e02ef..b40ca0f 100644 --- a/astrai/tokenize/tokenizer.py +++ b/astrai/tokenize/tokenizer.py @@ -6,8 +6,7 @@ import json from pathlib import Path from typing import Dict, List, Optional, Union -from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors -from tokenizers.models import BPE +from tokenizers import Tokenizer from astrai.tokenize.chat_template import ChatTemplate