Fix Speculative Config bug (#3049)

* fix speculative bug

* fix rl
This commit is contained in:
YuanRisheng
2025-07-29 10:50:48 +08:00
committed by GitHub
parent f2a528f9ae
commit 1a815b7a2a
6 changed files with 21 additions and 58 deletions

View File

@@ -24,7 +24,7 @@ import numpy as np
import paddle
from fastdeploy.cache_manager.cache_data import CacheStatus
from fastdeploy.engine.config import SpeculativeConfig
from fastdeploy.config import SpeculativeConfig
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
from fastdeploy.model_executor.ops.gpu import (
cuda_host_alloc,
@@ -114,7 +114,7 @@ class CacheTransferManager:
self.cpu_cache_kvs = {}
self.gpu_cache_k_tensors = []
self.gpu_cache_v_tensors = []
self.speculative_config = SpeculativeConfig(**args.speculative_config)
self.speculative_config = SpeculativeConfig(args.speculative_config)
self.num_extra_layers = self.speculative_config.num_extra_cache_layer
self.num_extra_layer_gpu_blocks = int(args.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)

View File

@@ -254,7 +254,7 @@ class SpeculativeConfig:
# ngram match
self.max_ngram_size: int = 5
# model for mtp/eagle/draft_model
self.model_name_or_path: Optional[str] = None
self.model: Optional[str] = None
# quantization of model
self.quantization: Optional[str] = None
# allocate more blocks to prevent mtp from finishing the block earlier than the main model
@@ -273,21 +273,11 @@ class SpeculativeConfig:
self.benchmark_mode: bool = False
self.num_extra_cache_layer = 0
# TODO(YuanRisheng): The name of the server args is different from the name of the SpeculativeConfig.
# We temperately add the name map here and will delete it in future.
name_map = {
"speculative_method": "method",
"speculative_max_draft_token_num": "num_speculative_tokens",
"speculative_model_name_or_path": "model_name_or_path",
"speculative_model_quantization": "quantization",
"speculative_benchmark_mode": "benchmark_mode",
}
for key, value in args.items():
if key in name_map.keys() and hasattr(self, name_map[key]):
if key == "speculative_benchmark_mode":
value = True if value.lower() == "true" else False
setattr(self, name_map[key], value)
if hasattr(self, key):
setattr(self, key, value)
self.read_model_config()
self.reset()
@@ -299,11 +289,11 @@ class SpeculativeConfig:
if not self.enabled_speculative_decoding():
return
self.is_unified_ckpt = check_unified_ckpt(self.model_name_or_path)
if self.model_name_or_path is None:
self.is_unified_ckpt = check_unified_ckpt(self.model)
if self.model is None:
return
self.config_path = os.path.join(self.model_name_or_path, "config.json")
self.config_path = os.path.join(self.model, "config.json")
if os.path.exists(self.config_path):
self.model_config = json.load(open(self.config_path, "r", encoding="utf-8"))

View File

@@ -1081,11 +1081,7 @@ class LLMEngine:
f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}"
f" --quantization {self.cfg.model_config.quantization}"
f" --ori_vocab_size {ori_vocab_size}"
f" --speculative_method {self.cfg.speculative_config.method}"
f" --speculative_max_draft_token_num {self.cfg.speculative_config.num_speculative_tokens}"
f" --speculative_model_name_or_path {self.cfg.speculative_config.model_name_or_path}"
f" --speculative_model_quantization {self.cfg.speculative_config.quantization}"
f" --speculative_benchmark_mode {self.cfg.speculative_config.benchmark_mode}"
f" --speculative_config '{self.cfg.speculative_config.to_json_string()}'"
f" --graph_optimization_config '{self.cfg.graph_optimization_config.to_json_string()}'"
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
f" --load_strategy {self.cfg.model_config.load_strategy}"

View File

@@ -83,10 +83,11 @@ class RolloutModelConfig:
self.pad_token_id = pad_token_id
self.eos_tokens_lens = eos_tokens_lens
self.enable_chunked_prefill = enable_chunked_prefill
self.speculative_method = speculative_method
self.speculative_max_draft_token_num = speculative_max_draft_token_num
self.speculative_model_name_or_path = speculative_model_name_or_path
self.speculative_model_quantization = speculative_model_quantization
self.speculative_config = {}
self.speculative_config["method"] = speculative_method
self.speculative_config["max_draft_token_num"] = speculative_max_draft_token_num
self.speculative_config["model"] = speculative_model_name_or_path
self.speculative_config["quantization"] = speculative_model_quantization
self.max_num_batched_tokens = max_num_batched_tokens
self.enable_prefix_caching = enable_prefix_caching
self.splitwise_role = splitwise_role

View File

@@ -73,7 +73,7 @@ class MTPProposer(Proposer):
self.model_config.architectures[0] = "Ernie4_5_MTPForCausalLM"
self.speculative_config.sharing_model = main_model
self.model_config.num_hidden_layers = 1
self.parallel_config.model_name_or_path = self.speculative_config.model_name_or_path
self.parallel_config.model_name_or_path = self.speculative_config.model
self.model_config.pretrained_config.prefix_name = "ernie.mtp_block"
if self.speculative_config.quantization != "":
self.model_config.quantization = self.speculative_config.quantization

View File

@@ -41,7 +41,7 @@ from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.model_executor.layers.quantization import get_quantization_config
from fastdeploy.platforms import current_platform
from fastdeploy.utils import get_logger, none_or_str
from fastdeploy.utils import get_logger
from fastdeploy.worker.worker_base import WorkerBase
logger = get_logger("worker_process", "worker_process.log")
@@ -476,34 +476,10 @@ def parse_args():
help="enable chunked prefill",
)
parser.add_argument(
"--speculative_method",
"--speculative_config",
type=json.loads,
default=None,
type=none_or_str,
choices=[
None,
"ngram",
"mtp",
],
)
parser.add_argument(
"--speculative_max_draft_token_num",
default=1,
type=int,
)
parser.add_argument(
"--speculative_model_name_or_path",
default="",
type=str,
)
parser.add_argument(
"--speculative_model_quantization",
default="WINT8",
type=str,
)
parser.add_argument(
"--speculative_benchmark_mode",
default="False",
type=str,
help="Configation of SpeculativeConfig.",
)
parser.add_argument(
"--max_num_batched_tokens",
@@ -607,7 +583,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
model_config = ModelConfig(vars(args))
device_config = DeviceConfig(vars(args))
decoding_config = DecodingConfig(vars(args))
speculative_config = SpeculativeConfig(vars(args))
speculative_config = SpeculativeConfig(args.speculative_config)
parallel_config = ParallelConfig(vars(args))
cache_config = CacheConfig(vars(args))
parallel_config.tensor_parallel_size = args.tensor_parallel_size