fix: 修复异常处理问题

This commit is contained in:
ViperEkura 2026-04-05 20:44:35 +08:00
parent d2b36cc85d
commit 23ce4bc3ae
2 changed files with 126 additions and 13 deletions

View File

@ -1,13 +1,19 @@
"""Unified inference engine.""" """Unified inference engine."""
import gc
import logging
import signal
import threading import threading
from typing import Any, Dict, Generator, List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Any, Dict, Generator, List, Optional, Union
from astrai.tokenize.tokenizer import TextTokenizer from astrai.tokenize.tokenizer import TextTokenizer
from astrai.inference.scheduler import InferenceScheduler from astrai.inference.scheduler import InferenceScheduler
logger = logging.getLogger(__name__)
class GenerationRequest: class GenerationRequest:
"""Request parameters for text generation.""" """Request parameters for text generation."""
@ -145,6 +151,22 @@ class InferenceEngine:
self.scheduler.start() self.scheduler.start()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Handle exceptions on exit."""
if exc_type is not None:
# An exception occurred - try to save state
logger.warning(f"Exception {exc_type.__name__}: {exc_val}, saving state...")
try:
self.save_state("./inference_state")
except Exception:
pass
self.shutdown()
return False
def generate( def generate(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
@ -153,14 +175,21 @@ class InferenceEngine:
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = 50, top_k: int = 50,
abort_on_exception: bool = True,
) -> Union[Generator[str, None, None], str, List[str]]: ) -> Union[Generator[str, None, None], str, List[str]]:
"""Unified generation interface.""" """Unified generation interface.
Args:
abort_on_exception: If True, abort the generation when consumer
stops iterating (GeneratorExit/StopIteration). Default: True.
"""
is_batch = isinstance(prompt, list) is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt] prompts = prompt if is_batch else [prompt]
if stream: if stream:
return self._generate_streaming( return self._generate_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k prompts, is_batch, max_tokens, temperature, top_p, top_k,
abort_on_exception
) )
else: else:
return self._generate_non_streaming( return self._generate_non_streaming(
@ -191,14 +220,20 @@ class InferenceEngine:
temperature: float, temperature: float,
top_p: float, top_p: float,
top_k: int, top_k: int,
abort_on_exception: bool = True,
) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]: ) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]:
"""Generate with streaming output.""" """Generate with streaming output.
Args:
abort_on_exception: If True, abort the task when generator is
stopped early by consumer (GeneratorExit/StopIteration).
"""
if is_batch: if is_batch:
raise NotImplementedError("Batch streaming is not implemented yet") raise NotImplementedError("Batch streaming is not implemented yet")
result = _StreamingResult() result = _StreamingResult()
self.scheduler.add_task( task_id = self.scheduler.add_task(
prompt=prompts[0], prompt=prompts[0],
max_tokens=max_tokens, max_tokens=max_tokens,
temperature=temperature, temperature=temperature,
@ -208,14 +243,21 @@ class InferenceEngine:
) )
def gen(): def gen():
while True: try:
tokens = result.pop_all() while True:
for token in tokens: tokens = result.pop_all()
if token == "[DONE]": for token in tokens:
return if token == "[DONE]":
yield token return
result.wait(timeout=0.05) yield token
result.wait(timeout=0.05)
except Exception:
# Consumer stopped iterating - abort the task
if abort_on_exception:
self.scheduler.remove_task(task_id)
raise
gen.task_id = task_id
return gen() return gen()
def _generate_non_streaming( def _generate_non_streaming(
@ -249,5 +291,58 @@ class InferenceEngine:
return self.scheduler.get_stats() return self.scheduler.get_stats()
def shutdown(self) -> None: def shutdown(self) -> None:
"""Shutdown the engine.""" """Shutdown the engine and release all resources."""
# Stop scheduler first
self.scheduler.stop() self.scheduler.stop()
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def force_stop(self) -> None:
"""
Force stop the engine immediately without saving state.
Use this for emergency shutdown when graceful shutdown is not possible.
"""
import os
# Stop watching threads if any
if hasattr(self, 'stop_watching'):
self.stop_watching()
# Unregister signal handlers
if hasattr(self, '_original_sigint'):
signal.signal(signal.SIGINT, self._original_sigint)
if hasattr(self, '_original_sigterm'):
signal.signal(signal.SIGTERM, self._original_sigterm)
# Force stop scheduler
self.scheduler._running = False
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
@classmethod
def create_and_run(cls, model, tokenizer, **kwargs):
"""
Create engine, run generation, and shutdown automatically.
This is a convenience method for simple scripts.
Args:
model: The model to use
tokenizer: The tokenizer to use
**kwargs: Arguments passed to generate()
Returns:
Generated text result
"""
with cls(model, tokenizer) as engine:
result = engine.generate(**kwargs)
return result

View File

@ -68,6 +68,9 @@ def apply_sampling_strategies(
filter_value: float = -float("inf"), filter_value: float = -float("inf"),
) -> Tensor: ) -> Tensor:
"""Apply sampling strategies to the logits tensor.""" """Apply sampling strategies to the logits tensor."""
# Clone logits to avoid inplace updates on inference tensor
logits = logits.clone()
if temperature != 1.0: if temperature != 1.0:
logits = logits / temperature logits = logits / temperature
@ -378,6 +381,21 @@ class InferenceScheduler:
if hasattr(self, "_loop_thread"): if hasattr(self, "_loop_thread"):
self._loop_thread.join(timeout=1.0) self._loop_thread.join(timeout=1.0)
# Clear KV cache to free GPU memory
if self.kv_cache is not None:
k_cache, v_cache = self.kv_cache
if k_cache is not None:
k_cache.detach()
if v_cache is not None:
v_cache.detach()
# Clear seq mask
self.seq_mask.detach()
# Clear task lists
self.waiting_queue.clear()
self.active_tasks.clear()
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> Dict[str, Any]:
"""Get scheduler statistics.""" """Get scheduler statistics."""
return { return {