mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 01:22:59 +08:00
[Feature][MTP]Support new mtp (#3656)
* update multi-draft-token strategy * fix format * support hybrid mtp with ngram speculative decoding method
This commit is contained in:
@@ -294,16 +294,24 @@ class SpeculativeConfig:
|
||||
self,
|
||||
args,
|
||||
):
|
||||
# speculative method, choose in [None, "ngram_match", "mtp"]
|
||||
self.method_list = ["ngram_match", "mtp"]
|
||||
self.mtp_strategy_list = ["default", "with_ngram"]
|
||||
|
||||
# speculative method, choose in [None, "ngram_match", "mtp", "hybrid_mtp_ngram"]
|
||||
self.method: Optional[str] = None
|
||||
# mtp strategy in mtp-method
|
||||
self.mtp_strategy = "default"
|
||||
# the max length of speculative tokens
|
||||
self.num_speculative_tokens: int = 1
|
||||
# the model runner step of draft model/mtp...
|
||||
self.num_model_steps: int = 1
|
||||
# the max length of candidate tokens for speculative method
|
||||
self.max_candidate_len: int = 5
|
||||
# the max length of verify window for speculative method
|
||||
self.verify_window: int = 2
|
||||
# ngram match
|
||||
self.max_ngram_size: int = 5
|
||||
self.min_ngram_size: int = 2
|
||||
# model for mtp/eagle/draft_model
|
||||
self.model: Optional[str] = None
|
||||
# quantization of model
|
||||
@@ -390,6 +398,33 @@ class SpeculativeConfig:
|
||||
logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
logger.info("=============================================================")
|
||||
|
||||
def check_legality_parameters(
|
||||
self,
|
||||
) -> None:
|
||||
"""Check the legality of parameters passed in from the command line"""
|
||||
if self.method is not None:
|
||||
assert (
|
||||
self.method in self.method_list
|
||||
), f"speculative method only support {self.method_list} now, but get {self.method}."
|
||||
|
||||
assert (
|
||||
self.num_speculative_tokens >= 1 and self.num_speculative_tokens <= 5
|
||||
), f"num_speculative_tokens only support in range[1, 5], but get {self.num_speculative_tokens}."
|
||||
assert (
|
||||
self.num_model_steps >= 1 and self.num_model_steps <= 5
|
||||
), f"num_model_steps only support in range[1, 5], but get {self.num_model_steps}."
|
||||
|
||||
if self.method in ["mtp", "hybrid_mtp_ngram"]:
|
||||
if self.num_speculative_tokens < self.num_model_steps:
|
||||
logger.warning(
|
||||
f"Get num_model_steps > num_speculative_tokens. Reset num_speculative_tokens to {self.num_model_steps}"
|
||||
)
|
||||
self.num_speculative_tokens = self.num_model_steps
|
||||
|
||||
assert (
|
||||
self.mtp_strategy in self.mtp_strategy_list
|
||||
), f"mtp_strategy_list only support {self.mtp_strategy_list}, but get {self.mtp_strategy}"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.to_json_string()
|
||||
|
||||
|
Reference in New Issue
Block a user