Files
FastDeploy/tests/output/test_token_processor.py
xunyoyo 3aee5c4bf5 [CI] 【Hackathon 9th Sprint No.37】NO.37 功能模块单测补充 (#5059)
* Add unit tests for TokenProcessor functionality

* Add trace stubs for token processor tests

* Increase token processor test coverage

* Clean up imports in test_token_processor.py

Remove unnecessary path manipulation in test file.

* Cleanup: Remove unused imports in test_token_processor

Removed unused imports from the test file.

* Add trace_carrier to task in test cases

Added trace_carrier attribute to task in multiple test cases to ensure proper handling of trace information.

* Refine token processor tests for safe coverage

* Expand postprocess coverage

* Add ZMQ logprob parsing test

---------

Co-authored-by: CSWYF3634076 <wangyafeng@baidu.com>
Co-authored-by: Tao Luo <luotao02@baidu.com>
2025-12-23 10:35:16 +08:00

1297 lines
47 KiB
Python

"""
Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import copy
import time
import types
from collections import Counter
from unittest import mock
import numpy as np
import paddle
import pytest
from fastdeploy import envs
from fastdeploy.engine.request import Request, RequestMetrics, RequestOutput
from fastdeploy.output import token_processor
from fastdeploy.output.token_processor import (
MAX_BSZ,
MAX_DRAFT_TOKENS,
SPECULATE_MAX_BSZ,
K,
TokenProcessor,
)
class _DummyCfg:
def __init__(
self,
speculative_method=None,
enable_logprob=False,
max_num_seqs=2,
enable_prefix_caching=False,
enable_output_caching=False,
):
self.parallel_config = types.SimpleNamespace(
local_data_parallel_id=0,
enable_expert_parallel=False,
data_parallel_size=1,
)
self.speculative_config = types.SimpleNamespace(
method=speculative_method,
num_speculative_tokens=2,
)
self.model_config = types.SimpleNamespace(enable_logprob=enable_logprob)
self.scheduler_config = types.SimpleNamespace(name="default", splitwise_role="decode")
self.cache_config = types.SimpleNamespace(
enable_prefix_caching=enable_prefix_caching,
enable_output_caching=enable_output_caching,
block_size=64,
)
self.max_num_seqs = max_num_seqs
self.splitwise_version = "v1"
class _DummyResourceManager:
def __init__(self, max_num_seqs=2):
self.max_num_seqs = max_num_seqs
self.stop_flags = [False] * max_num_seqs
self.tasks_list = [None] * max_num_seqs
self.req_dict = {}
self.requests = {}
self.to_be_rescheduled_request_id_set = set()
self.recycled = []
self.cached_tasks = []
self.cleared = False
def _recycle_block_tables(self, task):
self.recycled.append(task.request_id)
def reschedule_preempt_task(self, request_id):
self.recycled.append(f"reschedule-{request_id}")
def finish_requests_async(self, request_id):
self.recycled.append(f"finish-{request_id}")
def total_block_number(self):
return 8
def available_batch(self):
return self.tasks_list.count(None)
def info(self):
return "rm-info"
def get_finished_req(self):
return []
def cache_output_tokens(self, task):
self.cached_tasks.append(task.request_id)
def clear_data(self):
self.cleared = True
class _DummyQueue:
def get_finished_req(self):
return []
class _DummyConnector:
def __init__(self):
self.calls = []
def send_first_token(self, info, results):
self.calls.append((info, results))
@pytest.fixture(autouse=True)
def _ensure_cpu():
paddle.device.set_device("cpu")
def _make_processor(
speculative_method=None,
enable_logprob=False,
max_num_seqs=2,
enable_prefix_caching=False,
enable_output_caching=False,
):
cfg = _DummyCfg(
speculative_method=speculative_method,
enable_logprob=enable_logprob,
max_num_seqs=max_num_seqs,
enable_prefix_caching=enable_prefix_caching,
enable_output_caching=enable_output_caching,
)
cache = mock.Mock()
queue = _DummyQueue()
connector = _DummyConnector()
processor = TokenProcessor(cfg, cache, queue, connector)
rm = _DummyResourceManager(max_num_seqs)
processor.set_resource_manager(rm)
return processor, rm, cache, connector
class _Metric:
def __init__(self):
self.value = None
def set(self, v):
self.value = v
def inc(self, v=1):
self.value = (self.value or 0) + v
def dec(self, v=1):
self.value = (self.value or 0) - v
def observe(self, v):
self.value = v
class _Metrics:
def __init__(self):
self.spec_decode_num_accepted_tokens_total = _Metric()
self.spec_decode_num_emitted_tokens_total = _Metric()
self.spec_decode_draft_acceptance_rate = _Metric()
self.spec_decode_efficiency = _Metric()
self.spec_decode_num_draft_tokens_total = _Metric()
self.spec_decode_draft_single_head_acceptance_rate = [_Metric() for _ in range(MAX_DRAFT_TOKENS)]
self.time_per_output_token = _Metric()
self.generation_tokens_total = _Metric()
self.time_to_first_token = _Metric()
self.request_queue_time = _Metric()
self.request_prefill_time = _Metric()
self.request_decode_time = _Metric()
self.request_inference_time = _Metric()
self.request_generation_tokens = _Metric()
self.num_requests_running = _Metric()
self.request_success_total = _Metric()
self.available_gpu_block_num = _Metric()
self.batch_size = _Metric()
self.available_batch_size = _Metric()
def _init_speculative_metrics(self, method, num_speculative_tokens):
return None
def test_init_allocates_expected_buffers():
processor, _, _, _ = _make_processor()
assert list(processor.output_tokens.shape) == [MAX_BSZ + 2, 1]
processor_logprob, _, _, _ = _make_processor(enable_logprob=True)
assert list(processor_logprob.output_scores.shape) == [MAX_BSZ * (K + 1), 1]
processor_spec, _, _, _ = _make_processor(speculative_method="mtp", enable_logprob=False)
assert processor_spec.output_tokens.shape[0] == SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2
def test_run_uses_correct_worker_based_on_flag():
processor, _, _, _ = _make_processor()
processor.worker = None
with (
mock.patch.object(envs, "FD_USE_GET_SAVE_OUTPUT_V1", True),
mock.patch("fastdeploy.output.token_processor.threading.Thread") as thread_cls,
):
fake_thread = mock.Mock()
thread_cls.return_value = fake_thread
processor.run()
target = thread_cls.call_args.kwargs["target"]
assert target.__func__ is processor.process_sampling_results_use_zmq.__func__
assert fake_thread.daemon is True
processor.worker = object()
with pytest.raises(Exception):
processor.run()
def test_cleanup_resources_shuts_down_executor():
processor, _, _, _ = _make_processor()
processor.executor = mock.Mock()
processor._cleanup_resources()
processor.executor.shutdown.assert_called_once_with(wait=False)
def test_reschedule_preempt_task_use_zmq_reschedules_missing_batch():
processor, rm, _, _ = _make_processor()
rm.to_be_rescheduled_request_id_set = {"req-a"}
rm.requests = {"req-a": types.SimpleNamespace(idx=1)}
with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", True):
processor._reschedule_preempt_task_use_zmq([types.SimpleNamespace(batch_id=0)])
assert "reschedule-req-a" in rm.recycled
def test_process_batch_draft_tokens_collects_top_logprobs():
processor, rm, _, _ = _make_processor(speculative_method="mtp", enable_logprob=True)
rm.tasks_list[0] = types.SimpleNamespace(request_id="task-0", block_tables=[1])
ranks = paddle.to_tensor(np.array([[0, 1, 1]]))
scores = paddle.ones([1, 3, 1], dtype="float32")
tokens = paddle.arange(9, dtype="int64").reshape([1, 3, 3])
results = processor._process_batch_draft_tokens(
4, batch=1, accept_num=[2], tokens=tokens, scores=scores, ranks=ranks
)
assert len(results) == 1
assert results[0].outputs.draft_top_logprobs.logprob_token_ids[0][0] == 0
assert results[0].outputs.draft_top_logprobs.sampled_token_ranks[-1] == 1
def test_process_batch_output_use_zmq_finishes_on_eos():
processor, rm, cache, connector = _make_processor()
base_time = time.time()
task = Request(
request_id="req-zmq",
prompt=["hi"],
prompt_token_ids=[1, 2],
prompt_token_ids_len=2,
messages=[[{"content": "hi", "role": "user"}]],
history=[],
tools=[],
system="system",
eos_token_ids=[6],
metrics=RequestMetrics(
arrival_time=base_time,
preprocess_start_time=base_time - 0.2,
preprocess_end_time=base_time - 0.1,
inference_start_time=base_time,
),
)
task.metrics.decode_inference_start_time = base_time
task.disaggregate_info = None
task.ic_req_data = None
rm.tasks_list[0] = task
rm.req_dict[task.request_id] = task
rm.requests[task.request_id] = types.SimpleNamespace(idx=0)
tokens = np.array([5, 6], dtype=np.int64)
stream = types.SimpleNamespace(batch_id=0, tokens=tokens, pooler_output=None)
with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False):
results = processor._process_batch_output_use_zmq([stream])
assert results[0].finished is True
assert task.output_token_ids == [5, 6]
assert rm.stop_flags[0] is True
assert connector.calls == []
def test_process_batch_output_use_zmq_parses_logprobs():
processor, rm, _, _ = _make_processor(enable_logprob=True)
base_time = time.time()
task = Request(
request_id="req-zmq-logprob",
prompt=["hi"],
prompt_token_ids=[1],
prompt_token_ids_len=1,
messages=[[{"content": "hi", "role": "user"}]],
history=[],
tools=[],
system="system",
eos_token_ids=[6],
metrics=RequestMetrics(
arrival_time=base_time,
preprocess_start_time=base_time - 0.2,
preprocess_end_time=base_time - 0.1,
inference_start_time=base_time,
),
)
task.metrics.decode_inference_start_time = base_time
task.disaggregate_info = None
task.ic_req_data = None
rm.tasks_list[0] = task
rm.req_dict[task.request_id] = task
rm.requests[task.request_id] = types.SimpleNamespace(idx=0)
logprob_list = token_processor.LogprobsLists(
logprob_token_ids=[[1, 2]],
logprobs=[[0.1, 0.2]],
sampled_token_ranks=[0],
)
logprob_holder = types.SimpleNamespace(tolists=lambda: logprob_list)
stream = types.SimpleNamespace(
batch_id=0,
tokens=np.array([5], dtype=np.int64),
pooler_output=None,
logprobs=logprob_holder,
prompt_logprobs={"0": -0.1},
)
with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False):
results = processor._process_batch_output_use_zmq([stream])
assert results[0].outputs.logprob == 0.1
assert results[0].outputs.top_logprobs is logprob_list
assert results[0].prompt_logprobs == {"0": -0.1}
def test_recycle_resources_updates_metrics_and_state():
processor, rm, _, _ = _make_processor()
task = types.SimpleNamespace(request_id="req-1", block_tables=[1], disaggregate_info=None)
task.trace_carrier = None
rm.tasks_list[0] = task
rm.req_dict[task.request_id] = task
metrics = RequestMetrics(
arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time()
)
result = RequestOutput(request_id=task.request_id, outputs=None, finished=False, metrics=metrics)
with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False):
processor._recycle_resources(task.request_id, 0, task, result, is_prefill=False)
assert rm.stop_flags[0] is True
assert task.request_id not in rm.req_dict
assert rm.recycled[-1] == task.request_id
assert processor.tokens_counter.get(task.request_id) is None
def test_compute_speculative_status_builds_metrics():
processor, rm, _, _ = _make_processor(speculative_method="mtp", enable_logprob=True)
req_id = "req-spec"
rm.tasks_list[0] = types.SimpleNamespace(request_id=req_id, block_tables=[1])
metrics = RequestMetrics(
arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time()
)
result = RequestOutput(request_id=req_id, outputs=None, finished=False, metrics=metrics)
processor.total_step = 2
processor.number_of_output_tokens = 4
processor.speculative_stats_step = 0
processor.accept_token_num_per_head = [2, 1] + [0] * (MAX_DRAFT_TOKENS - 2)
processor.accept_token_num_per_head_per_request[req_id] = [2, 1]
processor.total_step_per_request[req_id] = 2
processor._compute_speculative_status(result)
assert hasattr(result.metrics, "speculate_metrics")
assert result.metrics.speculate_metrics.accepted_tokens == 3
def test_process_per_token_handles_recovery_stop_and_cleanup():
processor, rm, _, _ = _make_processor()
metrics = RequestMetrics(
arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time()
)
metrics.inference_start_time = time.time()
metrics.decode_inference_start_time = metrics.inference_start_time
task = types.SimpleNamespace(
request_id="req-recover",
prompt=["hi"],
prompt_token_ids=[1],
prompt_token_ids_len=1,
messages=[],
history=[],
tools=[],
system="sys",
eos_token_ids=[99],
metrics=metrics,
output_token_ids=[],
block_tables=[1],
)
task.trace_carrier = None
rm.tasks_list[0] = task
rm.req_dict[task.request_id] = task
rm.requests[task.request_id] = types.SimpleNamespace(idx=0)
result = RequestOutput(
request_id=task.request_id,
outputs=types.SimpleNamespace(token_ids=[]),
finished=False,
metrics=copy.copy(task.metrics),
)
with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False):
stopped = processor._process_per_token(
task,
batch_id=0,
token_ids=np.array([token_processor.RECOVERY_STOP_SIGNAL]),
result=result,
is_prefill=False,
)
assert stopped.finished is True
assert "incomplete" in stopped.error_msg
assert rm.stop_flags[0] is True
assert rm.tasks_list[0] is None
assert processor.tokens_counter.get(task.request_id) is None
def test_postprocess_buffers_and_merges_speculative_results():
processor, _, _, _ = _make_processor(speculative_method="mtp", enable_logprob=True)
processor.cached_generated_tokens = mock.Mock()
target_output = RequestOutput(
request_id="req-t",
outputs=types.SimpleNamespace(draft_top_logprobs=None),
finished=False,
metrics=RequestMetrics(arrival_time=time.time(), preprocess_start_time=0, preprocess_end_time=0),
)
draft_output = RequestOutput(
request_id="req-t",
outputs=types.SimpleNamespace(draft_top_logprobs="draft-logprobs"),
finished=False,
metrics=RequestMetrics(arrival_time=time.time(), preprocess_start_time=0, preprocess_end_time=0),
)
processor.postprocess([target_output], mtype=3)
assert processor._batch_result_buffer == [target_output]
processor.postprocess([draft_output], mtype=4)
processor.cached_generated_tokens.put_results.assert_called_once()
merged = processor.cached_generated_tokens.put_results.call_args.args[0][0]
assert merged.outputs.draft_top_logprobs == "draft-logprobs"
assert processor._batch_result_buffer is None
def test_postprocess_emits_finished_speculative_batch():
processor, _, _, _ = _make_processor(speculative_method="mtp", enable_logprob=True)
processor.cached_generated_tokens = mock.Mock()
finished_output = RequestOutput(
request_id="req-finished",
outputs=types.SimpleNamespace(draft_top_logprobs=None),
finished=True,
metrics=RequestMetrics(arrival_time=time.time(), preprocess_start_time=0, preprocess_end_time=0),
)
processor.postprocess([finished_output], mtype=3)
processor.cached_generated_tokens.put_results.assert_called_once_with([finished_output])
assert processor._batch_result_buffer is None
def test_postprocess_passes_through_unknown_type():
processor, _, _, _ = _make_processor(speculative_method="mtp", enable_logprob=True)
processor.cached_generated_tokens = mock.Mock()
output = RequestOutput(
request_id="req-direct",
outputs=types.SimpleNamespace(draft_top_logprobs=None),
finished=False,
metrics=RequestMetrics(arrival_time=time.time(), preprocess_start_time=0, preprocess_end_time=0),
)
processor.postprocess([output], mtype=99)
processor.cached_generated_tokens.put_results.assert_called_once_with([output])
def test_postprocess_logs_and_swallows_exception():
processor, _, _, _ = _make_processor()
processor.cached_generated_tokens = mock.Mock()
processor.cached_generated_tokens.put_results.side_effect = RuntimeError("boom")
output = RequestOutput(
request_id="req-error",
outputs=None,
finished=False,
metrics=RequestMetrics(arrival_time=time.time(), preprocess_start_time=0, preprocess_end_time=0),
)
processor.postprocess([output])
processor.cached_generated_tokens.put_results.assert_called_once()
def test_record_speculative_decoding_metrics_tracks_acceptance():
processor, _, _, _ = _make_processor(speculative_method="mtp", enable_logprob=True)
with mock.patch.object(token_processor, "main_process_metrics", _Metrics()):
processor.accept_token_num_per_head = [2, 3, 0, 0, 0, 0]
processor.num_draft_tokens = 0
processor.num_emitted_tokens = 0
processor.num_accepted_tokens = 0
processor._record_speculative_decoding_metrics(accept_num=[1, 2])
metrics = token_processor.main_process_metrics
assert metrics.spec_decode_num_accepted_tokens_total.value == 3
assert metrics.spec_decode_num_emitted_tokens_total.value == 5
assert pytest.approx(metrics.spec_decode_draft_acceptance_rate.value) == 0.75
assert pytest.approx(metrics.spec_decode_efficiency.value) == pytest.approx(5 / 6)
assert pytest.approx(metrics.spec_decode_draft_single_head_acceptance_rate[0].value) == 1.5
def test_recycle_resources_prefill_sends_first_token():
processor, rm, _, connector = _make_processor()
task_id = "req-prefill"
metrics = RequestMetrics(
arrival_time=time.time(),
preprocess_start_time=time.time(),
preprocess_end_time=time.time(),
inference_start_time=time.time(),
)
task = types.SimpleNamespace(
request_id=task_id,
metrics=metrics,
block_tables=[1],
disaggregate_info={"role": "prefill"},
eos_token_ids=[1],
)
task.trace_carrier = None
rm.tasks_list[0] = task
rm.req_dict[task_id] = task
result = RequestOutput(request_id=task_id, outputs=None, finished=False, metrics=metrics)
processor.engine_worker_queue = mock.Mock()
processor.engine_worker_queue.get_finished_req.side_effect = [[(task_id, "finished")]]
with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False):
processor._recycle_resources(task_id, 0, task, result, is_prefill=True)
assert rm.stop_flags[0] is True
assert connector.calls and connector.calls[0][1][0] is result
def test_recycle_resources_prefill_failure_sets_error():
processor, rm, _, connector = _make_processor()
task_id = "req-prefill-failed"
metrics = RequestMetrics(
arrival_time=time.time(),
preprocess_start_time=time.time(),
preprocess_end_time=time.time(),
inference_start_time=time.time(),
)
task = types.SimpleNamespace(
request_id=task_id,
metrics=metrics,
block_tables=[1],
disaggregate_info={"role": "prefill"},
eos_token_ids=[1],
)
task.trace_carrier = None
rm.tasks_list[0] = task
rm.req_dict[task_id] = task
result = RequestOutput(request_id=task_id, outputs=None, finished=False, metrics=metrics)
processor.engine_worker_queue = mock.Mock()
processor.engine_worker_queue.get_finished_req.side_effect = [[(task_id, "failed")]]
with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False):
processor._recycle_resources(task_id, 0, task, result, is_prefill=True)
assert result.error_code == 400
assert "failed" in result.error_message
assert connector.calls and connector.calls[0][1][0] is result
def test_clear_data_marks_all_tasks_finished():
processor, rm, _, _ = _make_processor()
metrics = RequestMetrics(
arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time()
)
task_a = types.SimpleNamespace(
request_id="req-a",
eos_token_ids=[0],
metrics=metrics,
disaggregate_info=None,
block_tables=[1],
arrival_time=time.time(),
)
task_b = types.SimpleNamespace(
request_id="req-b",
eos_token_ids=[0],
metrics=metrics,
disaggregate_info=None,
block_tables=[2],
arrival_time=time.time(),
)
rm.tasks_list[0] = task_a
rm.tasks_list[1] = task_b
rm.req_dict = {"req-a": task_a, "req-b": task_b}
processor.tokens_counter = Counter({"req-a": 2, "req-b": 1})
with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False):
processor.clear_data()
assert rm.tasks_list == [None, None]
assert not processor.tokens_counter
assert task_a.request_id in rm.recycled and task_b.request_id in rm.recycled
def test_record_speculative_decoding_accept_num_per_request_updates_maps():
processor, _, _, _ = _make_processor(speculative_method="mtp")
processor._record_speculative_decoding_accept_num_per_request("req-acc", 3)
assert processor.total_step_per_request["req-acc"] == 1
assert processor.accept_token_num_per_head_per_request["req-acc"][0] == 1
assert processor.accept_token_num_per_head[2] == 1
def test_reschedule_preempt_task_triggers_for_pending_requests():
processor, rm, _, _ = _make_processor()
rm.to_be_rescheduled_request_id_set = {"req-z"}
rm.requests = {"req-z": types.SimpleNamespace(idx=2)}
with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", True):
processor._reschedule_preempt_task(batch_size=1)
assert "reschedule-req-z" in rm.recycled
def test_process_batch_output_consumes_tokens_and_finishes_task():
processor, rm, _, _ = _make_processor()
metrics = RequestMetrics(
arrival_time=time.time(),
preprocess_start_time=time.time(),
preprocess_end_time=time.time(),
inference_start_time=time.time(),
)
metrics.decode_inference_start_time = metrics.inference_start_time
task = types.SimpleNamespace(
request_id="req-out",
disaggregate_info=None,
eos_token_ids=[7],
metrics=metrics,
output_token_ids=[],
messages=[{"role": "user", "content": "hi"}],
num_cached_tokens=0,
ic_req_data=None,
prompt_token_ids_len=0,
num_total_tokens=1,
block_tables=[1],
)
task.trace_carrier = None
task.get = lambda key, default=None: getattr(task, key, default)
rm.tasks_list[0] = task
rm.req_dict[task.request_id] = task
processor.output_tokens[1, 0] = 1
processor.output_tokens[2, 0] = 7
with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False):
processor._process_batch_output()
assert rm.stop_flags[0] is True
assert task.output_token_ids == [7]
def test_process_batch_output_logprob_records_topk_and_caching():
processor, rm, _, _ = _make_processor(enable_logprob=True, enable_prefix_caching=True, enable_output_caching=True)
metrics = RequestMetrics(
arrival_time=time.time(),
preprocess_start_time=time.time(),
preprocess_end_time=time.time(),
inference_start_time=time.time(),
)
metrics.decode_inference_start_time = metrics.inference_start_time
task = types.SimpleNamespace(
request_id="req-logprob",
disaggregate_info=None,
eos_token_ids=[3],
metrics=metrics,
output_token_ids=[],
messages=[],
num_cached_tokens=0,
ic_req_data=None,
prompt_token_ids_len=0,
num_total_tokens=1,
block_tables=[1],
get=lambda key, default=None: None,
)
task.trace_carrier = None
rm.tasks_list[0] = task
rm.req_dict[task.request_id] = task
processor.output_tokens[1, 0] = 1
token_block = np.arange(K + 1, dtype=np.int64) + 3
processor.output_tokens[2 : 2 + K + 1] = paddle.to_tensor(token_block.reshape([-1, 1]))
processor.output_scores[: K + 1] = paddle.ones([K + 1, 1], dtype="float32")
processor.output_ranks[0] = paddle.to_tensor(0, dtype="int64")
processor.cached_generated_tokens.put_results = mock.Mock()
with (
mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", True),
mock.patch.object(token_processor, "main_process_metrics", _Metrics()),
):
processor._process_batch_output()
assert rm.cached_tasks[-1] == "req-logprob"
sent = processor.cached_generated_tokens.put_results.call_args.args[0][0]
assert sent.outputs.top_logprobs is not None
def test_process_batch_output_speculative_logprob_handles_draft_batch():
processor, rm, _, _ = _make_processor(speculative_method="mtp", enable_logprob=True)
rm.tasks_list[0] = types.SimpleNamespace(request_id="req-draft", block_tables=[1], disaggregate_info=None)
target = RequestOutput(
request_id="req-draft",
outputs=types.SimpleNamespace(draft_top_logprobs=None),
finished=False,
metrics=None,
)
processor._batch_result_buffer = [target]
processor.cached_generated_tokens = mock.Mock()
processor.output_tokens[1, 0] = 4
processor.output_tokens[2, 0] = 1
processor.output_tokens[3, 0] = 1
draft_tokens = np.arange(MAX_DRAFT_TOKENS * (K + 1), dtype=np.int64).reshape([-1, 1]) + 5
processor.output_tokens[3 + MAX_BSZ : 3 + MAX_BSZ + len(draft_tokens)] = paddle.to_tensor(draft_tokens)
processor.output_scores[: MAX_DRAFT_TOKENS * (K + 1)] = paddle.ones(
[MAX_DRAFT_TOKENS * (K + 1), 1], dtype="float32"
)
processor.output_ranks[:MAX_DRAFT_TOKENS] = paddle.arange(MAX_DRAFT_TOKENS, dtype="int64")
with (
mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False),
mock.patch.object(token_processor, "main_process_metrics", _Metrics()),
):
processor._process_batch_output()
sent_batch = processor.cached_generated_tokens.put_results.call_args.args[0]
assert sent_batch and sent_batch[0].outputs.draft_top_logprobs is not None
def test_process_batch_output_speculative_recovery_stop_finishes():
processor, rm, _, _ = _make_processor(speculative_method="mtp")
metrics = RequestMetrics(
arrival_time=time.time(),
preprocess_start_time=time.time(),
preprocess_end_time=time.time(),
inference_start_time=time.time(),
)
metrics.decode_inference_start_time = metrics.inference_start_time
task = types.SimpleNamespace(
request_id="req-recover-spec",
disaggregate_info=None,
eos_token_ids=[2],
metrics=metrics,
output_token_ids=[],
messages=[],
num_cached_tokens=0,
ic_req_data=None,
prompt_token_ids_len=0,
num_total_tokens=1,
block_tables=[1],
get=lambda key, default=None: None,
)
task.trace_carrier = None
rm.tasks_list[0] = task
rm.req_dict[task.request_id] = task
processor.output_tokens[1] = 1
processor.output_tokens[2] = -3
processor.number_of_output_tokens = 1
processor.total_step = 1
processor.accept_token_num_per_head_per_request[task.request_id] = [1] + [0] * (MAX_DRAFT_TOKENS - 1)
processor.total_step_per_request[task.request_id] = 1
processor.cached_generated_tokens.put_results = mock.Mock()
with (
mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False),
mock.patch.object(token_processor, "main_process_metrics", _Metrics()),
):
processor._process_batch_output()
assert rm.stop_flags[0] is True
sent = processor.cached_generated_tokens.put_results.call_args.args[0][0]
assert sent.finished is True
assert "incomplete" in sent.error_msg
def test_process_batch_output_prefill_chunk_and_adapter_skip():
processor, rm, _, _ = _make_processor(enable_logprob=True)
metrics = RequestMetrics(
arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time()
)
metrics.inference_start_time = time.time()
metrics.decode_inference_start_time = metrics.inference_start_time
processor.cfg.scheduler_config.splitwise_role = "prefill"
task = types.SimpleNamespace(
request_id="req-prefill-chunk",
disaggregate_info={"role": "prefill"},
eos_token_ids=[1],
metrics=metrics,
output_token_ids=[],
messages=[{"role": "user", "content": "hi"}],
num_cached_tokens=0,
ic_req_data=None,
prompt_token_ids_len=1,
num_total_tokens=2,
block_tables=[1],
prefill_chunk_info=[{"idx": 0}, {"idx": 1}],
)
task.trace_carrier = None
task.get = lambda key, default=None: getattr(task, key, default)
rm.tasks_list[0] = task
rm.req_dict[task.request_id] = task
processor.output_tokens[1, 0] = 1
processor.output_tokens[2 : 2 + K + 1] = paddle.to_tensor(np.ones([K + 1, 1], dtype=np.int64))
processor.output_scores[: K + 1] = paddle.ones([K + 1, 1], dtype="float32")
processor.output_ranks[0] = paddle.to_tensor(0, dtype="int64")
processor.cached_generated_tokens.put_results = mock.Mock()
with (
mock.patch.object(envs, "FD_ENABLE_INTERNAL_ADAPTER", True),
mock.patch.object(token_processor, "main_process_metrics", _Metrics()),
):
processor._process_batch_output()
assert getattr(task, "prefill_chunk_num") == 1
assert processor.cached_generated_tokens.put_results.call_args.args[0] == []
def test_process_batch_output_handles_multimodal_and_negative_token():
processor, rm, _, _ = _make_processor()
metrics = RequestMetrics(
arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time()
)
metrics.inference_start_time = time.time()
metrics.decode_inference_start_time = metrics.inference_start_time
processor.cfg.scheduler_config.splitwise_role = "prefill"
task = types.SimpleNamespace(
request_id="req-negative",
disaggregate_info=None,
eos_token_ids=[5],
metrics=metrics,
output_token_ids=[],
messages=None,
num_cached_tokens=0,
ic_req_data=None,
prompt_token_ids_len=0,
num_total_tokens=1,
block_tables=[1],
prefill_chunk_info=None,
multimodal_inputs={"num_input_image_tokens": 2, "num_input_video_tokens": 3},
get=lambda key, default=None: None,
)
task.trace_carrier = None
rm.tasks_list[0] = task
rm.req_dict[task.request_id] = task
rm.to_be_rescheduled_request_id_set = {task.request_id}
rm.requests = {task.request_id: types.SimpleNamespace(idx=0)}
processor.output_tokens[1, 0] = 1
processor.output_tokens[2, 0] = -1
with (
mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", True),
mock.patch.object(token_processor, "main_process_metrics", _Metrics()),
):
processor._process_batch_output()
assert rm.recycled[-1] == f"reschedule-{task.request_id}"
def test_process_batch_output_speculative_logprob_targets_topk_scores():
processor, rm, _, _ = _make_processor(speculative_method="mtp", enable_logprob=True)
metrics = RequestMetrics(
arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time()
)
metrics.inference_start_time = time.time()
metrics.decode_inference_start_time = metrics.inference_start_time
task = types.SimpleNamespace(
request_id="req-spec-logprob",
disaggregate_info=None,
eos_token_ids=[9],
metrics=metrics,
output_token_ids=[],
messages=None,
num_cached_tokens=0,
ic_req_data=None,
prompt_token_ids_len=0,
num_total_tokens=1,
block_tables=[1],
get=lambda key, default=None: None,
)
task.trace_carrier = None
rm.tasks_list[0] = task
rm.req_dict[task.request_id] = task
processor.output_tokens[1, 0] = 3
processor.output_tokens[2, 0] = 1
processor.output_tokens[3, 0] = 2
token_block = np.arange(MAX_DRAFT_TOKENS * (K + 1), dtype=np.int64).reshape([-1, 1]) + 3
processor.output_tokens[3 + MAX_BSZ : 3 + MAX_BSZ + len(token_block)] = paddle.to_tensor(token_block)
score_block = paddle.arange(MAX_DRAFT_TOKENS * (K + 1), dtype="float32").reshape([-1, 1])
processor.output_scores[: MAX_DRAFT_TOKENS * (K + 1)] = score_block
processor.output_ranks[:MAX_DRAFT_TOKENS] = paddle.arange(MAX_DRAFT_TOKENS, dtype="int64")
processor.cached_generated_tokens.put_results = mock.Mock()
with (
mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False),
mock.patch.object(token_processor, "main_process_metrics", _Metrics()),
):
processor._process_batch_output()
assert processor.tokens_counter[task.request_id] == 2
def test_record_metrics_and_speculative_ngram_metrics():
processor, _, _, _ = _make_processor(speculative_method="ngram", enable_logprob=True)
metrics = _Metrics()
task = types.SimpleNamespace(
request_id="req-metrics",
metrics=RequestMetrics(arrival_time=time.time(), preprocess_start_time=0, preprocess_end_time=0),
last_token_time=time.time(),
)
with mock.patch.object(token_processor, "main_process_metrics", metrics):
processor._record_metrics(task, current_time=time.time(), token_ids=[1, 2])
processor.accept_token_num_per_head = [0, 2] + [0] * (MAX_DRAFT_TOKENS - 2)
processor.num_accepted_tokens = 3
processor.num_emitted_tokens = 3
processor._record_speculative_decoding_metrics(accept_num=[1, 1])
assert metrics.generation_tokens_total.value == 2
assert metrics.spec_decode_draft_acceptance_rate.value == 1
def test_clear_data_invokes_scheduler_cleanup():
processor, rm, _, _ = _make_processor()
metrics = RequestMetrics(
arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time()
)
task = types.SimpleNamespace(
request_id="req-clear",
arrival_time=time.time(),
disaggregate_info=None,
eos_token_ids=[0],
metrics=metrics,
output_token_ids=[],
messages=None,
num_cached_tokens=0,
ic_req_data=None,
prompt_token_ids_len=0,
num_total_tokens=1,
block_tables=[1],
get=lambda key, default=None: getattr(task, key, default),
)
task.trace_carrier = None
rm.tasks_list[0] = task
rm.stop_flags = [True] * rm.max_num_seqs
processor.tokens_counter[task.request_id] = 0
with (
mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", True),
mock.patch.object(token_processor, "main_process_metrics", _Metrics()),
):
processor.clear_data()
assert rm.cleared is True
def test_process_batch_output_skips_already_stopped_slot():
processor, rm, _, _ = _make_processor()
rm.stop_flags[0] = True
processor.output_tokens[1, 0] = 1
processor.output_tokens[2, 0] = 5
with (
mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False),
mock.patch.object(token_processor, "main_process_metrics", _Metrics()),
):
processor._process_batch_output()
assert processor.cached_generated_tokens.put_results.called
def test_process_batch_output_speculative_negative_token_reschedules():
processor, rm, _, _ = _make_processor(speculative_method="mtp")
task_id = "req-spec-neg"
metrics = RequestMetrics(
arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time()
)
metrics.inference_start_time = time.time()
metrics.decode_inference_start_time = metrics.inference_start_time
task = types.SimpleNamespace(
request_id=task_id,
disaggregate_info=None,
eos_token_ids=[1],
metrics=metrics,
output_token_ids=[],
messages=None,
num_cached_tokens=0,
ic_req_data=None,
prompt_token_ids_len=0,
num_total_tokens=1,
block_tables=[1],
get=lambda key, default=None: None,
)
task.trace_carrier = None
rm.tasks_list[0] = task
rm.req_dict[task_id] = task
rm.to_be_rescheduled_request_id_set = {task_id}
rm.requests = {task_id: types.SimpleNamespace(idx=0)}
processor.output_tokens[1] = 1
processor.output_tokens[2] = 1
processor.output_tokens[3] = 1
processor.output_tokens[2 + SPECULATE_MAX_BSZ] = -1
with (
mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", True),
mock.patch.object(token_processor, "main_process_metrics", _Metrics()),
):
processor._process_batch_output()
assert rm.recycled[-1] == f"reschedule-{task_id}"
def test_process_batch_output_use_zmq_reschedules_negative_token():
processor, rm, _, _ = _make_processor()
task = types.SimpleNamespace(request_id="req-zmq-neg")
rm.tasks_list[0] = task
rm.req_dict[task.request_id] = task
rm.to_be_rescheduled_request_id_set = {task.request_id}
stream = types.SimpleNamespace(batch_id=0, tokens=np.array([-1], dtype=np.int64), pooler_output=None)
with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", True):
results = processor._process_batch_output_use_zmq([stream])
assert results == []
assert rm.recycled[-1] == f"reschedule-{task.request_id}"
def test_process_batch_output_records_second_decode_token():
processor, rm, _, _ = _make_processor()
processor.cfg.scheduler_config.splitwise_role = "decode"
task = types.SimpleNamespace(
request_id="req-second",
disaggregate_info=None,
eos_token_ids=[2],
metrics=RequestMetrics(
arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time()
),
output_token_ids=[],
messages=None,
num_cached_tokens=0,
ic_req_data=None,
prompt_token_ids_len=0,
num_total_tokens=1,
block_tables=[1],
get=lambda key, default=None: None,
)
task.trace_carrier = None
task.metrics.inference_start_time = time.time()
task.metrics.decode_inference_start_time = task.metrics.inference_start_time
rm.tasks_list[0] = task
rm.req_dict[task.request_id] = task
processor.tokens_counter[task.request_id] = 1
processor.output_tokens[1, 0] = 1
processor.output_tokens[2, 0] = 2
with (
mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False),
mock.patch.object(token_processor, "main_process_metrics", _Metrics()),
):
processor._process_batch_output()
assert task.metrics.decode_recv_second_token_time is not None
def test_record_speculative_metrics_calls_init_when_missing():
processor, _, _, _ = _make_processor(speculative_method="mtp")
class _MinimalMetrics:
def __init__(self):
self.init_called = False
def _init_speculative_metrics(self, method, num_speculative_tokens):
self.spec_decode_num_accepted_tokens_total = _Metric()
self.spec_decode_num_emitted_tokens_total = _Metric()
self.spec_decode_draft_acceptance_rate = _Metric()
self.spec_decode_efficiency = _Metric()
self.spec_decode_num_draft_tokens_total = _Metric()
self.spec_decode_draft_single_head_acceptance_rate = [_Metric() for _ in range(MAX_DRAFT_TOKENS)]
self.init_called = True
processor.accept_token_num_per_head = [1, 1] + [0] * (MAX_DRAFT_TOKENS - 2)
processor.num_accepted_tokens = 2
processor.num_emitted_tokens = 2
metrics = _MinimalMetrics()
with mock.patch.object(token_processor, "main_process_metrics", metrics):
processor._record_speculative_decoding_metrics(accept_num=[1])
assert metrics.init_called is True
def test_process_batch_output_prefill_sets_draft_tokens():
processor, rm, _, connector = _make_processor(speculative_method="mtp")
processor.cfg.scheduler_config.splitwise_role = "prefill"
metrics = RequestMetrics(
arrival_time=time.time(),
preprocess_start_time=time.time(),
preprocess_end_time=time.time(),
inference_start_time=time.time(),
)
metrics.decode_inference_start_time = metrics.inference_start_time
task = types.SimpleNamespace(
request_id="req-prefill-draft",
disaggregate_info={"role": "prefill"},
eos_token_ids=[99],
metrics=metrics,
output_token_ids=[],
messages=None,
num_cached_tokens=0,
ic_req_data=None,
prompt_token_ids_len=0,
num_total_tokens=1,
block_tables=[1],
get=lambda key, default=None: None,
)
task.trace_carrier = None
rm.tasks_list[0] = task
rm.req_dict[task.request_id] = task
processor.engine_worker_queue = mock.Mock()
processor.engine_worker_queue.get_finished_req.side_effect = [[(task.request_id, "finished")]]
processor.output_tokens[1] = 1
processor.output_tokens[2] = 2
processor.output_tokens[2 + SPECULATE_MAX_BSZ] = 11
processor.output_tokens[2 + SPECULATE_MAX_BSZ + 1] = 12
with (
mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False),
mock.patch.object(token_processor, "main_process_metrics", _Metrics()),
):
processor._process_batch_output()
assert connector.calls
sent = connector.calls[0][1][0]
assert sent.outputs.draft_token_ids == [11, 12]
def test_process_batch_output_logs_recovery_stop_for_non_speculative():
processor, rm, _, _ = _make_processor()
metrics = RequestMetrics(
arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time()
)
metrics.inference_start_time = time.time()
metrics.decode_inference_start_time = metrics.inference_start_time
task = types.SimpleNamespace(
request_id="req-recovery",
disaggregate_info=None,
eos_token_ids=[1],
metrics=metrics,
output_token_ids=[],
messages=None,
num_cached_tokens=0,
ic_req_data=None,
prompt_token_ids_len=0,
num_total_tokens=1,
block_tables=[1],
)
task.trace_carrier = None
task.get = lambda k, d=None: getattr(task, k, d)
rm.tasks_list[0] = task
rm.req_dict[task.request_id] = task
processor.output_tokens[1, 0] = 1
processor.output_tokens[2, 0] = token_processor.RECOVERY_STOP_SIGNAL
with (
mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False),
mock.patch.object(token_processor, "main_process_metrics", _Metrics()),
):
processor._process_batch_output()
assert rm.stop_flags[0] is True
def test_process_batch_output_sets_multimodal_token_counts():
processor, rm, _, _ = _make_processor()
metrics = RequestMetrics(
arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time()
)
metrics.inference_start_time = time.time()
metrics.decode_inference_start_time = metrics.inference_start_time
task = types.SimpleNamespace(
request_id="req-mm",
disaggregate_info=None,
eos_token_ids=[7],
metrics=metrics,
output_token_ids=[],
messages=None,
num_cached_tokens=0,
ic_req_data=None,
prompt_token_ids_len=0,
num_total_tokens=1,
block_tables=[1],
multimodal_inputs={"num_input_image_tokens": 4, "num_input_video_tokens": 5},
)
task.trace_carrier = None
task.get = lambda key, default=None: getattr(task, key, default)
rm.tasks_list[0] = task
rm.req_dict[task.request_id] = task
processor.output_tokens[1, 0] = 1
processor.output_tokens[2, 0] = 7
with mock.patch.object(token_processor, "main_process_metrics", _Metrics()):
processor._process_batch_output()
sent = processor.cached_generated_tokens.put_results.call_args.args[0][0]
assert sent.num_input_image_tokens == 4 and sent.num_input_video_tokens == 5
def test_warmup_token_processor_initialization():
cfg = _DummyCfg()
with mock.patch.object(token_processor.TokenProcessor, "__init__", lambda self, _cfg: None):
warm = token_processor.WarmUpTokenProcessor(cfg)
assert warm._is_running is True and warm._is_blocking is True
warm.postprocess([])
def test_warmup_processor_stop_joins_worker():
warm = token_processor.WarmUpTokenProcessor.__new__(token_processor.WarmUpTokenProcessor)
warm._is_running = True
worker = mock.Mock()
warm.worker = worker
warm.stop()
worker.join.assert_called_once()
def test_healthy_behaviour_respects_timeout(monkeypatch):
processor, _, _, _ = _make_processor()
processor.timestamp_for_alive_before_handle_batch = time.time() - 1
processor.timestamp_for_alive_after_handle_batch = None
monkeypatch.setattr(envs, "FD_TOKEN_PROCESSOR_HEALTH_TIMEOUT", 0.1)
assert processor.healthy() is False
def test_healthy_detects_engine_hang():
processor, _, _, _ = _make_processor()
processor.timestamp_for_alive_before_handle_batch = None
processor.timestamp_for_alive_after_handle_batch = time.time()
processor.engine_output_token_hang = True
assert processor.healthy() is False
def test_healthy_recent_prehandle_activity_is_ok(monkeypatch):
processor, _, _, _ = _make_processor()
processor.timestamp_for_alive_before_handle_batch = time.time()
processor.timestamp_for_alive_after_handle_batch = None
monkeypatch.setattr(envs, "FD_TOKEN_PROCESSOR_HEALTH_TIMEOUT", 5)
assert processor.healthy() is True
def test_record_completion_metrics_updates_counters():
processor, _, _, _ = _make_processor()
task_id = "req-complete"
metrics = RequestMetrics(
arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time()
)
metrics.inference_start_time = time.time() - 0.2
metrics.engine_recv_first_token_time = time.time() - 0.1
task = types.SimpleNamespace(request_id=task_id, metrics=metrics, user="user-a")
processor.tokens_counter[task_id] = 4
with (
mock.patch.object(token_processor, "main_process_metrics", _Metrics()) as metrics_obj,
mock.patch.object(token_processor, "trace_print"),
):
processor._record_completion_metrics(task, current_time=time.time())
assert metrics_obj.request_decode_time.value is not None
assert metrics_obj.request_success_total.value == 1
assert metrics_obj.request_generation_tokens.value == 4
def test_process_sampling_results_use_zmq_rejects_speculative():
processor, _, _, _ = _make_processor(speculative_method="mtp")
with pytest.raises(NotImplementedError):
processor.process_sampling_results_use_zmq()