mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[CI]【Hackathon 9th Sprint No.12】功能模块 fastdeploy/spec_decode/mtp.py 单测补充 (#5533)
* Add unit tests for MTPProposer class in spec_decode/mtp.py * fix: remove non-existent QuantizationConfig import in test_mtp_proposer * fix: add logprobs_mode attribute to FakeModelConfig * fix: fix test failures in test_mtp_proposer - fix Mock setup, remove arrival_time, add missing keys * fix: add seq_lens_this_time initialization and kv_cache init before insert_tasks_v1 * fix: check pos_emb_type attribute existence before assertion * test: add minimal coverage for mtp cache type, mm init, preempted * test: fix cache_type_branches unsupported platform on 12 * test: refine MTPProposer tests for cache type, requests and chunked prefill * chore: remove stray spec_decode copy
This commit is contained in:
625
tests/spec_decode/test_mtp_proposer.py
Normal file
625
tests/spec_decode/test_mtp_proposer.py
Normal file
@@ -0,0 +1,625 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.config import SpeculativeConfig
|
||||
from fastdeploy.engine.request import Request, RequestType
|
||||
from fastdeploy.spec_decode.mtp import MTPProposer
|
||||
from tests.utils import FakeModelConfig, get_default_test_fd_config
|
||||
|
||||
|
||||
class TestMTPProposer(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.fd_config = get_default_test_fd_config()
|
||||
self.fd_config.model_config = FakeModelConfig()
|
||||
self.fd_config.model_config.architectures = ["ErnieMoeForCausalLM"]
|
||||
self.fd_config.model_config.dtype = "bfloat16"
|
||||
self.fd_config.model_config.rope_theta = 10000.0
|
||||
self.fd_config.model_config.enable_logprob = False
|
||||
self.fd_config.model_config.max_model_len = 2048
|
||||
self.fd_config.speculative_config = SpeculativeConfig({})
|
||||
self.fd_config.speculative_config.method = "mtp"
|
||||
self.fd_config.speculative_config.num_speculative_tokens = 1
|
||||
self.fd_config.speculative_config.model = "test_mtp_model"
|
||||
self.fd_config.speculative_config.quantization = ""
|
||||
self.fd_config.speculative_config.num_gpu_block_expand_ratio = 1.0
|
||||
self.fd_config.speculative_config.mtp_strategy = "default"
|
||||
self.fd_config.scheduler_config.splitwise_role = "mixed"
|
||||
self.fd_config.cache_config.enable_prefix_caching = False
|
||||
self.fd_config.cache_config.block_size = 16
|
||||
self.fd_config.cache_config.enc_dec_block_num = 0
|
||||
self.fd_config.cache_config.kv_cache_ratio = 0.9
|
||||
self.fd_config.cache_config.total_block_num = 100
|
||||
self.fd_config.cache_config.enable_chunked_prefill = False
|
||||
self.fd_config.graph_opt_config.draft_model_use_cudagraph = False
|
||||
self.fd_config.parallel_config.enable_expert_parallel = False
|
||||
self.fd_config.parallel_config.tensor_parallel_size = 1
|
||||
self.fd_config.parallel_config.tensor_parallel_rank = 0
|
||||
self.fd_config.parallel_config.data_parallel_rank = 0
|
||||
self.fd_config.parallel_config.use_ep = False
|
||||
|
||||
self.main_model = Mock()
|
||||
self.main_model.num_hidden_layers = 12
|
||||
self.local_rank = 0
|
||||
self.device_id = 0
|
||||
|
||||
self.target_model_inputs = {
|
||||
"block_tables": paddle.zeros([2, 10], dtype="int32"),
|
||||
"input_ids": paddle.zeros([2, 2048], dtype="int64"),
|
||||
"seq_lens_this_time": paddle.zeros([2, 1], dtype="int32"),
|
||||
"seq_lens_encoder": paddle.zeros([2, 1], dtype="int32"),
|
||||
"seq_lens_decoder": paddle.zeros([2, 1], dtype="int32"),
|
||||
"prompt_lens": paddle.zeros([2, 1], dtype="int64"),
|
||||
"step_idx": paddle.zeros([2, 1], dtype="int64"),
|
||||
"stop_flags": paddle.zeros([2, 1], dtype="bool"),
|
||||
"stop_nums": paddle.zeros([2, 1], dtype="int32"),
|
||||
"pre_ids": paddle.zeros([2, 2048], dtype="int64"),
|
||||
"output_cum_offsets": paddle.zeros([2], dtype="int32"),
|
||||
"output_padding_offset": paddle.zeros([2], dtype="int32"),
|
||||
"ids_remove_padding": paddle.zeros([2], dtype="int64"),
|
||||
"batch_id_per_token": paddle.zeros([2], dtype="int32"),
|
||||
"cu_seqlens_q": paddle.zeros([3], dtype="int32"),
|
||||
"cu_seqlens_k": paddle.zeros([3], dtype="int32"),
|
||||
"decoder_batch_ids": paddle.zeros([2], dtype="int32"),
|
||||
"decoder_tile_ids_per_batch": paddle.zeros([2], dtype="int32"),
|
||||
"decoder_num_blocks_cpu": paddle.zeros([2], dtype="int32").cpu(),
|
||||
"decoder_num_blocks_device": paddle.zeros([2], dtype="int32"),
|
||||
"decoder_chunk_size_device": paddle.zeros([2], dtype="int32"),
|
||||
"max_len_tensor_cpu": paddle.zeros([2], dtype="int32").cpu(),
|
||||
"encoder_batch_ids": paddle.zeros([2], dtype="int32"),
|
||||
"encoder_tile_ids_per_batch": paddle.zeros([2], dtype="int32"),
|
||||
"encoder_num_blocks_x_cpu": paddle.zeros([2], dtype="int32").cpu(),
|
||||
"kv_batch_ids": paddle.zeros([2], dtype="int32"),
|
||||
"kv_tile_ids_per_batch": paddle.zeros([2], dtype="int32"),
|
||||
"kv_num_blocks_x_cpu": paddle.zeros([2], dtype="int32").cpu(),
|
||||
"top_p": paddle.ones([2, 1], dtype="float32") * 0.9,
|
||||
"top_k": paddle.zeros([2, 1], dtype="int32"),
|
||||
"temperature": paddle.ones([2, 1], dtype="float32"),
|
||||
"eos_token_id": paddle.ones([2], dtype="int64") * 2,
|
||||
"penalty_score": paddle.ones([2, 1], dtype="float32"),
|
||||
"frequency_score": paddle.zeros([2, 1], dtype="float32"),
|
||||
"presence_score": paddle.zeros([2, 1], dtype="float32"),
|
||||
"infer_seed": paddle.zeros([2, 1], dtype="int64"),
|
||||
"max_dec_len": paddle.ones([2, 1], dtype="int64") * 512,
|
||||
"min_dec_len": paddle.zeros([2, 1], dtype="int64"),
|
||||
"bad_tokens": paddle.zeros([2], dtype="int64"),
|
||||
"draft_tokens": paddle.zeros([2, 2], dtype="int64"),
|
||||
"accept_tokens": paddle.zeros([2, 2], dtype="int64"),
|
||||
"accept_num": paddle.ones([2], dtype="int32"),
|
||||
"draft_logits": paddle.zeros([4, 32000], dtype="float32"),
|
||||
"temp_scaled_logprobs": paddle.zeros([2], dtype="float32"),
|
||||
"top_p_normalized_logprobs": paddle.zeros([2], dtype="float32"),
|
||||
"encoder_block_lens": paddle.zeros([2, 1], dtype="int32"),
|
||||
"cu_batch_token_offset": paddle.zeros([3], dtype="int32"),
|
||||
"is_block_step": paddle.zeros([2], dtype="bool"),
|
||||
"actual_draft_token_num": paddle.zeros([2], dtype="int32"),
|
||||
}
|
||||
|
||||
@patch("fastdeploy.spec_decode.mtp.get_model_loader")
|
||||
@patch("fastdeploy.spec_decode.mtp.get_attention_backend")
|
||||
@patch("fastdeploy.spec_decode.mtp.get_rope")
|
||||
def test_init_and_config_methods(self, mock_rope, mock_attn_backend, mock_model_loader):
|
||||
"""Test initialization and config update methods"""
|
||||
mock_model = Mock()
|
||||
mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000]))
|
||||
mock_model_loader.return_value.load_model.return_value = mock_model
|
||||
mock_attn = Mock()
|
||||
mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64])
|
||||
mock_attn_backend.return_value = lambda *args, **kwargs: mock_attn
|
||||
mock_rope.return_value = paddle.zeros([1, 2048, 64])
|
||||
|
||||
proposer = MTPProposer(
|
||||
self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs
|
||||
)
|
||||
|
||||
# Test _update_mtp_config
|
||||
self.assertEqual(proposer.model_config.architectures[0], "ErnieMTPForCausalLM")
|
||||
self.assertEqual(proposer.model_config.num_hidden_layers, 1)
|
||||
self.assertEqual(proposer.speculative_config.model_type, "mtp")
|
||||
|
||||
# Test _get_cache_type
|
||||
cache_type = proposer._get_cache_type()
|
||||
self.assertIn(cache_type, ["uint8", "int8"])
|
||||
|
||||
# Test is_chunk_prefill_enabled
|
||||
self.assertTrue(proposer.is_chunk_prefill_enabled())
|
||||
|
||||
@patch("fastdeploy.spec_decode.mtp.get_model_loader")
|
||||
@patch("fastdeploy.spec_decode.mtp.get_attention_backend")
|
||||
@patch("fastdeploy.spec_decode.mtp.get_rope")
|
||||
def test_dummy_prefill_inputs_and_kv_cache(self, mock_rope, mock_attn_backend, mock_model_loader):
|
||||
"""Test dummy_prefill_inputs and initialize_kv_cache with different branches"""
|
||||
mock_model = Mock()
|
||||
mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000]))
|
||||
mock_model_loader.return_value.load_model.return_value = mock_model
|
||||
mock_attn = Mock()
|
||||
mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64])
|
||||
mock_attn_backend.return_value = lambda *args, **kwargs: mock_attn
|
||||
mock_rope.return_value = paddle.zeros([1, 2048, 64])
|
||||
|
||||
proposer = MTPProposer(
|
||||
self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs
|
||||
)
|
||||
|
||||
# Test dummy_prefill_inputs with expert parallel
|
||||
self.fd_config.parallel_config.enable_expert_parallel = True
|
||||
proposer.dummy_prefill_inputs(num_tokens=100, batch_size=2, expected_decode_len=10)
|
||||
self.assertGreater(proposer.model_inputs["seq_lens_encoder"][0].item(), 0)
|
||||
|
||||
# Test initialize_kv_cache with prefix caching
|
||||
self.fd_config.cache_config.enable_prefix_caching = True
|
||||
proposer.initialize_kv_cache(main_model_num_blocks=10, profile=False)
|
||||
self.assertIn("caches", proposer.model_inputs)
|
||||
|
||||
# Test initialize_kv_cache with block_wise_fp8
|
||||
self.fd_config.quant_config = Mock()
|
||||
self.fd_config.quant_config.kv_cache_quant_type = "block_wise_fp8"
|
||||
proposer.initialize_kv_cache(main_model_num_blocks=10, profile=False)
|
||||
|
||||
# Test initialize_kv_cache with profile=True
|
||||
proposer.initialize_kv_cache(main_model_num_blocks=10, profile=True)
|
||||
|
||||
# Test clear_mtp_cache
|
||||
proposer.clear_mtp_cache()
|
||||
self.assertNotIn("caches", proposer.model_inputs)
|
||||
|
||||
@patch("fastdeploy.spec_decode.mtp.get_model_loader")
|
||||
@patch("fastdeploy.spec_decode.mtp.get_attention_backend")
|
||||
@patch("fastdeploy.spec_decode.mtp.get_rope")
|
||||
def test_update_mtp_block_num(self, mock_rope, mock_attn_backend, mock_model_loader):
|
||||
"""Test update_mtp_block_num"""
|
||||
mock_model = Mock()
|
||||
mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000]))
|
||||
mock_model_loader.return_value.load_model.return_value = mock_model
|
||||
mock_attn = Mock()
|
||||
mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64])
|
||||
mock_attn_backend.return_value = lambda *args, **kwargs: mock_attn
|
||||
mock_rope.return_value = paddle.zeros([1, 2048, 64])
|
||||
|
||||
proposer = MTPProposer(
|
||||
self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs
|
||||
)
|
||||
proposer.update_mtp_block_num(num_gpu_blocks=20)
|
||||
self.assertEqual(proposer.main_model_num_gpu_blocks, 20)
|
||||
self.assertIn("free_list", proposer.model_inputs)
|
||||
|
||||
@patch("fastdeploy.spec_decode.mtp.get_model_loader")
|
||||
@patch("fastdeploy.spec_decode.mtp.get_attention_backend")
|
||||
@patch("fastdeploy.spec_decode.mtp.get_rope")
|
||||
def test_insert_tasks_v1(self, mock_rope, mock_attn_backend, mock_model_loader):
|
||||
"""Test insert_tasks_v1 with different request types"""
|
||||
mock_model = Mock()
|
||||
mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000]))
|
||||
mock_model_loader.return_value.load_model.return_value = mock_model
|
||||
mock_attn = Mock()
|
||||
mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64])
|
||||
mock_attn_backend.return_value = lambda *args, **kwargs: mock_attn
|
||||
mock_rope.return_value = paddle.zeros([1, 2048, 64])
|
||||
|
||||
proposer = MTPProposer(
|
||||
self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs
|
||||
)
|
||||
|
||||
# Test with PREFILL request
|
||||
request1 = Request(
|
||||
request_id="test1",
|
||||
prompt="test",
|
||||
prompt_token_ids=[1, 2, 3, 4, 5],
|
||||
prompt_token_ids_len=5,
|
||||
messages=None,
|
||||
history=None,
|
||||
tools=None,
|
||||
system=None,
|
||||
eos_token_ids=[2],
|
||||
)
|
||||
request1.idx = 0
|
||||
request1.task_type = RequestType.PREFILL
|
||||
request1.prefill_start_index = 0
|
||||
request1.prefill_end_index = 5
|
||||
request1.output_token_ids = []
|
||||
request1.block_tables = [0, 1]
|
||||
|
||||
# Test with DECODE request
|
||||
request2 = Request(
|
||||
request_id="test2",
|
||||
prompt="test",
|
||||
prompt_token_ids=[1, 2],
|
||||
prompt_token_ids_len=2,
|
||||
messages=None,
|
||||
history=None,
|
||||
tools=None,
|
||||
system=None,
|
||||
eos_token_ids=[2],
|
||||
)
|
||||
request2.idx = 1
|
||||
request2.task_type = RequestType.DECODE
|
||||
request2.block_tables = [2, 3]
|
||||
|
||||
# Test with PREEMPTED request
|
||||
request3 = Request(
|
||||
request_id="test3",
|
||||
prompt="test",
|
||||
prompt_token_ids=[1],
|
||||
prompt_token_ids_len=1,
|
||||
messages=None,
|
||||
history=None,
|
||||
tools=None,
|
||||
system=None,
|
||||
eos_token_ids=[2],
|
||||
)
|
||||
request3.idx = 0
|
||||
request3.task_type = RequestType.PREEMPTED
|
||||
|
||||
# Test splitwise_role == "decode"
|
||||
self.fd_config.scheduler_config.splitwise_role = "decode"
|
||||
proposer.initialize_kv_cache(main_model_num_blocks=10)
|
||||
proposer.insert_tasks_v1([request1], 1)
|
||||
|
||||
# Test with multimodal
|
||||
proposer.enable_mm = True
|
||||
request1.multimodal_inputs = {"attention_mask_offset": [0, 1, 2, 3, 4]}
|
||||
proposer.model_inputs["attn_mask_offsets_full"] = paddle.zeros([2, 2048], dtype="int32")
|
||||
proposer.model_inputs["attn_mask_offsets_decoder"] = paddle.zeros([2, 1], dtype="int32")
|
||||
proposer.insert_tasks_v1([request1], 1)
|
||||
|
||||
@patch("fastdeploy.spec_decode.mtp.get_model_loader")
|
||||
@patch("fastdeploy.spec_decode.mtp.get_attention_backend")
|
||||
@patch("fastdeploy.spec_decode.mtp.get_rope")
|
||||
def test_insert_prefill_inputs(self, mock_rope, mock_attn_backend, mock_model_loader):
|
||||
"""Test insert_prefill_inputs with different roles and chunked prefill"""
|
||||
mock_model = Mock()
|
||||
mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000]))
|
||||
mock_model_loader.return_value.load_model.return_value = mock_model
|
||||
mock_attn = Mock()
|
||||
mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64])
|
||||
mock_attn_backend.return_value = lambda *args, **kwargs: mock_attn
|
||||
mock_rope.return_value = paddle.zeros([1, 2048, 64])
|
||||
|
||||
proposer = MTPProposer(
|
||||
self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs
|
||||
)
|
||||
|
||||
request = Request(
|
||||
request_id="test",
|
||||
prompt="test",
|
||||
prompt_token_ids=[1, 2, 3, 4, 5],
|
||||
prompt_token_ids_len=5,
|
||||
messages=None,
|
||||
history=None,
|
||||
tools=None,
|
||||
system=None,
|
||||
eos_token_ids=[2],
|
||||
)
|
||||
request.idx = 0
|
||||
request.block_tables = [0, 1]
|
||||
request.draft_token_ids = [10, 11]
|
||||
|
||||
# Test with prefill role
|
||||
request.disaggregate_info = {"role": "prefill"}
|
||||
proposer.insert_prefill_inputs([request], 1)
|
||||
self.assertEqual(proposer.role, "prefill")
|
||||
|
||||
# Test with decode role
|
||||
request.disaggregate_info = {"role": "decode"}
|
||||
proposer.insert_prefill_inputs([request], 1)
|
||||
self.assertEqual(proposer.role, "decode")
|
||||
|
||||
# Test with chunked prefill
|
||||
self.fd_config.cache_config.enable_chunked_prefill = True
|
||||
request.prefill_chunk_info = [3, 2]
|
||||
request.disaggregate_info = None
|
||||
proposer.insert_prefill_inputs([request], 1)
|
||||
|
||||
@patch("fastdeploy.spec_decode.mtp.get_model_loader")
|
||||
@patch("fastdeploy.spec_decode.mtp.get_attention_backend")
|
||||
@patch("fastdeploy.spec_decode.mtp.get_rope")
|
||||
def test_forward_meta_and_exist_prefill(self, mock_rope, mock_attn_backend, mock_model_loader):
|
||||
"""Test _initialize_forward_meta, _initialize_forward_meta_xpu, and exist_prefill"""
|
||||
mock_model = Mock()
|
||||
mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000]))
|
||||
mock_model_loader.return_value.load_model.return_value = mock_model
|
||||
mock_attn = Mock()
|
||||
mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64])
|
||||
mock_attn_backend.return_value = lambda *args, **kwargs: mock_attn
|
||||
mock_rope.return_value = paddle.zeros([1, 2048, 64])
|
||||
|
||||
proposer = MTPProposer(
|
||||
self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs
|
||||
)
|
||||
proposer.initialize_kv_cache(main_model_num_blocks=10)
|
||||
proposer.model_inputs["seq_lens_this_time"] = proposer.seq_lens_this_time_buffer
|
||||
|
||||
# Test _initialize_forward_meta
|
||||
proposer._initialize_forward_meta(step_use_cudagraph=False)
|
||||
self.assertIsNotNone(proposer.forward_meta)
|
||||
|
||||
# Test _initialize_forward_meta_xpu
|
||||
proposer._initialize_forward_meta_xpu()
|
||||
if hasattr(proposer.forward_meta, "pos_emb_type"):
|
||||
self.assertEqual(proposer.forward_meta.pos_emb_type, "NORMAL")
|
||||
|
||||
# Test exist_prefill
|
||||
proposer.share_inputs = {"seq_lens_encoder": paddle.ones([2, 1], dtype="int32")}
|
||||
result = proposer.exist_prefill()
|
||||
self.assertEqual(result, 1)
|
||||
|
||||
proposer.share_inputs = {"seq_lens_encoder": paddle.zeros([2, 1], dtype="int32")}
|
||||
result = proposer.exist_prefill()
|
||||
self.assertEqual(result, 0)
|
||||
|
||||
@patch("fastdeploy.spec_decode.mtp.get_model_loader")
|
||||
@patch("fastdeploy.spec_decode.mtp.get_attention_backend")
|
||||
@patch("fastdeploy.spec_decode.mtp.get_rope")
|
||||
@patch("fastdeploy.spec_decode.mtp.draft_model_preprocess")
|
||||
@patch("fastdeploy.spec_decode.mtp.eagle_get_hidden_states")
|
||||
def test_prepare_inputs_and_post_process(
|
||||
self, mock_eagle, mock_preprocess, mock_rope, mock_attn_backend, mock_model_loader
|
||||
):
|
||||
"""Test _prepare_inputs and _post_process"""
|
||||
mock_model = Mock()
|
||||
mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000]))
|
||||
mock_model_loader.return_value.load_model.return_value = mock_model
|
||||
mock_attn = Mock()
|
||||
mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64])
|
||||
mock_attn_backend.return_value = lambda *args, **kwargs: mock_attn
|
||||
mock_rope.return_value = paddle.zeros([1, 2048, 64])
|
||||
mock_eagle.return_value = paddle.zeros([2, 768], dtype="bfloat16")
|
||||
mock_preprocess.return_value = None
|
||||
|
||||
proposer = MTPProposer(
|
||||
self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs
|
||||
)
|
||||
full_hidden_states = paddle.zeros([2, 768], dtype="bfloat16")
|
||||
proposer.model_inputs["seq_lens_this_time"] = proposer.seq_lens_this_time_buffer
|
||||
|
||||
# Test _prepare_inputs
|
||||
proposer._prepare_inputs(full_hidden_states)
|
||||
mock_preprocess.assert_called()
|
||||
mock_eagle.assert_called()
|
||||
|
||||
# Test _post_process with prefill role
|
||||
proposer.role = "prefill"
|
||||
sampled_token_ids = paddle.ones([2, 1], dtype="int64")
|
||||
proposer._post_process(sampled_token_ids)
|
||||
|
||||
@patch("fastdeploy.spec_decode.mtp.get_model_loader")
|
||||
@patch("fastdeploy.spec_decode.mtp.get_attention_backend")
|
||||
@patch("fastdeploy.spec_decode.mtp.get_rope")
|
||||
def test_update_task_chunk_prefill(self, mock_rope, mock_attn_backend, mock_model_loader):
|
||||
"""Test update_task_chunk_prefill"""
|
||||
mock_model = Mock()
|
||||
mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000]))
|
||||
mock_model_loader.return_value.load_model.return_value = mock_model
|
||||
mock_attn = Mock()
|
||||
mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64])
|
||||
mock_attn_backend.return_value = lambda *args, **kwargs: mock_attn
|
||||
mock_rope.return_value = paddle.zeros([1, 2048, 64])
|
||||
|
||||
proposer = MTPProposer(
|
||||
self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs
|
||||
)
|
||||
proposer.model_inputs["seq_lens_this_time"] = proposer.seq_lens_this_time_buffer
|
||||
|
||||
task = Mock()
|
||||
task.idx = 0
|
||||
task.prefill_chunk_info = [3, 2, 1]
|
||||
task.prompt_token_ids = [1, 2, 3, 4, 5, 6]
|
||||
|
||||
# Test chunk_idx == len(prefill_chunk_info)
|
||||
task.chunk_idx = 3
|
||||
task.get = Mock(return_value=0)
|
||||
proposer.update_task_chunk_prefill(task)
|
||||
|
||||
# Test chunk_idx < len - 1
|
||||
task.chunk_idx = 0
|
||||
proposer.update_task_chunk_prefill(task)
|
||||
|
||||
# Test last prefill
|
||||
task.chunk_idx = 2
|
||||
proposer.update_task_chunk_prefill(task)
|
||||
|
||||
@patch("fastdeploy.spec_decode.mtp.get_model_loader")
|
||||
@patch("fastdeploy.spec_decode.mtp.get_attention_backend")
|
||||
@patch("fastdeploy.spec_decode.mtp.get_rope")
|
||||
@patch("fastdeploy.spec_decode.mtp.draft_model_postprocess")
|
||||
@patch("fastdeploy.spec_decode.mtp.mtp_step_paddle")
|
||||
def test_update_status(self, mock_mtp_step, mock_postprocess, mock_rope, mock_attn_backend, mock_model_loader):
|
||||
"""Test _update_status"""
|
||||
mock_model = Mock()
|
||||
mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000]))
|
||||
mock_model_loader.return_value.load_model.return_value = mock_model
|
||||
mock_attn = Mock()
|
||||
mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64])
|
||||
mock_attn_backend.return_value = lambda *args, **kwargs: mock_attn
|
||||
mock_rope.return_value = paddle.zeros([1, 2048, 64])
|
||||
mock_postprocess.return_value = None
|
||||
mock_mtp_step.return_value = None
|
||||
|
||||
proposer = MTPProposer(
|
||||
self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs
|
||||
)
|
||||
proposer.model_inputs["seq_lens_this_time"] = proposer.seq_lens_this_time_buffer
|
||||
|
||||
# Test with ENABLE_V1_KVCACHE_SCHEDULER=False
|
||||
with patch("fastdeploy.spec_decode.mtp.envs.ENABLE_V1_KVCACHE_SCHEDULER", False):
|
||||
proposer._update_status()
|
||||
mock_mtp_step.assert_called()
|
||||
|
||||
@patch("fastdeploy.spec_decode.mtp.get_model_loader")
|
||||
@patch("fastdeploy.spec_decode.mtp.get_attention_backend")
|
||||
@patch("fastdeploy.spec_decode.mtp.get_rope")
|
||||
@patch("fastdeploy.spec_decode.mtp.hybrid_mtp_ngram")
|
||||
def test_extend_draft_token_and_run_impl(self, mock_ngram, mock_rope, mock_attn_backend, mock_model_loader):
|
||||
"""Test _extend_draft_token_with_ngram_match and _run_impl"""
|
||||
mock_model = Mock()
|
||||
mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000]))
|
||||
mock_model_loader.return_value.load_model.return_value = mock_model
|
||||
mock_attn = Mock()
|
||||
mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64])
|
||||
mock_attn_backend.return_value = lambda *args, **kwargs: mock_attn
|
||||
mock_rope.return_value = paddle.zeros([1, 2048, 64])
|
||||
mock_ngram.return_value = None
|
||||
|
||||
proposer = MTPProposer(
|
||||
self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs
|
||||
)
|
||||
proposer.hybrid_mode = True
|
||||
proposer.max_ngram_size = 5
|
||||
proposer.min_ngram_size = 2
|
||||
|
||||
# Test _extend_draft_token_with_ngram_match
|
||||
proposer._extend_draft_token_with_ngram_match()
|
||||
mock_ngram.assert_called()
|
||||
|
||||
# Test _run_impl with hybrid_mode
|
||||
full_hidden_states = paddle.zeros([2, 768], dtype="bfloat16")
|
||||
with (
|
||||
patch.object(proposer, "_prepare_inputs"),
|
||||
patch.object(proposer, "_propose"),
|
||||
patch.object(proposer, "_update_status"),
|
||||
):
|
||||
proposer._run_impl(full_hidden_states)
|
||||
|
||||
@patch("fastdeploy.spec_decode.mtp.get_model_loader")
|
||||
@patch("fastdeploy.spec_decode.mtp.get_attention_backend")
|
||||
@patch("fastdeploy.spec_decode.mtp.get_rope")
|
||||
def test_padding_cudagraph_inputs_and_empty_cache(self, mock_rope, mock_attn_backend, mock_model_loader):
|
||||
"""Test padding_cudagraph_inputs and _empty_cache"""
|
||||
mock_model = Mock()
|
||||
mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000]))
|
||||
mock_model_loader.return_value.load_model.return_value = mock_model
|
||||
mock_attn = Mock()
|
||||
mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64])
|
||||
mock_attn_backend.return_value = lambda *args, **kwargs: mock_attn
|
||||
mock_rope.return_value = paddle.zeros([1, 2048, 64])
|
||||
|
||||
proposer = MTPProposer(
|
||||
self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs
|
||||
)
|
||||
proposer.initialize_kv_cache(main_model_num_blocks=10)
|
||||
proposer.model_inputs["seq_lens_this_time"] = proposer.seq_lens_this_time_buffer
|
||||
proposer._initialize_forward_meta()
|
||||
|
||||
# Test padding_cudagraph_inputs with step_use_cudagraph=True
|
||||
proposer.forward_meta.step_use_cudagraph = True
|
||||
proposer.padding_cudagraph_inputs()
|
||||
self.assertIsNotNone(proposer.real_token_num)
|
||||
|
||||
# Test _empty_cache
|
||||
with patch("paddle.device.cuda.empty_cache") as mock_empty:
|
||||
proposer._empty_cache()
|
||||
mock_empty.assert_called()
|
||||
|
||||
@patch("fastdeploy.spec_decode.mtp.get_model_loader")
|
||||
@patch("fastdeploy.spec_decode.mtp.get_attention_backend")
|
||||
@patch("fastdeploy.spec_decode.mtp.get_rope")
|
||||
@patch("fastdeploy.spec_decode.mtp.current_platform")
|
||||
def test_cache_type_branches(self, mock_platform, mock_rope, mock_attn_backend, mock_model_loader):
|
||||
"""Cover _get_cache_type CUDA/XPU/unsupported branches"""
|
||||
mock_model = Mock()
|
||||
mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000]))
|
||||
mock_model_loader.return_value.load_model.return_value = mock_model
|
||||
mock_attn = Mock()
|
||||
mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64])
|
||||
mock_attn_backend.return_value = lambda *args, **kwargs: mock_attn
|
||||
mock_rope.return_value = paddle.zeros([1, 2048, 64])
|
||||
|
||||
# CUDA branch
|
||||
mock_platform.is_cuda.return_value = True
|
||||
mock_platform.is_xpu.return_value = False
|
||||
proposer = MTPProposer(
|
||||
self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs
|
||||
)
|
||||
self.assertEqual(proposer._get_cache_type(), "uint8")
|
||||
|
||||
# XPU branch
|
||||
mock_platform.is_cuda.return_value = False
|
||||
mock_platform.is_xpu.return_value = True
|
||||
proposer = MTPProposer(
|
||||
self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs
|
||||
)
|
||||
self.assertEqual(proposer._get_cache_type(), "int8")
|
||||
|
||||
# Unsupported branch: reuse existing proposer to avoid RuntimeError in __init__
|
||||
mock_platform.is_cuda.return_value = False
|
||||
mock_platform.is_xpu.return_value = False
|
||||
with self.assertRaises(NotImplementedError):
|
||||
proposer._get_cache_type()
|
||||
|
||||
@patch("fastdeploy.spec_decode.mtp.get_model_loader")
|
||||
@patch("fastdeploy.spec_decode.mtp.get_attention_backend")
|
||||
@patch("fastdeploy.spec_decode.mtp.get_rope")
|
||||
def test_init_model_inputs_with_mm(self, mock_rope, mock_attn_backend, mock_model_loader):
|
||||
"""Init model inputs with enable_mm=True to cover attn_mask buffers"""
|
||||
mock_model = Mock()
|
||||
mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000]))
|
||||
mock_model_loader.return_value.load_model.return_value = mock_model
|
||||
mock_attn = Mock()
|
||||
mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64])
|
||||
mock_attn_backend.return_value = lambda *args, **kwargs: mock_attn
|
||||
mock_rope.return_value = paddle.zeros([1, 2048, 64])
|
||||
|
||||
self.fd_config.model_config.enable_mm = True
|
||||
proposer = MTPProposer(
|
||||
self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs
|
||||
)
|
||||
self.assertIn("attn_mask_offsets", proposer.model_inputs)
|
||||
self.assertIn("attn_mask_offsets_full", proposer.model_inputs)
|
||||
self.assertIn("attn_mask_offsets_decoder", proposer.model_inputs)
|
||||
|
||||
@patch("fastdeploy.spec_decode.mtp.get_model_loader")
|
||||
@patch("fastdeploy.spec_decode.mtp.get_attention_backend")
|
||||
@patch("fastdeploy.spec_decode.mtp.get_rope")
|
||||
def test_insert_tasks_v1_preempted(self, mock_rope, mock_attn_backend, mock_model_loader):
|
||||
"""Cover RequestType.PREEMPTED branch in insert_tasks_v1"""
|
||||
mock_model = Mock()
|
||||
mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000]))
|
||||
mock_model_loader.return_value.load_model.return_value = mock_model
|
||||
mock_attn = Mock()
|
||||
mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64])
|
||||
mock_attn_backend.return_value = lambda *args, **kwargs: mock_attn
|
||||
mock_rope.return_value = paddle.zeros([1, 2048, 64])
|
||||
|
||||
proposer = MTPProposer(
|
||||
self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs
|
||||
)
|
||||
proposer.model_inputs["caches"] = [] # avoid re-init
|
||||
|
||||
request = Request(
|
||||
request_id="preempt",
|
||||
prompt="t",
|
||||
prompt_token_ids=[1],
|
||||
prompt_token_ids_len=1,
|
||||
messages=None,
|
||||
history=None,
|
||||
tools=None,
|
||||
system=None,
|
||||
eos_token_ids=[2],
|
||||
)
|
||||
request.idx = 0
|
||||
request.task_type = RequestType.PREEMPTED
|
||||
proposer.insert_tasks_v1([request], 1)
|
||||
|
||||
self.assertTrue(proposer.model_inputs["stop_flags"][0].item())
|
||||
self.assertEqual(proposer.seq_lens_this_time_buffer[0].item(), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -60,6 +60,7 @@ class FakeModelConfig:
|
||||
self.model_format = "auto"
|
||||
self.enable_mm = False
|
||||
self.max_model_len = 512
|
||||
self.logprobs_mode = "raw_logprobs"
|
||||
|
||||
|
||||
def get_default_test_fd_config():
|
||||
|
||||
Reference in New Issue
Block a user