support mm mtp (#4013)

This commit is contained in:
xiaoxiaohehe001
2025-09-09 13:55:45 +08:00
committed by GitHub
parent c753f1fc9e
commit 5223065d59
11 changed files with 278 additions and 54 deletions

View File

@@ -33,6 +33,7 @@ from fastdeploy.model_executor.layers.sample.ops import (
min_p_sampling,
top_k_top_p_sampling,
)
from fastdeploy.model_executor.ops.gpu import limit_content_len
from fastdeploy.platforms import current_platform
from fastdeploy.worker.output import LogprobsTensors, SamplerOutput
@@ -304,6 +305,7 @@ class SpeculativeSampler(nn.Layer):
self.speculative_verify_window = fd_config.speculative_config.verify_window
self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len
self.speculative_benchmark_mode = fd_config.speculative_config.benchmark_mode
self.fd_config = fd_config
def pre_process(self, skip_idx_list: List[int] = []):
"""pre process before running"""
@@ -382,6 +384,22 @@ class SpeculativeSampler(nn.Layer):
self.speculative_benchmark_mode,
)
if hasattr(self.fd_config.model_config, "think_end_id") and self.fd_config.model_config.think_end_id > 0:
limit_content_len(
share_inputs["accept_tokens"],
self.fd_config.model_config.think_end_id,
share_inputs["max_content_len"],
share_inputs["max_think_len"],
share_inputs["step_idx"],
sampling_metadata.eos_token_ids,
share_inputs["max_dec_len"],
share_inputs["limit_content_status"],
share_inputs["enable_thinking"],
share_inputs["accept_num"],
share_inputs["seq_lens_decoder"],
share_inputs["stop_flags"],
)
return None