[Bug fix] fix pooling models (#5358)

* fix

* fix

* fix test

* fix gpu_model_runner

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
lizexu123
2025-12-04 11:06:30 +08:00
committed by GitHub
parent a52aea073c
commit 946025480e
5 changed files with 26 additions and 47 deletions

View File

@@ -2457,12 +2457,12 @@ class GPUModelRunner(ModelRunnerBase):
for i, (seq_len, prompt_len) in enumerate(zip(seq_lens_cpu, pooling_metadata.prompt_lens)):
if not self.cache_config.enable_prefix_caching:
output = raw_pooler_output[i].data if int(seq_len) == int(prompt_len) else None
output = raw_pooler_output[0].data if int(seq_len) == int(prompt_len) else None
pooler_output.append(output)
else:
current_seq_len_decoder = seq_lens_decoder_batch[i]
if int(current_seq_len_decoder) + int(seq_len) == int(prompt_len):
output = raw_pooler_output[i].data
output = raw_pooler_output[0].data
else:
output = None
pooler_output.append(output)