chore: 精简实现代码部分

This commit is contained in:
ViperEkura 2026-04-05 21:16:38 +08:00
parent d2fe8afbd1
commit e58dbd7c57
2 changed files with 6 additions and 66 deletions

View File

@ -2,7 +2,6 @@
import gc import gc
import logging import logging
import signal
import threading import threading
from typing import Any, Dict, Generator, List, Optional, Union from typing import Any, Dict, Generator, List, Optional, Union
@ -38,12 +37,12 @@ class GenerationRequest:
def _validate(self): def _validate(self):
"""Validate request parameters.""" """Validate request parameters."""
if not isinstance(self.top_k, int) or self.top_k < 0: if not (isinstance(self.top_k, int) and self.top_k >= 0):
raise ValueError("top_k must be a non-negative integer") raise ValueError("top_k must be a non-negative integer")
if not isinstance(self.top_p, float) or self.top_p < 0.0 or self.top_p > 1.0: if not (0.0 <= self.top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0") raise ValueError("top_p must be a float between 0.0 and 1.0")
if not isinstance(self.temperature, float) or self.temperature < 0.0: if not (isinstance(self.temperature, (int, float)) and self.temperature >= 0):
raise ValueError("temperature must be a non-negative float") raise ValueError("temperature must be a non-negative number")
class _StreamingResult: class _StreamingResult:
@ -75,7 +74,7 @@ class _NonStreamingResult:
"""Non-streaming result holder with event-based completion notification.""" """Non-streaming result holder with event-based completion notification."""
def __init__(self, count: int): def __init__(self, count: int):
self.results: List[str] = ["" for _ in range(count)] self.results: List[str] = [""] * count
self.done_flags: List[bool] = [False] * count self.done_flags: List[bool] = [False] * count
self._completed_count = 0 self._completed_count = 0
self._event = threading.Event() self._event = threading.Event()
@ -156,14 +155,6 @@ class InferenceEngine:
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
"""Handle exceptions on exit.""" """Handle exceptions on exit."""
if exc_type is not None:
# An exception occurred - try to save state
logger.warning(f"Exception {exc_type.__name__}: {exc_val}, saving state...")
try:
self.save_state("./inference_state")
except Exception:
pass
self.shutdown() self.shutdown()
return False return False
@ -297,56 +288,7 @@ class InferenceEngine:
def shutdown(self) -> None: def shutdown(self) -> None:
"""Shutdown the engine and release all resources.""" """Shutdown the engine and release all resources."""
# Stop scheduler first
self.scheduler.stop() self.scheduler.stop()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
def force_stop(self) -> None:
"""
Force stop the engine immediately without saving state.
Use this for emergency shutdown when graceful shutdown is not possible.
"""
# Stop watching threads if any
if hasattr(self, "stop_watching"):
self.stop_watching()
# Unregister signal handlers
if hasattr(self, "_original_sigint"):
signal.signal(signal.SIGINT, self._original_sigint)
if hasattr(self, "_original_sigterm"):
signal.signal(signal.SIGTERM, self._original_sigterm)
# Force stop scheduler
self.scheduler._running = False
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
@classmethod
def create_and_run(cls, model, tokenizer, **kwargs):
"""
Create engine, run generation, and shutdown automatically.
This is a convenience method for simple scripts.
Args:
model: The model to use
tokenizer: The tokenizer to use
**kwargs: Arguments passed to generate()
Returns:
Generated text result
"""
with cls(model, tokenizer) as engine:
result = engine.generate(**kwargs)
return result

View File

@ -55,9 +55,7 @@ class Task:
"""Check if task is finished.""" """Check if task is finished."""
if self.output_ids and self.output_ids[-1] in stop_ids: if self.output_ids and self.output_ids[-1] in stop_ids:
return True return True
if self.output_tokens >= self.max_tokens: return self.output_tokens >= self.max_tokens
return True
return False
def apply_sampling_strategies( def apply_sampling_strategies(