diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index ec7a1c072..9c4b8c9dc 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -275,6 +275,7 @@ class MTPProposer(Proposer): # self.model_inputs["caches"] = self.cache_kvs # Inherit generation hyperparameters from the main model for consistency self.model_inputs["top_p"] = self.main_model_inputs["top_p"] + self.model_inputs["top_k"] = self.main_model_inputs["top_k"] self.model_inputs["temperature"] = self.main_model_inputs["temperature"] self.model_inputs["eos_token_id"] = self.main_model_inputs["eos_token_id"] self.model_inputs["penalty_score"] = self.main_model_inputs["penalty_score"] @@ -528,6 +529,7 @@ class MTPProposer(Proposer): self.sampling_metadata = SamplingMetadata( temperature=self.model_inputs["temperature"], top_p=self.model_inputs["top_p"], + top_k=self.model_inputs["top_k"], step_idx=self.model_inputs["step_idx"], pre_token_ids=self.model_inputs["pre_ids"], frequency_penalties=self.model_inputs["frequency_score"],