refactor: 拆分engine.py 文件
This commit is contained in:
parent
861d33b1a1
commit
2b26f03bd3
|
|
@ -1,21 +1,10 @@
|
||||||
"""
|
"""Inference module for continuous batching."""
|
||||||
AstrAI Inference Module
|
|
||||||
|
|
||||||
Provides inference components for text generation with continuous batching support.
|
|
||||||
|
|
||||||
Main Components:
|
|
||||||
- InferenceEngine: Unified inference engine for continuous batching
|
|
||||||
- InferenceScheduler: Task scheduling with dynamic batch composition
|
|
||||||
- Task, TaskStatus: Task management for continuous batching
|
|
||||||
- GenerationRequest: Request parameters for generation
|
|
||||||
- apply_sampling_strategies: Sampling utilities for text generation
|
|
||||||
|
|
||||||
Author: AstrAI Team
|
|
||||||
"""
|
|
||||||
|
|
||||||
from astrai.inference.engine import (
|
from astrai.inference.engine import (
|
||||||
GenerationRequest,
|
GenerationRequest,
|
||||||
InferenceEngine,
|
InferenceEngine,
|
||||||
|
)
|
||||||
|
from astrai.inference.scheduler import (
|
||||||
InferenceScheduler,
|
InferenceScheduler,
|
||||||
Task,
|
Task,
|
||||||
TaskStatus,
|
TaskStatus,
|
||||||
|
|
@ -25,9 +14,11 @@ from astrai.inference.engine import (
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Engine
|
# Engine
|
||||||
"InferenceEngine",
|
"InferenceEngine",
|
||||||
|
# Scheduler
|
||||||
"InferenceScheduler",
|
"InferenceScheduler",
|
||||||
"Task",
|
"Task",
|
||||||
"TaskStatus",
|
"TaskStatus",
|
||||||
|
# Request
|
||||||
"GenerationRequest",
|
"GenerationRequest",
|
||||||
# Sampling
|
# Sampling
|
||||||
"apply_sampling_strategies",
|
"apply_sampling_strategies",
|
||||||
|
|
|
||||||
|
|
@ -1,49 +1,41 @@
|
||||||
"""
|
"""Unified inference engine."""
|
||||||
Continuous Batching Inference Engine
|
|
||||||
|
|
||||||
This module provides the main continuous batching components:
|
|
||||||
- Task: Individual generation task with state management
|
|
||||||
- TaskStatus: Task state enumeration
|
|
||||||
- InferenceScheduler: Handles request scheduling and KV cache management
|
|
||||||
- InferenceEngine: Unified inference engine
|
|
||||||
|
|
||||||
Author: AstrAI Team
|
|
||||||
"""
|
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
from typing import Any, Dict, Generator, List, Optional, Union
|
||||||
import uuid
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from enum import Enum, auto
|
|
||||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
from astrai.config import ModelParameter
|
||||||
from torch import Tensor
|
from astrai.tokenize.chat_template import build_prompt
|
||||||
|
|
||||||
from astrai.config import ModelConfig, ModelParameter
|
from astrai.inference.scheduler import InferenceScheduler
|
||||||
from astrai.tokenize.chat_template import HistoryType, build_prompt
|
|
||||||
|
|
||||||
|
|
||||||
# Use print for debugging instead of logging
|
|
||||||
def _debug(*args, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GenerationRequest:
|
class GenerationRequest:
|
||||||
"""Request parameters for text generation."""
|
"""Request parameters for text generation."""
|
||||||
|
|
||||||
top_k: int
|
def __init__(
|
||||||
top_p: float
|
self,
|
||||||
temperature: float
|
query: Union[str, List[str]],
|
||||||
max_len: int
|
top_k: int = 50,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
max_len: int = 1024,
|
||||||
|
history: Optional[Any] = None,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
):
|
||||||
|
self.query = query
|
||||||
|
self.top_k = top_k
|
||||||
|
self.top_p = top_p
|
||||||
|
self.temperature = temperature
|
||||||
|
self.max_len = max_len
|
||||||
|
self.history = history
|
||||||
|
self.system_prompt = system_prompt
|
||||||
|
self.stream = stream
|
||||||
|
|
||||||
query: Union[str, List[str]]
|
self._validate()
|
||||||
history: Optional[Union[HistoryType, List[HistoryType]]] = None
|
|
||||||
system_prompt: Optional[str] = None
|
|
||||||
stream: bool = False
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def _validate(self):
|
||||||
|
"""Validate request parameters."""
|
||||||
if not isinstance(self.top_k, int) or self.top_k < 0:
|
if not isinstance(self.top_k, int) or self.top_k < 0:
|
||||||
raise ValueError("top_k must be a non-negative integer")
|
raise ValueError("top_k must be a non-negative integer")
|
||||||
if not isinstance(self.top_p, float) or self.top_p < 0.0 or self.top_p > 1.0:
|
if not isinstance(self.top_p, float) or self.top_p < 0.0 or self.top_p > 1.0:
|
||||||
|
|
@ -52,482 +44,66 @@ class GenerationRequest:
|
||||||
raise ValueError("temperature must be a non-negative float")
|
raise ValueError("temperature must be a non-negative float")
|
||||||
|
|
||||||
|
|
||||||
class TaskStatus(Enum):
|
class _StreamingResult:
|
||||||
"""Task state enumeration for continuous batching.
|
"""Streaming result holder with event-based notification."""
|
||||||
|
|
||||||
States:
|
def __init__(self):
|
||||||
PENDING: Task is waiting to be scheduled
|
self.tokens: List[str] = []
|
||||||
RUNNING: Task is currently being processed
|
self._event = threading.Event()
|
||||||
FINISHED: Task completed successfully
|
|
||||||
ABORTED: Task was cancelled or failed
|
|
||||||
"""
|
|
||||||
|
|
||||||
PENDING = auto()
|
|
||||||
RUNNING = auto()
|
|
||||||
FINISHED = auto()
|
|
||||||
ABORTED = auto()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Task:
|
|
||||||
"""Individual task for continuous batching.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
task_id: Unique task identifier
|
|
||||||
prompt_ids: Input token IDs
|
|
||||||
max_tokens: Maximum tokens to generate
|
|
||||||
temperature: Sampling temperature
|
|
||||||
top_p: Top-p sampling parameter
|
|
||||||
top_k: Top-k sampling parameter
|
|
||||||
status: Current task status
|
|
||||||
output_ids: Generated token IDs
|
|
||||||
input_tokens: Number of input tokens
|
|
||||||
output_tokens: Number of output tokens generated
|
|
||||||
slot: Batch slot position (-1 if not assigned)
|
|
||||||
arrival_time: Task arrival timestamp
|
|
||||||
finish_time: Task completion timestamp
|
|
||||||
stream_callback: Callback for streaming output
|
|
||||||
"""
|
|
||||||
|
|
||||||
task_id: str
|
|
||||||
prompt_ids: List[int]
|
|
||||||
max_tokens: int = 1024
|
|
||||||
temperature: float = 1.0
|
|
||||||
top_p: float = 1.0
|
|
||||||
top_k: int = 50
|
|
||||||
|
|
||||||
status: TaskStatus = TaskStatus.PENDING
|
|
||||||
output_ids: List[int] = field(default_factory=list)
|
|
||||||
input_tokens: int = 0
|
|
||||||
output_tokens: int = 0
|
|
||||||
slot: int = -1
|
|
||||||
arrival_time: float = field(default_factory=time.time)
|
|
||||||
finish_time: Optional[float] = None
|
|
||||||
|
|
||||||
stream_callback: Optional[Callable[[str], None]] = None
|
|
||||||
|
|
||||||
def is_finished(self, stop_ids: List[int]) -> bool:
|
|
||||||
"""Check if task is finished."""
|
|
||||||
if self.output_ids and self.output_ids[-1] in stop_ids:
|
|
||||||
return True
|
|
||||||
if self.output_tokens >= self.max_tokens:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def apply_sampling_strategies(
|
|
||||||
logits: Tensor,
|
|
||||||
temperature: float,
|
|
||||||
top_k: int,
|
|
||||||
top_p: float,
|
|
||||||
filter_value: float = -float("inf"),
|
|
||||||
) -> Tensor:
|
|
||||||
"""Apply sampling strategies to the logits tensor."""
|
|
||||||
if temperature != 1.0:
|
|
||||||
logits = logits / temperature
|
|
||||||
|
|
||||||
if top_k > 0:
|
|
||||||
top_k = min(top_k, logits.size(-1))
|
|
||||||
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
|
|
||||||
logits[indices_to_remove] = filter_value
|
|
||||||
|
|
||||||
if top_p < 1.0:
|
|
||||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
|
||||||
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
|
||||||
|
|
||||||
sorted_indices_to_remove = cumulative_probs > top_p
|
|
||||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
||||||
sorted_indices_to_remove[..., 0] = 0
|
|
||||||
|
|
||||||
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
|
|
||||||
indices_to_remove.scatter_(
|
|
||||||
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
|
||||||
)
|
|
||||||
|
|
||||||
logits[indices_to_remove] = filter_value
|
|
||||||
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceScheduler:
|
|
||||||
"""Inference scheduler with continuous batching support.
|
|
||||||
|
|
||||||
Manages request scheduling, KV cache allocation, and generation loop.
|
|
||||||
Supports dynamic batch composition where new requests can join at any time
|
|
||||||
and completed requests are immediately released.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
config: ModelConfig,
|
|
||||||
max_batch_size: int = 16,
|
|
||||||
max_seq_len: Optional[int] = None,
|
|
||||||
device: str = "cuda",
|
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
|
||||||
):
|
|
||||||
self.model = model
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.config = config
|
|
||||||
self.max_batch_size = max_batch_size
|
|
||||||
self.max_seq_len = max_seq_len or config.max_len
|
|
||||||
self.device = device
|
|
||||||
self.dtype = dtype
|
|
||||||
|
|
||||||
num_heads = config.n_kv_heads
|
|
||||||
head_dim = config.dim // config.n_heads
|
|
||||||
n_layers = config.n_layers
|
|
||||||
|
|
||||||
k_cache = torch.empty(
|
|
||||||
(
|
|
||||||
max_batch_size,
|
|
||||||
self.max_seq_len,
|
|
||||||
n_layers,
|
|
||||||
num_heads,
|
|
||||||
head_dim,
|
|
||||||
),
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
v_cache = torch.empty(
|
|
||||||
(
|
|
||||||
max_batch_size,
|
|
||||||
self.max_seq_len,
|
|
||||||
n_layers,
|
|
||||||
num_heads,
|
|
||||||
head_dim,
|
|
||||||
),
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
self.kv_cache = (k_cache, v_cache)
|
|
||||||
self.seq_mask = torch.ones(
|
|
||||||
(max_batch_size, self.max_seq_len), device=device, dtype=torch.bool
|
|
||||||
)
|
|
||||||
|
|
||||||
self.waiting_queue: List[Task] = []
|
|
||||||
self.active_tasks: List[Task] = []
|
|
||||||
|
|
||||||
self._running = False
|
|
||||||
self._task_event = threading.Event()
|
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
self._total_tasks = 0
|
def append(self, token: str):
|
||||||
self._total_tokens = 0
|
|
||||||
|
|
||||||
def add_task(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
max_tokens: int = 1024,
|
|
||||||
temperature: float = 1.0,
|
|
||||||
top_p: float = 1.0,
|
|
||||||
top_k: int = 50,
|
|
||||||
stream_callback: Optional[Callable[[str], None]] = None,
|
|
||||||
) -> str:
|
|
||||||
"""Add a new task to the waiting queue."""
|
|
||||||
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
|
||||||
prompt_ids = self.tokenizer.encode(prompt)
|
|
||||||
|
|
||||||
_debug(
|
|
||||||
f"add_task: task_id={task_id}, prompt_len={len(prompt_ids)}, has_callback={stream_callback is not None}"
|
|
||||||
)
|
|
||||||
|
|
||||||
task = Task(
|
|
||||||
task_id=task_id,
|
|
||||||
prompt_ids=prompt_ids,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
top_k=top_k,
|
|
||||||
stream_callback=stream_callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self.waiting_queue.append(task)
|
self.tokens.append(token)
|
||||||
self._total_tasks += 1
|
self._event.set()
|
||||||
|
|
||||||
self._task_event.set()
|
def pop_all(self) -> List[str]:
|
||||||
return task_id
|
|
||||||
|
|
||||||
def remove_task(self, task_id: str) -> None:
|
|
||||||
"""Remove a task from the scheduler."""
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
|
tokens = self.tokens.copy()
|
||||||
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
|
self.tokens.clear()
|
||||||
|
if not tokens:
|
||||||
|
self._event.clear()
|
||||||
|
return tokens
|
||||||
|
|
||||||
def _remove_finished_tasks(self) -> None:
|
def wait(self, timeout: float = None) -> bool:
|
||||||
"""Remove finished tasks from active batch and update caches."""
|
return self._event.wait(timeout=timeout)
|
||||||
finished = []
|
|
||||||
for task in self.active_tasks:
|
|
||||||
if task.is_finished(self.tokenizer.stop_ids):
|
|
||||||
task.status = TaskStatus.FINISHED
|
|
||||||
task.finish_time = time.time()
|
|
||||||
finished.append(task)
|
|
||||||
self._total_tokens += task.output_tokens
|
|
||||||
|
|
||||||
for task in finished:
|
|
||||||
slot = task.slot
|
|
||||||
if slot >= 0 and slot < len(self.active_tasks):
|
|
||||||
self.seq_mask[slot, :] = False
|
|
||||||
task.slot = -1
|
|
||||||
|
|
||||||
self.active_tasks = [
|
class _NonStreamingResult:
|
||||||
t for t in self.active_tasks if t.status != TaskStatus.FINISHED
|
"""Non-streaming result holder with event-based completion notification."""
|
||||||
]
|
|
||||||
|
|
||||||
def _refill_active_batch(self) -> None:
|
def __init__(self, count: int):
|
||||||
"""Refill active batch with waiting tasks."""
|
self.results: List[str] = ["" for _ in range(count)]
|
||||||
available_slots = self.max_batch_size - len(self.active_tasks)
|
self.done_flags: List[bool] = [False] * count
|
||||||
if available_slots <= 0:
|
self._completed_count = 0
|
||||||
return
|
self._event = threading.Event()
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def append(self, idx: int, token: str):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
to_add = []
|
if token == "[DONE]":
|
||||||
for _ in range(min(available_slots, len(self.waiting_queue))):
|
if not self.done_flags[idx]:
|
||||||
if self.waiting_queue:
|
self.done_flags[idx] = True
|
||||||
task = self.waiting_queue.pop(0)
|
self._completed_count += 1
|
||||||
task.status = TaskStatus.RUNNING
|
if self._completed_count == len(self.results):
|
||||||
to_add.append(task)
|
self._event.set()
|
||||||
|
|
||||||
for task in to_add:
|
|
||||||
for i in range(self.max_batch_size):
|
|
||||||
if all(t.slot != i for t in self.active_tasks):
|
|
||||||
task.slot = i
|
|
||||||
break
|
|
||||||
self.active_tasks.append(task)
|
|
||||||
|
|
||||||
def _execute_prefill(self, tasks: List[Task]) -> None:
|
|
||||||
"""Execute Prefill phase: process entire prompt at once."""
|
|
||||||
if not tasks:
|
|
||||||
return
|
|
||||||
|
|
||||||
_debug(f"_execute_prefill: processing {len(tasks)} tasks")
|
|
||||||
|
|
||||||
# Sort tasks by slot to ensure correct batch indexing with KV cache
|
|
||||||
tasks = sorted(tasks, key=lambda t: t.slot)
|
|
||||||
|
|
||||||
prompt_lens = [len(task.prompt_ids) for task in tasks]
|
|
||||||
max_len = max(prompt_lens)
|
|
||||||
|
|
||||||
input_ids = torch.zeros(
|
|
||||||
len(tasks), max_len, dtype=torch.long, device=self.device
|
|
||||||
)
|
|
||||||
for i, task in enumerate(tasks):
|
|
||||||
if len(task.prompt_ids) > 0:
|
|
||||||
input_ids[i, : len(task.prompt_ids)] = torch.tensor(
|
|
||||||
task.prompt_ids, device=self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create boolean mask for attention
|
|
||||||
if self.tokenizer.pad_id is not None:
|
|
||||||
input_mask = torch.ne(input_ids, self.tokenizer.pad_id)
|
|
||||||
else:
|
|
||||||
input_mask = torch.ones(
|
|
||||||
input_ids.shape, dtype=torch.bool, device=self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
_debug(
|
|
||||||
f"_execute_prefill: input_ids shape={input_ids.shape}, max_len={max_len}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with torch.inference_mode():
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids,
|
|
||||||
input_mask=input_mask,
|
|
||||||
start_pos=0,
|
|
||||||
persistent_key_values=self.kv_cache,
|
|
||||||
)
|
|
||||||
_debug(
|
|
||||||
f"_execute_prefill: model forward done, output keys={outputs.keys() if hasattr(outputs, 'keys') else 'no keys'}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
_debug(f"_execute_prefill: ERROR: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
for i, task in enumerate(tasks):
|
|
||||||
task.input_tokens = prompt_lens[i]
|
|
||||||
task.output_tokens = 0
|
|
||||||
_debug(
|
|
||||||
f" task {task.task_id}: input_tokens={task.input_tokens}, output_tokens={task.output_tokens}"
|
|
||||||
)
|
|
||||||
|
|
||||||
for task in tasks:
|
|
||||||
if task.slot >= 0:
|
|
||||||
self.seq_mask[task.slot, : task.input_tokens] = True
|
|
||||||
|
|
||||||
_debug(f"_execute_prefill: done, {len(tasks)} tasks marked as prefill complete")
|
|
||||||
|
|
||||||
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
|
|
||||||
"""Execute Decode phase: generate one token at a time."""
|
|
||||||
if not tasks:
|
|
||||||
return
|
|
||||||
|
|
||||||
_debug(f"_execute_decode: processing {len(tasks)} tasks, start_pos={start_pos}")
|
|
||||||
|
|
||||||
# Sort tasks by slot to ensure batch index aligns with slot (KV cache position)
|
|
||||||
# Task at slot 0 → batch index 0 → KV stored at cache[0]
|
|
||||||
# Task at slot 1 → batch index 1 → KV stored at cache[1]
|
|
||||||
tasks = sorted(tasks, key=lambda t: t.slot)
|
|
||||||
|
|
||||||
input_ids = torch.zeros(len(tasks), dtype=torch.long, device=self.device)
|
|
||||||
for i, task in enumerate(tasks):
|
|
||||||
if task.output_ids:
|
|
||||||
input_ids[i] = task.output_ids[-1]
|
|
||||||
else:
|
else:
|
||||||
input_ids[i] = task.prompt_ids[-1]
|
self.results[idx] += token
|
||||||
|
|
||||||
input_tensor = input_ids.unsqueeze(1) # shape: (batch, 1)
|
def is_all_done(self) -> bool:
|
||||||
|
with self._lock:
|
||||||
|
return all(self.done_flags)
|
||||||
|
|
||||||
# Create 2D attention mask: (batch, seq_len)
|
def wait(self, timeout: float = None) -> bool:
|
||||||
active_mask = torch.ones((len(tasks), 1), dtype=torch.bool, device=self.device)
|
return self._event.wait(timeout=timeout)
|
||||||
_debug(
|
|
||||||
f"_execute_decode: input_tensor shape={input_tensor.shape}, active_mask shape={active_mask.shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
def get_results(self) -> List[str]:
|
||||||
with torch.inference_mode():
|
with self._lock:
|
||||||
outputs = self.model(
|
return self.results.copy()
|
||||||
input_tensor,
|
|
||||||
input_mask=active_mask,
|
|
||||||
persistent_key_values=self.kv_cache,
|
|
||||||
start_pos=start_pos,
|
|
||||||
)
|
|
||||||
_debug(
|
|
||||||
f"_execute_decode: model forward done, logits shape={outputs['logits'].shape}"
|
|
||||||
)
|
|
||||||
logits = outputs["logits"][:, -1, :]
|
|
||||||
except Exception as e:
|
|
||||||
_debug(f"_execute_decode: ERROR: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
next_token_ids = []
|
|
||||||
for i, task in enumerate(tasks):
|
|
||||||
logit = logits[i : i + 1]
|
|
||||||
logit = apply_sampling_strategies(
|
|
||||||
logit,
|
|
||||||
task.temperature,
|
|
||||||
task.top_k,
|
|
||||||
task.top_p,
|
|
||||||
)
|
|
||||||
probs = torch.softmax(logit, dim=-1)
|
|
||||||
next_token = torch.multinomial(probs, num_samples=1)
|
|
||||||
next_token_ids.append(next_token.item())
|
|
||||||
|
|
||||||
_debug(f"_execute_decode: next_tokens={next_token_ids}")
|
|
||||||
|
|
||||||
for task, next_token in zip(tasks, next_token_ids):
|
|
||||||
task.output_ids.append(next_token)
|
|
||||||
task.output_tokens += 1
|
|
||||||
|
|
||||||
pos = task.input_tokens + task.output_tokens
|
|
||||||
if task.slot >= 0 and pos < self.max_seq_len:
|
|
||||||
self.seq_mask[task.slot, pos] = True
|
|
||||||
|
|
||||||
if task.stream_callback:
|
|
||||||
token_str = self.tokenizer.decode([next_token])
|
|
||||||
task.stream_callback(token_str)
|
|
||||||
|
|
||||||
# Check if any task reached max_tokens or stop token
|
|
||||||
for task in tasks:
|
|
||||||
if task.output_tokens >= task.max_tokens or (
|
|
||||||
task.output_ids and task.output_ids[-1] in self.tokenizer.stop_ids
|
|
||||||
):
|
|
||||||
_debug(
|
|
||||||
f"decode: task {task.task_id} finished, output_tokens={task.output_tokens}, max_tokens={task.max_tokens}"
|
|
||||||
)
|
|
||||||
if task.stream_callback:
|
|
||||||
task.stream_callback("[DONE]")
|
|
||||||
|
|
||||||
def _run_generation_loop(self) -> None:
|
|
||||||
"""Main generation loop with continuous batching."""
|
|
||||||
_debug("generation_loop: started")
|
|
||||||
while self._running:
|
|
||||||
self._remove_finished_tasks()
|
|
||||||
self._refill_active_batch()
|
|
||||||
|
|
||||||
if not self.active_tasks:
|
|
||||||
self._task_event.wait(timeout=0.01)
|
|
||||||
self._task_event.clear()
|
|
||||||
continue
|
|
||||||
|
|
||||||
_debug(
|
|
||||||
f"generation_loop: active={len(self.active_tasks)}, waiting={len(self.waiting_queue)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
new_tasks = [t for t in self.active_tasks if t.output_tokens == 0]
|
|
||||||
decode_tasks = [t for t in self.active_tasks if t.output_tokens > 0]
|
|
||||||
|
|
||||||
_debug(
|
|
||||||
f"generation_loop: new_tasks={len(new_tasks)}, decode_tasks={len(decode_tasks)}"
|
|
||||||
)
|
|
||||||
for t in self.active_tasks:
|
|
||||||
_debug(
|
|
||||||
f" active task {t.task_id}: output_tokens={t.output_tokens}, input_tokens={t.input_tokens}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if decode_tasks:
|
|
||||||
start_pos = max(t.input_tokens + t.output_tokens for t in decode_tasks)
|
|
||||||
else:
|
|
||||||
start_pos = 0
|
|
||||||
|
|
||||||
# First run prefill for new tasks
|
|
||||||
if new_tasks:
|
|
||||||
_debug(f"generation_loop: running prefill for {len(new_tasks)} tasks")
|
|
||||||
self._execute_prefill(new_tasks)
|
|
||||||
_debug(f"generation_loop: prefill done")
|
|
||||||
|
|
||||||
# After prefill, convert these tasks to decode tasks in the same iteration
|
|
||||||
decode_tasks = new_tasks
|
|
||||||
start_pos = max(t.input_tokens for t in decode_tasks)
|
|
||||||
_debug(
|
|
||||||
f"generation_loop: after prefill, decode_tasks={len(decode_tasks)}, start_pos={start_pos}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if decode_tasks:
|
|
||||||
_debug(
|
|
||||||
f"generation_loop: running decode for {len(decode_tasks)} tasks, start_pos={start_pos}"
|
|
||||||
)
|
|
||||||
self._execute_decode(decode_tasks, start_pos)
|
|
||||||
_debug(f"generation_loop: decode done")
|
|
||||||
|
|
||||||
if not self.active_tasks and not self.waiting_queue:
|
|
||||||
time.sleep(0.001)
|
|
||||||
|
|
||||||
def start(self) -> None:
|
|
||||||
"""Start the generation loop in a background thread."""
|
|
||||||
if not self._running:
|
|
||||||
_debug("InferenceScheduler.start: starting loop thread")
|
|
||||||
self._running = True
|
|
||||||
self._loop_thread = threading.Thread(target=self._run_generation_loop)
|
|
||||||
self._loop_thread.daemon = True
|
|
||||||
self._loop_thread.start()
|
|
||||||
_debug("InferenceScheduler.start: loop thread started")
|
|
||||||
|
|
||||||
def stop(self) -> None:
|
|
||||||
"""Stop the generation loop."""
|
|
||||||
self._running = False
|
|
||||||
if hasattr(self, "_loop_thread"):
|
|
||||||
self._loop_thread.join(timeout=1.0)
|
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
|
||||||
"""Get scheduler statistics."""
|
|
||||||
return {
|
|
||||||
"total_tasks": self._total_tasks,
|
|
||||||
"total_tokens": self._total_tokens,
|
|
||||||
"active_tasks": len(self.active_tasks),
|
|
||||||
"waiting_queue": len(self.waiting_queue),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceEngine:
|
class InferenceEngine:
|
||||||
"""Unified inference engine for continuous batching.
|
"""Unified inference engine for continuous batching."""
|
||||||
|
|
||||||
Provides a single interface for:
|
|
||||||
- Single request generation (streaming or non-streaming)
|
|
||||||
- Batch request generation (streaming or non-streaming)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -556,9 +132,7 @@ class InferenceEngine:
|
||||||
self.kv_cache = self.scheduler.kv_cache
|
self.kv_cache = self.scheduler.kv_cache
|
||||||
self.seq_mask = self.scheduler.seq_mask
|
self.seq_mask = self.scheduler.seq_mask
|
||||||
|
|
||||||
_debug("InferenceEngine: starting scheduler")
|
|
||||||
self.scheduler.start()
|
self.scheduler.start()
|
||||||
_debug("InferenceEngine: scheduler started")
|
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
|
|
@ -606,43 +180,29 @@ class InferenceEngine:
|
||||||
top_p: float,
|
top_p: float,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]:
|
) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]:
|
||||||
"""Generate with streaming output (synchronous)."""
|
"""Generate with streaming output."""
|
||||||
results = []
|
|
||||||
_debug(f"_generate_streaming: prompts={len(prompts)}")
|
|
||||||
|
|
||||||
if is_batch:
|
if is_batch:
|
||||||
raise NotImplementedError("Batch streaming is not implemented yet")
|
raise NotImplementedError("Batch streaming is not implemented yet")
|
||||||
|
|
||||||
def make_callback(idx: int):
|
result = _StreamingResult()
|
||||||
def cb(token: str):
|
|
||||||
_debug(f"callback[{idx}]: token={token!r}")
|
|
||||||
results.append(token)
|
|
||||||
|
|
||||||
return cb
|
self.scheduler.add_task(
|
||||||
|
prompt=prompts[0],
|
||||||
for i, p in enumerate(prompts):
|
max_tokens=max_tokens,
|
||||||
_debug(f"_generate_streaming: adding task {i}: {p[:30]}...")
|
temperature=temperature,
|
||||||
self.scheduler.add_task(
|
top_p=top_p,
|
||||||
prompt=p,
|
top_k=top_k,
|
||||||
max_tokens=max_tokens,
|
stream_callback=result.append,
|
||||||
temperature=temperature,
|
)
|
||||||
top_p=top_p,
|
|
||||||
top_k=top_k,
|
|
||||||
stream_callback=make_callback(i),
|
|
||||||
)
|
|
||||||
|
|
||||||
def gen():
|
def gen():
|
||||||
_debug("generator: start yielding")
|
|
||||||
while True:
|
while True:
|
||||||
# Yield accumulated tokens
|
tokens = result.pop_all()
|
||||||
while results:
|
for token in tokens:
|
||||||
token = results.pop(0)
|
|
||||||
if token == "[DONE]":
|
if token == "[DONE]":
|
||||||
_debug("generator: got [DONE]")
|
|
||||||
return
|
return
|
||||||
_debug(f"generator: yielding {token!r}")
|
|
||||||
yield token
|
yield token
|
||||||
time.sleep(0.01)
|
result.wait(timeout=0.05)
|
||||||
|
|
||||||
return gen()
|
return gen()
|
||||||
|
|
||||||
|
|
@ -656,19 +216,7 @@ class InferenceEngine:
|
||||||
top_k: int,
|
top_k: int,
|
||||||
) -> Union[str, List[str]]:
|
) -> Union[str, List[str]]:
|
||||||
"""Generate without streaming."""
|
"""Generate without streaming."""
|
||||||
results = ["" for _ in range(len(prompts))]
|
result = _NonStreamingResult(len(prompts))
|
||||||
done_flags = [False] * len(prompts)
|
|
||||||
lock = threading.Lock()
|
|
||||||
|
|
||||||
def make_callback(idx: int):
|
|
||||||
def cb(token: str):
|
|
||||||
if token == "[DONE]":
|
|
||||||
done_flags[idx] = True
|
|
||||||
else:
|
|
||||||
with lock:
|
|
||||||
results[idx] += token
|
|
||||||
|
|
||||||
return cb
|
|
||||||
|
|
||||||
for i, p in enumerate(prompts):
|
for i, p in enumerate(prompts):
|
||||||
self.scheduler.add_task(
|
self.scheduler.add_task(
|
||||||
|
|
@ -677,12 +225,11 @@ class InferenceEngine:
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
stream_callback=make_callback(i),
|
stream_callback=result.append,
|
||||||
)
|
)
|
||||||
|
|
||||||
while not all(done_flags):
|
result.wait()
|
||||||
time.sleep(0.001)
|
results = result.get_results()
|
||||||
|
|
||||||
return results if is_batch else results[0]
|
return results if is_batch else results[0]
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,387 @@
|
||||||
|
"""Inference scheduler for continuous batching."""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from astrai.config import ModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatus:
|
||||||
|
"""Task state for continuous batching."""
|
||||||
|
|
||||||
|
PENDING = "pending"
|
||||||
|
RUNNING = "running"
|
||||||
|
FINISHED = "finished"
|
||||||
|
ABORTED = "aborted"
|
||||||
|
|
||||||
|
|
||||||
|
class Task:
|
||||||
|
"""Individual task for continuous batching."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
prompt_ids: List[int],
|
||||||
|
max_tokens: int = 1024,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
top_k: int = 50,
|
||||||
|
stream_callback: Optional[Callable[[str], None]] = None,
|
||||||
|
):
|
||||||
|
self.task_id = task_id
|
||||||
|
self.prompt_ids = prompt_ids
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.temperature = temperature
|
||||||
|
self.top_p = top_p
|
||||||
|
self.top_k = top_k
|
||||||
|
|
||||||
|
self.status = TaskStatus.PENDING
|
||||||
|
self.output_ids: List[int] = []
|
||||||
|
self.input_tokens: int = 0
|
||||||
|
self.output_tokens: int = 0
|
||||||
|
self.slot: int = -1
|
||||||
|
self.arrival_time = time.time()
|
||||||
|
self.finish_time: Optional[float] = None
|
||||||
|
|
||||||
|
self.stream_callback = stream_callback
|
||||||
|
|
||||||
|
def is_finished(self, stop_ids: List[int]) -> bool:
|
||||||
|
"""Check if task is finished."""
|
||||||
|
if self.output_ids and self.output_ids[-1] in stop_ids:
|
||||||
|
return True
|
||||||
|
if self.output_tokens >= self.max_tokens:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def apply_sampling_strategies(
|
||||||
|
logits: Tensor,
|
||||||
|
temperature: float,
|
||||||
|
top_k: int,
|
||||||
|
top_p: float,
|
||||||
|
filter_value: float = -float("inf"),
|
||||||
|
) -> Tensor:
|
||||||
|
"""Apply sampling strategies to the logits tensor."""
|
||||||
|
if temperature != 1.0:
|
||||||
|
logits = logits / temperature
|
||||||
|
|
||||||
|
if top_k > 0:
|
||||||
|
top_k = min(top_k, logits.size(-1))
|
||||||
|
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
|
||||||
|
logits[indices_to_remove] = filter_value
|
||||||
|
|
||||||
|
if top_p < 1.0:
|
||||||
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||||
|
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||||
|
|
||||||
|
sorted_indices_to_remove = cumulative_probs > top_p
|
||||||
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||||
|
sorted_indices_to_remove[..., 0] = 0
|
||||||
|
|
||||||
|
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
|
||||||
|
indices_to_remove.scatter_(
|
||||||
|
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
||||||
|
)
|
||||||
|
|
||||||
|
logits[indices_to_remove] = filter_value
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceScheduler:
|
||||||
|
"""Inference scheduler with continuous batching support."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
config: ModelConfig,
|
||||||
|
max_batch_size: int = 16,
|
||||||
|
max_seq_len: Optional[int] = None,
|
||||||
|
device: str = "cuda",
|
||||||
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.config = config
|
||||||
|
self.max_batch_size = max_batch_size
|
||||||
|
self.max_seq_len = max_seq_len or config.max_len
|
||||||
|
self.device = device
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
num_heads = config.n_kv_heads
|
||||||
|
head_dim = config.dim // config.n_heads
|
||||||
|
n_layers = config.n_layers
|
||||||
|
|
||||||
|
k_cache = torch.empty(
|
||||||
|
(
|
||||||
|
max_batch_size,
|
||||||
|
self.max_seq_len,
|
||||||
|
n_layers,
|
||||||
|
num_heads,
|
||||||
|
head_dim,
|
||||||
|
),
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
v_cache = torch.empty(
|
||||||
|
(
|
||||||
|
max_batch_size,
|
||||||
|
self.max_seq_len,
|
||||||
|
n_layers,
|
||||||
|
num_heads,
|
||||||
|
head_dim,
|
||||||
|
),
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
self.kv_cache = (k_cache, v_cache)
|
||||||
|
self.seq_mask = torch.ones(
|
||||||
|
(max_batch_size, self.max_seq_len), device=device, dtype=torch.bool
|
||||||
|
)
|
||||||
|
|
||||||
|
self.waiting_queue: List[Task] = []
|
||||||
|
self.active_tasks: List[Task] = []
|
||||||
|
|
||||||
|
self._running = False
|
||||||
|
self._task_event = threading.Event()
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
self._total_tasks = 0
|
||||||
|
self._total_tokens = 0
|
||||||
|
|
||||||
|
def add_task(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
max_tokens: int = 1024,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
top_k: int = 50,
|
||||||
|
stream_callback: Optional[Callable[[str], None]] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Add a new task to the waiting queue."""
|
||||||
|
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||||
|
prompt_ids = self.tokenizer.encode(prompt)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
task_id=task_id,
|
||||||
|
prompt_ids=prompt_ids,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
stream_callback=stream_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
self.waiting_queue.append(task)
|
||||||
|
self._total_tasks += 1
|
||||||
|
|
||||||
|
self._task_event.set()
|
||||||
|
return task_id
|
||||||
|
|
||||||
|
def remove_task(self, task_id: str) -> None:
|
||||||
|
"""Remove a task from the scheduler."""
|
||||||
|
with self._lock:
|
||||||
|
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
|
||||||
|
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
|
||||||
|
|
||||||
|
def _remove_finished_tasks(self) -> None:
|
||||||
|
"""Remove finished tasks from active batch."""
|
||||||
|
finished = []
|
||||||
|
for task in self.active_tasks:
|
||||||
|
if task.is_finished(self.tokenizer.stop_ids):
|
||||||
|
task.status = TaskStatus.FINISHED
|
||||||
|
task.finish_time = time.time()
|
||||||
|
finished.append(task)
|
||||||
|
self._total_tokens += task.output_tokens
|
||||||
|
|
||||||
|
for task in finished:
|
||||||
|
slot = task.slot
|
||||||
|
if slot >= 0 and slot < len(self.active_tasks):
|
||||||
|
self.seq_mask[slot, :] = False
|
||||||
|
task.slot = -1
|
||||||
|
|
||||||
|
self.active_tasks = [
|
||||||
|
t for t in self.active_tasks if t.status != TaskStatus.FINISHED
|
||||||
|
]
|
||||||
|
|
||||||
|
def _refill_active_batch(self) -> None:
|
||||||
|
"""Refill active batch with waiting tasks."""
|
||||||
|
available_slots = self.max_batch_size - len(self.active_tasks)
|
||||||
|
if available_slots <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
to_add = []
|
||||||
|
for _ in range(min(available_slots, len(self.waiting_queue))):
|
||||||
|
if self.waiting_queue:
|
||||||
|
task = self.waiting_queue.pop(0)
|
||||||
|
task.status = TaskStatus.RUNNING
|
||||||
|
to_add.append(task)
|
||||||
|
|
||||||
|
for task in to_add:
|
||||||
|
for i in range(self.max_batch_size):
|
||||||
|
if all(t.slot != i for t in self.active_tasks):
|
||||||
|
task.slot = i
|
||||||
|
break
|
||||||
|
self.active_tasks.append(task)
|
||||||
|
|
||||||
|
def _execute_prefill(self, tasks: List[Task]) -> None:
|
||||||
|
"""Execute Prefill phase."""
|
||||||
|
if not tasks:
|
||||||
|
return
|
||||||
|
|
||||||
|
tasks = sorted(tasks, key=lambda t: t.slot)
|
||||||
|
|
||||||
|
prompt_lens = [len(task.prompt_ids) for task in tasks]
|
||||||
|
max_len = max(prompt_lens)
|
||||||
|
|
||||||
|
input_ids = torch.zeros(
|
||||||
|
len(tasks), max_len, dtype=torch.long, device=self.device
|
||||||
|
)
|
||||||
|
for i, task in enumerate(tasks):
|
||||||
|
if len(task.prompt_ids) > 0:
|
||||||
|
input_ids[i, : len(task.prompt_ids)] = torch.tensor(
|
||||||
|
task.prompt_ids, device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.tokenizer.pad_id is not None:
|
||||||
|
input_mask = torch.ne(input_ids, self.tokenizer.pad_id)
|
||||||
|
else:
|
||||||
|
input_mask = torch.ones(
|
||||||
|
input_ids.shape, dtype=torch.bool, device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids,
|
||||||
|
input_mask=input_mask,
|
||||||
|
start_pos=0,
|
||||||
|
persistent_key_values=self.kv_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, task in enumerate(tasks):
|
||||||
|
task.input_tokens = prompt_lens[i]
|
||||||
|
task.output_tokens = 0
|
||||||
|
|
||||||
|
for task in tasks:
|
||||||
|
if task.slot >= 0:
|
||||||
|
self.seq_mask[task.slot, : task.input_tokens] = True
|
||||||
|
|
||||||
|
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
|
||||||
|
"""Execute Decode phase."""
|
||||||
|
if not tasks:
|
||||||
|
return
|
||||||
|
|
||||||
|
tasks = sorted(tasks, key=lambda t: t.slot)
|
||||||
|
|
||||||
|
input_ids = torch.zeros(len(tasks), dtype=torch.long, device=self.device)
|
||||||
|
for i, task in enumerate(tasks):
|
||||||
|
if task.output_ids:
|
||||||
|
input_ids[i] = task.output_ids[-1]
|
||||||
|
else:
|
||||||
|
input_ids[i] = task.prompt_ids[-1]
|
||||||
|
|
||||||
|
input_tensor = input_ids.unsqueeze(1)
|
||||||
|
active_mask = torch.ones((len(tasks), 1), dtype=torch.bool, device=self.device)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
outputs = self.model(
|
||||||
|
input_tensor,
|
||||||
|
input_mask=active_mask,
|
||||||
|
persistent_key_values=self.kv_cache,
|
||||||
|
start_pos=start_pos,
|
||||||
|
)
|
||||||
|
logits = outputs["logits"][:, -1, :]
|
||||||
|
|
||||||
|
next_token_ids = []
|
||||||
|
for i, task in enumerate(tasks):
|
||||||
|
logit = logits[i : i + 1]
|
||||||
|
logit = apply_sampling_strategies(
|
||||||
|
logit,
|
||||||
|
task.temperature,
|
||||||
|
task.top_k,
|
||||||
|
task.top_p,
|
||||||
|
)
|
||||||
|
probs = torch.softmax(logit, dim=-1)
|
||||||
|
next_token = torch.multinomial(probs, num_samples=1)
|
||||||
|
next_token_ids.append(next_token.item())
|
||||||
|
|
||||||
|
for task, next_token in zip(tasks, next_token_ids):
|
||||||
|
task.output_ids.append(next_token)
|
||||||
|
task.output_tokens += 1
|
||||||
|
|
||||||
|
pos = task.input_tokens + task.output_tokens
|
||||||
|
if task.slot >= 0 and pos < self.max_seq_len:
|
||||||
|
self.seq_mask[task.slot, pos] = True
|
||||||
|
|
||||||
|
if task.stream_callback:
|
||||||
|
token_str = self.tokenizer.decode([next_token])
|
||||||
|
task.stream_callback(token_str)
|
||||||
|
|
||||||
|
for task in tasks:
|
||||||
|
if task.output_tokens >= task.max_tokens or (
|
||||||
|
task.output_ids and task.output_ids[-1] in self.tokenizer.stop_ids
|
||||||
|
):
|
||||||
|
if task.stream_callback:
|
||||||
|
task.stream_callback("[DONE]")
|
||||||
|
|
||||||
|
def _run_generation_loop(self) -> None:
|
||||||
|
"""Main generation loop."""
|
||||||
|
while self._running:
|
||||||
|
self._remove_finished_tasks()
|
||||||
|
self._refill_active_batch()
|
||||||
|
|
||||||
|
if not self.active_tasks:
|
||||||
|
self._task_event.wait(timeout=0.01)
|
||||||
|
self._task_event.clear()
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_tasks = [t for t in self.active_tasks if t.output_tokens == 0]
|
||||||
|
decode_tasks = [t for t in self.active_tasks if t.output_tokens > 0]
|
||||||
|
|
||||||
|
if decode_tasks:
|
||||||
|
start_pos = max(t.input_tokens + t.output_tokens for t in decode_tasks)
|
||||||
|
else:
|
||||||
|
start_pos = 0
|
||||||
|
|
||||||
|
if new_tasks:
|
||||||
|
self._execute_prefill(new_tasks)
|
||||||
|
decode_tasks = new_tasks
|
||||||
|
start_pos = max(t.input_tokens for t in decode_tasks)
|
||||||
|
|
||||||
|
if decode_tasks:
|
||||||
|
self._execute_decode(decode_tasks, start_pos)
|
||||||
|
|
||||||
|
if not self.active_tasks and not self.waiting_queue:
|
||||||
|
self._task_event.wait(timeout=0.05)
|
||||||
|
self._task_event.clear()
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
"""Start the generation loop."""
|
||||||
|
if not self._running:
|
||||||
|
self._running = True
|
||||||
|
self._loop_thread = threading.Thread(target=self._run_generation_loop)
|
||||||
|
self._loop_thread.daemon = True
|
||||||
|
self._loop_thread.start()
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""Stop the generation loop."""
|
||||||
|
self._running = False
|
||||||
|
if hasattr(self, "_loop_thread"):
|
||||||
|
self._loop_thread.join(timeout=1.0)
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get scheduler statistics."""
|
||||||
|
return {
|
||||||
|
"total_tasks": self._total_tasks,
|
||||||
|
"total_tokens": self._total_tokens,
|
||||||
|
"active_tasks": len(self.active_tasks),
|
||||||
|
"waiting_queue": len(self.waiting_queue),
|
||||||
|
}
|
||||||
|
|
@ -4,6 +4,7 @@ import torch
|
||||||
|
|
||||||
from astrai.config.param_config import ModelParameter
|
from astrai.config.param_config import ModelParameter
|
||||||
|
|
||||||
|
|
||||||
def test_model_parameter(test_env):
|
def test_model_parameter(test_env):
|
||||||
save_dir = os.path.join(test_env["test_dir"], "save")
|
save_dir = os.path.join(test_env["test_dir"], "save")
|
||||||
model_param = ModelParameter(
|
model_param = ModelParameter(
|
||||||
|
|
@ -30,4 +31,4 @@ def test_transformer(test_env):
|
||||||
test_env["transformer_config"].max_len,
|
test_env["transformer_config"].max_len,
|
||||||
test_env["transformer_config"].vocab_size,
|
test_env["transformer_config"].vocab_size,
|
||||||
)
|
)
|
||||||
assert output_logits.shape == target_shape
|
assert output_logits.shape == target_shape
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue