From d5a3c5c933ad0cbbae70cff6bbf9c317c0681ebc Mon Sep 17 00:00:00 2001 From: sunlei1024 Date: Fri, 26 Sep 2025 13:07:48 +0800 Subject: [PATCH] feat: add draft_logprobs for Speculative Decode MTP --- fastdeploy/engine/request.py | 4 + fastdeploy/entrypoints/openai/protocol.py | 2 + fastdeploy/entrypoints/openai/serving_chat.py | 16 ++ .../entrypoints/openai/serving_completion.py | 16 ++ fastdeploy/output/token_processor.py | 112 +++++++++--- tests/output/test_process_batch_output.py | 167 ++++++++++++++++++ 6 files changed, 293 insertions(+), 24 deletions(-) create mode 100644 tests/output/test_process_batch_output.py diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 04a2276af..0cade6973 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -287,6 +287,7 @@ class CompletionOutput: token_ids: list[int] logprob: Optional[float] = None top_logprobs: Optional[LogprobsLists] = None + draft_top_logprobs: Optional[LogprobsLists] = None logprobs: Optional[SampleLogprobs] = None draft_token_ids: list[int] = None text: Optional[str] = None @@ -412,6 +413,7 @@ class RequestOutput: request_id: str, prompt: Optional[str] = None, prompt_token_ids: Optional[list[int]] = None, + output_type: Optional[int] = 3, outputs: CompletionOutput = None, finished: bool = False, metrics: Optional[RequestMetrics] = None, @@ -456,6 +458,7 @@ class RequestOutput: f"RequestOutput(request_id={self.request_id}, " f"prompt={self.prompt!r}, " f"prompt_token_ids={self.prompt_token_ids}, " + f"output_type={self.output_type}, " f"outputs={self.outputs}, " f"finished={self.finished}, " f"num_cached_tokens={self.num_cached_tokens}, " @@ -476,6 +479,7 @@ class RequestOutput: "request_id": self.request_id, "prompt": self.prompt, "prompt_token_ids": self.prompt_token_ids, + "output_type": self.output_type, "outputs": None if self.outputs is None else self.outputs.to_dict(), "metrics": None if self.metrics is None else self.metrics.to_dict(), "finished": self.finished, diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index b74e0ffb4..f0805d697 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -405,6 +405,7 @@ class CompletionRequest(BaseModel): echo: Optional[bool] = False frequency_penalty: Optional[float] = None logprobs: Optional[int] = None + include_draft_logprobs: Optional[bool] = False # For logits and logprobs post processing temp_scaled_logprobs: bool = False top_p_normalized_logprobs: bool = False @@ -540,6 +541,7 @@ class ChatCompletionRequest(BaseModel): frequency_penalty: Optional[float] = None logprobs: Optional[bool] = False top_logprobs: Optional[int] = 0 + include_draft_logprobs: Optional[bool] = False # For logits and logprobs post processing temp_scaled_logprobs: bool = False diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index 125d785fe..c1e189a36 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -295,10 +295,15 @@ class OpenAIServingChat: output_top_logprobs = output["top_logprobs"] previous_num_tokens += len(output["token_ids"]) logprobs_res: Optional[LogProbs] = None + draft_logprobs_res: Optional[LogProbs] = None if request.logprobs and output_top_logprobs is not None: logprobs_res = self._create_chat_logprobs( output_top_logprobs, request.logprobs, request.top_logprobs ) + if request.include_draft_logprobs: + draft_logprobs_res = self._create_chat_logprobs( + output_top_logprobs, request.logprobs, request.draft_top_logprobs + ) delta_message = DeltaMessage( reasoning_content="", @@ -326,6 +331,7 @@ class OpenAIServingChat: index=0, delta=delta_message, logprobs=logprobs_res, + draft_logprobs=draft_logprobs_res, arrival_time=arrival_time, ) if res["finished"]: @@ -461,11 +467,21 @@ class OpenAIServingChat: output = data["outputs"] output_top_logprobs = output["top_logprobs"] if output_top_logprobs is not None: + # logprobs logprobs_res = self._create_chat_logprobs( output_top_logprobs, request.logprobs, request.top_logprobs ) if logprobs_res and logprobs_res.content is not None: logprob_contents.extend(logprobs_res.content) + + # draf_logprobs + if request.include_draft_logprobs: + draft_logprobs_res = self._create_chat_logprobs( + output_top_logprobs, request.logprobs, request.draft_top_logprobs + ) + if draft_logprobs_res and draft_logprobs_res.content is not None: + draft_logprobs_res.extend(logprobs_res.content) + if data["finished"]: final_res = data task_is_finished = True diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index 9b089d073..e0d88d544 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -212,6 +212,7 @@ class OpenAIServingCompletion: valid_results = [dict()] * num_choices output_tokens = [0] * num_choices aggregated_top_logprobs = [[[], [], []] for _ in range(num_choices)] + aggregated_draft_top_logprobs = [[[], [], []] for _ in range(num_choices)] aggregated_token_ids = [[] for _ in range(num_choices)] completion_batched_token_ids = [[] for _ in range(num_choices)] current_waiting_time = 0 @@ -239,11 +240,18 @@ class OpenAIServingCompletion: output = data["outputs"] output_top_logprobs = output["top_logprobs"] + output_draft_top_logprobs = output["draft_top_logprobs"] if output_top_logprobs is not None: aggregated_top_logprobs[rid][0].extend(output_top_logprobs[0]) aggregated_top_logprobs[rid][1].extend(output_top_logprobs[1]) aggregated_top_logprobs[rid][2].extend(output_top_logprobs[2]) + # draft logprobs + if request.include_draft_logprobs: + aggregated_draft_top_logprobs[rid][0].extend(output_draft_top_logprobs[0]) + aggregated_draft_top_logprobs[rid][1].extend(output_draft_top_logprobs[1]) + aggregated_draft_top_logprobs[rid][2].extend(output_draft_top_logprobs[2]) + aggregated_token_ids[rid].extend(data["outputs"]["token_ids"]) self.engine_client.data_processor.process_response_dict( @@ -390,10 +398,17 @@ class OpenAIServingCompletion: await self._echo_back_prompt(request, res, idx) output = res["outputs"] output_top_logprobs = output["top_logprobs"] + output_draft_top_logprobs = output["draft_top_logprobs"] logprobs_res: Optional[CompletionLogprobs] = None + draft_logprobs_res: Optional[CompletionLogprobs] = None if request.logprobs and output_top_logprobs is not None: logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0) + # draft logprobs + if request.include_draft_logprobs: + draft_logprobs_res = self._create_completion_logprobs( + output_draft_top_logprobs, request.logprobs, 0 + ) output_tokens[idx] += 1 delta_message = CompletionResponseStreamChoice( index=idx, @@ -406,6 +421,7 @@ class OpenAIServingCompletion: reasoning_content="", arrival_time=arrival_time, logprobs=logprobs_res, + draft_logprobs=draft_logprobs_res, ) if not res["finished"] and "delta_message" in output: delta_message_output = output["delta_message"] diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index e48260fc6..42a906f97 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -109,6 +109,7 @@ class TokenProcessor: self.executor = ThreadPoolExecutor(max_workers=1) self.prefill_result_status = dict() self._finalizer = weakref.finalize(self, self._cleanup_resources) + self._batch_result_buffer = None def _cleanup_resources(self): """Cleaning up shared memory resources""" @@ -165,7 +166,20 @@ class TokenProcessor: try: is_blocking = True if self.speculative_decoding: - speculate_get_output(self.output_tokens, rank_id, is_blocking, False) + if ( + self.cfg.parallel_config.enable_expert_parallel + and self.cfg.parallel_config.data_parallel_size > 1 + ): + if self.use_logprobs: + # TODO speculate_get_output_with_topk + pass + else: + speculate_get_output(self.output_tokens, rank_id, is_blocking, True) + elif self.use_logprobs: + # TODO speculate_get_output_with_topk + pass + else: + speculate_get_output(self.output_tokens, rank_id, is_blocking, False) if self.output_tokens[0] == -2: continue @@ -213,7 +227,7 @@ class TokenProcessor: self.executor.submit(process_metrics) - def postprocess(self, batch_result): + def postprocess(self, batch_result, mtype=3): """ single post-processing function @@ -221,7 +235,21 @@ class TokenProcessor: batch_result (list): batch results """ try: - self.cached_generated_tokens.put_results(batch_result) + if self.cfg.speculative_config.method and self.cfg.use_logprobs: + if mtype == 3: # target + self._batch_result_buffer = batch_result + elif mtype == 4: # draft + target_batch_result = [] + draft_batch_result = batch_result + for target, decode in zip(self._batch_result_buffer, draft_batch_result): + target["outputs"]["draft_top_logprobs"] = decode["outputs"]["draft_top_logprobs"] + target_batch_result.append(target) + self._batch_result_buffer = None + self.cached_generated_tokens.put_results(target_batch_result) + else: + self.cached_generated_tokens.put_results(batch_result) + else: + self.cached_generated_tokens.put_results(batch_result) except Exception as e: llm_logger.error(f"Error in TokenProcessor's postprocess: {e}, {str(traceback.format_exc())}") @@ -302,9 +330,19 @@ class TokenProcessor: tokens = self.output_tokens.numpy() scores = None ranks = None + # target:3, draft:4 + mtype = 3 if self.cfg.speculative_config.method: - batch = self.output_tokens[1] - accept_num = tokens[2 : batch + 2] + if self.use_logprobs: + mtype = self.output_tokens[1, 0] + batch = self.output_tokens[2, 0] + accept_num = [int(num[0]) for num in self.output_tokens[3 : batch + 3]] + tokens = tokens[3 + batch : 3 + batch + batch * (K + 1) * MAX_DRAFT_TOKENS].reshape( + [batch, K + 1, MAX_DRAFT_TOKENS] + ) + else: + batch = self.output_tokens[1] + accept_num = tokens[2 : batch + 2] self._record_speculative_decoding_mertics(accept_num) elif self.use_logprobs: batch = self.output_tokens[1, 0] @@ -332,19 +370,24 @@ class TokenProcessor: task_id = task.request_id if self.cfg.speculative_config.method: - token_ids = tokens[ - 2 - + SPECULATE_MAX_BSZ - + i * MAX_DRAFT_TOKENS : 2 - + SPECULATE_MAX_BSZ - + i * MAX_DRAFT_TOKENS - + accept_num[i] - ].tolist() - if len(token_ids) == 0 or token_ids[-1] <= 0: - if envs.ENABLE_V1_KVCACHE_SCHEDULER: - if task_id in self.resource_manager.to_be_rescheduled_request_id_set: - self.resource_manager.reschedule_preempt_task(task_id) - continue + if accept_num[i] == -3: + recovery_stop = True + if recovery_stop: + llm_logger.info(f"recovery stop signal found at task {task_id}") + token_ids = [RECOVERY_STOP_SIGNAL] + elif self.use_logprobs: + token_ids = tokens[i][:, 0].tolist()[: accept_num[i]] + else: + token_ids = tokens[ + 2 + + SPECULATE_MAX_BSZ + + i * MAX_DRAFT_TOKENS : 2 + + SPECULATE_MAX_BSZ + + i * MAX_DRAFT_TOKENS + + accept_num[i] + ].tolist() + if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0): + continue else: token_id = int(tokens[i, 0]) token_ids = [token_id] @@ -387,6 +430,7 @@ class TokenProcessor: self._record_metrics(task, current_time, token_ids) result = RequestOutput( request_id=task_id, + output_type=mtype, outputs=CompletionOutput( index=i, send_idx=self.tokens_counter[task_id], @@ -412,16 +456,36 @@ class TokenProcessor: result.outputs.token_ids.append(token_id) task.output_token_ids.append(token_id) if self.use_logprobs: + # TODO 投机解码场景兼容支持 result.outputs.logprob = float(scores[i, 0]) # Construct top_logprobs topk_token_ids = tokens[i, :].tolist() topk_logprobs = scores[i, :].tolist() sampled_rank = ranks[i].item() - result.outputs.top_logprobs = LogprobsLists( - logprob_token_ids=[topk_token_ids], - logprobs=[topk_logprobs], - sampled_token_ranks=[sampled_rank], - ) + + if mtype == 3: # top_logprobs + if result.outputs.top_logprobs is None: + result.outputs.top_logprobs = LogprobsLists( + logprob_token_ids=[topk_token_ids], + logprobs=[topk_logprobs], + sampled_token_ranks=[sampled_rank], + ) + else: + result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids]) + result.outputs.top_logprobs.logprobs.extend([topk_logprobs]) + result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank]) + elif mtype == 4: # draft_top_logprobs + if result.outputs.draft_top_logprobs is None: + result.outputs.draft_top_logprobs = LogprobsLists( + logprob_token_ids=[topk_token_ids], + logprobs=[topk_logprobs], + sampled_token_ranks=[sampled_rank], + ) + else: + result.outputs.draft_top_logprobs.logprob_token_ids.extend([topk_token_ids]) + result.outputs.draft_top_logprobs.logprobs.extend([topk_logprobs]) + result.outputs.draft_top_logprobs.sampled_token_ranks.extend([sampled_rank]) + if token_id in task.eos_token_ids or is_prefill or recovery_stop: result.finished = True if recovery_stop: @@ -442,7 +506,7 @@ class TokenProcessor: if not is_prefill or self.cfg.scheduler_config.name == "splitwise": batch_result.append(result) - self.postprocess(batch_result) + self.postprocess(batch_result, mtype) def _record_metrics(self, task, current_time, token_ids): """Record all metrics for a task""" diff --git a/tests/output/test_process_batch_output.py b/tests/output/test_process_batch_output.py new file mode 100644 index 000000000..0d487c00f --- /dev/null +++ b/tests/output/test_process_batch_output.py @@ -0,0 +1,167 @@ +import time +import unittest +from unittest.mock import Mock + +import paddle + +from fastdeploy.output.token_processor import TokenProcessor + +paddle.set_device("cpu") + + +# Mock classes and constants needed for the test +class MockConfig: + class ParallelConfig: + local_data_parallel_id = 0 + + class SpeculativeConfig: + method = None + + class ModelConfig: + enable_logprob = False + + class SchedulerConfig: + name = "default" + + parallel_config = ParallelConfig() + speculative_config = SpeculativeConfig() + model_config = ModelConfig() + scheduler_config = SchedulerConfig() + + +class MockTask: + def __init__(self): + self.request_id = "test_request_1" + self.arrival_time = time.time() + self.inference_start_time = time.time() + self.schedule_start_time = time.time() + self.preprocess_end_time = time.time() - 0.1 + self.preprocess_start_time = time.time() - 0.2 + self.eos_token_ids = [2] + self.output_token_ids = [] + self.messages = "Test prompt" + self.num_cached_tokens = 0 + self.disaggregate_info = None + self.prefill_chunk_info = None + self.prefill_chunk_num = 0 + + +class MockResourceManager: + def __init__(self): + self.stop_flags = [False] + self.tasks_list = [MockTask()] + self.to_be_rescheduled_request_id_set = set() + + def info(self): + return "Mock resource manager info" + + def reschedule_preempt_task(self, task_id): + pass + + +# Constants +RECOVERY_STOP_SIGNAL = -3 +MAX_BSZ = 512 +K = 20 +MAX_DRAFT_TOKENS = 6 +SPECULATE_MAX_BSZ = 256 + + +class TestTokenProcessorProcessBatchOutput(unittest.TestCase): + + def setup_token_processor(self, speculative_decoding=False, use_logprobs=False): + """Helper method to setup TokenProcessor with different configurations""" + cfg = MockConfig() + cfg.speculative_config.method = "mtp" if speculative_decoding else None + cfg.model_config.enable_logprob = use_logprobs + + processor = TokenProcessor.__new__(TokenProcessor) + processor.cfg = cfg + processor.cached_generated_tokens = [] + processor.engine_worker_queue = Mock() + processor.split_connector = Mock() + processor.resource_manager = MockResourceManager() + processor.tokens_counter = {} + processor.total_step = 0 + processor.number_of_output_tokens = 0 + processor.prefill_result_status = {} + processor.executor = Mock() + + if speculative_decoding: + if use_logprobs: + processor.output_tokens = paddle.full( + shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1) + MAX_BSZ + 3, 1], + fill_value=2, + dtype="int64", + ) + processor.output_scores = paddle.full( + shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1), 1], + fill_value=0.0, + dtype="float32", + ) + processor.output_ranks = paddle.full( + shape=[MAX_BSZ * MAX_DRAFT_TOKENS], + fill_value=0, + dtype="int64", + ) + else: + processor.output_tokens = paddle.full( + shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2], + fill_value=2, + dtype="int64", + ) + elif use_logprobs: + processor.output_tokens = paddle.full(shape=[MAX_BSZ * (K + 1) + 2, 1], fill_value=2, dtype="int64") + processor.output_scores = paddle.full(shape=[MAX_BSZ * (K + 1), 1], fill_value=0.0, dtype="float32") + processor.output_ranks = paddle.full(shape=[MAX_BSZ], fill_value=0, dtype="int64") + else: + processor.output_tokens = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64") + + return processor + + def test_speculative_decoding_use_logprobs(self): + """Test basic speculative decoding scenario""" + processor = self.setup_token_processor(speculative_decoding=True, use_logprobs=True) + print(f"{processor}") + + # batch_size = 1 + # max_draft_tokens = MAX_DRAFT_TOKENS + + # # Setup speculative decoding output format + # output_tokens_np = np.full( + # SPECULATE_MAX_BSZ * max_draft_tokens + SPECULATE_MAX_BSZ + 10, + # 2, + # dtype=np.int64, + # ) + # output_tokens_np[1] = batch_size # batch size + # output_tokens_np[2:2 + batch_size] = [3] # accept numbers (3 accepted tokens) + + # # Setup draft tokens + # start_idx = 2 + SPECULATE_MAX_BSZ + # for i in range(batch_size): + # draft_tokens = np.arange(100, 100 + max_draft_tokens) + # output_tokens_np[ + # start_idx + i * max_draft_tokens:start_idx + (i + 1) * max_draft_tokens + # ] = draft_tokens + + # processor.output_tokens = paddle.to_tensor(output_tokens_np) + # processor.tokens_counter = {"test_request_1": 0} + # processor.postprocess = Mock() + + # # Mock speculative decoding metrics recording + # processor._record_speculative_decoding_mertics = Mock() + # processor._compute_speculative_status = Mock() + + # with patch.object(processor.resource_manager, "stop_flags", [False]): + # with patch.object(processor.resource_manager.tasks_list[0], "eos_token_ids", [2]): + # processor._process_batch_output() + + # self.assertTrue(processor._record_speculative_decoding_mertics.called) + # results = processor.postprocess.call_args[0][0] + # self.assertEqual(len(results), 1) + # # Should have 3 tokens (based on accept_num) + # self.assertEqual(len(results[0].outputs.token_ids), 3) + + +if __name__ == "__main__": + unittest.main(verbosity=2, buffer=False)