refactor: 优化接口设置, 去除冗余代码

This commit is contained in:
ViperEkura 2026-03-18 15:07:35 +08:00
parent e23a5ca426
commit 62fba9a298
8 changed files with 172 additions and 298 deletions

View File

@ -12,7 +12,7 @@ def batch_generate():
inputs = ["你好", "请问什么是人工智能", "今天天气如何", "我感到焦虑, 请问我应该怎么办", "请问什么是显卡"] inputs = ["你好", "请问什么是人工智能", "今天天气如何", "我感到焦虑, 请问我应该怎么办", "请问什么是显卡"]
responses = model.batch_generate( responses = model.batch_generate(
queries=inputs, query=inputs,
temperature=0.8, temperature=0.8,
top_p=0.95, top_p=0.95,
top_k=50 top_k=50

View File

@ -26,9 +26,10 @@ if __name__ == "__main__":
query = "作者设计了一个怎样的模型" query = "作者设计了一个怎样的模型"
emb_query = model.encode(query) emb_query = model.encode(query)
retrieved = retriever.retrieve(emb_query, retrive_top_k) retrieved = retriever.retrieve(emb_query, retrive_top_k)
retrieved_content = "\n".join([f"{idx + 1}. " + text for idx, (text, _) in enumerate(retrieved)])
retrive_response = model.retrieve_generate( retrive_response = model.retrieve_generate(
retrieved=retrieved, retrieved=retrieved_content,
query=query, query=query,
temperature=0.8, temperature=0.8,
top_p=0.95, top_p=0.95,
@ -36,7 +37,7 @@ if __name__ == "__main__":
) )
print("retrieve content:") print("retrieve content:")
print("\n".join([f"{idx + 1}. " + text for idx, (text, _) in enumerate(retrieved)])) print(retrieved_content)
print("\n\nretrive generate:") print("\n\nretrive generate:")
print(retrive_response) print(retrive_response)

View File

@ -17,12 +17,11 @@ from khaosz.data import (
BpeTokenizer BpeTokenizer
) )
from khaosz.inference.generator import ( from khaosz.inference.generator import (
TextGenerator, GenerationRequest,
ChatGenerator, LoopGenerator,
StreamGenerator, StreamGenerator,
BatchGenerator, BatchGenerator,
RetrievalGenerator, EmbeddingEncoder,
EmbeddingEncoder
) )
from khaosz.trainer import ( from khaosz.trainer import (
@ -46,11 +45,10 @@ __all__ = [
"DatasetLoader", "DatasetLoader",
"BpeTokenizer", "BpeTokenizer",
"TextGenerator", "GenerationRequest",
"ChatGenerator", "LoopGenerator",
"StreamGenerator", "StreamGenerator",
"BatchGenerator", "BatchGenerator",
"RetrievalGenerator",
"EmbeddingEncoder", "EmbeddingEncoder",
"Trainer", "Trainer",

View File

@ -1,21 +1,42 @@
from torch import nn
from torch import Tensor from torch import Tensor
from contextlib import contextmanager
from typing import List, Tuple, Generator, Union from typing import List, Tuple, Generator, Union
from khaosz.inference.generator import ( from khaosz.inference.generator import (
TextGenerator, GenerationRequest,
ChatGenerator, LoopGenerator,
StreamGenerator, StreamGenerator,
BatchGenerator, BatchGenerator,
RetrievalGenerator,
EmbeddingEncoder EmbeddingEncoder
) )
from khaosz.config.param_config import ModelParameter from khaosz.config.param_config import ModelParameter
@contextmanager
def disable_random_init():
init_functions = [
'xavier_normal_', 'xavier_uniform_',
'kaiming_normal_', 'kaiming_uniform_',
'zeros_', 'ones_', 'constant_',
'normal_', 'uniform_'
]
original_funcs = {}
for name in init_functions:
if hasattr(nn.init, name):
original_funcs[name] = getattr(nn.init, name)
setattr(nn.init, name, lambda *args, **kwargs: None)
try:
yield
finally:
for name, orig_func in original_funcs.items():
setattr(nn.init, name, orig_func)
class Khaosz: class Khaosz:
def __init__(self, model_dir: str): def __init__(self, model_dir: str):
self.parameter = ModelParameter() with disable_random_init():
self.parameter.load(model_dir) self.parameter = ModelParameter()
self.parameter.load(model_dir)
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
self.parameter.to(*args, **kwargs) self.parameter.to(*args, **kwargs)
@ -29,32 +50,33 @@ class Khaosz:
top_k: int=50, top_k: int=50,
top_p: float=0.95, top_p: float=0.95,
) -> str: ) -> str:
generator = ChatGenerator(self.parameter) generator = LoopGenerator(self.parameter)
return generator.generate( return generator.generate(
query, GenerationRequest(
top_k, top_p, temperature,
self.parameter.config.max_len,
query=query,
history=history, history=history,
temperature=temperature, build_prompt=True
top_k=top_k, ))
top_p=top_p,
)
def batch_generate( def batch_generate(
self, self,
queries: List[str], query: List[str],
histories: List[Tuple[str, str]]=None, history: List[Tuple[str, str]]=None,
temperature: float=0.8, temperature: float=0.8,
top_k: int=50, top_k: int=50,
top_p: float=0.95, top_p: float=0.95,
) -> List[str]: ) -> List[str]:
generator = BatchGenerator(self.parameter) generator = BatchGenerator(self.parameter)
return generator.generate( return generator.generate(
queries, GenerationRequest(
histories=histories, top_k, top_p, temperature,
temperature=temperature, self.parameter.config.max_len,
top_k=top_k, query=query,
top_p=top_p, history=history,
) build_prompt=True
))
def stream_generate( def stream_generate(
self, self,
@ -66,12 +88,13 @@ class Khaosz:
) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]: ) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
stream_generator = StreamGenerator(self.parameter) stream_generator = StreamGenerator(self.parameter)
return stream_generator.generate( return stream_generator.generate(
query, GenerationRequest(
history=history, top_k, top_p, temperature,
temperature=temperature, self.parameter.config.max_len,
top_k=top_k, query=query,
top_p=top_p, history=history,
) build_prompt=True
))
def retrieve_generate( def retrieve_generate(
self, self,
@ -82,15 +105,16 @@ class Khaosz:
top_k: int=50, top_k: int=50,
top_p: float=0.95, top_p: float=0.95,
) -> str: ) -> str:
generator = RetrievalGenerator(self.parameter) generator = LoopGenerator(self.parameter)
return generator.generate( return generator.generate(
retrieved, GenerationRequest(
query, top_k, top_p, temperature,
self.parameter.config.max_len,
query=query,
history=history, history=history,
temperature=temperature, system_prompt=retrieved,
top_k=top_k, build_prompt=True
top_p=top_p, ))
)
def text_generate( def text_generate(
self, self,
@ -99,14 +123,15 @@ class Khaosz:
top_k: int=50, top_k: int=50,
top_p: float=0.95, top_p: float=0.95,
) -> str: ) -> str:
generator = TextGenerator(self.parameter) generator = LoopGenerator(self.parameter)
return generator.generate(
return generator.generate( GenerationRequest(
query, top_k, top_p, temperature,
temperature=temperature, self.parameter.config.max_len,
top_k=top_k, query=query,
top_p=top_p, build_prompt=False
) ))
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]: def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
encoder = EmbeddingEncoder(self.parameter) encoder = EmbeddingEncoder(self.parameter)

View File

@ -1 +1,25 @@
# init file from khaosz.inference.core import (
GeneratorCore,
EmbeddingEncoderCore,
KVCacheManager,
)
from khaosz.inference.generator import (
GenerationRequest,
LoopGenerator,
StreamGenerator,
BatchGenerator,
EmbeddingEncoder,
)
__all__ = [
"GeneratorCore",
"EmbeddingEncoderCore",
"KVCacheManager",
"GenerationRequest",
"LoopGenerator",
"StreamGenerator",
"BatchGenerator",
"EmbeddingEncoder",
]

View File

@ -1,98 +0,0 @@
import torch
from torch import Tensor
from functools import wraps
from inspect import signature
class CudaGraphWrapper:
def __init__(self, function, device="cuda", cast=False):
self.function = function
self.cast = cast
self.device = device
self.static_input = None
self.static_output = None
self.graph = None
self.signature = signature(function)
def _update_inplace(self, lhs, rhs):
if isinstance(lhs, Tensor) and isinstance(rhs, Tensor):
if lhs.shape != rhs.shape:
raise ValueError(
f"Tensor shape mismatch! "
f"Expected: {lhs.shape}, Got: {rhs.shape}. "
f"Function: {self.function}"
)
if self.cast:
if lhs.device != rhs.device:
rhs = rhs.to(device=lhs.device)
if lhs.dtype != rhs.dtype:
rhs = rhs.to(dtype=lhs.dtype)
else:
if lhs.device != rhs.device:
raise ValueError(
f"Tensor device mismatch! "
f"Expected: {lhs.device}, Got: {rhs.device}. "
f"Function: {self.function}"
)
if lhs.dtype != rhs.dtype:
raise ValueError(
f"Tensor dtype mismatch! "
f"Expected: {lhs.dtype}, Got: {rhs.dtype}. "
f"Function: {self.function}"
)
lhs.copy_(rhs)
elif isinstance(lhs, dict):
for k in lhs:
if k in rhs:
self._update_inplace(lhs[k], rhs[k])
elif isinstance(lhs, (list, tuple)):
for i in range(len(lhs)):
if i < len(rhs):
self._update_inplace(lhs[i], rhs[i])
elif isinstance(lhs, (int, float, bool, str, type(None))):
if lhs != rhs:
raise ValueError("Does not support changing control parameters.")
def _update_args(self, input_args, input_kwargs):
bound_args = self.signature.bind(*input_args, **input_kwargs)
bound_args.apply_defaults()
args_dict = bound_args.arguments
if self.static_input is None:
self.static_input = args_dict
else:
self._update_inplace(self.static_input, args_dict)
def run(self, *args, **kwargs):
self._update_args(args, kwargs)
if self.graph is None:
# warmup
_ = torch.matmul(
torch.randn(100, 100, device=self.device),
torch.randn(100, 100, device=self.device)
)
torch.cuda.synchronize()
# capture graph
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
self.static_output = self.function(**self.static_input)
self.graph.replay()
return self.static_output
def cuda_graph(device="cuda", cast=False):
def decorator(func):
wrapper = CudaGraphWrapper(func, device, cast)
@wraps(func)
def wrapped(*args, **kwargs):
return wrapper.run(*args, **kwargs)
return wrapped
return decorator

View File

@ -1,10 +1,13 @@
import torch import torch
from dataclasses import dataclass
from torch import Tensor from torch import Tensor
from typing import List, Tuple, Union, Optional, Generator from typing import List, Tuple, Union, Optional, Generator
from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheManager from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheManager
from khaosz.config.param_config import ModelParameter from khaosz.config.param_config import ModelParameter
HistoryType = List[Tuple[str, str]]
def build_prompt( def build_prompt(
query: str, query: str,
init_prompt: Optional[str] = None, init_prompt: Optional[str] = None,
@ -34,7 +37,7 @@ def build_prompt(
return prompt return prompt
def pad_sequence(ids_list: List[List[int]], max_ids_len: int, pad_id: int) -> List[List[int]]: def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]:
""" """
Pad a list of sequences to a fixed length. Pad a list of sequences to a fixed length.
@ -47,34 +50,47 @@ def pad_sequence(ids_list: List[List[int]], max_ids_len: int, pad_id: int) -> Li
List[List[int]]: A list of padded sequences. List[List[int]]: A list of padded sequences.
""" """
max_ids_len = max(len(ids) for ids in ids_list)
new_ids_list = [] new_ids_list = []
for ids in ids_list: for ids in ids_list:
pad_len = max_ids_len - len(ids) pad_len = max_ids_len - len(ids)
padded_seq = [pad_id] * pad_len + ids padded_seq = [pad_id] * pad_len + ids
new_ids_list.append(padded_seq) new_ids_list.append(padded_seq)
return new_ids_list return new_ids_list, max_ids_len
@dataclass
class GenerationRequest:
top_k: int
top_p: float
temperature: float
max_len: int
query: Union[str, List[str]]
history: Optional[Union[HistoryType, List[HistoryType]]] = None
system_prompt: Optional[str] = None
build_prompt: bool = True
def __post_init__(self):
if not isinstance(self.top_k, int) or self.top_k < 0:
raise ValueError("top_k must be a non-negative integer")
if not isinstance(self.top_p, float) or self.top_p < 0.0 or self.top_p > 1.0:
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not isinstance(self.temperature, float) or self.temperature < 0.0:
raise ValueError("temperature must be a non-negative float")
class TextGenerator(GeneratorCore): class LoopGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter): def __init__(self, parameter: ModelParameter):
super().__init__(parameter) super().__init__(parameter)
def generate( def generate(self, request: GenerationRequest) -> str:
self,
query: str,
temperature: float,
top_k: int,
top_p: float,
) -> str:
assert temperature >= 0.0
assert top_k >= 0
assert top_p >= 0.0 and top_p <= 1.0
device = next(self.model.parameters()).device device = next(self.model.parameters()).device
cache_manager = KVCacheManager(self.config, 1, device=device) cache_manager = KVCacheManager(self.config, 1, device=device)
ids = self.tokenizer.encode(query) input_args = build_prompt(request.query, request.history) if request.build_prompt else request.query
ids = self.tokenizer.encode(input_args)
input_ids = torch.tensor([ids], device=device, dtype=torch.long) input_ids = torch.tensor([ids], device=device, dtype=torch.long)
start_cache_pos = len(ids) start_cache_pos = len(ids)
@ -83,84 +99,30 @@ class TextGenerator(GeneratorCore):
kv_caches = cache_manager.get_kvcache() kv_caches = cache_manager.get_kvcache()
ids = self.generate_loop( ids = self.generate_loop(
input_ids, ids, temperature, top_k, top_p, input_ids, ids, request.temperature, request.top_k, request.top_p,
kv_caches=kv_caches, kv_caches=kv_caches,
start_pos=cur_cache_pos start_pos=cur_cache_pos
) )
response = self.tokenizer.decode(ids[start_cache_pos:]) response = self.tokenizer.decode(ids[start_cache_pos:])
return response return response
class ChatGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def generate(
self,
query: str,
history: List[Tuple[str, str]],
temperature: float,
top_k: int,
top_p: float,
) -> str:
assert temperature >= 0.0
assert top_k >= 0
assert top_p >= 0.0 and top_p <= 1.0
if history is None:
history = []
device = next(self.model.parameters()).device
cache_manager = KVCacheManager(self.config, 1, device=device)
ids = self.tokenizer.encode(build_prompt(query, history))
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
start_cache_pos = len(ids)
cur_cache_pos = 0
self.model.eval()
kv_caches = cache_manager.get_kvcache()
ids = self.generate_loop(
input_ids, ids, temperature, top_k, top_p,
kv_caches=kv_caches,
start_pos=cur_cache_pos
)
response = self.tokenizer.decode(ids[start_cache_pos:])
return response
class StreamGenerator(GeneratorCore): class StreamGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter): def __init__(self, parameter: ModelParameter):
super().__init__(parameter) super().__init__(parameter)
def generate( def generate(self, request: GenerationRequest) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
self,
query: str,
history: List[Tuple[str, str]],
temperature: float,
top_k: int,
top_p: float,
) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
assert temperature >= 0.0
assert top_k >= 0
assert top_p >= 0.0 and top_p <= 1.0
if history is None: if request.history is None:
history = [] request.history = []
device = next(self.model.parameters()).device device = next(self.model.parameters()).device
cache_manager = KVCacheManager(self.config, 1, device=device) cache_manager = KVCacheManager(self.config, 1, device=device)
ids = self.tokenizer.encode(build_prompt(query, history)) input_args = build_prompt(request.query, request.history) if request.build_prompt else request.query
ids = self.tokenizer.encode(input_args)
input_ids = torch.tensor([ids], device=device, dtype=torch.long) input_ids = torch.tensor([ids], device=device, dtype=torch.long)
cpy_history = history.copy()
start_cache_pos = len(ids) start_cache_pos = len(ids)
cur_cache_pos = 0 cur_cache_pos = 0
@ -169,45 +131,36 @@ class StreamGenerator(GeneratorCore):
for _ in range(len(ids), self.config.max_len): for _ in range(len(ids), self.config.max_len):
next_token_id, cache_increase = self.generate_iterator( next_token_id, cache_increase = self.generate_iterator(
input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos) input_ids, request.temperature, request.top_k, request.top_p,
kv_caches=kv_caches,
start_pos=cur_cache_pos
)
input_ids = next_token_id input_ids = next_token_id
ids.append(next_token_id.item()) ids.append(next_token_id.item())
cur_cache_pos += cache_increase cur_cache_pos += cache_increase
response = self.tokenizer.decode(ids[start_cache_pos:]) response = self.tokenizer.decode(ids[start_cache_pos:])
yield response, cpy_history + [(query, response)] yield response, request.history + [(request.query, response)]
if next_token_id.item() in self.tokenizer.stop_ids: if next_token_id.item() in self.tokenizer.stop_ids:
yield response + "\n", cpy_history + [(query, response)] yield response + "\n", request.history + [(request.query, response)]
break break
class BatchGenerator(GeneratorCore): class BatchGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter): def __init__(self, parameter: ModelParameter):
super().__init__(parameter) super().__init__(parameter)
def generate( def generate(self, request: GenerationRequest) -> List[str]:
self, batch_size = len(request.query)
queries: List[str], if request.history is None:
histories: List[List[Tuple[str, str]]], request.history = [[] for _ in range(batch_size)]
temperature: float,
top_k: int,
top_p: float
) -> List[str]:
assert temperature >= 0.0 prompts = [build_prompt(query, history) for query, history in zip(request.query, request.history)]
assert top_k >= 0
assert top_p >= 0.0 and top_p <= 1.0
batch_size = len(queries)
if histories is None:
histories = [[] for _ in range(batch_size)]
prompts = [build_prompt(query, history) for query, history in zip(queries, histories)]
ids_list = [self.tokenizer.encode(prompt) for prompt in prompts] ids_list = [self.tokenizer.encode(prompt) for prompt in prompts]
max_ids_len = max(len(ids) for ids in ids_list) ids_list, max_ids_len = pad_sequence(ids_list, self.tokenizer.pad_id)
ids_list = pad_sequence(ids_list, max_ids_len, self.tokenizer.pad_id)
device = next(self.model.parameters()).device device = next(self.model.parameters()).device
cache_manager = KVCacheManager(self.config, batch_size, device=device) cache_manager = KVCacheManager(self.config, batch_size, device=device)
@ -224,7 +177,11 @@ class BatchGenerator(GeneratorCore):
attn_mask =cache_manager.get_seq_mask() attn_mask =cache_manager.get_seq_mask()
next_token_id, cache_increase = self.generate_iterator( next_token_id, cache_increase = self.generate_iterator(
input_tensor, temperature, top_k, top_p, attn_mask=attn_mask, kv_caches=kv_caches, start_pos=cur_cache_pos) input_tensor, request.temperature, request.top_k, request.top_p,
attn_mask=attn_mask,
kv_caches=kv_caches,
start_pos=cur_cache_pos
)
cur_cache_pos += cache_increase cur_cache_pos += cache_increase
active_mask = [] active_mask = []
@ -246,47 +203,14 @@ class BatchGenerator(GeneratorCore):
max_ids_len += 1 max_ids_len += 1
responses = [str()] * batch_size responses = [str()] * batch_size
for i in range(batch_size): for i in range(batch_size):
responses[i] = self.tokenizer.decode(ids_list[i][start_cache_pos:]) responses[i] = self.tokenizer.decode(ids_list[i][start_cache_pos:])
histories[i].append((queries[i], responses[i])) request.history[i].append((request.query[i], responses[i]))
return responses return responses
class RetrievalGenerator(GeneratorCore):
def __init__(self, retriever_parameter: ModelParameter):
super().__init__(retriever_parameter)
def generate(
self,
retrieved: List[str],
query: str,
history: List[Tuple[str, str]],
temperature: float,
top_k: int,
top_p: float,
) -> str:
assert temperature >= 0.0
assert top_k >= 0
assert top_p >= 0.0 and top_p <= 1.0
if history is None:
history = []
retrieved = "\n".join([f"{idx + 1}. {key}" for idx, key in enumerate(retrieved)]) if retrieved else ""
retrieved_query = f"{retrieved}\n\n{query}" if retrieved else query
parameter = ModelParameter(self.model, self.tokenizer, self.config)
return ChatGenerator(parameter).generate(
retrieved_query,
history,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
class EmbeddingEncoder(EmbeddingEncoderCore): class EmbeddingEncoder(EmbeddingEncoderCore):
def __init__(self, parameter: ModelParameter): def __init__(self, parameter: ModelParameter):
super().__init__(parameter) super().__init__(parameter)

View File

@ -13,35 +13,35 @@ PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
def batch_generate( def batch_generate(
model: Khaosz, model: Khaosz,
queries: List[str], query: List[str],
temperature: float, temperature: float,
top_k: int, top_k: int,
top_p: float, top_p: float,
batch_size: int, batch_size: int,
) -> List: ) -> List:
assert batch_size > 0 assert batch_size > 0
sorted_queries = sorted(queries, key=lambda x: len(x), reverse=True) sorted_query = sorted(query, key=lambda x: len(x), reverse=True)
original_indices = {query: idx for idx, query in enumerate(queries)} original_indices = {query: idx for idx, query in enumerate(query)}
responses = [None] * len(queries) responses = [None] * len(query)
total_batches = (len(sorted_queries) + batch_size - 1) // batch_size total_batches = (len(sorted_query) + batch_size - 1) // batch_size
for i in tqdm(range(0, total_batches * batch_size, batch_size), desc="Generating responses"): for i in tqdm(range(0, total_batches * batch_size, batch_size), desc="Generating responses"):
batch_queries = sorted_queries[i: min(i + batch_size, len(queries))] batch_query = sorted_query[i: min(i + batch_size, len(query))]
if not isinstance(batch_queries, list): if not isinstance(batch_query, list):
batch_queries = [batch_queries] batch_query = [batch_query]
batch_responses = model.batch_generate( batch_responses = model.batch_generate(
queries=batch_queries, query=batch_query,
temperature=temperature, temperature=temperature,
top_k=top_k, top_k=top_k,
top_p=top_p top_p=top_p
) )
for batch_query, batch_response in zip(batch_queries, batch_responses): for batch_query, batch_response in zip(batch_query, batch_responses):
print(f"Q: {batch_query[:50]} \nR: {batch_response[:50]})") print(f"Q: {batch_query[:50]} \nR: {batch_response[:50]})")
for query, response in zip(batch_queries, batch_responses): for query, response in zip(batch_query, batch_responses):
original_idx = original_indices[query] original_idx = original_indices[query]
responses[original_idx] = response responses[original_idx] = response
@ -60,11 +60,11 @@ def processor(
): ):
with open(input_json_file, "r", encoding='utf-8') as f: with open(input_json_file, "r", encoding='utf-8') as f:
input_dict = [json.loads(line) for line in f] input_dict = [json.loads(line) for line in f]
queries = [item[question_key] for item in input_dict] query = [item[question_key] for item in input_dict]
output_dict = batch_generate( output_dict = batch_generate(
model=model, model=model,
queries=queries, query=query,
temperature=temperature, temperature=temperature,
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,