refactor: 优化接口设置, 去除冗余代码
This commit is contained in:
parent
e23a5ca426
commit
62fba9a298
|
|
@ -12,7 +12,7 @@ def batch_generate():
|
|||
inputs = ["你好", "请问什么是人工智能", "今天天气如何", "我感到焦虑, 请问我应该怎么办", "请问什么是显卡"]
|
||||
|
||||
responses = model.batch_generate(
|
||||
queries=inputs,
|
||||
query=inputs,
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
top_k=50
|
||||
|
|
|
|||
|
|
@ -26,9 +26,10 @@ if __name__ == "__main__":
|
|||
query = "作者设计了一个怎样的模型"
|
||||
emb_query = model.encode(query)
|
||||
retrieved = retriever.retrieve(emb_query, retrive_top_k)
|
||||
retrieved_content = "\n".join([f"{idx + 1}. " + text for idx, (text, _) in enumerate(retrieved)])
|
||||
|
||||
retrive_response = model.retrieve_generate(
|
||||
retrieved=retrieved,
|
||||
retrieved=retrieved_content,
|
||||
query=query,
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
|
|
@ -36,7 +37,7 @@ if __name__ == "__main__":
|
|||
)
|
||||
|
||||
print("retrieve content:")
|
||||
print("\n".join([f"{idx + 1}. " + text for idx, (text, _) in enumerate(retrieved)]))
|
||||
print(retrieved_content)
|
||||
|
||||
print("\n\nretrive generate:")
|
||||
print(retrive_response)
|
||||
|
|
@ -17,12 +17,11 @@ from khaosz.data import (
|
|||
BpeTokenizer
|
||||
)
|
||||
from khaosz.inference.generator import (
|
||||
TextGenerator,
|
||||
ChatGenerator,
|
||||
StreamGenerator,
|
||||
BatchGenerator,
|
||||
RetrievalGenerator,
|
||||
EmbeddingEncoder
|
||||
GenerationRequest,
|
||||
LoopGenerator,
|
||||
StreamGenerator,
|
||||
BatchGenerator,
|
||||
EmbeddingEncoder,
|
||||
)
|
||||
|
||||
from khaosz.trainer import (
|
||||
|
|
@ -46,11 +45,10 @@ __all__ = [
|
|||
"DatasetLoader",
|
||||
"BpeTokenizer",
|
||||
|
||||
"TextGenerator",
|
||||
"ChatGenerator",
|
||||
"GenerationRequest",
|
||||
"LoopGenerator",
|
||||
"StreamGenerator",
|
||||
"BatchGenerator",
|
||||
"RetrievalGenerator",
|
||||
"EmbeddingEncoder",
|
||||
|
||||
"Trainer",
|
||||
|
|
|
|||
107
khaosz/api.py
107
khaosz/api.py
|
|
@ -1,21 +1,42 @@
|
|||
from torch import nn
|
||||
from torch import Tensor
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Tuple, Generator, Union
|
||||
|
||||
from khaosz.inference.generator import (
|
||||
TextGenerator,
|
||||
ChatGenerator,
|
||||
GenerationRequest,
|
||||
LoopGenerator,
|
||||
StreamGenerator,
|
||||
BatchGenerator,
|
||||
RetrievalGenerator,
|
||||
EmbeddingEncoder
|
||||
)
|
||||
from khaosz.config.param_config import ModelParameter
|
||||
|
||||
@contextmanager
|
||||
def disable_random_init():
|
||||
init_functions = [
|
||||
'xavier_normal_', 'xavier_uniform_',
|
||||
'kaiming_normal_', 'kaiming_uniform_',
|
||||
'zeros_', 'ones_', 'constant_',
|
||||
'normal_', 'uniform_'
|
||||
]
|
||||
original_funcs = {}
|
||||
for name in init_functions:
|
||||
if hasattr(nn.init, name):
|
||||
original_funcs[name] = getattr(nn.init, name)
|
||||
setattr(nn.init, name, lambda *args, **kwargs: None)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for name, orig_func in original_funcs.items():
|
||||
setattr(nn.init, name, orig_func)
|
||||
|
||||
|
||||
class Khaosz:
|
||||
def __init__(self, model_dir: str):
|
||||
self.parameter = ModelParameter()
|
||||
self.parameter.load(model_dir)
|
||||
with disable_random_init():
|
||||
self.parameter = ModelParameter()
|
||||
self.parameter.load(model_dir)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
self.parameter.to(*args, **kwargs)
|
||||
|
|
@ -29,32 +50,33 @@ class Khaosz:
|
|||
top_k: int=50,
|
||||
top_p: float=0.95,
|
||||
) -> str:
|
||||
generator = ChatGenerator(self.parameter)
|
||||
generator = LoopGenerator(self.parameter)
|
||||
return generator.generate(
|
||||
query,
|
||||
GenerationRequest(
|
||||
top_k, top_p, temperature,
|
||||
self.parameter.config.max_len,
|
||||
query=query,
|
||||
history=history,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
)
|
||||
build_prompt=True
|
||||
))
|
||||
|
||||
def batch_generate(
|
||||
self,
|
||||
queries: List[str],
|
||||
histories: List[Tuple[str, str]]=None,
|
||||
query: List[str],
|
||||
history: List[Tuple[str, str]]=None,
|
||||
temperature: float=0.8,
|
||||
top_k: int=50,
|
||||
top_p: float=0.95,
|
||||
) -> List[str]:
|
||||
generator = BatchGenerator(self.parameter)
|
||||
return generator.generate(
|
||||
queries,
|
||||
histories=histories,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
)
|
||||
|
||||
GenerationRequest(
|
||||
top_k, top_p, temperature,
|
||||
self.parameter.config.max_len,
|
||||
query=query,
|
||||
history=history,
|
||||
build_prompt=True
|
||||
))
|
||||
|
||||
def stream_generate(
|
||||
self,
|
||||
|
|
@ -66,12 +88,13 @@ class Khaosz:
|
|||
) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
|
||||
stream_generator = StreamGenerator(self.parameter)
|
||||
return stream_generator.generate(
|
||||
query,
|
||||
history=history,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
)
|
||||
GenerationRequest(
|
||||
top_k, top_p, temperature,
|
||||
self.parameter.config.max_len,
|
||||
query=query,
|
||||
history=history,
|
||||
build_prompt=True
|
||||
))
|
||||
|
||||
def retrieve_generate(
|
||||
self,
|
||||
|
|
@ -82,15 +105,16 @@ class Khaosz:
|
|||
top_k: int=50,
|
||||
top_p: float=0.95,
|
||||
) -> str:
|
||||
generator = RetrievalGenerator(self.parameter)
|
||||
generator = LoopGenerator(self.parameter)
|
||||
return generator.generate(
|
||||
retrieved,
|
||||
query,
|
||||
GenerationRequest(
|
||||
top_k, top_p, temperature,
|
||||
self.parameter.config.max_len,
|
||||
query=query,
|
||||
history=history,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
)
|
||||
system_prompt=retrieved,
|
||||
build_prompt=True
|
||||
))
|
||||
|
||||
def text_generate(
|
||||
self,
|
||||
|
|
@ -99,14 +123,15 @@ class Khaosz:
|
|||
top_k: int=50,
|
||||
top_p: float=0.95,
|
||||
) -> str:
|
||||
generator = TextGenerator(self.parameter)
|
||||
|
||||
return generator.generate(
|
||||
query,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
)
|
||||
generator = LoopGenerator(self.parameter)
|
||||
return generator.generate(
|
||||
GenerationRequest(
|
||||
top_k, top_p, temperature,
|
||||
self.parameter.config.max_len,
|
||||
query=query,
|
||||
build_prompt=False
|
||||
))
|
||||
|
||||
|
||||
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
|
||||
encoder = EmbeddingEncoder(self.parameter)
|
||||
|
|
|
|||
|
|
@ -1 +1,25 @@
|
|||
# init file
|
||||
from khaosz.inference.core import (
|
||||
GeneratorCore,
|
||||
EmbeddingEncoderCore,
|
||||
KVCacheManager,
|
||||
)
|
||||
|
||||
from khaosz.inference.generator import (
|
||||
GenerationRequest,
|
||||
LoopGenerator,
|
||||
StreamGenerator,
|
||||
BatchGenerator,
|
||||
EmbeddingEncoder,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"GeneratorCore",
|
||||
"EmbeddingEncoderCore",
|
||||
"KVCacheManager",
|
||||
|
||||
"GenerationRequest",
|
||||
"LoopGenerator",
|
||||
"StreamGenerator",
|
||||
"BatchGenerator",
|
||||
"EmbeddingEncoder",
|
||||
]
|
||||
|
|
@ -1,98 +0,0 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
from functools import wraps
|
||||
from inspect import signature
|
||||
|
||||
|
||||
class CudaGraphWrapper:
|
||||
def __init__(self, function, device="cuda", cast=False):
|
||||
self.function = function
|
||||
self.cast = cast
|
||||
self.device = device
|
||||
self.static_input = None
|
||||
self.static_output = None
|
||||
self.graph = None
|
||||
self.signature = signature(function)
|
||||
|
||||
def _update_inplace(self, lhs, rhs):
|
||||
if isinstance(lhs, Tensor) and isinstance(rhs, Tensor):
|
||||
if lhs.shape != rhs.shape:
|
||||
raise ValueError(
|
||||
f"Tensor shape mismatch! "
|
||||
f"Expected: {lhs.shape}, Got: {rhs.shape}. "
|
||||
f"Function: {self.function}"
|
||||
)
|
||||
if self.cast:
|
||||
if lhs.device != rhs.device:
|
||||
rhs = rhs.to(device=lhs.device)
|
||||
|
||||
if lhs.dtype != rhs.dtype:
|
||||
rhs = rhs.to(dtype=lhs.dtype)
|
||||
else:
|
||||
if lhs.device != rhs.device:
|
||||
raise ValueError(
|
||||
f"Tensor device mismatch! "
|
||||
f"Expected: {lhs.device}, Got: {rhs.device}. "
|
||||
f"Function: {self.function}"
|
||||
)
|
||||
if lhs.dtype != rhs.dtype:
|
||||
raise ValueError(
|
||||
f"Tensor dtype mismatch! "
|
||||
f"Expected: {lhs.dtype}, Got: {rhs.dtype}. "
|
||||
f"Function: {self.function}"
|
||||
)
|
||||
lhs.copy_(rhs)
|
||||
elif isinstance(lhs, dict):
|
||||
for k in lhs:
|
||||
if k in rhs:
|
||||
self._update_inplace(lhs[k], rhs[k])
|
||||
elif isinstance(lhs, (list, tuple)):
|
||||
for i in range(len(lhs)):
|
||||
if i < len(rhs):
|
||||
self._update_inplace(lhs[i], rhs[i])
|
||||
elif isinstance(lhs, (int, float, bool, str, type(None))):
|
||||
if lhs != rhs:
|
||||
raise ValueError("Does not support changing control parameters.")
|
||||
|
||||
def _update_args(self, input_args, input_kwargs):
|
||||
bound_args = self.signature.bind(*input_args, **input_kwargs)
|
||||
bound_args.apply_defaults()
|
||||
args_dict = bound_args.arguments
|
||||
|
||||
if self.static_input is None:
|
||||
self.static_input = args_dict
|
||||
else:
|
||||
self._update_inplace(self.static_input, args_dict)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
self._update_args(args, kwargs)
|
||||
|
||||
if self.graph is None:
|
||||
# warmup
|
||||
_ = torch.matmul(
|
||||
torch.randn(100, 100, device=self.device),
|
||||
torch.randn(100, 100, device=self.device)
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# capture graph
|
||||
self.graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.graph):
|
||||
self.static_output = self.function(**self.static_input)
|
||||
|
||||
self.graph.replay()
|
||||
|
||||
return self.static_output
|
||||
|
||||
|
||||
def cuda_graph(device="cuda", cast=False):
|
||||
def decorator(func):
|
||||
wrapper = CudaGraphWrapper(func, device, cast)
|
||||
|
||||
@wraps(func)
|
||||
def wrapped(*args, **kwargs):
|
||||
return wrapper.run(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
return decorator
|
||||
|
|
@ -1,10 +1,13 @@
|
|||
import torch
|
||||
from dataclasses import dataclass
|
||||
from torch import Tensor
|
||||
from typing import List, Tuple, Union, Optional, Generator
|
||||
from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheManager
|
||||
from khaosz.config.param_config import ModelParameter
|
||||
|
||||
|
||||
HistoryType = List[Tuple[str, str]]
|
||||
|
||||
def build_prompt(
|
||||
query: str,
|
||||
init_prompt: Optional[str] = None,
|
||||
|
|
@ -34,7 +37,7 @@ def build_prompt(
|
|||
|
||||
return prompt
|
||||
|
||||
def pad_sequence(ids_list: List[List[int]], max_ids_len: int, pad_id: int) -> List[List[int]]:
|
||||
def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]:
|
||||
"""
|
||||
Pad a list of sequences to a fixed length.
|
||||
|
||||
|
|
@ -47,34 +50,47 @@ def pad_sequence(ids_list: List[List[int]], max_ids_len: int, pad_id: int) -> Li
|
|||
List[List[int]]: A list of padded sequences.
|
||||
|
||||
"""
|
||||
max_ids_len = max(len(ids) for ids in ids_list)
|
||||
new_ids_list = []
|
||||
for ids in ids_list:
|
||||
pad_len = max_ids_len - len(ids)
|
||||
padded_seq = [pad_id] * pad_len + ids
|
||||
new_ids_list.append(padded_seq)
|
||||
|
||||
return new_ids_list
|
||||
return new_ids_list, max_ids_len
|
||||
|
||||
@dataclass
|
||||
class GenerationRequest:
|
||||
top_k: int
|
||||
top_p: float
|
||||
temperature: float
|
||||
max_len: int
|
||||
|
||||
query: Union[str, List[str]]
|
||||
history: Optional[Union[HistoryType, List[HistoryType]]] = None
|
||||
system_prompt: Optional[str] = None
|
||||
|
||||
build_prompt: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if not isinstance(self.top_k, int) or self.top_k < 0:
|
||||
raise ValueError("top_k must be a non-negative integer")
|
||||
if not isinstance(self.top_p, float) or self.top_p < 0.0 or self.top_p > 1.0:
|
||||
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
||||
if not isinstance(self.temperature, float) or self.temperature < 0.0:
|
||||
raise ValueError("temperature must be a non-negative float")
|
||||
|
||||
|
||||
class TextGenerator(GeneratorCore):
|
||||
class LoopGenerator(GeneratorCore):
|
||||
def __init__(self, parameter: ModelParameter):
|
||||
super().__init__(parameter)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
query: str,
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
) -> str:
|
||||
assert temperature >= 0.0
|
||||
assert top_k >= 0
|
||||
assert top_p >= 0.0 and top_p <= 1.0
|
||||
|
||||
def generate(self, request: GenerationRequest) -> str:
|
||||
device = next(self.model.parameters()).device
|
||||
cache_manager = KVCacheManager(self.config, 1, device=device)
|
||||
|
||||
ids = self.tokenizer.encode(query)
|
||||
input_args = build_prompt(request.query, request.history) if request.build_prompt else request.query
|
||||
ids = self.tokenizer.encode(input_args)
|
||||
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
||||
|
||||
start_cache_pos = len(ids)
|
||||
|
|
@ -83,84 +99,30 @@ class TextGenerator(GeneratorCore):
|
|||
kv_caches = cache_manager.get_kvcache()
|
||||
|
||||
ids = self.generate_loop(
|
||||
input_ids, ids, temperature, top_k, top_p,
|
||||
input_ids, ids, request.temperature, request.top_k, request.top_p,
|
||||
kv_caches=kv_caches,
|
||||
start_pos=cur_cache_pos
|
||||
)
|
||||
|
||||
response = self.tokenizer.decode(ids[start_cache_pos:])
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class ChatGenerator(GeneratorCore):
|
||||
def __init__(self, parameter: ModelParameter):
|
||||
super().__init__(parameter)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
query: str,
|
||||
history: List[Tuple[str, str]],
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
) -> str:
|
||||
|
||||
assert temperature >= 0.0
|
||||
assert top_k >= 0
|
||||
assert top_p >= 0.0 and top_p <= 1.0
|
||||
|
||||
if history is None:
|
||||
history = []
|
||||
|
||||
device = next(self.model.parameters()).device
|
||||
cache_manager = KVCacheManager(self.config, 1, device=device)
|
||||
|
||||
ids = self.tokenizer.encode(build_prompt(query, history))
|
||||
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
||||
|
||||
start_cache_pos = len(ids)
|
||||
cur_cache_pos = 0
|
||||
self.model.eval()
|
||||
kv_caches = cache_manager.get_kvcache()
|
||||
|
||||
ids = self.generate_loop(
|
||||
input_ids, ids, temperature, top_k, top_p,
|
||||
kv_caches=kv_caches,
|
||||
start_pos=cur_cache_pos
|
||||
)
|
||||
|
||||
response = self.tokenizer.decode(ids[start_cache_pos:])
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class StreamGenerator(GeneratorCore):
|
||||
def __init__(self, parameter: ModelParameter):
|
||||
super().__init__(parameter)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
query: str,
|
||||
history: List[Tuple[str, str]],
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
|
||||
|
||||
assert temperature >= 0.0
|
||||
assert top_k >= 0
|
||||
assert top_p >= 0.0 and top_p <= 1.0
|
||||
def generate(self, request: GenerationRequest) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
|
||||
|
||||
if history is None:
|
||||
history = []
|
||||
if request.history is None:
|
||||
request.history = []
|
||||
|
||||
device = next(self.model.parameters()).device
|
||||
cache_manager = KVCacheManager(self.config, 1, device=device)
|
||||
|
||||
ids = self.tokenizer.encode(build_prompt(query, history))
|
||||
|
||||
input_args = build_prompt(request.query, request.history) if request.build_prompt else request.query
|
||||
ids = self.tokenizer.encode(input_args)
|
||||
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
||||
cpy_history = history.copy()
|
||||
|
||||
start_cache_pos = len(ids)
|
||||
cur_cache_pos = 0
|
||||
|
|
@ -169,45 +131,36 @@ class StreamGenerator(GeneratorCore):
|
|||
|
||||
for _ in range(len(ids), self.config.max_len):
|
||||
next_token_id, cache_increase = self.generate_iterator(
|
||||
input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos)
|
||||
input_ids, request.temperature, request.top_k, request.top_p,
|
||||
kv_caches=kv_caches,
|
||||
start_pos=cur_cache_pos
|
||||
)
|
||||
|
||||
input_ids = next_token_id
|
||||
ids.append(next_token_id.item())
|
||||
cur_cache_pos += cache_increase
|
||||
|
||||
response = self.tokenizer.decode(ids[start_cache_pos:])
|
||||
yield response, cpy_history + [(query, response)]
|
||||
yield response, request.history + [(request.query, response)]
|
||||
|
||||
if next_token_id.item() in self.tokenizer.stop_ids:
|
||||
yield response + "\n", cpy_history + [(query, response)]
|
||||
yield response + "\n", request.history + [(request.query, response)]
|
||||
break
|
||||
|
||||
|
||||
|
||||
class BatchGenerator(GeneratorCore):
|
||||
def __init__(self, parameter: ModelParameter):
|
||||
super().__init__(parameter)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
queries: List[str],
|
||||
histories: List[List[Tuple[str, str]]],
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float
|
||||
) -> List[str]:
|
||||
def generate(self, request: GenerationRequest) -> List[str]:
|
||||
batch_size = len(request.query)
|
||||
if request.history is None:
|
||||
request.history = [[] for _ in range(batch_size)]
|
||||
|
||||
assert temperature >= 0.0
|
||||
assert top_k >= 0
|
||||
assert top_p >= 0.0 and top_p <= 1.0
|
||||
prompts = [build_prompt(query, history) for query, history in zip(request.query, request.history)]
|
||||
|
||||
batch_size = len(queries)
|
||||
if histories is None:
|
||||
histories = [[] for _ in range(batch_size)]
|
||||
|
||||
prompts = [build_prompt(query, history) for query, history in zip(queries, histories)]
|
||||
ids_list = [self.tokenizer.encode(prompt) for prompt in prompts]
|
||||
max_ids_len = max(len(ids) for ids in ids_list)
|
||||
ids_list = pad_sequence(ids_list, max_ids_len, self.tokenizer.pad_id)
|
||||
ids_list, max_ids_len = pad_sequence(ids_list, self.tokenizer.pad_id)
|
||||
|
||||
device = next(self.model.parameters()).device
|
||||
cache_manager = KVCacheManager(self.config, batch_size, device=device)
|
||||
|
|
@ -224,7 +177,11 @@ class BatchGenerator(GeneratorCore):
|
|||
attn_mask =cache_manager.get_seq_mask()
|
||||
|
||||
next_token_id, cache_increase = self.generate_iterator(
|
||||
input_tensor, temperature, top_k, top_p, attn_mask=attn_mask, kv_caches=kv_caches, start_pos=cur_cache_pos)
|
||||
input_tensor, request.temperature, request.top_k, request.top_p,
|
||||
attn_mask=attn_mask,
|
||||
kv_caches=kv_caches,
|
||||
start_pos=cur_cache_pos
|
||||
)
|
||||
|
||||
cur_cache_pos += cache_increase
|
||||
active_mask = []
|
||||
|
|
@ -246,47 +203,14 @@ class BatchGenerator(GeneratorCore):
|
|||
|
||||
max_ids_len += 1
|
||||
|
||||
|
||||
responses = [str()] * batch_size
|
||||
for i in range(batch_size):
|
||||
responses[i] = self.tokenizer.decode(ids_list[i][start_cache_pos:])
|
||||
histories[i].append((queries[i], responses[i]))
|
||||
request.history[i].append((request.query[i], responses[i]))
|
||||
|
||||
return responses
|
||||
|
||||
|
||||
class RetrievalGenerator(GeneratorCore):
|
||||
def __init__(self, retriever_parameter: ModelParameter):
|
||||
super().__init__(retriever_parameter)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
retrieved: List[str],
|
||||
query: str,
|
||||
history: List[Tuple[str, str]],
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
) -> str:
|
||||
assert temperature >= 0.0
|
||||
assert top_k >= 0
|
||||
assert top_p >= 0.0 and top_p <= 1.0
|
||||
|
||||
if history is None:
|
||||
history = []
|
||||
|
||||
retrieved = "\n".join([f"{idx + 1}. {key}" for idx, key in enumerate(retrieved)]) if retrieved else ""
|
||||
retrieved_query = f"{retrieved}\n\n{query}" if retrieved else query
|
||||
parameter = ModelParameter(self.model, self.tokenizer, self.config)
|
||||
|
||||
return ChatGenerator(parameter).generate(
|
||||
retrieved_query,
|
||||
history,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
)
|
||||
|
||||
class EmbeddingEncoder(EmbeddingEncoderCore):
|
||||
def __init__(self, parameter: ModelParameter):
|
||||
super().__init__(parameter)
|
||||
|
|
|
|||
|
|
@ -13,35 +13,35 @@ PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
|
|||
|
||||
def batch_generate(
|
||||
model: Khaosz,
|
||||
queries: List[str],
|
||||
query: List[str],
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
batch_size: int,
|
||||
) -> List:
|
||||
assert batch_size > 0
|
||||
sorted_queries = sorted(queries, key=lambda x: len(x), reverse=True)
|
||||
original_indices = {query: idx for idx, query in enumerate(queries)}
|
||||
sorted_query = sorted(query, key=lambda x: len(x), reverse=True)
|
||||
original_indices = {query: idx for idx, query in enumerate(query)}
|
||||
|
||||
responses = [None] * len(queries)
|
||||
total_batches = (len(sorted_queries) + batch_size - 1) // batch_size
|
||||
responses = [None] * len(query)
|
||||
total_batches = (len(sorted_query) + batch_size - 1) // batch_size
|
||||
|
||||
for i in tqdm(range(0, total_batches * batch_size, batch_size), desc="Generating responses"):
|
||||
batch_queries = sorted_queries[i: min(i + batch_size, len(queries))]
|
||||
if not isinstance(batch_queries, list):
|
||||
batch_queries = [batch_queries]
|
||||
batch_query = sorted_query[i: min(i + batch_size, len(query))]
|
||||
if not isinstance(batch_query, list):
|
||||
batch_query = [batch_query]
|
||||
|
||||
batch_responses = model.batch_generate(
|
||||
queries=batch_queries,
|
||||
query=batch_query,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p
|
||||
)
|
||||
|
||||
for batch_query, batch_response in zip(batch_queries, batch_responses):
|
||||
for batch_query, batch_response in zip(batch_query, batch_responses):
|
||||
print(f"Q: {batch_query[:50]} \nR: {batch_response[:50]})")
|
||||
|
||||
for query, response in zip(batch_queries, batch_responses):
|
||||
for query, response in zip(batch_query, batch_responses):
|
||||
original_idx = original_indices[query]
|
||||
responses[original_idx] = response
|
||||
|
||||
|
|
@ -60,11 +60,11 @@ def processor(
|
|||
):
|
||||
with open(input_json_file, "r", encoding='utf-8') as f:
|
||||
input_dict = [json.loads(line) for line in f]
|
||||
queries = [item[question_key] for item in input_dict]
|
||||
query = [item[question_key] for item in input_dict]
|
||||
|
||||
output_dict = batch_generate(
|
||||
model=model,
|
||||
queries=queries,
|
||||
query=query,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
|
|
|
|||
Loading…
Reference in New Issue