[Feature] support pooling model dummy_run (#4345)

* support qwen3-embedding

* fix ci bug

* support pooling dummy_run

* fix

* delete print

* parallel_config.max_model_len

* delete is_pooling_model in dummy_run

* fix

* fd_model

* fix embedding load

* fix

* fix post_process
This commit is contained in:
lizexu123
2025-10-17 13:30:55 +08:00
committed by GitHub
parent 15b6b8dc25
commit c234b995ab
10 changed files with 291 additions and 126 deletions

View File

@@ -37,8 +37,6 @@ class PoolingParams(
normalize: Whether to normalize the embeddings outputs.
dimensions: Reduce the dimensions of embeddings
if model support matryoshka representation.
activation: Whether to apply activation function to
the classification outputs.
softmax: Whether to apply softmax to the reward outputs.
step_tag_id: Step tag ID for process reward models to identify
specific steps in multi-step reasoning tasks.

View File

@@ -163,10 +163,8 @@ class VocabParallelEmbedding(nn.Layer):
initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range),
),
)
if self.world_size > 1:
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
if num_embeddings % self.world_size != 0:
set_weight_attrs(self.embeddings.weight, {"weight_loader", self.weight_loader})
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
set_weight_attrs(self.embeddings.weight, {"weight_loader": self.weight_loader})
else:
# column cut embedding
self.embeddings = nn.Embedding(
@@ -176,8 +174,8 @@ class VocabParallelEmbedding(nn.Layer):
self.embeddings.weight.is_distributed = True
self.embeddings.weight.split_axis = 1
if self.world_size > 1:
set_weight_attrs(self.embeddings.weight, {"output_dim": True})
set_weight_attrs(self.embeddings.weight, {"output_dim": True})
self.prefix = prefix
self.dropout = nn.Dropout(self.hidden_dropout_prob)

View File

@@ -69,8 +69,9 @@ def build_pooling_cursor(num_scheduled_tokens: list[int], prompt_lens: paddle.Te
n_seq = len(num_scheduled_tokens)
index = list(range(n_seq))
num_scheduled_tokens = paddle.to_tensor(num_scheduled_tokens, device="cpu")
cumsum = paddle.zeros([n_seq + 1], dtype="int64", place=paddle.CPUPlace())
num_scheduled_tokens = paddle.to_tensor(num_scheduled_tokens)
cumsum = paddle.zeros([n_seq + 1], dtype="int64")
paddle.cumsum(num_scheduled_tokens, axis=0, out=cumsum[1:])
if device == "gpu":
cumsum_device = cumsum.cuda()

View File

@@ -332,6 +332,29 @@ class PoolingMethod(nn.Layer, ABC):
return MeanPool()
raise NotImplementedError(f"Unsupported method: {pooling_type}")
@abstractmethod
def get_supported_tasks(self) -> Set[PoolingTask]:
raise NotImplementedError
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate()
@abstractmethod
def forward_all(
self,
hidden_states: paddle.Tensor,
pooling_cursor: PoolingCursor,
) -> Union[list[paddle.Tensor], paddle.Tensor]:
raise NotImplementedError
def forward(
self,
hidden_states: paddle.Tensor,
pooling_metadata: PoolingMetadata,
) -> Union[list[paddle.Tensor], paddle.Tensor]:
pooling_cursor = pooling_metadata.pooling_cursor
return self.forward_all(hidden_states, pooling_cursor)
class LastPool(PoolingMethod):

View File

@@ -180,7 +180,7 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
def as_embedding_model(cls: _T) -> _T:
"""
Subclass an existing vLLM model to support embeddings.
Subclass an existing FastDeploy model to support embeddings.
By default, the embeddings of the whole prompt are extracted from the
normalized hidden state corresponding to the last token.

View File

@@ -12,9 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Type
from typing import ClassVar, Literal, Protocol, Type
import paddle
from paddle import nn
from typing_extensions import TypeVar, runtime_checkable
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.layers.pooler import Pooler
T = TypeVar("T", default=paddle.Tensor)
T_co = TypeVar("T_co", default=paddle.Tensor, covariant=True)
def is_text_generation_model(model_cls: Type[nn.Layer]) -> bool:
@@ -24,13 +33,7 @@ def is_text_generation_model(model_cls: Type[nn.Layer]) -> bool:
def is_pooling_model(model_cls: Type[nn.Layer]) -> bool:
class_name = model_cls.__name__
pooling_indicators = ["Embedding", "ForSequenceClassification"]
return (
any(indicator in class_name for indicator in pooling_indicators)
or hasattr(model_cls, "is_embedding_model")
and model_cls.is_embedding_model
)
return getattr(model_cls, "is_pooling_model", False)
def is_multimodal_model(class_name: str) -> bool:
@@ -52,3 +55,48 @@ def get_default_pooling_type(model_cls: Type[nn.Layer] = None) -> str:
if model_cls is not None:
return getattr(model_cls, "default_pooling_type", "LAST")
return "LAST"
@runtime_checkable
class FdModel(Protocol[T_co]):
"""The interface required for all models in FastDeploy."""
def __init__(
self,
fd_config: FDConfig,
prefix: str = "",
) -> None:
pass
def forward(
self,
ids_remove_padding: paddle.Tensor,
forward_metadata: ForwardMeta,
) -> T_co:
pass
class FdModelForPooling(FdModel[T_co], Protocol[T_co]):
"""The interface required for all pooling models in FastDeploy."""
is_pooling_model: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports pooling.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
default_pooling_type: ClassVar[str] = "LAST"
"""
Indicates the
[fastdeploy.config.PoolerConfig.pooling_type][]
to use by default.
You can use the
[fastdeploy.model_executor.models.interfaces_base.default_pooling_type][]
decorator to conveniently set this field.
"""
pooler: Pooler
"""The pooler is only called on TP rank 0."""

View File

@@ -303,7 +303,9 @@ class Qwen3ForCausalLM(ModelForCasualLM):
if model_param_name not in params_dict:
continue
param = params_dict[model_param_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight, shard_id)
break

View File

@@ -157,7 +157,7 @@ def free_tensor(tensor):
del tensor
def default_weight_loader(fd_config: FDConfig) -> None:
def default_weight_loader(fd_config: FDConfig = None) -> None:
"""Default weight loader"""
def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None):
@@ -169,7 +169,7 @@ def default_weight_loader(fd_config: FDConfig) -> None:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# Tensor parallelism splits the weight along the output_dim
if output_dim is not None and fd_config.parallel_config.tensor_parallel_size > 1:
if output_dim is not None and fd_config is not None and fd_config.parallel_config.tensor_parallel_size > 1:
dim = -1 if output_dim else 0
if isinstance(loaded_weight, paddle.Tensor):
size = loaded_weight.shape[dim]

View File

@@ -18,7 +18,7 @@ import os
import queue
import time
from threading import Thread
from typing import List, Optional
from typing import List, Optional, cast
import numpy as np
import paddle
@@ -77,10 +77,15 @@ if not (current_platform.is_dcu() or current_platform.is_iluvatar()):
import zmq
from fastdeploy import envs
from fastdeploy.engine.pooling_params import PoolingParams
from fastdeploy.engine.tasks import PoolingTask
from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.layers.pool.metadata import PoolingMetadata
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
from fastdeploy.model_executor.models.interfaces_base import FdModelForPooling
from fastdeploy.output.pooler import PoolerOutput
from fastdeploy.worker.model_runner_base import ModelRunnerBase
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput
@@ -103,6 +108,7 @@ class GPUModelRunner(ModelRunnerBase):
self.speculative_decoding = self.speculative_method is not None
self.enable_logprob = fd_config.model_config.enable_logprob
self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop
self.is_pooling_model = self.fd_config.model_config.runner_type == "pooling"
# VL model config:
if self.enable_mm:
@@ -753,6 +759,25 @@ class GPUModelRunner(ModelRunnerBase):
return input_length_list, max_dec_len_list, block_num
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
model = self.get_model()
if not self.is_pooling_model:
return []
supported_tasks = list(model.pooler.get_supported_tasks())
if self.cache_config.enable_chunked_prefill and "encode" in supported_tasks:
supported_tasks.remove("encode")
logger.warning(
"Chunked prefill is not supported with "
"encode task which using ALL pooling. "
"Please turn off chunked prefill by export=FD_DISABLE_CHUNKED_PREFILL=1 before using it."
)
# score not support
return supported_tasks
def _dummy_prefill_inputs(self, input_length_list: List[int], max_dec_len_list: List[int], block_num: int):
"""Set dummy prefill inputs to share_inputs"""
batch_size = len(input_length_list)
@@ -1320,6 +1345,171 @@ class GPUModelRunner(ModelRunnerBase):
self.attn_backends.append(attn_backend)
def _dummy_pooler_run_task(
self,
hidden_states: paddle.Tensor,
task: PoolingTask,
) -> PoolerOutput:
num_tokens = hidden_states.shape[0]
max_num_seqs = self.scheduler_config.max_num_seqs
num_reqs = min(num_tokens, max_num_seqs)
min_tokens_per_req = num_tokens // num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs
req_num_tokens = num_tokens // num_reqs
dummy_prompt_lens = paddle.to_tensor(num_scheduled_tokens_list, dtype="int64")
dummy_token_ids = paddle.zeros(
[num_reqs, req_num_tokens],
dtype="int64",
)
model = cast(FdModelForPooling, self.get_model())
dummy_pooling_params = PoolingParams(task=task)
to_update = model.pooler.get_pooling_updates(task)
to_update.apply(dummy_pooling_params)
dummy_metadata = PoolingMetadata(
prompt_lens=dummy_prompt_lens,
prompt_token_ids=dummy_token_ids,
pooling_params=[dummy_pooling_params] * num_reqs,
)
dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list, device=hidden_states.place)
try:
return model.pooler(hidden_states=hidden_states, pooling_metadata=dummy_metadata)
except RuntimeError as e:
if "out of memory" in str(e):
raise RuntimeError(
"CUDA out of memory occurred when warming up pooler "
f"({task=}) with {num_reqs} dummy requests. Please try "
"lowering `max_num_seqs` or `gpu_memory_utilization` when "
"initializing the engine."
) from e
else:
raise e
def _dummy_pooler_run(
self,
hidden_states: paddle.Tensor,
) -> PoolerOutput:
output_size = dict[PoolingTask, float]()
for task in self.get_supported_pooling_tasks():
output = self._dummy_pooler_run_task(hidden_states, task)
output_size[task] = output.get_data_nbytes()
del output
max_task = max(output_size.items(), key=lambda x: x[1])[0]
final_output = self._dummy_pooler_run_task(hidden_states, max_task)
return final_output
def _dummy_sampler_run(
self,
hidden_states: paddle.Tensor,
model_output: paddle.Tensor,
) -> paddle.Tensor:
logits = self.model.compute_logits(hidden_states)
if not self.speculative_decoding:
set_value_by_flags_and_idx(
self.share_inputs["pre_ids"],
self.share_inputs["input_ids"],
self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_encoder"],
self.share_inputs["seq_lens_decoder"],
self.share_inputs["step_idx"],
self.share_inputs["stop_flags"],
)
sampler_output = self.sampler(logits, self.sampling_metadata)
if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.broadcast(
sampler_output.sampled_token_ids,
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
group=self.parallel_config.tp_group,
)
else:
self.sampler(
logits,
self.sampling_metadata,
self.model_config.max_model_len,
self.share_inputs,
)
sampler_output = None
if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.broadcast(
self.share_inputs["accept_tokens"],
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
group=self.parallel_config.tp_group,
)
paddle.distributed.broadcast(
self.share_inputs["accept_num"],
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
group=self.parallel_config.tp_group,
)
paddle.distributed.broadcast(
self.share_inputs["step_idx"],
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
group=self.parallel_config.tp_group,
)
paddle.distributed.broadcast(
self.share_inputs["stop_flags"],
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
group=self.parallel_config.tp_group,
)
# 5. post process
model_output_data = ModelOutputData(
next_tokens=self.share_inputs["next_tokens"],
stop_flags=self.share_inputs["stop_flags"],
step_idx=self.share_inputs["step_idx"],
max_dec_len=self.share_inputs["max_dec_len"],
pre_ids=self.share_inputs["pre_ids"],
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
eos_token_id=self.share_inputs["eos_token_id"],
not_need_stop=self.share_inputs["not_need_stop"],
input_ids=self.share_inputs["input_ids"],
stop_nums=self.share_inputs["stop_nums"],
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
is_block_step=self.share_inputs["is_block_step"],
full_hidden_states=model_output,
msg_queue_id=self.parallel_config.msg_queue_id,
mp_rank=self.parallel_config.tensor_parallel_rank,
use_ep=self.parallel_config.use_ep,
draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None),
actual_draft_token_num=(
self.share_inputs["actual_draft_token_num"] if self.speculative_decoding else None
),
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None),
think_end_id=(getattr(self.model_config, "think_end_id", -1) if self.enable_mm else -1),
need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None),
reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None),
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
)
post_process(
sampler_output=sampler_output,
model_output=model_output_data,
share_inputs=self.share_inputs,
block_size=self.cache_config.block_size,
speculative_decoding=self.speculative_decoding,
skip_save_output=True,
async_output_queue=self.async_output_queue,
)
if self.speculative_decoding:
if self.speculative_method == "mtp":
self.proposer.run(full_hidden_states=model_output)
else:
self.proposer.run(share_inputs=self.share_inputs)
return sampler_output
def _dummy_run(
self,
num_tokens: paddle.Tensor,
@@ -1392,110 +1582,11 @@ class GPUModelRunner(ModelRunnerBase):
self.model_config.max_model_len,
)
logits = None
if hasattr(self.model, "is_pooling_model") and self.model.is_pooling_model:
# TODO(lizexu123) The preheating the pooling function have not been implemented yet.
pass
if self.is_pooling_model:
self._dummy_pooler_run(hidden_states)
break
else:
# 4. Execute spec decode
logits = self.model.compute_logits(hidden_states)
if not self.speculative_decoding:
set_value_by_flags_and_idx(
self.share_inputs["pre_ids"],
self.share_inputs["input_ids"],
self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_encoder"],
self.share_inputs["seq_lens_decoder"],
self.share_inputs["step_idx"],
self.share_inputs["stop_flags"],
)
sampler_output = self.sampler(logits, self.sampling_metadata)
if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.broadcast(
sampler_output.sampled_token_ids,
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
group=self.parallel_config.tp_group,
)
else:
self.sampler(
logits,
self.sampling_metadata,
self.model_config.max_model_len,
self.share_inputs,
accept_all_drafts,
)
sampler_output = None
if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.broadcast(
self.share_inputs["accept_tokens"],
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
group=self.parallel_config.tp_group,
)
paddle.distributed.broadcast(
self.share_inputs["accept_num"],
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
group=self.parallel_config.tp_group,
)
paddle.distributed.broadcast(
self.share_inputs["step_idx"],
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
group=self.parallel_config.tp_group,
)
paddle.distributed.broadcast(
self.share_inputs["stop_flags"],
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
group=self.parallel_config.tp_group,
)
# 5. post process
model_output_data = ModelOutputData(
next_tokens=self.share_inputs["next_tokens"],
stop_flags=self.share_inputs["stop_flags"],
step_idx=self.share_inputs["step_idx"],
max_dec_len=self.share_inputs["max_dec_len"],
pre_ids=self.share_inputs["pre_ids"],
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
eos_token_id=self.share_inputs["eos_token_id"],
not_need_stop=self.share_inputs["not_need_stop"],
input_ids=self.share_inputs["input_ids"],
stop_nums=self.share_inputs["stop_nums"],
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
is_block_step=self.share_inputs["is_block_step"],
full_hidden_states=model_output,
msg_queue_id=self.parallel_config.msg_queue_id,
mp_rank=self.parallel_config.tensor_parallel_rank,
use_ep=self.parallel_config.use_ep,
draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None),
actual_draft_token_num=(
self.share_inputs["actual_draft_token_num"] if self.speculative_decoding else None
),
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
enable_thinking=self.share_inputs["enable_thinking"],
think_end_id=self.model_config.think_end_id,
need_think_end=self.share_inputs["need_think_end"],
reasoning_index=self.share_inputs["reasoning_index"],
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
prompt_lens=self.share_inputs["prompt_lens"],
)
post_process(
sampler_output=sampler_output,
model_output=model_output_data,
share_inputs=self.share_inputs,
block_size=self.cache_config.block_size,
speculative_decoding=self.speculative_decoding,
skip_save_output=True,
async_output_queue=self.async_output_queue,
)
if self.speculative_decoding:
if self.speculative_method == "mtp":
self.proposer.run(full_hidden_states=model_output)
else:
self.proposer.run(share_inputs=self.share_inputs)
self._dummy_sampler_run(hidden_states, model_output)
# 7. Updata 'infer_seed' and step_cuda()
self.share_inputs["infer_seed"].add_(self.infer_seed_increment)
@@ -1507,7 +1598,6 @@ class GPUModelRunner(ModelRunnerBase):
self.speculative_config,
self.cache_config.enable_prefix_caching,
)
if int((self.share_inputs["seq_lens_this_time"] > 0).sum()) == 0:
break

View File

@@ -281,3 +281,8 @@ class ModelRunnerOutput:
[num_reqs, num_spec_tokens]
"""
spec_token_ids: Optional[list[list[int]]]
"""
[num_reqs, hidden_size]
"""
pooler_output: list[Optional[paddle.Tensor]]