From 3535de5cc46d8181c4687bf2bba842b22cd80cba Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 4 Apr 2026 10:25:39 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E5=90=8C=E6=AD=A5device=20=E5=92=8C=20d?= =?UTF-8?q?type?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrai/inference/generator.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/astrai/inference/generator.py b/astrai/inference/generator.py index 1c8b043..a3600e3 100644 --- a/astrai/inference/generator.py +++ b/astrai/inference/generator.py @@ -106,8 +106,10 @@ class LoopGenerator(GeneratorCore): super().__init__(parameter) def generate(self, request: GenerationRequest) -> str: - device = next(self.model.parameters()).device - cache_manager = KVCacheManager(self.config, 1, device=device) + model_params = next(self.model.parameters()) + device = model_params.device + dtype = model_params.dtype + cache_manager = KVCacheManager(self.config, 1, device=device, dtype=dtype) prompt = build_prompt(request.query, request.history) ids = self.tokenizer.encode(prompt) @@ -135,8 +137,10 @@ class StreamGenerator(GeneratorCore): super().__init__(parameter) def generate(self, request: GenerationRequest) -> Generator[str, None, None]: - device = next(self.model.parameters()).device - cache_manager = KVCacheManager(self.config, 1, device=device) + model_params = next(self.model.parameters()) + device = model_params.device + dtype = model_params.dtype + cache_manager = KVCacheManager(self.config, 1, device=device, dtype=dtype) prompt = build_prompt(request.query, request.history) ids = self.tokenizer.encode(prompt) @@ -186,8 +190,12 @@ class BatchGenerator(GeneratorCore): ids_list = [self.tokenizer.encode(prompt) for prompt in prompts] ids_list, max_ids_len = pad_sequence(ids_list, self.tokenizer.pad_id) - device = next(self.model.parameters()).device - cache_manager = KVCacheManager(self.config, batch_size, device=device) + model_params = next(self.model.parameters()) + device = model_params.device + dtype = model_params.dtype + cache_manager = KVCacheManager( + self.config, batch_size, device=device, dtype=dtype + ) input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long) cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id)