[Feature][MTP]Support new mtp (#3656)

* update multi-draft-token strategy

* fix format

* support hybrid mtp with ngram speculative decoding method
This commit is contained in:
freeliuzc
2025-08-27 19:38:26 +08:00
committed by GitHub
parent 62659a7a73
commit c753f1fc9e
20 changed files with 501 additions and 579 deletions

View File

@@ -294,16 +294,24 @@ class SpeculativeConfig:
self,
args,
):
# speculative method, choose in [None, "ngram_match", "mtp"]
self.method_list = ["ngram_match", "mtp"]
self.mtp_strategy_list = ["default", "with_ngram"]
# speculative method, choose in [None, "ngram_match", "mtp", "hybrid_mtp_ngram"]
self.method: Optional[str] = None
# mtp strategy in mtp-method
self.mtp_strategy = "default"
# the max length of speculative tokens
self.num_speculative_tokens: int = 1
# the model runner step of draft model/mtp...
self.num_model_steps: int = 1
# the max length of candidate tokens for speculative method
self.max_candidate_len: int = 5
# the max length of verify window for speculative method
self.verify_window: int = 2
# ngram match
self.max_ngram_size: int = 5
self.min_ngram_size: int = 2
# model for mtp/eagle/draft_model
self.model: Optional[str] = None
# quantization of model
@@ -390,6 +398,33 @@ class SpeculativeConfig:
logger.info("{:<20}:{:<6}{}".format(k, "", v))
logger.info("=============================================================")
def check_legality_parameters(
self,
) -> None:
"""Check the legality of parameters passed in from the command line"""
if self.method is not None:
assert (
self.method in self.method_list
), f"speculative method only support {self.method_list} now, but get {self.method}."
assert (
self.num_speculative_tokens >= 1 and self.num_speculative_tokens <= 5
), f"num_speculative_tokens only support in range[1, 5], but get {self.num_speculative_tokens}."
assert (
self.num_model_steps >= 1 and self.num_model_steps <= 5
), f"num_model_steps only support in range[1, 5], but get {self.num_model_steps}."
if self.method in ["mtp", "hybrid_mtp_ngram"]:
if self.num_speculative_tokens < self.num_model_steps:
logger.warning(
f"Get num_model_steps > num_speculative_tokens. Reset num_speculative_tokens to {self.num_model_steps}"
)
self.num_speculative_tokens = self.num_model_steps
assert (
self.mtp_strategy in self.mtp_strategy_list
), f"mtp_strategy_list only support {self.mtp_strategy_list}, but get {self.mtp_strategy}"
def __str__(self) -> str:
return self.to_json_string()