From 23ce4bc3aecc6ffc5f63f349e57ecc8965b09ce6 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 5 Apr 2026 20:44:35 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E5=BC=82=E5=B8=B8?= =?UTF-8?q?=E5=A4=84=E7=90=86=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrai/inference/engine.py | 121 ++++++++++++++++++++++++++++++---- astrai/inference/scheduler.py | 18 +++++ 2 files changed, 126 insertions(+), 13 deletions(-) diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py index 14cc073..5f06ca4 100644 --- a/astrai/inference/engine.py +++ b/astrai/inference/engine.py @@ -1,13 +1,19 @@ """Unified inference engine.""" +import gc +import logging +import signal import threading +from typing import Any, Dict, Generator, List, Optional, Union + import torch import torch.nn as nn -from typing import Any, Dict, Generator, List, Optional, Union from astrai.tokenize.tokenizer import TextTokenizer from astrai.inference.scheduler import InferenceScheduler +logger = logging.getLogger(__name__) + class GenerationRequest: """Request parameters for text generation.""" @@ -145,6 +151,22 @@ class InferenceEngine: self.scheduler.start() + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Handle exceptions on exit.""" + if exc_type is not None: + # An exception occurred - try to save state + logger.warning(f"Exception {exc_type.__name__}: {exc_val}, saving state...") + try: + self.save_state("./inference_state") + except Exception: + pass + + self.shutdown() + return False + def generate( self, prompt: Union[str, List[str]], @@ -153,14 +175,21 @@ class InferenceEngine: temperature: float = 1.0, top_p: float = 1.0, top_k: int = 50, + abort_on_exception: bool = True, ) -> Union[Generator[str, None, None], str, List[str]]: - """Unified generation interface.""" + """Unified generation interface. + + Args: + abort_on_exception: If True, abort the generation when consumer + stops iterating (GeneratorExit/StopIteration). Default: True. + """ is_batch = isinstance(prompt, list) prompts = prompt if is_batch else [prompt] if stream: return self._generate_streaming( - prompts, is_batch, max_tokens, temperature, top_p, top_k + prompts, is_batch, max_tokens, temperature, top_p, top_k, + abort_on_exception ) else: return self._generate_non_streaming( @@ -191,14 +220,20 @@ class InferenceEngine: temperature: float, top_p: float, top_k: int, + abort_on_exception: bool = True, ) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]: - """Generate with streaming output.""" + """Generate with streaming output. + + Args: + abort_on_exception: If True, abort the task when generator is + stopped early by consumer (GeneratorExit/StopIteration). + """ if is_batch: raise NotImplementedError("Batch streaming is not implemented yet") result = _StreamingResult() - self.scheduler.add_task( + task_id = self.scheduler.add_task( prompt=prompts[0], max_tokens=max_tokens, temperature=temperature, @@ -208,14 +243,21 @@ class InferenceEngine: ) def gen(): - while True: - tokens = result.pop_all() - for token in tokens: - if token == "[DONE]": - return - yield token - result.wait(timeout=0.05) + try: + while True: + tokens = result.pop_all() + for token in tokens: + if token == "[DONE]": + return + yield token + result.wait(timeout=0.05) + except Exception: + # Consumer stopped iterating - abort the task + if abort_on_exception: + self.scheduler.remove_task(task_id) + raise + gen.task_id = task_id return gen() def _generate_non_streaming( @@ -249,5 +291,58 @@ class InferenceEngine: return self.scheduler.get_stats() def shutdown(self) -> None: - """Shutdown the engine.""" + """Shutdown the engine and release all resources.""" + + # Stop scheduler first self.scheduler.stop() + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + gc.collect() + + def force_stop(self) -> None: + """ + Force stop the engine immediately without saving state. + + Use this for emergency shutdown when graceful shutdown is not possible. + """ + import os + + # Stop watching threads if any + if hasattr(self, 'stop_watching'): + self.stop_watching() + + # Unregister signal handlers + if hasattr(self, '_original_sigint'): + signal.signal(signal.SIGINT, self._original_sigint) + if hasattr(self, '_original_sigterm'): + signal.signal(signal.SIGTERM, self._original_sigterm) + + # Force stop scheduler + self.scheduler._running = False + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + gc.collect() + + @classmethod + def create_and_run(cls, model, tokenizer, **kwargs): + """ + Create engine, run generation, and shutdown automatically. + + This is a convenience method for simple scripts. + + Args: + model: The model to use + tokenizer: The tokenizer to use + **kwargs: Arguments passed to generate() + + Returns: + Generated text result + """ + with cls(model, tokenizer) as engine: + result = engine.generate(**kwargs) + return result diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index 7e44b07..0e6b2dc 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -68,6 +68,9 @@ def apply_sampling_strategies( filter_value: float = -float("inf"), ) -> Tensor: """Apply sampling strategies to the logits tensor.""" + # Clone logits to avoid inplace updates on inference tensor + logits = logits.clone() + if temperature != 1.0: logits = logits / temperature @@ -377,6 +380,21 @@ class InferenceScheduler: self._running = False if hasattr(self, "_loop_thread"): self._loop_thread.join(timeout=1.0) + + # Clear KV cache to free GPU memory + if self.kv_cache is not None: + k_cache, v_cache = self.kv_cache + if k_cache is not None: + k_cache.detach() + if v_cache is not None: + v_cache.detach() + + # Clear seq mask + self.seq_mask.detach() + + # Clear task lists + self.waiting_queue.clear() + self.active_tasks.clear() def get_stats(self) -> Dict[str, Any]: """Get scheduler statistics."""