feat(inference): 添加generate_loop方法并优化KVCacheManager初始化
This commit is contained in:
parent
cdb47a62dc
commit
877669b799
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
from typing import List, Tuple, Union, Optional, Self
|
||||
from khaosz.config.param_config import ModelParameter
|
||||
from typing import Any, Callable, List, Tuple, Union, Optional, Self
|
||||
from khaosz.config import ModelParameter, TransformerConfig
|
||||
|
||||
|
||||
def apply_sampling_strategies(
|
||||
|
|
@ -86,6 +86,36 @@ class GeneratorCore:
|
|||
self.model.to(*args, **kargs)
|
||||
return self
|
||||
|
||||
def generate_loop(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
ids: List[int],
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
|
||||
start_pos: int = 0,
|
||||
callback: Optional[Callable[..., Any]] = None
|
||||
) -> List[int]:
|
||||
cur_cache_pos = start_pos
|
||||
|
||||
for _ in range(len(ids), self.config.m_len):
|
||||
next_token_id, cache_increase = self.generate_iterator(
|
||||
input_ids, temperature, top_k, top_p, attn_mask, kv_caches, cur_cache_pos)
|
||||
|
||||
input_ids = next_token_id
|
||||
ids.append(next_token_id.item())
|
||||
cur_cache_pos += cache_increase
|
||||
|
||||
if callback:
|
||||
callback(next_token_id.item(), ids.copy())
|
||||
|
||||
if next_token_id.item() in self.tokenizer.stop_ids:
|
||||
break
|
||||
|
||||
return ids
|
||||
|
||||
|
||||
class EmbeddingEncoderCore:
|
||||
def __init__(self, parameter: ModelParameter):
|
||||
|
|
@ -157,21 +187,18 @@ class EmbeddingEncoderCore:
|
|||
class KVCacheManager:
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int,
|
||||
config: TransformerConfig,
|
||||
batch_size: int,
|
||||
max_len: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
device: torch.device = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16
|
||||
):
|
||||
self.num_layers = num_layers
|
||||
self.batch_size = batch_size
|
||||
self.max_len = max_len
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.num_layers = config.n_layer
|
||||
self.max_len = config.m_len
|
||||
self.num_heads = config.n_kvhead
|
||||
self.head_dim = config.n_dim //config.n_head
|
||||
|
||||
self._kv_cache: Tuple[Tensor, Tensor] = None
|
||||
self._seq_mask: Tensor = None
|
||||
|
|
|
|||
|
|
@ -72,14 +72,7 @@ class TextGenerator(GeneratorCore):
|
|||
assert top_p >= 0.0 and top_p <= 1.0
|
||||
|
||||
device = next(self.model.parameters()).device
|
||||
cache_manager = KVCacheManager(
|
||||
num_layers=self.config.n_layer,
|
||||
batch_size=1,
|
||||
max_len=self.config.m_len,
|
||||
num_heads=self.config.n_kvhead,
|
||||
head_dim=self.config.n_dim // self.config.n_head,
|
||||
device=device,
|
||||
)
|
||||
cache_manager = KVCacheManager(self.config, 1, device=device)
|
||||
|
||||
ids = self.tokenizer.encode(query)
|
||||
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
||||
|
|
@ -89,16 +82,11 @@ class TextGenerator(GeneratorCore):
|
|||
self.model.eval()
|
||||
kv_caches = cache_manager.get_kvcache()
|
||||
|
||||
for _ in range(len(ids), self.config.m_len):
|
||||
next_token_id, cache_increase = self.generate_iterator(
|
||||
input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos)
|
||||
|
||||
input_ids = next_token_id
|
||||
ids.append(next_token_id.item())
|
||||
cur_cache_pos += cache_increase
|
||||
|
||||
if next_token_id.item() in self.tokenizer.stop_ids:
|
||||
break
|
||||
ids = self.generate_loop(
|
||||
input_ids, ids, temperature, top_k, top_p,
|
||||
kv_caches=kv_caches,
|
||||
start_pos=cur_cache_pos
|
||||
)
|
||||
|
||||
response = self.tokenizer.decode(ids[start_cache_pos:])
|
||||
|
||||
|
|
@ -126,14 +114,8 @@ class ChatGenerator(GeneratorCore):
|
|||
history = []
|
||||
|
||||
device = next(self.model.parameters()).device
|
||||
cache_manager = KVCacheManager(
|
||||
num_layers=self.config.n_layer,
|
||||
batch_size=1,
|
||||
max_len=self.config.m_len,
|
||||
num_heads=self.config.n_kvhead,
|
||||
head_dim=self.config.n_dim // self.config.n_head,
|
||||
device=device,
|
||||
)
|
||||
cache_manager = KVCacheManager(self.config, 1, device=device)
|
||||
|
||||
ids = self.tokenizer.encode(build_prompt(query, history))
|
||||
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
||||
cpy_history = history.copy()
|
||||
|
|
@ -143,17 +125,12 @@ class ChatGenerator(GeneratorCore):
|
|||
self.model.eval()
|
||||
kv_caches = cache_manager.get_kvcache()
|
||||
|
||||
for _ in range(len(ids), self.config.m_len):
|
||||
next_token_id, cache_increase = self.generate_iterator(
|
||||
input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos)
|
||||
|
||||
input_ids = next_token_id
|
||||
ids.append(next_token_id.item())
|
||||
cur_cache_pos += cache_increase
|
||||
|
||||
if next_token_id.item() in self.tokenizer.stop_ids:
|
||||
break
|
||||
|
||||
ids = self.generate_loop(
|
||||
input_ids, ids, temperature, top_k, top_p,
|
||||
kv_caches=kv_caches,
|
||||
start_pos=cur_cache_pos
|
||||
)
|
||||
|
||||
response = self.tokenizer.decode(ids[start_cache_pos:])
|
||||
cpy_history.append((query, response))
|
||||
|
||||
|
|
@ -181,14 +158,8 @@ class StreamGenerator(GeneratorCore):
|
|||
history = []
|
||||
|
||||
device = next(self.model.parameters()).device
|
||||
cache_manager = KVCacheManager(
|
||||
num_layers=self.config.n_layer,
|
||||
batch_size=1,
|
||||
max_len=self.config.m_len,
|
||||
num_heads=self.config.n_kvhead,
|
||||
head_dim=self.config.n_dim // self.config.n_head,
|
||||
device=device,
|
||||
)
|
||||
cache_manager = KVCacheManager(self.config, 1, device=device)
|
||||
|
||||
ids = self.tokenizer.encode(build_prompt(query, history))
|
||||
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
||||
cpy_history = history.copy()
|
||||
|
|
@ -241,14 +212,7 @@ class BatchGenerator(GeneratorCore):
|
|||
ids_list = pad_sequence(ids_list, max_ids_len, self.tokenizer.pad_id)
|
||||
|
||||
device = next(self.model.parameters()).device
|
||||
cache_manager = KVCacheManager(
|
||||
num_layers=self.config.n_layer,
|
||||
batch_size=batch_size,
|
||||
max_len=self.config.m_len,
|
||||
num_heads=self.config.n_kvhead,
|
||||
head_dim=self.config.n_dim // self.config.n_head,
|
||||
device=device,
|
||||
)
|
||||
cache_manager = KVCacheManager(self.config, batch_size, device=device)
|
||||
|
||||
input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long)
|
||||
cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id)
|
||||
|
|
|
|||
Loading…
Reference in New Issue