diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 0270398a9..7580cc584 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -23,6 +23,7 @@ from typing import Any, Dict, Optional, Union import numpy as np +from fastdeploy import envs from fastdeploy.engine.sampling_params import SamplingParams from fastdeploy.entrypoints.openai.protocol import ToolCall from fastdeploy.utils import data_processor_logger @@ -273,7 +274,20 @@ class Request: setattr(self, key, value) def __repr__(self) -> str: - return "" + """Safe string representation that ignores private and None fields.""" + try: + if not envs.FD_DEBUG: + return f"Request(request_id={self.request_id})" + else: + attrs_snapshot = dict(vars(self)) + non_none_fields = [ + f"{attr}={value!r}" + for attr, value in attrs_snapshot.items() + if value is not None and not attr.startswith("_") + ] + return f"Request({', '.join(non_none_fields)})" + except Exception as e: + return f"" @dataclass(slots=True) diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 0861ccb8c..4710438a5 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -338,6 +338,60 @@ class TokenProcessor: self.total_step = 0 self.speculative_stats_step += 1 + def _process_batch_draft_tokens(self, mtype, batch, accept_num, tokens, scores, ranks): + """ + Process batch draft tokens and generate corresponding request outputs + + Args: + mtype (int): Message type (3=target token, 4=draft token) + batch (int): Batch size + accept_num (list): List of accepted token counts per request + tokens (paddle.Tensor): Generated draft token IDs tensor + scores (paddle.Tensor): Token scores tensor + ranks (paddle.Tensor): Token sampling ranks tensor + + Returns: + list[RequestOutput]: List containing processed results for all requests + """ + batch_result = list() + for i in range(batch): + if self.resource_manager.stop_flags[i]: + continue + task = self.resource_manager.tasks_list[i] + task_id = task.request_id + result = RequestOutput( + request_id=task_id, + output_type=mtype, + outputs=CompletionOutput( + index=i, + send_idx=None, + token_ids=[], + draft_token_ids=[], + ), + finished=False, + metrics=None, + ) + + token_ids = tokens[i][:, 0].tolist()[: accept_num[i]] + for batch_token_index in range(len(token_ids)): + result.outputs.logprob = float(scores[i, batch_token_index, 0]) + topk_token_ids = tokens[i, batch_token_index, :].tolist() + topk_logprobs = scores[i, batch_token_index, :].tolist() + sampled_rank = ranks[i, batch_token_index].item() + + if result.outputs.draft_top_logprobs is None: + result.outputs.draft_top_logprobs = LogprobsLists( + logprob_token_ids=[topk_token_ids], + logprobs=[topk_logprobs], + sampled_token_ranks=[sampled_rank], + ) + else: + result.outputs.draft_top_logprobs.logprob_token_ids.extend([topk_token_ids]) + result.outputs.draft_top_logprobs.logprobs.extend([topk_logprobs]) + result.outputs.draft_top_logprobs.sampled_token_ranks.extend([sampled_rank]) + batch_result.append(result) + return batch_result + def _process_batch_output(self): """ batch post-processing function @@ -362,6 +416,12 @@ class TokenProcessor: .reshape([batch, MAX_DRAFT_TOKENS, K + 1]) ) ranks = self.output_ranks[: batch * MAX_DRAFT_TOKENS].numpy().reshape([batch, MAX_DRAFT_TOKENS]) + + # split draft_tokens into standalone post-processing path for MTP + logprobs + if mtype == 4: + batch_result = self._process_batch_draft_tokens(mtype, batch, accept_num, tokens, scores, ranks) + self.postprocess(batch_result, mtype) + return else: batch = self.output_tokens[1] accept_num = tokens[2 : batch + 2] @@ -479,9 +539,11 @@ class TokenProcessor: token_id = token_ids[batch_token_index] self.tokens_counter[task_id] += 1 if token_id != RECOVERY_STOP_SIGNAL: - result.outputs.token_ids.append(token_id) - if mtype == 3: # target_tokens - task.output_token_ids.append(token_id) + if not (envs.FD_ENABLE_INTERNAL_ADAPTER and token_id in task.eos_token_ids): + result.outputs.token_ids.append(token_id) + + task.output_token_ids.append(token_id) + if self.use_logprobs: if self.cfg.speculative_config.method: result.outputs.logprob = float(scores[i, batch_token_index, 0]) @@ -494,29 +556,18 @@ class TokenProcessor: topk_logprobs = scores[i, :].tolist() sampled_rank = ranks[i].item() - if mtype == 3: # top_logprobs - if result.outputs.top_logprobs is None: - result.outputs.top_logprobs = LogprobsLists( - logprob_token_ids=[topk_token_ids], - logprobs=[topk_logprobs], - sampled_token_ranks=[sampled_rank], - ) - else: - result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids]) - result.outputs.top_logprobs.logprobs.extend([topk_logprobs]) - result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank]) - elif mtype == 4: # draft_top_logprobs - if result.outputs.draft_top_logprobs is None: - result.outputs.draft_top_logprobs = LogprobsLists( - logprob_token_ids=[topk_token_ids], - logprobs=[topk_logprobs], - sampled_token_ranks=[sampled_rank], - ) - else: - result.outputs.draft_top_logprobs.logprob_token_ids.extend([topk_token_ids]) - result.outputs.draft_top_logprobs.logprobs.extend([topk_logprobs]) - result.outputs.draft_top_logprobs.sampled_token_ranks.extend([sampled_rank]) - if mtype == 3 and (token_id in task.eos_token_ids or is_prefill or recovery_stop): + if result.outputs.top_logprobs is None: + result.outputs.top_logprobs = LogprobsLists( + logprob_token_ids=[topk_token_ids], + logprobs=[topk_logprobs], + sampled_token_ranks=[sampled_rank], + ) + else: + result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids]) + result.outputs.top_logprobs.logprobs.extend([topk_logprobs]) + result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank]) + + if token_id in task.eos_token_ids or is_prefill or recovery_stop: result.finished = True if recovery_stop: result.error_msg = "Recover is not supported, the result is incomplete!" diff --git a/tests/output/test_process_batch_draft_tokens.py b/tests/output/test_process_batch_draft_tokens.py new file mode 100644 index 000000000..3686dd1b6 --- /dev/null +++ b/tests/output/test_process_batch_draft_tokens.py @@ -0,0 +1,157 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import unittest +from unittest.mock import MagicMock + +import numpy as np +import paddle + +from fastdeploy.engine.request import RequestOutput +from fastdeploy.output.token_processor import TokenProcessor + + +class TestProcessBatchDraftTokens(unittest.TestCase): + + def setUp(self): + # 模拟 cfg + cfg = MagicMock() + cfg.speculative_config = MagicMock() + cfg.speculative_config.method = "mtp" + cfg.speculative_config.num_speculative_tokens = 3 + cfg.model_config = MagicMock() + cfg.model_config.enable_logprob = True + + self.processor = TokenProcessor( + cfg=cfg, cached_generated_tokens=MagicMock(), engine_worker_queue=MagicMock(), split_connector=MagicMock() + ) + + # mock resource_manager + self.processor.resource_manager = MagicMock() + self.processor.resource_manager.stop_flags = [False] * 512 + self.processor.resource_manager.tasks_list = [MagicMock()] * 512 + + for task in self.processor.resource_manager.tasks_list: + task.request_id = "test_request" + task.eos_token_ids = [2] + + def test_process_batch_draft_tokens_normal_case(self): + """测试正常情况下的target处理""" + batch = 2 + accept_num = [3, 2] + K = 20 + MAX_DRAFT_TOKENS = 6 + + tokens = np.random.randint(100, 200, size=(batch, MAX_DRAFT_TOKENS, K + 1)) + scores = np.random.rand(batch, MAX_DRAFT_TOKENS, K + 1).astype(np.float32) + ranks = np.random.randint(0, K, size=(batch, MAX_DRAFT_TOKENS)) + + results = self.processor._process_batch_draft_tokens( + mtype=4, + batch=batch, + accept_num=accept_num, + tokens=paddle.to_tensor(tokens), + scores=paddle.to_tensor(scores), + ranks=paddle.to_tensor(ranks), + ) + + self.assertEqual(len(results), batch) + for i, result in enumerate(results): + self.assertIsInstance(result, RequestOutput) + self.assertEqual(result.output_type, 4) + self.assertEqual(result.outputs.index, i) + self.assertEqual(len(result.outputs.draft_top_logprobs.logprob_token_ids), accept_num[i]) + self.assertEqual(len(result.outputs.draft_top_logprobs.logprobs), accept_num[i]) + self.assertEqual(len(result.outputs.draft_top_logprobs.sampled_token_ranks), accept_num[i]) + + def test_process_batch_draft_tokens_with_stop_flag(self): + """测试有停止标志的情况""" + batch = 3 + self.processor.resource_manager.stop_flags[1] = True # 第二个 request 停止 + + accept_num = [3, 2, 1] + K = 20 + MAX_DRAFT_TOKENS = 6 + + tokens = np.random.randint(100, 200, size=(batch, MAX_DRAFT_TOKENS, K + 1)) + scores = np.random.rand(batch, MAX_DRAFT_TOKENS, K + 1).astype(np.float32) + ranks = np.random.randint(0, K, size=(batch, MAX_DRAFT_TOKENS)) + + results = self.processor._process_batch_draft_tokens( + mtype=4, + batch=batch, + accept_num=accept_num, + tokens=paddle.to_tensor(tokens), + scores=paddle.to_tensor(scores), + ranks=paddle.to_tensor(ranks), + ) + + self.assertEqual(len(results), 2) + self.assertEqual(results[0].outputs.index, 0) + self.assertEqual(results[1].outputs.index, 2) + + def test_process_batch_draft_tokens_empty_accept(self): + """测试 accept_num 为 0 的情况""" + batch = 2 + accept_num = [0, 0] + + K = 20 + MAX_DRAFT_TOKENS = 6 + tokens = np.random.randint(100, 200, size=(batch, MAX_DRAFT_TOKENS, K + 1)) + scores = np.random.rand(batch, MAX_DRAFT_TOKENS, K + 1).astype(np.float32) + ranks = np.random.randint(0, K, size=(batch, MAX_DRAFT_TOKENS)) + + results = self.processor._process_batch_draft_tokens( + mtype=4, + batch=batch, + accept_num=accept_num, + tokens=paddle.to_tensor(tokens), + scores=paddle.to_tensor(scores), + ranks=paddle.to_tensor(ranks), + ) + + self.assertEqual(len(results), batch) + for result in results: + self.assertIsNone(result.outputs.draft_top_logprobs) + + def test_process_batch_draft_tokens_different_k_values(self): + """测试不同 K 值情况""" + batch = 2 + accept_num = [3, 2] + + K = 5 + MAX_DRAFT_TOKENS = 6 + tokens = np.random.randint(100, 200, size=(batch, MAX_DRAFT_TOKENS, K + 1)) + scores = np.random.rand(batch, MAX_DRAFT_TOKENS, K + 1).astype(np.float32) + ranks = np.random.randint(0, K, size=(batch, MAX_DRAFT_TOKENS)) + + results = self.processor._process_batch_draft_tokens( + mtype=4, + batch=batch, + accept_num=accept_num, + tokens=paddle.to_tensor(tokens), + scores=paddle.to_tensor(scores), + ranks=paddle.to_tensor(ranks), + ) + + self.assertEqual(len(results), batch) + for i, result in enumerate(results): + self.assertEqual(len(result.outputs.draft_top_logprobs.logprob_token_ids[0]), K + 1) + self.assertEqual(len(result.outputs.draft_top_logprobs.logprobs[0]), K + 1) + + +if __name__ == "__main__": + unittest.main()