[Feature] Support prefill batch inference for pooling models. (#5436)

* fix multi-inputs

* fix threshold

* fix threshold

* fix

* support multi-batch

* add tests

* fix test

* test

* fix
This commit is contained in:
lizexu123
2025-12-09 16:21:00 +08:00
committed by GitHub
parent 31410415db
commit b0cf2c4b7a
2 changed files with 24 additions and 17 deletions

View File

@@ -2534,7 +2534,6 @@ class GPUModelRunner(ModelRunnerBase):
return None
def _pool(self, hidden_states: paddle.Tensor, num_running_requests: int) -> Optional[ModelRunnerOutput]:
num_scheduled_tokens = int(self.share_inputs["seq_lens_this_time"][:num_running_requests].sum())
hidden_states = hidden_states[:num_scheduled_tokens]
@@ -2546,36 +2545,40 @@ class GPUModelRunner(ModelRunnerBase):
prompt_token_ids=prompt_token_ids,
pooling_params=self.pooling_params,
)
num_scheduled_tokens_list = [
int(self.share_inputs["seq_lens_this_time"][i]) for i in range(num_running_requests)
]
device_str = "gpu" if hidden_states.place.is_gpu_place() else "cpu"
pooling_metadata.build_pooling_cursor(num_scheduled_tokens_list, device=device_str)
raw_pooler_output = self.model.pooler(hidden_states=hidden_states, pooling_metadata=pooling_metadata)
seq_lens_cpu = self.share_inputs["seq_lens_this_time"][:num_running_requests]
seq_lens_decoder = self.share_inputs["seq_lens_decoder"][:num_running_requests]
seq_lens_encoder = self.share_inputs["seq_lens_encoder"][:num_running_requests]
pooler_output: list[Optional[paddle.Tensor]] = []
pooler_output_idx = 0
seq_lens_decoder_batch = self.share_inputs["seq_lens_decoder"][:num_running_requests]
for i, prompt_len in enumerate(pooling_metadata.prompt_lens):
current_seq_len = num_scheduled_tokens_list[i]
for i, (seq_len, prompt_len) in enumerate(zip(seq_lens_cpu, pooling_metadata.prompt_lens)):
if not self.cache_config.enable_prefix_caching:
output = raw_pooler_output[0].data if int(seq_len) == int(prompt_len) else None
pooler_output.append(output)
if current_seq_len == 0:
pooler_output.append(None)
continue
total_processed = int(seq_lens_decoder[i]) + int(seq_lens_encoder[i])
if total_processed == int(prompt_len):
output = raw_pooler_output[pooler_output_idx]
else:
current_seq_len_decoder = seq_lens_decoder_batch[i]
if int(current_seq_len_decoder) + int(seq_len) == int(prompt_len):
output = raw_pooler_output[0].data
else:
output = None
pooler_output.append(output)
output = None
pooler_output = PoolerOutput(
outputs=pooler_output,
)
pooler_output.append(output)
pooler_output_idx += 1
return pooler_output
return PoolerOutput(outputs=pooler_output)
def _execute_empty_input(self, forward_meta) -> None:
"""

View File

@@ -32,6 +32,8 @@ from e2e.utils.serving_utils import (
is_port_open,
)
from fastdeploy import envs
@pytest.fixture(scope="session", autouse=True)
def setup_and_run_embedding_server():
@@ -50,6 +52,8 @@ def setup_and_run_embedding_server():
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model path not found: {model_path}")
envs.FD_ENABLE_MAX_PREFILL = 1
log_path = "embedding_server.log"
cmd = [
sys.executable,