diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu index aa6235687..0e6e66d00 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu @@ -38,14 +38,20 @@ __device__ int64_t topp_sampling_kernel(const int64_t *candidate_ids, const int tid = threadIdx.x; float sum_scores = 0.0f; - float rand_top_p = curand_uniform(dev_curand_states + tid) * topp; + for (int i = 0; i < candidate_len; i++) { + sum_scores += candidate_scores[i]; + } + float tgt_topp = sum_scores < topp ? sum_scores : topp; + + sum_scores = 0.0f; + float rand_top_p = curand_uniform(dev_curand_states + tid) * tgt_topp; for (int i = 0; i < candidate_len; i++) { sum_scores += candidate_scores[i]; if (rand_top_p <= sum_scores) { - return candidate_ids[i]; + return candidate_ids[i]; } } - return candidate_ids[0]; + return candidate_ids[0]; } __global__ void setup_kernel(curandState_t *state, const uint64_t seed, diff --git a/custom_ops/gpu_ops/speculate_decoding/top_p_candidates.cu b/custom_ops/gpu_ops/speculate_decoding/top_p_candidates.cu index 687041e48..a9e66862f 100644 --- a/custom_ops/gpu_ops/speculate_decoding/top_p_candidates.cu +++ b/custom_ops/gpu_ops/speculate_decoding/top_p_candidates.cu @@ -467,6 +467,9 @@ __global__ void KeMatrixTopPBeamTopKFt( break; } } + if (top_p_value == 1.0 && actual_candidates_lens[token_id] == 0){ + actual_candidates_lens[token_id] = max_cadidate_len; + } } } diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 5621d0132..e65b888f3 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -95,6 +95,13 @@ environment_variables: dict[str, Callable[[], Any]] = { "FD_FOR_TORCH_MODEL_FORMAT": lambda: bool(int(os.getenv("FD_FOR_TORCH_MODEL_FORMAT", "0"))), # force disable default chunked prefill "FD_DISABLE_CHUNKED_PREFILL": lambda: bool(int(os.getenv("FD_DISABLE_CHUNKED_PREFILL", "0"))), + # For separate setting of sampling parameters for speculative decoding + "FD_SPECULATE_SAMPLING_TOP_P": lambda: ( + None if "FD_SPECULATE_SAMPLING_TOP_P" not in os.environ else float(os.environ["FD_SPECULATE_SAMPLING_TOP_P"]) + ), + "FD_SPECULATE_SAMPLING_TOP_K": lambda: ( + None if "FD_SPECULATE_SAMPLING_TOP_K" not in os.environ else float(os.environ["FD_SPECULATE_SAMPLING_TOP_K"]) + ), "FD_ENABLE_INTERNAL_ADAPTER": lambda: int(os.getenv("FD_ENABLE_INTERNAL_ADAPTER", "0")), # LLMEngine recieve requests port, used when FD_ENABLE_INTERNAL_ADAPTER=1 "FD_ZMQ_RECV_REQUEST_SERVER_PORT": lambda: os.getenv("FD_ZMQ_RECV_REQUEST_SERVER_PORT", "8200"), diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 2614c4596..cc75ed96c 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -303,8 +303,16 @@ class MTPProposer(Proposer): ) # self.model_inputs["caches"] = self.cache_kvs # Inherit generation hyperparameters from the main model for consistency - self.model_inputs["top_p"] = self.target_model_inputs["top_p"] - self.model_inputs["top_k"] = self.target_model_inputs["top_k"] + self.model_inputs["top_p"] = ( + self.target_model_inputs["top_p"] + if envs.FD_SPECULATE_SAMPLING_TOP_P is None + else paddle.full_like(self.target_model_inputs["top_p"], envs.FD_SPECULATE_SAMPLING_TOP_P) + ) + self.model_inputs["top_k"] = ( + self.target_model_inputs["top_k"] + if envs.FD_SPECULATE_SAMPLING_TOP_K is None + else paddle.full_like(self.target_model_inputs["top_k"], envs.FD_SPECULATE_SAMPLING_TOP_K) + ) self.model_inputs["temperature"] = self.target_model_inputs["temperature"] self.model_inputs["eos_token_id"] = self.target_model_inputs["eos_token_id"] self.model_inputs["penalty_score"] = self.target_model_inputs["penalty_score"]