feat(inference): 添加generate_loop方法并优化KVCacheManager初始化

This commit is contained in:
ViperEkura 2025-10-31 21:15:15 +08:00
parent cdb47a62dc
commit 877669b799
2 changed files with 54 additions and 63 deletions

View File

@ -1,7 +1,7 @@
import torch
from torch import Tensor
from typing import List, Tuple, Union, Optional, Self
from khaosz.config.param_config import ModelParameter
from typing import Any, Callable, List, Tuple, Union, Optional, Self
from khaosz.config import ModelParameter, TransformerConfig
def apply_sampling_strategies(
@ -86,6 +86,36 @@ class GeneratorCore:
self.model.to(*args, **kargs)
return self
def generate_loop(
self,
input_ids: Tensor,
ids: List[int],
temperature: float,
top_k: int,
top_p: float,
attn_mask: Optional[Tensor] = None,
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
start_pos: int = 0,
callback: Optional[Callable[..., Any]] = None
) -> List[int]:
cur_cache_pos = start_pos
for _ in range(len(ids), self.config.m_len):
next_token_id, cache_increase = self.generate_iterator(
input_ids, temperature, top_k, top_p, attn_mask, kv_caches, cur_cache_pos)
input_ids = next_token_id
ids.append(next_token_id.item())
cur_cache_pos += cache_increase
if callback:
callback(next_token_id.item(), ids.copy())
if next_token_id.item() in self.tokenizer.stop_ids:
break
return ids
class EmbeddingEncoderCore:
def __init__(self, parameter: ModelParameter):
@ -157,21 +187,18 @@ class EmbeddingEncoderCore:
class KVCacheManager:
def __init__(
self,
num_layers: int,
config: TransformerConfig,
batch_size: int,
max_len: int,
num_heads: int,
head_dim: int,
device: torch.device = "cuda",
dtype: torch.dtype = torch.bfloat16
):
self.num_layers = num_layers
self.batch_size = batch_size
self.max_len = max_len
self.num_heads = num_heads
self.head_dim = head_dim
self.device = device
self.dtype = dtype
self.num_layers = config.n_layer
self.max_len = config.m_len
self.num_heads = config.n_kvhead
self.head_dim = config.n_dim //config.n_head
self._kv_cache: Tuple[Tensor, Tensor] = None
self._seq_mask: Tensor = None

View File

@ -72,14 +72,7 @@ class TextGenerator(GeneratorCore):
assert top_p >= 0.0 and top_p <= 1.0
device = next(self.model.parameters()).device
cache_manager = KVCacheManager(
num_layers=self.config.n_layer,
batch_size=1,
max_len=self.config.m_len,
num_heads=self.config.n_kvhead,
head_dim=self.config.n_dim // self.config.n_head,
device=device,
)
cache_manager = KVCacheManager(self.config, 1, device=device)
ids = self.tokenizer.encode(query)
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
@ -89,16 +82,11 @@ class TextGenerator(GeneratorCore):
self.model.eval()
kv_caches = cache_manager.get_kvcache()
for _ in range(len(ids), self.config.m_len):
next_token_id, cache_increase = self.generate_iterator(
input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos)
input_ids = next_token_id
ids.append(next_token_id.item())
cur_cache_pos += cache_increase
if next_token_id.item() in self.tokenizer.stop_ids:
break
ids = self.generate_loop(
input_ids, ids, temperature, top_k, top_p,
kv_caches=kv_caches,
start_pos=cur_cache_pos
)
response = self.tokenizer.decode(ids[start_cache_pos:])
@ -126,14 +114,8 @@ class ChatGenerator(GeneratorCore):
history = []
device = next(self.model.parameters()).device
cache_manager = KVCacheManager(
num_layers=self.config.n_layer,
batch_size=1,
max_len=self.config.m_len,
num_heads=self.config.n_kvhead,
head_dim=self.config.n_dim // self.config.n_head,
device=device,
)
cache_manager = KVCacheManager(self.config, 1, device=device)
ids = self.tokenizer.encode(build_prompt(query, history))
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
cpy_history = history.copy()
@ -143,17 +125,12 @@ class ChatGenerator(GeneratorCore):
self.model.eval()
kv_caches = cache_manager.get_kvcache()
for _ in range(len(ids), self.config.m_len):
next_token_id, cache_increase = self.generate_iterator(
input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos)
input_ids = next_token_id
ids.append(next_token_id.item())
cur_cache_pos += cache_increase
if next_token_id.item() in self.tokenizer.stop_ids:
break
ids = self.generate_loop(
input_ids, ids, temperature, top_k, top_p,
kv_caches=kv_caches,
start_pos=cur_cache_pos
)
response = self.tokenizer.decode(ids[start_cache_pos:])
cpy_history.append((query, response))
@ -181,14 +158,8 @@ class StreamGenerator(GeneratorCore):
history = []
device = next(self.model.parameters()).device
cache_manager = KVCacheManager(
num_layers=self.config.n_layer,
batch_size=1,
max_len=self.config.m_len,
num_heads=self.config.n_kvhead,
head_dim=self.config.n_dim // self.config.n_head,
device=device,
)
cache_manager = KVCacheManager(self.config, 1, device=device)
ids = self.tokenizer.encode(build_prompt(query, history))
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
cpy_history = history.copy()
@ -241,14 +212,7 @@ class BatchGenerator(GeneratorCore):
ids_list = pad_sequence(ids_list, max_ids_len, self.tokenizer.pad_id)
device = next(self.model.parameters()).device
cache_manager = KVCacheManager(
num_layers=self.config.n_layer,
batch_size=batch_size,
max_len=self.config.m_len,
num_heads=self.config.n_kvhead,
head_dim=self.config.n_dim // self.config.n_head,
device=device,
)
cache_manager = KVCacheManager(self.config, batch_size, device=device)
input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long)
cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id)