chore: 精简实现代码部分
This commit is contained in:
parent
d2fe8afbd1
commit
e58dbd7c57
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
import gc
|
||||
import logging
|
||||
import signal
|
||||
import threading
|
||||
from typing import Any, Dict, Generator, List, Optional, Union
|
||||
|
||||
|
|
@ -38,12 +37,12 @@ class GenerationRequest:
|
|||
|
||||
def _validate(self):
|
||||
"""Validate request parameters."""
|
||||
if not isinstance(self.top_k, int) or self.top_k < 0:
|
||||
if not (isinstance(self.top_k, int) and self.top_k >= 0):
|
||||
raise ValueError("top_k must be a non-negative integer")
|
||||
if not isinstance(self.top_p, float) or self.top_p < 0.0 or self.top_p > 1.0:
|
||||
if not (0.0 <= self.top_p <= 1.0):
|
||||
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
||||
if not isinstance(self.temperature, float) or self.temperature < 0.0:
|
||||
raise ValueError("temperature must be a non-negative float")
|
||||
if not (isinstance(self.temperature, (int, float)) and self.temperature >= 0):
|
||||
raise ValueError("temperature must be a non-negative number")
|
||||
|
||||
|
||||
class _StreamingResult:
|
||||
|
|
@ -75,7 +74,7 @@ class _NonStreamingResult:
|
|||
"""Non-streaming result holder with event-based completion notification."""
|
||||
|
||||
def __init__(self, count: int):
|
||||
self.results: List[str] = ["" for _ in range(count)]
|
||||
self.results: List[str] = [""] * count
|
||||
self.done_flags: List[bool] = [False] * count
|
||||
self._completed_count = 0
|
||||
self._event = threading.Event()
|
||||
|
|
@ -156,14 +155,6 @@ class InferenceEngine:
|
|||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Handle exceptions on exit."""
|
||||
if exc_type is not None:
|
||||
# An exception occurred - try to save state
|
||||
logger.warning(f"Exception {exc_type.__name__}: {exc_val}, saving state...")
|
||||
try:
|
||||
self.save_state("./inference_state")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.shutdown()
|
||||
return False
|
||||
|
||||
|
|
@ -297,56 +288,7 @@ class InferenceEngine:
|
|||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the engine and release all resources."""
|
||||
|
||||
# Stop scheduler first
|
||||
self.scheduler.stop()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
gc.collect()
|
||||
|
||||
def force_stop(self) -> None:
|
||||
"""
|
||||
Force stop the engine immediately without saving state.
|
||||
|
||||
Use this for emergency shutdown when graceful shutdown is not possible.
|
||||
"""
|
||||
|
||||
# Stop watching threads if any
|
||||
if hasattr(self, "stop_watching"):
|
||||
self.stop_watching()
|
||||
|
||||
# Unregister signal handlers
|
||||
if hasattr(self, "_original_sigint"):
|
||||
signal.signal(signal.SIGINT, self._original_sigint)
|
||||
if hasattr(self, "_original_sigterm"):
|
||||
signal.signal(signal.SIGTERM, self._original_sigterm)
|
||||
|
||||
# Force stop scheduler
|
||||
self.scheduler._running = False
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
gc.collect()
|
||||
|
||||
@classmethod
|
||||
def create_and_run(cls, model, tokenizer, **kwargs):
|
||||
"""
|
||||
Create engine, run generation, and shutdown automatically.
|
||||
|
||||
This is a convenience method for simple scripts.
|
||||
|
||||
Args:
|
||||
model: The model to use
|
||||
tokenizer: The tokenizer to use
|
||||
**kwargs: Arguments passed to generate()
|
||||
|
||||
Returns:
|
||||
Generated text result
|
||||
"""
|
||||
with cls(model, tokenizer) as engine:
|
||||
result = engine.generate(**kwargs)
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -55,9 +55,7 @@ class Task:
|
|||
"""Check if task is finished."""
|
||||
if self.output_ids and self.output_ids[-1] in stop_ids:
|
||||
return True
|
||||
if self.output_tokens >= self.max_tokens:
|
||||
return True
|
||||
return False
|
||||
return self.output_tokens >= self.max_tokens
|
||||
|
||||
|
||||
def apply_sampling_strategies(
|
||||
|
|
|
|||
Loading…
Reference in New Issue