""" Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import copy import time import types from collections import Counter from unittest import mock import numpy as np import paddle import pytest from fastdeploy import envs from fastdeploy.engine.request import Request, RequestMetrics, RequestOutput from fastdeploy.output import token_processor from fastdeploy.output.token_processor import ( MAX_BSZ, MAX_DRAFT_TOKENS, SPECULATE_MAX_BSZ, K, TokenProcessor, ) class _DummyCfg: def __init__( self, speculative_method=None, enable_logprob=False, max_num_seqs=2, enable_prefix_caching=False, enable_output_caching=False, ): self.parallel_config = types.SimpleNamespace( local_data_parallel_id=0, enable_expert_parallel=False, data_parallel_size=1, ) self.speculative_config = types.SimpleNamespace( method=speculative_method, num_speculative_tokens=2, ) self.model_config = types.SimpleNamespace(enable_logprob=enable_logprob) self.scheduler_config = types.SimpleNamespace(name="default", splitwise_role="decode") self.cache_config = types.SimpleNamespace( enable_prefix_caching=enable_prefix_caching, enable_output_caching=enable_output_caching, block_size=64, ) self.max_num_seqs = max_num_seqs self.splitwise_version = "v1" class _DummyResourceManager: def __init__(self, max_num_seqs=2): self.max_num_seqs = max_num_seqs self.stop_flags = [False] * max_num_seqs self.tasks_list = [None] * max_num_seqs self.req_dict = {} self.requests = {} self.to_be_rescheduled_request_id_set = set() self.recycled = [] self.cached_tasks = [] self.cleared = False def _recycle_block_tables(self, task): self.recycled.append(task.request_id) def reschedule_preempt_task(self, request_id): self.recycled.append(f"reschedule-{request_id}") def finish_requests_async(self, request_id): self.recycled.append(f"finish-{request_id}") def total_block_number(self): return 8 def available_batch(self): return self.tasks_list.count(None) def info(self): return "rm-info" def get_finished_req(self): return [] def cache_output_tokens(self, task): self.cached_tasks.append(task.request_id) def clear_data(self): self.cleared = True class _DummyQueue: def get_finished_req(self): return [] class _DummyConnector: def __init__(self): self.calls = [] def send_first_token(self, info, results): self.calls.append((info, results)) @pytest.fixture(autouse=True) def _ensure_cpu(): paddle.device.set_device("cpu") def _make_processor( speculative_method=None, enable_logprob=False, max_num_seqs=2, enable_prefix_caching=False, enable_output_caching=False, ): cfg = _DummyCfg( speculative_method=speculative_method, enable_logprob=enable_logprob, max_num_seqs=max_num_seqs, enable_prefix_caching=enable_prefix_caching, enable_output_caching=enable_output_caching, ) cache = mock.Mock() queue = _DummyQueue() connector = _DummyConnector() processor = TokenProcessor(cfg, cache, queue, connector) rm = _DummyResourceManager(max_num_seqs) processor.set_resource_manager(rm) return processor, rm, cache, connector class _Metric: def __init__(self): self.value = None def set(self, v): self.value = v def inc(self, v=1): self.value = (self.value or 0) + v def dec(self, v=1): self.value = (self.value or 0) - v def observe(self, v): self.value = v class _Metrics: def __init__(self): self.spec_decode_num_accepted_tokens_total = _Metric() self.spec_decode_num_emitted_tokens_total = _Metric() self.spec_decode_draft_acceptance_rate = _Metric() self.spec_decode_efficiency = _Metric() self.spec_decode_num_draft_tokens_total = _Metric() self.spec_decode_draft_single_head_acceptance_rate = [_Metric() for _ in range(MAX_DRAFT_TOKENS)] self.time_per_output_token = _Metric() self.generation_tokens_total = _Metric() self.time_to_first_token = _Metric() self.request_queue_time = _Metric() self.request_prefill_time = _Metric() self.request_decode_time = _Metric() self.request_inference_time = _Metric() self.request_generation_tokens = _Metric() self.num_requests_running = _Metric() self.request_success_total = _Metric() self.available_gpu_block_num = _Metric() self.batch_size = _Metric() self.available_batch_size = _Metric() def _init_speculative_metrics(self, method, num_speculative_tokens): return None def test_init_allocates_expected_buffers(): processor, _, _, _ = _make_processor() assert list(processor.output_tokens.shape) == [MAX_BSZ + 2, 1] processor_logprob, _, _, _ = _make_processor(enable_logprob=True) assert list(processor_logprob.output_scores.shape) == [MAX_BSZ * (K + 1), 1] processor_spec, _, _, _ = _make_processor(speculative_method="mtp", enable_logprob=False) assert processor_spec.output_tokens.shape[0] == SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2 def test_run_uses_correct_worker_based_on_flag(): processor, _, _, _ = _make_processor() processor.worker = None with ( mock.patch.object(envs, "FD_USE_GET_SAVE_OUTPUT_V1", True), mock.patch("fastdeploy.output.token_processor.threading.Thread") as thread_cls, ): fake_thread = mock.Mock() thread_cls.return_value = fake_thread processor.run() target = thread_cls.call_args.kwargs["target"] assert target.__func__ is processor.process_sampling_results_use_zmq.__func__ assert fake_thread.daemon is True processor.worker = object() with pytest.raises(Exception): processor.run() def test_cleanup_resources_shuts_down_executor(): processor, _, _, _ = _make_processor() processor.executor = mock.Mock() processor._cleanup_resources() processor.executor.shutdown.assert_called_once_with(wait=False) def test_reschedule_preempt_task_use_zmq_reschedules_missing_batch(): processor, rm, _, _ = _make_processor() rm.to_be_rescheduled_request_id_set = {"req-a"} rm.requests = {"req-a": types.SimpleNamespace(idx=1)} with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", True): processor._reschedule_preempt_task_use_zmq([types.SimpleNamespace(batch_id=0)]) assert "reschedule-req-a" in rm.recycled def test_process_batch_draft_tokens_collects_top_logprobs(): processor, rm, _, _ = _make_processor(speculative_method="mtp", enable_logprob=True) rm.tasks_list[0] = types.SimpleNamespace(request_id="task-0", block_tables=[1]) ranks = paddle.to_tensor(np.array([[0, 1, 1]])) scores = paddle.ones([1, 3, 1], dtype="float32") tokens = paddle.arange(9, dtype="int64").reshape([1, 3, 3]) results = processor._process_batch_draft_tokens( 4, batch=1, accept_num=[2], tokens=tokens, scores=scores, ranks=ranks ) assert len(results) == 1 assert results[0].outputs.draft_top_logprobs.logprob_token_ids[0][0] == 0 assert results[0].outputs.draft_top_logprobs.sampled_token_ranks[-1] == 1 def test_process_batch_output_use_zmq_finishes_on_eos(): processor, rm, cache, connector = _make_processor() base_time = time.time() task = Request( request_id="req-zmq", prompt=["hi"], prompt_token_ids=[1, 2], prompt_token_ids_len=2, messages=[[{"content": "hi", "role": "user"}]], history=[], tools=[], system="system", eos_token_ids=[6], metrics=RequestMetrics( arrival_time=base_time, preprocess_start_time=base_time - 0.2, preprocess_end_time=base_time - 0.1, inference_start_time=base_time, ), ) task.metrics.decode_inference_start_time = base_time task.disaggregate_info = None task.ic_req_data = None rm.tasks_list[0] = task rm.req_dict[task.request_id] = task rm.requests[task.request_id] = types.SimpleNamespace(idx=0) tokens = np.array([5, 6], dtype=np.int64) stream = types.SimpleNamespace(batch_id=0, tokens=tokens, pooler_output=None) with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False): results = processor._process_batch_output_use_zmq([stream]) assert results[0].finished is True assert task.output_token_ids == [5, 6] assert rm.stop_flags[0] is True assert connector.calls == [] def test_process_batch_output_use_zmq_parses_logprobs(): processor, rm, _, _ = _make_processor(enable_logprob=True) base_time = time.time() task = Request( request_id="req-zmq-logprob", prompt=["hi"], prompt_token_ids=[1], prompt_token_ids_len=1, messages=[[{"content": "hi", "role": "user"}]], history=[], tools=[], system="system", eos_token_ids=[6], metrics=RequestMetrics( arrival_time=base_time, preprocess_start_time=base_time - 0.2, preprocess_end_time=base_time - 0.1, inference_start_time=base_time, ), ) task.metrics.decode_inference_start_time = base_time task.disaggregate_info = None task.ic_req_data = None rm.tasks_list[0] = task rm.req_dict[task.request_id] = task rm.requests[task.request_id] = types.SimpleNamespace(idx=0) logprob_list = token_processor.LogprobsLists( logprob_token_ids=[[1, 2]], logprobs=[[0.1, 0.2]], sampled_token_ranks=[0], ) logprob_holder = types.SimpleNamespace(tolists=lambda: logprob_list) stream = types.SimpleNamespace( batch_id=0, tokens=np.array([5], dtype=np.int64), pooler_output=None, logprobs=logprob_holder, prompt_logprobs={"0": -0.1}, ) with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False): results = processor._process_batch_output_use_zmq([stream]) assert results[0].outputs.logprob == 0.1 assert results[0].outputs.top_logprobs is logprob_list assert results[0].prompt_logprobs == {"0": -0.1} def test_recycle_resources_updates_metrics_and_state(): processor, rm, _, _ = _make_processor() task = types.SimpleNamespace(request_id="req-1", block_tables=[1], disaggregate_info=None) task.trace_carrier = None rm.tasks_list[0] = task rm.req_dict[task.request_id] = task metrics = RequestMetrics( arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time() ) result = RequestOutput(request_id=task.request_id, outputs=None, finished=False, metrics=metrics) with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False): processor._recycle_resources(task.request_id, 0, task, result, is_prefill=False) assert rm.stop_flags[0] is True assert task.request_id not in rm.req_dict assert rm.recycled[-1] == task.request_id assert processor.tokens_counter.get(task.request_id) is None def test_compute_speculative_status_builds_metrics(): processor, rm, _, _ = _make_processor(speculative_method="mtp", enable_logprob=True) req_id = "req-spec" rm.tasks_list[0] = types.SimpleNamespace(request_id=req_id, block_tables=[1]) metrics = RequestMetrics( arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time() ) result = RequestOutput(request_id=req_id, outputs=None, finished=False, metrics=metrics) processor.total_step = 2 processor.number_of_output_tokens = 4 processor.speculative_stats_step = 0 processor.accept_token_num_per_head = [2, 1] + [0] * (MAX_DRAFT_TOKENS - 2) processor.accept_token_num_per_head_per_request[req_id] = [2, 1] processor.total_step_per_request[req_id] = 2 processor._compute_speculative_status(result) assert hasattr(result.metrics, "speculate_metrics") assert result.metrics.speculate_metrics.accepted_tokens == 3 def test_process_per_token_handles_recovery_stop_and_cleanup(): processor, rm, _, _ = _make_processor() metrics = RequestMetrics( arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time() ) metrics.inference_start_time = time.time() metrics.decode_inference_start_time = metrics.inference_start_time task = types.SimpleNamespace( request_id="req-recover", prompt=["hi"], prompt_token_ids=[1], prompt_token_ids_len=1, messages=[], history=[], tools=[], system="sys", eos_token_ids=[99], metrics=metrics, output_token_ids=[], block_tables=[1], ) task.trace_carrier = None rm.tasks_list[0] = task rm.req_dict[task.request_id] = task rm.requests[task.request_id] = types.SimpleNamespace(idx=0) result = RequestOutput( request_id=task.request_id, outputs=types.SimpleNamespace(token_ids=[]), finished=False, metrics=copy.copy(task.metrics), ) with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False): stopped = processor._process_per_token( task, batch_id=0, token_ids=np.array([token_processor.RECOVERY_STOP_SIGNAL]), result=result, is_prefill=False, ) assert stopped.finished is True assert "incomplete" in stopped.error_msg assert rm.stop_flags[0] is True assert rm.tasks_list[0] is None assert processor.tokens_counter.get(task.request_id) is None def test_postprocess_buffers_and_merges_speculative_results(): processor, _, _, _ = _make_processor(speculative_method="mtp", enable_logprob=True) processor.cached_generated_tokens = mock.Mock() target_output = RequestOutput( request_id="req-t", outputs=types.SimpleNamespace(draft_top_logprobs=None), finished=False, metrics=RequestMetrics(arrival_time=time.time(), preprocess_start_time=0, preprocess_end_time=0), ) draft_output = RequestOutput( request_id="req-t", outputs=types.SimpleNamespace(draft_top_logprobs="draft-logprobs"), finished=False, metrics=RequestMetrics(arrival_time=time.time(), preprocess_start_time=0, preprocess_end_time=0), ) processor.postprocess([target_output], mtype=3) assert processor._batch_result_buffer == [target_output] processor.postprocess([draft_output], mtype=4) processor.cached_generated_tokens.put_results.assert_called_once() merged = processor.cached_generated_tokens.put_results.call_args.args[0][0] assert merged.outputs.draft_top_logprobs == "draft-logprobs" assert processor._batch_result_buffer is None def test_postprocess_emits_finished_speculative_batch(): processor, _, _, _ = _make_processor(speculative_method="mtp", enable_logprob=True) processor.cached_generated_tokens = mock.Mock() finished_output = RequestOutput( request_id="req-finished", outputs=types.SimpleNamespace(draft_top_logprobs=None), finished=True, metrics=RequestMetrics(arrival_time=time.time(), preprocess_start_time=0, preprocess_end_time=0), ) processor.postprocess([finished_output], mtype=3) processor.cached_generated_tokens.put_results.assert_called_once_with([finished_output]) assert processor._batch_result_buffer is None def test_postprocess_passes_through_unknown_type(): processor, _, _, _ = _make_processor(speculative_method="mtp", enable_logprob=True) processor.cached_generated_tokens = mock.Mock() output = RequestOutput( request_id="req-direct", outputs=types.SimpleNamespace(draft_top_logprobs=None), finished=False, metrics=RequestMetrics(arrival_time=time.time(), preprocess_start_time=0, preprocess_end_time=0), ) processor.postprocess([output], mtype=99) processor.cached_generated_tokens.put_results.assert_called_once_with([output]) def test_postprocess_logs_and_swallows_exception(): processor, _, _, _ = _make_processor() processor.cached_generated_tokens = mock.Mock() processor.cached_generated_tokens.put_results.side_effect = RuntimeError("boom") output = RequestOutput( request_id="req-error", outputs=None, finished=False, metrics=RequestMetrics(arrival_time=time.time(), preprocess_start_time=0, preprocess_end_time=0), ) processor.postprocess([output]) processor.cached_generated_tokens.put_results.assert_called_once() def test_record_speculative_decoding_metrics_tracks_acceptance(): processor, _, _, _ = _make_processor(speculative_method="mtp", enable_logprob=True) with mock.patch.object(token_processor, "main_process_metrics", _Metrics()): processor.accept_token_num_per_head = [2, 3, 0, 0, 0, 0] processor.num_draft_tokens = 0 processor.num_emitted_tokens = 0 processor.num_accepted_tokens = 0 processor._record_speculative_decoding_metrics(accept_num=[1, 2]) metrics = token_processor.main_process_metrics assert metrics.spec_decode_num_accepted_tokens_total.value == 3 assert metrics.spec_decode_num_emitted_tokens_total.value == 5 assert pytest.approx(metrics.spec_decode_draft_acceptance_rate.value) == 0.75 assert pytest.approx(metrics.spec_decode_efficiency.value) == pytest.approx(5 / 6) assert pytest.approx(metrics.spec_decode_draft_single_head_acceptance_rate[0].value) == 1.5 def test_recycle_resources_prefill_sends_first_token(): processor, rm, _, connector = _make_processor() task_id = "req-prefill" metrics = RequestMetrics( arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time(), inference_start_time=time.time(), ) task = types.SimpleNamespace( request_id=task_id, metrics=metrics, block_tables=[1], disaggregate_info={"role": "prefill"}, eos_token_ids=[1], ) task.trace_carrier = None rm.tasks_list[0] = task rm.req_dict[task_id] = task result = RequestOutput(request_id=task_id, outputs=None, finished=False, metrics=metrics) processor.engine_worker_queue = mock.Mock() processor.engine_worker_queue.get_finished_req.side_effect = [[(task_id, "finished")]] with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False): processor._recycle_resources(task_id, 0, task, result, is_prefill=True) assert rm.stop_flags[0] is True assert connector.calls and connector.calls[0][1][0] is result def test_recycle_resources_prefill_failure_sets_error(): processor, rm, _, connector = _make_processor() task_id = "req-prefill-failed" metrics = RequestMetrics( arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time(), inference_start_time=time.time(), ) task = types.SimpleNamespace( request_id=task_id, metrics=metrics, block_tables=[1], disaggregate_info={"role": "prefill"}, eos_token_ids=[1], ) task.trace_carrier = None rm.tasks_list[0] = task rm.req_dict[task_id] = task result = RequestOutput(request_id=task_id, outputs=None, finished=False, metrics=metrics) processor.engine_worker_queue = mock.Mock() processor.engine_worker_queue.get_finished_req.side_effect = [[(task_id, "failed")]] with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False): processor._recycle_resources(task_id, 0, task, result, is_prefill=True) assert result.error_code == 400 assert "failed" in result.error_message assert connector.calls and connector.calls[0][1][0] is result def test_clear_data_marks_all_tasks_finished(): processor, rm, _, _ = _make_processor() metrics = RequestMetrics( arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time() ) task_a = types.SimpleNamespace( request_id="req-a", eos_token_ids=[0], metrics=metrics, disaggregate_info=None, block_tables=[1], arrival_time=time.time(), ) task_b = types.SimpleNamespace( request_id="req-b", eos_token_ids=[0], metrics=metrics, disaggregate_info=None, block_tables=[2], arrival_time=time.time(), ) rm.tasks_list[0] = task_a rm.tasks_list[1] = task_b rm.req_dict = {"req-a": task_a, "req-b": task_b} processor.tokens_counter = Counter({"req-a": 2, "req-b": 1}) with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False): processor.clear_data() assert rm.tasks_list == [None, None] assert not processor.tokens_counter assert task_a.request_id in rm.recycled and task_b.request_id in rm.recycled def test_record_speculative_decoding_accept_num_per_request_updates_maps(): processor, _, _, _ = _make_processor(speculative_method="mtp") processor._record_speculative_decoding_accept_num_per_request("req-acc", 3) assert processor.total_step_per_request["req-acc"] == 1 assert processor.accept_token_num_per_head_per_request["req-acc"][0] == 1 assert processor.accept_token_num_per_head[2] == 1 def test_reschedule_preempt_task_triggers_for_pending_requests(): processor, rm, _, _ = _make_processor() rm.to_be_rescheduled_request_id_set = {"req-z"} rm.requests = {"req-z": types.SimpleNamespace(idx=2)} with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", True): processor._reschedule_preempt_task(batch_size=1) assert "reschedule-req-z" in rm.recycled def test_process_batch_output_consumes_tokens_and_finishes_task(): processor, rm, _, _ = _make_processor() metrics = RequestMetrics( arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time(), inference_start_time=time.time(), ) metrics.decode_inference_start_time = metrics.inference_start_time task = types.SimpleNamespace( request_id="req-out", disaggregate_info=None, eos_token_ids=[7], metrics=metrics, output_token_ids=[], messages=[{"role": "user", "content": "hi"}], num_cached_tokens=0, ic_req_data=None, prompt_token_ids_len=0, num_total_tokens=1, block_tables=[1], ) task.trace_carrier = None task.get = lambda key, default=None: getattr(task, key, default) rm.tasks_list[0] = task rm.req_dict[task.request_id] = task processor.output_tokens[1, 0] = 1 processor.output_tokens[2, 0] = 7 with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False): processor._process_batch_output() assert rm.stop_flags[0] is True assert task.output_token_ids == [7] def test_process_batch_output_logprob_records_topk_and_caching(): processor, rm, _, _ = _make_processor(enable_logprob=True, enable_prefix_caching=True, enable_output_caching=True) metrics = RequestMetrics( arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time(), inference_start_time=time.time(), ) metrics.decode_inference_start_time = metrics.inference_start_time task = types.SimpleNamespace( request_id="req-logprob", disaggregate_info=None, eos_token_ids=[3], metrics=metrics, output_token_ids=[], messages=[], num_cached_tokens=0, ic_req_data=None, prompt_token_ids_len=0, num_total_tokens=1, block_tables=[1], get=lambda key, default=None: None, ) task.trace_carrier = None rm.tasks_list[0] = task rm.req_dict[task.request_id] = task processor.output_tokens[1, 0] = 1 token_block = np.arange(K + 1, dtype=np.int64) + 3 processor.output_tokens[2 : 2 + K + 1] = paddle.to_tensor(token_block.reshape([-1, 1])) processor.output_scores[: K + 1] = paddle.ones([K + 1, 1], dtype="float32") processor.output_ranks[0] = paddle.to_tensor(0, dtype="int64") processor.cached_generated_tokens.put_results = mock.Mock() with ( mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", True), mock.patch.object(token_processor, "main_process_metrics", _Metrics()), ): processor._process_batch_output() assert rm.cached_tasks[-1] == "req-logprob" sent = processor.cached_generated_tokens.put_results.call_args.args[0][0] assert sent.outputs.top_logprobs is not None def test_process_batch_output_speculative_logprob_handles_draft_batch(): processor, rm, _, _ = _make_processor(speculative_method="mtp", enable_logprob=True) rm.tasks_list[0] = types.SimpleNamespace(request_id="req-draft", block_tables=[1], disaggregate_info=None) target = RequestOutput( request_id="req-draft", outputs=types.SimpleNamespace(draft_top_logprobs=None), finished=False, metrics=None, ) processor._batch_result_buffer = [target] processor.cached_generated_tokens = mock.Mock() processor.output_tokens[1, 0] = 4 processor.output_tokens[2, 0] = 1 processor.output_tokens[3, 0] = 1 draft_tokens = np.arange(MAX_DRAFT_TOKENS * (K + 1), dtype=np.int64).reshape([-1, 1]) + 5 processor.output_tokens[3 + MAX_BSZ : 3 + MAX_BSZ + len(draft_tokens)] = paddle.to_tensor(draft_tokens) processor.output_scores[: MAX_DRAFT_TOKENS * (K + 1)] = paddle.ones( [MAX_DRAFT_TOKENS * (K + 1), 1], dtype="float32" ) processor.output_ranks[:MAX_DRAFT_TOKENS] = paddle.arange(MAX_DRAFT_TOKENS, dtype="int64") with ( mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False), mock.patch.object(token_processor, "main_process_metrics", _Metrics()), ): processor._process_batch_output() sent_batch = processor.cached_generated_tokens.put_results.call_args.args[0] assert sent_batch and sent_batch[0].outputs.draft_top_logprobs is not None def test_process_batch_output_speculative_recovery_stop_finishes(): processor, rm, _, _ = _make_processor(speculative_method="mtp") metrics = RequestMetrics( arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time(), inference_start_time=time.time(), ) metrics.decode_inference_start_time = metrics.inference_start_time task = types.SimpleNamespace( request_id="req-recover-spec", disaggregate_info=None, eos_token_ids=[2], metrics=metrics, output_token_ids=[], messages=[], num_cached_tokens=0, ic_req_data=None, prompt_token_ids_len=0, num_total_tokens=1, block_tables=[1], get=lambda key, default=None: None, ) task.trace_carrier = None rm.tasks_list[0] = task rm.req_dict[task.request_id] = task processor.output_tokens[1] = 1 processor.output_tokens[2] = -3 processor.number_of_output_tokens = 1 processor.total_step = 1 processor.accept_token_num_per_head_per_request[task.request_id] = [1] + [0] * (MAX_DRAFT_TOKENS - 1) processor.total_step_per_request[task.request_id] = 1 processor.cached_generated_tokens.put_results = mock.Mock() with ( mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False), mock.patch.object(token_processor, "main_process_metrics", _Metrics()), ): processor._process_batch_output() assert rm.stop_flags[0] is True sent = processor.cached_generated_tokens.put_results.call_args.args[0][0] assert sent.finished is True assert "incomplete" in sent.error_msg def test_process_batch_output_prefill_chunk_and_adapter_skip(): processor, rm, _, _ = _make_processor(enable_logprob=True) metrics = RequestMetrics( arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time() ) metrics.inference_start_time = time.time() metrics.decode_inference_start_time = metrics.inference_start_time processor.cfg.scheduler_config.splitwise_role = "prefill" task = types.SimpleNamespace( request_id="req-prefill-chunk", disaggregate_info={"role": "prefill"}, eos_token_ids=[1], metrics=metrics, output_token_ids=[], messages=[{"role": "user", "content": "hi"}], num_cached_tokens=0, ic_req_data=None, prompt_token_ids_len=1, num_total_tokens=2, block_tables=[1], prefill_chunk_info=[{"idx": 0}, {"idx": 1}], ) task.trace_carrier = None task.get = lambda key, default=None: getattr(task, key, default) rm.tasks_list[0] = task rm.req_dict[task.request_id] = task processor.output_tokens[1, 0] = 1 processor.output_tokens[2 : 2 + K + 1] = paddle.to_tensor(np.ones([K + 1, 1], dtype=np.int64)) processor.output_scores[: K + 1] = paddle.ones([K + 1, 1], dtype="float32") processor.output_ranks[0] = paddle.to_tensor(0, dtype="int64") processor.cached_generated_tokens.put_results = mock.Mock() with ( mock.patch.object(envs, "FD_ENABLE_INTERNAL_ADAPTER", True), mock.patch.object(token_processor, "main_process_metrics", _Metrics()), ): processor._process_batch_output() assert getattr(task, "prefill_chunk_num") == 1 assert processor.cached_generated_tokens.put_results.call_args.args[0] == [] def test_process_batch_output_handles_multimodal_and_negative_token(): processor, rm, _, _ = _make_processor() metrics = RequestMetrics( arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time() ) metrics.inference_start_time = time.time() metrics.decode_inference_start_time = metrics.inference_start_time processor.cfg.scheduler_config.splitwise_role = "prefill" task = types.SimpleNamespace( request_id="req-negative", disaggregate_info=None, eos_token_ids=[5], metrics=metrics, output_token_ids=[], messages=None, num_cached_tokens=0, ic_req_data=None, prompt_token_ids_len=0, num_total_tokens=1, block_tables=[1], prefill_chunk_info=None, multimodal_inputs={"num_input_image_tokens": 2, "num_input_video_tokens": 3}, get=lambda key, default=None: None, ) task.trace_carrier = None rm.tasks_list[0] = task rm.req_dict[task.request_id] = task rm.to_be_rescheduled_request_id_set = {task.request_id} rm.requests = {task.request_id: types.SimpleNamespace(idx=0)} processor.output_tokens[1, 0] = 1 processor.output_tokens[2, 0] = -1 with ( mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", True), mock.patch.object(token_processor, "main_process_metrics", _Metrics()), ): processor._process_batch_output() assert rm.recycled[-1] == f"reschedule-{task.request_id}" def test_process_batch_output_speculative_logprob_targets_topk_scores(): processor, rm, _, _ = _make_processor(speculative_method="mtp", enable_logprob=True) metrics = RequestMetrics( arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time() ) metrics.inference_start_time = time.time() metrics.decode_inference_start_time = metrics.inference_start_time task = types.SimpleNamespace( request_id="req-spec-logprob", disaggregate_info=None, eos_token_ids=[9], metrics=metrics, output_token_ids=[], messages=None, num_cached_tokens=0, ic_req_data=None, prompt_token_ids_len=0, num_total_tokens=1, block_tables=[1], get=lambda key, default=None: None, ) task.trace_carrier = None rm.tasks_list[0] = task rm.req_dict[task.request_id] = task processor.output_tokens[1, 0] = 3 processor.output_tokens[2, 0] = 1 processor.output_tokens[3, 0] = 2 token_block = np.arange(MAX_DRAFT_TOKENS * (K + 1), dtype=np.int64).reshape([-1, 1]) + 3 processor.output_tokens[3 + MAX_BSZ : 3 + MAX_BSZ + len(token_block)] = paddle.to_tensor(token_block) score_block = paddle.arange(MAX_DRAFT_TOKENS * (K + 1), dtype="float32").reshape([-1, 1]) processor.output_scores[: MAX_DRAFT_TOKENS * (K + 1)] = score_block processor.output_ranks[:MAX_DRAFT_TOKENS] = paddle.arange(MAX_DRAFT_TOKENS, dtype="int64") processor.cached_generated_tokens.put_results = mock.Mock() with ( mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False), mock.patch.object(token_processor, "main_process_metrics", _Metrics()), ): processor._process_batch_output() assert processor.tokens_counter[task.request_id] == 2 def test_record_metrics_and_speculative_ngram_metrics(): processor, _, _, _ = _make_processor(speculative_method="ngram", enable_logprob=True) metrics = _Metrics() task = types.SimpleNamespace( request_id="req-metrics", metrics=RequestMetrics(arrival_time=time.time(), preprocess_start_time=0, preprocess_end_time=0), last_token_time=time.time(), ) with mock.patch.object(token_processor, "main_process_metrics", metrics): processor._record_metrics(task, current_time=time.time(), token_ids=[1, 2]) processor.accept_token_num_per_head = [0, 2] + [0] * (MAX_DRAFT_TOKENS - 2) processor.num_accepted_tokens = 3 processor.num_emitted_tokens = 3 processor._record_speculative_decoding_metrics(accept_num=[1, 1]) assert metrics.generation_tokens_total.value == 2 assert metrics.spec_decode_draft_acceptance_rate.value == 1 def test_clear_data_invokes_scheduler_cleanup(): processor, rm, _, _ = _make_processor() metrics = RequestMetrics( arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time() ) task = types.SimpleNamespace( request_id="req-clear", arrival_time=time.time(), disaggregate_info=None, eos_token_ids=[0], metrics=metrics, output_token_ids=[], messages=None, num_cached_tokens=0, ic_req_data=None, prompt_token_ids_len=0, num_total_tokens=1, block_tables=[1], get=lambda key, default=None: getattr(task, key, default), ) task.trace_carrier = None rm.tasks_list[0] = task rm.stop_flags = [True] * rm.max_num_seqs processor.tokens_counter[task.request_id] = 0 with ( mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", True), mock.patch.object(token_processor, "main_process_metrics", _Metrics()), ): processor.clear_data() assert rm.cleared is True def test_process_batch_output_skips_already_stopped_slot(): processor, rm, _, _ = _make_processor() rm.stop_flags[0] = True processor.output_tokens[1, 0] = 1 processor.output_tokens[2, 0] = 5 with ( mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False), mock.patch.object(token_processor, "main_process_metrics", _Metrics()), ): processor._process_batch_output() assert processor.cached_generated_tokens.put_results.called def test_process_batch_output_speculative_negative_token_reschedules(): processor, rm, _, _ = _make_processor(speculative_method="mtp") task_id = "req-spec-neg" metrics = RequestMetrics( arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time() ) metrics.inference_start_time = time.time() metrics.decode_inference_start_time = metrics.inference_start_time task = types.SimpleNamespace( request_id=task_id, disaggregate_info=None, eos_token_ids=[1], metrics=metrics, output_token_ids=[], messages=None, num_cached_tokens=0, ic_req_data=None, prompt_token_ids_len=0, num_total_tokens=1, block_tables=[1], get=lambda key, default=None: None, ) task.trace_carrier = None rm.tasks_list[0] = task rm.req_dict[task_id] = task rm.to_be_rescheduled_request_id_set = {task_id} rm.requests = {task_id: types.SimpleNamespace(idx=0)} processor.output_tokens[1] = 1 processor.output_tokens[2] = 1 processor.output_tokens[3] = 1 processor.output_tokens[2 + SPECULATE_MAX_BSZ] = -1 with ( mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", True), mock.patch.object(token_processor, "main_process_metrics", _Metrics()), ): processor._process_batch_output() assert rm.recycled[-1] == f"reschedule-{task_id}" def test_process_batch_output_use_zmq_reschedules_negative_token(): processor, rm, _, _ = _make_processor() task = types.SimpleNamespace(request_id="req-zmq-neg") rm.tasks_list[0] = task rm.req_dict[task.request_id] = task rm.to_be_rescheduled_request_id_set = {task.request_id} stream = types.SimpleNamespace(batch_id=0, tokens=np.array([-1], dtype=np.int64), pooler_output=None) with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", True): results = processor._process_batch_output_use_zmq([stream]) assert results == [] assert rm.recycled[-1] == f"reschedule-{task.request_id}" def test_process_batch_output_records_second_decode_token(): processor, rm, _, _ = _make_processor() processor.cfg.scheduler_config.splitwise_role = "decode" task = types.SimpleNamespace( request_id="req-second", disaggregate_info=None, eos_token_ids=[2], metrics=RequestMetrics( arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time() ), output_token_ids=[], messages=None, num_cached_tokens=0, ic_req_data=None, prompt_token_ids_len=0, num_total_tokens=1, block_tables=[1], get=lambda key, default=None: None, ) task.trace_carrier = None task.metrics.inference_start_time = time.time() task.metrics.decode_inference_start_time = task.metrics.inference_start_time rm.tasks_list[0] = task rm.req_dict[task.request_id] = task processor.tokens_counter[task.request_id] = 1 processor.output_tokens[1, 0] = 1 processor.output_tokens[2, 0] = 2 with ( mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False), mock.patch.object(token_processor, "main_process_metrics", _Metrics()), ): processor._process_batch_output() assert task.metrics.decode_recv_second_token_time is not None def test_record_speculative_metrics_calls_init_when_missing(): processor, _, _, _ = _make_processor(speculative_method="mtp") class _MinimalMetrics: def __init__(self): self.init_called = False def _init_speculative_metrics(self, method, num_speculative_tokens): self.spec_decode_num_accepted_tokens_total = _Metric() self.spec_decode_num_emitted_tokens_total = _Metric() self.spec_decode_draft_acceptance_rate = _Metric() self.spec_decode_efficiency = _Metric() self.spec_decode_num_draft_tokens_total = _Metric() self.spec_decode_draft_single_head_acceptance_rate = [_Metric() for _ in range(MAX_DRAFT_TOKENS)] self.init_called = True processor.accept_token_num_per_head = [1, 1] + [0] * (MAX_DRAFT_TOKENS - 2) processor.num_accepted_tokens = 2 processor.num_emitted_tokens = 2 metrics = _MinimalMetrics() with mock.patch.object(token_processor, "main_process_metrics", metrics): processor._record_speculative_decoding_metrics(accept_num=[1]) assert metrics.init_called is True def test_process_batch_output_prefill_sets_draft_tokens(): processor, rm, _, connector = _make_processor(speculative_method="mtp") processor.cfg.scheduler_config.splitwise_role = "prefill" metrics = RequestMetrics( arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time(), inference_start_time=time.time(), ) metrics.decode_inference_start_time = metrics.inference_start_time task = types.SimpleNamespace( request_id="req-prefill-draft", disaggregate_info={"role": "prefill"}, eos_token_ids=[99], metrics=metrics, output_token_ids=[], messages=None, num_cached_tokens=0, ic_req_data=None, prompt_token_ids_len=0, num_total_tokens=1, block_tables=[1], get=lambda key, default=None: None, ) task.trace_carrier = None rm.tasks_list[0] = task rm.req_dict[task.request_id] = task processor.engine_worker_queue = mock.Mock() processor.engine_worker_queue.get_finished_req.side_effect = [[(task.request_id, "finished")]] processor.output_tokens[1] = 1 processor.output_tokens[2] = 2 processor.output_tokens[2 + SPECULATE_MAX_BSZ] = 11 processor.output_tokens[2 + SPECULATE_MAX_BSZ + 1] = 12 with ( mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False), mock.patch.object(token_processor, "main_process_metrics", _Metrics()), ): processor._process_batch_output() assert connector.calls sent = connector.calls[0][1][0] assert sent.outputs.draft_token_ids == [11, 12] def test_process_batch_output_logs_recovery_stop_for_non_speculative(): processor, rm, _, _ = _make_processor() metrics = RequestMetrics( arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time() ) metrics.inference_start_time = time.time() metrics.decode_inference_start_time = metrics.inference_start_time task = types.SimpleNamespace( request_id="req-recovery", disaggregate_info=None, eos_token_ids=[1], metrics=metrics, output_token_ids=[], messages=None, num_cached_tokens=0, ic_req_data=None, prompt_token_ids_len=0, num_total_tokens=1, block_tables=[1], ) task.trace_carrier = None task.get = lambda k, d=None: getattr(task, k, d) rm.tasks_list[0] = task rm.req_dict[task.request_id] = task processor.output_tokens[1, 0] = 1 processor.output_tokens[2, 0] = token_processor.RECOVERY_STOP_SIGNAL with ( mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False), mock.patch.object(token_processor, "main_process_metrics", _Metrics()), ): processor._process_batch_output() assert rm.stop_flags[0] is True def test_process_batch_output_sets_multimodal_token_counts(): processor, rm, _, _ = _make_processor() metrics = RequestMetrics( arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time() ) metrics.inference_start_time = time.time() metrics.decode_inference_start_time = metrics.inference_start_time task = types.SimpleNamespace( request_id="req-mm", disaggregate_info=None, eos_token_ids=[7], metrics=metrics, output_token_ids=[], messages=None, num_cached_tokens=0, ic_req_data=None, prompt_token_ids_len=0, num_total_tokens=1, block_tables=[1], multimodal_inputs={"num_input_image_tokens": 4, "num_input_video_tokens": 5}, ) task.trace_carrier = None task.get = lambda key, default=None: getattr(task, key, default) rm.tasks_list[0] = task rm.req_dict[task.request_id] = task processor.output_tokens[1, 0] = 1 processor.output_tokens[2, 0] = 7 with mock.patch.object(token_processor, "main_process_metrics", _Metrics()): processor._process_batch_output() sent = processor.cached_generated_tokens.put_results.call_args.args[0][0] assert sent.num_input_image_tokens == 4 and sent.num_input_video_tokens == 5 def test_warmup_token_processor_initialization(): cfg = _DummyCfg() with mock.patch.object(token_processor.TokenProcessor, "__init__", lambda self, _cfg: None): warm = token_processor.WarmUpTokenProcessor(cfg) assert warm._is_running is True and warm._is_blocking is True warm.postprocess([]) def test_warmup_processor_stop_joins_worker(): warm = token_processor.WarmUpTokenProcessor.__new__(token_processor.WarmUpTokenProcessor) warm._is_running = True worker = mock.Mock() warm.worker = worker warm.stop() worker.join.assert_called_once() def test_healthy_behaviour_respects_timeout(monkeypatch): processor, _, _, _ = _make_processor() processor.timestamp_for_alive_before_handle_batch = time.time() - 1 processor.timestamp_for_alive_after_handle_batch = None monkeypatch.setattr(envs, "FD_TOKEN_PROCESSOR_HEALTH_TIMEOUT", 0.1) assert processor.healthy() is False def test_healthy_detects_engine_hang(): processor, _, _, _ = _make_processor() processor.timestamp_for_alive_before_handle_batch = None processor.timestamp_for_alive_after_handle_batch = time.time() processor.engine_output_token_hang = True assert processor.healthy() is False def test_healthy_recent_prehandle_activity_is_ok(monkeypatch): processor, _, _, _ = _make_processor() processor.timestamp_for_alive_before_handle_batch = time.time() processor.timestamp_for_alive_after_handle_batch = None monkeypatch.setattr(envs, "FD_TOKEN_PROCESSOR_HEALTH_TIMEOUT", 5) assert processor.healthy() is True def test_record_completion_metrics_updates_counters(): processor, _, _, _ = _make_processor() task_id = "req-complete" metrics = RequestMetrics( arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time() ) metrics.inference_start_time = time.time() - 0.2 metrics.engine_recv_first_token_time = time.time() - 0.1 task = types.SimpleNamespace(request_id=task_id, metrics=metrics, user="user-a") processor.tokens_counter[task_id] = 4 with ( mock.patch.object(token_processor, "main_process_metrics", _Metrics()) as metrics_obj, mock.patch.object(token_processor, "trace_print"), ): processor._record_completion_metrics(task, current_time=time.time()) assert metrics_obj.request_decode_time.value is not None assert metrics_obj.request_success_total.value == 1 assert metrics_obj.request_generation_tokens.value == 4 def test_process_sampling_results_use_zmq_rejects_speculative(): processor, _, _, _ = _make_processor(speculative_method="mtp") with pytest.raises(NotImplementedError): processor.process_sampling_results_use_zmq()