From e58dbd7c57045c11e93df7e306be3601eff9be89 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 5 Apr 2026 21:16:38 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E7=B2=BE=E7=AE=80=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E9=83=A8=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrai/inference/engine.py | 68 +++-------------------------------- astrai/inference/scheduler.py | 4 +-- 2 files changed, 6 insertions(+), 66 deletions(-) diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py index 996a16c..7f83047 100644 --- a/astrai/inference/engine.py +++ b/astrai/inference/engine.py @@ -2,7 +2,6 @@ import gc import logging -import signal import threading from typing import Any, Dict, Generator, List, Optional, Union @@ -38,12 +37,12 @@ class GenerationRequest: def _validate(self): """Validate request parameters.""" - if not isinstance(self.top_k, int) or self.top_k < 0: + if not (isinstance(self.top_k, int) and self.top_k >= 0): raise ValueError("top_k must be a non-negative integer") - if not isinstance(self.top_p, float) or self.top_p < 0.0 or self.top_p > 1.0: + if not (0.0 <= self.top_p <= 1.0): raise ValueError("top_p must be a float between 0.0 and 1.0") - if not isinstance(self.temperature, float) or self.temperature < 0.0: - raise ValueError("temperature must be a non-negative float") + if not (isinstance(self.temperature, (int, float)) and self.temperature >= 0): + raise ValueError("temperature must be a non-negative number") class _StreamingResult: @@ -75,7 +74,7 @@ class _NonStreamingResult: """Non-streaming result holder with event-based completion notification.""" def __init__(self, count: int): - self.results: List[str] = ["" for _ in range(count)] + self.results: List[str] = [""] * count self.done_flags: List[bool] = [False] * count self._completed_count = 0 self._event = threading.Event() @@ -156,14 +155,6 @@ class InferenceEngine: def __exit__(self, exc_type, exc_val, exc_tb): """Handle exceptions on exit.""" - if exc_type is not None: - # An exception occurred - try to save state - logger.warning(f"Exception {exc_type.__name__}: {exc_val}, saving state...") - try: - self.save_state("./inference_state") - except Exception: - pass - self.shutdown() return False @@ -297,56 +288,7 @@ class InferenceEngine: def shutdown(self) -> None: """Shutdown the engine and release all resources.""" - - # Stop scheduler first self.scheduler.stop() - if torch.cuda.is_available(): torch.cuda.empty_cache() - gc.collect() - - def force_stop(self) -> None: - """ - Force stop the engine immediately without saving state. - - Use this for emergency shutdown when graceful shutdown is not possible. - """ - - # Stop watching threads if any - if hasattr(self, "stop_watching"): - self.stop_watching() - - # Unregister signal handlers - if hasattr(self, "_original_sigint"): - signal.signal(signal.SIGINT, self._original_sigint) - if hasattr(self, "_original_sigterm"): - signal.signal(signal.SIGTERM, self._original_sigterm) - - # Force stop scheduler - self.scheduler._running = False - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.synchronize() - - gc.collect() - - @classmethod - def create_and_run(cls, model, tokenizer, **kwargs): - """ - Create engine, run generation, and shutdown automatically. - - This is a convenience method for simple scripts. - - Args: - model: The model to use - tokenizer: The tokenizer to use - **kwargs: Arguments passed to generate() - - Returns: - Generated text result - """ - with cls(model, tokenizer) as engine: - result = engine.generate(**kwargs) - return result diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index b8ce5e9..d972b20 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -55,9 +55,7 @@ class Task: """Check if task is finished.""" if self.output_ids and self.output_ids[-1] in stop_ids: return True - if self.output_tokens >= self.max_tokens: - return True - return False + return self.output_tokens >= self.max_tokens def apply_sampling_strategies(