diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index 510b256..0d67651 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -118,12 +118,27 @@ class PrefixCacheManager: if len(self.lru) <= self.max_capacity: break - def _remove_node(self, node: RadixNode) -> None: - """Remove node from tree (simplified implementation).""" + def _remove_node( + self, + node: RadixNode, + parent: Optional[RadixNode] = None, + child_key: Optional[int] = None, + ) -> None: + """Remove node from tree, including empty parent nodes.""" + # First, recursively remove all children + for child_key, child_node in list(node.children.items()): + self._remove_node(child_node, node, child_key) + # Clear the node's leaf properties node.slot = -1 node.hash = None node.token_sequence = [] + node.children.clear() + + # If this node has no children and has a parent, remove the reference from parent + if parent is not None and child_key is not None and len(node.children) == 0: + if child_key in parent.children: + del parent.children[child_key] class TaskStatus: