refactor(generator): 优化生成逻辑

This commit is contained in:
ViperEkura 2025-11-07 07:24:00 +08:00
parent bdc3f4dc63
commit 66a551217e
1 changed files with 2 additions and 4 deletions

View File

@ -118,7 +118,6 @@ class ChatGenerator(GeneratorCore):
ids = self.tokenizer.encode(build_prompt(query, history))
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
cpy_history = history.copy()
start_cache_pos = len(ids)
cur_cache_pos = 0
@ -132,9 +131,8 @@ class ChatGenerator(GeneratorCore):
)
response = self.tokenizer.decode(ids[start_cache_pos:])
cpy_history.append((query, response))
return response, cpy_history
return response
class StreamGenerator(GeneratorCore):
@ -278,7 +276,7 @@ class RetrievalGenerator(GeneratorCore):
history = []
retrieved = "\n".join([f"{idx + 1}. {key}" for idx, key in enumerate(retrieved)]) if retrieved else ""
retrieved_query = f"{retrieved}<eos>\n\n根据以上内容回答: {query}" if retrieved else query
retrieved_query = f"{retrieved}\n\n{query}" if retrieved else query
parameter = ModelParameter(self.model, self.tokenizer, self.config)
return ChatGenerator(parameter).generate(