fix: 同步device 和 dtype
This commit is contained in:
parent
26989e54aa
commit
3535de5cc4
|
|
@ -106,8 +106,10 @@ class LoopGenerator(GeneratorCore):
|
|||
super().__init__(parameter)
|
||||
|
||||
def generate(self, request: GenerationRequest) -> str:
|
||||
device = next(self.model.parameters()).device
|
||||
cache_manager = KVCacheManager(self.config, 1, device=device)
|
||||
model_params = next(self.model.parameters())
|
||||
device = model_params.device
|
||||
dtype = model_params.dtype
|
||||
cache_manager = KVCacheManager(self.config, 1, device=device, dtype=dtype)
|
||||
|
||||
prompt = build_prompt(request.query, request.history)
|
||||
ids = self.tokenizer.encode(prompt)
|
||||
|
|
@ -135,8 +137,10 @@ class StreamGenerator(GeneratorCore):
|
|||
super().__init__(parameter)
|
||||
|
||||
def generate(self, request: GenerationRequest) -> Generator[str, None, None]:
|
||||
device = next(self.model.parameters()).device
|
||||
cache_manager = KVCacheManager(self.config, 1, device=device)
|
||||
model_params = next(self.model.parameters())
|
||||
device = model_params.device
|
||||
dtype = model_params.dtype
|
||||
cache_manager = KVCacheManager(self.config, 1, device=device, dtype=dtype)
|
||||
|
||||
prompt = build_prompt(request.query, request.history)
|
||||
ids = self.tokenizer.encode(prompt)
|
||||
|
|
@ -186,8 +190,12 @@ class BatchGenerator(GeneratorCore):
|
|||
ids_list = [self.tokenizer.encode(prompt) for prompt in prompts]
|
||||
ids_list, max_ids_len = pad_sequence(ids_list, self.tokenizer.pad_id)
|
||||
|
||||
device = next(self.model.parameters()).device
|
||||
cache_manager = KVCacheManager(self.config, batch_size, device=device)
|
||||
model_params = next(self.model.parameters())
|
||||
device = model_params.device
|
||||
dtype = model_params.dtype
|
||||
cache_manager = KVCacheManager(
|
||||
self.config, batch_size, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long)
|
||||
cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id)
|
||||
|
|
|
|||
Loading…
Reference in New Issue