refactor(generator): 优化生成逻辑
This commit is contained in:
parent
bdc3f4dc63
commit
66a551217e
|
|
@ -118,7 +118,6 @@ class ChatGenerator(GeneratorCore):
|
|||
|
||||
ids = self.tokenizer.encode(build_prompt(query, history))
|
||||
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
||||
cpy_history = history.copy()
|
||||
|
||||
start_cache_pos = len(ids)
|
||||
cur_cache_pos = 0
|
||||
|
|
@ -132,9 +131,8 @@ class ChatGenerator(GeneratorCore):
|
|||
)
|
||||
|
||||
response = self.tokenizer.decode(ids[start_cache_pos:])
|
||||
cpy_history.append((query, response))
|
||||
|
||||
return response, cpy_history
|
||||
return response
|
||||
|
||||
|
||||
class StreamGenerator(GeneratorCore):
|
||||
|
|
@ -278,7 +276,7 @@ class RetrievalGenerator(GeneratorCore):
|
|||
history = []
|
||||
|
||||
retrieved = "\n".join([f"{idx + 1}. {key}" for idx, key in enumerate(retrieved)]) if retrieved else ""
|
||||
retrieved_query = f"{retrieved}<eos>\n\n根据以上内容回答: {query}" if retrieved else query
|
||||
retrieved_query = f"{retrieved}\n\n{query}" if retrieved else query
|
||||
parameter = ModelParameter(self.model, self.tokenizer, self.config)
|
||||
|
||||
return ChatGenerator(parameter).generate(
|
||||
|
|
|
|||
Loading…
Reference in New Issue