[Cherry-Pick] MTP split draft_tokens into standalone post-processing path(#5205) (#5232)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* merge code

* fix Request CONFLICT

* remove unuse unittest

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
SunLei
2025-11-27 15:30:00 +08:00
committed by GitHub
parent bbcd92c8a0
commit f637ba708c
3 changed files with 249 additions and 27 deletions

View File

@@ -23,6 +23,7 @@ from typing import Any, Dict, Optional, Union
import numpy as np
from fastdeploy import envs
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.openai.protocol import ToolCall
from fastdeploy.utils import data_processor_logger
@@ -273,7 +274,20 @@ class Request:
setattr(self, key, value)
def __repr__(self) -> str:
return ""
"""Safe string representation that ignores private and None fields."""
try:
if not envs.FD_DEBUG:
return f"Request(request_id={self.request_id})"
else:
attrs_snapshot = dict(vars(self))
non_none_fields = [
f"{attr}={value!r}"
for attr, value in attrs_snapshot.items()
if value is not None and not attr.startswith("_")
]
return f"Request({', '.join(non_none_fields)})"
except Exception as e:
return f"<Request repr failed: {e}>"
@dataclass(slots=True)

View File

@@ -338,6 +338,60 @@ class TokenProcessor:
self.total_step = 0
self.speculative_stats_step += 1
def _process_batch_draft_tokens(self, mtype, batch, accept_num, tokens, scores, ranks):
"""
Process batch draft tokens and generate corresponding request outputs
Args:
mtype (int): Message type (3=target token, 4=draft token)
batch (int): Batch size
accept_num (list): List of accepted token counts per request
tokens (paddle.Tensor): Generated draft token IDs tensor
scores (paddle.Tensor): Token scores tensor
ranks (paddle.Tensor): Token sampling ranks tensor
Returns:
list[RequestOutput]: List containing processed results for all requests
"""
batch_result = list()
for i in range(batch):
if self.resource_manager.stop_flags[i]:
continue
task = self.resource_manager.tasks_list[i]
task_id = task.request_id
result = RequestOutput(
request_id=task_id,
output_type=mtype,
outputs=CompletionOutput(
index=i,
send_idx=None,
token_ids=[],
draft_token_ids=[],
),
finished=False,
metrics=None,
)
token_ids = tokens[i][:, 0].tolist()[: accept_num[i]]
for batch_token_index in range(len(token_ids)):
result.outputs.logprob = float(scores[i, batch_token_index, 0])
topk_token_ids = tokens[i, batch_token_index, :].tolist()
topk_logprobs = scores[i, batch_token_index, :].tolist()
sampled_rank = ranks[i, batch_token_index].item()
if result.outputs.draft_top_logprobs is None:
result.outputs.draft_top_logprobs = LogprobsLists(
logprob_token_ids=[topk_token_ids],
logprobs=[topk_logprobs],
sampled_token_ranks=[sampled_rank],
)
else:
result.outputs.draft_top_logprobs.logprob_token_ids.extend([topk_token_ids])
result.outputs.draft_top_logprobs.logprobs.extend([topk_logprobs])
result.outputs.draft_top_logprobs.sampled_token_ranks.extend([sampled_rank])
batch_result.append(result)
return batch_result
def _process_batch_output(self):
"""
batch post-processing function
@@ -362,6 +416,12 @@ class TokenProcessor:
.reshape([batch, MAX_DRAFT_TOKENS, K + 1])
)
ranks = self.output_ranks[: batch * MAX_DRAFT_TOKENS].numpy().reshape([batch, MAX_DRAFT_TOKENS])
# split draft_tokens into standalone post-processing path for MTP + logprobs
if mtype == 4:
batch_result = self._process_batch_draft_tokens(mtype, batch, accept_num, tokens, scores, ranks)
self.postprocess(batch_result, mtype)
return
else:
batch = self.output_tokens[1]
accept_num = tokens[2 : batch + 2]
@@ -479,9 +539,11 @@ class TokenProcessor:
token_id = token_ids[batch_token_index]
self.tokens_counter[task_id] += 1
if token_id != RECOVERY_STOP_SIGNAL:
result.outputs.token_ids.append(token_id)
if mtype == 3: # target_tokens
task.output_token_ids.append(token_id)
if not (envs.FD_ENABLE_INTERNAL_ADAPTER and token_id in task.eos_token_ids):
result.outputs.token_ids.append(token_id)
task.output_token_ids.append(token_id)
if self.use_logprobs:
if self.cfg.speculative_config.method:
result.outputs.logprob = float(scores[i, batch_token_index, 0])
@@ -494,29 +556,18 @@ class TokenProcessor:
topk_logprobs = scores[i, :].tolist()
sampled_rank = ranks[i].item()
if mtype == 3: # top_logprobs
if result.outputs.top_logprobs is None:
result.outputs.top_logprobs = LogprobsLists(
logprob_token_ids=[topk_token_ids],
logprobs=[topk_logprobs],
sampled_token_ranks=[sampled_rank],
)
else:
result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids])
result.outputs.top_logprobs.logprobs.extend([topk_logprobs])
result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank])
elif mtype == 4: # draft_top_logprobs
if result.outputs.draft_top_logprobs is None:
result.outputs.draft_top_logprobs = LogprobsLists(
logprob_token_ids=[topk_token_ids],
logprobs=[topk_logprobs],
sampled_token_ranks=[sampled_rank],
)
else:
result.outputs.draft_top_logprobs.logprob_token_ids.extend([topk_token_ids])
result.outputs.draft_top_logprobs.logprobs.extend([topk_logprobs])
result.outputs.draft_top_logprobs.sampled_token_ranks.extend([sampled_rank])
if mtype == 3 and (token_id in task.eos_token_ids or is_prefill or recovery_stop):
if result.outputs.top_logprobs is None:
result.outputs.top_logprobs = LogprobsLists(
logprob_token_ids=[topk_token_ids],
logprobs=[topk_logprobs],
sampled_token_ranks=[sampled_rank],
)
else:
result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids])
result.outputs.top_logprobs.logprobs.extend([topk_logprobs])
result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank])
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
result.finished = True
if recovery_stop:
result.error_msg = "Recover is not supported, the result is incomplete!"

View File

@@ -0,0 +1,157 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import unittest
from unittest.mock import MagicMock
import numpy as np
import paddle
from fastdeploy.engine.request import RequestOutput
from fastdeploy.output.token_processor import TokenProcessor
class TestProcessBatchDraftTokens(unittest.TestCase):
def setUp(self):
# 模拟 cfg
cfg = MagicMock()
cfg.speculative_config = MagicMock()
cfg.speculative_config.method = "mtp"
cfg.speculative_config.num_speculative_tokens = 3
cfg.model_config = MagicMock()
cfg.model_config.enable_logprob = True
self.processor = TokenProcessor(
cfg=cfg, cached_generated_tokens=MagicMock(), engine_worker_queue=MagicMock(), split_connector=MagicMock()
)
# mock resource_manager
self.processor.resource_manager = MagicMock()
self.processor.resource_manager.stop_flags = [False] * 512
self.processor.resource_manager.tasks_list = [MagicMock()] * 512
for task in self.processor.resource_manager.tasks_list:
task.request_id = "test_request"
task.eos_token_ids = [2]
def test_process_batch_draft_tokens_normal_case(self):
"""测试正常情况下的target处理"""
batch = 2
accept_num = [3, 2]
K = 20
MAX_DRAFT_TOKENS = 6
tokens = np.random.randint(100, 200, size=(batch, MAX_DRAFT_TOKENS, K + 1))
scores = np.random.rand(batch, MAX_DRAFT_TOKENS, K + 1).astype(np.float32)
ranks = np.random.randint(0, K, size=(batch, MAX_DRAFT_TOKENS))
results = self.processor._process_batch_draft_tokens(
mtype=4,
batch=batch,
accept_num=accept_num,
tokens=paddle.to_tensor(tokens),
scores=paddle.to_tensor(scores),
ranks=paddle.to_tensor(ranks),
)
self.assertEqual(len(results), batch)
for i, result in enumerate(results):
self.assertIsInstance(result, RequestOutput)
self.assertEqual(result.output_type, 4)
self.assertEqual(result.outputs.index, i)
self.assertEqual(len(result.outputs.draft_top_logprobs.logprob_token_ids), accept_num[i])
self.assertEqual(len(result.outputs.draft_top_logprobs.logprobs), accept_num[i])
self.assertEqual(len(result.outputs.draft_top_logprobs.sampled_token_ranks), accept_num[i])
def test_process_batch_draft_tokens_with_stop_flag(self):
"""测试有停止标志的情况"""
batch = 3
self.processor.resource_manager.stop_flags[1] = True # 第二个 request 停止
accept_num = [3, 2, 1]
K = 20
MAX_DRAFT_TOKENS = 6
tokens = np.random.randint(100, 200, size=(batch, MAX_DRAFT_TOKENS, K + 1))
scores = np.random.rand(batch, MAX_DRAFT_TOKENS, K + 1).astype(np.float32)
ranks = np.random.randint(0, K, size=(batch, MAX_DRAFT_TOKENS))
results = self.processor._process_batch_draft_tokens(
mtype=4,
batch=batch,
accept_num=accept_num,
tokens=paddle.to_tensor(tokens),
scores=paddle.to_tensor(scores),
ranks=paddle.to_tensor(ranks),
)
self.assertEqual(len(results), 2)
self.assertEqual(results[0].outputs.index, 0)
self.assertEqual(results[1].outputs.index, 2)
def test_process_batch_draft_tokens_empty_accept(self):
"""测试 accept_num 为 0 的情况"""
batch = 2
accept_num = [0, 0]
K = 20
MAX_DRAFT_TOKENS = 6
tokens = np.random.randint(100, 200, size=(batch, MAX_DRAFT_TOKENS, K + 1))
scores = np.random.rand(batch, MAX_DRAFT_TOKENS, K + 1).astype(np.float32)
ranks = np.random.randint(0, K, size=(batch, MAX_DRAFT_TOKENS))
results = self.processor._process_batch_draft_tokens(
mtype=4,
batch=batch,
accept_num=accept_num,
tokens=paddle.to_tensor(tokens),
scores=paddle.to_tensor(scores),
ranks=paddle.to_tensor(ranks),
)
self.assertEqual(len(results), batch)
for result in results:
self.assertIsNone(result.outputs.draft_top_logprobs)
def test_process_batch_draft_tokens_different_k_values(self):
"""测试不同 K 值情况"""
batch = 2
accept_num = [3, 2]
K = 5
MAX_DRAFT_TOKENS = 6
tokens = np.random.randint(100, 200, size=(batch, MAX_DRAFT_TOKENS, K + 1))
scores = np.random.rand(batch, MAX_DRAFT_TOKENS, K + 1).astype(np.float32)
ranks = np.random.randint(0, K, size=(batch, MAX_DRAFT_TOKENS))
results = self.processor._process_batch_draft_tokens(
mtype=4,
batch=batch,
accept_num=accept_num,
tokens=paddle.to_tensor(tokens),
scores=paddle.to_tensor(scores),
ranks=paddle.to_tensor(ranks),
)
self.assertEqual(len(results), batch)
for i, result in enumerate(results):
self.assertEqual(len(result.outputs.draft_top_logprobs.logprob_token_ids[0]), K + 1)
self.assertEqual(len(result.outputs.draft_top_logprobs.logprobs[0]), K + 1)
if __name__ == "__main__":
unittest.main()