fix: 同步device 和 dtype

This commit is contained in:
ViperEkura 2026-04-04 10:25:39 +08:00
parent 26989e54aa
commit 3535de5cc4
1 changed files with 14 additions and 6 deletions

View File

@ -106,8 +106,10 @@ class LoopGenerator(GeneratorCore):
super().__init__(parameter) super().__init__(parameter)
def generate(self, request: GenerationRequest) -> str: def generate(self, request: GenerationRequest) -> str:
device = next(self.model.parameters()).device model_params = next(self.model.parameters())
cache_manager = KVCacheManager(self.config, 1, device=device) device = model_params.device
dtype = model_params.dtype
cache_manager = KVCacheManager(self.config, 1, device=device, dtype=dtype)
prompt = build_prompt(request.query, request.history) prompt = build_prompt(request.query, request.history)
ids = self.tokenizer.encode(prompt) ids = self.tokenizer.encode(prompt)
@ -135,8 +137,10 @@ class StreamGenerator(GeneratorCore):
super().__init__(parameter) super().__init__(parameter)
def generate(self, request: GenerationRequest) -> Generator[str, None, None]: def generate(self, request: GenerationRequest) -> Generator[str, None, None]:
device = next(self.model.parameters()).device model_params = next(self.model.parameters())
cache_manager = KVCacheManager(self.config, 1, device=device) device = model_params.device
dtype = model_params.dtype
cache_manager = KVCacheManager(self.config, 1, device=device, dtype=dtype)
prompt = build_prompt(request.query, request.history) prompt = build_prompt(request.query, request.history)
ids = self.tokenizer.encode(prompt) ids = self.tokenizer.encode(prompt)
@ -186,8 +190,12 @@ class BatchGenerator(GeneratorCore):
ids_list = [self.tokenizer.encode(prompt) for prompt in prompts] ids_list = [self.tokenizer.encode(prompt) for prompt in prompts]
ids_list, max_ids_len = pad_sequence(ids_list, self.tokenizer.pad_id) ids_list, max_ids_len = pad_sequence(ids_list, self.tokenizer.pad_id)
device = next(self.model.parameters()).device model_params = next(self.model.parameters())
cache_manager = KVCacheManager(self.config, batch_size, device=device) device = model_params.device
dtype = model_params.dtype
cache_manager = KVCacheManager(
self.config, batch_size, device=device, dtype=dtype
)
input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long) input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long)
cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id) cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id)