mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-27 04:46:16 +08:00
@@ -254,7 +254,7 @@ class SpeculativeConfig:
|
||||
# ngram match
|
||||
self.max_ngram_size: int = 5
|
||||
# model for mtp/eagle/draft_model
|
||||
self.model_name_or_path: Optional[str] = None
|
||||
self.model: Optional[str] = None
|
||||
# quantization of model
|
||||
self.quantization: Optional[str] = None
|
||||
# allocate more blocks to prevent mtp from finishing the block earlier than the main model
|
||||
@@ -273,21 +273,11 @@ class SpeculativeConfig:
|
||||
self.benchmark_mode: bool = False
|
||||
|
||||
self.num_extra_cache_layer = 0
|
||||
# TODO(YuanRisheng): The name of the server args is different from the name of the SpeculativeConfig.
|
||||
# We temperately add the name map here and will delete it in future.
|
||||
name_map = {
|
||||
"speculative_method": "method",
|
||||
"speculative_max_draft_token_num": "num_speculative_tokens",
|
||||
"speculative_model_name_or_path": "model_name_or_path",
|
||||
"speculative_model_quantization": "quantization",
|
||||
"speculative_benchmark_mode": "benchmark_mode",
|
||||
}
|
||||
|
||||
for key, value in args.items():
|
||||
if key in name_map.keys() and hasattr(self, name_map[key]):
|
||||
if key == "speculative_benchmark_mode":
|
||||
value = True if value.lower() == "true" else False
|
||||
setattr(self, name_map[key], value)
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
|
||||
self.read_model_config()
|
||||
self.reset()
|
||||
|
||||
@@ -299,11 +289,11 @@ class SpeculativeConfig:
|
||||
if not self.enabled_speculative_decoding():
|
||||
return
|
||||
|
||||
self.is_unified_ckpt = check_unified_ckpt(self.model_name_or_path)
|
||||
if self.model_name_or_path is None:
|
||||
self.is_unified_ckpt = check_unified_ckpt(self.model)
|
||||
if self.model is None:
|
||||
return
|
||||
|
||||
self.config_path = os.path.join(self.model_name_or_path, "config.json")
|
||||
self.config_path = os.path.join(self.model, "config.json")
|
||||
if os.path.exists(self.config_path):
|
||||
self.model_config = json.load(open(self.config_path, "r", encoding="utf-8"))
|
||||
|
||||
|
Reference in New Issue
Block a user