fix: 修复异常处理问题
This commit is contained in:
parent
d2b36cc85d
commit
23ce4bc3ae
|
|
@ -1,13 +1,19 @@
|
|||
"""Unified inference engine."""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import signal
|
||||
import threading
|
||||
from typing import Any, Dict, Generator, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Any, Dict, Generator, List, Optional, Union
|
||||
|
||||
from astrai.tokenize.tokenizer import TextTokenizer
|
||||
from astrai.inference.scheduler import InferenceScheduler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GenerationRequest:
|
||||
"""Request parameters for text generation."""
|
||||
|
|
@ -145,6 +151,22 @@ class InferenceEngine:
|
|||
|
||||
self.scheduler.start()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Handle exceptions on exit."""
|
||||
if exc_type is not None:
|
||||
# An exception occurred - try to save state
|
||||
logger.warning(f"Exception {exc_type.__name__}: {exc_val}, saving state...")
|
||||
try:
|
||||
self.save_state("./inference_state")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.shutdown()
|
||||
return False
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
|
|
@ -153,14 +175,21 @@ class InferenceEngine:
|
|||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = 50,
|
||||
abort_on_exception: bool = True,
|
||||
) -> Union[Generator[str, None, None], str, List[str]]:
|
||||
"""Unified generation interface."""
|
||||
"""Unified generation interface.
|
||||
|
||||
Args:
|
||||
abort_on_exception: If True, abort the generation when consumer
|
||||
stops iterating (GeneratorExit/StopIteration). Default: True.
|
||||
"""
|
||||
is_batch = isinstance(prompt, list)
|
||||
prompts = prompt if is_batch else [prompt]
|
||||
|
||||
if stream:
|
||||
return self._generate_streaming(
|
||||
prompts, is_batch, max_tokens, temperature, top_p, top_k
|
||||
prompts, is_batch, max_tokens, temperature, top_p, top_k,
|
||||
abort_on_exception
|
||||
)
|
||||
else:
|
||||
return self._generate_non_streaming(
|
||||
|
|
@ -191,14 +220,20 @@ class InferenceEngine:
|
|||
temperature: float,
|
||||
top_p: float,
|
||||
top_k: int,
|
||||
abort_on_exception: bool = True,
|
||||
) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]:
|
||||
"""Generate with streaming output."""
|
||||
"""Generate with streaming output.
|
||||
|
||||
Args:
|
||||
abort_on_exception: If True, abort the task when generator is
|
||||
stopped early by consumer (GeneratorExit/StopIteration).
|
||||
"""
|
||||
if is_batch:
|
||||
raise NotImplementedError("Batch streaming is not implemented yet")
|
||||
|
||||
result = _StreamingResult()
|
||||
|
||||
self.scheduler.add_task(
|
||||
task_id = self.scheduler.add_task(
|
||||
prompt=prompts[0],
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
|
|
@ -208,14 +243,21 @@ class InferenceEngine:
|
|||
)
|
||||
|
||||
def gen():
|
||||
while True:
|
||||
tokens = result.pop_all()
|
||||
for token in tokens:
|
||||
if token == "[DONE]":
|
||||
return
|
||||
yield token
|
||||
result.wait(timeout=0.05)
|
||||
try:
|
||||
while True:
|
||||
tokens = result.pop_all()
|
||||
for token in tokens:
|
||||
if token == "[DONE]":
|
||||
return
|
||||
yield token
|
||||
result.wait(timeout=0.05)
|
||||
except Exception:
|
||||
# Consumer stopped iterating - abort the task
|
||||
if abort_on_exception:
|
||||
self.scheduler.remove_task(task_id)
|
||||
raise
|
||||
|
||||
gen.task_id = task_id
|
||||
return gen()
|
||||
|
||||
def _generate_non_streaming(
|
||||
|
|
@ -249,5 +291,58 @@ class InferenceEngine:
|
|||
return self.scheduler.get_stats()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the engine."""
|
||||
"""Shutdown the engine and release all resources."""
|
||||
|
||||
# Stop scheduler first
|
||||
self.scheduler.stop()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
gc.collect()
|
||||
|
||||
def force_stop(self) -> None:
|
||||
"""
|
||||
Force stop the engine immediately without saving state.
|
||||
|
||||
Use this for emergency shutdown when graceful shutdown is not possible.
|
||||
"""
|
||||
import os
|
||||
|
||||
# Stop watching threads if any
|
||||
if hasattr(self, 'stop_watching'):
|
||||
self.stop_watching()
|
||||
|
||||
# Unregister signal handlers
|
||||
if hasattr(self, '_original_sigint'):
|
||||
signal.signal(signal.SIGINT, self._original_sigint)
|
||||
if hasattr(self, '_original_sigterm'):
|
||||
signal.signal(signal.SIGTERM, self._original_sigterm)
|
||||
|
||||
# Force stop scheduler
|
||||
self.scheduler._running = False
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
gc.collect()
|
||||
|
||||
@classmethod
|
||||
def create_and_run(cls, model, tokenizer, **kwargs):
|
||||
"""
|
||||
Create engine, run generation, and shutdown automatically.
|
||||
|
||||
This is a convenience method for simple scripts.
|
||||
|
||||
Args:
|
||||
model: The model to use
|
||||
tokenizer: The tokenizer to use
|
||||
**kwargs: Arguments passed to generate()
|
||||
|
||||
Returns:
|
||||
Generated text result
|
||||
"""
|
||||
with cls(model, tokenizer) as engine:
|
||||
result = engine.generate(**kwargs)
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -68,6 +68,9 @@ def apply_sampling_strategies(
|
|||
filter_value: float = -float("inf"),
|
||||
) -> Tensor:
|
||||
"""Apply sampling strategies to the logits tensor."""
|
||||
# Clone logits to avoid inplace updates on inference tensor
|
||||
logits = logits.clone()
|
||||
|
||||
if temperature != 1.0:
|
||||
logits = logits / temperature
|
||||
|
||||
|
|
@ -377,6 +380,21 @@ class InferenceScheduler:
|
|||
self._running = False
|
||||
if hasattr(self, "_loop_thread"):
|
||||
self._loop_thread.join(timeout=1.0)
|
||||
|
||||
# Clear KV cache to free GPU memory
|
||||
if self.kv_cache is not None:
|
||||
k_cache, v_cache = self.kv_cache
|
||||
if k_cache is not None:
|
||||
k_cache.detach()
|
||||
if v_cache is not None:
|
||||
v_cache.detach()
|
||||
|
||||
# Clear seq mask
|
||||
self.seq_mask.detach()
|
||||
|
||||
# Clear task lists
|
||||
self.waiting_queue.clear()
|
||||
self.active_tasks.clear()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get scheduler statistics."""
|
||||
|
|
|
|||
Loading…
Reference in New Issue