From 2b26f03bd33524f8fe78668fa8c34fb1373224eb Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 5 Apr 2026 00:07:21 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E6=8B=86=E5=88=86engine.py=20?= =?UTF-8?q?=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrai/inference/__init__.py | 19 +- astrai/inference/engine.py | 623 +++++----------------------------- astrai/inference/scheduler.py | 387 +++++++++++++++++++++ tests/module/test_module.py | 3 +- 4 files changed, 479 insertions(+), 553 deletions(-) create mode 100644 astrai/inference/scheduler.py diff --git a/astrai/inference/__init__.py b/astrai/inference/__init__.py index 9c082c1..b3f3e2e 100644 --- a/astrai/inference/__init__.py +++ b/astrai/inference/__init__.py @@ -1,21 +1,10 @@ -""" -AstrAI Inference Module - -Provides inference components for text generation with continuous batching support. - -Main Components: -- InferenceEngine: Unified inference engine for continuous batching -- InferenceScheduler: Task scheduling with dynamic batch composition -- Task, TaskStatus: Task management for continuous batching -- GenerationRequest: Request parameters for generation -- apply_sampling_strategies: Sampling utilities for text generation - -Author: AstrAI Team -""" +"""Inference module for continuous batching.""" from astrai.inference.engine import ( GenerationRequest, InferenceEngine, +) +from astrai.inference.scheduler import ( InferenceScheduler, Task, TaskStatus, @@ -25,9 +14,11 @@ from astrai.inference.engine import ( __all__ = [ # Engine "InferenceEngine", + # Scheduler "InferenceScheduler", "Task", "TaskStatus", + # Request "GenerationRequest", # Sampling "apply_sampling_strategies", diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py index 7381738..5c86c53 100644 --- a/astrai/inference/engine.py +++ b/astrai/inference/engine.py @@ -1,49 +1,41 @@ -""" -Continuous Batching Inference Engine - -This module provides the main continuous batching components: -- Task: Individual generation task with state management -- TaskStatus: Task state enumeration -- InferenceScheduler: Handles request scheduling and KV cache management -- InferenceEngine: Unified inference engine - -Author: AstrAI Team -""" +"""Unified inference engine.""" import threading -import time -import uuid -from dataclasses import dataclass, field -from enum import Enum, auto -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union +from typing import Any, Dict, Generator, List, Optional, Union -import torch -from torch import Tensor +from astrai.config import ModelParameter +from astrai.tokenize.chat_template import build_prompt -from astrai.config import ModelConfig, ModelParameter -from astrai.tokenize.chat_template import HistoryType, build_prompt +from astrai.inference.scheduler import InferenceScheduler -# Use print for debugging instead of logging -def _debug(*args, **kwargs): - pass - - -@dataclass class GenerationRequest: """Request parameters for text generation.""" - top_k: int - top_p: float - temperature: float - max_len: int + def __init__( + self, + query: Union[str, List[str]], + top_k: int = 50, + top_p: float = 1.0, + temperature: float = 1.0, + max_len: int = 1024, + history: Optional[Any] = None, + system_prompt: Optional[str] = None, + stream: bool = False, + ): + self.query = query + self.top_k = top_k + self.top_p = top_p + self.temperature = temperature + self.max_len = max_len + self.history = history + self.system_prompt = system_prompt + self.stream = stream - query: Union[str, List[str]] - history: Optional[Union[HistoryType, List[HistoryType]]] = None - system_prompt: Optional[str] = None - stream: bool = False + self._validate() - def __post_init__(self): + def _validate(self): + """Validate request parameters.""" if not isinstance(self.top_k, int) or self.top_k < 0: raise ValueError("top_k must be a non-negative integer") if not isinstance(self.top_p, float) or self.top_p < 0.0 or self.top_p > 1.0: @@ -52,482 +44,66 @@ class GenerationRequest: raise ValueError("temperature must be a non-negative float") -class TaskStatus(Enum): - """Task state enumeration for continuous batching. +class _StreamingResult: + """Streaming result holder with event-based notification.""" - States: - PENDING: Task is waiting to be scheduled - RUNNING: Task is currently being processed - FINISHED: Task completed successfully - ABORTED: Task was cancelled or failed - """ - - PENDING = auto() - RUNNING = auto() - FINISHED = auto() - ABORTED = auto() - - -@dataclass -class Task: - """Individual task for continuous batching. - - Attributes: - task_id: Unique task identifier - prompt_ids: Input token IDs - max_tokens: Maximum tokens to generate - temperature: Sampling temperature - top_p: Top-p sampling parameter - top_k: Top-k sampling parameter - status: Current task status - output_ids: Generated token IDs - input_tokens: Number of input tokens - output_tokens: Number of output tokens generated - slot: Batch slot position (-1 if not assigned) - arrival_time: Task arrival timestamp - finish_time: Task completion timestamp - stream_callback: Callback for streaming output - """ - - task_id: str - prompt_ids: List[int] - max_tokens: int = 1024 - temperature: float = 1.0 - top_p: float = 1.0 - top_k: int = 50 - - status: TaskStatus = TaskStatus.PENDING - output_ids: List[int] = field(default_factory=list) - input_tokens: int = 0 - output_tokens: int = 0 - slot: int = -1 - arrival_time: float = field(default_factory=time.time) - finish_time: Optional[float] = None - - stream_callback: Optional[Callable[[str], None]] = None - - def is_finished(self, stop_ids: List[int]) -> bool: - """Check if task is finished.""" - if self.output_ids and self.output_ids[-1] in stop_ids: - return True - if self.output_tokens >= self.max_tokens: - return True - return False - - -def apply_sampling_strategies( - logits: Tensor, - temperature: float, - top_k: int, - top_p: float, - filter_value: float = -float("inf"), -) -> Tensor: - """Apply sampling strategies to the logits tensor.""" - if temperature != 1.0: - logits = logits / temperature - - if top_k > 0: - top_k = min(top_k, logits.size(-1)) - indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None] - logits[indices_to_remove] = filter_value - - if top_p < 1.0: - sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) - cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) - - sorted_indices_to_remove = cumulative_probs > top_p - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - - indices_to_remove = torch.zeros_like(logits, dtype=torch.bool) - indices_to_remove.scatter_( - dim=1, index=sorted_indices, src=sorted_indices_to_remove - ) - - logits[indices_to_remove] = filter_value - - return logits - - -class InferenceScheduler: - """Inference scheduler with continuous batching support. - - Manages request scheduling, KV cache allocation, and generation loop. - Supports dynamic batch composition where new requests can join at any time - and completed requests are immediately released. - """ - - def __init__( - self, - model, - tokenizer, - config: ModelConfig, - max_batch_size: int = 16, - max_seq_len: Optional[int] = None, - device: str = "cuda", - dtype: torch.dtype = torch.bfloat16, - ): - self.model = model - self.tokenizer = tokenizer - self.config = config - self.max_batch_size = max_batch_size - self.max_seq_len = max_seq_len or config.max_len - self.device = device - self.dtype = dtype - - num_heads = config.n_kv_heads - head_dim = config.dim // config.n_heads - n_layers = config.n_layers - - k_cache = torch.empty( - ( - max_batch_size, - self.max_seq_len, - n_layers, - num_heads, - head_dim, - ), - device=device, - dtype=dtype, - ) - v_cache = torch.empty( - ( - max_batch_size, - self.max_seq_len, - n_layers, - num_heads, - head_dim, - ), - device=device, - dtype=dtype, - ) - self.kv_cache = (k_cache, v_cache) - self.seq_mask = torch.ones( - (max_batch_size, self.max_seq_len), device=device, dtype=torch.bool - ) - - self.waiting_queue: List[Task] = [] - self.active_tasks: List[Task] = [] - - self._running = False - self._task_event = threading.Event() + def __init__(self): + self.tokens: List[str] = [] + self._event = threading.Event() self._lock = threading.Lock() - self._total_tasks = 0 - self._total_tokens = 0 - - def add_task( - self, - prompt: str, - max_tokens: int = 1024, - temperature: float = 1.0, - top_p: float = 1.0, - top_k: int = 50, - stream_callback: Optional[Callable[[str], None]] = None, - ) -> str: - """Add a new task to the waiting queue.""" - task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}" - prompt_ids = self.tokenizer.encode(prompt) - - _debug( - f"add_task: task_id={task_id}, prompt_len={len(prompt_ids)}, has_callback={stream_callback is not None}" - ) - - task = Task( - task_id=task_id, - prompt_ids=prompt_ids, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - top_k=top_k, - stream_callback=stream_callback, - ) - + def append(self, token: str): with self._lock: - self.waiting_queue.append(task) - self._total_tasks += 1 + self.tokens.append(token) + self._event.set() - self._task_event.set() - return task_id - - def remove_task(self, task_id: str) -> None: - """Remove a task from the scheduler.""" + def pop_all(self) -> List[str]: with self._lock: - self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id] - self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id] + tokens = self.tokens.copy() + self.tokens.clear() + if not tokens: + self._event.clear() + return tokens - def _remove_finished_tasks(self) -> None: - """Remove finished tasks from active batch and update caches.""" - finished = [] - for task in self.active_tasks: - if task.is_finished(self.tokenizer.stop_ids): - task.status = TaskStatus.FINISHED - task.finish_time = time.time() - finished.append(task) - self._total_tokens += task.output_tokens + def wait(self, timeout: float = None) -> bool: + return self._event.wait(timeout=timeout) - for task in finished: - slot = task.slot - if slot >= 0 and slot < len(self.active_tasks): - self.seq_mask[slot, :] = False - task.slot = -1 - self.active_tasks = [ - t for t in self.active_tasks if t.status != TaskStatus.FINISHED - ] +class _NonStreamingResult: + """Non-streaming result holder with event-based completion notification.""" - def _refill_active_batch(self) -> None: - """Refill active batch with waiting tasks.""" - available_slots = self.max_batch_size - len(self.active_tasks) - if available_slots <= 0: - return + def __init__(self, count: int): + self.results: List[str] = ["" for _ in range(count)] + self.done_flags: List[bool] = [False] * count + self._completed_count = 0 + self._event = threading.Event() + self._lock = threading.Lock() + def append(self, idx: int, token: str): with self._lock: - to_add = [] - for _ in range(min(available_slots, len(self.waiting_queue))): - if self.waiting_queue: - task = self.waiting_queue.pop(0) - task.status = TaskStatus.RUNNING - to_add.append(task) - - for task in to_add: - for i in range(self.max_batch_size): - if all(t.slot != i for t in self.active_tasks): - task.slot = i - break - self.active_tasks.append(task) - - def _execute_prefill(self, tasks: List[Task]) -> None: - """Execute Prefill phase: process entire prompt at once.""" - if not tasks: - return - - _debug(f"_execute_prefill: processing {len(tasks)} tasks") - - # Sort tasks by slot to ensure correct batch indexing with KV cache - tasks = sorted(tasks, key=lambda t: t.slot) - - prompt_lens = [len(task.prompt_ids) for task in tasks] - max_len = max(prompt_lens) - - input_ids = torch.zeros( - len(tasks), max_len, dtype=torch.long, device=self.device - ) - for i, task in enumerate(tasks): - if len(task.prompt_ids) > 0: - input_ids[i, : len(task.prompt_ids)] = torch.tensor( - task.prompt_ids, device=self.device - ) - - # Create boolean mask for attention - if self.tokenizer.pad_id is not None: - input_mask = torch.ne(input_ids, self.tokenizer.pad_id) - else: - input_mask = torch.ones( - input_ids.shape, dtype=torch.bool, device=self.device - ) - - _debug( - f"_execute_prefill: input_ids shape={input_ids.shape}, max_len={max_len}" - ) - - try: - with torch.inference_mode(): - outputs = self.model( - input_ids, - input_mask=input_mask, - start_pos=0, - persistent_key_values=self.kv_cache, - ) - _debug( - f"_execute_prefill: model forward done, output keys={outputs.keys() if hasattr(outputs, 'keys') else 'no keys'}" - ) - except Exception as e: - _debug(f"_execute_prefill: ERROR: {e}") - raise - - for i, task in enumerate(tasks): - task.input_tokens = prompt_lens[i] - task.output_tokens = 0 - _debug( - f" task {task.task_id}: input_tokens={task.input_tokens}, output_tokens={task.output_tokens}" - ) - - for task in tasks: - if task.slot >= 0: - self.seq_mask[task.slot, : task.input_tokens] = True - - _debug(f"_execute_prefill: done, {len(tasks)} tasks marked as prefill complete") - - def _execute_decode(self, tasks: List[Task], start_pos: int) -> None: - """Execute Decode phase: generate one token at a time.""" - if not tasks: - return - - _debug(f"_execute_decode: processing {len(tasks)} tasks, start_pos={start_pos}") - - # Sort tasks by slot to ensure batch index aligns with slot (KV cache position) - # Task at slot 0 → batch index 0 → KV stored at cache[0] - # Task at slot 1 → batch index 1 → KV stored at cache[1] - tasks = sorted(tasks, key=lambda t: t.slot) - - input_ids = torch.zeros(len(tasks), dtype=torch.long, device=self.device) - for i, task in enumerate(tasks): - if task.output_ids: - input_ids[i] = task.output_ids[-1] + if token == "[DONE]": + if not self.done_flags[idx]: + self.done_flags[idx] = True + self._completed_count += 1 + if self._completed_count == len(self.results): + self._event.set() else: - input_ids[i] = task.prompt_ids[-1] + self.results[idx] += token - input_tensor = input_ids.unsqueeze(1) # shape: (batch, 1) + def is_all_done(self) -> bool: + with self._lock: + return all(self.done_flags) - # Create 2D attention mask: (batch, seq_len) - active_mask = torch.ones((len(tasks), 1), dtype=torch.bool, device=self.device) - _debug( - f"_execute_decode: input_tensor shape={input_tensor.shape}, active_mask shape={active_mask.shape}" - ) + def wait(self, timeout: float = None) -> bool: + return self._event.wait(timeout=timeout) - try: - with torch.inference_mode(): - outputs = self.model( - input_tensor, - input_mask=active_mask, - persistent_key_values=self.kv_cache, - start_pos=start_pos, - ) - _debug( - f"_execute_decode: model forward done, logits shape={outputs['logits'].shape}" - ) - logits = outputs["logits"][:, -1, :] - except Exception as e: - _debug(f"_execute_decode: ERROR: {e}") - raise - - next_token_ids = [] - for i, task in enumerate(tasks): - logit = logits[i : i + 1] - logit = apply_sampling_strategies( - logit, - task.temperature, - task.top_k, - task.top_p, - ) - probs = torch.softmax(logit, dim=-1) - next_token = torch.multinomial(probs, num_samples=1) - next_token_ids.append(next_token.item()) - - _debug(f"_execute_decode: next_tokens={next_token_ids}") - - for task, next_token in zip(tasks, next_token_ids): - task.output_ids.append(next_token) - task.output_tokens += 1 - - pos = task.input_tokens + task.output_tokens - if task.slot >= 0 and pos < self.max_seq_len: - self.seq_mask[task.slot, pos] = True - - if task.stream_callback: - token_str = self.tokenizer.decode([next_token]) - task.stream_callback(token_str) - - # Check if any task reached max_tokens or stop token - for task in tasks: - if task.output_tokens >= task.max_tokens or ( - task.output_ids and task.output_ids[-1] in self.tokenizer.stop_ids - ): - _debug( - f"decode: task {task.task_id} finished, output_tokens={task.output_tokens}, max_tokens={task.max_tokens}" - ) - if task.stream_callback: - task.stream_callback("[DONE]") - - def _run_generation_loop(self) -> None: - """Main generation loop with continuous batching.""" - _debug("generation_loop: started") - while self._running: - self._remove_finished_tasks() - self._refill_active_batch() - - if not self.active_tasks: - self._task_event.wait(timeout=0.01) - self._task_event.clear() - continue - - _debug( - f"generation_loop: active={len(self.active_tasks)}, waiting={len(self.waiting_queue)}" - ) - - new_tasks = [t for t in self.active_tasks if t.output_tokens == 0] - decode_tasks = [t for t in self.active_tasks if t.output_tokens > 0] - - _debug( - f"generation_loop: new_tasks={len(new_tasks)}, decode_tasks={len(decode_tasks)}" - ) - for t in self.active_tasks: - _debug( - f" active task {t.task_id}: output_tokens={t.output_tokens}, input_tokens={t.input_tokens}" - ) - - if decode_tasks: - start_pos = max(t.input_tokens + t.output_tokens for t in decode_tasks) - else: - start_pos = 0 - - # First run prefill for new tasks - if new_tasks: - _debug(f"generation_loop: running prefill for {len(new_tasks)} tasks") - self._execute_prefill(new_tasks) - _debug(f"generation_loop: prefill done") - - # After prefill, convert these tasks to decode tasks in the same iteration - decode_tasks = new_tasks - start_pos = max(t.input_tokens for t in decode_tasks) - _debug( - f"generation_loop: after prefill, decode_tasks={len(decode_tasks)}, start_pos={start_pos}" - ) - - if decode_tasks: - _debug( - f"generation_loop: running decode for {len(decode_tasks)} tasks, start_pos={start_pos}" - ) - self._execute_decode(decode_tasks, start_pos) - _debug(f"generation_loop: decode done") - - if not self.active_tasks and not self.waiting_queue: - time.sleep(0.001) - - def start(self) -> None: - """Start the generation loop in a background thread.""" - if not self._running: - _debug("InferenceScheduler.start: starting loop thread") - self._running = True - self._loop_thread = threading.Thread(target=self._run_generation_loop) - self._loop_thread.daemon = True - self._loop_thread.start() - _debug("InferenceScheduler.start: loop thread started") - - def stop(self) -> None: - """Stop the generation loop.""" - self._running = False - if hasattr(self, "_loop_thread"): - self._loop_thread.join(timeout=1.0) - - def get_stats(self) -> Dict[str, Any]: - """Get scheduler statistics.""" - return { - "total_tasks": self._total_tasks, - "total_tokens": self._total_tokens, - "active_tasks": len(self.active_tasks), - "waiting_queue": len(self.waiting_queue), - } + def get_results(self) -> List[str]: + with self._lock: + return self.results.copy() class InferenceEngine: - """Unified inference engine for continuous batching. - - Provides a single interface for: - - Single request generation (streaming or non-streaming) - - Batch request generation (streaming or non-streaming) - """ + """Unified inference engine for continuous batching.""" def __init__( self, @@ -556,9 +132,7 @@ class InferenceEngine: self.kv_cache = self.scheduler.kv_cache self.seq_mask = self.scheduler.seq_mask - _debug("InferenceEngine: starting scheduler") self.scheduler.start() - _debug("InferenceEngine: scheduler started") def generate( self, @@ -606,43 +180,29 @@ class InferenceEngine: top_p: float, top_k: int, ) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]: - """Generate with streaming output (synchronous).""" - results = [] - _debug(f"_generate_streaming: prompts={len(prompts)}") - + """Generate with streaming output.""" if is_batch: raise NotImplementedError("Batch streaming is not implemented yet") - def make_callback(idx: int): - def cb(token: str): - _debug(f"callback[{idx}]: token={token!r}") - results.append(token) + result = _StreamingResult() - return cb - - for i, p in enumerate(prompts): - _debug(f"_generate_streaming: adding task {i}: {p[:30]}...") - self.scheduler.add_task( - prompt=p, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - top_k=top_k, - stream_callback=make_callback(i), - ) + self.scheduler.add_task( + prompt=prompts[0], + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + top_k=top_k, + stream_callback=result.append, + ) def gen(): - _debug("generator: start yielding") while True: - # Yield accumulated tokens - while results: - token = results.pop(0) + tokens = result.pop_all() + for token in tokens: if token == "[DONE]": - _debug("generator: got [DONE]") return - _debug(f"generator: yielding {token!r}") yield token - time.sleep(0.01) + result.wait(timeout=0.05) return gen() @@ -656,19 +216,7 @@ class InferenceEngine: top_k: int, ) -> Union[str, List[str]]: """Generate without streaming.""" - results = ["" for _ in range(len(prompts))] - done_flags = [False] * len(prompts) - lock = threading.Lock() - - def make_callback(idx: int): - def cb(token: str): - if token == "[DONE]": - done_flags[idx] = True - else: - with lock: - results[idx] += token - - return cb + result = _NonStreamingResult(len(prompts)) for i, p in enumerate(prompts): self.scheduler.add_task( @@ -677,12 +225,11 @@ class InferenceEngine: temperature=temperature, top_p=top_p, top_k=top_k, - stream_callback=make_callback(i), + stream_callback=result.append, ) - while not all(done_flags): - time.sleep(0.001) - + result.wait() + results = result.get_results() return results if is_batch else results[0] def get_stats(self) -> Dict[str, Any]: diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py new file mode 100644 index 0000000..58ceaca --- /dev/null +++ b/astrai/inference/scheduler.py @@ -0,0 +1,387 @@ +"""Inference scheduler for continuous batching.""" + +import threading +import time +import uuid +from typing import Any, Callable, Dict, List, Optional + +import torch +from torch import Tensor + +from astrai.config import ModelConfig + + +class TaskStatus: + """Task state for continuous batching.""" + + PENDING = "pending" + RUNNING = "running" + FINISHED = "finished" + ABORTED = "aborted" + + +class Task: + """Individual task for continuous batching.""" + + def __init__( + self, + task_id: str, + prompt_ids: List[int], + max_tokens: int = 1024, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = 50, + stream_callback: Optional[Callable[[str], None]] = None, + ): + self.task_id = task_id + self.prompt_ids = prompt_ids + self.max_tokens = max_tokens + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + + self.status = TaskStatus.PENDING + self.output_ids: List[int] = [] + self.input_tokens: int = 0 + self.output_tokens: int = 0 + self.slot: int = -1 + self.arrival_time = time.time() + self.finish_time: Optional[float] = None + + self.stream_callback = stream_callback + + def is_finished(self, stop_ids: List[int]) -> bool: + """Check if task is finished.""" + if self.output_ids and self.output_ids[-1] in stop_ids: + return True + if self.output_tokens >= self.max_tokens: + return True + return False + + +def apply_sampling_strategies( + logits: Tensor, + temperature: float, + top_k: int, + top_p: float, + filter_value: float = -float("inf"), +) -> Tensor: + """Apply sampling strategies to the logits tensor.""" + if temperature != 1.0: + logits = logits / temperature + + if top_k > 0: + top_k = min(top_k, logits.size(-1)) + indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) + + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + indices_to_remove = torch.zeros_like(logits, dtype=torch.bool) + indices_to_remove.scatter_( + dim=1, index=sorted_indices, src=sorted_indices_to_remove + ) + + logits[indices_to_remove] = filter_value + + return logits + + +class InferenceScheduler: + """Inference scheduler with continuous batching support.""" + + def __init__( + self, + model, + tokenizer, + config: ModelConfig, + max_batch_size: int = 16, + max_seq_len: Optional[int] = None, + device: str = "cuda", + dtype: torch.dtype = torch.bfloat16, + ): + self.model = model + self.tokenizer = tokenizer + self.config = config + self.max_batch_size = max_batch_size + self.max_seq_len = max_seq_len or config.max_len + self.device = device + self.dtype = dtype + + num_heads = config.n_kv_heads + head_dim = config.dim // config.n_heads + n_layers = config.n_layers + + k_cache = torch.empty( + ( + max_batch_size, + self.max_seq_len, + n_layers, + num_heads, + head_dim, + ), + device=device, + dtype=dtype, + ) + v_cache = torch.empty( + ( + max_batch_size, + self.max_seq_len, + n_layers, + num_heads, + head_dim, + ), + device=device, + dtype=dtype, + ) + self.kv_cache = (k_cache, v_cache) + self.seq_mask = torch.ones( + (max_batch_size, self.max_seq_len), device=device, dtype=torch.bool + ) + + self.waiting_queue: List[Task] = [] + self.active_tasks: List[Task] = [] + + self._running = False + self._task_event = threading.Event() + self._lock = threading.Lock() + + self._total_tasks = 0 + self._total_tokens = 0 + + def add_task( + self, + prompt: str, + max_tokens: int = 1024, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = 50, + stream_callback: Optional[Callable[[str], None]] = None, + ) -> str: + """Add a new task to the waiting queue.""" + task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}" + prompt_ids = self.tokenizer.encode(prompt) + + task = Task( + task_id=task_id, + prompt_ids=prompt_ids, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + top_k=top_k, + stream_callback=stream_callback, + ) + + with self._lock: + self.waiting_queue.append(task) + self._total_tasks += 1 + + self._task_event.set() + return task_id + + def remove_task(self, task_id: str) -> None: + """Remove a task from the scheduler.""" + with self._lock: + self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id] + self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id] + + def _remove_finished_tasks(self) -> None: + """Remove finished tasks from active batch.""" + finished = [] + for task in self.active_tasks: + if task.is_finished(self.tokenizer.stop_ids): + task.status = TaskStatus.FINISHED + task.finish_time = time.time() + finished.append(task) + self._total_tokens += task.output_tokens + + for task in finished: + slot = task.slot + if slot >= 0 and slot < len(self.active_tasks): + self.seq_mask[slot, :] = False + task.slot = -1 + + self.active_tasks = [ + t for t in self.active_tasks if t.status != TaskStatus.FINISHED + ] + + def _refill_active_batch(self) -> None: + """Refill active batch with waiting tasks.""" + available_slots = self.max_batch_size - len(self.active_tasks) + if available_slots <= 0: + return + + with self._lock: + to_add = [] + for _ in range(min(available_slots, len(self.waiting_queue))): + if self.waiting_queue: + task = self.waiting_queue.pop(0) + task.status = TaskStatus.RUNNING + to_add.append(task) + + for task in to_add: + for i in range(self.max_batch_size): + if all(t.slot != i for t in self.active_tasks): + task.slot = i + break + self.active_tasks.append(task) + + def _execute_prefill(self, tasks: List[Task]) -> None: + """Execute Prefill phase.""" + if not tasks: + return + + tasks = sorted(tasks, key=lambda t: t.slot) + + prompt_lens = [len(task.prompt_ids) for task in tasks] + max_len = max(prompt_lens) + + input_ids = torch.zeros( + len(tasks), max_len, dtype=torch.long, device=self.device + ) + for i, task in enumerate(tasks): + if len(task.prompt_ids) > 0: + input_ids[i, : len(task.prompt_ids)] = torch.tensor( + task.prompt_ids, device=self.device + ) + + if self.tokenizer.pad_id is not None: + input_mask = torch.ne(input_ids, self.tokenizer.pad_id) + else: + input_mask = torch.ones( + input_ids.shape, dtype=torch.bool, device=self.device + ) + + with torch.inference_mode(): + outputs = self.model( + input_ids, + input_mask=input_mask, + start_pos=0, + persistent_key_values=self.kv_cache, + ) + + for i, task in enumerate(tasks): + task.input_tokens = prompt_lens[i] + task.output_tokens = 0 + + for task in tasks: + if task.slot >= 0: + self.seq_mask[task.slot, : task.input_tokens] = True + + def _execute_decode(self, tasks: List[Task], start_pos: int) -> None: + """Execute Decode phase.""" + if not tasks: + return + + tasks = sorted(tasks, key=lambda t: t.slot) + + input_ids = torch.zeros(len(tasks), dtype=torch.long, device=self.device) + for i, task in enumerate(tasks): + if task.output_ids: + input_ids[i] = task.output_ids[-1] + else: + input_ids[i] = task.prompt_ids[-1] + + input_tensor = input_ids.unsqueeze(1) + active_mask = torch.ones((len(tasks), 1), dtype=torch.bool, device=self.device) + + with torch.inference_mode(): + outputs = self.model( + input_tensor, + input_mask=active_mask, + persistent_key_values=self.kv_cache, + start_pos=start_pos, + ) + logits = outputs["logits"][:, -1, :] + + next_token_ids = [] + for i, task in enumerate(tasks): + logit = logits[i : i + 1] + logit = apply_sampling_strategies( + logit, + task.temperature, + task.top_k, + task.top_p, + ) + probs = torch.softmax(logit, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + next_token_ids.append(next_token.item()) + + for task, next_token in zip(tasks, next_token_ids): + task.output_ids.append(next_token) + task.output_tokens += 1 + + pos = task.input_tokens + task.output_tokens + if task.slot >= 0 and pos < self.max_seq_len: + self.seq_mask[task.slot, pos] = True + + if task.stream_callback: + token_str = self.tokenizer.decode([next_token]) + task.stream_callback(token_str) + + for task in tasks: + if task.output_tokens >= task.max_tokens or ( + task.output_ids and task.output_ids[-1] in self.tokenizer.stop_ids + ): + if task.stream_callback: + task.stream_callback("[DONE]") + + def _run_generation_loop(self) -> None: + """Main generation loop.""" + while self._running: + self._remove_finished_tasks() + self._refill_active_batch() + + if not self.active_tasks: + self._task_event.wait(timeout=0.01) + self._task_event.clear() + continue + + new_tasks = [t for t in self.active_tasks if t.output_tokens == 0] + decode_tasks = [t for t in self.active_tasks if t.output_tokens > 0] + + if decode_tasks: + start_pos = max(t.input_tokens + t.output_tokens for t in decode_tasks) + else: + start_pos = 0 + + if new_tasks: + self._execute_prefill(new_tasks) + decode_tasks = new_tasks + start_pos = max(t.input_tokens for t in decode_tasks) + + if decode_tasks: + self._execute_decode(decode_tasks, start_pos) + + if not self.active_tasks and not self.waiting_queue: + self._task_event.wait(timeout=0.05) + self._task_event.clear() + + def start(self) -> None: + """Start the generation loop.""" + if not self._running: + self._running = True + self._loop_thread = threading.Thread(target=self._run_generation_loop) + self._loop_thread.daemon = True + self._loop_thread.start() + + def stop(self) -> None: + """Stop the generation loop.""" + self._running = False + if hasattr(self, "_loop_thread"): + self._loop_thread.join(timeout=1.0) + + def get_stats(self) -> Dict[str, Any]: + """Get scheduler statistics.""" + return { + "total_tasks": self._total_tasks, + "total_tokens": self._total_tokens, + "active_tasks": len(self.active_tasks), + "waiting_queue": len(self.waiting_queue), + } diff --git a/tests/module/test_module.py b/tests/module/test_module.py index 64478f6..76448d9 100644 --- a/tests/module/test_module.py +++ b/tests/module/test_module.py @@ -4,6 +4,7 @@ import torch from astrai.config.param_config import ModelParameter + def test_model_parameter(test_env): save_dir = os.path.join(test_env["test_dir"], "save") model_param = ModelParameter( @@ -30,4 +31,4 @@ def test_transformer(test_env): test_env["transformer_config"].max_len, test_env["transformer_config"].vocab_size, ) - assert output_logits.shape == target_shape \ No newline at end of file + assert output_logits.shape == target_shape