From 9cab3f47ffe3bb57510a67bc9072dca7b3658e7a Mon Sep 17 00:00:00 2001 From: chen <103103266+ckl117@users.noreply.github.com> Date: Mon, 25 Aug 2025 14:11:49 +0800 Subject: [PATCH] [Feature] Add temp_scaled_logprobs and top_p_normalized_logprobs parameters for logits and logprobs post processing (#3552) * [feature] Add temp_scaled_logprobs and top_p_normalized_logprobs parameters for logits and logprobs post processing * infer engine support temp_scaled_logprobs and top_p_normalized_logprobs * delete some code * code check * code check and add doc * fix tokenizer.decoder(-1), return 'Invalid Token' * add ci for temp_scaled and top_p logprobs * check test * check seq len time shape * logprob clip inf --------- Co-authored-by: sunlei1024 --- docs/online_serving/README.md | 11 ++- docs/zh/online_serving/README.md | 11 ++- fastdeploy/engine/sampling_params.py | 3 + fastdeploy/entrypoints/openai/protocol.py | 10 ++ .../model_executor/layers/sample/meta_data.py | 5 +- .../model_executor/layers/sample/sampler.py | 55 ++++++++++- fastdeploy/worker/gpu_model_runner.py | 15 +++ tests/ce/server/test_logprobs.py | 93 +++++++++++++++++++ 8 files changed, 195 insertions(+), 8 deletions(-) diff --git a/docs/online_serving/README.md b/docs/online_serving/README.md index 6cdf1be92..53a8b60c8 100644 --- a/docs/online_serving/README.md +++ b/docs/online_serving/README.md @@ -45,8 +45,9 @@ curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \ -H "Content-Type: application/json" \ -d '{ "messages": [ - {"role": "user", "content": "Hello!"}, "logprobs": true, "top_logprobs": 5 - ] + {"role": "user", "content": "Hello!"} + ], + "logprobs": true, "top_logprobs": 0, }' ``` @@ -193,6 +194,12 @@ max_streaming_response_tokens: Optional[int] = None disable_chat_template: Optional[bool] = False # Whether to disable chat template rendering, using raw input directly (default False means template is enabled). + +temp_scaled_logprobs: Optional[bool] = False +# Whether to divide the logits by the temperature coefficient when calculating logprobs (default is False, meaning the logits are not divided by the temperature coefficient). + +top_p_normalized_logprobs: Optional[bool] = False +# Whether to perform top-p normalization when calculating logprobs (default is False, indicating that top-p normalization is not performed). ``` ### Differences in Return Fields diff --git a/docs/zh/online_serving/README.md b/docs/zh/online_serving/README.md index d55daffc3..45e2168d2 100644 --- a/docs/zh/online_serving/README.md +++ b/docs/zh/online_serving/README.md @@ -45,8 +45,9 @@ curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \ -H "Content-Type: application/json" \ -d '{ "messages": [ - {"role": "user", "content": "Hello!"}, "logprobs": true, "top_logprobs": 5 - ] + {"role": "user", "content": "Hello!"} + ], + "logprobs": true, "top_logprobs": 0, }' ``` @@ -192,6 +193,12 @@ max_streaming_response_tokens: Optional[int] = None disable_chat_template: Optional[bool] = False # 是否禁用聊天模板渲染,直接使用原始输入(默认 False 表示启用模板)。 + +temp_scaled_logprobs: Optional[bool] = False +# 计算logprob时是否对logits除以温度系数(默认 False 表示不除以温度系数)。 + +top_p_normalized_logprobs: Optional[bool] = False +# 计算logprob时是否进行 top_p 归一化(默认 False 表示不进行top_p归一化)。 ``` ### 返回字段差异 diff --git a/fastdeploy/engine/sampling_params.py b/fastdeploy/engine/sampling_params.py index 1cd77d2b1..f95f09bd5 100644 --- a/fastdeploy/engine/sampling_params.py +++ b/fastdeploy/engine/sampling_params.py @@ -98,6 +98,9 @@ class SamplingParams: reasoning_max_tokens: Optional[int] = None min_tokens: int = 1 logprobs: Optional[int] = None + # For logits and logprobs post processing + temp_scaled_logprobs: bool = False + top_p_normalized_logprobs: bool = False bad_words: Optional[List[str]] = None _bad_words_token_ids: Optional[List[int]] = None diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index aae948485..733701eea 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -403,6 +403,9 @@ class CompletionRequest(BaseModel): echo: Optional[bool] = False frequency_penalty: Optional[float] = None logprobs: Optional[int] = None + # For logits and logprobs post processing + temp_scaled_logprobs: bool = False + top_p_normalized_logprobs: bool = False max_tokens: Optional[int] = None n: int = 1 presence_penalty: Optional[float] = None @@ -534,6 +537,11 @@ class ChatCompletionRequest(BaseModel): frequency_penalty: Optional[float] = None logprobs: Optional[bool] = False top_logprobs: Optional[int] = 0 + + # For logits and logprobs post processing + temp_scaled_logprobs: bool = False + top_p_normalized_logprobs: bool = False + # remove max_tokens when field is removed from OpenAI API max_tokens: Optional[int] = Field( default=None, @@ -591,6 +599,8 @@ class ChatCompletionRequest(BaseModel): req_dict["max_tokens"] = self.max_completion_tokens or self.max_tokens req_dict["logprobs"] = self.top_logprobs if self.logprobs else None + req_dict["temp_scaled_logprobs"] = self.temp_scaled_logprobs + req_dict["top_p_normalized_logprobs"] = self.top_p_normalized_logprobs # parse request model into dict, priority: request params > metadata params if self.metadata is not None: diff --git a/fastdeploy/model_executor/layers/sample/meta_data.py b/fastdeploy/model_executor/layers/sample/meta_data.py index 2f79dc48b..03cdf24c2 100644 --- a/fastdeploy/model_executor/layers/sample/meta_data.py +++ b/fastdeploy/model_executor/layers/sample/meta_data.py @@ -15,7 +15,7 @@ """ from dataclasses import dataclass -from typing import Optional +from typing import Dict, Optional import paddle @@ -51,3 +51,6 @@ class SamplingMetadata: stop_flags: Optional[paddle.Tensor] = None prompt_ids: Optional[paddle.Tensor] = None prompt_lens: Optional[paddle.Tensor] = None + temp_scaled_logprobs: Optional[paddle.Tensor] = None + top_p_normalized_logprobs: Optional[paddle.Tensor] = None + share_inputs: Optional[Dict[str, paddle.Tensor]] = None diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 5f7a7d157..5aecfa1f9 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -40,6 +40,18 @@ from fastdeploy.platforms import current_platform from fastdeploy.worker.output import LogprobsTensors, SamplerOutput +def top_p_normalize_probs_paddle( + probs: paddle.Tensor, + top_ps: paddle.Tensor, +): + probs_idx = probs.argsort(axis=-1, descending=True) + probs_sort = paddle.take_along_axis(probs, probs_idx, axis=-1) + probs_sum = paddle.cumsum(probs_sort, axis=-1) + probs_sort = paddle.where((probs_sum - probs_sort) > top_ps, paddle.zeros_like(probs_sort), probs_sort) + probs_sort.divide_(probs_sort.sum(axis=-1, keepdim=True)) + return paddle.zeros_like(probs_sort).put_along_axis_(indices=probs_idx, values=probs_sort, axis=-1) + + class SamplerProcessor: """ SamplingProcessor for guided decoding. @@ -207,9 +219,45 @@ class Sampler(nn.Layer): """pre process before running""" self.processor.pre_process(skip_idx_list) - def compute_logprobs(self, logits: paddle.Tensor) -> paddle.Tensor: + def compute_logprobs( + self, + logits: paddle.Tensor, + sampling_metadata: SamplingMetadata, + ) -> paddle.Tensor: """ """ - return F.log_softmax(logits, axis=-1) + last_logits = logits + real_bsz = last_logits.shape[0] + temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs + top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs + share_inputs = sampling_metadata.share_inputs + if temp_scaled_logprobs is not None: + real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz] + temperature = sampling_metadata.temperature[:real_bsz] + temp_temperature = paddle.where(real_bsz_temp_scaled, temperature, paddle.ones_like(temperature)) + last_logits = last_logits / temp_temperature + + last_logprobs = F.log_softmax(last_logits, axis=-1) + top_p_logprob = None + top_p_req_mask = None + + if top_p_normalized_logprobs is not None and share_inputs is not None: + seq_lens_this_time = share_inputs["seq_lens_this_time"].reshape([-1, 1])[:real_bsz] + seq_lens_encoder = share_inputs["seq_lens_encoder"].reshape([-1, 1])[:real_bsz] + seq_lens_decoder = share_inputs["seq_lens_decoder"].reshape([-1, 1])[:real_bsz] + seq_lens_time_sum = seq_lens_this_time + seq_lens_encoder + seq_lens_decoder + real_req_mask = seq_lens_time_sum > 0 + top_p_req_mask = paddle.logical_and(top_p_normalized_logprobs[:real_bsz], real_req_mask) + real_req_top_p = sampling_metadata.top_p[:real_bsz] + # Normalize logprobs if top_p normalization is enabled + # NOTE: only normalize logprobs when top_p is set and not equal to 1.0 + top_p_req_mask = paddle.logical_and(top_p_req_mask, real_req_top_p != 1.0) + if top_p_req_mask.any(): + probs = F.softmax(last_logits, axis=-1) + probs = top_p_normalize_probs_paddle(probs, real_req_top_p) + top_p_logprob = paddle.log(probs) + if top_p_logprob is not None: + last_logprobs = paddle.where(top_p_req_mask, top_p_logprob, last_logprobs) + return last_logprobs def gather_logprobs( self, @@ -234,6 +282,7 @@ class Sampler(nn.Layer): Sampled token rank tensor, (num tokens) """ assert token_ids.dtype == paddle.int64 + logprobs.clip_(min=paddle.finfo(logprobs.dtype).min) # Get with the logprob of the prompt or sampled token. token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1) @@ -260,7 +309,7 @@ class Sampler(nn.Layer): """ """ num_logprobs = sampling_metadata.max_num_logprobs if num_logprobs is not None: - raw_logprobs = self.compute_logprobs(logits) + raw_logprobs = self.compute_logprobs(logits, sampling_metadata) logits = self.processor.apply_token_mask(logits, skip_idx_list) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index af567cba1..7e7165c74 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -323,6 +323,10 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0) self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0) self.share_inputs["presence_score"][idx : idx + 1] = request.get("presence_penalty", 0.0) + self.share_inputs["temp_scaled_logprobs"][idx : idx + 1] = request.get("temp_scaled_logprobs", False) + self.share_inputs["top_p_normalized_logprobs"][idx : idx + 1] = request.get( + "top_p_normalized_logprobs", False + ) self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1) self.share_inputs["max_dec_len"][idx : idx + 1] = request.get( @@ -496,6 +500,12 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["presence_score"][idx : idx + 1] = get_attr_from_request( request, "presence_penalty", 0.0 ) + self.share_inputs["temp_scaled_logprobs"][idx : idx + 1] = get_attr_from_request( + request, "temp_scaled_logprobs", False + ) + self.share_inputs["top_p_normalized_logprobs"][idx : idx + 1] = get_attr_from_request( + request, "top_p_normalized_logprobs", False + ) self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1) self.share_inputs["max_dec_len"][idx : idx + 1] = request.get( @@ -634,6 +644,8 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["presence_score"] = paddle.full( [max_num_seqs, 1], self.model_config.presence_score, dtype="float32" ) + self.share_inputs["temp_scaled_logprobs"] = paddle.full([max_num_seqs, 1], False, dtype="bool") + self.share_inputs["top_p_normalized_logprobs"] = paddle.full([max_num_seqs, 1], False, dtype="bool") self.share_inputs["min_dec_len"] = paddle.full([max_num_seqs, 1], self.model_config.min_length, dtype="int64") self.share_inputs["max_dec_len"] = paddle.full( @@ -853,6 +865,9 @@ class GPUModelRunner(ModelRunnerBase): max_num_logprobs=20 if self.enable_logprob else None, enable_early_stop=self.enable_early_stop, stop_flags=self.share_inputs["stop_flags"], + temp_scaled_logprobs=self.share_inputs["temp_scaled_logprobs"], + top_p_normalized_logprobs=self.share_inputs["top_p_normalized_logprobs"], + share_inputs=self.share_inputs, ) def load_model(self) -> None: diff --git a/tests/ce/server/test_logprobs.py b/tests/ce/server/test_logprobs.py index 4f3214b55..5f4cb0c45 100644 --- a/tests/ce/server/test_logprobs.py +++ b/tests/ce/server/test_logprobs.py @@ -154,8 +154,101 @@ def test_stream_without_logprobs(): assert result_chunk["choices"][0]["logprobs"] is None +def test_stream_with_temp_scaled_logprobs(): + """ + 测试流式响应开启 temp_scaled_logprobs 后,首个 token 的概率信息是否正确。 + """ + data = { + "stream": True, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "牛顿的三大运动定律是什么?"}, + ], + "max_tokens": 3, + "temperature": 0.8, + "top_p": 0, + "temp_scaled_logprobs": True, + } + + payload = build_request_payload(TEMPLATE, data) + response = send_request(URL, payload) + + # 解析首个包含 content 的流式 chunk + result_chunk = {} + for line in response.iter_lines(): + if not line: + continue + decoded = line.decode("utf-8").removeprefix("data: ") + if decoded == "[DONE]": + break + + chunk = json.loads(decoded) + content = chunk["choices"][0]["delta"].get("content") + if content: + result_chunk = chunk + print(json.dumps(result_chunk, indent=2, ensure_ascii=False)) + break + + # 校验概率字段 + assert result_chunk["choices"][0]["delta"]["content"] == "牛顿" + assert result_chunk["choices"][0]["logprobs"]["content"][0]["token"] == "牛顿" + assert result_chunk["choices"][0]["logprobs"]["content"][0]["logprob"] == -0.006811376195400953 + assert result_chunk["choices"][0]["logprobs"]["content"][0]["top_logprobs"][0] == { + "token": "牛顿", + "logprob": -0.006811376195400953, + "bytes": [231, 137, 155, 233, 161, 191], + } + + +def test_stream_with_top_p_normalized_logprobs(): + """ + 测试流式响应开启 top_p_normalized_logprobs 后,首个 token 的概率信息是否正确。 + """ + data = { + "stream": True, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "牛顿的三大运动定律是什么?"}, + ], + "max_tokens": 3, + "top_p": 0, + "top_p_normalized_logprobs": True, + } + + payload = build_request_payload(TEMPLATE, data) + response = send_request(URL, payload) + + # 解析首个包含 content 的流式 chunk + result_chunk = {} + for line in response.iter_lines(): + if not line: + continue + decoded = line.decode("utf-8").removeprefix("data: ") + if decoded == "[DONE]": + break + + chunk = json.loads(decoded) + content = chunk["choices"][0]["delta"].get("content") + if content: + result_chunk = chunk + print(json.dumps(result_chunk, indent=2, ensure_ascii=False)) + break + + # 校验概率字段 + assert result_chunk["choices"][0]["delta"]["content"] == "牛顿" + assert result_chunk["choices"][0]["logprobs"]["content"][0]["token"] == "牛顿" + assert result_chunk["choices"][0]["logprobs"]["content"][0]["logprob"] == 0.0 + assert result_chunk["choices"][0]["logprobs"]["content"][0]["top_logprobs"][0] == { + "token": "牛顿", + "logprob": 0.0, + "bytes": [231, 137, 155, 233, 161, 191], + } + + if __name__ == "__main__": test_unstream_with_logprobs() test_unstream_without_logprobs() test_stream_with_logprobs() test_stream_without_logprobs() + test_stream_with_temp_scaled_logprobs() + test_stream_with_top_p_normalized_logprobs()