mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
support mm mtp (#4013)
This commit is contained in:
@@ -33,6 +33,7 @@ from fastdeploy.model_executor.layers.sample.ops import (
|
||||
min_p_sampling,
|
||||
top_k_top_p_sampling,
|
||||
)
|
||||
from fastdeploy.model_executor.ops.gpu import limit_content_len
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.worker.output import LogprobsTensors, SamplerOutput
|
||||
|
||||
@@ -304,6 +305,7 @@ class SpeculativeSampler(nn.Layer):
|
||||
self.speculative_verify_window = fd_config.speculative_config.verify_window
|
||||
self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len
|
||||
self.speculative_benchmark_mode = fd_config.speculative_config.benchmark_mode
|
||||
self.fd_config = fd_config
|
||||
|
||||
def pre_process(self, skip_idx_list: List[int] = []):
|
||||
"""pre process before running"""
|
||||
@@ -382,6 +384,22 @@ class SpeculativeSampler(nn.Layer):
|
||||
self.speculative_benchmark_mode,
|
||||
)
|
||||
|
||||
if hasattr(self.fd_config.model_config, "think_end_id") and self.fd_config.model_config.think_end_id > 0:
|
||||
limit_content_len(
|
||||
share_inputs["accept_tokens"],
|
||||
self.fd_config.model_config.think_end_id,
|
||||
share_inputs["max_content_len"],
|
||||
share_inputs["max_think_len"],
|
||||
share_inputs["step_idx"],
|
||||
sampling_metadata.eos_token_ids,
|
||||
share_inputs["max_dec_len"],
|
||||
share_inputs["limit_content_status"],
|
||||
share_inputs["enable_thinking"],
|
||||
share_inputs["accept_num"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["stop_flags"],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user