feat(inference): 添加generate_loop方法并优化KVCacheManager初始化

This commit is contained in:
ViperEkura 2025-10-31 21:15:15 +08:00
parent cdb47a62dc
commit 877669b799
2 changed files with 54 additions and 63 deletions

View File

@ -1,7 +1,7 @@
import torch import torch
from torch import Tensor from torch import Tensor
from typing import List, Tuple, Union, Optional, Self from typing import Any, Callable, List, Tuple, Union, Optional, Self
from khaosz.config.param_config import ModelParameter from khaosz.config import ModelParameter, TransformerConfig
def apply_sampling_strategies( def apply_sampling_strategies(
@ -86,6 +86,36 @@ class GeneratorCore:
self.model.to(*args, **kargs) self.model.to(*args, **kargs)
return self return self
def generate_loop(
self,
input_ids: Tensor,
ids: List[int],
temperature: float,
top_k: int,
top_p: float,
attn_mask: Optional[Tensor] = None,
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
start_pos: int = 0,
callback: Optional[Callable[..., Any]] = None
) -> List[int]:
cur_cache_pos = start_pos
for _ in range(len(ids), self.config.m_len):
next_token_id, cache_increase = self.generate_iterator(
input_ids, temperature, top_k, top_p, attn_mask, kv_caches, cur_cache_pos)
input_ids = next_token_id
ids.append(next_token_id.item())
cur_cache_pos += cache_increase
if callback:
callback(next_token_id.item(), ids.copy())
if next_token_id.item() in self.tokenizer.stop_ids:
break
return ids
class EmbeddingEncoderCore: class EmbeddingEncoderCore:
def __init__(self, parameter: ModelParameter): def __init__(self, parameter: ModelParameter):
@ -157,21 +187,18 @@ class EmbeddingEncoderCore:
class KVCacheManager: class KVCacheManager:
def __init__( def __init__(
self, self,
num_layers: int, config: TransformerConfig,
batch_size: int, batch_size: int,
max_len: int,
num_heads: int,
head_dim: int,
device: torch.device = "cuda", device: torch.device = "cuda",
dtype: torch.dtype = torch.bfloat16 dtype: torch.dtype = torch.bfloat16
): ):
self.num_layers = num_layers
self.batch_size = batch_size self.batch_size = batch_size
self.max_len = max_len
self.num_heads = num_heads
self.head_dim = head_dim
self.device = device self.device = device
self.dtype = dtype self.dtype = dtype
self.num_layers = config.n_layer
self.max_len = config.m_len
self.num_heads = config.n_kvhead
self.head_dim = config.n_dim //config.n_head
self._kv_cache: Tuple[Tensor, Tensor] = None self._kv_cache: Tuple[Tensor, Tensor] = None
self._seq_mask: Tensor = None self._seq_mask: Tensor = None

View File

@ -72,14 +72,7 @@ class TextGenerator(GeneratorCore):
assert top_p >= 0.0 and top_p <= 1.0 assert top_p >= 0.0 and top_p <= 1.0
device = next(self.model.parameters()).device device = next(self.model.parameters()).device
cache_manager = KVCacheManager( cache_manager = KVCacheManager(self.config, 1, device=device)
num_layers=self.config.n_layer,
batch_size=1,
max_len=self.config.m_len,
num_heads=self.config.n_kvhead,
head_dim=self.config.n_dim // self.config.n_head,
device=device,
)
ids = self.tokenizer.encode(query) ids = self.tokenizer.encode(query)
input_ids = torch.tensor([ids], device=device, dtype=torch.long) input_ids = torch.tensor([ids], device=device, dtype=torch.long)
@ -89,16 +82,11 @@ class TextGenerator(GeneratorCore):
self.model.eval() self.model.eval()
kv_caches = cache_manager.get_kvcache() kv_caches = cache_manager.get_kvcache()
for _ in range(len(ids), self.config.m_len): ids = self.generate_loop(
next_token_id, cache_increase = self.generate_iterator( input_ids, ids, temperature, top_k, top_p,
input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos) kv_caches=kv_caches,
start_pos=cur_cache_pos
input_ids = next_token_id )
ids.append(next_token_id.item())
cur_cache_pos += cache_increase
if next_token_id.item() in self.tokenizer.stop_ids:
break
response = self.tokenizer.decode(ids[start_cache_pos:]) response = self.tokenizer.decode(ids[start_cache_pos:])
@ -126,14 +114,8 @@ class ChatGenerator(GeneratorCore):
history = [] history = []
device = next(self.model.parameters()).device device = next(self.model.parameters()).device
cache_manager = KVCacheManager( cache_manager = KVCacheManager(self.config, 1, device=device)
num_layers=self.config.n_layer,
batch_size=1,
max_len=self.config.m_len,
num_heads=self.config.n_kvhead,
head_dim=self.config.n_dim // self.config.n_head,
device=device,
)
ids = self.tokenizer.encode(build_prompt(query, history)) ids = self.tokenizer.encode(build_prompt(query, history))
input_ids = torch.tensor([ids], device=device, dtype=torch.long) input_ids = torch.tensor([ids], device=device, dtype=torch.long)
cpy_history = history.copy() cpy_history = history.copy()
@ -143,17 +125,12 @@ class ChatGenerator(GeneratorCore):
self.model.eval() self.model.eval()
kv_caches = cache_manager.get_kvcache() kv_caches = cache_manager.get_kvcache()
for _ in range(len(ids), self.config.m_len): ids = self.generate_loop(
next_token_id, cache_increase = self.generate_iterator( input_ids, ids, temperature, top_k, top_p,
input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos) kv_caches=kv_caches,
start_pos=cur_cache_pos
input_ids = next_token_id )
ids.append(next_token_id.item())
cur_cache_pos += cache_increase
if next_token_id.item() in self.tokenizer.stop_ids:
break
response = self.tokenizer.decode(ids[start_cache_pos:]) response = self.tokenizer.decode(ids[start_cache_pos:])
cpy_history.append((query, response)) cpy_history.append((query, response))
@ -181,14 +158,8 @@ class StreamGenerator(GeneratorCore):
history = [] history = []
device = next(self.model.parameters()).device device = next(self.model.parameters()).device
cache_manager = KVCacheManager( cache_manager = KVCacheManager(self.config, 1, device=device)
num_layers=self.config.n_layer,
batch_size=1,
max_len=self.config.m_len,
num_heads=self.config.n_kvhead,
head_dim=self.config.n_dim // self.config.n_head,
device=device,
)
ids = self.tokenizer.encode(build_prompt(query, history)) ids = self.tokenizer.encode(build_prompt(query, history))
input_ids = torch.tensor([ids], device=device, dtype=torch.long) input_ids = torch.tensor([ids], device=device, dtype=torch.long)
cpy_history = history.copy() cpy_history = history.copy()
@ -241,14 +212,7 @@ class BatchGenerator(GeneratorCore):
ids_list = pad_sequence(ids_list, max_ids_len, self.tokenizer.pad_id) ids_list = pad_sequence(ids_list, max_ids_len, self.tokenizer.pad_id)
device = next(self.model.parameters()).device device = next(self.model.parameters()).device
cache_manager = KVCacheManager( cache_manager = KVCacheManager(self.config, batch_size, device=device)
num_layers=self.config.n_layer,
batch_size=batch_size,
max_len=self.config.m_len,
num_heads=self.config.n_kvhead,
head_dim=self.config.n_dim // self.config.n_head,
device=device,
)
input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long) input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long)
cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id) cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id)