diff --git a/docs/features/pooling_models.md b/docs/features/pooling_models.md
new file mode 100644
index 000000000..921f38c3a
--- /dev/null
+++ b/docs/features/pooling_models.md
@@ -0,0 +1,175 @@
+[简体中文](../zh/features//pooling_models.md)
+
+# Pooling Models
+
+FastDeploy also supports pooling models, such as embedding models.
+
+In FastDeploy, pooling models implement the `FdModelForPooling` interface.
+These models use a `Pooler` to extract the final hidden states of the input
+before returning them.
+
+## Configuration
+
+### Model Runner
+
+Run a model in pooling mode via the option `--runner pooling`.
+
+!!! tip
+ There is no need to set this option in the vast majority of cases as Fastdeploy can automatically
+ detect the appropriate model runner via `--runner auto`.
+
+### Model Conversion
+
+FastDeploy can adapt models for various pooling tasks via the option `--convert `.
+
+If `--runner pooling` has been set (manually or automatically) but the model does not implement the
+`FdModelForPooling` interface,
+vLLM will attempt to automatically convert the model according to the architecture names
+shown in the table below.
+
+| Architecture | `--convert` | Supported pooling tasks |
+|-------------------------------------------------|-------------|---------------------------------------|
+| `*ForTextEncoding`, `*EmbeddingModel`, `*Model` `*ForProcessRewardModel` | `embed` | `embed` |
+
+!!! tip
+ You can explicitly set `--convert ` to specify how to convert the model.
+
+### Pooler Configuration
+
+#### Predefined models
+
+If the `Pooler` defined by the model accepts `pooler_config`,
+you can override some of its attributes via the `--pooler-config` option.
+
+#### Converted models
+
+If the model has been converted via `--convert` (see above),
+the pooler assigned to each task has the following attributes by default:
+
+| Task | Pooling Type | Normalization | Softmax |
+|------------|--------------|---------------|---------|
+| `embed` | `LAST` | ✅︎ | ❌ |
+
+When loading [Sentence Transformers](https://huggingface.co/sentence-transformers) models,
+its Sentence Transformers configuration file (`modules.json`) takes priority over the model's defaults and It can also be specified during model network construction through @default_pooling_type("LAST").
+
+##### Pooling Type
+
+1.LastPool(PoolingType.LAST)
+
+Purpose:Extracts the hidden state of the last token in each sequence
+
+2.AllPool(PoolingType.ALL)
+
+Purpose:Returns the hidden states of all tokens in each sequence
+
+3.CLSPool(PoolingType.CLS)
+
+Purpose:Returns the hidden state of the first token in each sequence (CLS token)
+
+4.MeanPool(PoolingType.MEAN)
+
+Purpose:Computes the average of all token hidden states in each sequence
+
+## Online Serving
+
+FastDeploy's OpenAI-compatible server provides API endpoints and custom reward interfaces.
+
+[Embeddings API], supports text and multi-modal inputs
+
+[Reward API], scores specific content
+
+### Embedding Model:
+```python
+model_path=Qwen/Qwen3-Embedding-0.6B
+
+python -m fastdeploy.entrypoints.openai.api_server --model ${model_path} \
+ --max-num-seqs 256 --max-model-len 32768 \
+ --port 9412 --engine-worker-queue-port 7142 \
+ --metrics-port 7211 --tensor-parallel-size 1 \
+ --gpu-memory-utilization 0.9 \
+ --runner pooling
+```
+
+Request Methods:
+A. EmbeddingCompletionRequest Example (Standard Text Input)
+
+```bash
+curl -X POST 'YOUR_SERVICE_URL/v1/embeddings' \
+ -H 'Content-Type: application/json' \
+ -d '{
+ "model": "text-embedding-chat-model",
+ "input": [
+ "This is a sentence for pooling embedding.",
+ "Another input text."
+ ],
+ "user": "test_client"
+ }'
+```
+
+B. EmbeddingChatRequest Example (Message Sequence Input)
+
+```bash
+curl -X POST 'YOUR_SERVICE_URL/v1/embeddings' \
+ -H 'Content-Type: application/json' \
+ -d '{
+ "model": "text-embedding-chat-model",
+ "messages": [
+ {"role": "user", "content": "Generate embedding for user query."}
+ ]
+ }'
+```
+
+### Pooling Model and reward score
+```python
+model_path=RM_v1008
+python -m fastdeploy.entrypoints.openai.api_server \
+ --model ${model_path} \
+ --max-num-seqs 256 \
+ --max-model-len 8192 \
+ --port 13351 \
+ --engine-worker-queue-port 7562 \
+ --metrics-port 7531 \
+ --tensor-parallel-size 8 \
+ --gpu-memory-utilization 0.9 \
+ --runner pooling \
+ --convert embed
+```
+Request Method: ChatRewardRequest
+```bash
+curl --location 'http://xxxx/v1/chat/reward' \
+--header 'Content-Type: application/json' \
+--data '{
+ "model": "",
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": "https://xxx/a.png"
+ }
+ }
+ ]
+ },
+ {
+ "role": "assistant",
+ "content": [
+ {
+ "type": "text",
+ "text": "图里有几个人"
+ }
+ ]
+ }
+ ],
+ "user": "user-123",
+ "chat_template": null,
+ "chat_template_kwargs": {
+ "custom_var": "value"
+ },
+ "mm_processor_kwargs": {
+ "image_size": 224
+ }
+}'
+```
diff --git a/docs/zh/features/pooling_models.md b/docs/zh/features/pooling_models.md
new file mode 100644
index 000000000..e2f4ce7b2
--- /dev/null
+++ b/docs/zh/features/pooling_models.md
@@ -0,0 +1,168 @@
+[English](../../features/pooling_models.md)
+
+# Pooling Models
+
+FastDeploy也支持pooling模型,例如嵌入(embedding)模型。
+
+在FastDeploy中,池化模型通过`FdModelForPooling`接口。这些模型使用一个`Pooler`来提取输入的最终隐藏状态并返回。
+
+## Configuration
+
+### Model Runner
+
+通过`--runner pooling`选项以池化模型运行模型。
+
+!!! 提示
+ 在绝大多数情况下无需手动设置该选项,因此Fastdeploy可以通过--runner auto(默认值)自动检测合适的runner。
+
+### Model Conversion
+
+如果模型未实现FdModelForPooling接口但你希望以池化模式运行,FastDeploy可通过`--convert `自动转换模型。
+
+当设置了`--runner pooling`(手动或自动)但模型不符合接口时,FastDeploy会根据模型架构名称自动转换:
+
+| Architecture | `--convert` | 支持的池化类型 |
+|-------------------------------------------------|-------------|---------------------------------------|
+| `*ForTextEncoding`, `*EmbeddingModel`, `*Model` `**ForProcessRewardModel` | `embed` | `embed` |
+
+!!! 提示
+ 你可以显示设置`--convert `来制定模型转换方式。
+
+### Pooler Configuration
+
+#### Predefined models
+
+如果模型定义的`Pooler`接受pooler_config,你可以通过--pooler_config覆盖部分属性。
+
+#### Converted models
+
+如果模型通过--convert转换,各任务默认的池化配置如下:
+
+| Task | Pooling Type | Normalization | Softmax |
+|------------|--------------|---------------|---------|
+| `embed` | `LAST` | ✅︎ | ❌ |
+
+加载[Sentence Transformers](https://huggingface.co/sentence-transformers)模型时,其`modules.json`配置优于默认值,也可以通过@default_pooling_type("LAST")在模型组网时指定。
+
+#### Pooling Type
+
+1.LastPool(PoolingType.LAST)
+
+作用:提取每个序列的最后一个token的隐藏状态
+
+2.AllPool(PoolingType.ALL)
+
+作用:返回每个序列的所有token的隐藏状态
+
+3.CLSPool(PoolingType.CLS)
+
+作用:返回每个序列的第一个token(CLS token)的隐藏状态
+
+4.MeanPool(PoolingType.MEAN)
+
+作用:计算每个序列所有token隐藏状态的平均值
+
+## Online Serving
+
+FastDeploy的OpenAI兼容服务器提供了API的端点和自定义的reward接口
+
+- `Embeddings API`,支持文本和多模态输入
+- `Reward API`,给指定的内容打分
+
+### Embedding模型:
+```python
+model_path=Qwen/Qwen3-Embedding-0.6B
+
+python -m fastdeploy.entrypoints.openai.api_server --model ${model_path} \
+ --max-num-seqs 256 --max-model-len 32768 \
+ --port 9412 --engine-worker-queue-port 7142 \
+ --metrics-port 7211 --tensor-parallel-size 1 \
+ --gpu-memory-utilization 0.9 \
+ --runner pooling \
+
+```
+
+请求方式:
+A. EmbeddingCompletionRequest 示例(标准文本输入)
+
+```bash
+curl -X POST 'YOUR_SERVICE_URL/v1/embeddings' \
+ -H 'Content-Type: application/json' \
+ -d '{
+ "model": "text-embedding-chat-model",
+ "input": [
+ "This is a sentence for pooling embedding.",
+ "Another input text."
+ ],
+ "user": "test_client"
+ }'
+```
+
+B. EmbeddingChatRequest 示例(消息序列输入)
+
+```bash
+curl -X POST 'YOUR_SERVICE_URL/v1/embeddings' \
+ -H 'Content-Type: application/json' \
+ -d '{
+ "model": "text-embedding-chat-model",
+ "messages": [
+ {"role": "user", "content": "Generate embedding for user query."}
+ ]
+ }'
+```
+
+### Pooling模型和打分机制
+```python
+model_path=RM_v1008
+python -m fastdeploy.entrypoints.openai.api_server \
+ --model ${model_path} \
+ --max-num-seqs 256 \
+ --max-model-len 8192 \
+ --port 13351 \
+ --engine-worker-queue-port 7562 \
+ --metrics-port 7531 \
+ --tensor-parallel-size 8 \
+ --gpu-memory-utilization 0.9 \
+ --runner pooling \
+ --convert embed \
+```
+
+请求方式: ChatRewardRequest
+
+```bash
+curl --location 'http://xxxx/v1/chat/reward' \
+--header 'Content-Type: application/json' \
+--data '{
+ "model": "",
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": "https://xxx/a.png"
+ }
+ }
+ ]
+ },
+ {
+ "role": "assistant",
+ "content": [
+ {
+ "type": "text",
+ "text": "图里有几个人"
+ }
+ ]
+ }
+ ],
+ "user": "user-123",
+ "chat_template": null,
+ "chat_template_kwargs": {
+ "custom_var": "value"
+ },
+ "mm_processor_kwargs": {
+ "image_size": 224
+ }
+}'
+```
diff --git a/fastdeploy/config.py b/fastdeploy/config.py
index 37eadd768..e99d8f531 100644
--- a/fastdeploy/config.py
+++ b/fastdeploy/config.py
@@ -299,7 +299,7 @@ class ModelConfig:
self.tensor_parallel_size = self.infer_model_mp_num
del self.infer_model_mp_num
- if hasattr(self, "num_hidden_layers"):
+ if hasattr(self, "num_hidden_layers") and self.runner != "pooling":
if hasattr(self, "remove_tail_layer"):
if self.remove_tail_layer is True:
self.num_hidden_layers -= 1
diff --git a/fastdeploy/engine/pooling_params.py b/fastdeploy/engine/pooling_params.py
index 93192b6ec..7f8314129 100644
--- a/fastdeploy/engine/pooling_params.py
+++ b/fastdeploy/engine/pooling_params.py
@@ -164,7 +164,7 @@ class PoolingParams(
self.softmax = True
elif self.task == "reward":
if self.normalize is None:
- self.normalize = True
+ self.normalize = False
else:
raise ValueError(f"Unknown pooling task: {self.task}")
diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py
index 2fe8e1317..888364bb6 100644
--- a/fastdeploy/engine/request.py
+++ b/fastdeploy/engine/request.py
@@ -187,6 +187,7 @@ class Request:
pooling_params = PoolingParams.from_dict(d["pooling_params"])
else:
sampling_params = SamplingParams.from_dict(d)
+
if (
isinstance(d.get("multimodal_inputs"), dict)
and isinstance(d["multimodal_inputs"].get("mm_positions"), list)
@@ -202,7 +203,6 @@ class Request:
data_processor_logger.error(
f"Convert mm_positions to ImagePosition error: {e}, {str(traceback.format_exc())}"
)
-
return cls(
request_id=d["request_id"],
prompt=d.get("prompt"),
diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py
index 5d242d354..0bb35a284 100644
--- a/fastdeploy/entrypoints/openai/protocol.py
+++ b/fastdeploy/entrypoints/openai/protocol.py
@@ -920,16 +920,6 @@ class EmbeddingChatRequest(BaseModel):
user: Optional[str] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
- # --8<-- [start:chat-embedding-extra-params]
- add_generation_prompt: bool = Field(
- default=False,
- description=(
- "If true, the generation prompt will be added to the chat template. "
- "This is a parameter used by chat template in tokenizer config of the "
- "model."
- ),
- )
-
add_special_tokens: bool = Field(
default=False,
description=(
@@ -1013,9 +1003,9 @@ PoolingChatRequest = EmbeddingChatRequest
class ChatRewardRequest(BaseModel):
- model: Optional[str] = None # 指定模型,例如 "default" 或支持 embedding 的 chat 模型
- messages: Union[List[Any], List[int]] # 聊天消息列表(必选)
- user: Optional[str] = None # 调用方标识符
+ model: Optional[str] = None
+ messages: Union[List[Any], List[int]]
+ user: Optional[str] = None
dimensions: Optional[int] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
@@ -1084,15 +1074,15 @@ class ChatRewardRequest(BaseModel):
class ChatRewardData(BaseModel):
- index: Optional[int] = None # 数据索引(可选)
- object: str = "reward" # 固定为 "reward"
- score: List[float] # reward 分数(浮点数列表)
+ index: Optional[int] = None
+ object: str = "reward"
+ score: List[float]
class ChatRewardResponse(BaseModel):
- id: str # 响应 ID,例如 chat-reward-
- object: str = "object" # 固定为 "object"
- created: int # 创建时间(Unix 时间戳)
- model: str # 使用的模型名
- data: List[ChatRewardData] # reward 结果列表
- usage: Optional[UsageInfo] = None # Token 使用情况
+ id: str
+ object: str = "object"
+ created: int
+ model: str
+ data: List[ChatRewardData]
+ usage: Optional[UsageInfo] = None
diff --git a/fastdeploy/entrypoints/openai/serving_engine.py b/fastdeploy/entrypoints/openai/serving_engine.py
index bd22df566..bc8090d97 100644
--- a/fastdeploy/entrypoints/openai/serving_engine.py
+++ b/fastdeploy/entrypoints/openai/serving_engine.py
@@ -256,7 +256,6 @@ class ZmqOpenAIServing(OpenAIServing):
chat_template_kwargs.update(
{
"chat_template": request_dict.get("chat_template"),
- "add_generation_prompt": request_dict.get("add_generation_prompt"),
"add_stop_sequences": request_dict.get("add_stop_sequences"),
}
)
diff --git a/fastdeploy/model_executor/layers/pool/metadata.py b/fastdeploy/model_executor/layers/pool/metadata.py
index 699800a0f..951c3caea 100644
--- a/fastdeploy/model_executor/layers/pool/metadata.py
+++ b/fastdeploy/model_executor/layers/pool/metadata.py
@@ -15,12 +15,14 @@
"""
from dataclasses import dataclass
-from typing import Optional
+from typing import Optional, Union
import paddle
from fastdeploy.engine.pooling_params import PoolingParams
+Device = Union[paddle.CPUPlace, paddle.CUDAPlace, paddle.XPUPlace]
+
@dataclass
class PoolingCursor:
@@ -60,21 +62,21 @@ class PoolingMetadata:
pooling_cursor=None if self.pooling_cursor is None else self.pooling_cursor[indices],
)
- def build_pooling_cursor(self, num_scheduled_tokens: list[int], device: str):
+ def build_pooling_cursor(self, num_scheduled_tokens: list[int], device: Device):
self.pooling_cursor = build_pooling_cursor(num_scheduled_tokens, self.prompt_lens, device)
-def build_pooling_cursor(num_scheduled_tokens: list[int], prompt_lens: paddle.Tensor, device: str):
+def build_pooling_cursor(num_scheduled_tokens: list[int], prompt_lens: paddle.Tensor, device: Device):
assert len(prompt_lens) == len(num_scheduled_tokens)
n_seq = len(num_scheduled_tokens)
index = list(range(n_seq))
- num_scheduled_tokens = paddle.to_tensor(num_scheduled_tokens)
+ num_scheduled_tokens = paddle.to_tensor(num_scheduled_tokens, dtype="int64")
cumsum = paddle.zeros([n_seq + 1], dtype="int64")
paddle.cumsum(num_scheduled_tokens, axis=0, out=cumsum[1:])
- if device == "gpu":
- cumsum_device = cumsum.cuda()
+ if isinstance(device, paddle.CUDAPlace):
+ cumsum_device = paddle.assign(cumsum).cuda(device.get_device_id())
else:
cumsum_device = cumsum
return PoolingCursor(
diff --git a/fastdeploy/model_executor/layers/pooler.py b/fastdeploy/model_executor/layers/pooler.py
index 0266987a8..78f9a0eab 100644
--- a/fastdeploy/model_executor/layers/pooler.py
+++ b/fastdeploy/model_executor/layers/pooler.py
@@ -78,7 +78,6 @@ def get_pooling_params(pooling_metadata: PoolingMetadata) -> list[PoolingParams]
def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]:
pooling_params = get_pooling_params(pooling_metadata)
-
tasks: list[PoolingTask] = [task for pooling_param in pooling_params if (task := pooling_param.task) is not None]
assert len(pooling_params) == len(tasks)
@@ -108,7 +107,7 @@ class Pooler(nn.Layer, ABC):
@staticmethod
def for_encode(pooler_config: PoolerConfig, model_config: Optional["ModelConfig"] = None):
if pooler_config.pooling_type == "STEP":
- return StepPooler()
+ return StepPooler(model_config)
resolved_config = ResolvedPoolingConfig(task="encode", pooling_type=PoolingType.ALL)
return SimplePooler.from_config(resolved_config, model_config)
@@ -121,6 +120,14 @@ class Pooler(nn.Layer, ABC):
)
return SimplePooler.from_config(resolved_config, model_config)
+ @staticmethod
+ def for_reward(pooler_config: PoolerConfig, model_config: Optional["ModelConfig"] = None):
+ resolved_config = ResolvedPoolingConfig.from_config(
+ task="reward",
+ pooler_config=pooler_config,
+ )
+ return SimplePooler.from_config(resolved_config, model_config)
+
@staticmethod
def for_classify(
pooler_config: PoolerConfig,
@@ -274,6 +281,7 @@ class EmbeddingPoolerHead(PoolerHead):
pooled_data = [vecs if d is None else vecs[..., :d] for vecs, d in zip(pooled_data, dimensions_list)]
# for normalize
flags = [p.normalize for p in pooling_params]
+
if len(set(flags)) == 1:
if flags[0]:
pooled_data = self.activation(pooled_data)
@@ -293,7 +301,6 @@ class RewardPoolerHead(PoolerHead):
def forward(self, pooled_data: Union[list[paddle.Tensor], paddle.Tensor], pooling_metadata: PoolingMetadata):
pooling_params = get_pooling_params(pooling_metadata)
- # for softmax
flags = [p.softmax for p in pooling_params]
if len(set(flags)) == 1:
if flags[0]:
@@ -345,7 +352,7 @@ class PoolingMethod(nn.Layer, ABC):
class LastPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]:
- return {"encode", "embed", "classify", "score"}
+ return {"encode", "embed", "classify", "score", "reward"}
def forward_all(
self,
@@ -366,8 +373,8 @@ class AllPool(PoolingMethod):
) -> Union[list[paddle.Tensor], paddle.Tensor]:
assert not pooling_cursor.is_partial_prefill(), "partial prefill not supported with ALL pooling"
-
hidden_states_lst = list(hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist()))
+
return [hidden_states_lst[i] for i in pooling_cursor.index]
@@ -416,11 +423,12 @@ class CLSPool(PoolingMethod):
class StepPooler(Pooler):
def __init__(
self,
+ model_config: ModelConfig,
) -> None:
super().__init__()
self.pooling = AllPool()
- self.head = RewardPoolerHead()
+ self.head = RewardPoolerHead(model_config)
def extract_states(
self,
@@ -455,14 +463,11 @@ class StepPooler(Pooler):
def forward(
self,
- hidden_states: Union[paddle.Tensor, list[paddle.Tensor]],
+ hidden_states: paddle.Tensor | list[paddle.Tensor],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata)
- pooling_params = get_pooling_params(pooling_metadata)
- assert len(pooled_data) == len(pooling_params)
-
- pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
+ pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data
@@ -484,7 +489,7 @@ class SimplePooler(Pooler):
pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type)
if pooler_config.task == "embed":
head = EmbeddingPoolerHead(model_config)
- elif pooler_config.task == "encode":
+ elif pooler_config.task == "encode" or pooler_config.task == "reward":
head = RewardPoolerHead(model_config)
else:
raise NotImplementedError(f"Unknown task: {pooler_config.task}")
diff --git a/fastdeploy/model_executor/models/adapters.py b/fastdeploy/model_executor/models/adapters.py
index 7dcfd2c0c..306bbcd51 100644
--- a/fastdeploy/model_executor/models/adapters.py
+++ b/fastdeploy/model_executor/models/adapters.py
@@ -166,6 +166,7 @@ def as_embedding_model(cls: _T) -> _T:
{
"encode": Pooler.for_encode(pooler_config, fd_config.model_config),
"embed": Pooler.for_embed(pooler_config, fd_config.model_config),
+ "reward": Pooler.for_reward(pooler_config, fd_config.model_config),
},
)
diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py
index 7c1ff3eaf..c8b414bcd 100644
--- a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py
+++ b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py
@@ -348,12 +348,17 @@ class Ernie4_5_VLDecoderLayer(nn.Layer):
prefix=f"{prefix}.mlp",
)
+ norm_dtype = None
+ if fd_config.model_config.architectures[0] == "Ernie4_5_VLMoeForProcessRewardModel":
+ norm_dtype = "float32"
+
self.input_layernorm = RMSNorm(
fd_config,
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.input_layernorm",
layer_id=layer_id,
+ dtype=norm_dtype,
)
self.post_attention_layernorm = RMSNorm(
@@ -362,6 +367,7 @@ class Ernie4_5_VLDecoderLayer(nn.Layer):
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.post_attention_layernorm",
layer_id=layer_id,
+ dtype=norm_dtype,
)
def load_state_dict(self, state_dict):
@@ -542,7 +548,6 @@ class Ernie4_5_VLModel(nn.Layer):
)
out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
-
return out
diff --git a/fastdeploy/model_executor/models/ernie_vl_rm.py b/fastdeploy/model_executor/models/ernie_vl_rm.py
index cfa29c845..c7d255f94 100644
--- a/fastdeploy/model_executor/models/ernie_vl_rm.py
+++ b/fastdeploy/model_executor/models/ernie_vl_rm.py
@@ -57,6 +57,7 @@ class Ernie4_5_VLMoeRewardBaseModel(nn.Layer):
)
self.ernie = Ernie4_5_VLModel(fd_config=fd_config)
self.head_dtype = paddle.bfloat16
+ self.fd_config = fd_config
# Persistent buffers for CUDA graphs.
if fd_config.graph_opt_config.use_cudagraph:
@@ -111,7 +112,7 @@ class Ernie4_5_VLMoeRewardBaseModel(nn.Layer):
input_embeddings = self.get_input_embeddings(
ids_remove_padding=ids_remove_padding,
image_features=image_features,
- image_token_num=vl_moe_meta.image_token_num.item(),
+ image_token_num=vl_moe_meta.num_image_patch_id.item(),
)
if forward_meta.step_use_cudagraph:
@@ -124,18 +125,22 @@ class Ernie4_5_VLMoeRewardBaseModel(nn.Layer):
forward_meta=forward_meta,
vl_moe_meta=vl_moe_meta,
)
+
+ if isinstance(hidden_states, tuple):
+ hidden_states = hidden_states[0]
+
hidden_states = hidden_states.to(self.head_dtype)
logits = self.rm_head(hidden_states)
- return logits
+ return logits.cast("float32")
@ModelRegistry.register_model_class(
architecture="Ernie4_5_VLMoeForProcessRewardModel",
module_name="ernie_vl_rm",
- category=[ModelCategory.REWARD],
- primary_use=ModelCategory.REWARD,
+ category=ModelCategory.REWARD | ModelCategory.MULTIMODAL,
+ primary_use=ModelCategory.REWARD | ModelCategory.MULTIMODAL,
)
-@default_pooling_type("ALL")
+@default_pooling_type("LAST")
class Ernie4_5_VLMoeForProcessRewardModel(Ernie4_5_VLMoeRewardBaseModel):
def __init__(self, fd_config: FDConfig):
@@ -147,7 +152,13 @@ class Ernie4_5_VLMoeForProcessRewardModel(Ernie4_5_VLMoeRewardBaseModel):
pooler_config = fd_config.model_config.pooler_config
assert pooler_config is not None
- self.pooler = DispatchPooler({"encode": Pooler.for_encode(pooler_config)})
+ self.pooler = DispatchPooler(
+ {
+ "encode": Pooler.for_encode(pooler_config, fd_config.model_config),
+ "embed": Pooler.for_embed(pooler_config, fd_config.model_config),
+ "reward": Pooler.for_reward(pooler_config, fd_config.model_config),
+ },
+ )
self.process_weights_before_loading_fn = process_weights_before_loading(skip_prefixes=["lm_head"])
@@ -159,4 +170,5 @@ class Ernie4_5_VLMoeForProcessRewardModel(Ernie4_5_VLMoeRewardBaseModel):
@paddle.no_grad()
def load_weights(self, weights_iterator):
# Filter out lm_head weights of Ernie4_5_VLMoeForConditionalGeneration
+
Ernie4_5_VLMoeForConditionalGeneration.load_weights(self, weights_iterator)
diff --git a/fastdeploy/model_executor/models/model_base.py b/fastdeploy/model_executor/models/model_base.py
index b81606bae..2103c3a0c 100644
--- a/fastdeploy/model_executor/models/model_base.py
+++ b/fastdeploy/model_executor/models/model_base.py
@@ -58,7 +58,7 @@ class ModelInfo:
is_text_generation=ModelCategory.TEXT_GENERATION in category,
is_multimodal=ModelCategory.MULTIMODAL in category,
is_reasoning=ModelCategory.REASONING in category,
- is_pooling=ModelCategory.EMBEDDING in category,
+ is_pooling=(ModelCategory.EMBEDDING in category) or (ModelCategory.REWARD in category),
default_pooling_type=get_default_pooling_type(model_cls),
module_path=module_path,
)
diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py
index 411eec7d1..4a4132597 100644
--- a/fastdeploy/model_executor/pre_and_post_process.py
+++ b/fastdeploy/model_executor/pre_and_post_process.py
@@ -298,6 +298,8 @@ def _build_stream_transfer_data(
stream_transfer_datas.append(stream_transfer_data)
elif pooler_outputs is not None:
for bid, pooler_output in enumerate(pooler_outputs):
+ if pooler_output is None:
+ continue
if pooler_output.dtype == paddle.bfloat16:
pooler_output = pooler_output.astype("float32")
diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py
index 7965fcbb8..d5be2801f 100644
--- a/fastdeploy/worker/gpu_model_runner.py
+++ b/fastdeploy/worker/gpu_model_runner.py
@@ -542,7 +542,11 @@ class GPUModelRunner(ModelRunnerBase):
rope_3d_position_ids["position_ids_offset"].append(
position_ids.shape[0] + rope_3d_position_ids["position_ids_offset"][-1]
)
- rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048))
+
+ if self.is_pooling_model:
+ rope_3d_position_ids["max_tokens_lst"].append(0)
+ else:
+ rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048))
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = None):
"""
@@ -2428,7 +2432,6 @@ class GPUModelRunner(ModelRunnerBase):
def _pool(self, hidden_states: paddle.Tensor, num_running_requests: int) -> Optional[ModelRunnerOutput]:
num_scheduled_tokens = int(self.share_inputs["seq_lens_this_time"][:num_running_requests].sum())
-
hidden_states = hidden_states[:num_scheduled_tokens]
prompt_lens = self.share_inputs["prompt_lens"][:num_running_requests]
@@ -2446,11 +2449,23 @@ class GPUModelRunner(ModelRunnerBase):
pooling_metadata.build_pooling_cursor(num_scheduled_tokens_list, device=device_str)
raw_pooler_output = self.model.pooler(hidden_states=hidden_states, pooling_metadata=pooling_metadata)
+
seq_lens_cpu = self.share_inputs["seq_lens_this_time"][:num_running_requests]
pooler_output: list[Optional[paddle.Tensor]] = []
- for raw_output, seq_len, prompt_len in zip(raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens):
- output = raw_output.data if int(seq_len) == int(prompt_len) else None
- pooler_output.append(output)
+
+ seq_lens_decoder_batch = self.share_inputs["seq_lens_decoder"][:num_running_requests]
+
+ for i, (seq_len, prompt_len) in enumerate(zip(seq_lens_cpu, pooling_metadata.prompt_lens)):
+ if not self.cache_config.enable_prefix_caching:
+ output = raw_pooler_output[i].data if int(seq_len) == int(prompt_len) else None
+ pooler_output.append(output)
+ else:
+ current_seq_len_decoder = seq_lens_decoder_batch[i]
+ if int(current_seq_len_decoder) + int(seq_len) == int(prompt_len):
+ output = raw_pooler_output[i].data
+ else:
+ output = None
+ pooler_output.append(output)
pooler_output = PoolerOutput(
outputs=pooler_output,
diff --git a/tests/pooling/test_Ernie4_5_reward_serving.py b/tests/pooling/test_Ernie4_5_reward_serving.py
new file mode 100644
index 000000000..32a5d4f80
--- /dev/null
+++ b/tests/pooling/test_Ernie4_5_reward_serving.py
@@ -0,0 +1,204 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+import signal
+import subprocess
+import sys
+import time
+
+import pytest
+import requests
+from e2e.utils.serving_utils import (
+ FD_API_PORT,
+ FD_CACHE_QUEUE_PORT,
+ FD_ENGINE_QUEUE_PORT,
+ FD_METRICS_PORT,
+ clean_ports,
+ is_port_open,
+)
+
+# ==========================
+# Shared Helper Functions
+# ==========================
+
+
+def _start_server_process(enable_caching: bool, log_filename: str):
+
+ print(f"\n[Server Setup] Cleaning ports before starting (Caching={'ON' if enable_caching else 'OFF'})...")
+ clean_ports()
+
+ base_path = os.getenv("MODEL_PATH")
+ if base_path:
+ model_path = os.path.join(base_path, "RM_v1008_5")
+ else:
+ model_path = "./RM_v1008_5"
+
+ if not os.path.exists(model_path):
+ raise FileNotFoundError(f"Model path not found: {model_path}")
+
+ cmd = [
+ sys.executable,
+ "-m",
+ "fastdeploy.entrypoints.openai.api_server",
+ "--model",
+ model_path,
+ "--port",
+ str(FD_API_PORT),
+ "--tensor-parallel-size",
+ "2",
+ "--engine-worker-queue-port",
+ str(FD_ENGINE_QUEUE_PORT),
+ "--metrics-port",
+ str(FD_METRICS_PORT),
+ "--cache-queue-port",
+ str(FD_CACHE_QUEUE_PORT),
+ "--max-model-len",
+ "8192",
+ "--max-num-seqs",
+ "256",
+ "--runner",
+ "pooling",
+ "--convert",
+ "embed",
+ ]
+
+ if enable_caching:
+ cmd.append("--enable-prefix-caching")
+ else:
+ cmd.append("--no-enable-prefix-caching")
+
+ print(f"[Server Setup] Command: {' '.join(cmd)}")
+
+ with open(log_filename, "w") as logfile:
+ process = subprocess.Popen(
+ cmd,
+ stdout=logfile,
+ stderr=subprocess.STDOUT,
+ start_new_session=True,
+ )
+
+ # Wait for server to start
+ for _ in range(300):
+ if is_port_open("127.0.0.1", FD_API_PORT):
+ print(f"[Server Setup] Server is up on port {FD_API_PORT}")
+ break
+ time.sleep(1)
+ else:
+ print("[Server Setup] Server failed to start. Cleaning up...")
+ try:
+ os.killpg(process.pid, signal.SIGTERM)
+ except Exception:
+ pass
+ if os.path.exists(log_filename):
+ with open(log_filename, "r") as f:
+ print(f"Server Log Tail ({log_filename}):\n{f.read()[-500:]}")
+ raise RuntimeError(f"Server did not start on port {FD_API_PORT}")
+
+ return process
+
+
+@pytest.fixture(scope="function")
+def reward_api_url():
+ """Returns the API endpoint URL for reward."""
+ return f"http://0.0.0.0:{FD_API_PORT}/v1/reward"
+
+
+@pytest.fixture(scope="function")
+def headers():
+ """Returns common HTTP request headers."""
+ return {"Content-Type": "application/json"}
+
+
+@pytest.fixture(scope="function")
+def server_default_caching():
+ _start_server_process(enable_caching=True, log_filename="reward_server_caching_on.log")
+
+
+@pytest.fixture(scope="function")
+def server_no_caching():
+ _start_server_process(enable_caching=False, log_filename="reward_server_caching_off.log")
+
+
+def save_score_baseline(score: float, baseline_file: str):
+ """Save reward score to baseline file."""
+ baseline_data = {"score": score}
+ with open(baseline_file, "w", encoding="utf-8") as f:
+ json.dump(baseline_data, f, indent=2)
+ print(f"Baseline saved to: {baseline_file}")
+
+
+def check_score_against_baseline(current_score: float, baseline_file: str, threshold: float = 0.01):
+ """Check reward score against baseline file."""
+ try:
+ with open(baseline_file, "r", encoding="utf-8") as f:
+ baseline_data = json.load(f)
+ baseline_score = baseline_data["score"]
+ except FileNotFoundError:
+ print(f"Baseline file not found: {baseline_file}. Saving current as baseline.")
+ save_score_baseline(current_score, baseline_file)
+ return
+
+ diff = abs(current_score - baseline_score)
+ print(f"Score Difference: {diff:.6f} (Current: {current_score}, Baseline: {baseline_score})")
+
+ if diff >= threshold:
+ temp_file = f"{baseline_file}.current"
+ save_score_baseline(current_score, temp_file)
+ raise AssertionError(
+ f"Score differs from baseline by too much (diff={diff:.6f} >= {threshold}):\n"
+ f"Current score saved to: {temp_file}"
+ )
+
+
+def _run_test_logic(reward_api_url, headers, baseline_filename):
+ payload = {
+ "model": "default",
+ "messages": [
+ {"role": "user", "content": [{"type": "text", "text": "北京天安门在哪里?"}]},
+ {"role": "assistant", "content": [{"type": "text", "text": "北京天安门在中国北京故宫的前面。"}]},
+ ],
+ "user": "user-123",
+ "enable_thinking": False,
+ }
+
+ print(f"\n=== Sending request to {reward_api_url} ===")
+ response = requests.post(reward_api_url, headers=headers, json=payload, timeout=30)
+ assert response.status_code == 200, f"API request failed with status {response.status_code}: {response.text}"
+
+ result = response.json()
+ print(f"Response: {json.dumps(result, indent=2, ensure_ascii=False)}")
+
+ assert "data" in result and len(result["data"]) > 0
+ score = float(result["data"][0]["score"][0])
+ print(f"✓ Reward Score: {score}")
+
+ base_path = os.getenv("MODEL_PATH", "")
+ if base_path:
+ baseline_file = os.path.join(base_path, baseline_filename)
+ else:
+ baseline_file = baseline_filename
+
+ check_score_against_baseline(score, baseline_file, threshold=0.01)
+
+
+def test_reward_model_with_caching(server_default_caching, reward_api_url, headers):
+ print("\n>>> Running Test: WITH Prefix Caching")
+ _run_test_logic(reward_api_url, headers, baseline_filename="reward_score_baseline.json")
+
+
+def test_reward_model_without_caching(server_no_caching, reward_api_url, headers):
+ print("\n>>> Running Test: WITHOUT Prefix Caching")
+ _run_test_logic(reward_api_url, headers, baseline_filename="reward_score_baseline_no_caching.json")
diff --git a/tests/pooling/test_Qwen3-Embedding_serving.py b/tests/pooling/test_Qwen3-Embedding_serving.py
index 69e937593..910f41671 100644
--- a/tests/pooling/test_Qwen3-Embedding_serving.py
+++ b/tests/pooling/test_Qwen3-Embedding_serving.py
@@ -79,7 +79,7 @@ def setup_and_run_embedding_server():
model_path = "./Qwen3-Embedding-0.6B"
if not os.path.exists(model_path):
- pytest.skip(f"Model path not found: {model_path}")
+ raise FileNotFoundError(f"Model path not found: {model_path}")
log_path = "embedding_server.log"
cmd = [