mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] Support prefill batch inference for pooling models. (#5436)
* fix multi-inputs * fix threshold * fix threshold * fix * support multi-batch * add tests * fix test * test * fix
This commit is contained in:
@@ -2534,7 +2534,6 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
return None
|
||||
|
||||
def _pool(self, hidden_states: paddle.Tensor, num_running_requests: int) -> Optional[ModelRunnerOutput]:
|
||||
|
||||
num_scheduled_tokens = int(self.share_inputs["seq_lens_this_time"][:num_running_requests].sum())
|
||||
hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
|
||||
@@ -2546,36 +2545,40 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
pooling_params=self.pooling_params,
|
||||
)
|
||||
|
||||
num_scheduled_tokens_list = [
|
||||
int(self.share_inputs["seq_lens_this_time"][i]) for i in range(num_running_requests)
|
||||
]
|
||||
|
||||
device_str = "gpu" if hidden_states.place.is_gpu_place() else "cpu"
|
||||
pooling_metadata.build_pooling_cursor(num_scheduled_tokens_list, device=device_str)
|
||||
|
||||
raw_pooler_output = self.model.pooler(hidden_states=hidden_states, pooling_metadata=pooling_metadata)
|
||||
|
||||
seq_lens_cpu = self.share_inputs["seq_lens_this_time"][:num_running_requests]
|
||||
seq_lens_decoder = self.share_inputs["seq_lens_decoder"][:num_running_requests]
|
||||
seq_lens_encoder = self.share_inputs["seq_lens_encoder"][:num_running_requests]
|
||||
|
||||
pooler_output: list[Optional[paddle.Tensor]] = []
|
||||
pooler_output_idx = 0
|
||||
|
||||
seq_lens_decoder_batch = self.share_inputs["seq_lens_decoder"][:num_running_requests]
|
||||
for i, prompt_len in enumerate(pooling_metadata.prompt_lens):
|
||||
current_seq_len = num_scheduled_tokens_list[i]
|
||||
|
||||
for i, (seq_len, prompt_len) in enumerate(zip(seq_lens_cpu, pooling_metadata.prompt_lens)):
|
||||
if not self.cache_config.enable_prefix_caching:
|
||||
output = raw_pooler_output[0].data if int(seq_len) == int(prompt_len) else None
|
||||
pooler_output.append(output)
|
||||
if current_seq_len == 0:
|
||||
pooler_output.append(None)
|
||||
continue
|
||||
|
||||
total_processed = int(seq_lens_decoder[i]) + int(seq_lens_encoder[i])
|
||||
|
||||
if total_processed == int(prompt_len):
|
||||
output = raw_pooler_output[pooler_output_idx]
|
||||
else:
|
||||
current_seq_len_decoder = seq_lens_decoder_batch[i]
|
||||
if int(current_seq_len_decoder) + int(seq_len) == int(prompt_len):
|
||||
output = raw_pooler_output[0].data
|
||||
else:
|
||||
output = None
|
||||
pooler_output.append(output)
|
||||
output = None
|
||||
|
||||
pooler_output = PoolerOutput(
|
||||
outputs=pooler_output,
|
||||
)
|
||||
pooler_output.append(output)
|
||||
pooler_output_idx += 1
|
||||
|
||||
return pooler_output
|
||||
return PoolerOutput(outputs=pooler_output)
|
||||
|
||||
def _execute_empty_input(self, forward_meta) -> None:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user