mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] Support prefill batch inference for pooling models. (#5436)
* fix multi-inputs * fix threshold * fix threshold * fix * support multi-batch * add tests * fix test * test * fix
This commit is contained in:
@@ -2534,7 +2534,6 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def _pool(self, hidden_states: paddle.Tensor, num_running_requests: int) -> Optional[ModelRunnerOutput]:
|
def _pool(self, hidden_states: paddle.Tensor, num_running_requests: int) -> Optional[ModelRunnerOutput]:
|
||||||
|
|
||||||
num_scheduled_tokens = int(self.share_inputs["seq_lens_this_time"][:num_running_requests].sum())
|
num_scheduled_tokens = int(self.share_inputs["seq_lens_this_time"][:num_running_requests].sum())
|
||||||
hidden_states = hidden_states[:num_scheduled_tokens]
|
hidden_states = hidden_states[:num_scheduled_tokens]
|
||||||
|
|
||||||
@@ -2546,36 +2545,40 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
pooling_params=self.pooling_params,
|
pooling_params=self.pooling_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_scheduled_tokens_list = [
|
num_scheduled_tokens_list = [
|
||||||
int(self.share_inputs["seq_lens_this_time"][i]) for i in range(num_running_requests)
|
int(self.share_inputs["seq_lens_this_time"][i]) for i in range(num_running_requests)
|
||||||
]
|
]
|
||||||
|
|
||||||
device_str = "gpu" if hidden_states.place.is_gpu_place() else "cpu"
|
device_str = "gpu" if hidden_states.place.is_gpu_place() else "cpu"
|
||||||
pooling_metadata.build_pooling_cursor(num_scheduled_tokens_list, device=device_str)
|
pooling_metadata.build_pooling_cursor(num_scheduled_tokens_list, device=device_str)
|
||||||
|
|
||||||
raw_pooler_output = self.model.pooler(hidden_states=hidden_states, pooling_metadata=pooling_metadata)
|
raw_pooler_output = self.model.pooler(hidden_states=hidden_states, pooling_metadata=pooling_metadata)
|
||||||
|
|
||||||
seq_lens_cpu = self.share_inputs["seq_lens_this_time"][:num_running_requests]
|
seq_lens_decoder = self.share_inputs["seq_lens_decoder"][:num_running_requests]
|
||||||
|
seq_lens_encoder = self.share_inputs["seq_lens_encoder"][:num_running_requests]
|
||||||
|
|
||||||
pooler_output: list[Optional[paddle.Tensor]] = []
|
pooler_output: list[Optional[paddle.Tensor]] = []
|
||||||
|
pooler_output_idx = 0
|
||||||
|
|
||||||
seq_lens_decoder_batch = self.share_inputs["seq_lens_decoder"][:num_running_requests]
|
for i, prompt_len in enumerate(pooling_metadata.prompt_lens):
|
||||||
|
current_seq_len = num_scheduled_tokens_list[i]
|
||||||
|
|
||||||
for i, (seq_len, prompt_len) in enumerate(zip(seq_lens_cpu, pooling_metadata.prompt_lens)):
|
if current_seq_len == 0:
|
||||||
if not self.cache_config.enable_prefix_caching:
|
pooler_output.append(None)
|
||||||
output = raw_pooler_output[0].data if int(seq_len) == int(prompt_len) else None
|
continue
|
||||||
pooler_output.append(output)
|
|
||||||
|
total_processed = int(seq_lens_decoder[i]) + int(seq_lens_encoder[i])
|
||||||
|
|
||||||
|
if total_processed == int(prompt_len):
|
||||||
|
output = raw_pooler_output[pooler_output_idx]
|
||||||
else:
|
else:
|
||||||
current_seq_len_decoder = seq_lens_decoder_batch[i]
|
output = None
|
||||||
if int(current_seq_len_decoder) + int(seq_len) == int(prompt_len):
|
|
||||||
output = raw_pooler_output[0].data
|
|
||||||
else:
|
|
||||||
output = None
|
|
||||||
pooler_output.append(output)
|
|
||||||
|
|
||||||
pooler_output = PoolerOutput(
|
pooler_output.append(output)
|
||||||
outputs=pooler_output,
|
pooler_output_idx += 1
|
||||||
)
|
|
||||||
|
|
||||||
return pooler_output
|
return PoolerOutput(outputs=pooler_output)
|
||||||
|
|
||||||
def _execute_empty_input(self, forward_meta) -> None:
|
def _execute_empty_input(self, forward_meta) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -32,6 +32,8 @@ from e2e.utils.serving_utils import (
|
|||||||
is_port_open,
|
is_port_open,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from fastdeploy import envs
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def setup_and_run_embedding_server():
|
def setup_and_run_embedding_server():
|
||||||
@@ -50,6 +52,8 @@ def setup_and_run_embedding_server():
|
|||||||
if not os.path.exists(model_path):
|
if not os.path.exists(model_path):
|
||||||
raise FileNotFoundError(f"Model path not found: {model_path}")
|
raise FileNotFoundError(f"Model path not found: {model_path}")
|
||||||
|
|
||||||
|
envs.FD_ENABLE_MAX_PREFILL = 1
|
||||||
|
|
||||||
log_path = "embedding_server.log"
|
log_path = "embedding_server.log"
|
||||||
cmd = [
|
cmd = [
|
||||||
sys.executable,
|
sys.executable,
|
||||||
|
|||||||
Reference in New Issue
Block a user