Fix Speculative Config bug (#3049)

* fix speculative bug

* fix rl
This commit is contained in:
YuanRisheng
2025-07-29 10:50:48 +08:00
committed by GitHub
parent f2a528f9ae
commit 1a815b7a2a
6 changed files with 21 additions and 58 deletions

View File

@@ -254,7 +254,7 @@ class SpeculativeConfig:
# ngram match
self.max_ngram_size: int = 5
# model for mtp/eagle/draft_model
self.model_name_or_path: Optional[str] = None
self.model: Optional[str] = None
# quantization of model
self.quantization: Optional[str] = None
# allocate more blocks to prevent mtp from finishing the block earlier than the main model
@@ -273,21 +273,11 @@ class SpeculativeConfig:
self.benchmark_mode: bool = False
self.num_extra_cache_layer = 0
# TODO(YuanRisheng): The name of the server args is different from the name of the SpeculativeConfig.
# We temperately add the name map here and will delete it in future.
name_map = {
"speculative_method": "method",
"speculative_max_draft_token_num": "num_speculative_tokens",
"speculative_model_name_or_path": "model_name_or_path",
"speculative_model_quantization": "quantization",
"speculative_benchmark_mode": "benchmark_mode",
}
for key, value in args.items():
if key in name_map.keys() and hasattr(self, name_map[key]):
if key == "speculative_benchmark_mode":
value = True if value.lower() == "true" else False
setattr(self, name_map[key], value)
if hasattr(self, key):
setattr(self, key, value)
self.read_model_config()
self.reset()
@@ -299,11 +289,11 @@ class SpeculativeConfig:
if not self.enabled_speculative_decoding():
return
self.is_unified_ckpt = check_unified_ckpt(self.model_name_or_path)
if self.model_name_or_path is None:
self.is_unified_ckpt = check_unified_ckpt(self.model)
if self.model is None:
return
self.config_path = os.path.join(self.model_name_or_path, "config.json")
self.config_path = os.path.join(self.model, "config.json")
if os.path.exists(self.config_path):
self.model_config = json.load(open(self.config_path, "r", encoding="utf-8"))