fix: 修复异常处理问题
This commit is contained in:
parent
d2b36cc85d
commit
23ce4bc3ae
|
|
@ -1,13 +1,19 @@
|
||||||
"""Unified inference engine."""
|
"""Unified inference engine."""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import logging
|
||||||
|
import signal
|
||||||
import threading
|
import threading
|
||||||
|
from typing import Any, Dict, Generator, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import Any, Dict, Generator, List, Optional, Union
|
|
||||||
|
|
||||||
from astrai.tokenize.tokenizer import TextTokenizer
|
from astrai.tokenize.tokenizer import TextTokenizer
|
||||||
from astrai.inference.scheduler import InferenceScheduler
|
from astrai.inference.scheduler import InferenceScheduler
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class GenerationRequest:
|
class GenerationRequest:
|
||||||
"""Request parameters for text generation."""
|
"""Request parameters for text generation."""
|
||||||
|
|
@ -145,6 +151,22 @@ class InferenceEngine:
|
||||||
|
|
||||||
self.scheduler.start()
|
self.scheduler.start()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Handle exceptions on exit."""
|
||||||
|
if exc_type is not None:
|
||||||
|
# An exception occurred - try to save state
|
||||||
|
logger.warning(f"Exception {exc_type.__name__}: {exc_val}, saving state...")
|
||||||
|
try:
|
||||||
|
self.save_state("./inference_state")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.shutdown()
|
||||||
|
return False
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompt: Union[str, List[str]],
|
prompt: Union[str, List[str]],
|
||||||
|
|
@ -153,14 +175,21 @@ class InferenceEngine:
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
|
abort_on_exception: bool = True,
|
||||||
) -> Union[Generator[str, None, None], str, List[str]]:
|
) -> Union[Generator[str, None, None], str, List[str]]:
|
||||||
"""Unified generation interface."""
|
"""Unified generation interface.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
abort_on_exception: If True, abort the generation when consumer
|
||||||
|
stops iterating (GeneratorExit/StopIteration). Default: True.
|
||||||
|
"""
|
||||||
is_batch = isinstance(prompt, list)
|
is_batch = isinstance(prompt, list)
|
||||||
prompts = prompt if is_batch else [prompt]
|
prompts = prompt if is_batch else [prompt]
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._generate_streaming(
|
return self._generate_streaming(
|
||||||
prompts, is_batch, max_tokens, temperature, top_p, top_k
|
prompts, is_batch, max_tokens, temperature, top_p, top_k,
|
||||||
|
abort_on_exception
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self._generate_non_streaming(
|
return self._generate_non_streaming(
|
||||||
|
|
@ -191,14 +220,20 @@ class InferenceEngine:
|
||||||
temperature: float,
|
temperature: float,
|
||||||
top_p: float,
|
top_p: float,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
|
abort_on_exception: bool = True,
|
||||||
) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]:
|
) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]:
|
||||||
"""Generate with streaming output."""
|
"""Generate with streaming output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
abort_on_exception: If True, abort the task when generator is
|
||||||
|
stopped early by consumer (GeneratorExit/StopIteration).
|
||||||
|
"""
|
||||||
if is_batch:
|
if is_batch:
|
||||||
raise NotImplementedError("Batch streaming is not implemented yet")
|
raise NotImplementedError("Batch streaming is not implemented yet")
|
||||||
|
|
||||||
result = _StreamingResult()
|
result = _StreamingResult()
|
||||||
|
|
||||||
self.scheduler.add_task(
|
task_id = self.scheduler.add_task(
|
||||||
prompt=prompts[0],
|
prompt=prompts[0],
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
|
|
@ -208,14 +243,21 @@ class InferenceEngine:
|
||||||
)
|
)
|
||||||
|
|
||||||
def gen():
|
def gen():
|
||||||
while True:
|
try:
|
||||||
tokens = result.pop_all()
|
while True:
|
||||||
for token in tokens:
|
tokens = result.pop_all()
|
||||||
if token == "[DONE]":
|
for token in tokens:
|
||||||
return
|
if token == "[DONE]":
|
||||||
yield token
|
return
|
||||||
result.wait(timeout=0.05)
|
yield token
|
||||||
|
result.wait(timeout=0.05)
|
||||||
|
except Exception:
|
||||||
|
# Consumer stopped iterating - abort the task
|
||||||
|
if abort_on_exception:
|
||||||
|
self.scheduler.remove_task(task_id)
|
||||||
|
raise
|
||||||
|
|
||||||
|
gen.task_id = task_id
|
||||||
return gen()
|
return gen()
|
||||||
|
|
||||||
def _generate_non_streaming(
|
def _generate_non_streaming(
|
||||||
|
|
@ -249,5 +291,58 @@ class InferenceEngine:
|
||||||
return self.scheduler.get_stats()
|
return self.scheduler.get_stats()
|
||||||
|
|
||||||
def shutdown(self) -> None:
|
def shutdown(self) -> None:
|
||||||
"""Shutdown the engine."""
|
"""Shutdown the engine and release all resources."""
|
||||||
|
|
||||||
|
# Stop scheduler first
|
||||||
self.scheduler.stop()
|
self.scheduler.stop()
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
def force_stop(self) -> None:
|
||||||
|
"""
|
||||||
|
Force stop the engine immediately without saving state.
|
||||||
|
|
||||||
|
Use this for emergency shutdown when graceful shutdown is not possible.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Stop watching threads if any
|
||||||
|
if hasattr(self, 'stop_watching'):
|
||||||
|
self.stop_watching()
|
||||||
|
|
||||||
|
# Unregister signal handlers
|
||||||
|
if hasattr(self, '_original_sigint'):
|
||||||
|
signal.signal(signal.SIGINT, self._original_sigint)
|
||||||
|
if hasattr(self, '_original_sigterm'):
|
||||||
|
signal.signal(signal.SIGTERM, self._original_sigterm)
|
||||||
|
|
||||||
|
# Force stop scheduler
|
||||||
|
self.scheduler._running = False
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_and_run(cls, model, tokenizer, **kwargs):
|
||||||
|
"""
|
||||||
|
Create engine, run generation, and shutdown automatically.
|
||||||
|
|
||||||
|
This is a convenience method for simple scripts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model to use
|
||||||
|
tokenizer: The tokenizer to use
|
||||||
|
**kwargs: Arguments passed to generate()
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated text result
|
||||||
|
"""
|
||||||
|
with cls(model, tokenizer) as engine:
|
||||||
|
result = engine.generate(**kwargs)
|
||||||
|
return result
|
||||||
|
|
|
||||||
|
|
@ -68,6 +68,9 @@ def apply_sampling_strategies(
|
||||||
filter_value: float = -float("inf"),
|
filter_value: float = -float("inf"),
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""Apply sampling strategies to the logits tensor."""
|
"""Apply sampling strategies to the logits tensor."""
|
||||||
|
# Clone logits to avoid inplace updates on inference tensor
|
||||||
|
logits = logits.clone()
|
||||||
|
|
||||||
if temperature != 1.0:
|
if temperature != 1.0:
|
||||||
logits = logits / temperature
|
logits = logits / temperature
|
||||||
|
|
||||||
|
|
@ -378,6 +381,21 @@ class InferenceScheduler:
|
||||||
if hasattr(self, "_loop_thread"):
|
if hasattr(self, "_loop_thread"):
|
||||||
self._loop_thread.join(timeout=1.0)
|
self._loop_thread.join(timeout=1.0)
|
||||||
|
|
||||||
|
# Clear KV cache to free GPU memory
|
||||||
|
if self.kv_cache is not None:
|
||||||
|
k_cache, v_cache = self.kv_cache
|
||||||
|
if k_cache is not None:
|
||||||
|
k_cache.detach()
|
||||||
|
if v_cache is not None:
|
||||||
|
v_cache.detach()
|
||||||
|
|
||||||
|
# Clear seq mask
|
||||||
|
self.seq_mask.detach()
|
||||||
|
|
||||||
|
# Clear task lists
|
||||||
|
self.waiting_queue.clear()
|
||||||
|
self.active_tasks.clear()
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
"""Get scheduler statistics."""
|
"""Get scheduler statistics."""
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue