chore: 精简实现代码部分

This commit is contained in:
ViperEkura 2026-04-05 21:16:38 +08:00
parent d2fe8afbd1
commit e58dbd7c57
2 changed files with 6 additions and 66 deletions

View File

@ -2,7 +2,6 @@
import gc
import logging
import signal
import threading
from typing import Any, Dict, Generator, List, Optional, Union
@ -38,12 +37,12 @@ class GenerationRequest:
def _validate(self):
"""Validate request parameters."""
if not isinstance(self.top_k, int) or self.top_k < 0:
if not (isinstance(self.top_k, int) and self.top_k >= 0):
raise ValueError("top_k must be a non-negative integer")
if not isinstance(self.top_p, float) or self.top_p < 0.0 or self.top_p > 1.0:
if not (0.0 <= self.top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not isinstance(self.temperature, float) or self.temperature < 0.0:
raise ValueError("temperature must be a non-negative float")
if not (isinstance(self.temperature, (int, float)) and self.temperature >= 0):
raise ValueError("temperature must be a non-negative number")
class _StreamingResult:
@ -75,7 +74,7 @@ class _NonStreamingResult:
"""Non-streaming result holder with event-based completion notification."""
def __init__(self, count: int):
self.results: List[str] = ["" for _ in range(count)]
self.results: List[str] = [""] * count
self.done_flags: List[bool] = [False] * count
self._completed_count = 0
self._event = threading.Event()
@ -156,14 +155,6 @@ class InferenceEngine:
def __exit__(self, exc_type, exc_val, exc_tb):
"""Handle exceptions on exit."""
if exc_type is not None:
# An exception occurred - try to save state
logger.warning(f"Exception {exc_type.__name__}: {exc_val}, saving state...")
try:
self.save_state("./inference_state")
except Exception:
pass
self.shutdown()
return False
@ -297,56 +288,7 @@ class InferenceEngine:
def shutdown(self) -> None:
"""Shutdown the engine and release all resources."""
# Stop scheduler first
self.scheduler.stop()
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def force_stop(self) -> None:
"""
Force stop the engine immediately without saving state.
Use this for emergency shutdown when graceful shutdown is not possible.
"""
# Stop watching threads if any
if hasattr(self, "stop_watching"):
self.stop_watching()
# Unregister signal handlers
if hasattr(self, "_original_sigint"):
signal.signal(signal.SIGINT, self._original_sigint)
if hasattr(self, "_original_sigterm"):
signal.signal(signal.SIGTERM, self._original_sigterm)
# Force stop scheduler
self.scheduler._running = False
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
@classmethod
def create_and_run(cls, model, tokenizer, **kwargs):
"""
Create engine, run generation, and shutdown automatically.
This is a convenience method for simple scripts.
Args:
model: The model to use
tokenizer: The tokenizer to use
**kwargs: Arguments passed to generate()
Returns:
Generated text result
"""
with cls(model, tokenizer) as engine:
result = engine.generate(**kwargs)
return result

View File

@ -55,9 +55,7 @@ class Task:
"""Check if task is finished."""
if self.output_ids and self.output_ids[-1] in stop_ids:
return True
if self.output_tokens >= self.max_tokens:
return True
return False
return self.output_tokens >= self.max_tokens
def apply_sampling_strategies(