refactor: 拆分engine.py 文件

This commit is contained in:
ViperEkura 2026-04-05 00:07:21 +08:00
parent 861d33b1a1
commit 2b26f03bd3
4 changed files with 479 additions and 553 deletions

View File

@ -1,21 +1,10 @@
""" """Inference module for continuous batching."""
AstrAI Inference Module
Provides inference components for text generation with continuous batching support.
Main Components:
- InferenceEngine: Unified inference engine for continuous batching
- InferenceScheduler: Task scheduling with dynamic batch composition
- Task, TaskStatus: Task management for continuous batching
- GenerationRequest: Request parameters for generation
- apply_sampling_strategies: Sampling utilities for text generation
Author: AstrAI Team
"""
from astrai.inference.engine import ( from astrai.inference.engine import (
GenerationRequest, GenerationRequest,
InferenceEngine, InferenceEngine,
)
from astrai.inference.scheduler import (
InferenceScheduler, InferenceScheduler,
Task, Task,
TaskStatus, TaskStatus,
@ -25,9 +14,11 @@ from astrai.inference.engine import (
__all__ = [ __all__ = [
# Engine # Engine
"InferenceEngine", "InferenceEngine",
# Scheduler
"InferenceScheduler", "InferenceScheduler",
"Task", "Task",
"TaskStatus", "TaskStatus",
# Request
"GenerationRequest", "GenerationRequest",
# Sampling # Sampling
"apply_sampling_strategies", "apply_sampling_strategies",

View File

@ -1,49 +1,41 @@
""" """Unified inference engine."""
Continuous Batching Inference Engine
This module provides the main continuous batching components:
- Task: Individual generation task with state management
- TaskStatus: Task state enumeration
- InferenceScheduler: Handles request scheduling and KV cache management
- InferenceEngine: Unified inference engine
Author: AstrAI Team
"""
import threading import threading
import time from typing import Any, Dict, Generator, List, Optional, Union
import uuid
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
import torch from astrai.config import ModelParameter
from torch import Tensor from astrai.tokenize.chat_template import build_prompt
from astrai.config import ModelConfig, ModelParameter from astrai.inference.scheduler import InferenceScheduler
from astrai.tokenize.chat_template import HistoryType, build_prompt
# Use print for debugging instead of logging
def _debug(*args, **kwargs):
pass
@dataclass
class GenerationRequest: class GenerationRequest:
"""Request parameters for text generation.""" """Request parameters for text generation."""
top_k: int def __init__(
top_p: float self,
temperature: float query: Union[str, List[str]],
max_len: int top_k: int = 50,
top_p: float = 1.0,
temperature: float = 1.0,
max_len: int = 1024,
history: Optional[Any] = None,
system_prompt: Optional[str] = None,
stream: bool = False,
):
self.query = query
self.top_k = top_k
self.top_p = top_p
self.temperature = temperature
self.max_len = max_len
self.history = history
self.system_prompt = system_prompt
self.stream = stream
query: Union[str, List[str]] self._validate()
history: Optional[Union[HistoryType, List[HistoryType]]] = None
system_prompt: Optional[str] = None
stream: bool = False
def __post_init__(self): def _validate(self):
"""Validate request parameters."""
if not isinstance(self.top_k, int) or self.top_k < 0: if not isinstance(self.top_k, int) or self.top_k < 0:
raise ValueError("top_k must be a non-negative integer") raise ValueError("top_k must be a non-negative integer")
if not isinstance(self.top_p, float) or self.top_p < 0.0 or self.top_p > 1.0: if not isinstance(self.top_p, float) or self.top_p < 0.0 or self.top_p > 1.0:
@ -52,482 +44,66 @@ class GenerationRequest:
raise ValueError("temperature must be a non-negative float") raise ValueError("temperature must be a non-negative float")
class TaskStatus(Enum): class _StreamingResult:
"""Task state enumeration for continuous batching. """Streaming result holder with event-based notification."""
States: def __init__(self):
PENDING: Task is waiting to be scheduled self.tokens: List[str] = []
RUNNING: Task is currently being processed self._event = threading.Event()
FINISHED: Task completed successfully
ABORTED: Task was cancelled or failed
"""
PENDING = auto()
RUNNING = auto()
FINISHED = auto()
ABORTED = auto()
@dataclass
class Task:
"""Individual task for continuous batching.
Attributes:
task_id: Unique task identifier
prompt_ids: Input token IDs
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Top-p sampling parameter
top_k: Top-k sampling parameter
status: Current task status
output_ids: Generated token IDs
input_tokens: Number of input tokens
output_tokens: Number of output tokens generated
slot: Batch slot position (-1 if not assigned)
arrival_time: Task arrival timestamp
finish_time: Task completion timestamp
stream_callback: Callback for streaming output
"""
task_id: str
prompt_ids: List[int]
max_tokens: int = 1024
temperature: float = 1.0
top_p: float = 1.0
top_k: int = 50
status: TaskStatus = TaskStatus.PENDING
output_ids: List[int] = field(default_factory=list)
input_tokens: int = 0
output_tokens: int = 0
slot: int = -1
arrival_time: float = field(default_factory=time.time)
finish_time: Optional[float] = None
stream_callback: Optional[Callable[[str], None]] = None
def is_finished(self, stop_ids: List[int]) -> bool:
"""Check if task is finished."""
if self.output_ids and self.output_ids[-1] in stop_ids:
return True
if self.output_tokens >= self.max_tokens:
return True
return False
def apply_sampling_strategies(
logits: Tensor,
temperature: float,
top_k: int,
top_p: float,
filter_value: float = -float("inf"),
) -> Tensor:
"""Apply sampling strategies to the logits tensor."""
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
top_k = min(top_k, logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(
dim=1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
class InferenceScheduler:
"""Inference scheduler with continuous batching support.
Manages request scheduling, KV cache allocation, and generation loop.
Supports dynamic batch composition where new requests can join at any time
and completed requests are immediately released.
"""
def __init__(
self,
model,
tokenizer,
config: ModelConfig,
max_batch_size: int = 16,
max_seq_len: Optional[int] = None,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
self.model = model
self.tokenizer = tokenizer
self.config = config
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len or config.max_len
self.device = device
self.dtype = dtype
num_heads = config.n_kv_heads
head_dim = config.dim // config.n_heads
n_layers = config.n_layers
k_cache = torch.empty(
(
max_batch_size,
self.max_seq_len,
n_layers,
num_heads,
head_dim,
),
device=device,
dtype=dtype,
)
v_cache = torch.empty(
(
max_batch_size,
self.max_seq_len,
n_layers,
num_heads,
head_dim,
),
device=device,
dtype=dtype,
)
self.kv_cache = (k_cache, v_cache)
self.seq_mask = torch.ones(
(max_batch_size, self.max_seq_len), device=device, dtype=torch.bool
)
self.waiting_queue: List[Task] = []
self.active_tasks: List[Task] = []
self._running = False
self._task_event = threading.Event()
self._lock = threading.Lock() self._lock = threading.Lock()
self._total_tasks = 0 def append(self, token: str):
self._total_tokens = 0
def add_task(
self,
prompt: str,
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None,
) -> str:
"""Add a new task to the waiting queue."""
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
prompt_ids = self.tokenizer.encode(prompt)
_debug(
f"add_task: task_id={task_id}, prompt_len={len(prompt_ids)}, has_callback={stream_callback is not None}"
)
task = Task(
task_id=task_id,
prompt_ids=prompt_ids,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=stream_callback,
)
with self._lock: with self._lock:
self.waiting_queue.append(task) self.tokens.append(token)
self._total_tasks += 1 self._event.set()
self._task_event.set() def pop_all(self) -> List[str]:
return task_id
def remove_task(self, task_id: str) -> None:
"""Remove a task from the scheduler."""
with self._lock: with self._lock:
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id] tokens = self.tokens.copy()
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id] self.tokens.clear()
if not tokens:
self._event.clear()
return tokens
def _remove_finished_tasks(self) -> None: def wait(self, timeout: float = None) -> bool:
"""Remove finished tasks from active batch and update caches.""" return self._event.wait(timeout=timeout)
finished = []
for task in self.active_tasks:
if task.is_finished(self.tokenizer.stop_ids):
task.status = TaskStatus.FINISHED
task.finish_time = time.time()
finished.append(task)
self._total_tokens += task.output_tokens
for task in finished:
slot = task.slot
if slot >= 0 and slot < len(self.active_tasks):
self.seq_mask[slot, :] = False
task.slot = -1
self.active_tasks = [ class _NonStreamingResult:
t for t in self.active_tasks if t.status != TaskStatus.FINISHED """Non-streaming result holder with event-based completion notification."""
]
def _refill_active_batch(self) -> None: def __init__(self, count: int):
"""Refill active batch with waiting tasks.""" self.results: List[str] = ["" for _ in range(count)]
available_slots = self.max_batch_size - len(self.active_tasks) self.done_flags: List[bool] = [False] * count
if available_slots <= 0: self._completed_count = 0
return self._event = threading.Event()
self._lock = threading.Lock()
def append(self, idx: int, token: str):
with self._lock: with self._lock:
to_add = [] if token == "[DONE]":
for _ in range(min(available_slots, len(self.waiting_queue))): if not self.done_flags[idx]:
if self.waiting_queue: self.done_flags[idx] = True
task = self.waiting_queue.pop(0) self._completed_count += 1
task.status = TaskStatus.RUNNING if self._completed_count == len(self.results):
to_add.append(task) self._event.set()
for task in to_add:
for i in range(self.max_batch_size):
if all(t.slot != i for t in self.active_tasks):
task.slot = i
break
self.active_tasks.append(task)
def _execute_prefill(self, tasks: List[Task]) -> None:
"""Execute Prefill phase: process entire prompt at once."""
if not tasks:
return
_debug(f"_execute_prefill: processing {len(tasks)} tasks")
# Sort tasks by slot to ensure correct batch indexing with KV cache
tasks = sorted(tasks, key=lambda t: t.slot)
prompt_lens = [len(task.prompt_ids) for task in tasks]
max_len = max(prompt_lens)
input_ids = torch.zeros(
len(tasks), max_len, dtype=torch.long, device=self.device
)
for i, task in enumerate(tasks):
if len(task.prompt_ids) > 0:
input_ids[i, : len(task.prompt_ids)] = torch.tensor(
task.prompt_ids, device=self.device
)
# Create boolean mask for attention
if self.tokenizer.pad_id is not None:
input_mask = torch.ne(input_ids, self.tokenizer.pad_id)
else:
input_mask = torch.ones(
input_ids.shape, dtype=torch.bool, device=self.device
)
_debug(
f"_execute_prefill: input_ids shape={input_ids.shape}, max_len={max_len}"
)
try:
with torch.inference_mode():
outputs = self.model(
input_ids,
input_mask=input_mask,
start_pos=0,
persistent_key_values=self.kv_cache,
)
_debug(
f"_execute_prefill: model forward done, output keys={outputs.keys() if hasattr(outputs, 'keys') else 'no keys'}"
)
except Exception as e:
_debug(f"_execute_prefill: ERROR: {e}")
raise
for i, task in enumerate(tasks):
task.input_tokens = prompt_lens[i]
task.output_tokens = 0
_debug(
f" task {task.task_id}: input_tokens={task.input_tokens}, output_tokens={task.output_tokens}"
)
for task in tasks:
if task.slot >= 0:
self.seq_mask[task.slot, : task.input_tokens] = True
_debug(f"_execute_prefill: done, {len(tasks)} tasks marked as prefill complete")
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
"""Execute Decode phase: generate one token at a time."""
if not tasks:
return
_debug(f"_execute_decode: processing {len(tasks)} tasks, start_pos={start_pos}")
# Sort tasks by slot to ensure batch index aligns with slot (KV cache position)
# Task at slot 0 → batch index 0 → KV stored at cache[0]
# Task at slot 1 → batch index 1 → KV stored at cache[1]
tasks = sorted(tasks, key=lambda t: t.slot)
input_ids = torch.zeros(len(tasks), dtype=torch.long, device=self.device)
for i, task in enumerate(tasks):
if task.output_ids:
input_ids[i] = task.output_ids[-1]
else: else:
input_ids[i] = task.prompt_ids[-1] self.results[idx] += token
input_tensor = input_ids.unsqueeze(1) # shape: (batch, 1) def is_all_done(self) -> bool:
with self._lock:
return all(self.done_flags)
# Create 2D attention mask: (batch, seq_len) def wait(self, timeout: float = None) -> bool:
active_mask = torch.ones((len(tasks), 1), dtype=torch.bool, device=self.device) return self._event.wait(timeout=timeout)
_debug(
f"_execute_decode: input_tensor shape={input_tensor.shape}, active_mask shape={active_mask.shape}"
)
try: def get_results(self) -> List[str]:
with torch.inference_mode(): with self._lock:
outputs = self.model( return self.results.copy()
input_tensor,
input_mask=active_mask,
persistent_key_values=self.kv_cache,
start_pos=start_pos,
)
_debug(
f"_execute_decode: model forward done, logits shape={outputs['logits'].shape}"
)
logits = outputs["logits"][:, -1, :]
except Exception as e:
_debug(f"_execute_decode: ERROR: {e}")
raise
next_token_ids = []
for i, task in enumerate(tasks):
logit = logits[i : i + 1]
logit = apply_sampling_strategies(
logit,
task.temperature,
task.top_k,
task.top_p,
)
probs = torch.softmax(logit, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
next_token_ids.append(next_token.item())
_debug(f"_execute_decode: next_tokens={next_token_ids}")
for task, next_token in zip(tasks, next_token_ids):
task.output_ids.append(next_token)
task.output_tokens += 1
pos = task.input_tokens + task.output_tokens
if task.slot >= 0 and pos < self.max_seq_len:
self.seq_mask[task.slot, pos] = True
if task.stream_callback:
token_str = self.tokenizer.decode([next_token])
task.stream_callback(token_str)
# Check if any task reached max_tokens or stop token
for task in tasks:
if task.output_tokens >= task.max_tokens or (
task.output_ids and task.output_ids[-1] in self.tokenizer.stop_ids
):
_debug(
f"decode: task {task.task_id} finished, output_tokens={task.output_tokens}, max_tokens={task.max_tokens}"
)
if task.stream_callback:
task.stream_callback("[DONE]")
def _run_generation_loop(self) -> None:
"""Main generation loop with continuous batching."""
_debug("generation_loop: started")
while self._running:
self._remove_finished_tasks()
self._refill_active_batch()
if not self.active_tasks:
self._task_event.wait(timeout=0.01)
self._task_event.clear()
continue
_debug(
f"generation_loop: active={len(self.active_tasks)}, waiting={len(self.waiting_queue)}"
)
new_tasks = [t for t in self.active_tasks if t.output_tokens == 0]
decode_tasks = [t for t in self.active_tasks if t.output_tokens > 0]
_debug(
f"generation_loop: new_tasks={len(new_tasks)}, decode_tasks={len(decode_tasks)}"
)
for t in self.active_tasks:
_debug(
f" active task {t.task_id}: output_tokens={t.output_tokens}, input_tokens={t.input_tokens}"
)
if decode_tasks:
start_pos = max(t.input_tokens + t.output_tokens for t in decode_tasks)
else:
start_pos = 0
# First run prefill for new tasks
if new_tasks:
_debug(f"generation_loop: running prefill for {len(new_tasks)} tasks")
self._execute_prefill(new_tasks)
_debug(f"generation_loop: prefill done")
# After prefill, convert these tasks to decode tasks in the same iteration
decode_tasks = new_tasks
start_pos = max(t.input_tokens for t in decode_tasks)
_debug(
f"generation_loop: after prefill, decode_tasks={len(decode_tasks)}, start_pos={start_pos}"
)
if decode_tasks:
_debug(
f"generation_loop: running decode for {len(decode_tasks)} tasks, start_pos={start_pos}"
)
self._execute_decode(decode_tasks, start_pos)
_debug(f"generation_loop: decode done")
if not self.active_tasks and not self.waiting_queue:
time.sleep(0.001)
def start(self) -> None:
"""Start the generation loop in a background thread."""
if not self._running:
_debug("InferenceScheduler.start: starting loop thread")
self._running = True
self._loop_thread = threading.Thread(target=self._run_generation_loop)
self._loop_thread.daemon = True
self._loop_thread.start()
_debug("InferenceScheduler.start: loop thread started")
def stop(self) -> None:
"""Stop the generation loop."""
self._running = False
if hasattr(self, "_loop_thread"):
self._loop_thread.join(timeout=1.0)
def get_stats(self) -> Dict[str, Any]:
"""Get scheduler statistics."""
return {
"total_tasks": self._total_tasks,
"total_tokens": self._total_tokens,
"active_tasks": len(self.active_tasks),
"waiting_queue": len(self.waiting_queue),
}
class InferenceEngine: class InferenceEngine:
"""Unified inference engine for continuous batching. """Unified inference engine for continuous batching."""
Provides a single interface for:
- Single request generation (streaming or non-streaming)
- Batch request generation (streaming or non-streaming)
"""
def __init__( def __init__(
self, self,
@ -556,9 +132,7 @@ class InferenceEngine:
self.kv_cache = self.scheduler.kv_cache self.kv_cache = self.scheduler.kv_cache
self.seq_mask = self.scheduler.seq_mask self.seq_mask = self.scheduler.seq_mask
_debug("InferenceEngine: starting scheduler")
self.scheduler.start() self.scheduler.start()
_debug("InferenceEngine: scheduler started")
def generate( def generate(
self, self,
@ -606,43 +180,29 @@ class InferenceEngine:
top_p: float, top_p: float,
top_k: int, top_k: int,
) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]: ) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]:
"""Generate with streaming output (synchronous).""" """Generate with streaming output."""
results = []
_debug(f"_generate_streaming: prompts={len(prompts)}")
if is_batch: if is_batch:
raise NotImplementedError("Batch streaming is not implemented yet") raise NotImplementedError("Batch streaming is not implemented yet")
def make_callback(idx: int): result = _StreamingResult()
def cb(token: str):
_debug(f"callback[{idx}]: token={token!r}")
results.append(token)
return cb self.scheduler.add_task(
prompt=prompts[0],
for i, p in enumerate(prompts): max_tokens=max_tokens,
_debug(f"_generate_streaming: adding task {i}: {p[:30]}...") temperature=temperature,
self.scheduler.add_task( top_p=top_p,
prompt=p, top_k=top_k,
max_tokens=max_tokens, stream_callback=result.append,
temperature=temperature, )
top_p=top_p,
top_k=top_k,
stream_callback=make_callback(i),
)
def gen(): def gen():
_debug("generator: start yielding")
while True: while True:
# Yield accumulated tokens tokens = result.pop_all()
while results: for token in tokens:
token = results.pop(0)
if token == "[DONE]": if token == "[DONE]":
_debug("generator: got [DONE]")
return return
_debug(f"generator: yielding {token!r}")
yield token yield token
time.sleep(0.01) result.wait(timeout=0.05)
return gen() return gen()
@ -656,19 +216,7 @@ class InferenceEngine:
top_k: int, top_k: int,
) -> Union[str, List[str]]: ) -> Union[str, List[str]]:
"""Generate without streaming.""" """Generate without streaming."""
results = ["" for _ in range(len(prompts))] result = _NonStreamingResult(len(prompts))
done_flags = [False] * len(prompts)
lock = threading.Lock()
def make_callback(idx: int):
def cb(token: str):
if token == "[DONE]":
done_flags[idx] = True
else:
with lock:
results[idx] += token
return cb
for i, p in enumerate(prompts): for i, p in enumerate(prompts):
self.scheduler.add_task( self.scheduler.add_task(
@ -677,12 +225,11 @@ class InferenceEngine:
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
stream_callback=make_callback(i), stream_callback=result.append,
) )
while not all(done_flags): result.wait()
time.sleep(0.001) results = result.get_results()
return results if is_batch else results[0] return results if is_batch else results[0]
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> Dict[str, Any]:

View File

@ -0,0 +1,387 @@
"""Inference scheduler for continuous batching."""
import threading
import time
import uuid
from typing import Any, Callable, Dict, List, Optional
import torch
from torch import Tensor
from astrai.config import ModelConfig
class TaskStatus:
"""Task state for continuous batching."""
PENDING = "pending"
RUNNING = "running"
FINISHED = "finished"
ABORTED = "aborted"
class Task:
"""Individual task for continuous batching."""
def __init__(
self,
task_id: str,
prompt_ids: List[int],
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None,
):
self.task_id = task_id
self.prompt_ids = prompt_ids
self.max_tokens = max_tokens
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.status = TaskStatus.PENDING
self.output_ids: List[int] = []
self.input_tokens: int = 0
self.output_tokens: int = 0
self.slot: int = -1
self.arrival_time = time.time()
self.finish_time: Optional[float] = None
self.stream_callback = stream_callback
def is_finished(self, stop_ids: List[int]) -> bool:
"""Check if task is finished."""
if self.output_ids and self.output_ids[-1] in stop_ids:
return True
if self.output_tokens >= self.max_tokens:
return True
return False
def apply_sampling_strategies(
logits: Tensor,
temperature: float,
top_k: int,
top_p: float,
filter_value: float = -float("inf"),
) -> Tensor:
"""Apply sampling strategies to the logits tensor."""
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
top_k = min(top_k, logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(
dim=1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
class InferenceScheduler:
"""Inference scheduler with continuous batching support."""
def __init__(
self,
model,
tokenizer,
config: ModelConfig,
max_batch_size: int = 16,
max_seq_len: Optional[int] = None,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
self.model = model
self.tokenizer = tokenizer
self.config = config
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len or config.max_len
self.device = device
self.dtype = dtype
num_heads = config.n_kv_heads
head_dim = config.dim // config.n_heads
n_layers = config.n_layers
k_cache = torch.empty(
(
max_batch_size,
self.max_seq_len,
n_layers,
num_heads,
head_dim,
),
device=device,
dtype=dtype,
)
v_cache = torch.empty(
(
max_batch_size,
self.max_seq_len,
n_layers,
num_heads,
head_dim,
),
device=device,
dtype=dtype,
)
self.kv_cache = (k_cache, v_cache)
self.seq_mask = torch.ones(
(max_batch_size, self.max_seq_len), device=device, dtype=torch.bool
)
self.waiting_queue: List[Task] = []
self.active_tasks: List[Task] = []
self._running = False
self._task_event = threading.Event()
self._lock = threading.Lock()
self._total_tasks = 0
self._total_tokens = 0
def add_task(
self,
prompt: str,
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None,
) -> str:
"""Add a new task to the waiting queue."""
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
prompt_ids = self.tokenizer.encode(prompt)
task = Task(
task_id=task_id,
prompt_ids=prompt_ids,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=stream_callback,
)
with self._lock:
self.waiting_queue.append(task)
self._total_tasks += 1
self._task_event.set()
return task_id
def remove_task(self, task_id: str) -> None:
"""Remove a task from the scheduler."""
with self._lock:
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
def _remove_finished_tasks(self) -> None:
"""Remove finished tasks from active batch."""
finished = []
for task in self.active_tasks:
if task.is_finished(self.tokenizer.stop_ids):
task.status = TaskStatus.FINISHED
task.finish_time = time.time()
finished.append(task)
self._total_tokens += task.output_tokens
for task in finished:
slot = task.slot
if slot >= 0 and slot < len(self.active_tasks):
self.seq_mask[slot, :] = False
task.slot = -1
self.active_tasks = [
t for t in self.active_tasks if t.status != TaskStatus.FINISHED
]
def _refill_active_batch(self) -> None:
"""Refill active batch with waiting tasks."""
available_slots = self.max_batch_size - len(self.active_tasks)
if available_slots <= 0:
return
with self._lock:
to_add = []
for _ in range(min(available_slots, len(self.waiting_queue))):
if self.waiting_queue:
task = self.waiting_queue.pop(0)
task.status = TaskStatus.RUNNING
to_add.append(task)
for task in to_add:
for i in range(self.max_batch_size):
if all(t.slot != i for t in self.active_tasks):
task.slot = i
break
self.active_tasks.append(task)
def _execute_prefill(self, tasks: List[Task]) -> None:
"""Execute Prefill phase."""
if not tasks:
return
tasks = sorted(tasks, key=lambda t: t.slot)
prompt_lens = [len(task.prompt_ids) for task in tasks]
max_len = max(prompt_lens)
input_ids = torch.zeros(
len(tasks), max_len, dtype=torch.long, device=self.device
)
for i, task in enumerate(tasks):
if len(task.prompt_ids) > 0:
input_ids[i, : len(task.prompt_ids)] = torch.tensor(
task.prompt_ids, device=self.device
)
if self.tokenizer.pad_id is not None:
input_mask = torch.ne(input_ids, self.tokenizer.pad_id)
else:
input_mask = torch.ones(
input_ids.shape, dtype=torch.bool, device=self.device
)
with torch.inference_mode():
outputs = self.model(
input_ids,
input_mask=input_mask,
start_pos=0,
persistent_key_values=self.kv_cache,
)
for i, task in enumerate(tasks):
task.input_tokens = prompt_lens[i]
task.output_tokens = 0
for task in tasks:
if task.slot >= 0:
self.seq_mask[task.slot, : task.input_tokens] = True
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
"""Execute Decode phase."""
if not tasks:
return
tasks = sorted(tasks, key=lambda t: t.slot)
input_ids = torch.zeros(len(tasks), dtype=torch.long, device=self.device)
for i, task in enumerate(tasks):
if task.output_ids:
input_ids[i] = task.output_ids[-1]
else:
input_ids[i] = task.prompt_ids[-1]
input_tensor = input_ids.unsqueeze(1)
active_mask = torch.ones((len(tasks), 1), dtype=torch.bool, device=self.device)
with torch.inference_mode():
outputs = self.model(
input_tensor,
input_mask=active_mask,
persistent_key_values=self.kv_cache,
start_pos=start_pos,
)
logits = outputs["logits"][:, -1, :]
next_token_ids = []
for i, task in enumerate(tasks):
logit = logits[i : i + 1]
logit = apply_sampling_strategies(
logit,
task.temperature,
task.top_k,
task.top_p,
)
probs = torch.softmax(logit, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
next_token_ids.append(next_token.item())
for task, next_token in zip(tasks, next_token_ids):
task.output_ids.append(next_token)
task.output_tokens += 1
pos = task.input_tokens + task.output_tokens
if task.slot >= 0 and pos < self.max_seq_len:
self.seq_mask[task.slot, pos] = True
if task.stream_callback:
token_str = self.tokenizer.decode([next_token])
task.stream_callback(token_str)
for task in tasks:
if task.output_tokens >= task.max_tokens or (
task.output_ids and task.output_ids[-1] in self.tokenizer.stop_ids
):
if task.stream_callback:
task.stream_callback("[DONE]")
def _run_generation_loop(self) -> None:
"""Main generation loop."""
while self._running:
self._remove_finished_tasks()
self._refill_active_batch()
if not self.active_tasks:
self._task_event.wait(timeout=0.01)
self._task_event.clear()
continue
new_tasks = [t for t in self.active_tasks if t.output_tokens == 0]
decode_tasks = [t for t in self.active_tasks if t.output_tokens > 0]
if decode_tasks:
start_pos = max(t.input_tokens + t.output_tokens for t in decode_tasks)
else:
start_pos = 0
if new_tasks:
self._execute_prefill(new_tasks)
decode_tasks = new_tasks
start_pos = max(t.input_tokens for t in decode_tasks)
if decode_tasks:
self._execute_decode(decode_tasks, start_pos)
if not self.active_tasks and not self.waiting_queue:
self._task_event.wait(timeout=0.05)
self._task_event.clear()
def start(self) -> None:
"""Start the generation loop."""
if not self._running:
self._running = True
self._loop_thread = threading.Thread(target=self._run_generation_loop)
self._loop_thread.daemon = True
self._loop_thread.start()
def stop(self) -> None:
"""Stop the generation loop."""
self._running = False
if hasattr(self, "_loop_thread"):
self._loop_thread.join(timeout=1.0)
def get_stats(self) -> Dict[str, Any]:
"""Get scheduler statistics."""
return {
"total_tasks": self._total_tasks,
"total_tokens": self._total_tokens,
"active_tasks": len(self.active_tasks),
"waiting_queue": len(self.waiting_queue),
}

View File

@ -4,6 +4,7 @@ import torch
from astrai.config.param_config import ModelParameter from astrai.config.param_config import ModelParameter
def test_model_parameter(test_env): def test_model_parameter(test_env):
save_dir = os.path.join(test_env["test_dir"], "save") save_dir = os.path.join(test_env["test_dir"], "save")
model_param = ModelParameter( model_param = ModelParameter(
@ -30,4 +31,4 @@ def test_transformer(test_env):
test_env["transformer_config"].max_len, test_env["transformer_config"].max_len,
test_env["transformer_config"].vocab_size, test_env["transformer_config"].vocab_size,
) )
assert output_logits.shape == target_shape assert output_logits.shape == target_shape