feat(inference): 添加generate_loop方法并优化KVCacheManager初始化
This commit is contained in:
parent
cdb47a62dc
commit
877669b799
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import List, Tuple, Union, Optional, Self
|
from typing import Any, Callable, List, Tuple, Union, Optional, Self
|
||||||
from khaosz.config.param_config import ModelParameter
|
from khaosz.config import ModelParameter, TransformerConfig
|
||||||
|
|
||||||
|
|
||||||
def apply_sampling_strategies(
|
def apply_sampling_strategies(
|
||||||
|
|
@ -86,6 +86,36 @@ class GeneratorCore:
|
||||||
self.model.to(*args, **kargs)
|
self.model.to(*args, **kargs)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def generate_loop(
|
||||||
|
self,
|
||||||
|
input_ids: Tensor,
|
||||||
|
ids: List[int],
|
||||||
|
temperature: float,
|
||||||
|
top_k: int,
|
||||||
|
top_p: float,
|
||||||
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
|
||||||
|
start_pos: int = 0,
|
||||||
|
callback: Optional[Callable[..., Any]] = None
|
||||||
|
) -> List[int]:
|
||||||
|
cur_cache_pos = start_pos
|
||||||
|
|
||||||
|
for _ in range(len(ids), self.config.m_len):
|
||||||
|
next_token_id, cache_increase = self.generate_iterator(
|
||||||
|
input_ids, temperature, top_k, top_p, attn_mask, kv_caches, cur_cache_pos)
|
||||||
|
|
||||||
|
input_ids = next_token_id
|
||||||
|
ids.append(next_token_id.item())
|
||||||
|
cur_cache_pos += cache_increase
|
||||||
|
|
||||||
|
if callback:
|
||||||
|
callback(next_token_id.item(), ids.copy())
|
||||||
|
|
||||||
|
if next_token_id.item() in self.tokenizer.stop_ids:
|
||||||
|
break
|
||||||
|
|
||||||
|
return ids
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingEncoderCore:
|
class EmbeddingEncoderCore:
|
||||||
def __init__(self, parameter: ModelParameter):
|
def __init__(self, parameter: ModelParameter):
|
||||||
|
|
@ -157,21 +187,18 @@ class EmbeddingEncoderCore:
|
||||||
class KVCacheManager:
|
class KVCacheManager:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_layers: int,
|
config: TransformerConfig,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
max_len: int,
|
|
||||||
num_heads: int,
|
|
||||||
head_dim: int,
|
|
||||||
device: torch.device = "cuda",
|
device: torch.device = "cuda",
|
||||||
dtype: torch.dtype = torch.bfloat16
|
dtype: torch.dtype = torch.bfloat16
|
||||||
):
|
):
|
||||||
self.num_layers = num_layers
|
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.max_len = max_len
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_dim = head_dim
|
|
||||||
self.device = device
|
self.device = device
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
self.num_layers = config.n_layer
|
||||||
|
self.max_len = config.m_len
|
||||||
|
self.num_heads = config.n_kvhead
|
||||||
|
self.head_dim = config.n_dim //config.n_head
|
||||||
|
|
||||||
self._kv_cache: Tuple[Tensor, Tensor] = None
|
self._kv_cache: Tuple[Tensor, Tensor] = None
|
||||||
self._seq_mask: Tensor = None
|
self._seq_mask: Tensor = None
|
||||||
|
|
|
||||||
|
|
@ -72,14 +72,7 @@ class TextGenerator(GeneratorCore):
|
||||||
assert top_p >= 0.0 and top_p <= 1.0
|
assert top_p >= 0.0 and top_p <= 1.0
|
||||||
|
|
||||||
device = next(self.model.parameters()).device
|
device = next(self.model.parameters()).device
|
||||||
cache_manager = KVCacheManager(
|
cache_manager = KVCacheManager(self.config, 1, device=device)
|
||||||
num_layers=self.config.n_layer,
|
|
||||||
batch_size=1,
|
|
||||||
max_len=self.config.m_len,
|
|
||||||
num_heads=self.config.n_kvhead,
|
|
||||||
head_dim=self.config.n_dim // self.config.n_head,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
ids = self.tokenizer.encode(query)
|
ids = self.tokenizer.encode(query)
|
||||||
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
||||||
|
|
@ -89,16 +82,11 @@ class TextGenerator(GeneratorCore):
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
kv_caches = cache_manager.get_kvcache()
|
kv_caches = cache_manager.get_kvcache()
|
||||||
|
|
||||||
for _ in range(len(ids), self.config.m_len):
|
ids = self.generate_loop(
|
||||||
next_token_id, cache_increase = self.generate_iterator(
|
input_ids, ids, temperature, top_k, top_p,
|
||||||
input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos)
|
kv_caches=kv_caches,
|
||||||
|
start_pos=cur_cache_pos
|
||||||
input_ids = next_token_id
|
)
|
||||||
ids.append(next_token_id.item())
|
|
||||||
cur_cache_pos += cache_increase
|
|
||||||
|
|
||||||
if next_token_id.item() in self.tokenizer.stop_ids:
|
|
||||||
break
|
|
||||||
|
|
||||||
response = self.tokenizer.decode(ids[start_cache_pos:])
|
response = self.tokenizer.decode(ids[start_cache_pos:])
|
||||||
|
|
||||||
|
|
@ -126,14 +114,8 @@ class ChatGenerator(GeneratorCore):
|
||||||
history = []
|
history = []
|
||||||
|
|
||||||
device = next(self.model.parameters()).device
|
device = next(self.model.parameters()).device
|
||||||
cache_manager = KVCacheManager(
|
cache_manager = KVCacheManager(self.config, 1, device=device)
|
||||||
num_layers=self.config.n_layer,
|
|
||||||
batch_size=1,
|
|
||||||
max_len=self.config.m_len,
|
|
||||||
num_heads=self.config.n_kvhead,
|
|
||||||
head_dim=self.config.n_dim // self.config.n_head,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
ids = self.tokenizer.encode(build_prompt(query, history))
|
ids = self.tokenizer.encode(build_prompt(query, history))
|
||||||
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
||||||
cpy_history = history.copy()
|
cpy_history = history.copy()
|
||||||
|
|
@ -143,17 +125,12 @@ class ChatGenerator(GeneratorCore):
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
kv_caches = cache_manager.get_kvcache()
|
kv_caches = cache_manager.get_kvcache()
|
||||||
|
|
||||||
for _ in range(len(ids), self.config.m_len):
|
ids = self.generate_loop(
|
||||||
next_token_id, cache_increase = self.generate_iterator(
|
input_ids, ids, temperature, top_k, top_p,
|
||||||
input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos)
|
kv_caches=kv_caches,
|
||||||
|
start_pos=cur_cache_pos
|
||||||
input_ids = next_token_id
|
)
|
||||||
ids.append(next_token_id.item())
|
|
||||||
cur_cache_pos += cache_increase
|
|
||||||
|
|
||||||
if next_token_id.item() in self.tokenizer.stop_ids:
|
|
||||||
break
|
|
||||||
|
|
||||||
response = self.tokenizer.decode(ids[start_cache_pos:])
|
response = self.tokenizer.decode(ids[start_cache_pos:])
|
||||||
cpy_history.append((query, response))
|
cpy_history.append((query, response))
|
||||||
|
|
||||||
|
|
@ -181,14 +158,8 @@ class StreamGenerator(GeneratorCore):
|
||||||
history = []
|
history = []
|
||||||
|
|
||||||
device = next(self.model.parameters()).device
|
device = next(self.model.parameters()).device
|
||||||
cache_manager = KVCacheManager(
|
cache_manager = KVCacheManager(self.config, 1, device=device)
|
||||||
num_layers=self.config.n_layer,
|
|
||||||
batch_size=1,
|
|
||||||
max_len=self.config.m_len,
|
|
||||||
num_heads=self.config.n_kvhead,
|
|
||||||
head_dim=self.config.n_dim // self.config.n_head,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
ids = self.tokenizer.encode(build_prompt(query, history))
|
ids = self.tokenizer.encode(build_prompt(query, history))
|
||||||
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
||||||
cpy_history = history.copy()
|
cpy_history = history.copy()
|
||||||
|
|
@ -241,14 +212,7 @@ class BatchGenerator(GeneratorCore):
|
||||||
ids_list = pad_sequence(ids_list, max_ids_len, self.tokenizer.pad_id)
|
ids_list = pad_sequence(ids_list, max_ids_len, self.tokenizer.pad_id)
|
||||||
|
|
||||||
device = next(self.model.parameters()).device
|
device = next(self.model.parameters()).device
|
||||||
cache_manager = KVCacheManager(
|
cache_manager = KVCacheManager(self.config, batch_size, device=device)
|
||||||
num_layers=self.config.n_layer,
|
|
||||||
batch_size=batch_size,
|
|
||||||
max_len=self.config.m_len,
|
|
||||||
num_heads=self.config.n_kvhead,
|
|
||||||
head_dim=self.config.n_dim // self.config.n_head,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long)
|
input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long)
|
||||||
cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id)
|
cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue