fix: 同步device 和 dtype
This commit is contained in:
parent
26989e54aa
commit
3535de5cc4
|
|
@ -106,8 +106,10 @@ class LoopGenerator(GeneratorCore):
|
||||||
super().__init__(parameter)
|
super().__init__(parameter)
|
||||||
|
|
||||||
def generate(self, request: GenerationRequest) -> str:
|
def generate(self, request: GenerationRequest) -> str:
|
||||||
device = next(self.model.parameters()).device
|
model_params = next(self.model.parameters())
|
||||||
cache_manager = KVCacheManager(self.config, 1, device=device)
|
device = model_params.device
|
||||||
|
dtype = model_params.dtype
|
||||||
|
cache_manager = KVCacheManager(self.config, 1, device=device, dtype=dtype)
|
||||||
|
|
||||||
prompt = build_prompt(request.query, request.history)
|
prompt = build_prompt(request.query, request.history)
|
||||||
ids = self.tokenizer.encode(prompt)
|
ids = self.tokenizer.encode(prompt)
|
||||||
|
|
@ -135,8 +137,10 @@ class StreamGenerator(GeneratorCore):
|
||||||
super().__init__(parameter)
|
super().__init__(parameter)
|
||||||
|
|
||||||
def generate(self, request: GenerationRequest) -> Generator[str, None, None]:
|
def generate(self, request: GenerationRequest) -> Generator[str, None, None]:
|
||||||
device = next(self.model.parameters()).device
|
model_params = next(self.model.parameters())
|
||||||
cache_manager = KVCacheManager(self.config, 1, device=device)
|
device = model_params.device
|
||||||
|
dtype = model_params.dtype
|
||||||
|
cache_manager = KVCacheManager(self.config, 1, device=device, dtype=dtype)
|
||||||
|
|
||||||
prompt = build_prompt(request.query, request.history)
|
prompt = build_prompt(request.query, request.history)
|
||||||
ids = self.tokenizer.encode(prompt)
|
ids = self.tokenizer.encode(prompt)
|
||||||
|
|
@ -186,8 +190,12 @@ class BatchGenerator(GeneratorCore):
|
||||||
ids_list = [self.tokenizer.encode(prompt) for prompt in prompts]
|
ids_list = [self.tokenizer.encode(prompt) for prompt in prompts]
|
||||||
ids_list, max_ids_len = pad_sequence(ids_list, self.tokenizer.pad_id)
|
ids_list, max_ids_len = pad_sequence(ids_list, self.tokenizer.pad_id)
|
||||||
|
|
||||||
device = next(self.model.parameters()).device
|
model_params = next(self.model.parameters())
|
||||||
cache_manager = KVCacheManager(self.config, batch_size, device=device)
|
device = model_params.device
|
||||||
|
dtype = model_params.dtype
|
||||||
|
cache_manager = KVCacheManager(
|
||||||
|
self.config, batch_size, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long)
|
input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long)
|
||||||
cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id)
|
cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue