From 861d33b1a1c0c0f50d4921ee4d1e0fa1f8ffc458 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 4 Apr 2026 23:49:18 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E6=9B=B4=E6=96=B0inference=20?= =?UTF-8?q?=E9=83=A8=E5=88=86=E7=9A=84=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrai/__init__.py | 14 +- astrai/inference/__init__.py | 47 ++- astrai/inference/core.py | 272 ------------- astrai/inference/engine.py | 694 +++++++++++++++++++++++++++++++++ astrai/inference/generator.py | 269 ------------- astrai/inference/server.py | 229 +++++++---- scripts/demo/generate_ar.py | 14 +- scripts/demo/generate_batch.py | 14 +- scripts/demo/stream_chat.py | 31 +- scripts/tools/generate.py | 15 +- tests/inference/conftest.py | 22 +- tests/inference/test_server.py | 62 +-- tests/module/test_module.py | 40 +- 13 files changed, 965 insertions(+), 758 deletions(-) delete mode 100644 astrai/inference/core.py create mode 100644 astrai/inference/engine.py delete mode 100644 astrai/inference/generator.py diff --git a/astrai/__init__.py b/astrai/__init__.py index 3f5bd84..423ab84 100644 --- a/astrai/__init__.py +++ b/astrai/__init__.py @@ -8,13 +8,9 @@ from astrai.config import ( from astrai.factory import BaseFactory from astrai.dataset import DatasetFactory from astrai.tokenize import BpeTokenizer -from astrai.inference.generator import ( - BatchGenerator, - EmbeddingEncoder, +from astrai.inference import ( GenerationRequest, - GeneratorFactory, - LoopGenerator, - StreamGenerator, + InferenceEngine, ) from astrai.model.transformer import Transformer from astrai.trainer import SchedulerFactory, StrategyFactory, Trainer @@ -26,11 +22,7 @@ __all__ = [ "DatasetFactory", "BpeTokenizer", "GenerationRequest", - "LoopGenerator", - "StreamGenerator", - "BatchGenerator", - "EmbeddingEncoder", - "GeneratorFactory", + "InferenceEngine", "Trainer", "StrategyFactory", "SchedulerFactory", diff --git a/astrai/inference/__init__.py b/astrai/inference/__init__.py index 6675be6..9c082c1 100644 --- a/astrai/inference/__init__.py +++ b/astrai/inference/__init__.py @@ -1,25 +1,34 @@ -from astrai.inference.core import ( - EmbeddingEncoderCore, - GeneratorCore, - KVCacheManager, -) -from astrai.inference.generator import ( - BatchGenerator, - EmbeddingEncoder, +""" +AstrAI Inference Module + +Provides inference components for text generation with continuous batching support. + +Main Components: +- InferenceEngine: Unified inference engine for continuous batching +- InferenceScheduler: Task scheduling with dynamic batch composition +- Task, TaskStatus: Task management for continuous batching +- GenerationRequest: Request parameters for generation +- apply_sampling_strategies: Sampling utilities for text generation + +Author: AstrAI Team +""" + +from astrai.inference.engine import ( GenerationRequest, - GeneratorFactory, - LoopGenerator, - StreamGenerator, + InferenceEngine, + InferenceScheduler, + Task, + TaskStatus, + apply_sampling_strategies, ) __all__ = [ - "GeneratorCore", - "EmbeddingEncoderCore", - "KVCacheManager", + # Engine + "InferenceEngine", + "InferenceScheduler", + "Task", + "TaskStatus", "GenerationRequest", - "LoopGenerator", - "StreamGenerator", - "BatchGenerator", - "EmbeddingEncoder", - "GeneratorFactory", + # Sampling + "apply_sampling_strategies", ] diff --git a/astrai/inference/core.py b/astrai/inference/core.py deleted file mode 100644 index b0952f7..0000000 --- a/astrai/inference/core.py +++ /dev/null @@ -1,272 +0,0 @@ -from typing import Any, Callable, List, Optional, Self, Tuple, Union - -import torch -from torch import Tensor - -from astrai.config import ModelConfig, ModelParameter - - -def apply_sampling_strategies( - logits: Tensor, - temperature: float, - top_k: int, - top_p: float, - filter_value: float = -float("inf"), -) -> Tensor: - """ - Apply sampling strategies to the logits tensor. - - Args: - logits (Tensor): The logits tensor. - temperature (float): The temperature parameter. - top_k (int): The top-k parameter. - top_p (float): The top-p parameter. - filter_value (float, optional): The filter value. Defaults to -float("inf"). - - Returns: - Tensor: The sampled logits tensor. - - """ - - if temperature != 1.0: - logits = logits / temperature - - if top_k > 0: - top_k = min(top_k, logits.size(-1)) - indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None] - logits[indices_to_remove] = filter_value - - if top_p < 1.0: - sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) - cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) - - sorted_indices_to_remove = cumulative_probs > top_p - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - - indices_to_remove = torch.zeros_like(logits, dtype=torch.bool) - indices_to_remove.scatter_( - dim=1, index=sorted_indices, src=sorted_indices_to_remove - ) - - logits[indices_to_remove] = filter_value - - return logits - - -class GeneratorCore: - def __init__(self, parameter: ModelParameter): - self.model = parameter.model - self.tokenizer = parameter.tokenizer - self.config = parameter.config - - def generate_iterator( - self, - input_ids: Tensor, - temperature: float, - top_k: int, - top_p: float, - attn_mask: Optional[Tensor] = None, - kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None, - start_pos: int = 0, - ) -> Tuple[Tensor, int]: - - with torch.inference_mode(): - outputs = self.model(input_ids, attn_mask, kv_caches, start_pos) - logits = outputs["logits"][:, -1, :] - cache_increase = input_ids.size(-1) - - logits = apply_sampling_strategies(logits, temperature, top_k, top_p) - probs = torch.softmax(logits, dim=-1) - next_token_id = torch.multinomial(probs, num_samples=1) - - return next_token_id, cache_increase - - def generate_loop( - self, - input_ids: Tensor, - ids: List[int], - temperature: float, - top_k: int, - top_p: float, - attn_mask: Optional[Tensor] = None, - kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None, - start_pos: int = 0, - callback: Optional[Callable[..., Any]] = None, - ) -> List[int]: - cur_cache_pos = start_pos - - for _ in range(len(ids), self.config.max_len): - next_token_id, cache_increase = self.generate_iterator( - input_ids, - temperature, - top_k, - top_p, - attn_mask, - kv_caches, - cur_cache_pos, - ) - - input_ids = next_token_id - ids.append(next_token_id.item()) - cur_cache_pos += cache_increase - - if callback: - callback(next_token_id.item(), ids.copy()) - - if next_token_id.item() in self.tokenizer.stop_ids: - break - - return ids - - def to(self, *args, **kargs) -> Self: - self.model.to(*args, **kargs) - return self - - -class EmbeddingEncoderCore: - def __init__(self, parameter: ModelParameter): - self.model = parameter.model - self.tokenizer = parameter.tokenizer - self.config = parameter.config - - def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]: - with_batch = isinstance(sentence, list) - ids = self.tokenizer.encode(sentence) - batch_ids = ids if with_batch else [ids] - max_model_len = self.config.max_len - - all_fragments = [] - fragment_origin_idx = [] - - for i, seq in enumerate(batch_ids): - if len(seq) > max_model_len: - fragments = [ - seq[j : j + max_model_len] - for j in range(0, len(seq), max_model_len) - ] - all_fragments.extend(fragments) - fragment_origin_idx.extend([i] * len(fragments)) - else: - all_fragments.append(seq) - fragment_origin_idx.append(i) - - # if empty fragments - if not all_fragments or not ids: - return [] if with_batch else torch.tensor([]) - - device = next(self.model.parameters()).device - max_len = min(max(len(seq) for seq in all_fragments), max_model_len) - - padded_ids = [] - masks = [] - for seq in all_fragments: - pad_len = max_len - len(seq) - padded_seq = seq + [self.tokenizer.pad_id] * pad_len - mask = [token_id != self.tokenizer.pad_id for token_id in padded_seq] - padded_ids.append(padded_seq) - masks.append(mask) - - input_tensor = torch.tensor(padded_ids, device=device, dtype=torch.long) - seq_mask = torch.tensor(masks, device=device, dtype=torch.bool) - - with torch.inference_mode(): - outputs = self.model(input_tensor, seq_mask)["hidden_states"] - # [num_fragments, seq_len, hidden_size] - fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1)) - - sentence_embs: List[Tensor] = [] - for i in range(len(batch_ids)): - indices = [ - idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i - ] - if indices: - sum_frags = torch.sum( - fragment_embs[indices, :, :], dim=1 - ) # [frags, hidden_size] - length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze( - 1 - ) # [frags, 1] - emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size] - sentence_embs.append(emb.flatten()) - - if with_batch: - return [emb.flatten() for emb in sentence_embs] - else: - return sentence_embs[0].flatten() - - def to(self, *args, **kargs) -> Self: - self.model.to(*args, **kargs) - return self - - -class KVCacheManager: - def __init__( - self, - config: ModelConfig, - batch_size: int, - device: torch.device = "cuda", - dtype: torch.dtype = torch.bfloat16, - ): - self.batch_size = batch_size - self.device = device - self.dtype = dtype - self.num_layers = config.n_layers - self.max_len = config.max_len - self.num_heads = config.n_kv_heads - self.head_dim = config.dim // config.n_heads - - self._kv_cache: Tuple[Tensor, Tensor] = None - self._seq_mask: Tensor = None - self._initialize() - - def _initialize(self): - k_cache = torch.empty( - ( - self.batch_size, - self.max_len, - self.num_layers, - self.num_heads, - self.head_dim, - ), - device=self.device, - dtype=self.dtype, - ) - v_cache = torch.empty( - ( - self.batch_size, - self.max_len, - self.num_layers, - self.num_heads, - self.head_dim, - ), - device=self.device, - dtype=self.dtype, - ) - self._kv_cache = (k_cache, v_cache) - self._seq_mask = torch.ones( - (self.batch_size, self.max_len), device=self.device, dtype=torch.bool - ) - - def update(self, active_mask: Tensor): - k_cache, v_cache = self._kv_cache - self._kv_cache = (k_cache[active_mask], v_cache[active_mask]) - self._seq_mask = self._seq_mask[active_mask] - - def reset(self, full_reset=False): - if full_reset: - self._kv_cache = None - self._seq_mask = None - else: - self._initialize() - - def set_seq_mask(self, input_ids: Tensor, pad_id: int): - batch_size, seq_len = input_ids.shape - bool_mask = input_ids != pad_id - self._seq_mask[:batch_size, :seq_len] = bool_mask - - def get_kvcache(self) -> Tuple[Tensor, Tensor]: - return self._kv_cache - - def get_seq_mask(self) -> Tensor: - return self._seq_mask diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py new file mode 100644 index 0000000..7381738 --- /dev/null +++ b/astrai/inference/engine.py @@ -0,0 +1,694 @@ +""" +Continuous Batching Inference Engine + +This module provides the main continuous batching components: +- Task: Individual generation task with state management +- TaskStatus: Task state enumeration +- InferenceScheduler: Handles request scheduling and KV cache management +- InferenceEngine: Unified inference engine + +Author: AstrAI Team +""" + +import threading +import time +import uuid +from dataclasses import dataclass, field +from enum import Enum, auto +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union + +import torch +from torch import Tensor + +from astrai.config import ModelConfig, ModelParameter +from astrai.tokenize.chat_template import HistoryType, build_prompt + + +# Use print for debugging instead of logging +def _debug(*args, **kwargs): + pass + + +@dataclass +class GenerationRequest: + """Request parameters for text generation.""" + + top_k: int + top_p: float + temperature: float + max_len: int + + query: Union[str, List[str]] + history: Optional[Union[HistoryType, List[HistoryType]]] = None + system_prompt: Optional[str] = None + stream: bool = False + + def __post_init__(self): + if not isinstance(self.top_k, int) or self.top_k < 0: + raise ValueError("top_k must be a non-negative integer") + if not isinstance(self.top_p, float) or self.top_p < 0.0 or self.top_p > 1.0: + raise ValueError("top_p must be a float between 0.0 and 1.0") + if not isinstance(self.temperature, float) or self.temperature < 0.0: + raise ValueError("temperature must be a non-negative float") + + +class TaskStatus(Enum): + """Task state enumeration for continuous batching. + + States: + PENDING: Task is waiting to be scheduled + RUNNING: Task is currently being processed + FINISHED: Task completed successfully + ABORTED: Task was cancelled or failed + """ + + PENDING = auto() + RUNNING = auto() + FINISHED = auto() + ABORTED = auto() + + +@dataclass +class Task: + """Individual task for continuous batching. + + Attributes: + task_id: Unique task identifier + prompt_ids: Input token IDs + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + top_p: Top-p sampling parameter + top_k: Top-k sampling parameter + status: Current task status + output_ids: Generated token IDs + input_tokens: Number of input tokens + output_tokens: Number of output tokens generated + slot: Batch slot position (-1 if not assigned) + arrival_time: Task arrival timestamp + finish_time: Task completion timestamp + stream_callback: Callback for streaming output + """ + + task_id: str + prompt_ids: List[int] + max_tokens: int = 1024 + temperature: float = 1.0 + top_p: float = 1.0 + top_k: int = 50 + + status: TaskStatus = TaskStatus.PENDING + output_ids: List[int] = field(default_factory=list) + input_tokens: int = 0 + output_tokens: int = 0 + slot: int = -1 + arrival_time: float = field(default_factory=time.time) + finish_time: Optional[float] = None + + stream_callback: Optional[Callable[[str], None]] = None + + def is_finished(self, stop_ids: List[int]) -> bool: + """Check if task is finished.""" + if self.output_ids and self.output_ids[-1] in stop_ids: + return True + if self.output_tokens >= self.max_tokens: + return True + return False + + +def apply_sampling_strategies( + logits: Tensor, + temperature: float, + top_k: int, + top_p: float, + filter_value: float = -float("inf"), +) -> Tensor: + """Apply sampling strategies to the logits tensor.""" + if temperature != 1.0: + logits = logits / temperature + + if top_k > 0: + top_k = min(top_k, logits.size(-1)) + indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) + + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + indices_to_remove = torch.zeros_like(logits, dtype=torch.bool) + indices_to_remove.scatter_( + dim=1, index=sorted_indices, src=sorted_indices_to_remove + ) + + logits[indices_to_remove] = filter_value + + return logits + + +class InferenceScheduler: + """Inference scheduler with continuous batching support. + + Manages request scheduling, KV cache allocation, and generation loop. + Supports dynamic batch composition where new requests can join at any time + and completed requests are immediately released. + """ + + def __init__( + self, + model, + tokenizer, + config: ModelConfig, + max_batch_size: int = 16, + max_seq_len: Optional[int] = None, + device: str = "cuda", + dtype: torch.dtype = torch.bfloat16, + ): + self.model = model + self.tokenizer = tokenizer + self.config = config + self.max_batch_size = max_batch_size + self.max_seq_len = max_seq_len or config.max_len + self.device = device + self.dtype = dtype + + num_heads = config.n_kv_heads + head_dim = config.dim // config.n_heads + n_layers = config.n_layers + + k_cache = torch.empty( + ( + max_batch_size, + self.max_seq_len, + n_layers, + num_heads, + head_dim, + ), + device=device, + dtype=dtype, + ) + v_cache = torch.empty( + ( + max_batch_size, + self.max_seq_len, + n_layers, + num_heads, + head_dim, + ), + device=device, + dtype=dtype, + ) + self.kv_cache = (k_cache, v_cache) + self.seq_mask = torch.ones( + (max_batch_size, self.max_seq_len), device=device, dtype=torch.bool + ) + + self.waiting_queue: List[Task] = [] + self.active_tasks: List[Task] = [] + + self._running = False + self._task_event = threading.Event() + self._lock = threading.Lock() + + self._total_tasks = 0 + self._total_tokens = 0 + + def add_task( + self, + prompt: str, + max_tokens: int = 1024, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = 50, + stream_callback: Optional[Callable[[str], None]] = None, + ) -> str: + """Add a new task to the waiting queue.""" + task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}" + prompt_ids = self.tokenizer.encode(prompt) + + _debug( + f"add_task: task_id={task_id}, prompt_len={len(prompt_ids)}, has_callback={stream_callback is not None}" + ) + + task = Task( + task_id=task_id, + prompt_ids=prompt_ids, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + top_k=top_k, + stream_callback=stream_callback, + ) + + with self._lock: + self.waiting_queue.append(task) + self._total_tasks += 1 + + self._task_event.set() + return task_id + + def remove_task(self, task_id: str) -> None: + """Remove a task from the scheduler.""" + with self._lock: + self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id] + self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id] + + def _remove_finished_tasks(self) -> None: + """Remove finished tasks from active batch and update caches.""" + finished = [] + for task in self.active_tasks: + if task.is_finished(self.tokenizer.stop_ids): + task.status = TaskStatus.FINISHED + task.finish_time = time.time() + finished.append(task) + self._total_tokens += task.output_tokens + + for task in finished: + slot = task.slot + if slot >= 0 and slot < len(self.active_tasks): + self.seq_mask[slot, :] = False + task.slot = -1 + + self.active_tasks = [ + t for t in self.active_tasks if t.status != TaskStatus.FINISHED + ] + + def _refill_active_batch(self) -> None: + """Refill active batch with waiting tasks.""" + available_slots = self.max_batch_size - len(self.active_tasks) + if available_slots <= 0: + return + + with self._lock: + to_add = [] + for _ in range(min(available_slots, len(self.waiting_queue))): + if self.waiting_queue: + task = self.waiting_queue.pop(0) + task.status = TaskStatus.RUNNING + to_add.append(task) + + for task in to_add: + for i in range(self.max_batch_size): + if all(t.slot != i for t in self.active_tasks): + task.slot = i + break + self.active_tasks.append(task) + + def _execute_prefill(self, tasks: List[Task]) -> None: + """Execute Prefill phase: process entire prompt at once.""" + if not tasks: + return + + _debug(f"_execute_prefill: processing {len(tasks)} tasks") + + # Sort tasks by slot to ensure correct batch indexing with KV cache + tasks = sorted(tasks, key=lambda t: t.slot) + + prompt_lens = [len(task.prompt_ids) for task in tasks] + max_len = max(prompt_lens) + + input_ids = torch.zeros( + len(tasks), max_len, dtype=torch.long, device=self.device + ) + for i, task in enumerate(tasks): + if len(task.prompt_ids) > 0: + input_ids[i, : len(task.prompt_ids)] = torch.tensor( + task.prompt_ids, device=self.device + ) + + # Create boolean mask for attention + if self.tokenizer.pad_id is not None: + input_mask = torch.ne(input_ids, self.tokenizer.pad_id) + else: + input_mask = torch.ones( + input_ids.shape, dtype=torch.bool, device=self.device + ) + + _debug( + f"_execute_prefill: input_ids shape={input_ids.shape}, max_len={max_len}" + ) + + try: + with torch.inference_mode(): + outputs = self.model( + input_ids, + input_mask=input_mask, + start_pos=0, + persistent_key_values=self.kv_cache, + ) + _debug( + f"_execute_prefill: model forward done, output keys={outputs.keys() if hasattr(outputs, 'keys') else 'no keys'}" + ) + except Exception as e: + _debug(f"_execute_prefill: ERROR: {e}") + raise + + for i, task in enumerate(tasks): + task.input_tokens = prompt_lens[i] + task.output_tokens = 0 + _debug( + f" task {task.task_id}: input_tokens={task.input_tokens}, output_tokens={task.output_tokens}" + ) + + for task in tasks: + if task.slot >= 0: + self.seq_mask[task.slot, : task.input_tokens] = True + + _debug(f"_execute_prefill: done, {len(tasks)} tasks marked as prefill complete") + + def _execute_decode(self, tasks: List[Task], start_pos: int) -> None: + """Execute Decode phase: generate one token at a time.""" + if not tasks: + return + + _debug(f"_execute_decode: processing {len(tasks)} tasks, start_pos={start_pos}") + + # Sort tasks by slot to ensure batch index aligns with slot (KV cache position) + # Task at slot 0 → batch index 0 → KV stored at cache[0] + # Task at slot 1 → batch index 1 → KV stored at cache[1] + tasks = sorted(tasks, key=lambda t: t.slot) + + input_ids = torch.zeros(len(tasks), dtype=torch.long, device=self.device) + for i, task in enumerate(tasks): + if task.output_ids: + input_ids[i] = task.output_ids[-1] + else: + input_ids[i] = task.prompt_ids[-1] + + input_tensor = input_ids.unsqueeze(1) # shape: (batch, 1) + + # Create 2D attention mask: (batch, seq_len) + active_mask = torch.ones((len(tasks), 1), dtype=torch.bool, device=self.device) + _debug( + f"_execute_decode: input_tensor shape={input_tensor.shape}, active_mask shape={active_mask.shape}" + ) + + try: + with torch.inference_mode(): + outputs = self.model( + input_tensor, + input_mask=active_mask, + persistent_key_values=self.kv_cache, + start_pos=start_pos, + ) + _debug( + f"_execute_decode: model forward done, logits shape={outputs['logits'].shape}" + ) + logits = outputs["logits"][:, -1, :] + except Exception as e: + _debug(f"_execute_decode: ERROR: {e}") + raise + + next_token_ids = [] + for i, task in enumerate(tasks): + logit = logits[i : i + 1] + logit = apply_sampling_strategies( + logit, + task.temperature, + task.top_k, + task.top_p, + ) + probs = torch.softmax(logit, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + next_token_ids.append(next_token.item()) + + _debug(f"_execute_decode: next_tokens={next_token_ids}") + + for task, next_token in zip(tasks, next_token_ids): + task.output_ids.append(next_token) + task.output_tokens += 1 + + pos = task.input_tokens + task.output_tokens + if task.slot >= 0 and pos < self.max_seq_len: + self.seq_mask[task.slot, pos] = True + + if task.stream_callback: + token_str = self.tokenizer.decode([next_token]) + task.stream_callback(token_str) + + # Check if any task reached max_tokens or stop token + for task in tasks: + if task.output_tokens >= task.max_tokens or ( + task.output_ids and task.output_ids[-1] in self.tokenizer.stop_ids + ): + _debug( + f"decode: task {task.task_id} finished, output_tokens={task.output_tokens}, max_tokens={task.max_tokens}" + ) + if task.stream_callback: + task.stream_callback("[DONE]") + + def _run_generation_loop(self) -> None: + """Main generation loop with continuous batching.""" + _debug("generation_loop: started") + while self._running: + self._remove_finished_tasks() + self._refill_active_batch() + + if not self.active_tasks: + self._task_event.wait(timeout=0.01) + self._task_event.clear() + continue + + _debug( + f"generation_loop: active={len(self.active_tasks)}, waiting={len(self.waiting_queue)}" + ) + + new_tasks = [t for t in self.active_tasks if t.output_tokens == 0] + decode_tasks = [t for t in self.active_tasks if t.output_tokens > 0] + + _debug( + f"generation_loop: new_tasks={len(new_tasks)}, decode_tasks={len(decode_tasks)}" + ) + for t in self.active_tasks: + _debug( + f" active task {t.task_id}: output_tokens={t.output_tokens}, input_tokens={t.input_tokens}" + ) + + if decode_tasks: + start_pos = max(t.input_tokens + t.output_tokens for t in decode_tasks) + else: + start_pos = 0 + + # First run prefill for new tasks + if new_tasks: + _debug(f"generation_loop: running prefill for {len(new_tasks)} tasks") + self._execute_prefill(new_tasks) + _debug(f"generation_loop: prefill done") + + # After prefill, convert these tasks to decode tasks in the same iteration + decode_tasks = new_tasks + start_pos = max(t.input_tokens for t in decode_tasks) + _debug( + f"generation_loop: after prefill, decode_tasks={len(decode_tasks)}, start_pos={start_pos}" + ) + + if decode_tasks: + _debug( + f"generation_loop: running decode for {len(decode_tasks)} tasks, start_pos={start_pos}" + ) + self._execute_decode(decode_tasks, start_pos) + _debug(f"generation_loop: decode done") + + if not self.active_tasks and not self.waiting_queue: + time.sleep(0.001) + + def start(self) -> None: + """Start the generation loop in a background thread.""" + if not self._running: + _debug("InferenceScheduler.start: starting loop thread") + self._running = True + self._loop_thread = threading.Thread(target=self._run_generation_loop) + self._loop_thread.daemon = True + self._loop_thread.start() + _debug("InferenceScheduler.start: loop thread started") + + def stop(self) -> None: + """Stop the generation loop.""" + self._running = False + if hasattr(self, "_loop_thread"): + self._loop_thread.join(timeout=1.0) + + def get_stats(self) -> Dict[str, Any]: + """Get scheduler statistics.""" + return { + "total_tasks": self._total_tasks, + "total_tokens": self._total_tokens, + "active_tasks": len(self.active_tasks), + "waiting_queue": len(self.waiting_queue), + } + + +class InferenceEngine: + """Unified inference engine for continuous batching. + + Provides a single interface for: + - Single request generation (streaming or non-streaming) + - Batch request generation (streaming or non-streaming) + """ + + def __init__( + self, + parameter: ModelParameter, + max_batch_size: int = 16, + max_seq_len: Optional[int] = None, + ): + self.model = parameter.model + self.tokenizer = parameter.tokenizer + self.config = parameter.config + + model_params = next(self.model.parameters()) + self.device = model_params.device + self.dtype = model_params.dtype + + self.scheduler = InferenceScheduler( + model=self.model, + tokenizer=self.tokenizer, + config=self.config, + max_batch_size=max_batch_size, + max_seq_len=max_seq_len, + device=self.device, + dtype=self.dtype, + ) + + self.kv_cache = self.scheduler.kv_cache + self.seq_mask = self.scheduler.seq_mask + + _debug("InferenceEngine: starting scheduler") + self.scheduler.start() + _debug("InferenceEngine: scheduler started") + + def generate( + self, + prompt: Union[str, List[str]], + stream: bool = False, + max_tokens: int = 1024, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = 50, + ) -> Union[Generator[str, None, None], str, List[str]]: + """Unified generation interface.""" + is_batch = isinstance(prompt, list) + prompts = prompt if is_batch else [prompt] + + if stream: + return self._generate_streaming( + prompts, is_batch, max_tokens, temperature, top_p, top_k + ) + else: + return self._generate_non_streaming( + prompts, is_batch, max_tokens, temperature, top_p, top_k + ) + + def generate_with_request( + self, request: GenerationRequest + ) -> Union[Generator[str, None, None], str, List[str]]: + """Generate with GenerationRequest object.""" + prompt = build_prompt(request.query, request.history) + + return self.generate( + prompt=prompt, + stream=request.stream, + max_tokens=request.max_len, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + ) + + def _generate_streaming( + self, + prompts: List[str], + is_batch: bool, + max_tokens: int, + temperature: float, + top_p: float, + top_k: int, + ) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]: + """Generate with streaming output (synchronous).""" + results = [] + _debug(f"_generate_streaming: prompts={len(prompts)}") + + if is_batch: + raise NotImplementedError("Batch streaming is not implemented yet") + + def make_callback(idx: int): + def cb(token: str): + _debug(f"callback[{idx}]: token={token!r}") + results.append(token) + + return cb + + for i, p in enumerate(prompts): + _debug(f"_generate_streaming: adding task {i}: {p[:30]}...") + self.scheduler.add_task( + prompt=p, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + top_k=top_k, + stream_callback=make_callback(i), + ) + + def gen(): + _debug("generator: start yielding") + while True: + # Yield accumulated tokens + while results: + token = results.pop(0) + if token == "[DONE]": + _debug("generator: got [DONE]") + return + _debug(f"generator: yielding {token!r}") + yield token + time.sleep(0.01) + + return gen() + + def _generate_non_streaming( + self, + prompts: List[str], + is_batch: bool, + max_tokens: int, + temperature: float, + top_p: float, + top_k: int, + ) -> Union[str, List[str]]: + """Generate without streaming.""" + results = ["" for _ in range(len(prompts))] + done_flags = [False] * len(prompts) + lock = threading.Lock() + + def make_callback(idx: int): + def cb(token: str): + if token == "[DONE]": + done_flags[idx] = True + else: + with lock: + results[idx] += token + + return cb + + for i, p in enumerate(prompts): + self.scheduler.add_task( + prompt=p, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + top_k=top_k, + stream_callback=make_callback(i), + ) + + while not all(done_flags): + time.sleep(0.001) + + return results if is_batch else results[0] + + def get_stats(self) -> Dict[str, Any]: + """Get engine statistics.""" + return self.scheduler.get_stats() + + def shutdown(self) -> None: + """Shutdown the engine.""" + self.scheduler.stop() diff --git a/astrai/inference/generator.py b/astrai/inference/generator.py deleted file mode 100644 index 46331dd..0000000 --- a/astrai/inference/generator.py +++ /dev/null @@ -1,269 +0,0 @@ -from dataclasses import dataclass -from typing import Generator, List, Optional, Tuple, Union - -import torch -from torch import Tensor - -from astrai.config.param_config import ModelParameter -from astrai.factory import BaseFactory -from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager -from astrai.tokenize.chat_template import HistoryType, build_prompt - - -def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]: - """ - Pad a list of sequences to a fixed length. - - Args: - ids_list (List[List[int]]): A list of sequences. - max_ids_len (int): The maximum length of sequences. - pad_id (int): The id to pad sequences. - - Returns: - List[List[int]]: A list of padded sequences. - - """ - max_ids_len = max(len(ids) for ids in ids_list) - new_ids_list = [] - for ids in ids_list: - pad_len = max_ids_len - len(ids) - padded_seq = [pad_id] * pad_len + ids - new_ids_list.append(padded_seq) - - return new_ids_list, max_ids_len - - -@dataclass -class GenerationRequest: - """ - Request parameters for text generation. - - Attributes: - top_k: Top-k sampling parameter. - top_p: Top-p (nucleus) sampling parameter. - temperature: Sampling temperature. - max_len: Maximum generation length. - query: Input query (string or list of strings for batch). - history: Conversation history. - system_prompt: System prompt for the conversation. - stream: Whether to use streaming generation. - """ - - top_k: int - top_p: float - temperature: float - max_len: int - - query: Union[str, List[str]] - history: Optional[Union[HistoryType, List[HistoryType]]] = None - system_prompt: Optional[str] = None - stream: bool = False - - def __post_init__(self): - if not isinstance(self.top_k, int) or self.top_k < 0: - raise ValueError("top_k must be a non-negative integer") - if not isinstance(self.top_p, float) or self.top_p < 0.0 or self.top_p > 1.0: - raise ValueError("top_p must be a float between 0.0 and 1.0") - if not isinstance(self.temperature, float) or self.temperature < 0.0: - raise ValueError("temperature must be a non-negative float") - - -class LoopGenerator(GeneratorCore): - def __init__(self, parameter: ModelParameter): - super().__init__(parameter) - - def generate(self, request: GenerationRequest) -> str: - model_params = next(self.model.parameters()) - device = model_params.device - dtype = model_params.dtype - cache_manager = KVCacheManager(self.config, 1, device=device, dtype=dtype) - - prompt = build_prompt(request.query, request.history) - ids = self.tokenizer.encode(prompt) - input_ids = torch.tensor([ids], device=device, dtype=torch.long) - - start_cache_pos = len(ids) - self.model.eval() - kv_caches = cache_manager.get_kvcache() - - ids = self.generate_loop( - input_ids, - ids, - request.temperature, - request.top_k, - request.top_p, - kv_caches=kv_caches, - ) - response = self.tokenizer.decode(ids[start_cache_pos:]) - - return response - - -class StreamGenerator(GeneratorCore): - def __init__(self, parameter: ModelParameter): - super().__init__(parameter) - - def generate(self, request: GenerationRequest) -> Generator[str, None, None]: - model_params = next(self.model.parameters()) - device = model_params.device - dtype = model_params.dtype - cache_manager = KVCacheManager(self.config, 1, device=device, dtype=dtype) - - prompt = build_prompt(request.query, request.history) - ids = self.tokenizer.encode(prompt) - input_ids = torch.tensor([ids], device=device, dtype=torch.long) - - start_cache_pos = len(ids) - cur_cache_pos = 0 - self.model.eval() - kv_caches = cache_manager.get_kvcache() - - for _ in range(len(ids), self.config.max_len): - next_token_id, cache_increase = self.generate_iterator( - input_ids, - request.temperature, - request.top_k, - request.top_p, - kv_caches=kv_caches, - start_pos=cur_cache_pos, - ) - - input_ids = next_token_id - ids.append(next_token_id.item()) - cur_cache_pos += cache_increase - - response = self.tokenizer.decode(ids[start_cache_pos:]) - yield response - - if next_token_id.item() in self.tokenizer.stop_ids: - yield response + "\n" - break - - -class BatchGenerator(GeneratorCore): - def __init__(self, parameter: ModelParameter): - super().__init__(parameter) - - def generate(self, request: GenerationRequest) -> List[str]: - batch_size = len(request.query) - if request.history is None: - request.history = [[] for _ in range(batch_size)] - - prompts = [ - build_prompt(query, history) - for query, history in zip(request.query, request.history) - ] - - ids_list = [self.tokenizer.encode(prompt) for prompt in prompts] - ids_list, max_ids_len = pad_sequence(ids_list, self.tokenizer.pad_id) - - model_params = next(self.model.parameters()) - device = model_params.device - dtype = model_params.dtype - cache_manager = KVCacheManager( - self.config, batch_size, device=device, dtype=dtype - ) - - input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long) - cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id) - activate_task_mask = [True] * batch_size - - start_cache_pos = max_ids_len - cur_cache_pos = 0 - - while max_ids_len < self.config.max_len and sum(activate_task_mask) != 0: - kv_caches = cache_manager.get_kvcache() - attn_mask = cache_manager.get_seq_mask() - - next_token_id, cache_increase = self.generate_iterator( - input_tensor, - request.temperature, - request.top_k, - request.top_p, - attn_mask=attn_mask, - kv_caches=kv_caches, - start_pos=cur_cache_pos, - ) - - cur_cache_pos += cache_increase - active_mask = [] - c_ids = 0 - - for i in range(batch_size): - if activate_task_mask[i]: - token = next_token_id[c_ids, :].item() - ids_list[i].append(token) - c_ids += 1 - - is_active = token not in self.tokenizer.stop_ids - activate_task_mask[i] = is_active - active_mask.append(is_active) - - active_mask = torch.tensor(active_mask, device=device, dtype=torch.bool) - cache_manager.update(active_mask) - input_tensor = next_token_id[active_mask, :] - - max_ids_len += 1 - - responses = [str()] * batch_size - for i in range(batch_size): - responses[i] = self.tokenizer.decode(ids_list[i][start_cache_pos:]) - request.history[i].append((request.query[i], responses[i])) - - return responses - - -class EmbeddingEncoder(EmbeddingEncoderCore): - def __init__(self, parameter: ModelParameter): - super().__init__(parameter) - - def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]: - return super().encode(sentence) - - -class GeneratorFactory(BaseFactory[GeneratorCore]): - """Factory class for creating generator instances. - - Provides smart generator selection based on request characteristics: - - Streaming: Use StreamGenerator for streaming output - - Batch: Use BatchGenerator when query is a list - - Single: Use LoopGenerator for single query non-streaming - - Example usage: - generator = GeneratorFactory.create(parameter, request) - result = generator.generate(request) - """ - - @staticmethod - def create(parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore: - """Create a generator based on request characteristics. - - Args: - parameter: Model parameters containing model, tokenizer, config - request: Generation request with query, options, etc. - - Returns: - Appropriate GeneratorCore subclass instance - """ - # Streaming generation: check stream field first - if request.stream: - return StreamGenerator(parameter) - - # Batch generation: query is a list of strings - if isinstance(request.query, list): - return BatchGenerator(parameter) - - # Default: single query non-streaming - return LoopGenerator(parameter) - - @staticmethod - def create_encoder(parameter: ModelParameter) -> EmbeddingEncoderCore: - """Create an embedding encoder instance. - - Args: - parameter: Model parameters - - Returns: - EmbeddingEncoderCore instance - """ - return EmbeddingEncoder(parameter) diff --git a/astrai/inference/server.py b/astrai/inference/server.py index 129392a..c780add 100644 --- a/astrai/inference/server.py +++ b/astrai/inference/server.py @@ -1,3 +1,13 @@ +""" +Inference Server with Continuous Batching Support + +FastAPI server for inference with continuous batching. +Provides OpenAI-compatible chat completion endpoints. + +Author: AstrAI Team +""" + +import json import logging from contextlib import asynccontextmanager from pathlib import Path @@ -10,12 +20,13 @@ from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field from astrai.config.param_config import ModelParameter -from astrai.inference.generator import GenerationRequest, GeneratorFactory +from astrai.inference.engine import GenerationRequest, InferenceEngine logger = logging.getLogger(__name__) -# Global model parameter (loaded once) +# Global model parameter and engine (loaded once) _model_param: Optional[ModelParameter] = None +_engine: Optional[InferenceEngine] = None _project_root = Path(__file__).parent.parent.parent # Server configuration (set before running server) @@ -23,6 +34,7 @@ _server_config: Dict[str, Any] = { "device": "cuda", "dtype": torch.bfloat16, "param_path": None, + "max_batch_size": 16, } @@ -30,6 +42,7 @@ def configure_server( device: str = "cuda", dtype: torch.dtype = torch.bfloat16, param_path: Optional[Path] = None, + max_batch_size: int = 16, ): """Configure server settings before starting. @@ -37,40 +50,47 @@ def configure_server( device: Device to load model on (e.g., "cuda", "cpu", "cuda:0") dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16) param_path: Path to model parameters directory + max_batch_size: Maximum batch size for continuous batching """ _server_config["device"] = device _server_config["dtype"] = dtype _server_config["param_path"] = param_path + _server_config["max_batch_size"] = max_batch_size @asynccontextmanager async def lifespan(app: FastAPI): """Lifespan context manager for startup and shutdown events.""" + global _model_param, _engine # Startup: Load model with configured settings try: load_model( param_path=_server_config["param_path"], device=_server_config["device"], dtype=_server_config["dtype"], + max_batch_size=_server_config["max_batch_size"], ) except Exception as e: logger.error(f"Failed to load model: {e}") raise yield - # Shutdown: Cleanup if needed - pass + # Shutdown: Cleanup engine + if _engine: + _engine.shutdown() + logger.info("Inference engine shutdown complete") -app = FastAPI(title="AstrAI Inference Server", version="0.1.0", lifespan=lifespan) +app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan) def load_model( param_path: Optional[Path] = None, device: str = "cuda", dtype: torch.dtype = torch.bfloat16, + max_batch_size: int = 16, ): - """Load model parameters into global variable.""" - global _model_param + """Load model parameters and initialize inference engine.""" + global _model_param, _engine if param_path is None: param_path = _project_root / "params" if not param_path.exists(): @@ -79,6 +99,13 @@ def load_model( _model_param.to(device=device, dtype=dtype) logger.info(f"Model loaded on {device} with dtype {dtype}") + # Initialize inference engine with continuous batching + _engine = InferenceEngine( + parameter=_model_param, + max_batch_size=max_batch_size, + ) + logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}") + # Pydantic models for API request/response class ChatMessage(BaseModel): @@ -134,54 +161,77 @@ def convert_messages_to_history( assistant_buffer.append(msg.content) else: logger.warning(f"Unknown role {msg.role}") - # If there is a pending user message without assistant, treat as current query - # We'll handle this later return system_prompt, history if history else None +def convert_messages_to_prompt(messages: List[ChatMessage]) -> str: + """Convert messages to prompt string. + + Args: + messages: List of ChatMessage objects + + Returns: + str: Formatted prompt string + """ + system_prompt, history = convert_messages_to_history(messages) + + # Get the last user message as query + user_messages = [m.content for m in messages if m.role == "user"] + if not user_messages: + raise ValueError("No user message found") + query = user_messages[-1] + + # Build prompt using chat template + from astrai.tokenize.chat_template import build_prompt + + return build_prompt(query, history) + + @app.get("/health") async def health(): - return {"status": "ok", "model_loaded": _model_param is not None} + return { + "status": "ok", + "model_loaded": _model_param is not None, + "engine_ready": _engine is not None, + } + + +@app.get("/stats") +async def get_stats(): + """Get inference engine statistics.""" + if _engine is None: + raise HTTPException(status_code=503, detail="Engine not initialized") + return _engine.get_stats() @app.post("/v1/chat/completions", response_model=CompletionResponse) async def chat_completion(request: ChatCompletionRequest): - """OpenAI‑compatible chat completion endpoint. + """OpenAI-compatible chat completion endpoint. - Supports both streaming and non‑streaming modes. + Supports both streaming and non-streaming modes with continuous batching. """ - if _model_param is None: - raise HTTPException(status_code=503, detail="Model not loaded") - # Convert messages to query/history - # For simplicity, assume the last user message is the query, previous messages are history - system_prompt, history = convert_messages_to_history(request.messages) - # Extract last user message as query - user_messages = [m.content for m in request.messages if m.role == "user"] - if not user_messages: - raise HTTPException(status_code=400, detail="No user message found") - query = user_messages[-1] - # If there are multiple user messages, we could merge them, but for demo we keep simple + if _engine is None: + raise HTTPException(status_code=503, detail="Engine not initialized") - gen_request = GenerationRequest( - query=query, - temperature=request.temperature, - top_p=request.top_p, - top_k=request.top_k, - max_len=request.max_tokens, - history=history, - system_prompt=system_prompt, - stream=request.stream, - ) + # Convert messages to prompt + prompt = convert_messages_to_prompt(request.messages) if request.stream: - # Return streaming response + # Streaming response (use synchronous generator) + generator = _engine.generate( + prompt=prompt, + stream=True, + max_tokens=request.max_tokens, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + ) + def generate_stream(): - generator = GeneratorFactory.create(_model_param, gen_request) - for chunk in generator.generate(gen_request): - # chunk is the cumulative response string - # For OpenAI compatibility, we send incremental delta - # For simplicity, we send the whole chunk each time - yield f"data: {chunk}\n\n" + for token in generator: + if token == "[DONE]": + break + yield f"data: {json.dumps({'choices': [{'delta': {'content': token}}]})}\n\n" yield "data: [DONE]\n\n" return StreamingResponse( @@ -190,13 +240,17 @@ async def chat_completion(request: ChatCompletionRequest): headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, ) else: - # Non‑streaming - generator = GeneratorFactory.create(_model_param, gen_request) - if gen_request.stream: - # Should not happen because we set stream=False - pass - response_text = generator.generate(gen_request) - # Build OpenAI‑style response + # Non-streaming response + result = _engine.generate( + prompt=prompt, + stream=False, + max_tokens=request.max_tokens, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + ) + + # Build OpenAI-style response import time resp = CompletionResponse( @@ -205,7 +259,7 @@ async def chat_completion(request: ChatCompletionRequest): choices=[ { "index": 0, - "message": {"role": "assistant", "content": response_text}, + "message": {"role": "assistant", "content": result}, "finish_reason": "stop", } ], @@ -223,35 +277,58 @@ async def generate( max_len: int = 2048, stream: bool = False, ): - """Simple generation endpoint compatible with existing GenerationRequest.""" - if _model_param is None: - raise HTTPException(status_code=503, detail="Model not loaded") + """Simple generation endpoint. + + Args: + query: Input query string + history: Conversation history as list of [user, assistant] pairs + temperature: Sampling temperature + top_p: Top-p sampling parameter + top_k: Top-k sampling parameter + max_len: Maximum tokens to generate + stream: Enable streaming output + + Returns: + dict: Generation result with response field + """ + if _engine is None: + raise HTTPException(status_code=503, detail="Engine not initialized") + # Convert history format hist: Optional[List[Tuple[str, str]]] = None if history: - hist = [ - (h[0], h[1]) for h in history - ] # assuming each item is [user, assistant] - gen_request = GenerationRequest( - query=query, - temperature=temperature, - top_p=top_p, - top_k=top_k, - max_len=max_len, - history=hist, - stream=stream, - ) + hist = [(h[0], h[1]) for h in history] + + # Build prompt + from astrai.tokenize.chat_template import build_prompt + + prompt = build_prompt(query, hist) + if stream: + # Synchronous streaming + result = _engine.generate( + prompt=prompt, + stream=True, + max_tokens=max_len, + temperature=temperature, + top_p=top_p, + top_k=top_k, + ) def stream_generator(): - generator = GeneratorFactory.create(_model_param, gen_request) - for chunk in generator.generate(gen_request): - yield chunk + "\n" + for token in result: + yield token + "\n" return StreamingResponse(stream_generator(), media_type="text/plain") else: - generator = GeneratorFactory.create(_model_param, gen_request) - result = generator.generate(gen_request) + result = _engine.generate( + prompt=prompt, + stream=False, + max_tokens=max_len, + temperature=temperature, + top_p=top_p, + top_k=top_k, + ) return {"response": result} @@ -262,6 +339,7 @@ def run_server( device: str = "cuda", dtype: torch.dtype = torch.bfloat16, param_path: Optional[Path] = None, + max_batch_size: int = 16, ): """Run the FastAPI server with uvicorn. @@ -272,6 +350,17 @@ def run_server( device: Device to load model on (e.g., "cuda", "cpu", "cuda:0") dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16) param_path: Path to model parameters directory + max_batch_size: Maximum batch size for continuous batching """ - configure_server(device=device, dtype=dtype, param_path=param_path) - uvicorn.run("astrai.inference.server:app", host=host, port=port, reload=reload) + configure_server( + device=device, + dtype=dtype, + param_path=param_path, + max_batch_size=max_batch_size, + ) + uvicorn.run( + "astrai.inference.server:app", + host=host, + port=port, + reload=reload, + ) diff --git a/scripts/demo/generate_ar.py b/scripts/demo/generate_ar.py index 441a695..a72a94c 100644 --- a/scripts/demo/generate_ar.py +++ b/scripts/demo/generate_ar.py @@ -3,7 +3,7 @@ from pathlib import Path import torch from astrai.config.param_config import ModelParameter -from astrai.inference.generator import GenerationRequest, GeneratorFactory +from astrai.inference import InferenceEngine PROJECT_ROOT = Path(__file__).resolve().parents[2] PARAMETER_ROOT = Path(PROJECT_ROOT, "params") @@ -15,17 +15,15 @@ def generate_text(): query = input(">> ") - request = GenerationRequest( - query=query, + engine = InferenceEngine(param) + response = engine.generate( + prompt=query, + stream=False, + max_tokens=param.config.max_len, temperature=0.8, top_p=0.95, top_k=50, - max_len=param.config.max_len, - history=None, - system_prompt=None, ) - generator = GeneratorFactory.create(param, request) - response = generator.generate(request) print(response) diff --git a/scripts/demo/generate_batch.py b/scripts/demo/generate_batch.py index 7b87a1f..754ab4c 100644 --- a/scripts/demo/generate_batch.py +++ b/scripts/demo/generate_batch.py @@ -3,7 +3,7 @@ from pathlib import Path import torch from astrai.config.param_config import ModelParameter -from astrai.inference.generator import GenerationRequest, GeneratorFactory +from astrai.inference import InferenceEngine PROJECT_ROOT = Path(__file__).resolve().parents[2] PARAMETER_ROOT = Path(PROJECT_ROOT, "params") @@ -21,17 +21,15 @@ def batch_generate(): "请问什么是显卡", ] - request = GenerationRequest( - query=inputs, + engine = InferenceEngine(param) + responses = engine.generate( + prompt=inputs, + stream=False, + max_tokens=param.config.max_len, temperature=0.8, top_p=0.95, top_k=50, - max_len=param.config.max_len, - history=None, - system_prompt=None, ) - generator = GeneratorFactory.create(param, request) - responses = generator.generate(request) for q, r in zip(inputs, responses): print((q, r)) diff --git a/scripts/demo/stream_chat.py b/scripts/demo/stream_chat.py index 9a84708..f06e4fd 100644 --- a/scripts/demo/stream_chat.py +++ b/scripts/demo/stream_chat.py @@ -3,7 +3,7 @@ from pathlib import Path import torch from astrai.config.param_config import ModelParameter -from astrai.inference.generator import GenerationRequest, GeneratorFactory +from astrai.inference import InferenceEngine PROJECT_ROOT = Path(__file__).resolve().parents[2] PARAMETER_ROOT = Path(PROJECT_ROOT, "params") @@ -14,32 +14,27 @@ def chat(): param.to(device="cuda", dtype=torch.bfloat16) history = [] + engine = InferenceEngine(param) + while True: query = input(">> ") if query == "!exit": break - request = GenerationRequest( - query=query, + full_response = "" + + for token in engine.generate( + prompt=query, + stream=True, + max_tokens=param.config.max_len, temperature=0.8, top_p=0.95, top_k=50, - max_len=param.config.max_len, - history=history, - system_prompt=None, - stream=True, - ) - generator = GeneratorFactory.create(param, request) + ): + print(token, end="", flush=True) + full_response += token - response_size = 0 - full_response = "" - for response in generator.generate(request): - # response is the cumulative response up to current token - print(response[response_size:], end="", flush=True) - response_size = len(response) - full_response = response - - # After generation, update history + print() history.append((query, full_response.strip())) diff --git a/scripts/tools/generate.py b/scripts/tools/generate.py index 26f3db3..7011ec5 100644 --- a/scripts/tools/generate.py +++ b/scripts/tools/generate.py @@ -4,7 +4,7 @@ import json import torch from astrai.config.param_config import ModelParameter -from astrai.inference.generator import BatchGenerator, GenerationRequest +from astrai.inference import InferenceEngine def processor( @@ -19,25 +19,22 @@ def processor( ): param = ModelParameter.load(model_dir, disable_init=True) param.to(device="cuda", dtype=torch.bfloat16) - generator = BatchGenerator(param) + engine = InferenceEngine(param) with open(input_json_file, "r", encoding="utf-8") as f: input_data = [json.loads(line) for line in f] queries = [item[question_key] for item in input_data] - request = GenerationRequest( - query=queries, + responses = engine.generate( + prompt=queries, + stream=False, + max_tokens=param.config.max_len, temperature=temperature, top_p=top_p, top_k=top_k, - max_len=param.config.max_len, - history=None, - system_prompt=None, ) - responses = generator.generate(request) - with open(output_json_file, "w", encoding="utf-8") as f: for query, response in zip(queries, responses): output_item = {question_key: query, response_key: response} diff --git a/tests/inference/conftest.py b/tests/inference/conftest.py index 5b0144c..6c55b4d 100644 --- a/tests/inference/conftest.py +++ b/tests/inference/conftest.py @@ -1,11 +1,11 @@ """Shared fixtures for inference tests.""" -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest from fastapi.testclient import TestClient -from astrai.inference.server import app +from astrai.inference.server import app, _engine @pytest.fixture @@ -30,13 +30,17 @@ def mock_model_param(): @pytest.fixture -def mock_generator(mock_model_param): - """Mock the GeneratorFactory and its generators.""" - with patch("astrai.inference.server.GeneratorFactory") as MockFactory: - mock_gen = MagicMock() - mock_gen.generate.return_value = "mock response" - MockFactory.create.return_value = mock_gen - yield MockFactory, mock_gen +def mock_engine(): + """Create a mock InferenceEngine.""" + mock = MagicMock() + mock.generate.return_value = "mock response" + mock.get_stats.return_value = { + "total_tasks": 0, + "total_tokens": 0, + "active_tasks": 0, + "waiting_queue": 0, + } + return mock @pytest.fixture diff --git a/tests/inference/test_server.py b/tests/inference/test_server.py index 00bc61f..c1a0a7e 100644 --- a/tests/inference/test_server.py +++ b/tests/inference/test_server.py @@ -6,24 +6,29 @@ import pytest def test_health_no_model(client, monkeypatch): """GET /health should return 200 even when model not loaded.""" monkeypatch.setattr("astrai.inference.server._model_param", None) + monkeypatch.setattr("astrai.inference.server._engine", None) response = client.get("/health") assert response.status_code == 200 data = response.json() assert data["status"] == "ok" assert not data["model_loaded"] + assert not data["engine_ready"] -def test_health_with_model(client, loaded_model): +def test_health_with_model(client, loaded_model, mock_engine, monkeypatch): """GET /health should return 200 when model is loaded.""" + monkeypatch.setattr("astrai.inference.server._engine", mock_engine) response = client.get("/health") assert response.status_code == 200 - assert response.json() == {"status": "ok", "model_loaded": True} + data = response.json() + assert data["status"] == "ok" + assert data["model_loaded"] is True + assert data["engine_ready"] is True -def test_generate_non_stream(client, loaded_model, mock_generator): +def test_generate_non_stream(client, loaded_model, mock_engine, monkeypatch): """POST /generate with stream=false should return JSON response.""" - MockFactory, mock_gen = mock_generator - mock_gen.generate.return_value = "Test response" + monkeypatch.setattr("astrai.inference.server._engine", mock_engine) response = client.post( "/generate", params={ @@ -37,15 +42,19 @@ def test_generate_non_stream(client, loaded_model, mock_generator): ) assert response.status_code == 200 data = response.json() - assert data["response"] == "Test response" - MockFactory.create.assert_called_once() + assert data["response"] == "mock response" -def test_generate_stream(client, loaded_model, mock_generator): +def test_generate_stream(client, loaded_model, mock_engine, monkeypatch): """POST /generate with stream=true should return plain text stream.""" - MockFactory, mock_gen = mock_generator - # Simulate a streaming generator that yields two chunks - mock_gen.generate.return_value = ["chunk1", "chunk2"] + + # Create a streaming mock + def stream_gen(): + yield "chunk1" + yield "chunk2" + + mock_engine.generate.return_value = stream_gen() + monkeypatch.setattr("astrai.inference.server._engine", mock_engine) response = client.post( "/generate", params={ @@ -66,10 +75,10 @@ def test_generate_stream(client, loaded_model, mock_generator): assert "chunk2" in content -def test_chat_completions_non_stream(client, loaded_model, mock_generator): +def test_chat_completions_non_stream(client, loaded_model, mock_engine, monkeypatch): """POST /v1/chat/completions with stream=false returns OpenAI‑style JSON.""" - MockFactory, mock_gen = mock_generator - mock_gen.generate.return_value = "Assistant reply" + mock_engine.generate.return_value = "Assistant reply" + monkeypatch.setattr("astrai.inference.server._engine", mock_engine) response = client.post( "/v1/chat/completions", json={ @@ -88,11 +97,17 @@ def test_chat_completions_non_stream(client, loaded_model, mock_generator): assert data["choices"][0]["message"]["content"] == "Assistant reply" -def test_chat_completions_stream(client, loaded_model, mock_generator): +def test_chat_completions_stream(client, loaded_model, mock_engine, monkeypatch): """POST /v1/chat/completions with stream=true returns SSE stream.""" - MockFactory, mock_gen = mock_generator + # Simulate a streaming generator that yields cumulative responses - mock_gen.generate.return_value = ["cumulative1", "cumulative2"] + def stream_gen(): + yield "cumulative1" + yield "cumulative2" + yield "[DONE]" + + mock_engine.generate.return_value = stream_gen() + monkeypatch.setattr("astrai.inference.server._engine", mock_engine) response = client.post( "/v1/chat/completions", json={ @@ -116,10 +131,9 @@ def test_chat_completions_stream(client, loaded_model, mock_generator): assert any("cumulative2" in line for line in lines) -def test_generate_with_history(client, loaded_model, mock_generator): +def test_generate_with_history(client, loaded_model, mock_engine, monkeypatch): """POST /generate with history parameter.""" - MockFactory, mock_gen = mock_generator - mock_gen.generate.return_value = "Response with history" + monkeypatch.setattr("astrai.inference.server._engine", mock_engine) response = client.post( "/generate", params={ @@ -129,12 +143,8 @@ def test_generate_with_history(client, loaded_model, mock_generator): }, ) assert response.status_code == 200 - MockFactory.create.assert_called_once() - # Check that history was passed correctly (currently history is not parsed due to FastAPI limitation) - call_args = MockFactory.create.call_args - req = call_args[0][1] # second argument is GenerationRequest - # Because history cannot be passed via query params, it will be None - assert req.history is None + # Verify the engine.generate was called + mock_engine.generate.assert_called_once() if __name__ == "__main__": diff --git a/tests/module/test_module.py b/tests/module/test_module.py index 06bf989..64478f6 100644 --- a/tests/module/test_module.py +++ b/tests/module/test_module.py @@ -3,8 +3,6 @@ import os import torch from astrai.config.param_config import ModelParameter -from astrai.inference.generator import EmbeddingEncoderCore, GeneratorCore - def test_model_parameter(test_env): save_dir = os.path.join(test_env["test_dir"], "save") @@ -32,40 +30,4 @@ def test_transformer(test_env): test_env["transformer_config"].max_len, test_env["transformer_config"].vocab_size, ) - assert output_logits.shape == target_shape - - -# generator -def test_embedding_encoder_core(test_env): - parameter = ModelParameter( - test_env["model"], test_env["tokenizer"], test_env["transformer_config"] - ) - encoder = EmbeddingEncoderCore(parameter) - - single_emb = encoder.encode("测试文本") - assert isinstance(single_emb, torch.Tensor) - assert single_emb.shape[-1] == test_env["transformer_config"].dim - - batch_emb = encoder.encode(["测试1", "测试2"]) - assert isinstance(batch_emb, list) - assert len(batch_emb) == 2 - - -def test_generator_core(test_env): - parameter = ModelParameter( - test_env["model"], test_env["tokenizer"], test_env["transformer_config"] - ) - generator = GeneratorCore(parameter) - input_ids = torch.randint(0, test_env["transformer_config"].vocab_size, (4, 10)) - next_token_id, cache_increase = generator.generate_iterator( - input_ids=input_ids, - temperature=0.8, - top_k=50, - top_p=0.95, - attn_mask=None, - kv_caches=None, - start_pos=0, - ) - - assert next_token_id.shape == (4, 1) - assert cache_increase == 10 + assert output_logits.shape == target_shape \ No newline at end of file