diff --git a/khaosz/inference/core.py b/khaosz/inference/core.py index f019f86..e372cc4 100644 --- a/khaosz/inference/core.py +++ b/khaosz/inference/core.py @@ -1,7 +1,7 @@ import torch from torch import Tensor -from typing import List, Tuple, Union, Optional, Self -from khaosz.config.param_config import ModelParameter +from typing import Any, Callable, List, Tuple, Union, Optional, Self +from khaosz.config import ModelParameter, TransformerConfig def apply_sampling_strategies( @@ -86,6 +86,36 @@ class GeneratorCore: self.model.to(*args, **kargs) return self + def generate_loop( + self, + input_ids: Tensor, + ids: List[int], + temperature: float, + top_k: int, + top_p: float, + attn_mask: Optional[Tensor] = None, + kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None, + start_pos: int = 0, + callback: Optional[Callable[..., Any]] = None + ) -> List[int]: + cur_cache_pos = start_pos + + for _ in range(len(ids), self.config.m_len): + next_token_id, cache_increase = self.generate_iterator( + input_ids, temperature, top_k, top_p, attn_mask, kv_caches, cur_cache_pos) + + input_ids = next_token_id + ids.append(next_token_id.item()) + cur_cache_pos += cache_increase + + if callback: + callback(next_token_id.item(), ids.copy()) + + if next_token_id.item() in self.tokenizer.stop_ids: + break + + return ids + class EmbeddingEncoderCore: def __init__(self, parameter: ModelParameter): @@ -157,21 +187,18 @@ class EmbeddingEncoderCore: class KVCacheManager: def __init__( self, - num_layers: int, + config: TransformerConfig, batch_size: int, - max_len: int, - num_heads: int, - head_dim: int, device: torch.device = "cuda", dtype: torch.dtype = torch.bfloat16 ): - self.num_layers = num_layers self.batch_size = batch_size - self.max_len = max_len - self.num_heads = num_heads - self.head_dim = head_dim self.device = device self.dtype = dtype + self.num_layers = config.n_layer + self.max_len = config.m_len + self.num_heads = config.n_kvhead + self.head_dim = config.n_dim //config.n_head self._kv_cache: Tuple[Tensor, Tensor] = None self._seq_mask: Tensor = None diff --git a/khaosz/inference/generator.py b/khaosz/inference/generator.py index cbd404f..8e19c81 100644 --- a/khaosz/inference/generator.py +++ b/khaosz/inference/generator.py @@ -72,14 +72,7 @@ class TextGenerator(GeneratorCore): assert top_p >= 0.0 and top_p <= 1.0 device = next(self.model.parameters()).device - cache_manager = KVCacheManager( - num_layers=self.config.n_layer, - batch_size=1, - max_len=self.config.m_len, - num_heads=self.config.n_kvhead, - head_dim=self.config.n_dim // self.config.n_head, - device=device, - ) + cache_manager = KVCacheManager(self.config, 1, device=device) ids = self.tokenizer.encode(query) input_ids = torch.tensor([ids], device=device, dtype=torch.long) @@ -89,16 +82,11 @@ class TextGenerator(GeneratorCore): self.model.eval() kv_caches = cache_manager.get_kvcache() - for _ in range(len(ids), self.config.m_len): - next_token_id, cache_increase = self.generate_iterator( - input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos) - - input_ids = next_token_id - ids.append(next_token_id.item()) - cur_cache_pos += cache_increase - - if next_token_id.item() in self.tokenizer.stop_ids: - break + ids = self.generate_loop( + input_ids, ids, temperature, top_k, top_p, + kv_caches=kv_caches, + start_pos=cur_cache_pos + ) response = self.tokenizer.decode(ids[start_cache_pos:]) @@ -126,14 +114,8 @@ class ChatGenerator(GeneratorCore): history = [] device = next(self.model.parameters()).device - cache_manager = KVCacheManager( - num_layers=self.config.n_layer, - batch_size=1, - max_len=self.config.m_len, - num_heads=self.config.n_kvhead, - head_dim=self.config.n_dim // self.config.n_head, - device=device, - ) + cache_manager = KVCacheManager(self.config, 1, device=device) + ids = self.tokenizer.encode(build_prompt(query, history)) input_ids = torch.tensor([ids], device=device, dtype=torch.long) cpy_history = history.copy() @@ -143,17 +125,12 @@ class ChatGenerator(GeneratorCore): self.model.eval() kv_caches = cache_manager.get_kvcache() - for _ in range(len(ids), self.config.m_len): - next_token_id, cache_increase = self.generate_iterator( - input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos) - - input_ids = next_token_id - ids.append(next_token_id.item()) - cur_cache_pos += cache_increase - - if next_token_id.item() in self.tokenizer.stop_ids: - break - + ids = self.generate_loop( + input_ids, ids, temperature, top_k, top_p, + kv_caches=kv_caches, + start_pos=cur_cache_pos + ) + response = self.tokenizer.decode(ids[start_cache_pos:]) cpy_history.append((query, response)) @@ -181,14 +158,8 @@ class StreamGenerator(GeneratorCore): history = [] device = next(self.model.parameters()).device - cache_manager = KVCacheManager( - num_layers=self.config.n_layer, - batch_size=1, - max_len=self.config.m_len, - num_heads=self.config.n_kvhead, - head_dim=self.config.n_dim // self.config.n_head, - device=device, - ) + cache_manager = KVCacheManager(self.config, 1, device=device) + ids = self.tokenizer.encode(build_prompt(query, history)) input_ids = torch.tensor([ids], device=device, dtype=torch.long) cpy_history = history.copy() @@ -241,14 +212,7 @@ class BatchGenerator(GeneratorCore): ids_list = pad_sequence(ids_list, max_ids_len, self.tokenizer.pad_id) device = next(self.model.parameters()).device - cache_manager = KVCacheManager( - num_layers=self.config.n_layer, - batch_size=batch_size, - max_len=self.config.m_len, - num_heads=self.config.n_kvhead, - head_dim=self.config.n_dim // self.config.n_head, - device=device, - ) + cache_manager = KVCacheManager(self.config, batch_size, device=device) input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long) cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id)