refactor(generator): 优化生成逻辑
This commit is contained in:
parent
bdc3f4dc63
commit
66a551217e
|
|
@ -118,7 +118,6 @@ class ChatGenerator(GeneratorCore):
|
||||||
|
|
||||||
ids = self.tokenizer.encode(build_prompt(query, history))
|
ids = self.tokenizer.encode(build_prompt(query, history))
|
||||||
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
||||||
cpy_history = history.copy()
|
|
||||||
|
|
||||||
start_cache_pos = len(ids)
|
start_cache_pos = len(ids)
|
||||||
cur_cache_pos = 0
|
cur_cache_pos = 0
|
||||||
|
|
@ -132,9 +131,8 @@ class ChatGenerator(GeneratorCore):
|
||||||
)
|
)
|
||||||
|
|
||||||
response = self.tokenizer.decode(ids[start_cache_pos:])
|
response = self.tokenizer.decode(ids[start_cache_pos:])
|
||||||
cpy_history.append((query, response))
|
|
||||||
|
|
||||||
return response, cpy_history
|
return response
|
||||||
|
|
||||||
|
|
||||||
class StreamGenerator(GeneratorCore):
|
class StreamGenerator(GeneratorCore):
|
||||||
|
|
@ -278,7 +276,7 @@ class RetrievalGenerator(GeneratorCore):
|
||||||
history = []
|
history = []
|
||||||
|
|
||||||
retrieved = "\n".join([f"{idx + 1}. {key}" for idx, key in enumerate(retrieved)]) if retrieved else ""
|
retrieved = "\n".join([f"{idx + 1}. {key}" for idx, key in enumerate(retrieved)]) if retrieved else ""
|
||||||
retrieved_query = f"{retrieved}<eos>\n\n根据以上内容回答: {query}" if retrieved else query
|
retrieved_query = f"{retrieved}\n\n{query}" if retrieved else query
|
||||||
parameter = ModelParameter(self.model, self.tokenizer, self.config)
|
parameter = ModelParameter(self.model, self.tokenizer, self.config)
|
||||||
|
|
||||||
return ChatGenerator(parameter).generate(
|
return ChatGenerator(parameter).generate(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue