chore: 精简实现代码部分
This commit is contained in:
parent
d2fe8afbd1
commit
e58dbd7c57
|
|
@ -2,7 +2,6 @@
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import logging
|
import logging
|
||||||
import signal
|
|
||||||
import threading
|
import threading
|
||||||
from typing import Any, Dict, Generator, List, Optional, Union
|
from typing import Any, Dict, Generator, List, Optional, Union
|
||||||
|
|
||||||
|
|
@ -38,12 +37,12 @@ class GenerationRequest:
|
||||||
|
|
||||||
def _validate(self):
|
def _validate(self):
|
||||||
"""Validate request parameters."""
|
"""Validate request parameters."""
|
||||||
if not isinstance(self.top_k, int) or self.top_k < 0:
|
if not (isinstance(self.top_k, int) and self.top_k >= 0):
|
||||||
raise ValueError("top_k must be a non-negative integer")
|
raise ValueError("top_k must be a non-negative integer")
|
||||||
if not isinstance(self.top_p, float) or self.top_p < 0.0 or self.top_p > 1.0:
|
if not (0.0 <= self.top_p <= 1.0):
|
||||||
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
||||||
if not isinstance(self.temperature, float) or self.temperature < 0.0:
|
if not (isinstance(self.temperature, (int, float)) and self.temperature >= 0):
|
||||||
raise ValueError("temperature must be a non-negative float")
|
raise ValueError("temperature must be a non-negative number")
|
||||||
|
|
||||||
|
|
||||||
class _StreamingResult:
|
class _StreamingResult:
|
||||||
|
|
@ -75,7 +74,7 @@ class _NonStreamingResult:
|
||||||
"""Non-streaming result holder with event-based completion notification."""
|
"""Non-streaming result holder with event-based completion notification."""
|
||||||
|
|
||||||
def __init__(self, count: int):
|
def __init__(self, count: int):
|
||||||
self.results: List[str] = ["" for _ in range(count)]
|
self.results: List[str] = [""] * count
|
||||||
self.done_flags: List[bool] = [False] * count
|
self.done_flags: List[bool] = [False] * count
|
||||||
self._completed_count = 0
|
self._completed_count = 0
|
||||||
self._event = threading.Event()
|
self._event = threading.Event()
|
||||||
|
|
@ -156,14 +155,6 @@ class InferenceEngine:
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
"""Handle exceptions on exit."""
|
"""Handle exceptions on exit."""
|
||||||
if exc_type is not None:
|
|
||||||
# An exception occurred - try to save state
|
|
||||||
logger.warning(f"Exception {exc_type.__name__}: {exc_val}, saving state...")
|
|
||||||
try:
|
|
||||||
self.save_state("./inference_state")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
self.shutdown()
|
self.shutdown()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
@ -297,56 +288,7 @@ class InferenceEngine:
|
||||||
|
|
||||||
def shutdown(self) -> None:
|
def shutdown(self) -> None:
|
||||||
"""Shutdown the engine and release all resources."""
|
"""Shutdown the engine and release all resources."""
|
||||||
|
|
||||||
# Stop scheduler first
|
|
||||||
self.scheduler.stop()
|
self.scheduler.stop()
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
def force_stop(self) -> None:
|
|
||||||
"""
|
|
||||||
Force stop the engine immediately without saving state.
|
|
||||||
|
|
||||||
Use this for emergency shutdown when graceful shutdown is not possible.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Stop watching threads if any
|
|
||||||
if hasattr(self, "stop_watching"):
|
|
||||||
self.stop_watching()
|
|
||||||
|
|
||||||
# Unregister signal handlers
|
|
||||||
if hasattr(self, "_original_sigint"):
|
|
||||||
signal.signal(signal.SIGINT, self._original_sigint)
|
|
||||||
if hasattr(self, "_original_sigterm"):
|
|
||||||
signal.signal(signal.SIGTERM, self._original_sigterm)
|
|
||||||
|
|
||||||
# Force stop scheduler
|
|
||||||
self.scheduler._running = False
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create_and_run(cls, model, tokenizer, **kwargs):
|
|
||||||
"""
|
|
||||||
Create engine, run generation, and shutdown automatically.
|
|
||||||
|
|
||||||
This is a convenience method for simple scripts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The model to use
|
|
||||||
tokenizer: The tokenizer to use
|
|
||||||
**kwargs: Arguments passed to generate()
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Generated text result
|
|
||||||
"""
|
|
||||||
with cls(model, tokenizer) as engine:
|
|
||||||
result = engine.generate(**kwargs)
|
|
||||||
return result
|
|
||||||
|
|
|
||||||
|
|
@ -55,9 +55,7 @@ class Task:
|
||||||
"""Check if task is finished."""
|
"""Check if task is finished."""
|
||||||
if self.output_ids and self.output_ids[-1] in stop_ids:
|
if self.output_ids and self.output_ids[-1] in stop_ids:
|
||||||
return True
|
return True
|
||||||
if self.output_tokens >= self.max_tokens:
|
return self.output_tokens >= self.max_tokens
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def apply_sampling_strategies(
|
def apply_sampling_strategies(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue