From c424e08dc51674d9fde1f136ddfa0228a324e1cb Mon Sep 17 00:00:00 2001 From: SunLei Date: Thu, 27 Nov 2025 11:22:41 +0800 Subject: [PATCH] [Speculative Decoding] split draft_tokens into standalone post-processing path (#5205) * refactor(mtp): split draft_tokens into standalone post-processing path for MTP + logprobs * Restore Request.__repr__ implementation * ci * add envs * fix unittest --- fastdeploy/engine/request.py | 17 +- fastdeploy/entrypoints/openai/serving_chat.py | 6 + fastdeploy/output/token_processor.py | 98 ++++++++--- .../openai/test_max_streaming_tokens.py | 2 + .../output/test_process_batch_draft_tokens.py | 157 ++++++++++++++++++ .../test_process_batch_output_use_zmq.py | 16 ++ 6 files changed, 269 insertions(+), 27 deletions(-) create mode 100644 tests/output/test_process_batch_draft_tokens.py diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index ecafaec7f..e790455d3 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -25,6 +25,7 @@ from typing import Any, Dict, Generic, Optional, Union import numpy as np from typing_extensions import TypeVar +from fastdeploy import envs from fastdeploy.engine.pooling_params import PoolingParams from fastdeploy.engine.sampling_params import SamplingParams from fastdeploy.entrypoints.openai.protocol import ToolCall @@ -331,8 +332,20 @@ class Request: setattr(self, key, value) def __repr__(self) -> str: - """Safe string representation that ignores private and None fields.""" - return "" + """Sanitized repr without private or None fields.""" + try: + if not envs.FD_DEBUG: + return f"Request(request_id={self.request_id})" + else: + attrs_snapshot = dict(vars(self)) + non_none_fields = [ + f"{attr}={value!r}" + for attr, value in attrs_snapshot.items() + if value is not None and not attr.startswith("_") + ] + return f"Request({', '.join(non_none_fields)})" + except Exception as e: + return f"" @dataclass(slots=True) diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index 208461312..bedfdc5eb 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -571,6 +571,7 @@ class OpenAIServingChat: num_input_video_tokens=num_input_video_tokens, num_image_tokens=num_image_tokens, logprob_contents=logprob_contents, + draft_logprob_contents=draft_logprob_contents, response_processor=response_processor, max_tokens=max_tokens, ) @@ -622,6 +623,7 @@ class OpenAIServingChat: num_input_video_tokens: list, num_image_tokens: list, logprob_contents: list, + draft_logprob_contents: list, response_processor: ChatResponseProcessor, max_tokens: int, ) -> ChatCompletionResponseChoice: @@ -649,6 +651,9 @@ class OpenAIServingChat: logprobs_full_res = None if logprob_contents[idx]: logprobs_full_res = LogProbs(content=logprob_contents[idx]) + draft_logprobs_full_res = None + if draft_logprob_contents[idx]: + draft_logprobs_full_res = LogProbs(content=draft_logprob_contents[idx]) num_cached_tokens[idx] = data.get("num_cached_tokens", 0) num_input_image_tokens[idx] = data.get("num_input_image_tokens", 0) @@ -669,6 +674,7 @@ class OpenAIServingChat: index=idx, message=message, logprobs=logprobs_full_res, + draft_logprobs=draft_logprobs_full_res, finish_reason=finish_reason, ) diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 58a64401e..4aa3ab307 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -527,6 +527,60 @@ class TokenProcessor: self.total_step = 0 self.speculative_stats_step += 1 + def _process_batch_draft_tokens(self, mtype, batch, accept_num, tokens, scores, ranks): + """ + Process batch draft tokens and generate corresponding request outputs + + Args: + mtype (int): Message type (3=target token, 4=draft token) + batch (int): Batch size + accept_num (list): List of accepted token counts per request + tokens (paddle.Tensor): Generated draft token IDs tensor + scores (paddle.Tensor): Token scores tensor + ranks (paddle.Tensor): Token sampling ranks tensor + + Returns: + list[RequestOutput]: List containing processed results for all requests + """ + batch_result = list() + for i in range(batch): + if self.resource_manager.stop_flags[i]: + continue + task = self.resource_manager.tasks_list[i] + task_id = task.request_id + result = RequestOutput( + request_id=task_id, + output_type=mtype, + outputs=CompletionOutput( + index=i, + send_idx=None, + token_ids=[], + draft_token_ids=[], + ), + finished=False, + metrics=None, + ) + + token_ids = tokens[i][:, 0].tolist()[: accept_num[i]] + for batch_token_index in range(len(token_ids)): + result.outputs.logprob = float(scores[i, batch_token_index, 0]) + topk_token_ids = tokens[i, batch_token_index, :].tolist() + topk_logprobs = scores[i, batch_token_index, :].tolist() + sampled_rank = ranks[i, batch_token_index].item() + + if result.outputs.draft_top_logprobs is None: + result.outputs.draft_top_logprobs = LogprobsLists( + logprob_token_ids=[topk_token_ids], + logprobs=[topk_logprobs], + sampled_token_ranks=[sampled_rank], + ) + else: + result.outputs.draft_top_logprobs.logprob_token_ids.extend([topk_token_ids]) + result.outputs.draft_top_logprobs.logprobs.extend([topk_logprobs]) + result.outputs.draft_top_logprobs.sampled_token_ranks.extend([sampled_rank]) + batch_result.append(result) + return batch_result + def _process_batch_output(self): """ batch post-processing function @@ -551,6 +605,12 @@ class TokenProcessor: .reshape([batch, MAX_DRAFT_TOKENS, K + 1]) ) ranks = self.output_ranks[: batch * MAX_DRAFT_TOKENS].numpy().reshape([batch, MAX_DRAFT_TOKENS]) + + # split draft_tokens into standalone post-processing path for MTP + logprobs + if mtype == 4: + batch_result = self._process_batch_draft_tokens(mtype, batch, accept_num, tokens, scores, ranks) + self.postprocess(batch_result, mtype) + return else: batch = self.output_tokens[1] accept_num = tokens[2 : batch + 2] @@ -678,8 +738,7 @@ class TokenProcessor: if not (envs.FD_ENABLE_INTERNAL_ADAPTER and token_id in task.eos_token_ids): result.outputs.token_ids.append(token_id) - if mtype == 3: - task.output_token_ids.append(token_id) + task.output_token_ids.append(token_id) if self.use_logprobs: if self.cfg.speculative_config.method: @@ -693,29 +752,18 @@ class TokenProcessor: topk_logprobs = scores[i, :].tolist() sampled_rank = ranks[i].item() - if mtype == 3: # top_logprobs - if result.outputs.top_logprobs is None: - result.outputs.top_logprobs = LogprobsLists( - logprob_token_ids=[topk_token_ids], - logprobs=[topk_logprobs], - sampled_token_ranks=[sampled_rank], - ) - else: - result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids]) - result.outputs.top_logprobs.logprobs.extend([topk_logprobs]) - result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank]) - elif mtype == 4: # draft_top_logprobs - if result.outputs.draft_top_logprobs is None: - result.outputs.draft_top_logprobs = LogprobsLists( - logprob_token_ids=[topk_token_ids], - logprobs=[topk_logprobs], - sampled_token_ranks=[sampled_rank], - ) - else: - result.outputs.draft_top_logprobs.logprob_token_ids.extend([topk_token_ids]) - result.outputs.draft_top_logprobs.logprobs.extend([topk_logprobs]) - result.outputs.draft_top_logprobs.sampled_token_ranks.extend([sampled_rank]) - if mtype == 3 and (token_id in task.eos_token_ids or is_prefill or recovery_stop): + if result.outputs.top_logprobs is None: + result.outputs.top_logprobs = LogprobsLists( + logprob_token_ids=[topk_token_ids], + logprobs=[topk_logprobs], + sampled_token_ranks=[sampled_rank], + ) + else: + result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids]) + result.outputs.top_logprobs.logprobs.extend([topk_logprobs]) + result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank]) + + if token_id in task.eos_token_ids or is_prefill or recovery_stop: result.finished = True if recovery_stop: result.error_msg = "Recover is not supported, the result is incomplete!" diff --git a/tests/entrypoints/openai/test_max_streaming_tokens.py b/tests/entrypoints/openai/test_max_streaming_tokens.py index 1e728aa3b..7ce6df13e 100644 --- a/tests/entrypoints/openai/test_max_streaming_tokens.py +++ b/tests/entrypoints/openai/test_max_streaming_tokens.py @@ -445,6 +445,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase): prompt_token_ids = [1, 2] prompt_tokens = "test_prompt" logprob_contents = [[{"token": "hello", "logprob": 0.1}], [{"token": "hello", "logprob": 0.1}]] + draft_logprob_contents = [[{"token": "hello", "logprob": 0.1}], [{"token": "hello", "logprob": 0.1}]] mock_response_processor = Mock() mock_response_processor.enable_multimodal_content.return_value = False completion_token_ids = [[], []] @@ -467,6 +468,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase): num_input_video_tokens=num_input_video_tokens, num_image_tokens=num_image_tokens, logprob_contents=logprob_contents, + draft_logprob_contents=draft_logprob_contents, response_processor=mock_response_processor, max_tokens=max_tokens_list[idx], ) diff --git a/tests/output/test_process_batch_draft_tokens.py b/tests/output/test_process_batch_draft_tokens.py new file mode 100644 index 000000000..3686dd1b6 --- /dev/null +++ b/tests/output/test_process_batch_draft_tokens.py @@ -0,0 +1,157 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import unittest +from unittest.mock import MagicMock + +import numpy as np +import paddle + +from fastdeploy.engine.request import RequestOutput +from fastdeploy.output.token_processor import TokenProcessor + + +class TestProcessBatchDraftTokens(unittest.TestCase): + + def setUp(self): + # 模拟 cfg + cfg = MagicMock() + cfg.speculative_config = MagicMock() + cfg.speculative_config.method = "mtp" + cfg.speculative_config.num_speculative_tokens = 3 + cfg.model_config = MagicMock() + cfg.model_config.enable_logprob = True + + self.processor = TokenProcessor( + cfg=cfg, cached_generated_tokens=MagicMock(), engine_worker_queue=MagicMock(), split_connector=MagicMock() + ) + + # mock resource_manager + self.processor.resource_manager = MagicMock() + self.processor.resource_manager.stop_flags = [False] * 512 + self.processor.resource_manager.tasks_list = [MagicMock()] * 512 + + for task in self.processor.resource_manager.tasks_list: + task.request_id = "test_request" + task.eos_token_ids = [2] + + def test_process_batch_draft_tokens_normal_case(self): + """测试正常情况下的target处理""" + batch = 2 + accept_num = [3, 2] + K = 20 + MAX_DRAFT_TOKENS = 6 + + tokens = np.random.randint(100, 200, size=(batch, MAX_DRAFT_TOKENS, K + 1)) + scores = np.random.rand(batch, MAX_DRAFT_TOKENS, K + 1).astype(np.float32) + ranks = np.random.randint(0, K, size=(batch, MAX_DRAFT_TOKENS)) + + results = self.processor._process_batch_draft_tokens( + mtype=4, + batch=batch, + accept_num=accept_num, + tokens=paddle.to_tensor(tokens), + scores=paddle.to_tensor(scores), + ranks=paddle.to_tensor(ranks), + ) + + self.assertEqual(len(results), batch) + for i, result in enumerate(results): + self.assertIsInstance(result, RequestOutput) + self.assertEqual(result.output_type, 4) + self.assertEqual(result.outputs.index, i) + self.assertEqual(len(result.outputs.draft_top_logprobs.logprob_token_ids), accept_num[i]) + self.assertEqual(len(result.outputs.draft_top_logprobs.logprobs), accept_num[i]) + self.assertEqual(len(result.outputs.draft_top_logprobs.sampled_token_ranks), accept_num[i]) + + def test_process_batch_draft_tokens_with_stop_flag(self): + """测试有停止标志的情况""" + batch = 3 + self.processor.resource_manager.stop_flags[1] = True # 第二个 request 停止 + + accept_num = [3, 2, 1] + K = 20 + MAX_DRAFT_TOKENS = 6 + + tokens = np.random.randint(100, 200, size=(batch, MAX_DRAFT_TOKENS, K + 1)) + scores = np.random.rand(batch, MAX_DRAFT_TOKENS, K + 1).astype(np.float32) + ranks = np.random.randint(0, K, size=(batch, MAX_DRAFT_TOKENS)) + + results = self.processor._process_batch_draft_tokens( + mtype=4, + batch=batch, + accept_num=accept_num, + tokens=paddle.to_tensor(tokens), + scores=paddle.to_tensor(scores), + ranks=paddle.to_tensor(ranks), + ) + + self.assertEqual(len(results), 2) + self.assertEqual(results[0].outputs.index, 0) + self.assertEqual(results[1].outputs.index, 2) + + def test_process_batch_draft_tokens_empty_accept(self): + """测试 accept_num 为 0 的情况""" + batch = 2 + accept_num = [0, 0] + + K = 20 + MAX_DRAFT_TOKENS = 6 + tokens = np.random.randint(100, 200, size=(batch, MAX_DRAFT_TOKENS, K + 1)) + scores = np.random.rand(batch, MAX_DRAFT_TOKENS, K + 1).astype(np.float32) + ranks = np.random.randint(0, K, size=(batch, MAX_DRAFT_TOKENS)) + + results = self.processor._process_batch_draft_tokens( + mtype=4, + batch=batch, + accept_num=accept_num, + tokens=paddle.to_tensor(tokens), + scores=paddle.to_tensor(scores), + ranks=paddle.to_tensor(ranks), + ) + + self.assertEqual(len(results), batch) + for result in results: + self.assertIsNone(result.outputs.draft_top_logprobs) + + def test_process_batch_draft_tokens_different_k_values(self): + """测试不同 K 值情况""" + batch = 2 + accept_num = [3, 2] + + K = 5 + MAX_DRAFT_TOKENS = 6 + tokens = np.random.randint(100, 200, size=(batch, MAX_DRAFT_TOKENS, K + 1)) + scores = np.random.rand(batch, MAX_DRAFT_TOKENS, K + 1).astype(np.float32) + ranks = np.random.randint(0, K, size=(batch, MAX_DRAFT_TOKENS)) + + results = self.processor._process_batch_draft_tokens( + mtype=4, + batch=batch, + accept_num=accept_num, + tokens=paddle.to_tensor(tokens), + scores=paddle.to_tensor(scores), + ranks=paddle.to_tensor(ranks), + ) + + self.assertEqual(len(results), batch) + for i, result in enumerate(results): + self.assertEqual(len(result.outputs.draft_top_logprobs.logprob_token_ids[0]), K + 1) + self.assertEqual(len(result.outputs.draft_top_logprobs.logprobs[0]), K + 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/output/test_process_batch_output_use_zmq.py b/tests/output/test_process_batch_output_use_zmq.py index 8be8c5b77..e58f613d1 100644 --- a/tests/output/test_process_batch_output_use_zmq.py +++ b/tests/output/test_process_batch_output_use_zmq.py @@ -1,3 +1,19 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + import unittest from unittest.mock import MagicMock, patch