refactor: 拆分engine.py 文件

This commit is contained in:
ViperEkura 2026-04-05 00:07:21 +08:00
parent 861d33b1a1
commit 2b26f03bd3
4 changed files with 479 additions and 553 deletions

View File

@ -1,21 +1,10 @@
"""
AstrAI Inference Module
Provides inference components for text generation with continuous batching support.
Main Components:
- InferenceEngine: Unified inference engine for continuous batching
- InferenceScheduler: Task scheduling with dynamic batch composition
- Task, TaskStatus: Task management for continuous batching
- GenerationRequest: Request parameters for generation
- apply_sampling_strategies: Sampling utilities for text generation
Author: AstrAI Team
"""
"""Inference module for continuous batching."""
from astrai.inference.engine import (
GenerationRequest,
InferenceEngine,
)
from astrai.inference.scheduler import (
InferenceScheduler,
Task,
TaskStatus,
@ -25,9 +14,11 @@ from astrai.inference.engine import (
__all__ = [
# Engine
"InferenceEngine",
# Scheduler
"InferenceScheduler",
"Task",
"TaskStatus",
# Request
"GenerationRequest",
# Sampling
"apply_sampling_strategies",

View File

@ -1,49 +1,41 @@
"""
Continuous Batching Inference Engine
This module provides the main continuous batching components:
- Task: Individual generation task with state management
- TaskStatus: Task state enumeration
- InferenceScheduler: Handles request scheduling and KV cache management
- InferenceEngine: Unified inference engine
Author: AstrAI Team
"""
"""Unified inference engine."""
import threading
import time
import uuid
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
from typing import Any, Dict, Generator, List, Optional, Union
import torch
from torch import Tensor
from astrai.config import ModelParameter
from astrai.tokenize.chat_template import build_prompt
from astrai.config import ModelConfig, ModelParameter
from astrai.tokenize.chat_template import HistoryType, build_prompt
from astrai.inference.scheduler import InferenceScheduler
# Use print for debugging instead of logging
def _debug(*args, **kwargs):
pass
@dataclass
class GenerationRequest:
"""Request parameters for text generation."""
top_k: int
top_p: float
temperature: float
max_len: int
def __init__(
self,
query: Union[str, List[str]],
top_k: int = 50,
top_p: float = 1.0,
temperature: float = 1.0,
max_len: int = 1024,
history: Optional[Any] = None,
system_prompt: Optional[str] = None,
stream: bool = False,
):
self.query = query
self.top_k = top_k
self.top_p = top_p
self.temperature = temperature
self.max_len = max_len
self.history = history
self.system_prompt = system_prompt
self.stream = stream
query: Union[str, List[str]]
history: Optional[Union[HistoryType, List[HistoryType]]] = None
system_prompt: Optional[str] = None
stream: bool = False
self._validate()
def __post_init__(self):
def _validate(self):
"""Validate request parameters."""
if not isinstance(self.top_k, int) or self.top_k < 0:
raise ValueError("top_k must be a non-negative integer")
if not isinstance(self.top_p, float) or self.top_p < 0.0 or self.top_p > 1.0:
@ -52,482 +44,66 @@ class GenerationRequest:
raise ValueError("temperature must be a non-negative float")
class TaskStatus(Enum):
"""Task state enumeration for continuous batching.
class _StreamingResult:
"""Streaming result holder with event-based notification."""
States:
PENDING: Task is waiting to be scheduled
RUNNING: Task is currently being processed
FINISHED: Task completed successfully
ABORTED: Task was cancelled or failed
"""
PENDING = auto()
RUNNING = auto()
FINISHED = auto()
ABORTED = auto()
@dataclass
class Task:
"""Individual task for continuous batching.
Attributes:
task_id: Unique task identifier
prompt_ids: Input token IDs
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Top-p sampling parameter
top_k: Top-k sampling parameter
status: Current task status
output_ids: Generated token IDs
input_tokens: Number of input tokens
output_tokens: Number of output tokens generated
slot: Batch slot position (-1 if not assigned)
arrival_time: Task arrival timestamp
finish_time: Task completion timestamp
stream_callback: Callback for streaming output
"""
task_id: str
prompt_ids: List[int]
max_tokens: int = 1024
temperature: float = 1.0
top_p: float = 1.0
top_k: int = 50
status: TaskStatus = TaskStatus.PENDING
output_ids: List[int] = field(default_factory=list)
input_tokens: int = 0
output_tokens: int = 0
slot: int = -1
arrival_time: float = field(default_factory=time.time)
finish_time: Optional[float] = None
stream_callback: Optional[Callable[[str], None]] = None
def is_finished(self, stop_ids: List[int]) -> bool:
"""Check if task is finished."""
if self.output_ids and self.output_ids[-1] in stop_ids:
return True
if self.output_tokens >= self.max_tokens:
return True
return False
def apply_sampling_strategies(
logits: Tensor,
temperature: float,
top_k: int,
top_p: float,
filter_value: float = -float("inf"),
) -> Tensor:
"""Apply sampling strategies to the logits tensor."""
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
top_k = min(top_k, logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(
dim=1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
class InferenceScheduler:
"""Inference scheduler with continuous batching support.
Manages request scheduling, KV cache allocation, and generation loop.
Supports dynamic batch composition where new requests can join at any time
and completed requests are immediately released.
"""
def __init__(
self,
model,
tokenizer,
config: ModelConfig,
max_batch_size: int = 16,
max_seq_len: Optional[int] = None,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
self.model = model
self.tokenizer = tokenizer
self.config = config
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len or config.max_len
self.device = device
self.dtype = dtype
num_heads = config.n_kv_heads
head_dim = config.dim // config.n_heads
n_layers = config.n_layers
k_cache = torch.empty(
(
max_batch_size,
self.max_seq_len,
n_layers,
num_heads,
head_dim,
),
device=device,
dtype=dtype,
)
v_cache = torch.empty(
(
max_batch_size,
self.max_seq_len,
n_layers,
num_heads,
head_dim,
),
device=device,
dtype=dtype,
)
self.kv_cache = (k_cache, v_cache)
self.seq_mask = torch.ones(
(max_batch_size, self.max_seq_len), device=device, dtype=torch.bool
)
self.waiting_queue: List[Task] = []
self.active_tasks: List[Task] = []
self._running = False
self._task_event = threading.Event()
def __init__(self):
self.tokens: List[str] = []
self._event = threading.Event()
self._lock = threading.Lock()
self._total_tasks = 0
self._total_tokens = 0
def add_task(
self,
prompt: str,
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None,
) -> str:
"""Add a new task to the waiting queue."""
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
prompt_ids = self.tokenizer.encode(prompt)
_debug(
f"add_task: task_id={task_id}, prompt_len={len(prompt_ids)}, has_callback={stream_callback is not None}"
)
task = Task(
task_id=task_id,
prompt_ids=prompt_ids,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=stream_callback,
)
def append(self, token: str):
with self._lock:
self.waiting_queue.append(task)
self._total_tasks += 1
self.tokens.append(token)
self._event.set()
self._task_event.set()
return task_id
def remove_task(self, task_id: str) -> None:
"""Remove a task from the scheduler."""
def pop_all(self) -> List[str]:
with self._lock:
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
tokens = self.tokens.copy()
self.tokens.clear()
if not tokens:
self._event.clear()
return tokens
def _remove_finished_tasks(self) -> None:
"""Remove finished tasks from active batch and update caches."""
finished = []
for task in self.active_tasks:
if task.is_finished(self.tokenizer.stop_ids):
task.status = TaskStatus.FINISHED
task.finish_time = time.time()
finished.append(task)
self._total_tokens += task.output_tokens
def wait(self, timeout: float = None) -> bool:
return self._event.wait(timeout=timeout)
for task in finished:
slot = task.slot
if slot >= 0 and slot < len(self.active_tasks):
self.seq_mask[slot, :] = False
task.slot = -1
self.active_tasks = [
t for t in self.active_tasks if t.status != TaskStatus.FINISHED
]
class _NonStreamingResult:
"""Non-streaming result holder with event-based completion notification."""
def _refill_active_batch(self) -> None:
"""Refill active batch with waiting tasks."""
available_slots = self.max_batch_size - len(self.active_tasks)
if available_slots <= 0:
return
def __init__(self, count: int):
self.results: List[str] = ["" for _ in range(count)]
self.done_flags: List[bool] = [False] * count
self._completed_count = 0
self._event = threading.Event()
self._lock = threading.Lock()
def append(self, idx: int, token: str):
with self._lock:
to_add = []
for _ in range(min(available_slots, len(self.waiting_queue))):
if self.waiting_queue:
task = self.waiting_queue.pop(0)
task.status = TaskStatus.RUNNING
to_add.append(task)
for task in to_add:
for i in range(self.max_batch_size):
if all(t.slot != i for t in self.active_tasks):
task.slot = i
break
self.active_tasks.append(task)
def _execute_prefill(self, tasks: List[Task]) -> None:
"""Execute Prefill phase: process entire prompt at once."""
if not tasks:
return
_debug(f"_execute_prefill: processing {len(tasks)} tasks")
# Sort tasks by slot to ensure correct batch indexing with KV cache
tasks = sorted(tasks, key=lambda t: t.slot)
prompt_lens = [len(task.prompt_ids) for task in tasks]
max_len = max(prompt_lens)
input_ids = torch.zeros(
len(tasks), max_len, dtype=torch.long, device=self.device
)
for i, task in enumerate(tasks):
if len(task.prompt_ids) > 0:
input_ids[i, : len(task.prompt_ids)] = torch.tensor(
task.prompt_ids, device=self.device
)
# Create boolean mask for attention
if self.tokenizer.pad_id is not None:
input_mask = torch.ne(input_ids, self.tokenizer.pad_id)
if token == "[DONE]":
if not self.done_flags[idx]:
self.done_flags[idx] = True
self._completed_count += 1
if self._completed_count == len(self.results):
self._event.set()
else:
input_mask = torch.ones(
input_ids.shape, dtype=torch.bool, device=self.device
)
self.results[idx] += token
_debug(
f"_execute_prefill: input_ids shape={input_ids.shape}, max_len={max_len}"
)
def is_all_done(self) -> bool:
with self._lock:
return all(self.done_flags)
try:
with torch.inference_mode():
outputs = self.model(
input_ids,
input_mask=input_mask,
start_pos=0,
persistent_key_values=self.kv_cache,
)
_debug(
f"_execute_prefill: model forward done, output keys={outputs.keys() if hasattr(outputs, 'keys') else 'no keys'}"
)
except Exception as e:
_debug(f"_execute_prefill: ERROR: {e}")
raise
def wait(self, timeout: float = None) -> bool:
return self._event.wait(timeout=timeout)
for i, task in enumerate(tasks):
task.input_tokens = prompt_lens[i]
task.output_tokens = 0
_debug(
f" task {task.task_id}: input_tokens={task.input_tokens}, output_tokens={task.output_tokens}"
)
for task in tasks:
if task.slot >= 0:
self.seq_mask[task.slot, : task.input_tokens] = True
_debug(f"_execute_prefill: done, {len(tasks)} tasks marked as prefill complete")
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
"""Execute Decode phase: generate one token at a time."""
if not tasks:
return
_debug(f"_execute_decode: processing {len(tasks)} tasks, start_pos={start_pos}")
# Sort tasks by slot to ensure batch index aligns with slot (KV cache position)
# Task at slot 0 → batch index 0 → KV stored at cache[0]
# Task at slot 1 → batch index 1 → KV stored at cache[1]
tasks = sorted(tasks, key=lambda t: t.slot)
input_ids = torch.zeros(len(tasks), dtype=torch.long, device=self.device)
for i, task in enumerate(tasks):
if task.output_ids:
input_ids[i] = task.output_ids[-1]
else:
input_ids[i] = task.prompt_ids[-1]
input_tensor = input_ids.unsqueeze(1) # shape: (batch, 1)
# Create 2D attention mask: (batch, seq_len)
active_mask = torch.ones((len(tasks), 1), dtype=torch.bool, device=self.device)
_debug(
f"_execute_decode: input_tensor shape={input_tensor.shape}, active_mask shape={active_mask.shape}"
)
try:
with torch.inference_mode():
outputs = self.model(
input_tensor,
input_mask=active_mask,
persistent_key_values=self.kv_cache,
start_pos=start_pos,
)
_debug(
f"_execute_decode: model forward done, logits shape={outputs['logits'].shape}"
)
logits = outputs["logits"][:, -1, :]
except Exception as e:
_debug(f"_execute_decode: ERROR: {e}")
raise
next_token_ids = []
for i, task in enumerate(tasks):
logit = logits[i : i + 1]
logit = apply_sampling_strategies(
logit,
task.temperature,
task.top_k,
task.top_p,
)
probs = torch.softmax(logit, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
next_token_ids.append(next_token.item())
_debug(f"_execute_decode: next_tokens={next_token_ids}")
for task, next_token in zip(tasks, next_token_ids):
task.output_ids.append(next_token)
task.output_tokens += 1
pos = task.input_tokens + task.output_tokens
if task.slot >= 0 and pos < self.max_seq_len:
self.seq_mask[task.slot, pos] = True
if task.stream_callback:
token_str = self.tokenizer.decode([next_token])
task.stream_callback(token_str)
# Check if any task reached max_tokens or stop token
for task in tasks:
if task.output_tokens >= task.max_tokens or (
task.output_ids and task.output_ids[-1] in self.tokenizer.stop_ids
):
_debug(
f"decode: task {task.task_id} finished, output_tokens={task.output_tokens}, max_tokens={task.max_tokens}"
)
if task.stream_callback:
task.stream_callback("[DONE]")
def _run_generation_loop(self) -> None:
"""Main generation loop with continuous batching."""
_debug("generation_loop: started")
while self._running:
self._remove_finished_tasks()
self._refill_active_batch()
if not self.active_tasks:
self._task_event.wait(timeout=0.01)
self._task_event.clear()
continue
_debug(
f"generation_loop: active={len(self.active_tasks)}, waiting={len(self.waiting_queue)}"
)
new_tasks = [t for t in self.active_tasks if t.output_tokens == 0]
decode_tasks = [t for t in self.active_tasks if t.output_tokens > 0]
_debug(
f"generation_loop: new_tasks={len(new_tasks)}, decode_tasks={len(decode_tasks)}"
)
for t in self.active_tasks:
_debug(
f" active task {t.task_id}: output_tokens={t.output_tokens}, input_tokens={t.input_tokens}"
)
if decode_tasks:
start_pos = max(t.input_tokens + t.output_tokens for t in decode_tasks)
else:
start_pos = 0
# First run prefill for new tasks
if new_tasks:
_debug(f"generation_loop: running prefill for {len(new_tasks)} tasks")
self._execute_prefill(new_tasks)
_debug(f"generation_loop: prefill done")
# After prefill, convert these tasks to decode tasks in the same iteration
decode_tasks = new_tasks
start_pos = max(t.input_tokens for t in decode_tasks)
_debug(
f"generation_loop: after prefill, decode_tasks={len(decode_tasks)}, start_pos={start_pos}"
)
if decode_tasks:
_debug(
f"generation_loop: running decode for {len(decode_tasks)} tasks, start_pos={start_pos}"
)
self._execute_decode(decode_tasks, start_pos)
_debug(f"generation_loop: decode done")
if not self.active_tasks and not self.waiting_queue:
time.sleep(0.001)
def start(self) -> None:
"""Start the generation loop in a background thread."""
if not self._running:
_debug("InferenceScheduler.start: starting loop thread")
self._running = True
self._loop_thread = threading.Thread(target=self._run_generation_loop)
self._loop_thread.daemon = True
self._loop_thread.start()
_debug("InferenceScheduler.start: loop thread started")
def stop(self) -> None:
"""Stop the generation loop."""
self._running = False
if hasattr(self, "_loop_thread"):
self._loop_thread.join(timeout=1.0)
def get_stats(self) -> Dict[str, Any]:
"""Get scheduler statistics."""
return {
"total_tasks": self._total_tasks,
"total_tokens": self._total_tokens,
"active_tasks": len(self.active_tasks),
"waiting_queue": len(self.waiting_queue),
}
def get_results(self) -> List[str]:
with self._lock:
return self.results.copy()
class InferenceEngine:
"""Unified inference engine for continuous batching.
Provides a single interface for:
- Single request generation (streaming or non-streaming)
- Batch request generation (streaming or non-streaming)
"""
"""Unified inference engine for continuous batching."""
def __init__(
self,
@ -556,9 +132,7 @@ class InferenceEngine:
self.kv_cache = self.scheduler.kv_cache
self.seq_mask = self.scheduler.seq_mask
_debug("InferenceEngine: starting scheduler")
self.scheduler.start()
_debug("InferenceEngine: scheduler started")
def generate(
self,
@ -606,43 +180,29 @@ class InferenceEngine:
top_p: float,
top_k: int,
) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]:
"""Generate with streaming output (synchronous)."""
results = []
_debug(f"_generate_streaming: prompts={len(prompts)}")
"""Generate with streaming output."""
if is_batch:
raise NotImplementedError("Batch streaming is not implemented yet")
def make_callback(idx: int):
def cb(token: str):
_debug(f"callback[{idx}]: token={token!r}")
results.append(token)
result = _StreamingResult()
return cb
for i, p in enumerate(prompts):
_debug(f"_generate_streaming: adding task {i}: {p[:30]}...")
self.scheduler.add_task(
prompt=p,
prompt=prompts[0],
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=make_callback(i),
stream_callback=result.append,
)
def gen():
_debug("generator: start yielding")
while True:
# Yield accumulated tokens
while results:
token = results.pop(0)
tokens = result.pop_all()
for token in tokens:
if token == "[DONE]":
_debug("generator: got [DONE]")
return
_debug(f"generator: yielding {token!r}")
yield token
time.sleep(0.01)
result.wait(timeout=0.05)
return gen()
@ -656,19 +216,7 @@ class InferenceEngine:
top_k: int,
) -> Union[str, List[str]]:
"""Generate without streaming."""
results = ["" for _ in range(len(prompts))]
done_flags = [False] * len(prompts)
lock = threading.Lock()
def make_callback(idx: int):
def cb(token: str):
if token == "[DONE]":
done_flags[idx] = True
else:
with lock:
results[idx] += token
return cb
result = _NonStreamingResult(len(prompts))
for i, p in enumerate(prompts):
self.scheduler.add_task(
@ -677,12 +225,11 @@ class InferenceEngine:
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=make_callback(i),
stream_callback=result.append,
)
while not all(done_flags):
time.sleep(0.001)
result.wait()
results = result.get_results()
return results if is_batch else results[0]
def get_stats(self) -> Dict[str, Any]:

View File

@ -0,0 +1,387 @@
"""Inference scheduler for continuous batching."""
import threading
import time
import uuid
from typing import Any, Callable, Dict, List, Optional
import torch
from torch import Tensor
from astrai.config import ModelConfig
class TaskStatus:
"""Task state for continuous batching."""
PENDING = "pending"
RUNNING = "running"
FINISHED = "finished"
ABORTED = "aborted"
class Task:
"""Individual task for continuous batching."""
def __init__(
self,
task_id: str,
prompt_ids: List[int],
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None,
):
self.task_id = task_id
self.prompt_ids = prompt_ids
self.max_tokens = max_tokens
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.status = TaskStatus.PENDING
self.output_ids: List[int] = []
self.input_tokens: int = 0
self.output_tokens: int = 0
self.slot: int = -1
self.arrival_time = time.time()
self.finish_time: Optional[float] = None
self.stream_callback = stream_callback
def is_finished(self, stop_ids: List[int]) -> bool:
"""Check if task is finished."""
if self.output_ids and self.output_ids[-1] in stop_ids:
return True
if self.output_tokens >= self.max_tokens:
return True
return False
def apply_sampling_strategies(
logits: Tensor,
temperature: float,
top_k: int,
top_p: float,
filter_value: float = -float("inf"),
) -> Tensor:
"""Apply sampling strategies to the logits tensor."""
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
top_k = min(top_k, logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(
dim=1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
class InferenceScheduler:
"""Inference scheduler with continuous batching support."""
def __init__(
self,
model,
tokenizer,
config: ModelConfig,
max_batch_size: int = 16,
max_seq_len: Optional[int] = None,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
self.model = model
self.tokenizer = tokenizer
self.config = config
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len or config.max_len
self.device = device
self.dtype = dtype
num_heads = config.n_kv_heads
head_dim = config.dim // config.n_heads
n_layers = config.n_layers
k_cache = torch.empty(
(
max_batch_size,
self.max_seq_len,
n_layers,
num_heads,
head_dim,
),
device=device,
dtype=dtype,
)
v_cache = torch.empty(
(
max_batch_size,
self.max_seq_len,
n_layers,
num_heads,
head_dim,
),
device=device,
dtype=dtype,
)
self.kv_cache = (k_cache, v_cache)
self.seq_mask = torch.ones(
(max_batch_size, self.max_seq_len), device=device, dtype=torch.bool
)
self.waiting_queue: List[Task] = []
self.active_tasks: List[Task] = []
self._running = False
self._task_event = threading.Event()
self._lock = threading.Lock()
self._total_tasks = 0
self._total_tokens = 0
def add_task(
self,
prompt: str,
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None,
) -> str:
"""Add a new task to the waiting queue."""
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
prompt_ids = self.tokenizer.encode(prompt)
task = Task(
task_id=task_id,
prompt_ids=prompt_ids,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=stream_callback,
)
with self._lock:
self.waiting_queue.append(task)
self._total_tasks += 1
self._task_event.set()
return task_id
def remove_task(self, task_id: str) -> None:
"""Remove a task from the scheduler."""
with self._lock:
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
def _remove_finished_tasks(self) -> None:
"""Remove finished tasks from active batch."""
finished = []
for task in self.active_tasks:
if task.is_finished(self.tokenizer.stop_ids):
task.status = TaskStatus.FINISHED
task.finish_time = time.time()
finished.append(task)
self._total_tokens += task.output_tokens
for task in finished:
slot = task.slot
if slot >= 0 and slot < len(self.active_tasks):
self.seq_mask[slot, :] = False
task.slot = -1
self.active_tasks = [
t for t in self.active_tasks if t.status != TaskStatus.FINISHED
]
def _refill_active_batch(self) -> None:
"""Refill active batch with waiting tasks."""
available_slots = self.max_batch_size - len(self.active_tasks)
if available_slots <= 0:
return
with self._lock:
to_add = []
for _ in range(min(available_slots, len(self.waiting_queue))):
if self.waiting_queue:
task = self.waiting_queue.pop(0)
task.status = TaskStatus.RUNNING
to_add.append(task)
for task in to_add:
for i in range(self.max_batch_size):
if all(t.slot != i for t in self.active_tasks):
task.slot = i
break
self.active_tasks.append(task)
def _execute_prefill(self, tasks: List[Task]) -> None:
"""Execute Prefill phase."""
if not tasks:
return
tasks = sorted(tasks, key=lambda t: t.slot)
prompt_lens = [len(task.prompt_ids) for task in tasks]
max_len = max(prompt_lens)
input_ids = torch.zeros(
len(tasks), max_len, dtype=torch.long, device=self.device
)
for i, task in enumerate(tasks):
if len(task.prompt_ids) > 0:
input_ids[i, : len(task.prompt_ids)] = torch.tensor(
task.prompt_ids, device=self.device
)
if self.tokenizer.pad_id is not None:
input_mask = torch.ne(input_ids, self.tokenizer.pad_id)
else:
input_mask = torch.ones(
input_ids.shape, dtype=torch.bool, device=self.device
)
with torch.inference_mode():
outputs = self.model(
input_ids,
input_mask=input_mask,
start_pos=0,
persistent_key_values=self.kv_cache,
)
for i, task in enumerate(tasks):
task.input_tokens = prompt_lens[i]
task.output_tokens = 0
for task in tasks:
if task.slot >= 0:
self.seq_mask[task.slot, : task.input_tokens] = True
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
"""Execute Decode phase."""
if not tasks:
return
tasks = sorted(tasks, key=lambda t: t.slot)
input_ids = torch.zeros(len(tasks), dtype=torch.long, device=self.device)
for i, task in enumerate(tasks):
if task.output_ids:
input_ids[i] = task.output_ids[-1]
else:
input_ids[i] = task.prompt_ids[-1]
input_tensor = input_ids.unsqueeze(1)
active_mask = torch.ones((len(tasks), 1), dtype=torch.bool, device=self.device)
with torch.inference_mode():
outputs = self.model(
input_tensor,
input_mask=active_mask,
persistent_key_values=self.kv_cache,
start_pos=start_pos,
)
logits = outputs["logits"][:, -1, :]
next_token_ids = []
for i, task in enumerate(tasks):
logit = logits[i : i + 1]
logit = apply_sampling_strategies(
logit,
task.temperature,
task.top_k,
task.top_p,
)
probs = torch.softmax(logit, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
next_token_ids.append(next_token.item())
for task, next_token in zip(tasks, next_token_ids):
task.output_ids.append(next_token)
task.output_tokens += 1
pos = task.input_tokens + task.output_tokens
if task.slot >= 0 and pos < self.max_seq_len:
self.seq_mask[task.slot, pos] = True
if task.stream_callback:
token_str = self.tokenizer.decode([next_token])
task.stream_callback(token_str)
for task in tasks:
if task.output_tokens >= task.max_tokens or (
task.output_ids and task.output_ids[-1] in self.tokenizer.stop_ids
):
if task.stream_callback:
task.stream_callback("[DONE]")
def _run_generation_loop(self) -> None:
"""Main generation loop."""
while self._running:
self._remove_finished_tasks()
self._refill_active_batch()
if not self.active_tasks:
self._task_event.wait(timeout=0.01)
self._task_event.clear()
continue
new_tasks = [t for t in self.active_tasks if t.output_tokens == 0]
decode_tasks = [t for t in self.active_tasks if t.output_tokens > 0]
if decode_tasks:
start_pos = max(t.input_tokens + t.output_tokens for t in decode_tasks)
else:
start_pos = 0
if new_tasks:
self._execute_prefill(new_tasks)
decode_tasks = new_tasks
start_pos = max(t.input_tokens for t in decode_tasks)
if decode_tasks:
self._execute_decode(decode_tasks, start_pos)
if not self.active_tasks and not self.waiting_queue:
self._task_event.wait(timeout=0.05)
self._task_event.clear()
def start(self) -> None:
"""Start the generation loop."""
if not self._running:
self._running = True
self._loop_thread = threading.Thread(target=self._run_generation_loop)
self._loop_thread.daemon = True
self._loop_thread.start()
def stop(self) -> None:
"""Stop the generation loop."""
self._running = False
if hasattr(self, "_loop_thread"):
self._loop_thread.join(timeout=1.0)
def get_stats(self) -> Dict[str, Any]:
"""Get scheduler statistics."""
return {
"total_tasks": self._total_tasks,
"total_tokens": self._total_tokens,
"active_tasks": len(self.active_tasks),
"waiting_queue": len(self.waiting_queue),
}

View File

@ -4,6 +4,7 @@ import torch
from astrai.config.param_config import ModelParameter
def test_model_parameter(test_env):
save_dir = os.path.join(test_env["test_dir"], "save")
model_param = ModelParameter(