mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
* Add unit tests for TokenProcessor functionality * Add trace stubs for token processor tests * Increase token processor test coverage * Clean up imports in test_token_processor.py Remove unnecessary path manipulation in test file. * Cleanup: Remove unused imports in test_token_processor Removed unused imports from the test file. * Add trace_carrier to task in test cases Added trace_carrier attribute to task in multiple test cases to ensure proper handling of trace information. * Refine token processor tests for safe coverage * Expand postprocess coverage * Add ZMQ logprob parsing test --------- Co-authored-by: CSWYF3634076 <wangyafeng@baidu.com> Co-authored-by: Tao Luo <luotao02@baidu.com>
1297 lines
47 KiB
Python
1297 lines
47 KiB
Python
"""
|
|
Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import copy
|
|
import time
|
|
import types
|
|
from collections import Counter
|
|
from unittest import mock
|
|
|
|
import numpy as np
|
|
import paddle
|
|
import pytest
|
|
|
|
from fastdeploy import envs
|
|
from fastdeploy.engine.request import Request, RequestMetrics, RequestOutput
|
|
from fastdeploy.output import token_processor
|
|
from fastdeploy.output.token_processor import (
|
|
MAX_BSZ,
|
|
MAX_DRAFT_TOKENS,
|
|
SPECULATE_MAX_BSZ,
|
|
K,
|
|
TokenProcessor,
|
|
)
|
|
|
|
|
|
class _DummyCfg:
|
|
def __init__(
|
|
self,
|
|
speculative_method=None,
|
|
enable_logprob=False,
|
|
max_num_seqs=2,
|
|
enable_prefix_caching=False,
|
|
enable_output_caching=False,
|
|
):
|
|
self.parallel_config = types.SimpleNamespace(
|
|
local_data_parallel_id=0,
|
|
enable_expert_parallel=False,
|
|
data_parallel_size=1,
|
|
)
|
|
self.speculative_config = types.SimpleNamespace(
|
|
method=speculative_method,
|
|
num_speculative_tokens=2,
|
|
)
|
|
self.model_config = types.SimpleNamespace(enable_logprob=enable_logprob)
|
|
self.scheduler_config = types.SimpleNamespace(name="default", splitwise_role="decode")
|
|
self.cache_config = types.SimpleNamespace(
|
|
enable_prefix_caching=enable_prefix_caching,
|
|
enable_output_caching=enable_output_caching,
|
|
block_size=64,
|
|
)
|
|
self.max_num_seqs = max_num_seqs
|
|
self.splitwise_version = "v1"
|
|
|
|
|
|
class _DummyResourceManager:
|
|
def __init__(self, max_num_seqs=2):
|
|
self.max_num_seqs = max_num_seqs
|
|
self.stop_flags = [False] * max_num_seqs
|
|
self.tasks_list = [None] * max_num_seqs
|
|
self.req_dict = {}
|
|
self.requests = {}
|
|
self.to_be_rescheduled_request_id_set = set()
|
|
self.recycled = []
|
|
self.cached_tasks = []
|
|
self.cleared = False
|
|
|
|
def _recycle_block_tables(self, task):
|
|
self.recycled.append(task.request_id)
|
|
|
|
def reschedule_preempt_task(self, request_id):
|
|
self.recycled.append(f"reschedule-{request_id}")
|
|
|
|
def finish_requests_async(self, request_id):
|
|
self.recycled.append(f"finish-{request_id}")
|
|
|
|
def total_block_number(self):
|
|
return 8
|
|
|
|
def available_batch(self):
|
|
return self.tasks_list.count(None)
|
|
|
|
def info(self):
|
|
return "rm-info"
|
|
|
|
def get_finished_req(self):
|
|
return []
|
|
|
|
def cache_output_tokens(self, task):
|
|
self.cached_tasks.append(task.request_id)
|
|
|
|
def clear_data(self):
|
|
self.cleared = True
|
|
|
|
|
|
class _DummyQueue:
|
|
def get_finished_req(self):
|
|
return []
|
|
|
|
|
|
class _DummyConnector:
|
|
def __init__(self):
|
|
self.calls = []
|
|
|
|
def send_first_token(self, info, results):
|
|
self.calls.append((info, results))
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _ensure_cpu():
|
|
paddle.device.set_device("cpu")
|
|
|
|
|
|
def _make_processor(
|
|
speculative_method=None,
|
|
enable_logprob=False,
|
|
max_num_seqs=2,
|
|
enable_prefix_caching=False,
|
|
enable_output_caching=False,
|
|
):
|
|
cfg = _DummyCfg(
|
|
speculative_method=speculative_method,
|
|
enable_logprob=enable_logprob,
|
|
max_num_seqs=max_num_seqs,
|
|
enable_prefix_caching=enable_prefix_caching,
|
|
enable_output_caching=enable_output_caching,
|
|
)
|
|
cache = mock.Mock()
|
|
queue = _DummyQueue()
|
|
connector = _DummyConnector()
|
|
processor = TokenProcessor(cfg, cache, queue, connector)
|
|
rm = _DummyResourceManager(max_num_seqs)
|
|
processor.set_resource_manager(rm)
|
|
return processor, rm, cache, connector
|
|
|
|
|
|
class _Metric:
|
|
def __init__(self):
|
|
self.value = None
|
|
|
|
def set(self, v):
|
|
self.value = v
|
|
|
|
def inc(self, v=1):
|
|
self.value = (self.value or 0) + v
|
|
|
|
def dec(self, v=1):
|
|
self.value = (self.value or 0) - v
|
|
|
|
def observe(self, v):
|
|
self.value = v
|
|
|
|
|
|
class _Metrics:
|
|
def __init__(self):
|
|
self.spec_decode_num_accepted_tokens_total = _Metric()
|
|
self.spec_decode_num_emitted_tokens_total = _Metric()
|
|
self.spec_decode_draft_acceptance_rate = _Metric()
|
|
self.spec_decode_efficiency = _Metric()
|
|
self.spec_decode_num_draft_tokens_total = _Metric()
|
|
self.spec_decode_draft_single_head_acceptance_rate = [_Metric() for _ in range(MAX_DRAFT_TOKENS)]
|
|
self.time_per_output_token = _Metric()
|
|
self.generation_tokens_total = _Metric()
|
|
self.time_to_first_token = _Metric()
|
|
self.request_queue_time = _Metric()
|
|
self.request_prefill_time = _Metric()
|
|
self.request_decode_time = _Metric()
|
|
self.request_inference_time = _Metric()
|
|
self.request_generation_tokens = _Metric()
|
|
self.num_requests_running = _Metric()
|
|
self.request_success_total = _Metric()
|
|
self.available_gpu_block_num = _Metric()
|
|
self.batch_size = _Metric()
|
|
self.available_batch_size = _Metric()
|
|
|
|
def _init_speculative_metrics(self, method, num_speculative_tokens):
|
|
return None
|
|
|
|
|
|
def test_init_allocates_expected_buffers():
|
|
processor, _, _, _ = _make_processor()
|
|
assert list(processor.output_tokens.shape) == [MAX_BSZ + 2, 1]
|
|
|
|
processor_logprob, _, _, _ = _make_processor(enable_logprob=True)
|
|
assert list(processor_logprob.output_scores.shape) == [MAX_BSZ * (K + 1), 1]
|
|
|
|
processor_spec, _, _, _ = _make_processor(speculative_method="mtp", enable_logprob=False)
|
|
assert processor_spec.output_tokens.shape[0] == SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2
|
|
|
|
|
|
def test_run_uses_correct_worker_based_on_flag():
|
|
processor, _, _, _ = _make_processor()
|
|
processor.worker = None
|
|
with (
|
|
mock.patch.object(envs, "FD_USE_GET_SAVE_OUTPUT_V1", True),
|
|
mock.patch("fastdeploy.output.token_processor.threading.Thread") as thread_cls,
|
|
):
|
|
fake_thread = mock.Mock()
|
|
thread_cls.return_value = fake_thread
|
|
processor.run()
|
|
target = thread_cls.call_args.kwargs["target"]
|
|
assert target.__func__ is processor.process_sampling_results_use_zmq.__func__
|
|
assert fake_thread.daemon is True
|
|
|
|
processor.worker = object()
|
|
with pytest.raises(Exception):
|
|
processor.run()
|
|
|
|
|
|
def test_cleanup_resources_shuts_down_executor():
|
|
processor, _, _, _ = _make_processor()
|
|
processor.executor = mock.Mock()
|
|
processor._cleanup_resources()
|
|
processor.executor.shutdown.assert_called_once_with(wait=False)
|
|
|
|
|
|
def test_reschedule_preempt_task_use_zmq_reschedules_missing_batch():
|
|
processor, rm, _, _ = _make_processor()
|
|
rm.to_be_rescheduled_request_id_set = {"req-a"}
|
|
rm.requests = {"req-a": types.SimpleNamespace(idx=1)}
|
|
with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", True):
|
|
processor._reschedule_preempt_task_use_zmq([types.SimpleNamespace(batch_id=0)])
|
|
assert "reschedule-req-a" in rm.recycled
|
|
|
|
|
|
def test_process_batch_draft_tokens_collects_top_logprobs():
|
|
processor, rm, _, _ = _make_processor(speculative_method="mtp", enable_logprob=True)
|
|
rm.tasks_list[0] = types.SimpleNamespace(request_id="task-0", block_tables=[1])
|
|
ranks = paddle.to_tensor(np.array([[0, 1, 1]]))
|
|
scores = paddle.ones([1, 3, 1], dtype="float32")
|
|
tokens = paddle.arange(9, dtype="int64").reshape([1, 3, 3])
|
|
|
|
results = processor._process_batch_draft_tokens(
|
|
4, batch=1, accept_num=[2], tokens=tokens, scores=scores, ranks=ranks
|
|
)
|
|
|
|
assert len(results) == 1
|
|
assert results[0].outputs.draft_top_logprobs.logprob_token_ids[0][0] == 0
|
|
assert results[0].outputs.draft_top_logprobs.sampled_token_ranks[-1] == 1
|
|
|
|
|
|
def test_process_batch_output_use_zmq_finishes_on_eos():
|
|
processor, rm, cache, connector = _make_processor()
|
|
base_time = time.time()
|
|
task = Request(
|
|
request_id="req-zmq",
|
|
prompt=["hi"],
|
|
prompt_token_ids=[1, 2],
|
|
prompt_token_ids_len=2,
|
|
messages=[[{"content": "hi", "role": "user"}]],
|
|
history=[],
|
|
tools=[],
|
|
system="system",
|
|
eos_token_ids=[6],
|
|
metrics=RequestMetrics(
|
|
arrival_time=base_time,
|
|
preprocess_start_time=base_time - 0.2,
|
|
preprocess_end_time=base_time - 0.1,
|
|
inference_start_time=base_time,
|
|
),
|
|
)
|
|
task.metrics.decode_inference_start_time = base_time
|
|
task.disaggregate_info = None
|
|
task.ic_req_data = None
|
|
rm.tasks_list[0] = task
|
|
rm.req_dict[task.request_id] = task
|
|
rm.requests[task.request_id] = types.SimpleNamespace(idx=0)
|
|
|
|
tokens = np.array([5, 6], dtype=np.int64)
|
|
stream = types.SimpleNamespace(batch_id=0, tokens=tokens, pooler_output=None)
|
|
with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False):
|
|
results = processor._process_batch_output_use_zmq([stream])
|
|
|
|
assert results[0].finished is True
|
|
assert task.output_token_ids == [5, 6]
|
|
assert rm.stop_flags[0] is True
|
|
assert connector.calls == []
|
|
|
|
|
|
def test_process_batch_output_use_zmq_parses_logprobs():
|
|
processor, rm, _, _ = _make_processor(enable_logprob=True)
|
|
base_time = time.time()
|
|
task = Request(
|
|
request_id="req-zmq-logprob",
|
|
prompt=["hi"],
|
|
prompt_token_ids=[1],
|
|
prompt_token_ids_len=1,
|
|
messages=[[{"content": "hi", "role": "user"}]],
|
|
history=[],
|
|
tools=[],
|
|
system="system",
|
|
eos_token_ids=[6],
|
|
metrics=RequestMetrics(
|
|
arrival_time=base_time,
|
|
preprocess_start_time=base_time - 0.2,
|
|
preprocess_end_time=base_time - 0.1,
|
|
inference_start_time=base_time,
|
|
),
|
|
)
|
|
task.metrics.decode_inference_start_time = base_time
|
|
task.disaggregate_info = None
|
|
task.ic_req_data = None
|
|
rm.tasks_list[0] = task
|
|
rm.req_dict[task.request_id] = task
|
|
rm.requests[task.request_id] = types.SimpleNamespace(idx=0)
|
|
|
|
logprob_list = token_processor.LogprobsLists(
|
|
logprob_token_ids=[[1, 2]],
|
|
logprobs=[[0.1, 0.2]],
|
|
sampled_token_ranks=[0],
|
|
)
|
|
logprob_holder = types.SimpleNamespace(tolists=lambda: logprob_list)
|
|
stream = types.SimpleNamespace(
|
|
batch_id=0,
|
|
tokens=np.array([5], dtype=np.int64),
|
|
pooler_output=None,
|
|
logprobs=logprob_holder,
|
|
prompt_logprobs={"0": -0.1},
|
|
)
|
|
with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False):
|
|
results = processor._process_batch_output_use_zmq([stream])
|
|
|
|
assert results[0].outputs.logprob == 0.1
|
|
assert results[0].outputs.top_logprobs is logprob_list
|
|
assert results[0].prompt_logprobs == {"0": -0.1}
|
|
|
|
|
|
def test_recycle_resources_updates_metrics_and_state():
|
|
processor, rm, _, _ = _make_processor()
|
|
task = types.SimpleNamespace(request_id="req-1", block_tables=[1], disaggregate_info=None)
|
|
task.trace_carrier = None
|
|
rm.tasks_list[0] = task
|
|
rm.req_dict[task.request_id] = task
|
|
|
|
metrics = RequestMetrics(
|
|
arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time()
|
|
)
|
|
result = RequestOutput(request_id=task.request_id, outputs=None, finished=False, metrics=metrics)
|
|
with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False):
|
|
processor._recycle_resources(task.request_id, 0, task, result, is_prefill=False)
|
|
|
|
assert rm.stop_flags[0] is True
|
|
assert task.request_id not in rm.req_dict
|
|
assert rm.recycled[-1] == task.request_id
|
|
assert processor.tokens_counter.get(task.request_id) is None
|
|
|
|
|
|
def test_compute_speculative_status_builds_metrics():
|
|
processor, rm, _, _ = _make_processor(speculative_method="mtp", enable_logprob=True)
|
|
req_id = "req-spec"
|
|
rm.tasks_list[0] = types.SimpleNamespace(request_id=req_id, block_tables=[1])
|
|
metrics = RequestMetrics(
|
|
arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time()
|
|
)
|
|
result = RequestOutput(request_id=req_id, outputs=None, finished=False, metrics=metrics)
|
|
|
|
processor.total_step = 2
|
|
processor.number_of_output_tokens = 4
|
|
processor.speculative_stats_step = 0
|
|
processor.accept_token_num_per_head = [2, 1] + [0] * (MAX_DRAFT_TOKENS - 2)
|
|
processor.accept_token_num_per_head_per_request[req_id] = [2, 1]
|
|
processor.total_step_per_request[req_id] = 2
|
|
processor._compute_speculative_status(result)
|
|
|
|
assert hasattr(result.metrics, "speculate_metrics")
|
|
assert result.metrics.speculate_metrics.accepted_tokens == 3
|
|
|
|
|
|
def test_process_per_token_handles_recovery_stop_and_cleanup():
|
|
processor, rm, _, _ = _make_processor()
|
|
metrics = RequestMetrics(
|
|
arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time()
|
|
)
|
|
metrics.inference_start_time = time.time()
|
|
metrics.decode_inference_start_time = metrics.inference_start_time
|
|
task = types.SimpleNamespace(
|
|
request_id="req-recover",
|
|
prompt=["hi"],
|
|
prompt_token_ids=[1],
|
|
prompt_token_ids_len=1,
|
|
messages=[],
|
|
history=[],
|
|
tools=[],
|
|
system="sys",
|
|
eos_token_ids=[99],
|
|
metrics=metrics,
|
|
output_token_ids=[],
|
|
block_tables=[1],
|
|
)
|
|
task.trace_carrier = None
|
|
rm.tasks_list[0] = task
|
|
rm.req_dict[task.request_id] = task
|
|
rm.requests[task.request_id] = types.SimpleNamespace(idx=0)
|
|
|
|
result = RequestOutput(
|
|
request_id=task.request_id,
|
|
outputs=types.SimpleNamespace(token_ids=[]),
|
|
finished=False,
|
|
metrics=copy.copy(task.metrics),
|
|
)
|
|
|
|
with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False):
|
|
stopped = processor._process_per_token(
|
|
task,
|
|
batch_id=0,
|
|
token_ids=np.array([token_processor.RECOVERY_STOP_SIGNAL]),
|
|
result=result,
|
|
is_prefill=False,
|
|
)
|
|
|
|
assert stopped.finished is True
|
|
assert "incomplete" in stopped.error_msg
|
|
assert rm.stop_flags[0] is True
|
|
assert rm.tasks_list[0] is None
|
|
assert processor.tokens_counter.get(task.request_id) is None
|
|
|
|
|
|
def test_postprocess_buffers_and_merges_speculative_results():
|
|
processor, _, _, _ = _make_processor(speculative_method="mtp", enable_logprob=True)
|
|
processor.cached_generated_tokens = mock.Mock()
|
|
|
|
target_output = RequestOutput(
|
|
request_id="req-t",
|
|
outputs=types.SimpleNamespace(draft_top_logprobs=None),
|
|
finished=False,
|
|
metrics=RequestMetrics(arrival_time=time.time(), preprocess_start_time=0, preprocess_end_time=0),
|
|
)
|
|
draft_output = RequestOutput(
|
|
request_id="req-t",
|
|
outputs=types.SimpleNamespace(draft_top_logprobs="draft-logprobs"),
|
|
finished=False,
|
|
metrics=RequestMetrics(arrival_time=time.time(), preprocess_start_time=0, preprocess_end_time=0),
|
|
)
|
|
|
|
processor.postprocess([target_output], mtype=3)
|
|
assert processor._batch_result_buffer == [target_output]
|
|
|
|
processor.postprocess([draft_output], mtype=4)
|
|
processor.cached_generated_tokens.put_results.assert_called_once()
|
|
merged = processor.cached_generated_tokens.put_results.call_args.args[0][0]
|
|
assert merged.outputs.draft_top_logprobs == "draft-logprobs"
|
|
assert processor._batch_result_buffer is None
|
|
|
|
|
|
def test_postprocess_emits_finished_speculative_batch():
|
|
processor, _, _, _ = _make_processor(speculative_method="mtp", enable_logprob=True)
|
|
processor.cached_generated_tokens = mock.Mock()
|
|
|
|
finished_output = RequestOutput(
|
|
request_id="req-finished",
|
|
outputs=types.SimpleNamespace(draft_top_logprobs=None),
|
|
finished=True,
|
|
metrics=RequestMetrics(arrival_time=time.time(), preprocess_start_time=0, preprocess_end_time=0),
|
|
)
|
|
|
|
processor.postprocess([finished_output], mtype=3)
|
|
|
|
processor.cached_generated_tokens.put_results.assert_called_once_with([finished_output])
|
|
assert processor._batch_result_buffer is None
|
|
|
|
|
|
def test_postprocess_passes_through_unknown_type():
|
|
processor, _, _, _ = _make_processor(speculative_method="mtp", enable_logprob=True)
|
|
processor.cached_generated_tokens = mock.Mock()
|
|
|
|
output = RequestOutput(
|
|
request_id="req-direct",
|
|
outputs=types.SimpleNamespace(draft_top_logprobs=None),
|
|
finished=False,
|
|
metrics=RequestMetrics(arrival_time=time.time(), preprocess_start_time=0, preprocess_end_time=0),
|
|
)
|
|
|
|
processor.postprocess([output], mtype=99)
|
|
|
|
processor.cached_generated_tokens.put_results.assert_called_once_with([output])
|
|
|
|
|
|
def test_postprocess_logs_and_swallows_exception():
|
|
processor, _, _, _ = _make_processor()
|
|
processor.cached_generated_tokens = mock.Mock()
|
|
processor.cached_generated_tokens.put_results.side_effect = RuntimeError("boom")
|
|
|
|
output = RequestOutput(
|
|
request_id="req-error",
|
|
outputs=None,
|
|
finished=False,
|
|
metrics=RequestMetrics(arrival_time=time.time(), preprocess_start_time=0, preprocess_end_time=0),
|
|
)
|
|
|
|
processor.postprocess([output])
|
|
|
|
processor.cached_generated_tokens.put_results.assert_called_once()
|
|
|
|
|
|
def test_record_speculative_decoding_metrics_tracks_acceptance():
|
|
processor, _, _, _ = _make_processor(speculative_method="mtp", enable_logprob=True)
|
|
with mock.patch.object(token_processor, "main_process_metrics", _Metrics()):
|
|
processor.accept_token_num_per_head = [2, 3, 0, 0, 0, 0]
|
|
processor.num_draft_tokens = 0
|
|
processor.num_emitted_tokens = 0
|
|
processor.num_accepted_tokens = 0
|
|
|
|
processor._record_speculative_decoding_metrics(accept_num=[1, 2])
|
|
|
|
metrics = token_processor.main_process_metrics
|
|
assert metrics.spec_decode_num_accepted_tokens_total.value == 3
|
|
assert metrics.spec_decode_num_emitted_tokens_total.value == 5
|
|
assert pytest.approx(metrics.spec_decode_draft_acceptance_rate.value) == 0.75
|
|
assert pytest.approx(metrics.spec_decode_efficiency.value) == pytest.approx(5 / 6)
|
|
assert pytest.approx(metrics.spec_decode_draft_single_head_acceptance_rate[0].value) == 1.5
|
|
|
|
|
|
def test_recycle_resources_prefill_sends_first_token():
|
|
processor, rm, _, connector = _make_processor()
|
|
task_id = "req-prefill"
|
|
metrics = RequestMetrics(
|
|
arrival_time=time.time(),
|
|
preprocess_start_time=time.time(),
|
|
preprocess_end_time=time.time(),
|
|
inference_start_time=time.time(),
|
|
)
|
|
task = types.SimpleNamespace(
|
|
request_id=task_id,
|
|
metrics=metrics,
|
|
block_tables=[1],
|
|
disaggregate_info={"role": "prefill"},
|
|
eos_token_ids=[1],
|
|
)
|
|
task.trace_carrier = None
|
|
rm.tasks_list[0] = task
|
|
rm.req_dict[task_id] = task
|
|
result = RequestOutput(request_id=task_id, outputs=None, finished=False, metrics=metrics)
|
|
processor.engine_worker_queue = mock.Mock()
|
|
processor.engine_worker_queue.get_finished_req.side_effect = [[(task_id, "finished")]]
|
|
|
|
with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False):
|
|
processor._recycle_resources(task_id, 0, task, result, is_prefill=True)
|
|
|
|
assert rm.stop_flags[0] is True
|
|
assert connector.calls and connector.calls[0][1][0] is result
|
|
|
|
|
|
def test_recycle_resources_prefill_failure_sets_error():
|
|
processor, rm, _, connector = _make_processor()
|
|
task_id = "req-prefill-failed"
|
|
metrics = RequestMetrics(
|
|
arrival_time=time.time(),
|
|
preprocess_start_time=time.time(),
|
|
preprocess_end_time=time.time(),
|
|
inference_start_time=time.time(),
|
|
)
|
|
task = types.SimpleNamespace(
|
|
request_id=task_id,
|
|
metrics=metrics,
|
|
block_tables=[1],
|
|
disaggregate_info={"role": "prefill"},
|
|
eos_token_ids=[1],
|
|
)
|
|
task.trace_carrier = None
|
|
rm.tasks_list[0] = task
|
|
rm.req_dict[task_id] = task
|
|
result = RequestOutput(request_id=task_id, outputs=None, finished=False, metrics=metrics)
|
|
processor.engine_worker_queue = mock.Mock()
|
|
processor.engine_worker_queue.get_finished_req.side_effect = [[(task_id, "failed")]]
|
|
|
|
with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False):
|
|
processor._recycle_resources(task_id, 0, task, result, is_prefill=True)
|
|
|
|
assert result.error_code == 400
|
|
assert "failed" in result.error_message
|
|
assert connector.calls and connector.calls[0][1][0] is result
|
|
|
|
|
|
def test_clear_data_marks_all_tasks_finished():
|
|
processor, rm, _, _ = _make_processor()
|
|
metrics = RequestMetrics(
|
|
arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time()
|
|
)
|
|
task_a = types.SimpleNamespace(
|
|
request_id="req-a",
|
|
eos_token_ids=[0],
|
|
metrics=metrics,
|
|
disaggregate_info=None,
|
|
block_tables=[1],
|
|
arrival_time=time.time(),
|
|
)
|
|
task_b = types.SimpleNamespace(
|
|
request_id="req-b",
|
|
eos_token_ids=[0],
|
|
metrics=metrics,
|
|
disaggregate_info=None,
|
|
block_tables=[2],
|
|
arrival_time=time.time(),
|
|
)
|
|
rm.tasks_list[0] = task_a
|
|
rm.tasks_list[1] = task_b
|
|
rm.req_dict = {"req-a": task_a, "req-b": task_b}
|
|
processor.tokens_counter = Counter({"req-a": 2, "req-b": 1})
|
|
|
|
with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False):
|
|
processor.clear_data()
|
|
|
|
assert rm.tasks_list == [None, None]
|
|
assert not processor.tokens_counter
|
|
assert task_a.request_id in rm.recycled and task_b.request_id in rm.recycled
|
|
|
|
|
|
def test_record_speculative_decoding_accept_num_per_request_updates_maps():
|
|
processor, _, _, _ = _make_processor(speculative_method="mtp")
|
|
processor._record_speculative_decoding_accept_num_per_request("req-acc", 3)
|
|
|
|
assert processor.total_step_per_request["req-acc"] == 1
|
|
assert processor.accept_token_num_per_head_per_request["req-acc"][0] == 1
|
|
assert processor.accept_token_num_per_head[2] == 1
|
|
|
|
|
|
def test_reschedule_preempt_task_triggers_for_pending_requests():
|
|
processor, rm, _, _ = _make_processor()
|
|
rm.to_be_rescheduled_request_id_set = {"req-z"}
|
|
rm.requests = {"req-z": types.SimpleNamespace(idx=2)}
|
|
with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", True):
|
|
processor._reschedule_preempt_task(batch_size=1)
|
|
|
|
assert "reschedule-req-z" in rm.recycled
|
|
|
|
|
|
def test_process_batch_output_consumes_tokens_and_finishes_task():
|
|
processor, rm, _, _ = _make_processor()
|
|
metrics = RequestMetrics(
|
|
arrival_time=time.time(),
|
|
preprocess_start_time=time.time(),
|
|
preprocess_end_time=time.time(),
|
|
inference_start_time=time.time(),
|
|
)
|
|
metrics.decode_inference_start_time = metrics.inference_start_time
|
|
task = types.SimpleNamespace(
|
|
request_id="req-out",
|
|
disaggregate_info=None,
|
|
eos_token_ids=[7],
|
|
metrics=metrics,
|
|
output_token_ids=[],
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
num_cached_tokens=0,
|
|
ic_req_data=None,
|
|
prompt_token_ids_len=0,
|
|
num_total_tokens=1,
|
|
block_tables=[1],
|
|
)
|
|
task.trace_carrier = None
|
|
task.get = lambda key, default=None: getattr(task, key, default)
|
|
rm.tasks_list[0] = task
|
|
rm.req_dict[task.request_id] = task
|
|
processor.output_tokens[1, 0] = 1
|
|
processor.output_tokens[2, 0] = 7
|
|
|
|
with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False):
|
|
processor._process_batch_output()
|
|
|
|
assert rm.stop_flags[0] is True
|
|
assert task.output_token_ids == [7]
|
|
|
|
|
|
def test_process_batch_output_logprob_records_topk_and_caching():
|
|
processor, rm, _, _ = _make_processor(enable_logprob=True, enable_prefix_caching=True, enable_output_caching=True)
|
|
metrics = RequestMetrics(
|
|
arrival_time=time.time(),
|
|
preprocess_start_time=time.time(),
|
|
preprocess_end_time=time.time(),
|
|
inference_start_time=time.time(),
|
|
)
|
|
metrics.decode_inference_start_time = metrics.inference_start_time
|
|
task = types.SimpleNamespace(
|
|
request_id="req-logprob",
|
|
disaggregate_info=None,
|
|
eos_token_ids=[3],
|
|
metrics=metrics,
|
|
output_token_ids=[],
|
|
messages=[],
|
|
num_cached_tokens=0,
|
|
ic_req_data=None,
|
|
prompt_token_ids_len=0,
|
|
num_total_tokens=1,
|
|
block_tables=[1],
|
|
get=lambda key, default=None: None,
|
|
)
|
|
task.trace_carrier = None
|
|
rm.tasks_list[0] = task
|
|
rm.req_dict[task.request_id] = task
|
|
processor.output_tokens[1, 0] = 1
|
|
token_block = np.arange(K + 1, dtype=np.int64) + 3
|
|
processor.output_tokens[2 : 2 + K + 1] = paddle.to_tensor(token_block.reshape([-1, 1]))
|
|
processor.output_scores[: K + 1] = paddle.ones([K + 1, 1], dtype="float32")
|
|
processor.output_ranks[0] = paddle.to_tensor(0, dtype="int64")
|
|
processor.cached_generated_tokens.put_results = mock.Mock()
|
|
|
|
with (
|
|
mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", True),
|
|
mock.patch.object(token_processor, "main_process_metrics", _Metrics()),
|
|
):
|
|
processor._process_batch_output()
|
|
|
|
assert rm.cached_tasks[-1] == "req-logprob"
|
|
sent = processor.cached_generated_tokens.put_results.call_args.args[0][0]
|
|
assert sent.outputs.top_logprobs is not None
|
|
|
|
|
|
def test_process_batch_output_speculative_logprob_handles_draft_batch():
|
|
processor, rm, _, _ = _make_processor(speculative_method="mtp", enable_logprob=True)
|
|
rm.tasks_list[0] = types.SimpleNamespace(request_id="req-draft", block_tables=[1], disaggregate_info=None)
|
|
target = RequestOutput(
|
|
request_id="req-draft",
|
|
outputs=types.SimpleNamespace(draft_top_logprobs=None),
|
|
finished=False,
|
|
metrics=None,
|
|
)
|
|
processor._batch_result_buffer = [target]
|
|
processor.cached_generated_tokens = mock.Mock()
|
|
processor.output_tokens[1, 0] = 4
|
|
processor.output_tokens[2, 0] = 1
|
|
processor.output_tokens[3, 0] = 1
|
|
|
|
draft_tokens = np.arange(MAX_DRAFT_TOKENS * (K + 1), dtype=np.int64).reshape([-1, 1]) + 5
|
|
processor.output_tokens[3 + MAX_BSZ : 3 + MAX_BSZ + len(draft_tokens)] = paddle.to_tensor(draft_tokens)
|
|
processor.output_scores[: MAX_DRAFT_TOKENS * (K + 1)] = paddle.ones(
|
|
[MAX_DRAFT_TOKENS * (K + 1), 1], dtype="float32"
|
|
)
|
|
processor.output_ranks[:MAX_DRAFT_TOKENS] = paddle.arange(MAX_DRAFT_TOKENS, dtype="int64")
|
|
|
|
with (
|
|
mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False),
|
|
mock.patch.object(token_processor, "main_process_metrics", _Metrics()),
|
|
):
|
|
processor._process_batch_output()
|
|
|
|
sent_batch = processor.cached_generated_tokens.put_results.call_args.args[0]
|
|
assert sent_batch and sent_batch[0].outputs.draft_top_logprobs is not None
|
|
|
|
|
|
def test_process_batch_output_speculative_recovery_stop_finishes():
|
|
processor, rm, _, _ = _make_processor(speculative_method="mtp")
|
|
metrics = RequestMetrics(
|
|
arrival_time=time.time(),
|
|
preprocess_start_time=time.time(),
|
|
preprocess_end_time=time.time(),
|
|
inference_start_time=time.time(),
|
|
)
|
|
metrics.decode_inference_start_time = metrics.inference_start_time
|
|
task = types.SimpleNamespace(
|
|
request_id="req-recover-spec",
|
|
disaggregate_info=None,
|
|
eos_token_ids=[2],
|
|
metrics=metrics,
|
|
output_token_ids=[],
|
|
messages=[],
|
|
num_cached_tokens=0,
|
|
ic_req_data=None,
|
|
prompt_token_ids_len=0,
|
|
num_total_tokens=1,
|
|
block_tables=[1],
|
|
get=lambda key, default=None: None,
|
|
)
|
|
task.trace_carrier = None
|
|
rm.tasks_list[0] = task
|
|
rm.req_dict[task.request_id] = task
|
|
processor.output_tokens[1] = 1
|
|
processor.output_tokens[2] = -3
|
|
processor.number_of_output_tokens = 1
|
|
processor.total_step = 1
|
|
processor.accept_token_num_per_head_per_request[task.request_id] = [1] + [0] * (MAX_DRAFT_TOKENS - 1)
|
|
processor.total_step_per_request[task.request_id] = 1
|
|
processor.cached_generated_tokens.put_results = mock.Mock()
|
|
|
|
with (
|
|
mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False),
|
|
mock.patch.object(token_processor, "main_process_metrics", _Metrics()),
|
|
):
|
|
processor._process_batch_output()
|
|
|
|
assert rm.stop_flags[0] is True
|
|
sent = processor.cached_generated_tokens.put_results.call_args.args[0][0]
|
|
assert sent.finished is True
|
|
assert "incomplete" in sent.error_msg
|
|
|
|
|
|
def test_process_batch_output_prefill_chunk_and_adapter_skip():
|
|
processor, rm, _, _ = _make_processor(enable_logprob=True)
|
|
metrics = RequestMetrics(
|
|
arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time()
|
|
)
|
|
metrics.inference_start_time = time.time()
|
|
metrics.decode_inference_start_time = metrics.inference_start_time
|
|
processor.cfg.scheduler_config.splitwise_role = "prefill"
|
|
task = types.SimpleNamespace(
|
|
request_id="req-prefill-chunk",
|
|
disaggregate_info={"role": "prefill"},
|
|
eos_token_ids=[1],
|
|
metrics=metrics,
|
|
output_token_ids=[],
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
num_cached_tokens=0,
|
|
ic_req_data=None,
|
|
prompt_token_ids_len=1,
|
|
num_total_tokens=2,
|
|
block_tables=[1],
|
|
prefill_chunk_info=[{"idx": 0}, {"idx": 1}],
|
|
)
|
|
task.trace_carrier = None
|
|
task.get = lambda key, default=None: getattr(task, key, default)
|
|
rm.tasks_list[0] = task
|
|
rm.req_dict[task.request_id] = task
|
|
processor.output_tokens[1, 0] = 1
|
|
processor.output_tokens[2 : 2 + K + 1] = paddle.to_tensor(np.ones([K + 1, 1], dtype=np.int64))
|
|
processor.output_scores[: K + 1] = paddle.ones([K + 1, 1], dtype="float32")
|
|
processor.output_ranks[0] = paddle.to_tensor(0, dtype="int64")
|
|
processor.cached_generated_tokens.put_results = mock.Mock()
|
|
|
|
with (
|
|
mock.patch.object(envs, "FD_ENABLE_INTERNAL_ADAPTER", True),
|
|
mock.patch.object(token_processor, "main_process_metrics", _Metrics()),
|
|
):
|
|
processor._process_batch_output()
|
|
|
|
assert getattr(task, "prefill_chunk_num") == 1
|
|
assert processor.cached_generated_tokens.put_results.call_args.args[0] == []
|
|
|
|
|
|
def test_process_batch_output_handles_multimodal_and_negative_token():
|
|
processor, rm, _, _ = _make_processor()
|
|
metrics = RequestMetrics(
|
|
arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time()
|
|
)
|
|
metrics.inference_start_time = time.time()
|
|
metrics.decode_inference_start_time = metrics.inference_start_time
|
|
processor.cfg.scheduler_config.splitwise_role = "prefill"
|
|
task = types.SimpleNamespace(
|
|
request_id="req-negative",
|
|
disaggregate_info=None,
|
|
eos_token_ids=[5],
|
|
metrics=metrics,
|
|
output_token_ids=[],
|
|
messages=None,
|
|
num_cached_tokens=0,
|
|
ic_req_data=None,
|
|
prompt_token_ids_len=0,
|
|
num_total_tokens=1,
|
|
block_tables=[1],
|
|
prefill_chunk_info=None,
|
|
multimodal_inputs={"num_input_image_tokens": 2, "num_input_video_tokens": 3},
|
|
get=lambda key, default=None: None,
|
|
)
|
|
task.trace_carrier = None
|
|
rm.tasks_list[0] = task
|
|
rm.req_dict[task.request_id] = task
|
|
rm.to_be_rescheduled_request_id_set = {task.request_id}
|
|
rm.requests = {task.request_id: types.SimpleNamespace(idx=0)}
|
|
processor.output_tokens[1, 0] = 1
|
|
processor.output_tokens[2, 0] = -1
|
|
|
|
with (
|
|
mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", True),
|
|
mock.patch.object(token_processor, "main_process_metrics", _Metrics()),
|
|
):
|
|
processor._process_batch_output()
|
|
|
|
assert rm.recycled[-1] == f"reschedule-{task.request_id}"
|
|
|
|
|
|
def test_process_batch_output_speculative_logprob_targets_topk_scores():
|
|
processor, rm, _, _ = _make_processor(speculative_method="mtp", enable_logprob=True)
|
|
metrics = RequestMetrics(
|
|
arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time()
|
|
)
|
|
metrics.inference_start_time = time.time()
|
|
metrics.decode_inference_start_time = metrics.inference_start_time
|
|
task = types.SimpleNamespace(
|
|
request_id="req-spec-logprob",
|
|
disaggregate_info=None,
|
|
eos_token_ids=[9],
|
|
metrics=metrics,
|
|
output_token_ids=[],
|
|
messages=None,
|
|
num_cached_tokens=0,
|
|
ic_req_data=None,
|
|
prompt_token_ids_len=0,
|
|
num_total_tokens=1,
|
|
block_tables=[1],
|
|
get=lambda key, default=None: None,
|
|
)
|
|
task.trace_carrier = None
|
|
rm.tasks_list[0] = task
|
|
rm.req_dict[task.request_id] = task
|
|
processor.output_tokens[1, 0] = 3
|
|
processor.output_tokens[2, 0] = 1
|
|
processor.output_tokens[3, 0] = 2
|
|
token_block = np.arange(MAX_DRAFT_TOKENS * (K + 1), dtype=np.int64).reshape([-1, 1]) + 3
|
|
processor.output_tokens[3 + MAX_BSZ : 3 + MAX_BSZ + len(token_block)] = paddle.to_tensor(token_block)
|
|
score_block = paddle.arange(MAX_DRAFT_TOKENS * (K + 1), dtype="float32").reshape([-1, 1])
|
|
processor.output_scores[: MAX_DRAFT_TOKENS * (K + 1)] = score_block
|
|
processor.output_ranks[:MAX_DRAFT_TOKENS] = paddle.arange(MAX_DRAFT_TOKENS, dtype="int64")
|
|
processor.cached_generated_tokens.put_results = mock.Mock()
|
|
|
|
with (
|
|
mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False),
|
|
mock.patch.object(token_processor, "main_process_metrics", _Metrics()),
|
|
):
|
|
processor._process_batch_output()
|
|
|
|
assert processor.tokens_counter[task.request_id] == 2
|
|
|
|
|
|
def test_record_metrics_and_speculative_ngram_metrics():
|
|
processor, _, _, _ = _make_processor(speculative_method="ngram", enable_logprob=True)
|
|
metrics = _Metrics()
|
|
task = types.SimpleNamespace(
|
|
request_id="req-metrics",
|
|
metrics=RequestMetrics(arrival_time=time.time(), preprocess_start_time=0, preprocess_end_time=0),
|
|
last_token_time=time.time(),
|
|
)
|
|
with mock.patch.object(token_processor, "main_process_metrics", metrics):
|
|
processor._record_metrics(task, current_time=time.time(), token_ids=[1, 2])
|
|
processor.accept_token_num_per_head = [0, 2] + [0] * (MAX_DRAFT_TOKENS - 2)
|
|
processor.num_accepted_tokens = 3
|
|
processor.num_emitted_tokens = 3
|
|
processor._record_speculative_decoding_metrics(accept_num=[1, 1])
|
|
|
|
assert metrics.generation_tokens_total.value == 2
|
|
assert metrics.spec_decode_draft_acceptance_rate.value == 1
|
|
|
|
|
|
def test_clear_data_invokes_scheduler_cleanup():
|
|
processor, rm, _, _ = _make_processor()
|
|
metrics = RequestMetrics(
|
|
arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time()
|
|
)
|
|
task = types.SimpleNamespace(
|
|
request_id="req-clear",
|
|
arrival_time=time.time(),
|
|
disaggregate_info=None,
|
|
eos_token_ids=[0],
|
|
metrics=metrics,
|
|
output_token_ids=[],
|
|
messages=None,
|
|
num_cached_tokens=0,
|
|
ic_req_data=None,
|
|
prompt_token_ids_len=0,
|
|
num_total_tokens=1,
|
|
block_tables=[1],
|
|
get=lambda key, default=None: getattr(task, key, default),
|
|
)
|
|
task.trace_carrier = None
|
|
rm.tasks_list[0] = task
|
|
rm.stop_flags = [True] * rm.max_num_seqs
|
|
processor.tokens_counter[task.request_id] = 0
|
|
|
|
with (
|
|
mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", True),
|
|
mock.patch.object(token_processor, "main_process_metrics", _Metrics()),
|
|
):
|
|
processor.clear_data()
|
|
|
|
assert rm.cleared is True
|
|
|
|
|
|
def test_process_batch_output_skips_already_stopped_slot():
|
|
processor, rm, _, _ = _make_processor()
|
|
rm.stop_flags[0] = True
|
|
processor.output_tokens[1, 0] = 1
|
|
processor.output_tokens[2, 0] = 5
|
|
|
|
with (
|
|
mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False),
|
|
mock.patch.object(token_processor, "main_process_metrics", _Metrics()),
|
|
):
|
|
processor._process_batch_output()
|
|
|
|
assert processor.cached_generated_tokens.put_results.called
|
|
|
|
|
|
def test_process_batch_output_speculative_negative_token_reschedules():
|
|
processor, rm, _, _ = _make_processor(speculative_method="mtp")
|
|
task_id = "req-spec-neg"
|
|
metrics = RequestMetrics(
|
|
arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time()
|
|
)
|
|
metrics.inference_start_time = time.time()
|
|
metrics.decode_inference_start_time = metrics.inference_start_time
|
|
task = types.SimpleNamespace(
|
|
request_id=task_id,
|
|
disaggregate_info=None,
|
|
eos_token_ids=[1],
|
|
metrics=metrics,
|
|
output_token_ids=[],
|
|
messages=None,
|
|
num_cached_tokens=0,
|
|
ic_req_data=None,
|
|
prompt_token_ids_len=0,
|
|
num_total_tokens=1,
|
|
block_tables=[1],
|
|
get=lambda key, default=None: None,
|
|
)
|
|
task.trace_carrier = None
|
|
rm.tasks_list[0] = task
|
|
rm.req_dict[task_id] = task
|
|
rm.to_be_rescheduled_request_id_set = {task_id}
|
|
rm.requests = {task_id: types.SimpleNamespace(idx=0)}
|
|
processor.output_tokens[1] = 1
|
|
processor.output_tokens[2] = 1
|
|
processor.output_tokens[3] = 1
|
|
processor.output_tokens[2 + SPECULATE_MAX_BSZ] = -1
|
|
|
|
with (
|
|
mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", True),
|
|
mock.patch.object(token_processor, "main_process_metrics", _Metrics()),
|
|
):
|
|
processor._process_batch_output()
|
|
|
|
assert rm.recycled[-1] == f"reschedule-{task_id}"
|
|
|
|
|
|
def test_process_batch_output_use_zmq_reschedules_negative_token():
|
|
processor, rm, _, _ = _make_processor()
|
|
task = types.SimpleNamespace(request_id="req-zmq-neg")
|
|
rm.tasks_list[0] = task
|
|
rm.req_dict[task.request_id] = task
|
|
rm.to_be_rescheduled_request_id_set = {task.request_id}
|
|
|
|
stream = types.SimpleNamespace(batch_id=0, tokens=np.array([-1], dtype=np.int64), pooler_output=None)
|
|
with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", True):
|
|
results = processor._process_batch_output_use_zmq([stream])
|
|
|
|
assert results == []
|
|
assert rm.recycled[-1] == f"reschedule-{task.request_id}"
|
|
|
|
|
|
def test_process_batch_output_records_second_decode_token():
|
|
processor, rm, _, _ = _make_processor()
|
|
processor.cfg.scheduler_config.splitwise_role = "decode"
|
|
task = types.SimpleNamespace(
|
|
request_id="req-second",
|
|
disaggregate_info=None,
|
|
eos_token_ids=[2],
|
|
metrics=RequestMetrics(
|
|
arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time()
|
|
),
|
|
output_token_ids=[],
|
|
messages=None,
|
|
num_cached_tokens=0,
|
|
ic_req_data=None,
|
|
prompt_token_ids_len=0,
|
|
num_total_tokens=1,
|
|
block_tables=[1],
|
|
get=lambda key, default=None: None,
|
|
)
|
|
task.trace_carrier = None
|
|
task.metrics.inference_start_time = time.time()
|
|
task.metrics.decode_inference_start_time = task.metrics.inference_start_time
|
|
rm.tasks_list[0] = task
|
|
rm.req_dict[task.request_id] = task
|
|
processor.tokens_counter[task.request_id] = 1
|
|
processor.output_tokens[1, 0] = 1
|
|
processor.output_tokens[2, 0] = 2
|
|
|
|
with (
|
|
mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False),
|
|
mock.patch.object(token_processor, "main_process_metrics", _Metrics()),
|
|
):
|
|
processor._process_batch_output()
|
|
|
|
assert task.metrics.decode_recv_second_token_time is not None
|
|
|
|
|
|
def test_record_speculative_metrics_calls_init_when_missing():
|
|
processor, _, _, _ = _make_processor(speculative_method="mtp")
|
|
|
|
class _MinimalMetrics:
|
|
def __init__(self):
|
|
self.init_called = False
|
|
|
|
def _init_speculative_metrics(self, method, num_speculative_tokens):
|
|
self.spec_decode_num_accepted_tokens_total = _Metric()
|
|
self.spec_decode_num_emitted_tokens_total = _Metric()
|
|
self.spec_decode_draft_acceptance_rate = _Metric()
|
|
self.spec_decode_efficiency = _Metric()
|
|
self.spec_decode_num_draft_tokens_total = _Metric()
|
|
self.spec_decode_draft_single_head_acceptance_rate = [_Metric() for _ in range(MAX_DRAFT_TOKENS)]
|
|
self.init_called = True
|
|
|
|
processor.accept_token_num_per_head = [1, 1] + [0] * (MAX_DRAFT_TOKENS - 2)
|
|
processor.num_accepted_tokens = 2
|
|
processor.num_emitted_tokens = 2
|
|
|
|
metrics = _MinimalMetrics()
|
|
with mock.patch.object(token_processor, "main_process_metrics", metrics):
|
|
processor._record_speculative_decoding_metrics(accept_num=[1])
|
|
|
|
assert metrics.init_called is True
|
|
|
|
|
|
def test_process_batch_output_prefill_sets_draft_tokens():
|
|
processor, rm, _, connector = _make_processor(speculative_method="mtp")
|
|
processor.cfg.scheduler_config.splitwise_role = "prefill"
|
|
metrics = RequestMetrics(
|
|
arrival_time=time.time(),
|
|
preprocess_start_time=time.time(),
|
|
preprocess_end_time=time.time(),
|
|
inference_start_time=time.time(),
|
|
)
|
|
metrics.decode_inference_start_time = metrics.inference_start_time
|
|
task = types.SimpleNamespace(
|
|
request_id="req-prefill-draft",
|
|
disaggregate_info={"role": "prefill"},
|
|
eos_token_ids=[99],
|
|
metrics=metrics,
|
|
output_token_ids=[],
|
|
messages=None,
|
|
num_cached_tokens=0,
|
|
ic_req_data=None,
|
|
prompt_token_ids_len=0,
|
|
num_total_tokens=1,
|
|
block_tables=[1],
|
|
get=lambda key, default=None: None,
|
|
)
|
|
task.trace_carrier = None
|
|
rm.tasks_list[0] = task
|
|
rm.req_dict[task.request_id] = task
|
|
processor.engine_worker_queue = mock.Mock()
|
|
processor.engine_worker_queue.get_finished_req.side_effect = [[(task.request_id, "finished")]]
|
|
processor.output_tokens[1] = 1
|
|
processor.output_tokens[2] = 2
|
|
processor.output_tokens[2 + SPECULATE_MAX_BSZ] = 11
|
|
processor.output_tokens[2 + SPECULATE_MAX_BSZ + 1] = 12
|
|
|
|
with (
|
|
mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False),
|
|
mock.patch.object(token_processor, "main_process_metrics", _Metrics()),
|
|
):
|
|
processor._process_batch_output()
|
|
|
|
assert connector.calls
|
|
sent = connector.calls[0][1][0]
|
|
assert sent.outputs.draft_token_ids == [11, 12]
|
|
|
|
|
|
def test_process_batch_output_logs_recovery_stop_for_non_speculative():
|
|
processor, rm, _, _ = _make_processor()
|
|
metrics = RequestMetrics(
|
|
arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time()
|
|
)
|
|
metrics.inference_start_time = time.time()
|
|
metrics.decode_inference_start_time = metrics.inference_start_time
|
|
task = types.SimpleNamespace(
|
|
request_id="req-recovery",
|
|
disaggregate_info=None,
|
|
eos_token_ids=[1],
|
|
metrics=metrics,
|
|
output_token_ids=[],
|
|
messages=None,
|
|
num_cached_tokens=0,
|
|
ic_req_data=None,
|
|
prompt_token_ids_len=0,
|
|
num_total_tokens=1,
|
|
block_tables=[1],
|
|
)
|
|
task.trace_carrier = None
|
|
task.get = lambda k, d=None: getattr(task, k, d)
|
|
rm.tasks_list[0] = task
|
|
rm.req_dict[task.request_id] = task
|
|
processor.output_tokens[1, 0] = 1
|
|
processor.output_tokens[2, 0] = token_processor.RECOVERY_STOP_SIGNAL
|
|
|
|
with (
|
|
mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False),
|
|
mock.patch.object(token_processor, "main_process_metrics", _Metrics()),
|
|
):
|
|
processor._process_batch_output()
|
|
|
|
assert rm.stop_flags[0] is True
|
|
|
|
|
|
def test_process_batch_output_sets_multimodal_token_counts():
|
|
processor, rm, _, _ = _make_processor()
|
|
metrics = RequestMetrics(
|
|
arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time()
|
|
)
|
|
metrics.inference_start_time = time.time()
|
|
metrics.decode_inference_start_time = metrics.inference_start_time
|
|
task = types.SimpleNamespace(
|
|
request_id="req-mm",
|
|
disaggregate_info=None,
|
|
eos_token_ids=[7],
|
|
metrics=metrics,
|
|
output_token_ids=[],
|
|
messages=None,
|
|
num_cached_tokens=0,
|
|
ic_req_data=None,
|
|
prompt_token_ids_len=0,
|
|
num_total_tokens=1,
|
|
block_tables=[1],
|
|
multimodal_inputs={"num_input_image_tokens": 4, "num_input_video_tokens": 5},
|
|
)
|
|
task.trace_carrier = None
|
|
task.get = lambda key, default=None: getattr(task, key, default)
|
|
rm.tasks_list[0] = task
|
|
rm.req_dict[task.request_id] = task
|
|
processor.output_tokens[1, 0] = 1
|
|
processor.output_tokens[2, 0] = 7
|
|
|
|
with mock.patch.object(token_processor, "main_process_metrics", _Metrics()):
|
|
processor._process_batch_output()
|
|
|
|
sent = processor.cached_generated_tokens.put_results.call_args.args[0][0]
|
|
assert sent.num_input_image_tokens == 4 and sent.num_input_video_tokens == 5
|
|
|
|
|
|
def test_warmup_token_processor_initialization():
|
|
cfg = _DummyCfg()
|
|
with mock.patch.object(token_processor.TokenProcessor, "__init__", lambda self, _cfg: None):
|
|
warm = token_processor.WarmUpTokenProcessor(cfg)
|
|
assert warm._is_running is True and warm._is_blocking is True
|
|
warm.postprocess([])
|
|
|
|
|
|
def test_warmup_processor_stop_joins_worker():
|
|
warm = token_processor.WarmUpTokenProcessor.__new__(token_processor.WarmUpTokenProcessor)
|
|
warm._is_running = True
|
|
worker = mock.Mock()
|
|
warm.worker = worker
|
|
warm.stop()
|
|
worker.join.assert_called_once()
|
|
|
|
|
|
def test_healthy_behaviour_respects_timeout(monkeypatch):
|
|
processor, _, _, _ = _make_processor()
|
|
processor.timestamp_for_alive_before_handle_batch = time.time() - 1
|
|
processor.timestamp_for_alive_after_handle_batch = None
|
|
monkeypatch.setattr(envs, "FD_TOKEN_PROCESSOR_HEALTH_TIMEOUT", 0.1)
|
|
|
|
assert processor.healthy() is False
|
|
|
|
|
|
def test_healthy_detects_engine_hang():
|
|
processor, _, _, _ = _make_processor()
|
|
processor.timestamp_for_alive_before_handle_batch = None
|
|
processor.timestamp_for_alive_after_handle_batch = time.time()
|
|
processor.engine_output_token_hang = True
|
|
|
|
assert processor.healthy() is False
|
|
|
|
|
|
def test_healthy_recent_prehandle_activity_is_ok(monkeypatch):
|
|
processor, _, _, _ = _make_processor()
|
|
processor.timestamp_for_alive_before_handle_batch = time.time()
|
|
processor.timestamp_for_alive_after_handle_batch = None
|
|
monkeypatch.setattr(envs, "FD_TOKEN_PROCESSOR_HEALTH_TIMEOUT", 5)
|
|
|
|
assert processor.healthy() is True
|
|
|
|
|
|
def test_record_completion_metrics_updates_counters():
|
|
processor, _, _, _ = _make_processor()
|
|
task_id = "req-complete"
|
|
metrics = RequestMetrics(
|
|
arrival_time=time.time(), preprocess_start_time=time.time(), preprocess_end_time=time.time()
|
|
)
|
|
metrics.inference_start_time = time.time() - 0.2
|
|
metrics.engine_recv_first_token_time = time.time() - 0.1
|
|
task = types.SimpleNamespace(request_id=task_id, metrics=metrics, user="user-a")
|
|
processor.tokens_counter[task_id] = 4
|
|
|
|
with (
|
|
mock.patch.object(token_processor, "main_process_metrics", _Metrics()) as metrics_obj,
|
|
mock.patch.object(token_processor, "trace_print"),
|
|
):
|
|
processor._record_completion_metrics(task, current_time=time.time())
|
|
|
|
assert metrics_obj.request_decode_time.value is not None
|
|
assert metrics_obj.request_success_total.value == 1
|
|
assert metrics_obj.request_generation_tokens.value == 4
|
|
|
|
|
|
def test_process_sampling_results_use_zmq_rejects_speculative():
|
|
processor, _, _, _ = _make_processor(speculative_method="mtp")
|
|
with pytest.raises(NotImplementedError):
|
|
processor.process_sampling_results_use_zmq()
|