refactor: 更新inference 部分的实现

This commit is contained in:
ViperEkura 2026-04-04 23:49:18 +08:00
parent 99b821ebf5
commit 861d33b1a1
13 changed files with 965 additions and 758 deletions

View File

@ -8,13 +8,9 @@ from astrai.config import (
from astrai.factory import BaseFactory
from astrai.dataset import DatasetFactory
from astrai.tokenize import BpeTokenizer
from astrai.inference.generator import (
BatchGenerator,
EmbeddingEncoder,
from astrai.inference import (
GenerationRequest,
GeneratorFactory,
LoopGenerator,
StreamGenerator,
InferenceEngine,
)
from astrai.model.transformer import Transformer
from astrai.trainer import SchedulerFactory, StrategyFactory, Trainer
@ -26,11 +22,7 @@ __all__ = [
"DatasetFactory",
"BpeTokenizer",
"GenerationRequest",
"LoopGenerator",
"StreamGenerator",
"BatchGenerator",
"EmbeddingEncoder",
"GeneratorFactory",
"InferenceEngine",
"Trainer",
"StrategyFactory",
"SchedulerFactory",

View File

@ -1,25 +1,34 @@
from astrai.inference.core import (
EmbeddingEncoderCore,
GeneratorCore,
KVCacheManager,
)
from astrai.inference.generator import (
BatchGenerator,
EmbeddingEncoder,
"""
AstrAI Inference Module
Provides inference components for text generation with continuous batching support.
Main Components:
- InferenceEngine: Unified inference engine for continuous batching
- InferenceScheduler: Task scheduling with dynamic batch composition
- Task, TaskStatus: Task management for continuous batching
- GenerationRequest: Request parameters for generation
- apply_sampling_strategies: Sampling utilities for text generation
Author: AstrAI Team
"""
from astrai.inference.engine import (
GenerationRequest,
GeneratorFactory,
LoopGenerator,
StreamGenerator,
InferenceEngine,
InferenceScheduler,
Task,
TaskStatus,
apply_sampling_strategies,
)
__all__ = [
"GeneratorCore",
"EmbeddingEncoderCore",
"KVCacheManager",
# Engine
"InferenceEngine",
"InferenceScheduler",
"Task",
"TaskStatus",
"GenerationRequest",
"LoopGenerator",
"StreamGenerator",
"BatchGenerator",
"EmbeddingEncoder",
"GeneratorFactory",
# Sampling
"apply_sampling_strategies",
]

View File

@ -1,272 +0,0 @@
from typing import Any, Callable, List, Optional, Self, Tuple, Union
import torch
from torch import Tensor
from astrai.config import ModelConfig, ModelParameter
def apply_sampling_strategies(
logits: Tensor,
temperature: float,
top_k: int,
top_p: float,
filter_value: float = -float("inf"),
) -> Tensor:
"""
Apply sampling strategies to the logits tensor.
Args:
logits (Tensor): The logits tensor.
temperature (float): The temperature parameter.
top_k (int): The top-k parameter.
top_p (float): The top-p parameter.
filter_value (float, optional): The filter value. Defaults to -float("inf").
Returns:
Tensor: The sampled logits tensor.
"""
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
top_k = min(top_k, logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(
dim=1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
class GeneratorCore:
def __init__(self, parameter: ModelParameter):
self.model = parameter.model
self.tokenizer = parameter.tokenizer
self.config = parameter.config
def generate_iterator(
self,
input_ids: Tensor,
temperature: float,
top_k: int,
top_p: float,
attn_mask: Optional[Tensor] = None,
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
start_pos: int = 0,
) -> Tuple[Tensor, int]:
with torch.inference_mode():
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
logits = outputs["logits"][:, -1, :]
cache_increase = input_ids.size(-1)
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
probs = torch.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
return next_token_id, cache_increase
def generate_loop(
self,
input_ids: Tensor,
ids: List[int],
temperature: float,
top_k: int,
top_p: float,
attn_mask: Optional[Tensor] = None,
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
start_pos: int = 0,
callback: Optional[Callable[..., Any]] = None,
) -> List[int]:
cur_cache_pos = start_pos
for _ in range(len(ids), self.config.max_len):
next_token_id, cache_increase = self.generate_iterator(
input_ids,
temperature,
top_k,
top_p,
attn_mask,
kv_caches,
cur_cache_pos,
)
input_ids = next_token_id
ids.append(next_token_id.item())
cur_cache_pos += cache_increase
if callback:
callback(next_token_id.item(), ids.copy())
if next_token_id.item() in self.tokenizer.stop_ids:
break
return ids
def to(self, *args, **kargs) -> Self:
self.model.to(*args, **kargs)
return self
class EmbeddingEncoderCore:
def __init__(self, parameter: ModelParameter):
self.model = parameter.model
self.tokenizer = parameter.tokenizer
self.config = parameter.config
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
with_batch = isinstance(sentence, list)
ids = self.tokenizer.encode(sentence)
batch_ids = ids if with_batch else [ids]
max_model_len = self.config.max_len
all_fragments = []
fragment_origin_idx = []
for i, seq in enumerate(batch_ids):
if len(seq) > max_model_len:
fragments = [
seq[j : j + max_model_len]
for j in range(0, len(seq), max_model_len)
]
all_fragments.extend(fragments)
fragment_origin_idx.extend([i] * len(fragments))
else:
all_fragments.append(seq)
fragment_origin_idx.append(i)
# if empty fragments
if not all_fragments or not ids:
return [] if with_batch else torch.tensor([])
device = next(self.model.parameters()).device
max_len = min(max(len(seq) for seq in all_fragments), max_model_len)
padded_ids = []
masks = []
for seq in all_fragments:
pad_len = max_len - len(seq)
padded_seq = seq + [self.tokenizer.pad_id] * pad_len
mask = [token_id != self.tokenizer.pad_id for token_id in padded_seq]
padded_ids.append(padded_seq)
masks.append(mask)
input_tensor = torch.tensor(padded_ids, device=device, dtype=torch.long)
seq_mask = torch.tensor(masks, device=device, dtype=torch.bool)
with torch.inference_mode():
outputs = self.model(input_tensor, seq_mask)["hidden_states"]
# [num_fragments, seq_len, hidden_size]
fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1))
sentence_embs: List[Tensor] = []
for i in range(len(batch_ids)):
indices = [
idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i
]
if indices:
sum_frags = torch.sum(
fragment_embs[indices, :, :], dim=1
) # [frags, hidden_size]
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(
1
) # [frags, 1]
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
sentence_embs.append(emb.flatten())
if with_batch:
return [emb.flatten() for emb in sentence_embs]
else:
return sentence_embs[0].flatten()
def to(self, *args, **kargs) -> Self:
self.model.to(*args, **kargs)
return self
class KVCacheManager:
def __init__(
self,
config: ModelConfig,
batch_size: int,
device: torch.device = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
self.batch_size = batch_size
self.device = device
self.dtype = dtype
self.num_layers = config.n_layers
self.max_len = config.max_len
self.num_heads = config.n_kv_heads
self.head_dim = config.dim // config.n_heads
self._kv_cache: Tuple[Tensor, Tensor] = None
self._seq_mask: Tensor = None
self._initialize()
def _initialize(self):
k_cache = torch.empty(
(
self.batch_size,
self.max_len,
self.num_layers,
self.num_heads,
self.head_dim,
),
device=self.device,
dtype=self.dtype,
)
v_cache = torch.empty(
(
self.batch_size,
self.max_len,
self.num_layers,
self.num_heads,
self.head_dim,
),
device=self.device,
dtype=self.dtype,
)
self._kv_cache = (k_cache, v_cache)
self._seq_mask = torch.ones(
(self.batch_size, self.max_len), device=self.device, dtype=torch.bool
)
def update(self, active_mask: Tensor):
k_cache, v_cache = self._kv_cache
self._kv_cache = (k_cache[active_mask], v_cache[active_mask])
self._seq_mask = self._seq_mask[active_mask]
def reset(self, full_reset=False):
if full_reset:
self._kv_cache = None
self._seq_mask = None
else:
self._initialize()
def set_seq_mask(self, input_ids: Tensor, pad_id: int):
batch_size, seq_len = input_ids.shape
bool_mask = input_ids != pad_id
self._seq_mask[:batch_size, :seq_len] = bool_mask
def get_kvcache(self) -> Tuple[Tensor, Tensor]:
return self._kv_cache
def get_seq_mask(self) -> Tensor:
return self._seq_mask

694
astrai/inference/engine.py Normal file
View File

@ -0,0 +1,694 @@
"""
Continuous Batching Inference Engine
This module provides the main continuous batching components:
- Task: Individual generation task with state management
- TaskStatus: Task state enumeration
- InferenceScheduler: Handles request scheduling and KV cache management
- InferenceEngine: Unified inference engine
Author: AstrAI Team
"""
import threading
import time
import uuid
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
import torch
from torch import Tensor
from astrai.config import ModelConfig, ModelParameter
from astrai.tokenize.chat_template import HistoryType, build_prompt
# Use print for debugging instead of logging
def _debug(*args, **kwargs):
pass
@dataclass
class GenerationRequest:
"""Request parameters for text generation."""
top_k: int
top_p: float
temperature: float
max_len: int
query: Union[str, List[str]]
history: Optional[Union[HistoryType, List[HistoryType]]] = None
system_prompt: Optional[str] = None
stream: bool = False
def __post_init__(self):
if not isinstance(self.top_k, int) or self.top_k < 0:
raise ValueError("top_k must be a non-negative integer")
if not isinstance(self.top_p, float) or self.top_p < 0.0 or self.top_p > 1.0:
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not isinstance(self.temperature, float) or self.temperature < 0.0:
raise ValueError("temperature must be a non-negative float")
class TaskStatus(Enum):
"""Task state enumeration for continuous batching.
States:
PENDING: Task is waiting to be scheduled
RUNNING: Task is currently being processed
FINISHED: Task completed successfully
ABORTED: Task was cancelled or failed
"""
PENDING = auto()
RUNNING = auto()
FINISHED = auto()
ABORTED = auto()
@dataclass
class Task:
"""Individual task for continuous batching.
Attributes:
task_id: Unique task identifier
prompt_ids: Input token IDs
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Top-p sampling parameter
top_k: Top-k sampling parameter
status: Current task status
output_ids: Generated token IDs
input_tokens: Number of input tokens
output_tokens: Number of output tokens generated
slot: Batch slot position (-1 if not assigned)
arrival_time: Task arrival timestamp
finish_time: Task completion timestamp
stream_callback: Callback for streaming output
"""
task_id: str
prompt_ids: List[int]
max_tokens: int = 1024
temperature: float = 1.0
top_p: float = 1.0
top_k: int = 50
status: TaskStatus = TaskStatus.PENDING
output_ids: List[int] = field(default_factory=list)
input_tokens: int = 0
output_tokens: int = 0
slot: int = -1
arrival_time: float = field(default_factory=time.time)
finish_time: Optional[float] = None
stream_callback: Optional[Callable[[str], None]] = None
def is_finished(self, stop_ids: List[int]) -> bool:
"""Check if task is finished."""
if self.output_ids and self.output_ids[-1] in stop_ids:
return True
if self.output_tokens >= self.max_tokens:
return True
return False
def apply_sampling_strategies(
logits: Tensor,
temperature: float,
top_k: int,
top_p: float,
filter_value: float = -float("inf"),
) -> Tensor:
"""Apply sampling strategies to the logits tensor."""
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
top_k = min(top_k, logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(
dim=1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
class InferenceScheduler:
"""Inference scheduler with continuous batching support.
Manages request scheduling, KV cache allocation, and generation loop.
Supports dynamic batch composition where new requests can join at any time
and completed requests are immediately released.
"""
def __init__(
self,
model,
tokenizer,
config: ModelConfig,
max_batch_size: int = 16,
max_seq_len: Optional[int] = None,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
self.model = model
self.tokenizer = tokenizer
self.config = config
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len or config.max_len
self.device = device
self.dtype = dtype
num_heads = config.n_kv_heads
head_dim = config.dim // config.n_heads
n_layers = config.n_layers
k_cache = torch.empty(
(
max_batch_size,
self.max_seq_len,
n_layers,
num_heads,
head_dim,
),
device=device,
dtype=dtype,
)
v_cache = torch.empty(
(
max_batch_size,
self.max_seq_len,
n_layers,
num_heads,
head_dim,
),
device=device,
dtype=dtype,
)
self.kv_cache = (k_cache, v_cache)
self.seq_mask = torch.ones(
(max_batch_size, self.max_seq_len), device=device, dtype=torch.bool
)
self.waiting_queue: List[Task] = []
self.active_tasks: List[Task] = []
self._running = False
self._task_event = threading.Event()
self._lock = threading.Lock()
self._total_tasks = 0
self._total_tokens = 0
def add_task(
self,
prompt: str,
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None,
) -> str:
"""Add a new task to the waiting queue."""
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
prompt_ids = self.tokenizer.encode(prompt)
_debug(
f"add_task: task_id={task_id}, prompt_len={len(prompt_ids)}, has_callback={stream_callback is not None}"
)
task = Task(
task_id=task_id,
prompt_ids=prompt_ids,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=stream_callback,
)
with self._lock:
self.waiting_queue.append(task)
self._total_tasks += 1
self._task_event.set()
return task_id
def remove_task(self, task_id: str) -> None:
"""Remove a task from the scheduler."""
with self._lock:
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
def _remove_finished_tasks(self) -> None:
"""Remove finished tasks from active batch and update caches."""
finished = []
for task in self.active_tasks:
if task.is_finished(self.tokenizer.stop_ids):
task.status = TaskStatus.FINISHED
task.finish_time = time.time()
finished.append(task)
self._total_tokens += task.output_tokens
for task in finished:
slot = task.slot
if slot >= 0 and slot < len(self.active_tasks):
self.seq_mask[slot, :] = False
task.slot = -1
self.active_tasks = [
t for t in self.active_tasks if t.status != TaskStatus.FINISHED
]
def _refill_active_batch(self) -> None:
"""Refill active batch with waiting tasks."""
available_slots = self.max_batch_size - len(self.active_tasks)
if available_slots <= 0:
return
with self._lock:
to_add = []
for _ in range(min(available_slots, len(self.waiting_queue))):
if self.waiting_queue:
task = self.waiting_queue.pop(0)
task.status = TaskStatus.RUNNING
to_add.append(task)
for task in to_add:
for i in range(self.max_batch_size):
if all(t.slot != i for t in self.active_tasks):
task.slot = i
break
self.active_tasks.append(task)
def _execute_prefill(self, tasks: List[Task]) -> None:
"""Execute Prefill phase: process entire prompt at once."""
if not tasks:
return
_debug(f"_execute_prefill: processing {len(tasks)} tasks")
# Sort tasks by slot to ensure correct batch indexing with KV cache
tasks = sorted(tasks, key=lambda t: t.slot)
prompt_lens = [len(task.prompt_ids) for task in tasks]
max_len = max(prompt_lens)
input_ids = torch.zeros(
len(tasks), max_len, dtype=torch.long, device=self.device
)
for i, task in enumerate(tasks):
if len(task.prompt_ids) > 0:
input_ids[i, : len(task.prompt_ids)] = torch.tensor(
task.prompt_ids, device=self.device
)
# Create boolean mask for attention
if self.tokenizer.pad_id is not None:
input_mask = torch.ne(input_ids, self.tokenizer.pad_id)
else:
input_mask = torch.ones(
input_ids.shape, dtype=torch.bool, device=self.device
)
_debug(
f"_execute_prefill: input_ids shape={input_ids.shape}, max_len={max_len}"
)
try:
with torch.inference_mode():
outputs = self.model(
input_ids,
input_mask=input_mask,
start_pos=0,
persistent_key_values=self.kv_cache,
)
_debug(
f"_execute_prefill: model forward done, output keys={outputs.keys() if hasattr(outputs, 'keys') else 'no keys'}"
)
except Exception as e:
_debug(f"_execute_prefill: ERROR: {e}")
raise
for i, task in enumerate(tasks):
task.input_tokens = prompt_lens[i]
task.output_tokens = 0
_debug(
f" task {task.task_id}: input_tokens={task.input_tokens}, output_tokens={task.output_tokens}"
)
for task in tasks:
if task.slot >= 0:
self.seq_mask[task.slot, : task.input_tokens] = True
_debug(f"_execute_prefill: done, {len(tasks)} tasks marked as prefill complete")
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
"""Execute Decode phase: generate one token at a time."""
if not tasks:
return
_debug(f"_execute_decode: processing {len(tasks)} tasks, start_pos={start_pos}")
# Sort tasks by slot to ensure batch index aligns with slot (KV cache position)
# Task at slot 0 → batch index 0 → KV stored at cache[0]
# Task at slot 1 → batch index 1 → KV stored at cache[1]
tasks = sorted(tasks, key=lambda t: t.slot)
input_ids = torch.zeros(len(tasks), dtype=torch.long, device=self.device)
for i, task in enumerate(tasks):
if task.output_ids:
input_ids[i] = task.output_ids[-1]
else:
input_ids[i] = task.prompt_ids[-1]
input_tensor = input_ids.unsqueeze(1) # shape: (batch, 1)
# Create 2D attention mask: (batch, seq_len)
active_mask = torch.ones((len(tasks), 1), dtype=torch.bool, device=self.device)
_debug(
f"_execute_decode: input_tensor shape={input_tensor.shape}, active_mask shape={active_mask.shape}"
)
try:
with torch.inference_mode():
outputs = self.model(
input_tensor,
input_mask=active_mask,
persistent_key_values=self.kv_cache,
start_pos=start_pos,
)
_debug(
f"_execute_decode: model forward done, logits shape={outputs['logits'].shape}"
)
logits = outputs["logits"][:, -1, :]
except Exception as e:
_debug(f"_execute_decode: ERROR: {e}")
raise
next_token_ids = []
for i, task in enumerate(tasks):
logit = logits[i : i + 1]
logit = apply_sampling_strategies(
logit,
task.temperature,
task.top_k,
task.top_p,
)
probs = torch.softmax(logit, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
next_token_ids.append(next_token.item())
_debug(f"_execute_decode: next_tokens={next_token_ids}")
for task, next_token in zip(tasks, next_token_ids):
task.output_ids.append(next_token)
task.output_tokens += 1
pos = task.input_tokens + task.output_tokens
if task.slot >= 0 and pos < self.max_seq_len:
self.seq_mask[task.slot, pos] = True
if task.stream_callback:
token_str = self.tokenizer.decode([next_token])
task.stream_callback(token_str)
# Check if any task reached max_tokens or stop token
for task in tasks:
if task.output_tokens >= task.max_tokens or (
task.output_ids and task.output_ids[-1] in self.tokenizer.stop_ids
):
_debug(
f"decode: task {task.task_id} finished, output_tokens={task.output_tokens}, max_tokens={task.max_tokens}"
)
if task.stream_callback:
task.stream_callback("[DONE]")
def _run_generation_loop(self) -> None:
"""Main generation loop with continuous batching."""
_debug("generation_loop: started")
while self._running:
self._remove_finished_tasks()
self._refill_active_batch()
if not self.active_tasks:
self._task_event.wait(timeout=0.01)
self._task_event.clear()
continue
_debug(
f"generation_loop: active={len(self.active_tasks)}, waiting={len(self.waiting_queue)}"
)
new_tasks = [t for t in self.active_tasks if t.output_tokens == 0]
decode_tasks = [t for t in self.active_tasks if t.output_tokens > 0]
_debug(
f"generation_loop: new_tasks={len(new_tasks)}, decode_tasks={len(decode_tasks)}"
)
for t in self.active_tasks:
_debug(
f" active task {t.task_id}: output_tokens={t.output_tokens}, input_tokens={t.input_tokens}"
)
if decode_tasks:
start_pos = max(t.input_tokens + t.output_tokens for t in decode_tasks)
else:
start_pos = 0
# First run prefill for new tasks
if new_tasks:
_debug(f"generation_loop: running prefill for {len(new_tasks)} tasks")
self._execute_prefill(new_tasks)
_debug(f"generation_loop: prefill done")
# After prefill, convert these tasks to decode tasks in the same iteration
decode_tasks = new_tasks
start_pos = max(t.input_tokens for t in decode_tasks)
_debug(
f"generation_loop: after prefill, decode_tasks={len(decode_tasks)}, start_pos={start_pos}"
)
if decode_tasks:
_debug(
f"generation_loop: running decode for {len(decode_tasks)} tasks, start_pos={start_pos}"
)
self._execute_decode(decode_tasks, start_pos)
_debug(f"generation_loop: decode done")
if not self.active_tasks and not self.waiting_queue:
time.sleep(0.001)
def start(self) -> None:
"""Start the generation loop in a background thread."""
if not self._running:
_debug("InferenceScheduler.start: starting loop thread")
self._running = True
self._loop_thread = threading.Thread(target=self._run_generation_loop)
self._loop_thread.daemon = True
self._loop_thread.start()
_debug("InferenceScheduler.start: loop thread started")
def stop(self) -> None:
"""Stop the generation loop."""
self._running = False
if hasattr(self, "_loop_thread"):
self._loop_thread.join(timeout=1.0)
def get_stats(self) -> Dict[str, Any]:
"""Get scheduler statistics."""
return {
"total_tasks": self._total_tasks,
"total_tokens": self._total_tokens,
"active_tasks": len(self.active_tasks),
"waiting_queue": len(self.waiting_queue),
}
class InferenceEngine:
"""Unified inference engine for continuous batching.
Provides a single interface for:
- Single request generation (streaming or non-streaming)
- Batch request generation (streaming or non-streaming)
"""
def __init__(
self,
parameter: ModelParameter,
max_batch_size: int = 16,
max_seq_len: Optional[int] = None,
):
self.model = parameter.model
self.tokenizer = parameter.tokenizer
self.config = parameter.config
model_params = next(self.model.parameters())
self.device = model_params.device
self.dtype = model_params.dtype
self.scheduler = InferenceScheduler(
model=self.model,
tokenizer=self.tokenizer,
config=self.config,
max_batch_size=max_batch_size,
max_seq_len=max_seq_len,
device=self.device,
dtype=self.dtype,
)
self.kv_cache = self.scheduler.kv_cache
self.seq_mask = self.scheduler.seq_mask
_debug("InferenceEngine: starting scheduler")
self.scheduler.start()
_debug("InferenceEngine: scheduler started")
def generate(
self,
prompt: Union[str, List[str]],
stream: bool = False,
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
) -> Union[Generator[str, None, None], str, List[str]]:
"""Unified generation interface."""
is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt]
if stream:
return self._generate_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k
)
else:
return self._generate_non_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k
)
def generate_with_request(
self, request: GenerationRequest
) -> Union[Generator[str, None, None], str, List[str]]:
"""Generate with GenerationRequest object."""
prompt = build_prompt(request.query, request.history)
return self.generate(
prompt=prompt,
stream=request.stream,
max_tokens=request.max_len,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
)
def _generate_streaming(
self,
prompts: List[str],
is_batch: bool,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]:
"""Generate with streaming output (synchronous)."""
results = []
_debug(f"_generate_streaming: prompts={len(prompts)}")
if is_batch:
raise NotImplementedError("Batch streaming is not implemented yet")
def make_callback(idx: int):
def cb(token: str):
_debug(f"callback[{idx}]: token={token!r}")
results.append(token)
return cb
for i, p in enumerate(prompts):
_debug(f"_generate_streaming: adding task {i}: {p[:30]}...")
self.scheduler.add_task(
prompt=p,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=make_callback(i),
)
def gen():
_debug("generator: start yielding")
while True:
# Yield accumulated tokens
while results:
token = results.pop(0)
if token == "[DONE]":
_debug("generator: got [DONE]")
return
_debug(f"generator: yielding {token!r}")
yield token
time.sleep(0.01)
return gen()
def _generate_non_streaming(
self,
prompts: List[str],
is_batch: bool,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
) -> Union[str, List[str]]:
"""Generate without streaming."""
results = ["" for _ in range(len(prompts))]
done_flags = [False] * len(prompts)
lock = threading.Lock()
def make_callback(idx: int):
def cb(token: str):
if token == "[DONE]":
done_flags[idx] = True
else:
with lock:
results[idx] += token
return cb
for i, p in enumerate(prompts):
self.scheduler.add_task(
prompt=p,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=make_callback(i),
)
while not all(done_flags):
time.sleep(0.001)
return results if is_batch else results[0]
def get_stats(self) -> Dict[str, Any]:
"""Get engine statistics."""
return self.scheduler.get_stats()
def shutdown(self) -> None:
"""Shutdown the engine."""
self.scheduler.stop()

View File

@ -1,269 +0,0 @@
from dataclasses import dataclass
from typing import Generator, List, Optional, Tuple, Union
import torch
from torch import Tensor
from astrai.config.param_config import ModelParameter
from astrai.factory import BaseFactory
from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager
from astrai.tokenize.chat_template import HistoryType, build_prompt
def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]:
"""
Pad a list of sequences to a fixed length.
Args:
ids_list (List[List[int]]): A list of sequences.
max_ids_len (int): The maximum length of sequences.
pad_id (int): The id to pad sequences.
Returns:
List[List[int]]: A list of padded sequences.
"""
max_ids_len = max(len(ids) for ids in ids_list)
new_ids_list = []
for ids in ids_list:
pad_len = max_ids_len - len(ids)
padded_seq = [pad_id] * pad_len + ids
new_ids_list.append(padded_seq)
return new_ids_list, max_ids_len
@dataclass
class GenerationRequest:
"""
Request parameters for text generation.
Attributes:
top_k: Top-k sampling parameter.
top_p: Top-p (nucleus) sampling parameter.
temperature: Sampling temperature.
max_len: Maximum generation length.
query: Input query (string or list of strings for batch).
history: Conversation history.
system_prompt: System prompt for the conversation.
stream: Whether to use streaming generation.
"""
top_k: int
top_p: float
temperature: float
max_len: int
query: Union[str, List[str]]
history: Optional[Union[HistoryType, List[HistoryType]]] = None
system_prompt: Optional[str] = None
stream: bool = False
def __post_init__(self):
if not isinstance(self.top_k, int) or self.top_k < 0:
raise ValueError("top_k must be a non-negative integer")
if not isinstance(self.top_p, float) or self.top_p < 0.0 or self.top_p > 1.0:
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not isinstance(self.temperature, float) or self.temperature < 0.0:
raise ValueError("temperature must be a non-negative float")
class LoopGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def generate(self, request: GenerationRequest) -> str:
model_params = next(self.model.parameters())
device = model_params.device
dtype = model_params.dtype
cache_manager = KVCacheManager(self.config, 1, device=device, dtype=dtype)
prompt = build_prompt(request.query, request.history)
ids = self.tokenizer.encode(prompt)
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
start_cache_pos = len(ids)
self.model.eval()
kv_caches = cache_manager.get_kvcache()
ids = self.generate_loop(
input_ids,
ids,
request.temperature,
request.top_k,
request.top_p,
kv_caches=kv_caches,
)
response = self.tokenizer.decode(ids[start_cache_pos:])
return response
class StreamGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def generate(self, request: GenerationRequest) -> Generator[str, None, None]:
model_params = next(self.model.parameters())
device = model_params.device
dtype = model_params.dtype
cache_manager = KVCacheManager(self.config, 1, device=device, dtype=dtype)
prompt = build_prompt(request.query, request.history)
ids = self.tokenizer.encode(prompt)
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
start_cache_pos = len(ids)
cur_cache_pos = 0
self.model.eval()
kv_caches = cache_manager.get_kvcache()
for _ in range(len(ids), self.config.max_len):
next_token_id, cache_increase = self.generate_iterator(
input_ids,
request.temperature,
request.top_k,
request.top_p,
kv_caches=kv_caches,
start_pos=cur_cache_pos,
)
input_ids = next_token_id
ids.append(next_token_id.item())
cur_cache_pos += cache_increase
response = self.tokenizer.decode(ids[start_cache_pos:])
yield response
if next_token_id.item() in self.tokenizer.stop_ids:
yield response + "\n"
break
class BatchGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def generate(self, request: GenerationRequest) -> List[str]:
batch_size = len(request.query)
if request.history is None:
request.history = [[] for _ in range(batch_size)]
prompts = [
build_prompt(query, history)
for query, history in zip(request.query, request.history)
]
ids_list = [self.tokenizer.encode(prompt) for prompt in prompts]
ids_list, max_ids_len = pad_sequence(ids_list, self.tokenizer.pad_id)
model_params = next(self.model.parameters())
device = model_params.device
dtype = model_params.dtype
cache_manager = KVCacheManager(
self.config, batch_size, device=device, dtype=dtype
)
input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long)
cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id)
activate_task_mask = [True] * batch_size
start_cache_pos = max_ids_len
cur_cache_pos = 0
while max_ids_len < self.config.max_len and sum(activate_task_mask) != 0:
kv_caches = cache_manager.get_kvcache()
attn_mask = cache_manager.get_seq_mask()
next_token_id, cache_increase = self.generate_iterator(
input_tensor,
request.temperature,
request.top_k,
request.top_p,
attn_mask=attn_mask,
kv_caches=kv_caches,
start_pos=cur_cache_pos,
)
cur_cache_pos += cache_increase
active_mask = []
c_ids = 0
for i in range(batch_size):
if activate_task_mask[i]:
token = next_token_id[c_ids, :].item()
ids_list[i].append(token)
c_ids += 1
is_active = token not in self.tokenizer.stop_ids
activate_task_mask[i] = is_active
active_mask.append(is_active)
active_mask = torch.tensor(active_mask, device=device, dtype=torch.bool)
cache_manager.update(active_mask)
input_tensor = next_token_id[active_mask, :]
max_ids_len += 1
responses = [str()] * batch_size
for i in range(batch_size):
responses[i] = self.tokenizer.decode(ids_list[i][start_cache_pos:])
request.history[i].append((request.query[i], responses[i]))
return responses
class EmbeddingEncoder(EmbeddingEncoderCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
return super().encode(sentence)
class GeneratorFactory(BaseFactory[GeneratorCore]):
"""Factory class for creating generator instances.
Provides smart generator selection based on request characteristics:
- Streaming: Use StreamGenerator for streaming output
- Batch: Use BatchGenerator when query is a list
- Single: Use LoopGenerator for single query non-streaming
Example usage:
generator = GeneratorFactory.create(parameter, request)
result = generator.generate(request)
"""
@staticmethod
def create(parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore:
"""Create a generator based on request characteristics.
Args:
parameter: Model parameters containing model, tokenizer, config
request: Generation request with query, options, etc.
Returns:
Appropriate GeneratorCore subclass instance
"""
# Streaming generation: check stream field first
if request.stream:
return StreamGenerator(parameter)
# Batch generation: query is a list of strings
if isinstance(request.query, list):
return BatchGenerator(parameter)
# Default: single query non-streaming
return LoopGenerator(parameter)
@staticmethod
def create_encoder(parameter: ModelParameter) -> EmbeddingEncoderCore:
"""Create an embedding encoder instance.
Args:
parameter: Model parameters
Returns:
EmbeddingEncoderCore instance
"""
return EmbeddingEncoder(parameter)

View File

@ -1,3 +1,13 @@
"""
Inference Server with Continuous Batching Support
FastAPI server for inference with continuous batching.
Provides OpenAI-compatible chat completion endpoints.
Author: AstrAI Team
"""
import json
import logging
from contextlib import asynccontextmanager
from pathlib import Path
@ -10,12 +20,13 @@ from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from astrai.config.param_config import ModelParameter
from astrai.inference.generator import GenerationRequest, GeneratorFactory
from astrai.inference.engine import GenerationRequest, InferenceEngine
logger = logging.getLogger(__name__)
# Global model parameter (loaded once)
# Global model parameter and engine (loaded once)
_model_param: Optional[ModelParameter] = None
_engine: Optional[InferenceEngine] = None
_project_root = Path(__file__).parent.parent.parent
# Server configuration (set before running server)
@ -23,6 +34,7 @@ _server_config: Dict[str, Any] = {
"device": "cuda",
"dtype": torch.bfloat16,
"param_path": None,
"max_batch_size": 16,
}
@ -30,6 +42,7 @@ def configure_server(
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
param_path: Optional[Path] = None,
max_batch_size: int = 16,
):
"""Configure server settings before starting.
@ -37,40 +50,47 @@ def configure_server(
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
param_path: Path to model parameters directory
max_batch_size: Maximum batch size for continuous batching
"""
_server_config["device"] = device
_server_config["dtype"] = dtype
_server_config["param_path"] = param_path
_server_config["max_batch_size"] = max_batch_size
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events."""
global _model_param, _engine
# Startup: Load model with configured settings
try:
load_model(
param_path=_server_config["param_path"],
device=_server_config["device"],
dtype=_server_config["dtype"],
max_batch_size=_server_config["max_batch_size"],
)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
# Shutdown: Cleanup if needed
pass
# Shutdown: Cleanup engine
if _engine:
_engine.shutdown()
logger.info("Inference engine shutdown complete")
app = FastAPI(title="AstrAI Inference Server", version="0.1.0", lifespan=lifespan)
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
def load_model(
param_path: Optional[Path] = None,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 16,
):
"""Load model parameters into global variable."""
global _model_param
"""Load model parameters and initialize inference engine."""
global _model_param, _engine
if param_path is None:
param_path = _project_root / "params"
if not param_path.exists():
@ -79,6 +99,13 @@ def load_model(
_model_param.to(device=device, dtype=dtype)
logger.info(f"Model loaded on {device} with dtype {dtype}")
# Initialize inference engine with continuous batching
_engine = InferenceEngine(
parameter=_model_param,
max_batch_size=max_batch_size,
)
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
# Pydantic models for API request/response
class ChatMessage(BaseModel):
@ -134,54 +161,77 @@ def convert_messages_to_history(
assistant_buffer.append(msg.content)
else:
logger.warning(f"Unknown role {msg.role}")
# If there is a pending user message without assistant, treat as current query
# We'll handle this later
return system_prompt, history if history else None
def convert_messages_to_prompt(messages: List[ChatMessage]) -> str:
"""Convert messages to prompt string.
Args:
messages: List of ChatMessage objects
Returns:
str: Formatted prompt string
"""
system_prompt, history = convert_messages_to_history(messages)
# Get the last user message as query
user_messages = [m.content for m in messages if m.role == "user"]
if not user_messages:
raise ValueError("No user message found")
query = user_messages[-1]
# Build prompt using chat template
from astrai.tokenize.chat_template import build_prompt
return build_prompt(query, history)
@app.get("/health")
async def health():
return {"status": "ok", "model_loaded": _model_param is not None}
return {
"status": "ok",
"model_loaded": _model_param is not None,
"engine_ready": _engine is not None,
}
@app.get("/stats")
async def get_stats():
"""Get inference engine statistics."""
if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
return _engine.get_stats()
@app.post("/v1/chat/completions", response_model=CompletionResponse)
async def chat_completion(request: ChatCompletionRequest):
"""OpenAIcompatible chat completion endpoint.
"""OpenAI-compatible chat completion endpoint.
Supports both streaming and nonstreaming modes.
Supports both streaming and non-streaming modes with continuous batching.
"""
if _model_param is None:
raise HTTPException(status_code=503, detail="Model not loaded")
# Convert messages to query/history
# For simplicity, assume the last user message is the query, previous messages are history
system_prompt, history = convert_messages_to_history(request.messages)
# Extract last user message as query
user_messages = [m.content for m in request.messages if m.role == "user"]
if not user_messages:
raise HTTPException(status_code=400, detail="No user message found")
query = user_messages[-1]
# If there are multiple user messages, we could merge them, but for demo we keep simple
if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
gen_request = GenerationRequest(
query=query,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
max_len=request.max_tokens,
history=history,
system_prompt=system_prompt,
stream=request.stream,
)
# Convert messages to prompt
prompt = convert_messages_to_prompt(request.messages)
if request.stream:
# Return streaming response
# Streaming response (use synchronous generator)
generator = _engine.generate(
prompt=prompt,
stream=True,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
)
def generate_stream():
generator = GeneratorFactory.create(_model_param, gen_request)
for chunk in generator.generate(gen_request):
# chunk is the cumulative response string
# For OpenAI compatibility, we send incremental delta
# For simplicity, we send the whole chunk each time
yield f"data: {chunk}\n\n"
for token in generator:
if token == "[DONE]":
break
yield f"data: {json.dumps({'choices': [{'delta': {'content': token}}]})}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
@ -190,13 +240,17 @@ async def chat_completion(request: ChatCompletionRequest):
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
else:
# Nonstreaming
generator = GeneratorFactory.create(_model_param, gen_request)
if gen_request.stream:
# Should not happen because we set stream=False
pass
response_text = generator.generate(gen_request)
# Build OpenAIstyle response
# Non-streaming response
result = _engine.generate(
prompt=prompt,
stream=False,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
)
# Build OpenAI-style response
import time
resp = CompletionResponse(
@ -205,7 +259,7 @@ async def chat_completion(request: ChatCompletionRequest):
choices=[
{
"index": 0,
"message": {"role": "assistant", "content": response_text},
"message": {"role": "assistant", "content": result},
"finish_reason": "stop",
}
],
@ -223,35 +277,58 @@ async def generate(
max_len: int = 2048,
stream: bool = False,
):
"""Simple generation endpoint compatible with existing GenerationRequest."""
if _model_param is None:
raise HTTPException(status_code=503, detail="Model not loaded")
"""Simple generation endpoint.
Args:
query: Input query string
history: Conversation history as list of [user, assistant] pairs
temperature: Sampling temperature
top_p: Top-p sampling parameter
top_k: Top-k sampling parameter
max_len: Maximum tokens to generate
stream: Enable streaming output
Returns:
dict: Generation result with response field
"""
if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
# Convert history format
hist: Optional[List[Tuple[str, str]]] = None
if history:
hist = [
(h[0], h[1]) for h in history
] # assuming each item is [user, assistant]
gen_request = GenerationRequest(
query=query,
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_len=max_len,
history=hist,
stream=stream,
)
hist = [(h[0], h[1]) for h in history]
# Build prompt
from astrai.tokenize.chat_template import build_prompt
prompt = build_prompt(query, hist)
if stream:
# Synchronous streaming
result = _engine.generate(
prompt=prompt,
stream=True,
max_tokens=max_len,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
def stream_generator():
generator = GeneratorFactory.create(_model_param, gen_request)
for chunk in generator.generate(gen_request):
yield chunk + "\n"
for token in result:
yield token + "\n"
return StreamingResponse(stream_generator(), media_type="text/plain")
else:
generator = GeneratorFactory.create(_model_param, gen_request)
result = generator.generate(gen_request)
result = _engine.generate(
prompt=prompt,
stream=False,
max_tokens=max_len,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
return {"response": result}
@ -262,6 +339,7 @@ def run_server(
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
param_path: Optional[Path] = None,
max_batch_size: int = 16,
):
"""Run the FastAPI server with uvicorn.
@ -272,6 +350,17 @@ def run_server(
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
param_path: Path to model parameters directory
max_batch_size: Maximum batch size for continuous batching
"""
configure_server(device=device, dtype=dtype, param_path=param_path)
uvicorn.run("astrai.inference.server:app", host=host, port=port, reload=reload)
configure_server(
device=device,
dtype=dtype,
param_path=param_path,
max_batch_size=max_batch_size,
)
uvicorn.run(
"astrai.inference.server:app",
host=host,
port=port,
reload=reload,
)

View File

@ -3,7 +3,7 @@ from pathlib import Path
import torch
from astrai.config.param_config import ModelParameter
from astrai.inference.generator import GenerationRequest, GeneratorFactory
from astrai.inference import InferenceEngine
PROJECT_ROOT = Path(__file__).resolve().parents[2]
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
@ -15,17 +15,15 @@ def generate_text():
query = input(">> ")
request = GenerationRequest(
query=query,
engine = InferenceEngine(param)
response = engine.generate(
prompt=query,
stream=False,
max_tokens=param.config.max_len,
temperature=0.8,
top_p=0.95,
top_k=50,
max_len=param.config.max_len,
history=None,
system_prompt=None,
)
generator = GeneratorFactory.create(param, request)
response = generator.generate(request)
print(response)

View File

@ -3,7 +3,7 @@ from pathlib import Path
import torch
from astrai.config.param_config import ModelParameter
from astrai.inference.generator import GenerationRequest, GeneratorFactory
from astrai.inference import InferenceEngine
PROJECT_ROOT = Path(__file__).resolve().parents[2]
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
@ -21,17 +21,15 @@ def batch_generate():
"请问什么是显卡",
]
request = GenerationRequest(
query=inputs,
engine = InferenceEngine(param)
responses = engine.generate(
prompt=inputs,
stream=False,
max_tokens=param.config.max_len,
temperature=0.8,
top_p=0.95,
top_k=50,
max_len=param.config.max_len,
history=None,
system_prompt=None,
)
generator = GeneratorFactory.create(param, request)
responses = generator.generate(request)
for q, r in zip(inputs, responses):
print((q, r))

View File

@ -3,7 +3,7 @@ from pathlib import Path
import torch
from astrai.config.param_config import ModelParameter
from astrai.inference.generator import GenerationRequest, GeneratorFactory
from astrai.inference import InferenceEngine
PROJECT_ROOT = Path(__file__).resolve().parents[2]
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
@ -14,32 +14,27 @@ def chat():
param.to(device="cuda", dtype=torch.bfloat16)
history = []
engine = InferenceEngine(param)
while True:
query = input(">> ")
if query == "!exit":
break
request = GenerationRequest(
query=query,
full_response = ""
for token in engine.generate(
prompt=query,
stream=True,
max_tokens=param.config.max_len,
temperature=0.8,
top_p=0.95,
top_k=50,
max_len=param.config.max_len,
history=history,
system_prompt=None,
stream=True,
)
generator = GeneratorFactory.create(param, request)
):
print(token, end="", flush=True)
full_response += token
response_size = 0
full_response = ""
for response in generator.generate(request):
# response is the cumulative response up to current token
print(response[response_size:], end="", flush=True)
response_size = len(response)
full_response = response
# After generation, update history
print()
history.append((query, full_response.strip()))

View File

@ -4,7 +4,7 @@ import json
import torch
from astrai.config.param_config import ModelParameter
from astrai.inference.generator import BatchGenerator, GenerationRequest
from astrai.inference import InferenceEngine
def processor(
@ -19,25 +19,22 @@ def processor(
):
param = ModelParameter.load(model_dir, disable_init=True)
param.to(device="cuda", dtype=torch.bfloat16)
generator = BatchGenerator(param)
engine = InferenceEngine(param)
with open(input_json_file, "r", encoding="utf-8") as f:
input_data = [json.loads(line) for line in f]
queries = [item[question_key] for item in input_data]
request = GenerationRequest(
query=queries,
responses = engine.generate(
prompt=queries,
stream=False,
max_tokens=param.config.max_len,
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_len=param.config.max_len,
history=None,
system_prompt=None,
)
responses = generator.generate(request)
with open(output_json_file, "w", encoding="utf-8") as f:
for query, response in zip(queries, responses):
output_item = {question_key: query, response_key: response}

View File

@ -1,11 +1,11 @@
"""Shared fixtures for inference tests."""
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import pytest
from fastapi.testclient import TestClient
from astrai.inference.server import app
from astrai.inference.server import app, _engine
@pytest.fixture
@ -30,13 +30,17 @@ def mock_model_param():
@pytest.fixture
def mock_generator(mock_model_param):
"""Mock the GeneratorFactory and its generators."""
with patch("astrai.inference.server.GeneratorFactory") as MockFactory:
mock_gen = MagicMock()
mock_gen.generate.return_value = "mock response"
MockFactory.create.return_value = mock_gen
yield MockFactory, mock_gen
def mock_engine():
"""Create a mock InferenceEngine."""
mock = MagicMock()
mock.generate.return_value = "mock response"
mock.get_stats.return_value = {
"total_tasks": 0,
"total_tokens": 0,
"active_tasks": 0,
"waiting_queue": 0,
}
return mock
@pytest.fixture

View File

@ -6,24 +6,29 @@ import pytest
def test_health_no_model(client, monkeypatch):
"""GET /health should return 200 even when model not loaded."""
monkeypatch.setattr("astrai.inference.server._model_param", None)
monkeypatch.setattr("astrai.inference.server._engine", None)
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "ok"
assert not data["model_loaded"]
assert not data["engine_ready"]
def test_health_with_model(client, loaded_model):
def test_health_with_model(client, loaded_model, mock_engine, monkeypatch):
"""GET /health should return 200 when model is loaded."""
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
response = client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "ok", "model_loaded": True}
data = response.json()
assert data["status"] == "ok"
assert data["model_loaded"] is True
assert data["engine_ready"] is True
def test_generate_non_stream(client, loaded_model, mock_generator):
def test_generate_non_stream(client, loaded_model, mock_engine, monkeypatch):
"""POST /generate with stream=false should return JSON response."""
MockFactory, mock_gen = mock_generator
mock_gen.generate.return_value = "Test response"
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
response = client.post(
"/generate",
params={
@ -37,15 +42,19 @@ def test_generate_non_stream(client, loaded_model, mock_generator):
)
assert response.status_code == 200
data = response.json()
assert data["response"] == "Test response"
MockFactory.create.assert_called_once()
assert data["response"] == "mock response"
def test_generate_stream(client, loaded_model, mock_generator):
def test_generate_stream(client, loaded_model, mock_engine, monkeypatch):
"""POST /generate with stream=true should return plain text stream."""
MockFactory, mock_gen = mock_generator
# Simulate a streaming generator that yields two chunks
mock_gen.generate.return_value = ["chunk1", "chunk2"]
# Create a streaming mock
def stream_gen():
yield "chunk1"
yield "chunk2"
mock_engine.generate.return_value = stream_gen()
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
response = client.post(
"/generate",
params={
@ -66,10 +75,10 @@ def test_generate_stream(client, loaded_model, mock_generator):
assert "chunk2" in content
def test_chat_completions_non_stream(client, loaded_model, mock_generator):
def test_chat_completions_non_stream(client, loaded_model, mock_engine, monkeypatch):
"""POST /v1/chat/completions with stream=false returns OpenAIstyle JSON."""
MockFactory, mock_gen = mock_generator
mock_gen.generate.return_value = "Assistant reply"
mock_engine.generate.return_value = "Assistant reply"
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
response = client.post(
"/v1/chat/completions",
json={
@ -88,11 +97,17 @@ def test_chat_completions_non_stream(client, loaded_model, mock_generator):
assert data["choices"][0]["message"]["content"] == "Assistant reply"
def test_chat_completions_stream(client, loaded_model, mock_generator):
def test_chat_completions_stream(client, loaded_model, mock_engine, monkeypatch):
"""POST /v1/chat/completions with stream=true returns SSE stream."""
MockFactory, mock_gen = mock_generator
# Simulate a streaming generator that yields cumulative responses
mock_gen.generate.return_value = ["cumulative1", "cumulative2"]
def stream_gen():
yield "cumulative1"
yield "cumulative2"
yield "[DONE]"
mock_engine.generate.return_value = stream_gen()
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
response = client.post(
"/v1/chat/completions",
json={
@ -116,10 +131,9 @@ def test_chat_completions_stream(client, loaded_model, mock_generator):
assert any("cumulative2" in line for line in lines)
def test_generate_with_history(client, loaded_model, mock_generator):
def test_generate_with_history(client, loaded_model, mock_engine, monkeypatch):
"""POST /generate with history parameter."""
MockFactory, mock_gen = mock_generator
mock_gen.generate.return_value = "Response with history"
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
response = client.post(
"/generate",
params={
@ -129,12 +143,8 @@ def test_generate_with_history(client, loaded_model, mock_generator):
},
)
assert response.status_code == 200
MockFactory.create.assert_called_once()
# Check that history was passed correctly (currently history is not parsed due to FastAPI limitation)
call_args = MockFactory.create.call_args
req = call_args[0][1] # second argument is GenerationRequest
# Because history cannot be passed via query params, it will be None
assert req.history is None
# Verify the engine.generate was called
mock_engine.generate.assert_called_once()
if __name__ == "__main__":

View File

@ -3,8 +3,6 @@ import os
import torch
from astrai.config.param_config import ModelParameter
from astrai.inference.generator import EmbeddingEncoderCore, GeneratorCore
def test_model_parameter(test_env):
save_dir = os.path.join(test_env["test_dir"], "save")
@ -33,39 +31,3 @@ def test_transformer(test_env):
test_env["transformer_config"].vocab_size,
)
assert output_logits.shape == target_shape
# generator
def test_embedding_encoder_core(test_env):
parameter = ModelParameter(
test_env["model"], test_env["tokenizer"], test_env["transformer_config"]
)
encoder = EmbeddingEncoderCore(parameter)
single_emb = encoder.encode("测试文本")
assert isinstance(single_emb, torch.Tensor)
assert single_emb.shape[-1] == test_env["transformer_config"].dim
batch_emb = encoder.encode(["测试1", "测试2"])
assert isinstance(batch_emb, list)
assert len(batch_emb) == 2
def test_generator_core(test_env):
parameter = ModelParameter(
test_env["model"], test_env["tokenizer"], test_env["transformer_config"]
)
generator = GeneratorCore(parameter)
input_ids = torch.randint(0, test_env["transformer_config"].vocab_size, (4, 10))
next_token_id, cache_increase = generator.generate_iterator(
input_ids=input_ids,
temperature=0.8,
top_k=50,
top_p=0.95,
attn_mask=None,
kv_caches=None,
start_pos=0,
)
assert next_token_id.shape == (4, 1)
assert cache_increase == 10