mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Bug fix] fix pooling models (#5358)
* fix * fix * fix test * fix gpu_model_runner --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -2457,12 +2457,12 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
for i, (seq_len, prompt_len) in enumerate(zip(seq_lens_cpu, pooling_metadata.prompt_lens)):
|
||||
if not self.cache_config.enable_prefix_caching:
|
||||
output = raw_pooler_output[i].data if int(seq_len) == int(prompt_len) else None
|
||||
output = raw_pooler_output[0].data if int(seq_len) == int(prompt_len) else None
|
||||
pooler_output.append(output)
|
||||
else:
|
||||
current_seq_len_decoder = seq_lens_decoder_batch[i]
|
||||
if int(current_seq_len_decoder) + int(seq_len) == int(prompt_len):
|
||||
output = raw_pooler_output[i].data
|
||||
output = raw_pooler_output[0].data
|
||||
else:
|
||||
output = None
|
||||
pooler_output.append(output)
|
||||
|
||||
Reference in New Issue
Block a user