refactor: 更新inference 部分的实现
This commit is contained in:
parent
99b821ebf5
commit
861d33b1a1
|
|
@ -8,13 +8,9 @@ from astrai.config import (
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
from astrai.dataset import DatasetFactory
|
from astrai.dataset import DatasetFactory
|
||||||
from astrai.tokenize import BpeTokenizer
|
from astrai.tokenize import BpeTokenizer
|
||||||
from astrai.inference.generator import (
|
from astrai.inference import (
|
||||||
BatchGenerator,
|
|
||||||
EmbeddingEncoder,
|
|
||||||
GenerationRequest,
|
GenerationRequest,
|
||||||
GeneratorFactory,
|
InferenceEngine,
|
||||||
LoopGenerator,
|
|
||||||
StreamGenerator,
|
|
||||||
)
|
)
|
||||||
from astrai.model.transformer import Transformer
|
from astrai.model.transformer import Transformer
|
||||||
from astrai.trainer import SchedulerFactory, StrategyFactory, Trainer
|
from astrai.trainer import SchedulerFactory, StrategyFactory, Trainer
|
||||||
|
|
@ -26,11 +22,7 @@ __all__ = [
|
||||||
"DatasetFactory",
|
"DatasetFactory",
|
||||||
"BpeTokenizer",
|
"BpeTokenizer",
|
||||||
"GenerationRequest",
|
"GenerationRequest",
|
||||||
"LoopGenerator",
|
"InferenceEngine",
|
||||||
"StreamGenerator",
|
|
||||||
"BatchGenerator",
|
|
||||||
"EmbeddingEncoder",
|
|
||||||
"GeneratorFactory",
|
|
||||||
"Trainer",
|
"Trainer",
|
||||||
"StrategyFactory",
|
"StrategyFactory",
|
||||||
"SchedulerFactory",
|
"SchedulerFactory",
|
||||||
|
|
|
||||||
|
|
@ -1,25 +1,34 @@
|
||||||
from astrai.inference.core import (
|
"""
|
||||||
EmbeddingEncoderCore,
|
AstrAI Inference Module
|
||||||
GeneratorCore,
|
|
||||||
KVCacheManager,
|
Provides inference components for text generation with continuous batching support.
|
||||||
)
|
|
||||||
from astrai.inference.generator import (
|
Main Components:
|
||||||
BatchGenerator,
|
- InferenceEngine: Unified inference engine for continuous batching
|
||||||
EmbeddingEncoder,
|
- InferenceScheduler: Task scheduling with dynamic batch composition
|
||||||
|
- Task, TaskStatus: Task management for continuous batching
|
||||||
|
- GenerationRequest: Request parameters for generation
|
||||||
|
- apply_sampling_strategies: Sampling utilities for text generation
|
||||||
|
|
||||||
|
Author: AstrAI Team
|
||||||
|
"""
|
||||||
|
|
||||||
|
from astrai.inference.engine import (
|
||||||
GenerationRequest,
|
GenerationRequest,
|
||||||
GeneratorFactory,
|
InferenceEngine,
|
||||||
LoopGenerator,
|
InferenceScheduler,
|
||||||
StreamGenerator,
|
Task,
|
||||||
|
TaskStatus,
|
||||||
|
apply_sampling_strategies,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"GeneratorCore",
|
# Engine
|
||||||
"EmbeddingEncoderCore",
|
"InferenceEngine",
|
||||||
"KVCacheManager",
|
"InferenceScheduler",
|
||||||
|
"Task",
|
||||||
|
"TaskStatus",
|
||||||
"GenerationRequest",
|
"GenerationRequest",
|
||||||
"LoopGenerator",
|
# Sampling
|
||||||
"StreamGenerator",
|
"apply_sampling_strategies",
|
||||||
"BatchGenerator",
|
|
||||||
"EmbeddingEncoder",
|
|
||||||
"GeneratorFactory",
|
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,272 +0,0 @@
|
||||||
from typing import Any, Callable, List, Optional, Self, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from astrai.config import ModelConfig, ModelParameter
|
|
||||||
|
|
||||||
|
|
||||||
def apply_sampling_strategies(
|
|
||||||
logits: Tensor,
|
|
||||||
temperature: float,
|
|
||||||
top_k: int,
|
|
||||||
top_p: float,
|
|
||||||
filter_value: float = -float("inf"),
|
|
||||||
) -> Tensor:
|
|
||||||
"""
|
|
||||||
Apply sampling strategies to the logits tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
logits (Tensor): The logits tensor.
|
|
||||||
temperature (float): The temperature parameter.
|
|
||||||
top_k (int): The top-k parameter.
|
|
||||||
top_p (float): The top-p parameter.
|
|
||||||
filter_value (float, optional): The filter value. Defaults to -float("inf").
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: The sampled logits tensor.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
if temperature != 1.0:
|
|
||||||
logits = logits / temperature
|
|
||||||
|
|
||||||
if top_k > 0:
|
|
||||||
top_k = min(top_k, logits.size(-1))
|
|
||||||
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
|
|
||||||
logits[indices_to_remove] = filter_value
|
|
||||||
|
|
||||||
if top_p < 1.0:
|
|
||||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
|
||||||
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
|
||||||
|
|
||||||
sorted_indices_to_remove = cumulative_probs > top_p
|
|
||||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
||||||
sorted_indices_to_remove[..., 0] = 0
|
|
||||||
|
|
||||||
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
|
|
||||||
indices_to_remove.scatter_(
|
|
||||||
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
|
||||||
)
|
|
||||||
|
|
||||||
logits[indices_to_remove] = filter_value
|
|
||||||
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
class GeneratorCore:
|
|
||||||
def __init__(self, parameter: ModelParameter):
|
|
||||||
self.model = parameter.model
|
|
||||||
self.tokenizer = parameter.tokenizer
|
|
||||||
self.config = parameter.config
|
|
||||||
|
|
||||||
def generate_iterator(
|
|
||||||
self,
|
|
||||||
input_ids: Tensor,
|
|
||||||
temperature: float,
|
|
||||||
top_k: int,
|
|
||||||
top_p: float,
|
|
||||||
attn_mask: Optional[Tensor] = None,
|
|
||||||
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
|
|
||||||
start_pos: int = 0,
|
|
||||||
) -> Tuple[Tensor, int]:
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
|
|
||||||
logits = outputs["logits"][:, -1, :]
|
|
||||||
cache_increase = input_ids.size(-1)
|
|
||||||
|
|
||||||
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
|
|
||||||
probs = torch.softmax(logits, dim=-1)
|
|
||||||
next_token_id = torch.multinomial(probs, num_samples=1)
|
|
||||||
|
|
||||||
return next_token_id, cache_increase
|
|
||||||
|
|
||||||
def generate_loop(
|
|
||||||
self,
|
|
||||||
input_ids: Tensor,
|
|
||||||
ids: List[int],
|
|
||||||
temperature: float,
|
|
||||||
top_k: int,
|
|
||||||
top_p: float,
|
|
||||||
attn_mask: Optional[Tensor] = None,
|
|
||||||
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
|
|
||||||
start_pos: int = 0,
|
|
||||||
callback: Optional[Callable[..., Any]] = None,
|
|
||||||
) -> List[int]:
|
|
||||||
cur_cache_pos = start_pos
|
|
||||||
|
|
||||||
for _ in range(len(ids), self.config.max_len):
|
|
||||||
next_token_id, cache_increase = self.generate_iterator(
|
|
||||||
input_ids,
|
|
||||||
temperature,
|
|
||||||
top_k,
|
|
||||||
top_p,
|
|
||||||
attn_mask,
|
|
||||||
kv_caches,
|
|
||||||
cur_cache_pos,
|
|
||||||
)
|
|
||||||
|
|
||||||
input_ids = next_token_id
|
|
||||||
ids.append(next_token_id.item())
|
|
||||||
cur_cache_pos += cache_increase
|
|
||||||
|
|
||||||
if callback:
|
|
||||||
callback(next_token_id.item(), ids.copy())
|
|
||||||
|
|
||||||
if next_token_id.item() in self.tokenizer.stop_ids:
|
|
||||||
break
|
|
||||||
|
|
||||||
return ids
|
|
||||||
|
|
||||||
def to(self, *args, **kargs) -> Self:
|
|
||||||
self.model.to(*args, **kargs)
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingEncoderCore:
|
|
||||||
def __init__(self, parameter: ModelParameter):
|
|
||||||
self.model = parameter.model
|
|
||||||
self.tokenizer = parameter.tokenizer
|
|
||||||
self.config = parameter.config
|
|
||||||
|
|
||||||
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
|
|
||||||
with_batch = isinstance(sentence, list)
|
|
||||||
ids = self.tokenizer.encode(sentence)
|
|
||||||
batch_ids = ids if with_batch else [ids]
|
|
||||||
max_model_len = self.config.max_len
|
|
||||||
|
|
||||||
all_fragments = []
|
|
||||||
fragment_origin_idx = []
|
|
||||||
|
|
||||||
for i, seq in enumerate(batch_ids):
|
|
||||||
if len(seq) > max_model_len:
|
|
||||||
fragments = [
|
|
||||||
seq[j : j + max_model_len]
|
|
||||||
for j in range(0, len(seq), max_model_len)
|
|
||||||
]
|
|
||||||
all_fragments.extend(fragments)
|
|
||||||
fragment_origin_idx.extend([i] * len(fragments))
|
|
||||||
else:
|
|
||||||
all_fragments.append(seq)
|
|
||||||
fragment_origin_idx.append(i)
|
|
||||||
|
|
||||||
# if empty fragments
|
|
||||||
if not all_fragments or not ids:
|
|
||||||
return [] if with_batch else torch.tensor([])
|
|
||||||
|
|
||||||
device = next(self.model.parameters()).device
|
|
||||||
max_len = min(max(len(seq) for seq in all_fragments), max_model_len)
|
|
||||||
|
|
||||||
padded_ids = []
|
|
||||||
masks = []
|
|
||||||
for seq in all_fragments:
|
|
||||||
pad_len = max_len - len(seq)
|
|
||||||
padded_seq = seq + [self.tokenizer.pad_id] * pad_len
|
|
||||||
mask = [token_id != self.tokenizer.pad_id for token_id in padded_seq]
|
|
||||||
padded_ids.append(padded_seq)
|
|
||||||
masks.append(mask)
|
|
||||||
|
|
||||||
input_tensor = torch.tensor(padded_ids, device=device, dtype=torch.long)
|
|
||||||
seq_mask = torch.tensor(masks, device=device, dtype=torch.bool)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
outputs = self.model(input_tensor, seq_mask)["hidden_states"]
|
|
||||||
# [num_fragments, seq_len, hidden_size]
|
|
||||||
fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1))
|
|
||||||
|
|
||||||
sentence_embs: List[Tensor] = []
|
|
||||||
for i in range(len(batch_ids)):
|
|
||||||
indices = [
|
|
||||||
idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i
|
|
||||||
]
|
|
||||||
if indices:
|
|
||||||
sum_frags = torch.sum(
|
|
||||||
fragment_embs[indices, :, :], dim=1
|
|
||||||
) # [frags, hidden_size]
|
|
||||||
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(
|
|
||||||
1
|
|
||||||
) # [frags, 1]
|
|
||||||
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
|
|
||||||
sentence_embs.append(emb.flatten())
|
|
||||||
|
|
||||||
if with_batch:
|
|
||||||
return [emb.flatten() for emb in sentence_embs]
|
|
||||||
else:
|
|
||||||
return sentence_embs[0].flatten()
|
|
||||||
|
|
||||||
def to(self, *args, **kargs) -> Self:
|
|
||||||
self.model.to(*args, **kargs)
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class KVCacheManager:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: ModelConfig,
|
|
||||||
batch_size: int,
|
|
||||||
device: torch.device = "cuda",
|
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
|
||||||
):
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.device = device
|
|
||||||
self.dtype = dtype
|
|
||||||
self.num_layers = config.n_layers
|
|
||||||
self.max_len = config.max_len
|
|
||||||
self.num_heads = config.n_kv_heads
|
|
||||||
self.head_dim = config.dim // config.n_heads
|
|
||||||
|
|
||||||
self._kv_cache: Tuple[Tensor, Tensor] = None
|
|
||||||
self._seq_mask: Tensor = None
|
|
||||||
self._initialize()
|
|
||||||
|
|
||||||
def _initialize(self):
|
|
||||||
k_cache = torch.empty(
|
|
||||||
(
|
|
||||||
self.batch_size,
|
|
||||||
self.max_len,
|
|
||||||
self.num_layers,
|
|
||||||
self.num_heads,
|
|
||||||
self.head_dim,
|
|
||||||
),
|
|
||||||
device=self.device,
|
|
||||||
dtype=self.dtype,
|
|
||||||
)
|
|
||||||
v_cache = torch.empty(
|
|
||||||
(
|
|
||||||
self.batch_size,
|
|
||||||
self.max_len,
|
|
||||||
self.num_layers,
|
|
||||||
self.num_heads,
|
|
||||||
self.head_dim,
|
|
||||||
),
|
|
||||||
device=self.device,
|
|
||||||
dtype=self.dtype,
|
|
||||||
)
|
|
||||||
self._kv_cache = (k_cache, v_cache)
|
|
||||||
self._seq_mask = torch.ones(
|
|
||||||
(self.batch_size, self.max_len), device=self.device, dtype=torch.bool
|
|
||||||
)
|
|
||||||
|
|
||||||
def update(self, active_mask: Tensor):
|
|
||||||
k_cache, v_cache = self._kv_cache
|
|
||||||
self._kv_cache = (k_cache[active_mask], v_cache[active_mask])
|
|
||||||
self._seq_mask = self._seq_mask[active_mask]
|
|
||||||
|
|
||||||
def reset(self, full_reset=False):
|
|
||||||
if full_reset:
|
|
||||||
self._kv_cache = None
|
|
||||||
self._seq_mask = None
|
|
||||||
else:
|
|
||||||
self._initialize()
|
|
||||||
|
|
||||||
def set_seq_mask(self, input_ids: Tensor, pad_id: int):
|
|
||||||
batch_size, seq_len = input_ids.shape
|
|
||||||
bool_mask = input_ids != pad_id
|
|
||||||
self._seq_mask[:batch_size, :seq_len] = bool_mask
|
|
||||||
|
|
||||||
def get_kvcache(self) -> Tuple[Tensor, Tensor]:
|
|
||||||
return self._kv_cache
|
|
||||||
|
|
||||||
def get_seq_mask(self) -> Tensor:
|
|
||||||
return self._seq_mask
|
|
||||||
|
|
@ -0,0 +1,694 @@
|
||||||
|
"""
|
||||||
|
Continuous Batching Inference Engine
|
||||||
|
|
||||||
|
This module provides the main continuous batching components:
|
||||||
|
- Task: Individual generation task with state management
|
||||||
|
- TaskStatus: Task state enumeration
|
||||||
|
- InferenceScheduler: Handles request scheduling and KV cache management
|
||||||
|
- InferenceEngine: Unified inference engine
|
||||||
|
|
||||||
|
Author: AstrAI Team
|
||||||
|
"""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum, auto
|
||||||
|
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from astrai.config import ModelConfig, ModelParameter
|
||||||
|
from astrai.tokenize.chat_template import HistoryType, build_prompt
|
||||||
|
|
||||||
|
|
||||||
|
# Use print for debugging instead of logging
|
||||||
|
def _debug(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GenerationRequest:
|
||||||
|
"""Request parameters for text generation."""
|
||||||
|
|
||||||
|
top_k: int
|
||||||
|
top_p: float
|
||||||
|
temperature: float
|
||||||
|
max_len: int
|
||||||
|
|
||||||
|
query: Union[str, List[str]]
|
||||||
|
history: Optional[Union[HistoryType, List[HistoryType]]] = None
|
||||||
|
system_prompt: Optional[str] = None
|
||||||
|
stream: bool = False
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if not isinstance(self.top_k, int) or self.top_k < 0:
|
||||||
|
raise ValueError("top_k must be a non-negative integer")
|
||||||
|
if not isinstance(self.top_p, float) or self.top_p < 0.0 or self.top_p > 1.0:
|
||||||
|
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
||||||
|
if not isinstance(self.temperature, float) or self.temperature < 0.0:
|
||||||
|
raise ValueError("temperature must be a non-negative float")
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatus(Enum):
|
||||||
|
"""Task state enumeration for continuous batching.
|
||||||
|
|
||||||
|
States:
|
||||||
|
PENDING: Task is waiting to be scheduled
|
||||||
|
RUNNING: Task is currently being processed
|
||||||
|
FINISHED: Task completed successfully
|
||||||
|
ABORTED: Task was cancelled or failed
|
||||||
|
"""
|
||||||
|
|
||||||
|
PENDING = auto()
|
||||||
|
RUNNING = auto()
|
||||||
|
FINISHED = auto()
|
||||||
|
ABORTED = auto()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Task:
|
||||||
|
"""Individual task for continuous batching.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
task_id: Unique task identifier
|
||||||
|
prompt_ids: Input token IDs
|
||||||
|
max_tokens: Maximum tokens to generate
|
||||||
|
temperature: Sampling temperature
|
||||||
|
top_p: Top-p sampling parameter
|
||||||
|
top_k: Top-k sampling parameter
|
||||||
|
status: Current task status
|
||||||
|
output_ids: Generated token IDs
|
||||||
|
input_tokens: Number of input tokens
|
||||||
|
output_tokens: Number of output tokens generated
|
||||||
|
slot: Batch slot position (-1 if not assigned)
|
||||||
|
arrival_time: Task arrival timestamp
|
||||||
|
finish_time: Task completion timestamp
|
||||||
|
stream_callback: Callback for streaming output
|
||||||
|
"""
|
||||||
|
|
||||||
|
task_id: str
|
||||||
|
prompt_ids: List[int]
|
||||||
|
max_tokens: int = 1024
|
||||||
|
temperature: float = 1.0
|
||||||
|
top_p: float = 1.0
|
||||||
|
top_k: int = 50
|
||||||
|
|
||||||
|
status: TaskStatus = TaskStatus.PENDING
|
||||||
|
output_ids: List[int] = field(default_factory=list)
|
||||||
|
input_tokens: int = 0
|
||||||
|
output_tokens: int = 0
|
||||||
|
slot: int = -1
|
||||||
|
arrival_time: float = field(default_factory=time.time)
|
||||||
|
finish_time: Optional[float] = None
|
||||||
|
|
||||||
|
stream_callback: Optional[Callable[[str], None]] = None
|
||||||
|
|
||||||
|
def is_finished(self, stop_ids: List[int]) -> bool:
|
||||||
|
"""Check if task is finished."""
|
||||||
|
if self.output_ids and self.output_ids[-1] in stop_ids:
|
||||||
|
return True
|
||||||
|
if self.output_tokens >= self.max_tokens:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def apply_sampling_strategies(
|
||||||
|
logits: Tensor,
|
||||||
|
temperature: float,
|
||||||
|
top_k: int,
|
||||||
|
top_p: float,
|
||||||
|
filter_value: float = -float("inf"),
|
||||||
|
) -> Tensor:
|
||||||
|
"""Apply sampling strategies to the logits tensor."""
|
||||||
|
if temperature != 1.0:
|
||||||
|
logits = logits / temperature
|
||||||
|
|
||||||
|
if top_k > 0:
|
||||||
|
top_k = min(top_k, logits.size(-1))
|
||||||
|
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
|
||||||
|
logits[indices_to_remove] = filter_value
|
||||||
|
|
||||||
|
if top_p < 1.0:
|
||||||
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||||
|
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||||
|
|
||||||
|
sorted_indices_to_remove = cumulative_probs > top_p
|
||||||
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||||
|
sorted_indices_to_remove[..., 0] = 0
|
||||||
|
|
||||||
|
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
|
||||||
|
indices_to_remove.scatter_(
|
||||||
|
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
||||||
|
)
|
||||||
|
|
||||||
|
logits[indices_to_remove] = filter_value
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceScheduler:
|
||||||
|
"""Inference scheduler with continuous batching support.
|
||||||
|
|
||||||
|
Manages request scheduling, KV cache allocation, and generation loop.
|
||||||
|
Supports dynamic batch composition where new requests can join at any time
|
||||||
|
and completed requests are immediately released.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
config: ModelConfig,
|
||||||
|
max_batch_size: int = 16,
|
||||||
|
max_seq_len: Optional[int] = None,
|
||||||
|
device: str = "cuda",
|
||||||
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.config = config
|
||||||
|
self.max_batch_size = max_batch_size
|
||||||
|
self.max_seq_len = max_seq_len or config.max_len
|
||||||
|
self.device = device
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
num_heads = config.n_kv_heads
|
||||||
|
head_dim = config.dim // config.n_heads
|
||||||
|
n_layers = config.n_layers
|
||||||
|
|
||||||
|
k_cache = torch.empty(
|
||||||
|
(
|
||||||
|
max_batch_size,
|
||||||
|
self.max_seq_len,
|
||||||
|
n_layers,
|
||||||
|
num_heads,
|
||||||
|
head_dim,
|
||||||
|
),
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
v_cache = torch.empty(
|
||||||
|
(
|
||||||
|
max_batch_size,
|
||||||
|
self.max_seq_len,
|
||||||
|
n_layers,
|
||||||
|
num_heads,
|
||||||
|
head_dim,
|
||||||
|
),
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
self.kv_cache = (k_cache, v_cache)
|
||||||
|
self.seq_mask = torch.ones(
|
||||||
|
(max_batch_size, self.max_seq_len), device=device, dtype=torch.bool
|
||||||
|
)
|
||||||
|
|
||||||
|
self.waiting_queue: List[Task] = []
|
||||||
|
self.active_tasks: List[Task] = []
|
||||||
|
|
||||||
|
self._running = False
|
||||||
|
self._task_event = threading.Event()
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
self._total_tasks = 0
|
||||||
|
self._total_tokens = 0
|
||||||
|
|
||||||
|
def add_task(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
max_tokens: int = 1024,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
top_k: int = 50,
|
||||||
|
stream_callback: Optional[Callable[[str], None]] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Add a new task to the waiting queue."""
|
||||||
|
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||||
|
prompt_ids = self.tokenizer.encode(prompt)
|
||||||
|
|
||||||
|
_debug(
|
||||||
|
f"add_task: task_id={task_id}, prompt_len={len(prompt_ids)}, has_callback={stream_callback is not None}"
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
task_id=task_id,
|
||||||
|
prompt_ids=prompt_ids,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
stream_callback=stream_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
self.waiting_queue.append(task)
|
||||||
|
self._total_tasks += 1
|
||||||
|
|
||||||
|
self._task_event.set()
|
||||||
|
return task_id
|
||||||
|
|
||||||
|
def remove_task(self, task_id: str) -> None:
|
||||||
|
"""Remove a task from the scheduler."""
|
||||||
|
with self._lock:
|
||||||
|
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
|
||||||
|
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
|
||||||
|
|
||||||
|
def _remove_finished_tasks(self) -> None:
|
||||||
|
"""Remove finished tasks from active batch and update caches."""
|
||||||
|
finished = []
|
||||||
|
for task in self.active_tasks:
|
||||||
|
if task.is_finished(self.tokenizer.stop_ids):
|
||||||
|
task.status = TaskStatus.FINISHED
|
||||||
|
task.finish_time = time.time()
|
||||||
|
finished.append(task)
|
||||||
|
self._total_tokens += task.output_tokens
|
||||||
|
|
||||||
|
for task in finished:
|
||||||
|
slot = task.slot
|
||||||
|
if slot >= 0 and slot < len(self.active_tasks):
|
||||||
|
self.seq_mask[slot, :] = False
|
||||||
|
task.slot = -1
|
||||||
|
|
||||||
|
self.active_tasks = [
|
||||||
|
t for t in self.active_tasks if t.status != TaskStatus.FINISHED
|
||||||
|
]
|
||||||
|
|
||||||
|
def _refill_active_batch(self) -> None:
|
||||||
|
"""Refill active batch with waiting tasks."""
|
||||||
|
available_slots = self.max_batch_size - len(self.active_tasks)
|
||||||
|
if available_slots <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
to_add = []
|
||||||
|
for _ in range(min(available_slots, len(self.waiting_queue))):
|
||||||
|
if self.waiting_queue:
|
||||||
|
task = self.waiting_queue.pop(0)
|
||||||
|
task.status = TaskStatus.RUNNING
|
||||||
|
to_add.append(task)
|
||||||
|
|
||||||
|
for task in to_add:
|
||||||
|
for i in range(self.max_batch_size):
|
||||||
|
if all(t.slot != i for t in self.active_tasks):
|
||||||
|
task.slot = i
|
||||||
|
break
|
||||||
|
self.active_tasks.append(task)
|
||||||
|
|
||||||
|
def _execute_prefill(self, tasks: List[Task]) -> None:
|
||||||
|
"""Execute Prefill phase: process entire prompt at once."""
|
||||||
|
if not tasks:
|
||||||
|
return
|
||||||
|
|
||||||
|
_debug(f"_execute_prefill: processing {len(tasks)} tasks")
|
||||||
|
|
||||||
|
# Sort tasks by slot to ensure correct batch indexing with KV cache
|
||||||
|
tasks = sorted(tasks, key=lambda t: t.slot)
|
||||||
|
|
||||||
|
prompt_lens = [len(task.prompt_ids) for task in tasks]
|
||||||
|
max_len = max(prompt_lens)
|
||||||
|
|
||||||
|
input_ids = torch.zeros(
|
||||||
|
len(tasks), max_len, dtype=torch.long, device=self.device
|
||||||
|
)
|
||||||
|
for i, task in enumerate(tasks):
|
||||||
|
if len(task.prompt_ids) > 0:
|
||||||
|
input_ids[i, : len(task.prompt_ids)] = torch.tensor(
|
||||||
|
task.prompt_ids, device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create boolean mask for attention
|
||||||
|
if self.tokenizer.pad_id is not None:
|
||||||
|
input_mask = torch.ne(input_ids, self.tokenizer.pad_id)
|
||||||
|
else:
|
||||||
|
input_mask = torch.ones(
|
||||||
|
input_ids.shape, dtype=torch.bool, device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
_debug(
|
||||||
|
f"_execute_prefill: input_ids shape={input_ids.shape}, max_len={max_len}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with torch.inference_mode():
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids,
|
||||||
|
input_mask=input_mask,
|
||||||
|
start_pos=0,
|
||||||
|
persistent_key_values=self.kv_cache,
|
||||||
|
)
|
||||||
|
_debug(
|
||||||
|
f"_execute_prefill: model forward done, output keys={outputs.keys() if hasattr(outputs, 'keys') else 'no keys'}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
_debug(f"_execute_prefill: ERROR: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
for i, task in enumerate(tasks):
|
||||||
|
task.input_tokens = prompt_lens[i]
|
||||||
|
task.output_tokens = 0
|
||||||
|
_debug(
|
||||||
|
f" task {task.task_id}: input_tokens={task.input_tokens}, output_tokens={task.output_tokens}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for task in tasks:
|
||||||
|
if task.slot >= 0:
|
||||||
|
self.seq_mask[task.slot, : task.input_tokens] = True
|
||||||
|
|
||||||
|
_debug(f"_execute_prefill: done, {len(tasks)} tasks marked as prefill complete")
|
||||||
|
|
||||||
|
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
|
||||||
|
"""Execute Decode phase: generate one token at a time."""
|
||||||
|
if not tasks:
|
||||||
|
return
|
||||||
|
|
||||||
|
_debug(f"_execute_decode: processing {len(tasks)} tasks, start_pos={start_pos}")
|
||||||
|
|
||||||
|
# Sort tasks by slot to ensure batch index aligns with slot (KV cache position)
|
||||||
|
# Task at slot 0 → batch index 0 → KV stored at cache[0]
|
||||||
|
# Task at slot 1 → batch index 1 → KV stored at cache[1]
|
||||||
|
tasks = sorted(tasks, key=lambda t: t.slot)
|
||||||
|
|
||||||
|
input_ids = torch.zeros(len(tasks), dtype=torch.long, device=self.device)
|
||||||
|
for i, task in enumerate(tasks):
|
||||||
|
if task.output_ids:
|
||||||
|
input_ids[i] = task.output_ids[-1]
|
||||||
|
else:
|
||||||
|
input_ids[i] = task.prompt_ids[-1]
|
||||||
|
|
||||||
|
input_tensor = input_ids.unsqueeze(1) # shape: (batch, 1)
|
||||||
|
|
||||||
|
# Create 2D attention mask: (batch, seq_len)
|
||||||
|
active_mask = torch.ones((len(tasks), 1), dtype=torch.bool, device=self.device)
|
||||||
|
_debug(
|
||||||
|
f"_execute_decode: input_tensor shape={input_tensor.shape}, active_mask shape={active_mask.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with torch.inference_mode():
|
||||||
|
outputs = self.model(
|
||||||
|
input_tensor,
|
||||||
|
input_mask=active_mask,
|
||||||
|
persistent_key_values=self.kv_cache,
|
||||||
|
start_pos=start_pos,
|
||||||
|
)
|
||||||
|
_debug(
|
||||||
|
f"_execute_decode: model forward done, logits shape={outputs['logits'].shape}"
|
||||||
|
)
|
||||||
|
logits = outputs["logits"][:, -1, :]
|
||||||
|
except Exception as e:
|
||||||
|
_debug(f"_execute_decode: ERROR: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
next_token_ids = []
|
||||||
|
for i, task in enumerate(tasks):
|
||||||
|
logit = logits[i : i + 1]
|
||||||
|
logit = apply_sampling_strategies(
|
||||||
|
logit,
|
||||||
|
task.temperature,
|
||||||
|
task.top_k,
|
||||||
|
task.top_p,
|
||||||
|
)
|
||||||
|
probs = torch.softmax(logit, dim=-1)
|
||||||
|
next_token = torch.multinomial(probs, num_samples=1)
|
||||||
|
next_token_ids.append(next_token.item())
|
||||||
|
|
||||||
|
_debug(f"_execute_decode: next_tokens={next_token_ids}")
|
||||||
|
|
||||||
|
for task, next_token in zip(tasks, next_token_ids):
|
||||||
|
task.output_ids.append(next_token)
|
||||||
|
task.output_tokens += 1
|
||||||
|
|
||||||
|
pos = task.input_tokens + task.output_tokens
|
||||||
|
if task.slot >= 0 and pos < self.max_seq_len:
|
||||||
|
self.seq_mask[task.slot, pos] = True
|
||||||
|
|
||||||
|
if task.stream_callback:
|
||||||
|
token_str = self.tokenizer.decode([next_token])
|
||||||
|
task.stream_callback(token_str)
|
||||||
|
|
||||||
|
# Check if any task reached max_tokens or stop token
|
||||||
|
for task in tasks:
|
||||||
|
if task.output_tokens >= task.max_tokens or (
|
||||||
|
task.output_ids and task.output_ids[-1] in self.tokenizer.stop_ids
|
||||||
|
):
|
||||||
|
_debug(
|
||||||
|
f"decode: task {task.task_id} finished, output_tokens={task.output_tokens}, max_tokens={task.max_tokens}"
|
||||||
|
)
|
||||||
|
if task.stream_callback:
|
||||||
|
task.stream_callback("[DONE]")
|
||||||
|
|
||||||
|
def _run_generation_loop(self) -> None:
|
||||||
|
"""Main generation loop with continuous batching."""
|
||||||
|
_debug("generation_loop: started")
|
||||||
|
while self._running:
|
||||||
|
self._remove_finished_tasks()
|
||||||
|
self._refill_active_batch()
|
||||||
|
|
||||||
|
if not self.active_tasks:
|
||||||
|
self._task_event.wait(timeout=0.01)
|
||||||
|
self._task_event.clear()
|
||||||
|
continue
|
||||||
|
|
||||||
|
_debug(
|
||||||
|
f"generation_loop: active={len(self.active_tasks)}, waiting={len(self.waiting_queue)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
new_tasks = [t for t in self.active_tasks if t.output_tokens == 0]
|
||||||
|
decode_tasks = [t for t in self.active_tasks if t.output_tokens > 0]
|
||||||
|
|
||||||
|
_debug(
|
||||||
|
f"generation_loop: new_tasks={len(new_tasks)}, decode_tasks={len(decode_tasks)}"
|
||||||
|
)
|
||||||
|
for t in self.active_tasks:
|
||||||
|
_debug(
|
||||||
|
f" active task {t.task_id}: output_tokens={t.output_tokens}, input_tokens={t.input_tokens}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if decode_tasks:
|
||||||
|
start_pos = max(t.input_tokens + t.output_tokens for t in decode_tasks)
|
||||||
|
else:
|
||||||
|
start_pos = 0
|
||||||
|
|
||||||
|
# First run prefill for new tasks
|
||||||
|
if new_tasks:
|
||||||
|
_debug(f"generation_loop: running prefill for {len(new_tasks)} tasks")
|
||||||
|
self._execute_prefill(new_tasks)
|
||||||
|
_debug(f"generation_loop: prefill done")
|
||||||
|
|
||||||
|
# After prefill, convert these tasks to decode tasks in the same iteration
|
||||||
|
decode_tasks = new_tasks
|
||||||
|
start_pos = max(t.input_tokens for t in decode_tasks)
|
||||||
|
_debug(
|
||||||
|
f"generation_loop: after prefill, decode_tasks={len(decode_tasks)}, start_pos={start_pos}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if decode_tasks:
|
||||||
|
_debug(
|
||||||
|
f"generation_loop: running decode for {len(decode_tasks)} tasks, start_pos={start_pos}"
|
||||||
|
)
|
||||||
|
self._execute_decode(decode_tasks, start_pos)
|
||||||
|
_debug(f"generation_loop: decode done")
|
||||||
|
|
||||||
|
if not self.active_tasks and not self.waiting_queue:
|
||||||
|
time.sleep(0.001)
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
"""Start the generation loop in a background thread."""
|
||||||
|
if not self._running:
|
||||||
|
_debug("InferenceScheduler.start: starting loop thread")
|
||||||
|
self._running = True
|
||||||
|
self._loop_thread = threading.Thread(target=self._run_generation_loop)
|
||||||
|
self._loop_thread.daemon = True
|
||||||
|
self._loop_thread.start()
|
||||||
|
_debug("InferenceScheduler.start: loop thread started")
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""Stop the generation loop."""
|
||||||
|
self._running = False
|
||||||
|
if hasattr(self, "_loop_thread"):
|
||||||
|
self._loop_thread.join(timeout=1.0)
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get scheduler statistics."""
|
||||||
|
return {
|
||||||
|
"total_tasks": self._total_tasks,
|
||||||
|
"total_tokens": self._total_tokens,
|
||||||
|
"active_tasks": len(self.active_tasks),
|
||||||
|
"waiting_queue": len(self.waiting_queue),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceEngine:
|
||||||
|
"""Unified inference engine for continuous batching.
|
||||||
|
|
||||||
|
Provides a single interface for:
|
||||||
|
- Single request generation (streaming or non-streaming)
|
||||||
|
- Batch request generation (streaming or non-streaming)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parameter: ModelParameter,
|
||||||
|
max_batch_size: int = 16,
|
||||||
|
max_seq_len: Optional[int] = None,
|
||||||
|
):
|
||||||
|
self.model = parameter.model
|
||||||
|
self.tokenizer = parameter.tokenizer
|
||||||
|
self.config = parameter.config
|
||||||
|
|
||||||
|
model_params = next(self.model.parameters())
|
||||||
|
self.device = model_params.device
|
||||||
|
self.dtype = model_params.dtype
|
||||||
|
|
||||||
|
self.scheduler = InferenceScheduler(
|
||||||
|
model=self.model,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
config=self.config,
|
||||||
|
max_batch_size=max_batch_size,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.kv_cache = self.scheduler.kv_cache
|
||||||
|
self.seq_mask = self.scheduler.seq_mask
|
||||||
|
|
||||||
|
_debug("InferenceEngine: starting scheduler")
|
||||||
|
self.scheduler.start()
|
||||||
|
_debug("InferenceEngine: scheduler started")
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]],
|
||||||
|
stream: bool = False,
|
||||||
|
max_tokens: int = 1024,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
top_k: int = 50,
|
||||||
|
) -> Union[Generator[str, None, None], str, List[str]]:
|
||||||
|
"""Unified generation interface."""
|
||||||
|
is_batch = isinstance(prompt, list)
|
||||||
|
prompts = prompt if is_batch else [prompt]
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._generate_streaming(
|
||||||
|
prompts, is_batch, max_tokens, temperature, top_p, top_k
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self._generate_non_streaming(
|
||||||
|
prompts, is_batch, max_tokens, temperature, top_p, top_k
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_with_request(
|
||||||
|
self, request: GenerationRequest
|
||||||
|
) -> Union[Generator[str, None, None], str, List[str]]:
|
||||||
|
"""Generate with GenerationRequest object."""
|
||||||
|
prompt = build_prompt(request.query, request.history)
|
||||||
|
|
||||||
|
return self.generate(
|
||||||
|
prompt=prompt,
|
||||||
|
stream=request.stream,
|
||||||
|
max_tokens=request.max_len,
|
||||||
|
temperature=request.temperature,
|
||||||
|
top_p=request.top_p,
|
||||||
|
top_k=request.top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_streaming(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
is_batch: bool,
|
||||||
|
max_tokens: int,
|
||||||
|
temperature: float,
|
||||||
|
top_p: float,
|
||||||
|
top_k: int,
|
||||||
|
) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]:
|
||||||
|
"""Generate with streaming output (synchronous)."""
|
||||||
|
results = []
|
||||||
|
_debug(f"_generate_streaming: prompts={len(prompts)}")
|
||||||
|
|
||||||
|
if is_batch:
|
||||||
|
raise NotImplementedError("Batch streaming is not implemented yet")
|
||||||
|
|
||||||
|
def make_callback(idx: int):
|
||||||
|
def cb(token: str):
|
||||||
|
_debug(f"callback[{idx}]: token={token!r}")
|
||||||
|
results.append(token)
|
||||||
|
|
||||||
|
return cb
|
||||||
|
|
||||||
|
for i, p in enumerate(prompts):
|
||||||
|
_debug(f"_generate_streaming: adding task {i}: {p[:30]}...")
|
||||||
|
self.scheduler.add_task(
|
||||||
|
prompt=p,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
stream_callback=make_callback(i),
|
||||||
|
)
|
||||||
|
|
||||||
|
def gen():
|
||||||
|
_debug("generator: start yielding")
|
||||||
|
while True:
|
||||||
|
# Yield accumulated tokens
|
||||||
|
while results:
|
||||||
|
token = results.pop(0)
|
||||||
|
if token == "[DONE]":
|
||||||
|
_debug("generator: got [DONE]")
|
||||||
|
return
|
||||||
|
_debug(f"generator: yielding {token!r}")
|
||||||
|
yield token
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
return gen()
|
||||||
|
|
||||||
|
def _generate_non_streaming(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
is_batch: bool,
|
||||||
|
max_tokens: int,
|
||||||
|
temperature: float,
|
||||||
|
top_p: float,
|
||||||
|
top_k: int,
|
||||||
|
) -> Union[str, List[str]]:
|
||||||
|
"""Generate without streaming."""
|
||||||
|
results = ["" for _ in range(len(prompts))]
|
||||||
|
done_flags = [False] * len(prompts)
|
||||||
|
lock = threading.Lock()
|
||||||
|
|
||||||
|
def make_callback(idx: int):
|
||||||
|
def cb(token: str):
|
||||||
|
if token == "[DONE]":
|
||||||
|
done_flags[idx] = True
|
||||||
|
else:
|
||||||
|
with lock:
|
||||||
|
results[idx] += token
|
||||||
|
|
||||||
|
return cb
|
||||||
|
|
||||||
|
for i, p in enumerate(prompts):
|
||||||
|
self.scheduler.add_task(
|
||||||
|
prompt=p,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
stream_callback=make_callback(i),
|
||||||
|
)
|
||||||
|
|
||||||
|
while not all(done_flags):
|
||||||
|
time.sleep(0.001)
|
||||||
|
|
||||||
|
return results if is_batch else results[0]
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get engine statistics."""
|
||||||
|
return self.scheduler.get_stats()
|
||||||
|
|
||||||
|
def shutdown(self) -> None:
|
||||||
|
"""Shutdown the engine."""
|
||||||
|
self.scheduler.stop()
|
||||||
|
|
@ -1,269 +0,0 @@
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Generator, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from astrai.config.param_config import ModelParameter
|
|
||||||
from astrai.factory import BaseFactory
|
|
||||||
from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager
|
|
||||||
from astrai.tokenize.chat_template import HistoryType, build_prompt
|
|
||||||
|
|
||||||
|
|
||||||
def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]:
|
|
||||||
"""
|
|
||||||
Pad a list of sequences to a fixed length.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ids_list (List[List[int]]): A list of sequences.
|
|
||||||
max_ids_len (int): The maximum length of sequences.
|
|
||||||
pad_id (int): The id to pad sequences.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[List[int]]: A list of padded sequences.
|
|
||||||
|
|
||||||
"""
|
|
||||||
max_ids_len = max(len(ids) for ids in ids_list)
|
|
||||||
new_ids_list = []
|
|
||||||
for ids in ids_list:
|
|
||||||
pad_len = max_ids_len - len(ids)
|
|
||||||
padded_seq = [pad_id] * pad_len + ids
|
|
||||||
new_ids_list.append(padded_seq)
|
|
||||||
|
|
||||||
return new_ids_list, max_ids_len
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GenerationRequest:
|
|
||||||
"""
|
|
||||||
Request parameters for text generation.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
top_k: Top-k sampling parameter.
|
|
||||||
top_p: Top-p (nucleus) sampling parameter.
|
|
||||||
temperature: Sampling temperature.
|
|
||||||
max_len: Maximum generation length.
|
|
||||||
query: Input query (string or list of strings for batch).
|
|
||||||
history: Conversation history.
|
|
||||||
system_prompt: System prompt for the conversation.
|
|
||||||
stream: Whether to use streaming generation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
top_k: int
|
|
||||||
top_p: float
|
|
||||||
temperature: float
|
|
||||||
max_len: int
|
|
||||||
|
|
||||||
query: Union[str, List[str]]
|
|
||||||
history: Optional[Union[HistoryType, List[HistoryType]]] = None
|
|
||||||
system_prompt: Optional[str] = None
|
|
||||||
stream: bool = False
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if not isinstance(self.top_k, int) or self.top_k < 0:
|
|
||||||
raise ValueError("top_k must be a non-negative integer")
|
|
||||||
if not isinstance(self.top_p, float) or self.top_p < 0.0 or self.top_p > 1.0:
|
|
||||||
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
|
||||||
if not isinstance(self.temperature, float) or self.temperature < 0.0:
|
|
||||||
raise ValueError("temperature must be a non-negative float")
|
|
||||||
|
|
||||||
|
|
||||||
class LoopGenerator(GeneratorCore):
|
|
||||||
def __init__(self, parameter: ModelParameter):
|
|
||||||
super().__init__(parameter)
|
|
||||||
|
|
||||||
def generate(self, request: GenerationRequest) -> str:
|
|
||||||
model_params = next(self.model.parameters())
|
|
||||||
device = model_params.device
|
|
||||||
dtype = model_params.dtype
|
|
||||||
cache_manager = KVCacheManager(self.config, 1, device=device, dtype=dtype)
|
|
||||||
|
|
||||||
prompt = build_prompt(request.query, request.history)
|
|
||||||
ids = self.tokenizer.encode(prompt)
|
|
||||||
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
|
||||||
|
|
||||||
start_cache_pos = len(ids)
|
|
||||||
self.model.eval()
|
|
||||||
kv_caches = cache_manager.get_kvcache()
|
|
||||||
|
|
||||||
ids = self.generate_loop(
|
|
||||||
input_ids,
|
|
||||||
ids,
|
|
||||||
request.temperature,
|
|
||||||
request.top_k,
|
|
||||||
request.top_p,
|
|
||||||
kv_caches=kv_caches,
|
|
||||||
)
|
|
||||||
response = self.tokenizer.decode(ids[start_cache_pos:])
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
class StreamGenerator(GeneratorCore):
|
|
||||||
def __init__(self, parameter: ModelParameter):
|
|
||||||
super().__init__(parameter)
|
|
||||||
|
|
||||||
def generate(self, request: GenerationRequest) -> Generator[str, None, None]:
|
|
||||||
model_params = next(self.model.parameters())
|
|
||||||
device = model_params.device
|
|
||||||
dtype = model_params.dtype
|
|
||||||
cache_manager = KVCacheManager(self.config, 1, device=device, dtype=dtype)
|
|
||||||
|
|
||||||
prompt = build_prompt(request.query, request.history)
|
|
||||||
ids = self.tokenizer.encode(prompt)
|
|
||||||
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
|
||||||
|
|
||||||
start_cache_pos = len(ids)
|
|
||||||
cur_cache_pos = 0
|
|
||||||
self.model.eval()
|
|
||||||
kv_caches = cache_manager.get_kvcache()
|
|
||||||
|
|
||||||
for _ in range(len(ids), self.config.max_len):
|
|
||||||
next_token_id, cache_increase = self.generate_iterator(
|
|
||||||
input_ids,
|
|
||||||
request.temperature,
|
|
||||||
request.top_k,
|
|
||||||
request.top_p,
|
|
||||||
kv_caches=kv_caches,
|
|
||||||
start_pos=cur_cache_pos,
|
|
||||||
)
|
|
||||||
|
|
||||||
input_ids = next_token_id
|
|
||||||
ids.append(next_token_id.item())
|
|
||||||
cur_cache_pos += cache_increase
|
|
||||||
|
|
||||||
response = self.tokenizer.decode(ids[start_cache_pos:])
|
|
||||||
yield response
|
|
||||||
|
|
||||||
if next_token_id.item() in self.tokenizer.stop_ids:
|
|
||||||
yield response + "\n"
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
class BatchGenerator(GeneratorCore):
|
|
||||||
def __init__(self, parameter: ModelParameter):
|
|
||||||
super().__init__(parameter)
|
|
||||||
|
|
||||||
def generate(self, request: GenerationRequest) -> List[str]:
|
|
||||||
batch_size = len(request.query)
|
|
||||||
if request.history is None:
|
|
||||||
request.history = [[] for _ in range(batch_size)]
|
|
||||||
|
|
||||||
prompts = [
|
|
||||||
build_prompt(query, history)
|
|
||||||
for query, history in zip(request.query, request.history)
|
|
||||||
]
|
|
||||||
|
|
||||||
ids_list = [self.tokenizer.encode(prompt) for prompt in prompts]
|
|
||||||
ids_list, max_ids_len = pad_sequence(ids_list, self.tokenizer.pad_id)
|
|
||||||
|
|
||||||
model_params = next(self.model.parameters())
|
|
||||||
device = model_params.device
|
|
||||||
dtype = model_params.dtype
|
|
||||||
cache_manager = KVCacheManager(
|
|
||||||
self.config, batch_size, device=device, dtype=dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long)
|
|
||||||
cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id)
|
|
||||||
activate_task_mask = [True] * batch_size
|
|
||||||
|
|
||||||
start_cache_pos = max_ids_len
|
|
||||||
cur_cache_pos = 0
|
|
||||||
|
|
||||||
while max_ids_len < self.config.max_len and sum(activate_task_mask) != 0:
|
|
||||||
kv_caches = cache_manager.get_kvcache()
|
|
||||||
attn_mask = cache_manager.get_seq_mask()
|
|
||||||
|
|
||||||
next_token_id, cache_increase = self.generate_iterator(
|
|
||||||
input_tensor,
|
|
||||||
request.temperature,
|
|
||||||
request.top_k,
|
|
||||||
request.top_p,
|
|
||||||
attn_mask=attn_mask,
|
|
||||||
kv_caches=kv_caches,
|
|
||||||
start_pos=cur_cache_pos,
|
|
||||||
)
|
|
||||||
|
|
||||||
cur_cache_pos += cache_increase
|
|
||||||
active_mask = []
|
|
||||||
c_ids = 0
|
|
||||||
|
|
||||||
for i in range(batch_size):
|
|
||||||
if activate_task_mask[i]:
|
|
||||||
token = next_token_id[c_ids, :].item()
|
|
||||||
ids_list[i].append(token)
|
|
||||||
c_ids += 1
|
|
||||||
|
|
||||||
is_active = token not in self.tokenizer.stop_ids
|
|
||||||
activate_task_mask[i] = is_active
|
|
||||||
active_mask.append(is_active)
|
|
||||||
|
|
||||||
active_mask = torch.tensor(active_mask, device=device, dtype=torch.bool)
|
|
||||||
cache_manager.update(active_mask)
|
|
||||||
input_tensor = next_token_id[active_mask, :]
|
|
||||||
|
|
||||||
max_ids_len += 1
|
|
||||||
|
|
||||||
responses = [str()] * batch_size
|
|
||||||
for i in range(batch_size):
|
|
||||||
responses[i] = self.tokenizer.decode(ids_list[i][start_cache_pos:])
|
|
||||||
request.history[i].append((request.query[i], responses[i]))
|
|
||||||
|
|
||||||
return responses
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingEncoder(EmbeddingEncoderCore):
|
|
||||||
def __init__(self, parameter: ModelParameter):
|
|
||||||
super().__init__(parameter)
|
|
||||||
|
|
||||||
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
|
|
||||||
return super().encode(sentence)
|
|
||||||
|
|
||||||
|
|
||||||
class GeneratorFactory(BaseFactory[GeneratorCore]):
|
|
||||||
"""Factory class for creating generator instances.
|
|
||||||
|
|
||||||
Provides smart generator selection based on request characteristics:
|
|
||||||
- Streaming: Use StreamGenerator for streaming output
|
|
||||||
- Batch: Use BatchGenerator when query is a list
|
|
||||||
- Single: Use LoopGenerator for single query non-streaming
|
|
||||||
|
|
||||||
Example usage:
|
|
||||||
generator = GeneratorFactory.create(parameter, request)
|
|
||||||
result = generator.generate(request)
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create(parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore:
|
|
||||||
"""Create a generator based on request characteristics.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
parameter: Model parameters containing model, tokenizer, config
|
|
||||||
request: Generation request with query, options, etc.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Appropriate GeneratorCore subclass instance
|
|
||||||
"""
|
|
||||||
# Streaming generation: check stream field first
|
|
||||||
if request.stream:
|
|
||||||
return StreamGenerator(parameter)
|
|
||||||
|
|
||||||
# Batch generation: query is a list of strings
|
|
||||||
if isinstance(request.query, list):
|
|
||||||
return BatchGenerator(parameter)
|
|
||||||
|
|
||||||
# Default: single query non-streaming
|
|
||||||
return LoopGenerator(parameter)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_encoder(parameter: ModelParameter) -> EmbeddingEncoderCore:
|
|
||||||
"""Create an embedding encoder instance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
parameter: Model parameters
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
EmbeddingEncoderCore instance
|
|
||||||
"""
|
|
||||||
return EmbeddingEncoder(parameter)
|
|
||||||
|
|
@ -1,3 +1,13 @@
|
||||||
|
"""
|
||||||
|
Inference Server with Continuous Batching Support
|
||||||
|
|
||||||
|
FastAPI server for inference with continuous batching.
|
||||||
|
Provides OpenAI-compatible chat completion endpoints.
|
||||||
|
|
||||||
|
Author: AstrAI Team
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
@ -10,12 +20,13 @@ from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from astrai.config.param_config import ModelParameter
|
from astrai.config.param_config import ModelParameter
|
||||||
from astrai.inference.generator import GenerationRequest, GeneratorFactory
|
from astrai.inference.engine import GenerationRequest, InferenceEngine
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Global model parameter (loaded once)
|
# Global model parameter and engine (loaded once)
|
||||||
_model_param: Optional[ModelParameter] = None
|
_model_param: Optional[ModelParameter] = None
|
||||||
|
_engine: Optional[InferenceEngine] = None
|
||||||
_project_root = Path(__file__).parent.parent.parent
|
_project_root = Path(__file__).parent.parent.parent
|
||||||
|
|
||||||
# Server configuration (set before running server)
|
# Server configuration (set before running server)
|
||||||
|
|
@ -23,6 +34,7 @@ _server_config: Dict[str, Any] = {
|
||||||
"device": "cuda",
|
"device": "cuda",
|
||||||
"dtype": torch.bfloat16,
|
"dtype": torch.bfloat16,
|
||||||
"param_path": None,
|
"param_path": None,
|
||||||
|
"max_batch_size": 16,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -30,6 +42,7 @@ def configure_server(
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
param_path: Optional[Path] = None,
|
param_path: Optional[Path] = None,
|
||||||
|
max_batch_size: int = 16,
|
||||||
):
|
):
|
||||||
"""Configure server settings before starting.
|
"""Configure server settings before starting.
|
||||||
|
|
||||||
|
|
@ -37,40 +50,47 @@ def configure_server(
|
||||||
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
|
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
|
||||||
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
|
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
|
||||||
param_path: Path to model parameters directory
|
param_path: Path to model parameters directory
|
||||||
|
max_batch_size: Maximum batch size for continuous batching
|
||||||
"""
|
"""
|
||||||
_server_config["device"] = device
|
_server_config["device"] = device
|
||||||
_server_config["dtype"] = dtype
|
_server_config["dtype"] = dtype
|
||||||
_server_config["param_path"] = param_path
|
_server_config["param_path"] = param_path
|
||||||
|
_server_config["max_batch_size"] = max_batch_size
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
"""Lifespan context manager for startup and shutdown events."""
|
"""Lifespan context manager for startup and shutdown events."""
|
||||||
|
global _model_param, _engine
|
||||||
# Startup: Load model with configured settings
|
# Startup: Load model with configured settings
|
||||||
try:
|
try:
|
||||||
load_model(
|
load_model(
|
||||||
param_path=_server_config["param_path"],
|
param_path=_server_config["param_path"],
|
||||||
device=_server_config["device"],
|
device=_server_config["device"],
|
||||||
dtype=_server_config["dtype"],
|
dtype=_server_config["dtype"],
|
||||||
|
max_batch_size=_server_config["max_batch_size"],
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to load model: {e}")
|
logger.error(f"Failed to load model: {e}")
|
||||||
raise
|
raise
|
||||||
yield
|
yield
|
||||||
# Shutdown: Cleanup if needed
|
# Shutdown: Cleanup engine
|
||||||
pass
|
if _engine:
|
||||||
|
_engine.shutdown()
|
||||||
|
logger.info("Inference engine shutdown complete")
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="AstrAI Inference Server", version="0.1.0", lifespan=lifespan)
|
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
param_path: Optional[Path] = None,
|
param_path: Optional[Path] = None,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
|
max_batch_size: int = 16,
|
||||||
):
|
):
|
||||||
"""Load model parameters into global variable."""
|
"""Load model parameters and initialize inference engine."""
|
||||||
global _model_param
|
global _model_param, _engine
|
||||||
if param_path is None:
|
if param_path is None:
|
||||||
param_path = _project_root / "params"
|
param_path = _project_root / "params"
|
||||||
if not param_path.exists():
|
if not param_path.exists():
|
||||||
|
|
@ -79,6 +99,13 @@ def load_model(
|
||||||
_model_param.to(device=device, dtype=dtype)
|
_model_param.to(device=device, dtype=dtype)
|
||||||
logger.info(f"Model loaded on {device} with dtype {dtype}")
|
logger.info(f"Model loaded on {device} with dtype {dtype}")
|
||||||
|
|
||||||
|
# Initialize inference engine with continuous batching
|
||||||
|
_engine = InferenceEngine(
|
||||||
|
parameter=_model_param,
|
||||||
|
max_batch_size=max_batch_size,
|
||||||
|
)
|
||||||
|
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
|
||||||
|
|
||||||
|
|
||||||
# Pydantic models for API request/response
|
# Pydantic models for API request/response
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
|
|
@ -134,54 +161,77 @@ def convert_messages_to_history(
|
||||||
assistant_buffer.append(msg.content)
|
assistant_buffer.append(msg.content)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Unknown role {msg.role}")
|
logger.warning(f"Unknown role {msg.role}")
|
||||||
# If there is a pending user message without assistant, treat as current query
|
|
||||||
# We'll handle this later
|
|
||||||
return system_prompt, history if history else None
|
return system_prompt, history if history else None
|
||||||
|
|
||||||
|
|
||||||
|
def convert_messages_to_prompt(messages: List[ChatMessage]) -> str:
|
||||||
|
"""Convert messages to prompt string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of ChatMessage objects
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Formatted prompt string
|
||||||
|
"""
|
||||||
|
system_prompt, history = convert_messages_to_history(messages)
|
||||||
|
|
||||||
|
# Get the last user message as query
|
||||||
|
user_messages = [m.content for m in messages if m.role == "user"]
|
||||||
|
if not user_messages:
|
||||||
|
raise ValueError("No user message found")
|
||||||
|
query = user_messages[-1]
|
||||||
|
|
||||||
|
# Build prompt using chat template
|
||||||
|
from astrai.tokenize.chat_template import build_prompt
|
||||||
|
|
||||||
|
return build_prompt(query, history)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health():
|
async def health():
|
||||||
return {"status": "ok", "model_loaded": _model_param is not None}
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"model_loaded": _model_param is not None,
|
||||||
|
"engine_ready": _engine is not None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/stats")
|
||||||
|
async def get_stats():
|
||||||
|
"""Get inference engine statistics."""
|
||||||
|
if _engine is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||||
|
return _engine.get_stats()
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions", response_model=CompletionResponse)
|
@app.post("/v1/chat/completions", response_model=CompletionResponse)
|
||||||
async def chat_completion(request: ChatCompletionRequest):
|
async def chat_completion(request: ChatCompletionRequest):
|
||||||
"""OpenAI‑compatible chat completion endpoint.
|
"""OpenAI-compatible chat completion endpoint.
|
||||||
|
|
||||||
Supports both streaming and non‑streaming modes.
|
Supports both streaming and non-streaming modes with continuous batching.
|
||||||
"""
|
"""
|
||||||
if _model_param is None:
|
if _engine is None:
|
||||||
raise HTTPException(status_code=503, detail="Model not loaded")
|
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||||
# Convert messages to query/history
|
|
||||||
# For simplicity, assume the last user message is the query, previous messages are history
|
|
||||||
system_prompt, history = convert_messages_to_history(request.messages)
|
|
||||||
# Extract last user message as query
|
|
||||||
user_messages = [m.content for m in request.messages if m.role == "user"]
|
|
||||||
if not user_messages:
|
|
||||||
raise HTTPException(status_code=400, detail="No user message found")
|
|
||||||
query = user_messages[-1]
|
|
||||||
# If there are multiple user messages, we could merge them, but for demo we keep simple
|
|
||||||
|
|
||||||
gen_request = GenerationRequest(
|
# Convert messages to prompt
|
||||||
query=query,
|
prompt = convert_messages_to_prompt(request.messages)
|
||||||
temperature=request.temperature,
|
|
||||||
top_p=request.top_p,
|
|
||||||
top_k=request.top_k,
|
|
||||||
max_len=request.max_tokens,
|
|
||||||
history=history,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
stream=request.stream,
|
|
||||||
)
|
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
# Return streaming response
|
# Streaming response (use synchronous generator)
|
||||||
|
generator = _engine.generate(
|
||||||
|
prompt=prompt,
|
||||||
|
stream=True,
|
||||||
|
max_tokens=request.max_tokens,
|
||||||
|
temperature=request.temperature,
|
||||||
|
top_p=request.top_p,
|
||||||
|
top_k=request.top_k,
|
||||||
|
)
|
||||||
|
|
||||||
def generate_stream():
|
def generate_stream():
|
||||||
generator = GeneratorFactory.create(_model_param, gen_request)
|
for token in generator:
|
||||||
for chunk in generator.generate(gen_request):
|
if token == "[DONE]":
|
||||||
# chunk is the cumulative response string
|
break
|
||||||
# For OpenAI compatibility, we send incremental delta
|
yield f"data: {json.dumps({'choices': [{'delta': {'content': token}}]})}\n\n"
|
||||||
# For simplicity, we send the whole chunk each time
|
|
||||||
yield f"data: {chunk}\n\n"
|
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
|
|
@ -190,13 +240,17 @@ async def chat_completion(request: ChatCompletionRequest):
|
||||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Non‑streaming
|
# Non-streaming response
|
||||||
generator = GeneratorFactory.create(_model_param, gen_request)
|
result = _engine.generate(
|
||||||
if gen_request.stream:
|
prompt=prompt,
|
||||||
# Should not happen because we set stream=False
|
stream=False,
|
||||||
pass
|
max_tokens=request.max_tokens,
|
||||||
response_text = generator.generate(gen_request)
|
temperature=request.temperature,
|
||||||
# Build OpenAI‑style response
|
top_p=request.top_p,
|
||||||
|
top_k=request.top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build OpenAI-style response
|
||||||
import time
|
import time
|
||||||
|
|
||||||
resp = CompletionResponse(
|
resp = CompletionResponse(
|
||||||
|
|
@ -205,7 +259,7 @@ async def chat_completion(request: ChatCompletionRequest):
|
||||||
choices=[
|
choices=[
|
||||||
{
|
{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"message": {"role": "assistant", "content": response_text},
|
"message": {"role": "assistant", "content": result},
|
||||||
"finish_reason": "stop",
|
"finish_reason": "stop",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
@ -223,35 +277,58 @@ async def generate(
|
||||||
max_len: int = 2048,
|
max_len: int = 2048,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
):
|
):
|
||||||
"""Simple generation endpoint compatible with existing GenerationRequest."""
|
"""Simple generation endpoint.
|
||||||
if _model_param is None:
|
|
||||||
raise HTTPException(status_code=503, detail="Model not loaded")
|
Args:
|
||||||
|
query: Input query string
|
||||||
|
history: Conversation history as list of [user, assistant] pairs
|
||||||
|
temperature: Sampling temperature
|
||||||
|
top_p: Top-p sampling parameter
|
||||||
|
top_k: Top-k sampling parameter
|
||||||
|
max_len: Maximum tokens to generate
|
||||||
|
stream: Enable streaming output
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Generation result with response field
|
||||||
|
"""
|
||||||
|
if _engine is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||||
|
|
||||||
# Convert history format
|
# Convert history format
|
||||||
hist: Optional[List[Tuple[str, str]]] = None
|
hist: Optional[List[Tuple[str, str]]] = None
|
||||||
if history:
|
if history:
|
||||||
hist = [
|
hist = [(h[0], h[1]) for h in history]
|
||||||
(h[0], h[1]) for h in history
|
|
||||||
] # assuming each item is [user, assistant]
|
# Build prompt
|
||||||
gen_request = GenerationRequest(
|
from astrai.tokenize.chat_template import build_prompt
|
||||||
query=query,
|
|
||||||
temperature=temperature,
|
prompt = build_prompt(query, hist)
|
||||||
top_p=top_p,
|
|
||||||
top_k=top_k,
|
|
||||||
max_len=max_len,
|
|
||||||
history=hist,
|
|
||||||
stream=stream,
|
|
||||||
)
|
|
||||||
if stream:
|
if stream:
|
||||||
|
# Synchronous streaming
|
||||||
|
result = _engine.generate(
|
||||||
|
prompt=prompt,
|
||||||
|
stream=True,
|
||||||
|
max_tokens=max_len,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
)
|
||||||
|
|
||||||
def stream_generator():
|
def stream_generator():
|
||||||
generator = GeneratorFactory.create(_model_param, gen_request)
|
for token in result:
|
||||||
for chunk in generator.generate(gen_request):
|
yield token + "\n"
|
||||||
yield chunk + "\n"
|
|
||||||
|
|
||||||
return StreamingResponse(stream_generator(), media_type="text/plain")
|
return StreamingResponse(stream_generator(), media_type="text/plain")
|
||||||
else:
|
else:
|
||||||
generator = GeneratorFactory.create(_model_param, gen_request)
|
result = _engine.generate(
|
||||||
result = generator.generate(gen_request)
|
prompt=prompt,
|
||||||
|
stream=False,
|
||||||
|
max_tokens=max_len,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
)
|
||||||
return {"response": result}
|
return {"response": result}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -262,6 +339,7 @@ def run_server(
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
param_path: Optional[Path] = None,
|
param_path: Optional[Path] = None,
|
||||||
|
max_batch_size: int = 16,
|
||||||
):
|
):
|
||||||
"""Run the FastAPI server with uvicorn.
|
"""Run the FastAPI server with uvicorn.
|
||||||
|
|
||||||
|
|
@ -272,6 +350,17 @@ def run_server(
|
||||||
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
|
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
|
||||||
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
|
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
|
||||||
param_path: Path to model parameters directory
|
param_path: Path to model parameters directory
|
||||||
|
max_batch_size: Maximum batch size for continuous batching
|
||||||
"""
|
"""
|
||||||
configure_server(device=device, dtype=dtype, param_path=param_path)
|
configure_server(
|
||||||
uvicorn.run("astrai.inference.server:app", host=host, port=port, reload=reload)
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
param_path=param_path,
|
||||||
|
max_batch_size=max_batch_size,
|
||||||
|
)
|
||||||
|
uvicorn.run(
|
||||||
|
"astrai.inference.server:app",
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
reload=reload,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from pathlib import Path
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.config.param_config import ModelParameter
|
from astrai.config.param_config import ModelParameter
|
||||||
from astrai.inference.generator import GenerationRequest, GeneratorFactory
|
from astrai.inference import InferenceEngine
|
||||||
|
|
||||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||||
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||||
|
|
@ -15,17 +15,15 @@ def generate_text():
|
||||||
|
|
||||||
query = input(">> ")
|
query = input(">> ")
|
||||||
|
|
||||||
request = GenerationRequest(
|
engine = InferenceEngine(param)
|
||||||
query=query,
|
response = engine.generate(
|
||||||
|
prompt=query,
|
||||||
|
stream=False,
|
||||||
|
max_tokens=param.config.max_len,
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
max_len=param.config.max_len,
|
|
||||||
history=None,
|
|
||||||
system_prompt=None,
|
|
||||||
)
|
)
|
||||||
generator = GeneratorFactory.create(param, request)
|
|
||||||
response = generator.generate(request)
|
|
||||||
|
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from pathlib import Path
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.config.param_config import ModelParameter
|
from astrai.config.param_config import ModelParameter
|
||||||
from astrai.inference.generator import GenerationRequest, GeneratorFactory
|
from astrai.inference import InferenceEngine
|
||||||
|
|
||||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||||
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||||
|
|
@ -21,17 +21,15 @@ def batch_generate():
|
||||||
"请问什么是显卡",
|
"请问什么是显卡",
|
||||||
]
|
]
|
||||||
|
|
||||||
request = GenerationRequest(
|
engine = InferenceEngine(param)
|
||||||
query=inputs,
|
responses = engine.generate(
|
||||||
|
prompt=inputs,
|
||||||
|
stream=False,
|
||||||
|
max_tokens=param.config.max_len,
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
max_len=param.config.max_len,
|
|
||||||
history=None,
|
|
||||||
system_prompt=None,
|
|
||||||
)
|
)
|
||||||
generator = GeneratorFactory.create(param, request)
|
|
||||||
responses = generator.generate(request)
|
|
||||||
|
|
||||||
for q, r in zip(inputs, responses):
|
for q, r in zip(inputs, responses):
|
||||||
print((q, r))
|
print((q, r))
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from pathlib import Path
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.config.param_config import ModelParameter
|
from astrai.config.param_config import ModelParameter
|
||||||
from astrai.inference.generator import GenerationRequest, GeneratorFactory
|
from astrai.inference import InferenceEngine
|
||||||
|
|
||||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||||
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||||
|
|
@ -14,32 +14,27 @@ def chat():
|
||||||
param.to(device="cuda", dtype=torch.bfloat16)
|
param.to(device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
history = []
|
history = []
|
||||||
|
engine = InferenceEngine(param)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
query = input(">> ")
|
query = input(">> ")
|
||||||
if query == "!exit":
|
if query == "!exit":
|
||||||
break
|
break
|
||||||
|
|
||||||
request = GenerationRequest(
|
full_response = ""
|
||||||
query=query,
|
|
||||||
|
for token in engine.generate(
|
||||||
|
prompt=query,
|
||||||
|
stream=True,
|
||||||
|
max_tokens=param.config.max_len,
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
max_len=param.config.max_len,
|
):
|
||||||
history=history,
|
print(token, end="", flush=True)
|
||||||
system_prompt=None,
|
full_response += token
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
generator = GeneratorFactory.create(param, request)
|
|
||||||
|
|
||||||
response_size = 0
|
print()
|
||||||
full_response = ""
|
|
||||||
for response in generator.generate(request):
|
|
||||||
# response is the cumulative response up to current token
|
|
||||||
print(response[response_size:], end="", flush=True)
|
|
||||||
response_size = len(response)
|
|
||||||
full_response = response
|
|
||||||
|
|
||||||
# After generation, update history
|
|
||||||
history.append((query, full_response.strip()))
|
history.append((query, full_response.strip()))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import json
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.config.param_config import ModelParameter
|
from astrai.config.param_config import ModelParameter
|
||||||
from astrai.inference.generator import BatchGenerator, GenerationRequest
|
from astrai.inference import InferenceEngine
|
||||||
|
|
||||||
|
|
||||||
def processor(
|
def processor(
|
||||||
|
|
@ -19,25 +19,22 @@ def processor(
|
||||||
):
|
):
|
||||||
param = ModelParameter.load(model_dir, disable_init=True)
|
param = ModelParameter.load(model_dir, disable_init=True)
|
||||||
param.to(device="cuda", dtype=torch.bfloat16)
|
param.to(device="cuda", dtype=torch.bfloat16)
|
||||||
generator = BatchGenerator(param)
|
engine = InferenceEngine(param)
|
||||||
|
|
||||||
with open(input_json_file, "r", encoding="utf-8") as f:
|
with open(input_json_file, "r", encoding="utf-8") as f:
|
||||||
input_data = [json.loads(line) for line in f]
|
input_data = [json.loads(line) for line in f]
|
||||||
|
|
||||||
queries = [item[question_key] for item in input_data]
|
queries = [item[question_key] for item in input_data]
|
||||||
|
|
||||||
request = GenerationRequest(
|
responses = engine.generate(
|
||||||
query=queries,
|
prompt=queries,
|
||||||
|
stream=False,
|
||||||
|
max_tokens=param.config.max_len,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
max_len=param.config.max_len,
|
|
||||||
history=None,
|
|
||||||
system_prompt=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
responses = generator.generate(request)
|
|
||||||
|
|
||||||
with open(output_json_file, "w", encoding="utf-8") as f:
|
with open(output_json_file, "w", encoding="utf-8") as f:
|
||||||
for query, response in zip(queries, responses):
|
for query, response in zip(queries, responses):
|
||||||
output_item = {question_key: query, response_key: response}
|
output_item = {question_key: query, response_key: response}
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
"""Shared fixtures for inference tests."""
|
"""Shared fixtures for inference tests."""
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from astrai.inference.server import app
|
from astrai.inference.server import app, _engine
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -30,13 +30,17 @@ def mock_model_param():
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_generator(mock_model_param):
|
def mock_engine():
|
||||||
"""Mock the GeneratorFactory and its generators."""
|
"""Create a mock InferenceEngine."""
|
||||||
with patch("astrai.inference.server.GeneratorFactory") as MockFactory:
|
mock = MagicMock()
|
||||||
mock_gen = MagicMock()
|
mock.generate.return_value = "mock response"
|
||||||
mock_gen.generate.return_value = "mock response"
|
mock.get_stats.return_value = {
|
||||||
MockFactory.create.return_value = mock_gen
|
"total_tasks": 0,
|
||||||
yield MockFactory, mock_gen
|
"total_tokens": 0,
|
||||||
|
"active_tasks": 0,
|
||||||
|
"waiting_queue": 0,
|
||||||
|
}
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
||||||
|
|
@ -6,24 +6,29 @@ import pytest
|
||||||
def test_health_no_model(client, monkeypatch):
|
def test_health_no_model(client, monkeypatch):
|
||||||
"""GET /health should return 200 even when model not loaded."""
|
"""GET /health should return 200 even when model not loaded."""
|
||||||
monkeypatch.setattr("astrai.inference.server._model_param", None)
|
monkeypatch.setattr("astrai.inference.server._model_param", None)
|
||||||
|
monkeypatch.setattr("astrai.inference.server._engine", None)
|
||||||
response = client.get("/health")
|
response = client.get("/health")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["status"] == "ok"
|
assert data["status"] == "ok"
|
||||||
assert not data["model_loaded"]
|
assert not data["model_loaded"]
|
||||||
|
assert not data["engine_ready"]
|
||||||
|
|
||||||
|
|
||||||
def test_health_with_model(client, loaded_model):
|
def test_health_with_model(client, loaded_model, mock_engine, monkeypatch):
|
||||||
"""GET /health should return 200 when model is loaded."""
|
"""GET /health should return 200 when model is loaded."""
|
||||||
|
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
|
||||||
response = client.get("/health")
|
response = client.get("/health")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == {"status": "ok", "model_loaded": True}
|
data = response.json()
|
||||||
|
assert data["status"] == "ok"
|
||||||
|
assert data["model_loaded"] is True
|
||||||
|
assert data["engine_ready"] is True
|
||||||
|
|
||||||
|
|
||||||
def test_generate_non_stream(client, loaded_model, mock_generator):
|
def test_generate_non_stream(client, loaded_model, mock_engine, monkeypatch):
|
||||||
"""POST /generate with stream=false should return JSON response."""
|
"""POST /generate with stream=false should return JSON response."""
|
||||||
MockFactory, mock_gen = mock_generator
|
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
|
||||||
mock_gen.generate.return_value = "Test response"
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/generate",
|
"/generate",
|
||||||
params={
|
params={
|
||||||
|
|
@ -37,15 +42,19 @@ def test_generate_non_stream(client, loaded_model, mock_generator):
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["response"] == "Test response"
|
assert data["response"] == "mock response"
|
||||||
MockFactory.create.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
def test_generate_stream(client, loaded_model, mock_generator):
|
def test_generate_stream(client, loaded_model, mock_engine, monkeypatch):
|
||||||
"""POST /generate with stream=true should return plain text stream."""
|
"""POST /generate with stream=true should return plain text stream."""
|
||||||
MockFactory, mock_gen = mock_generator
|
|
||||||
# Simulate a streaming generator that yields two chunks
|
# Create a streaming mock
|
||||||
mock_gen.generate.return_value = ["chunk1", "chunk2"]
|
def stream_gen():
|
||||||
|
yield "chunk1"
|
||||||
|
yield "chunk2"
|
||||||
|
|
||||||
|
mock_engine.generate.return_value = stream_gen()
|
||||||
|
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/generate",
|
"/generate",
|
||||||
params={
|
params={
|
||||||
|
|
@ -66,10 +75,10 @@ def test_generate_stream(client, loaded_model, mock_generator):
|
||||||
assert "chunk2" in content
|
assert "chunk2" in content
|
||||||
|
|
||||||
|
|
||||||
def test_chat_completions_non_stream(client, loaded_model, mock_generator):
|
def test_chat_completions_non_stream(client, loaded_model, mock_engine, monkeypatch):
|
||||||
"""POST /v1/chat/completions with stream=false returns OpenAI‑style JSON."""
|
"""POST /v1/chat/completions with stream=false returns OpenAI‑style JSON."""
|
||||||
MockFactory, mock_gen = mock_generator
|
mock_engine.generate.return_value = "Assistant reply"
|
||||||
mock_gen.generate.return_value = "Assistant reply"
|
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/chat/completions",
|
"/v1/chat/completions",
|
||||||
json={
|
json={
|
||||||
|
|
@ -88,11 +97,17 @@ def test_chat_completions_non_stream(client, loaded_model, mock_generator):
|
||||||
assert data["choices"][0]["message"]["content"] == "Assistant reply"
|
assert data["choices"][0]["message"]["content"] == "Assistant reply"
|
||||||
|
|
||||||
|
|
||||||
def test_chat_completions_stream(client, loaded_model, mock_generator):
|
def test_chat_completions_stream(client, loaded_model, mock_engine, monkeypatch):
|
||||||
"""POST /v1/chat/completions with stream=true returns SSE stream."""
|
"""POST /v1/chat/completions with stream=true returns SSE stream."""
|
||||||
MockFactory, mock_gen = mock_generator
|
|
||||||
# Simulate a streaming generator that yields cumulative responses
|
# Simulate a streaming generator that yields cumulative responses
|
||||||
mock_gen.generate.return_value = ["cumulative1", "cumulative2"]
|
def stream_gen():
|
||||||
|
yield "cumulative1"
|
||||||
|
yield "cumulative2"
|
||||||
|
yield "[DONE]"
|
||||||
|
|
||||||
|
mock_engine.generate.return_value = stream_gen()
|
||||||
|
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/chat/completions",
|
"/v1/chat/completions",
|
||||||
json={
|
json={
|
||||||
|
|
@ -116,10 +131,9 @@ def test_chat_completions_stream(client, loaded_model, mock_generator):
|
||||||
assert any("cumulative2" in line for line in lines)
|
assert any("cumulative2" in line for line in lines)
|
||||||
|
|
||||||
|
|
||||||
def test_generate_with_history(client, loaded_model, mock_generator):
|
def test_generate_with_history(client, loaded_model, mock_engine, monkeypatch):
|
||||||
"""POST /generate with history parameter."""
|
"""POST /generate with history parameter."""
|
||||||
MockFactory, mock_gen = mock_generator
|
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
|
||||||
mock_gen.generate.return_value = "Response with history"
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/generate",
|
"/generate",
|
||||||
params={
|
params={
|
||||||
|
|
@ -129,12 +143,8 @@ def test_generate_with_history(client, loaded_model, mock_generator):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
MockFactory.create.assert_called_once()
|
# Verify the engine.generate was called
|
||||||
# Check that history was passed correctly (currently history is not parsed due to FastAPI limitation)
|
mock_engine.generate.assert_called_once()
|
||||||
call_args = MockFactory.create.call_args
|
|
||||||
req = call_args[0][1] # second argument is GenerationRequest
|
|
||||||
# Because history cannot be passed via query params, it will be None
|
|
||||||
assert req.history is None
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,6 @@ import os
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.config.param_config import ModelParameter
|
from astrai.config.param_config import ModelParameter
|
||||||
from astrai.inference.generator import EmbeddingEncoderCore, GeneratorCore
|
|
||||||
|
|
||||||
|
|
||||||
def test_model_parameter(test_env):
|
def test_model_parameter(test_env):
|
||||||
save_dir = os.path.join(test_env["test_dir"], "save")
|
save_dir = os.path.join(test_env["test_dir"], "save")
|
||||||
|
|
@ -33,39 +31,3 @@ def test_transformer(test_env):
|
||||||
test_env["transformer_config"].vocab_size,
|
test_env["transformer_config"].vocab_size,
|
||||||
)
|
)
|
||||||
assert output_logits.shape == target_shape
|
assert output_logits.shape == target_shape
|
||||||
|
|
||||||
|
|
||||||
# generator
|
|
||||||
def test_embedding_encoder_core(test_env):
|
|
||||||
parameter = ModelParameter(
|
|
||||||
test_env["model"], test_env["tokenizer"], test_env["transformer_config"]
|
|
||||||
)
|
|
||||||
encoder = EmbeddingEncoderCore(parameter)
|
|
||||||
|
|
||||||
single_emb = encoder.encode("测试文本")
|
|
||||||
assert isinstance(single_emb, torch.Tensor)
|
|
||||||
assert single_emb.shape[-1] == test_env["transformer_config"].dim
|
|
||||||
|
|
||||||
batch_emb = encoder.encode(["测试1", "测试2"])
|
|
||||||
assert isinstance(batch_emb, list)
|
|
||||||
assert len(batch_emb) == 2
|
|
||||||
|
|
||||||
|
|
||||||
def test_generator_core(test_env):
|
|
||||||
parameter = ModelParameter(
|
|
||||||
test_env["model"], test_env["tokenizer"], test_env["transformer_config"]
|
|
||||||
)
|
|
||||||
generator = GeneratorCore(parameter)
|
|
||||||
input_ids = torch.randint(0, test_env["transformer_config"].vocab_size, (4, 10))
|
|
||||||
next_token_id, cache_increase = generator.generate_iterator(
|
|
||||||
input_ids=input_ids,
|
|
||||||
temperature=0.8,
|
|
||||||
top_k=50,
|
|
||||||
top_p=0.95,
|
|
||||||
attn_mask=None,
|
|
||||||
kv_caches=None,
|
|
||||||
start_pos=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert next_token_id.shape == (4, 1)
|
|
||||||
assert cache_increase == 10
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue