diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc index 59df0f0f2..896ebb95a 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc @@ -26,7 +26,8 @@ namespace api = baidu::xpu::api; -void SpeculateVerify(const paddle::Tensor &accept_tokens, +void SpeculateVerify(const paddle::Tensor &sampled_token_ids, + const paddle::Tensor &accept_tokens, const paddle::Tensor &accept_num, const paddle::Tensor &step_idx, const paddle::Tensor &stop_flags, @@ -48,17 +49,12 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens, bool enable_topp, bool benchmark_mode, bool accept_all_drafts) { - // TODO(chenhuan09):support accept_all_drafts auto bsz = accept_tokens.shape()[0]; int real_bsz = seq_lens_this_time.shape()[0]; auto max_draft_tokens = draft_tokens.shape()[1]; auto end_length = end_tokens.shape()[0]; auto max_candidate_len = verify_tokens.shape()[1]; - constexpr int BlockSize = 512; - // set topp_seed if needed - const paddle::optional &topp_seed = nullptr; - phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); api::Context *ctx = @@ -69,16 +65,16 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens, xpu_ctx_flag = false; } - // phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); - // auto dev_ctx = - // paddle::experimental::DeviceContextPool::Instance().Get(place); auto - // xpu_ctx = static_cast(dev_ctx); - bool use_topk = false; char *env_var = getenv("SPECULATE_VERIFY_USE_TOPK"); if (env_var) { use_topk = static_cast(std::stoi(env_var)); } + bool use_target_sampling = false; + char *env_var_1 = getenv("SPECULATE_VERIFY_USE_TARGET_SAMPLING"); + if (env_var_1) { + use_target_sampling = static_cast(std::stoi(env_var_1)); + } bool prefill_one_step_stop = false; if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) { // std::cout << "Your PATH is: " << env_p << '\n'; @@ -108,10 +104,12 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens, auto dev_curand_states = !xpu_ctx_flag ? dev_curand_states_cpu.data() : dev_curand_states_xpu; + int ret; if (use_topk) { if (enable_topp) { - baidu::xpu::api::plugin::speculate_verify( + ret = baidu::xpu::api::plugin::speculate_verify( ctx, + sampled_token_ids.data(), const_cast(accept_tokens.data()), const_cast(accept_num.data()), const_cast(step_idx.data()), @@ -137,10 +135,14 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens, max_candidate_len, verify_window, prefill_one_step_stop, - benchmark_mode); + benchmark_mode, + accept_all_drafts, + use_target_sampling); + PD_CHECK(ret == 0, "speculate_verify failed."); } else { - baidu::xpu::api::plugin::speculate_verify( + ret = baidu::xpu::api::plugin::speculate_verify( ctx, + sampled_token_ids.data(), const_cast(accept_tokens.data()), const_cast(accept_num.data()), const_cast(step_idx.data()), @@ -166,12 +168,16 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens, max_candidate_len, verify_window, prefill_one_step_stop, - benchmark_mode); + benchmark_mode, + accept_all_drafts, + use_target_sampling); } + PD_CHECK(ret == 0, "speculate_verify failed."); } else { if (enable_topp) { - baidu::xpu::api::plugin::speculate_verify( + ret = baidu::xpu::api::plugin::speculate_verify( ctx, + sampled_token_ids.data(), const_cast(accept_tokens.data()), const_cast(accept_num.data()), const_cast(step_idx.data()), @@ -197,10 +203,14 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens, max_candidate_len, verify_window, prefill_one_step_stop, - benchmark_mode); + benchmark_mode, + accept_all_drafts, + use_target_sampling); + PD_CHECK(ret == 0, "speculate_verify failed."); } else { - baidu::xpu::api::plugin::speculate_verify( + ret = baidu::xpu::api::plugin::speculate_verify( ctx, + sampled_token_ids.data(), const_cast(accept_tokens.data()), const_cast(accept_num.data()), const_cast(step_idx.data()), @@ -226,18 +236,25 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens, max_candidate_len, verify_window, prefill_one_step_stop, - benchmark_mode); + benchmark_mode, + accept_all_drafts, + use_target_sampling); } + PD_CHECK(ret == 0, "speculate_verify failed."); + } + if (draft_tokens.is_cpu()) { + delete ctx; } } PD_BUILD_STATIC_OP(speculate_verify) - .Inputs({"accept_tokens", + .Inputs({"sampled_token_ids", + "accept_tokens", "accept_num", "step_idx", - "stop_flags", "seq_lens_encoder", "seq_lens_decoder", + "stop_flags", "draft_tokens", "seq_lens_this_time", "verify_tokens", diff --git a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc index 75bc176a0..353e686da 100644 --- a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc +++ b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc @@ -248,7 +248,8 @@ std::vector TopPCandidates( int candidates_len, int max_seq_len); -void SpeculateVerify(const paddle::Tensor& accept_tokens, +void SpeculateVerify(const paddle::Tensor& sampled_token_ids, + const paddle::Tensor& accept_tokens, const paddle::Tensor& accept_num, const paddle::Tensor& step_idx, const paddle::Tensor& stop_flags, @@ -1013,6 +1014,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("speculate_verify", &SpeculateVerify, + py::arg("sampled_token_ids"), py::arg("accept_tokens"), py::arg("accept_num"), py::arg("step_idx"), diff --git a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h index a00fa46ec..ef774095f 100644 --- a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h +++ b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h @@ -384,6 +384,7 @@ DLL_EXPORT int mtp_free_and_dispatch_block(Context* ctx, template DLL_EXPORT int speculate_verify(Context* ctx, + const int64_t* sampled_token_ids, int64_t* accept_tokens, int* accept_num, int64_t* step_idx, @@ -409,7 +410,9 @@ DLL_EXPORT int speculate_verify(Context* ctx, const int max_candidate_len, const int verify_window, const bool prefill_one_step_stop, - const bool benchmark_mode); + const bool benchmark_mode, + const bool accept_all_drafts, + const bool use_target_sampling); DLL_EXPORT int speculate_clear_accept_nums(Context* ctx, int* accept_num, diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_verify.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_verify.xpu index 26ad38c9f..350741924 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_verify.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_verify.xpu @@ -72,7 +72,7 @@ static __device__ inline unsigned int xorwow(unsigned int &state) { state ^= state >> 13; return state; } -typedef uint32_t curandStatePhilox4_32_10_t; + __device__ int64_t topp_sampling_kernel(__global_ptr__ const int64_t *candidate_ids, __global_ptr__ const float *candidate_scores, @@ -91,9 +91,10 @@ topp_sampling_kernel(__global_ptr__ const int64_t *candidate_ids, } return candidate_ids[0]; } -#define sm_size 1024 + template __global__ void speculate_verify( + const int64_t *sampled_token_ids, int64_t *accept_tokens, // out [real_bsz, max_draft_tokens], 输出最终接收的 // token(通过验证或采样) int *accept_num, // out [real_bsz], 每个序列最终接受的 token @@ -139,9 +140,10 @@ __global__ void speculate_verify( // 的最大候选数(用于验证或采样) const int verify_window, // scalar, TopK 验证窗口(允许连续 top1 匹配次数) const bool prefill_one_step_stop, - const bool benchmark_mode) { - const int cid = core_id(); - const int64_t tid = cluster_id() * core_num() + core_id(); + const bool benchmark_mode, + const bool accept_all_drafts, + const bool use_target_sampling) { + const int64_t tid = core_id() * cluster_num() + cluster_id(); const int64_t nthreads = cluster_num() * core_num(); for (int64_t bid = tid; bid < real_bsz; bid += nthreads) { int stop_flag_now_int = 0; @@ -158,6 +160,7 @@ __global__ void speculate_verify( verify_tokens + start_token_id * max_candidate_len; auto *draft_tokens_now = draft_tokens + bid * max_draft_tokens; auto *actual_candidate_len_now = actual_candidate_len + start_token_id; + auto *sampled_token_id_now = sampled_token_ids + start_token_id; int i = 0; // printf("seq_lens_this_time[%d]-1: %d \n",bid, // seq_lens_this_time[bid]-1); @@ -168,7 +171,43 @@ __global__ void speculate_verify( if (seq_lens_encoder[bid] != 0) { break; } - if (USE_TOPK) { + if (accept_all_drafts) { + // accept all draft tokens + step_idx[bid]++; + auto accept_token = draft_tokens_now[i + 1]; + accept_tokens[bid * max_draft_tokens + i] = accept_token; + + if (is_in_end(accept_token, end_tokens, end_length) || + step_idx[bid] >= max_dec_len[bid]) { + stop_flags[bid] = true; + stop_flag_now_int = 1; + if (step_idx[bid] >= max_dec_len[bid]) + accept_tokens[bid * max_draft_tokens + i] = end_tokens[0]; + break; + } else { + accept_num_now++; + } + continue; + } + if (use_target_sampling) { + if (sampled_token_id_now[i] == draft_tokens_now[i + 1]) { + step_idx[bid]++; + auto accept_token = draft_tokens_now[i + 1]; + accept_tokens[bid * max_draft_tokens + i] = accept_token; + if (is_in_end(accept_token, end_tokens, end_length) || + step_idx[bid] >= max_dec_len[bid]) { + stop_flags[bid] = true; + stop_flag_now_int = 1; + if (step_idx[bid] >= max_dec_len[bid]) + accept_tokens[bid * max_draft_tokens + i] = end_tokens[0]; + break; + } else { + accept_num_now++; + } + } else { + break; + } + } else if (USE_TOPK) { if (verify_tokens_now[i * max_candidate_len] == draft_tokens_now[i + 1]) { // accept_num_now++; @@ -274,7 +313,9 @@ __global__ void speculate_verify( __global_ptr__ const float *verify_scores_now = verify_scores + start_token_id * max_candidate_len; step_idx[bid]++; - if (ENABLE_TOPP) { + if (use_target_sampling) { + accept_token = sampled_token_id_now[i]; + } else if (ENABLE_TOPP) { auto actual_candidate_len_value = actual_candidate_len_now[i] > max_candidate_len ? max_candidate_len @@ -306,7 +347,8 @@ __global__ void speculate_verify( } #define SPECULATE_VERIFY_INSTANTIATE(ENABLE_TOPP, USE_TOPK) \ template __global__ void speculate_verify( \ - int64_t * accept_tokens, \ + const int64_t *sampled_token_ids, \ + int64_t *accept_tokens, \ int *accept_num, \ int64_t *step_idx, \ bool *stop_flags, \ @@ -331,7 +373,9 @@ __global__ void speculate_verify( int max_candidate_len, \ int verify_window, \ bool prefill_one_step_stop, \ - bool benchmark_mode); + bool benchmark_mode, \ + bool accept_all_drafts, \ + bool use_target_sampling); SPECULATE_VERIFY_INSTANTIATE(true, true) SPECULATE_VERIFY_INSTANTIATE(true, false) SPECULATE_VERIFY_INSTANTIATE(false, true) diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_verify.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_verify.cpp index 3989ce8de..20457f0f1 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_verify.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_verify.cpp @@ -24,6 +24,7 @@ typedef uint32_t curandStatePhilox4_32_10_t; template __attribute__((global)) void speculate_verify( + const int64_t *sampled_token_ids, int64_t *accept_tokens, int *accept_num, int64_t *step_idx, @@ -49,7 +50,9 @@ __attribute__((global)) void speculate_verify( const int max_candidate_len, const int verify_window, const bool prefill_one_step_stop, - const bool benchmark_mode); + const bool benchmark_mode, + const bool accept_all_drafts, + const bool use_target_sampling); } // namespace plugin } // namespace xpu3 @@ -113,6 +116,7 @@ static int64_t topp_sampling_kernel(const int64_t *candidate_ids, template static int cpu_wrapper(Context *ctx, + const int64_t *sampled_token_ids, int64_t *accept_tokens, int *accept_num, int64_t *step_idx, @@ -138,7 +142,9 @@ static int cpu_wrapper(Context *ctx, const int max_candidate_len, const int verify_window, const bool prefill_one_step_stop, - const bool benchmark_mode) { + const bool benchmark_mode, + const bool accept_all_drafts, + const bool use_target_sampling) { for (int bid = 0; bid < real_bsz; ++bid) { // verify and set stop flags int accept_num_now = 1; @@ -157,6 +163,7 @@ static int cpu_wrapper(Context *ctx, verify_tokens + start_token_id * max_candidate_len; auto *draft_tokens_now = draft_tokens + bid * max_draft_tokens; auto *actual_candidate_len_now = actual_candidate_len + start_token_id; + auto *sampled_token_id_now = sampled_token_ids + start_token_id; int i = 0; // printf("seq_lens_this_time[%d]-1: %d \n",bid, @@ -168,7 +175,43 @@ static int cpu_wrapper(Context *ctx, if (seq_lens_encoder[bid] != 0) { break; } - if (USE_TOPK) { + if (accept_all_drafts) { + // accept all draft tokens + step_idx[bid]++; + auto accept_token = draft_tokens_now[i + 1]; + accept_tokens[bid * max_draft_tokens + i] = accept_token; + + if (is_in_end(accept_token, end_tokens, end_length) || + step_idx[bid] >= max_dec_len[bid]) { + stop_flags[bid] = true; + stop_flag_now_int = 1; + if (step_idx[bid] >= max_dec_len[bid]) + accept_tokens[bid * max_draft_tokens + i] = end_tokens[0]; + break; + } else { + accept_num_now++; + } + continue; + } + if (use_target_sampling) { + if (sampled_token_id_now[i] == draft_tokens_now[i + 1]) { + step_idx[bid]++; + auto accept_token = draft_tokens_now[i + 1]; + accept_tokens[bid * max_draft_tokens + i] = accept_token; + if (is_in_end(accept_token, end_tokens, end_length) || + step_idx[bid] >= max_dec_len[bid]) { + stop_flags[bid] = true; + stop_flag_now_int = 1; + if (step_idx[bid] >= max_dec_len[bid]) + accept_tokens[bid * max_draft_tokens + i] = end_tokens[0]; + break; + } else { + accept_num_now++; + } + } else { + break; + } + } else if (USE_TOPK) { if (verify_tokens_now[i * max_candidate_len] == draft_tokens_now[i + 1]) { // accept_num_now++; @@ -270,7 +313,9 @@ static int cpu_wrapper(Context *ctx, const float *verify_scores_now = verify_scores + start_token_id * max_candidate_len; step_idx[bid]++; - if (ENABLE_TOPP) { + if (use_target_sampling) { + accept_token = sampled_token_id_now[i]; + } else if (ENABLE_TOPP) { auto actual_candidate_len_value = actual_candidate_len_now[i] > max_candidate_len ? max_candidate_len @@ -307,6 +352,7 @@ static int cpu_wrapper(Context *ctx, template static int xpu3_wrapper(Context *ctx, + const int64_t *sampled_token_ids, int64_t *accept_tokens, int *accept_num, int64_t *step_idx, @@ -332,10 +378,13 @@ static int xpu3_wrapper(Context *ctx, const int max_candidate_len, const int verify_window, const bool prefill_one_step_stop, - const bool benchmark_mode) { + const bool benchmark_mode, + const bool accept_all_drafts, + const bool use_target_sampling) { using XPU_INT64 = typename XPUIndexType::type; xpu3::plugin::speculate_verify - <<<1, 64, ctx->xpu_stream>>>( + <<ncluster(), 64, ctx->xpu_stream>>>( + reinterpret_cast(sampled_token_ids), reinterpret_cast(accept_tokens), accept_num, reinterpret_cast(step_idx), @@ -361,11 +410,14 @@ static int xpu3_wrapper(Context *ctx, max_candidate_len, verify_window, prefill_one_step_stop, - benchmark_mode); + benchmark_mode, + accept_all_drafts, + use_target_sampling); return api::SUCCESS; } template int speculate_verify(Context *ctx, + const int64_t *sampled_token_ids, int64_t *accept_tokens, int *accept_num, int64_t *step_idx, @@ -391,10 +443,13 @@ int speculate_verify(Context *ctx, const int max_candidate_len, const int verify_window, const bool prefill_one_step_stop, - const bool benchmark_mode) { + const bool benchmark_mode, + const bool accept_all_drafts, + const bool use_target_sampling) { WRAPPER_CHECK_CTX(ctx); WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_verify", int64_t); - WRAPPER_DUMP_PARAM3(ctx, accept_tokens, accept_num, step_idx); + WRAPPER_DUMP_PARAM4( + ctx, sampled_token_ids, accept_tokens, accept_num, step_idx); WRAPPER_DUMP_PARAM6(ctx, stop_flags, seq_lens_encoder, @@ -421,6 +476,7 @@ int speculate_verify(Context *ctx, verify_window, prefill_one_step_stop, benchmark_mode); + WRAPPER_DUMP_PARAM2(ctx, accept_all_drafts, use_target_sampling); WRAPPER_DUMP(ctx); WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * max_draft_tokens, accept_tokens); WRAPPER_CHECK_PTR(ctx, int, real_bsz, accept_num); @@ -454,6 +510,7 @@ int speculate_verify(Context *ctx, if (ctx->dev().type() == api::kCPU) { return cpu_wrapper(ctx, + sampled_token_ids, accept_tokens, accept_num, step_idx, @@ -479,10 +536,13 @@ int speculate_verify(Context *ctx, max_candidate_len, verify_window, prefill_one_step_stop, - benchmark_mode); + benchmark_mode, + accept_all_drafts, + use_target_sampling); } if (ctx->dev().type() == api::kXPU3) { return xpu3_wrapper(ctx, + sampled_token_ids, accept_tokens, accept_num, step_idx, @@ -508,7 +568,9 @@ int speculate_verify(Context *ctx, max_candidate_len, verify_window, prefill_one_step_stop, - benchmark_mode); + benchmark_mode, + accept_all_drafts, + use_target_sampling); } WRAPPER_UNIMPLEMENTED(ctx); } @@ -517,6 +579,7 @@ int speculate_verify(Context *ctx, template int \ baidu::xpu::api::plugin::speculate_verify( \ baidu::xpu::api::Context *, /* xpu_ctx */ \ + const int64_t *, /* sampled_token_ids */ \ int64_t *, /* accept_tokens */ \ int *, /* accept_num */ \ int64_t *, /* step_idx */ \ @@ -541,8 +604,11 @@ int speculate_verify(Context *ctx, int, /* max_seq_len */ \ int, /* max_candidate_len */ \ int, /* verify_window */ \ - bool, \ - bool); /* prefill_one_step_stop */ + bool, /* prefill_one_step_stop */ \ + bool, /* benchmark_mode */ \ + bool, /* accept_all_drafts */ \ + bool /* use_target_sampling */ \ + ); INSTANTIATE_SPECULATE_VERIFY(false, false) INSTANTIATE_SPECULATE_VERIFY(false, true) diff --git a/custom_ops/xpu_ops/test/test_speculate_verify.py b/custom_ops/xpu_ops/test/test_speculate_verify.py index 17733a6e2..d08a8e187 100644 --- a/custom_ops/xpu_ops/test/test_speculate_verify.py +++ b/custom_ops/xpu_ops/test/test_speculate_verify.py @@ -13,11 +13,10 @@ # limitations under the License. import random +import unittest from typing import List import numpy as np - -# tests/speculate_verify.py import paddle from fastdeploy.model_executor.ops.xpu import speculate_verify @@ -25,53 +24,31 @@ from fastdeploy.model_executor.ops.xpu import speculate_verify def topp_sampling_kernel(candidate_ids, candidate_scores, curand_value, candidate_len, topp, tid=0): """ - Python 仿真版 Top-p 样本选择函数。 + Python simulation version of the Top-p sampling function. - 参数: - - candidate_ids: [candidate_len] int64 array,候选 token - - candidate_scores: [candidate_len] float32 array,对应概率 - - curand_value: float,范围在 [0,1),模拟 GPU 中的 curand_uniform - - candidate_len: int,候选个数 - - topp: float,TopP 截断阈值 - - tid: 模拟线程 ID,仅用于调试(非必须) + Parameters: + - candidate_ids: [candidate_len] int64 array, candidate tokens + - candidate_scores: [candidate_len] float32 array, corresponding probabilities + - curand_value: float, in the range [0, 1), simulating the GPU's curand_uniform + - candidate_len: int, number of candidates + - topp: float, Top-P truncation threshold + - tid: simulated thread ID, for debugging purposes only (optional) - 返回: - - 采样得到的 token(int64) + Returns: + - The sampled token (int64) """ rand_top_p = curand_value * topp sum_scores = 0.0 for i in range(candidate_len): - print( - f"debug sample i:{i} scores:{candidate_scores[i]},ids:{candidate_ids[i]},curand_value{curand_value},topp{topp}, value*topp{rand_top_p}" - ) sum_scores += candidate_scores[i] sum_scores += candidate_scores[i] if rand_top_p <= sum_scores: return candidate_ids[i] - return candidate_ids[0] # fallback(理论上不会走到这) + return candidate_ids[0] -# def is_in_end(id: int, end_ids: np.ndarray, length: int) -> bool: -# """ -# 判断 id 是否存在于 end_ids 前 length 个元素中。 -# """ -# for i in range(length): -# if id == end_ids[i]: -# return True -# return False - -# def is_in(candidates: np.ndarray, draft: int, candidate_len: int) -> bool: -# """ -# 判断 draft 是否在 candidates 的前 candidate_len 个元素中。 -# """ -# for i in range(candidate_len): -# if draft == candidates[i]: -# return True -# return False - - -# ---------------- NumPy 参考实现 ---------------- -def speculate_verify_np( +def speculate_verify_ref( + sampled_token_ids, accept_tokens, accept_num, step_idx, @@ -92,6 +69,8 @@ def speculate_verify_np( max_seq_len, verify_window, enable_topp, + benchmark_mode, + accept_all_drafts, ): def is_in_end(token, end_tokens, end_length): return token in end_tokens[:end_length] @@ -112,84 +91,43 @@ def speculate_verify_np( infer_seed: List[int] = [initial_seed] * bsz dev_curand_states: List[float] = [] - # 循环生成随机数 for i in range(bsz): - current_seed = infer_seed[i] # 这里 current_seed 总是等于 initial_seed + current_seed = infer_seed[i] - # 使用当前的种子创建一个独立的随机数生成器实例 - # 这对应于 C++ 的 std::mt19937_64 engine(infer_seed[i]); + # std::mt19937_64 engine(infer_seed[i]); rng = random.Random(current_seed) - # 从独立的生成器中获取一个 [0.0, 1.0) 范围内的浮点数 - # 这对应于 C++ 的 dist(engine); dev_curand_states.append(rng.random()) - # --- 在函数内部进行扁平化操作 --- - # 只有那些在 C++ 中通过指针算术访问的多维数组需要扁平化 - accept_tokens_flat = accept_tokens.reshape(-1) - draft_tokens_flat = draft_tokens.reshape(-1) - verify_tokens_flat = verify_tokens.reshape(-1) - verify_scores_flat = verify_scores.reshape(-1) - print(f"DEBUG: accept_tokens_flat shape: {accept_tokens_flat.shape}") - print(f"DEBUG: draft_tokens_flat shape: {draft_tokens_flat.shape}") - print(f"DEBUG: verify_tokens_flat shape: {verify_tokens_flat.shape}") - print(f"DEBUG: verify_scores_flat shape: {verify_scores_flat.shape}") - # 其他数组 (如 accept_num, step_idx, stop_flags, end_tokens, dev_curand_states, actual_candidate_len, - # seq_lens_encoder, seq_lens_decoder, actual_draft_token_nums, topp_values, - # seq_lens_this_time, max_dec_len, is_block_step, output_cum_offsets) - # 根据其 C++ 原始定义,如果本身就是一维的,则不需要额外的 reshape。 - # 这里直接使用其原始引用,或者如果其维度不确定,也可以做 flatten()。 - # 为了明确,我们假设这些参数如果不是 (N, K) 形式,就已经是 (N,) 形式。 - print() - # 遍历批次中的每个样本 + + # flatten + accept_tokens_flat = accept_tokens.reshape([-1]) + draft_tokens_flat = draft_tokens.reshape([-1]) + verify_tokens_flat = verify_tokens.reshape([-1]) + verify_scores_flat = verify_scores.reshape([-1]) for bid in range(real_bsz): - # C++: const int start_token_id = bid * max_seq_len - output_cum_offsets[bid]; start_token_id = bid * max_seq_len - output_cum_offsets[bid] accept_num_now = 1 stop_flag_now_int = 0 - print( - f"DEBUG: start_token_id: {start_token_id}, max_seq_len: {max_seq_len}, output_cum_offsets[{bid}]: {output_cum_offsets[bid]}" - ) - # C++: if (!(is_block_step[bid] || bid >= real_bsz)) - if not ( - is_block_step[bid] or bid >= real_bsz - ): # bid >= real_bsz 在 Python for 循环中天然满足,但为保持一致保留 + if not (is_block_step[bid] or bid >= real_bsz): # bid >= real_bsz reserved for consistency with gpu if stop_flags[bid]: stop_flag_now_int = 1 else: - # C++: auto *verify_tokens_now = verify_tokens + start_token_id * max_candidate_len; - # Python: verify_tokens_now 是一个指向当前批次 verify_tokens 起始的扁平视图 - # 模拟了 C++ 中指针偏移后的“基地址” - verify_tokens_now = verify_tokens_flat[start_token_id * max_candidate_len :] # 从基址到末尾 - - # C++: auto *draft_tokens_now = draft_tokens + bid * max_draft_tokens; - # Python: draft_tokens_now 是当前批次 draft_tokens 起始的扁平视图 - draft_tokens_now = draft_tokens_flat[bid * max_draft_tokens :] # 从基址到末尾 - - # C++: auto *actual_candidate_len_now = actual_candidate_len + start_token_id; - # Python: actual_candidate_len_now 是当前批次 actual_candidate_len 起始的扁平视图 - actual_candidate_len_now = actual_candidate_len[start_token_id:] # actual_candidate_len 已经是 1D - - # C++: int i = 0; + verify_tokens_now = verify_tokens_flat[start_token_id * max_candidate_len :] + draft_tokens_now = draft_tokens_flat[bid * max_draft_tokens :] + actual_candidate_len_now = actual_candidate_len[start_token_id:] i = 0 + for loop_i in range(seq_lens_this_time[bid] - 1): + i = loop_i - # C++: for (; i < seq_lens_this_time[bid] - 1; i++) - for loop_i in range(seq_lens_this_time[bid] - 1): # 使用 loop_i 作为 Python 的循环变量 - i = loop_i # 保持 C++ 的 i 在每次迭代中更新为当前索引 - - # C++: if (seq_lens_encoder[bid] != 0) if seq_lens_encoder[bid] != 0: break if use_topk: - # C++: if (verify_tokens_now[i * max_candidate_len] == draft_tokens_now[i + 1]) if verify_tokens_now[i * max_candidate_len] == draft_tokens_now[i + 1]: step_idx[bid] += 1 accept_token = draft_tokens_now[i + 1] - # C++: accept_tokens[bid * max_draft_tokens + i] = accept_token; accept_tokens_flat[bid * max_draft_tokens + i] = accept_token - - # C++: if (is_in_end(accept_token, end_tokens, end_length) || step_idx[bid] >= max_dec_len[bid]) if is_in_end(accept_token, end_tokens, end_length) or step_idx[bid] >= max_dec_len[bid]: stop_flags[bid] = True stop_flag_now_int = 1 @@ -200,12 +138,8 @@ def speculate_verify_np( accept_num_now += 1 else: break - else: # C++: else (Top P verify) - # C++: auto actual_candidate_len_value = actual_candidate_len_now[i] > max_candidate_len ? max_candidate_len : actual_candidate_len_now[i]; + else: actual_candidate_len_value = min(actual_candidate_len_now[i], max_candidate_len) - - # C++: if (is_in(verify_tokens_now + i * max_candidate_len, draft_tokens_now[i + 1], actual_candidate_len_value)) - # 传入当前候选的扁平视图 verify_tokens_current_candidate_view = verify_tokens_now[ i * max_candidate_len : (i + 1) * max_candidate_len ] @@ -229,27 +163,23 @@ def speculate_verify_np( accept_num_now += 1 else: # TopK verify - ii = i # C++ 中 ii 从 i 开始 - # C++: if (max_candidate_len >= 2 && verify_tokens_now[ii * max_candidate_len + 1] == draft_tokens_now[ii + 1]) + ii = i # Start from i if ( max_candidate_len >= 2 and verify_tokens_now[ii * max_candidate_len + 1] == draft_tokens_now[ii + 1] ): # top-2 j = 0 - ii += 1 # C++ 中 ii 从下一个位置开始检查 - # C++: for (; j < verify_window && ii < seq_lens_this_time[bid] - 1; j++, ii++) + ii += 1 # Start from ii next position while j < verify_window and ii < seq_lens_this_time[bid] - 1: if verify_tokens_now[ii * max_candidate_len] != draft_tokens_now[ii + 1]: break j += 1 ii += 1 - # C++: if (j >= verify_window) if j >= verify_window: # accept all accept_num_now += verify_window + 1 step_idx[bid] += verify_window + 1 - # C++: for (; i < ii; i++) - for k_accepted_idx in range(i, ii): # i 会被更新 + for k_accepted_idx in range(i, ii): accept_token = draft_tokens_now[k_accepted_idx + 1] accept_tokens_flat[bid * max_draft_tokens + k_accepted_idx] = accept_token @@ -269,29 +199,18 @@ def speculate_verify_np( ) accept_num_now -= 1 step_idx[bid] -= 1 - break # 跳出内层接受循环 - break # 跳出主验证循环 (TopK 逻辑结束,无论成功与否) - # else 的 break 对应 is_in(Top P 验证失败,也不是 TopK 匹配) - break # 跳出主验证循环 - - # 采样阶段 (Sampling Phase) - # C++ 中 i 变量在循环结束后会保留其最终值,直接用于采样 - # Python 同样,loop_i 的最终值赋值给了 i + break + break # TopK finish + break # Jump main loop if not stop_flag_now_int: accept_token: int - - # C++: const float *verify_scores_now = verify_scores + start_token_id * max_candidate_len; - # Python: verify_scores_now 对应 C++ 中从 start_token_id 开始的 verify_scores 视图 verify_scores_now = verify_scores_flat[start_token_id * max_candidate_len :] step_idx[bid] += 1 if enable_topp: - # C++: auto actual_candidate_len_value = actual_candidate_len_now[i] > max_candidate_len ? max_candidate_len : actual_candidate_len_now[i]; actual_candidate_len_value = min(actual_candidate_len_now[i], max_candidate_len) - - # 传入当前候选的扁平视图 verify_tokens_sampling_view = verify_tokens_now[ i * max_candidate_len : (i + 1) * max_candidate_len ] @@ -299,24 +218,17 @@ def speculate_verify_np( i * max_candidate_len : (i + 1) * max_candidate_len ] - # C++: accept_token = topp_sampling_kernel(...) accept_token = topp_sampling_kernel( verify_tokens_sampling_view, verify_scores_sampling_view, - dev_curand_states[i], # C++: dev_curand_states + i + dev_curand_states[i], actual_candidate_len_value, - topp[bid], # C++: topp[bid] - bid, # C++: bid + topp[bid], + bid, ) else: accept_token = int(verify_tokens_now[i * max_candidate_len]) - print( - "debug python last accept_token", - accept_token, - "prefill_one_step_stop", - prefill_one_step_stop, - ) - # C++: accept_tokens[bid * max_draft_tokens + i] = accept_token; + accept_tokens_flat[bid * max_draft_tokens + i] = accept_token if prefill_one_step_stop: @@ -333,7 +245,6 @@ def speculate_verify_np( return accept_tokens, accept_num, step_idx, stop_flags -# ---------------- 生成随机输入 ---------------- def gen_speculate_verify_inputs( real_bsz=123, max_draft_tokens=16, @@ -341,12 +252,11 @@ def gen_speculate_verify_inputs( max_candidate_len=8, verify_window=2, end_length=4, - enable_topp=True, + enable_topp=False, seed=2025, ): rng = np.random.default_rng(seed) - # 基础输入 seq_lens_encoder = rng.integers(0, 3, size=real_bsz, dtype=np.int32) seq_lens_decoder = rng.integers(1, max_draft_tokens, size=real_bsz, dtype=np.int32) draft_tokens = rng.integers(0, 1000, size=(real_bsz, max_draft_tokens), dtype=np.int64) @@ -354,10 +264,8 @@ def gen_speculate_verify_inputs( seq_lens_this_time = rng.integers(1, max_seq_len + 1, size=real_bsz, dtype=np.int32) sum_seq_this_time = int(np.sum(seq_lens_this_time)) - # print("debug param set sum_seq_this_time",sum_seq_this_time) - # print("debug param real_bsz * max_draft_tokens < 2k",real_bsz * max_draft_tokens) - # print("debug sum_seq_this_time * max_candidate_len < 2k",sum_seq_this_time * max_candidate_len) + sampled_token_ids = rng.integers(0, 1000, size=(sum_seq_this_time, 1), dtype=np.int64) verify_tokens = rng.integers(0, 1000, size=(sum_seq_this_time, max_candidate_len), dtype=np.int64) verify_scores = rng.random(size=(sum_seq_this_time, max_candidate_len)).astype(np.float32) @@ -378,13 +286,14 @@ def gen_speculate_verify_inputs( else np.zeros(real_bsz, dtype=np.float32) ) - # 输出(占位) + # Output(inplace) accept_tokens = np.zeros((real_bsz, max_draft_tokens), dtype=np.int64) accept_num = np.zeros(real_bsz, dtype=np.int32) step_idx = np.zeros(real_bsz, dtype=np.int64) stop_flags = np.zeros(real_bsz, dtype=bool) return { + "sampled_token_ids": sampled_token_ids, "accept_tokens": accept_tokens, "accept_num": accept_num, "step_idx": step_idx, @@ -405,178 +314,12 @@ def gen_speculate_verify_inputs( "max_seq_len": max_seq_len, "verify_window": verify_window, "enable_topp": enable_topp, + "benchmark_mode": False, + "accept_all_drafts": False, } -# ------------------- 单测主体 ------------------- -# # ---- Paddle 端 ---- -def run_speculate_verify_test( - real_bsz, - max_draft_tokens, - max_seq_len, - max_candidate_len, - verify_window, - end_length, - enable_topp, - seed, -): - inputs = gen_speculate_verify_inputs( - real_bsz=real_bsz, - max_draft_tokens=max_draft_tokens, - max_seq_len=max_seq_len, - max_candidate_len=max_candidate_len, - verify_window=verify_window, - end_length=end_length, - enable_topp=enable_topp, - seed=seed, - ) - - paddle_inputs = {} - - print("========= 1 xpu process==========") - - for k, v in inputs.items(): - if isinstance(v, (int, bool)): - paddle_inputs[k] = v - # print(f"{k:<25} type: {type(v).__name__}, value: {v}") - else: - # paddle_inputs[k] = paddle.to_tensor(v, place=paddle.CPUPlace()) - paddle_inputs[k] = paddle.to_tensor(v, place=paddle.XPUPlace(0)) - # print(f"{k:<25} type: Tensor, dtype: {paddle_inputs[k].dtype}, shape: {paddle_inputs[k].shape}") - - out_pd = speculate_verify(**paddle_inputs) - (accept_tokens_pd, accept_num_pd, step_idx_pd, stop_flags_pd) = out_pd - pd_tensors = [accept_tokens_pd, accept_num_pd, step_idx_pd, stop_flags_pd] - - print("========= 1 end==========") - print("========= 2 python process==========") - - # np_inputs = {k: (paddle_inputs[k].numpy().copy() if isinstance(paddle_inputs[k], paddle.Tensor) - # else paddle_inputs[k]) - # for k in paddle_inputs} - - # out_np = speculate_verify_np(**np_inputs) - # (accept_tokens_np, accept_num_np, step_idx_np, stop_flags_np) = out_np - # np_tensors = [accept_tokens_np, accept_num_np, step_idx_np, stop_flags_np] - - print("=========2 end =======") - - print("========= 3 (CPU)==========") - paddle_inputs_cpu = {} - - for k, v in inputs.items(): # 重新使用原始的 inputs 字典,确保数据原始状态 - if isinstance(v, (int, bool)): - paddle_inputs_cpu[k] = v - # print(f"{k:<25} type: {type(v).__name__}, value: {v}") - else: - # 核心修改:使用 paddle.CPUPlace() - paddle_inputs_cpu[k] = paddle.to_tensor(v, place=paddle.CPUPlace()) - # print(f"{k:<25} type: Tensor, dtype: {paddle_inputs_cpu[k].dtype}, shape: {paddle_inputs_cpu[k].shape}") - - out_cpu = speculate_verify(**paddle_inputs_cpu) - (accept_tokens_cpu, accept_num_cpu, step_idx_cpu, stop_flags_cpu) = out_cpu - - cpu_tensors = [ - accept_tokens_cpu, - accept_num_cpu, - step_idx_cpu, - stop_flags_cpu, - ] - print("========= 3 (CPU) end==========") - - # ---------------- 校对 ---------------- - # print("========= python/cpu vs xpu verify ==========") - - # names = ["accept_tokens", "accept_num", "step_idx", "stop_flags"] - # for name, pd_val, np_val in zip(names, pd_tensors, np_tensors): - # pd_arr = pd_val.numpy() - # ok = np.array_equal(pd_arr, np_val) - # print(f"{name:20s} equal: {ok}") - # if not ok: - # print(f"{name} mismatch!\nPaddle:\n{pd_arr}\n\nNumPy:\n{np_val}") - - print("========= cpu vs xpu verify ==========") - - names = ["accept_tokens", "accept_num", "step_idx", "stop_flags"] - # for name, pd_val, np_val in zip(names, pd_tensors, cpu_tensors): - # pd_arr = pd_val.numpy() - # ok = np.array_equal(pd_arr, np_val) - # print(f"{name:20s} equal: {ok}") - # if not ok: - # print(f"{name} mismatch!\nPaddle:\n{pd_arr}\n\nNumPy:\n{np_val}") - - for name, pd_val, np_val in zip(names, pd_tensors, cpu_tensors): - pd_arr = pd_val.numpy() - ok = np.array_equal(pd_arr, np_val) - print(f"{name:20s} equal: {ok}") - if not ok: - print(f"{name} mismatch!") - - # 输出不同位置的索引和对应值 - print(f"{name} mismatch!\nPaddle:\n{pd_arr}\n\nNumPy:\n{np_val}") - mismatches = np.where(pd_arr != np_val) - for idx in zip(*mismatches): - print(f" idx {idx}: Paddle = {pd_arr[idx]}, NumPy = {np_val[idx]}") - - # 如果差异太多可限制输出数量 - if len(mismatches[0]) > 20: - print(" ... (truncated)") - - -# ------------------------------------- -# 测试用例 -# ------------------------------------- test_configs = [ - { - "real_bsz": 4, - "max_draft_tokens": 3, - "max_seq_len": 30, - "max_candidate_len": 4, - "verify_window": 2, - "end_length": 2, - "enable_topp": True, - "seed": 2025, - }, - { - "real_bsz": 77, - "max_draft_tokens": 10, - "max_seq_len": 12000, - "max_candidate_len": 8, - "verify_window": 2, - "end_length": 4, - "enable_topp": True, - "seed": 2025, - }, - { - "real_bsz": 1, - "max_draft_tokens": 2, - "max_seq_len": 10, - "max_candidate_len": 1, - "verify_window": 1, - "end_length": 1, - "enable_topp": True, - "seed": 42, - }, - { - "real_bsz": 128, - "max_draft_tokens": 7, - "max_seq_len": 999, - "max_candidate_len": 5, - "verify_window": 3, - "end_length": 3, - "enable_topp": True, - "seed": 422, - }, - { - "real_bsz": 99, - "max_draft_tokens": 5, - "max_seq_len": 10, - "max_candidate_len": 3, - "verify_window": 4, - "end_length": 4, - "enable_topp": True, - "seed": 42, - }, { "real_bsz": 1, "max_draft_tokens": 9, @@ -629,6 +372,45 @@ test_configs = [ }, ] -for i, cfg in enumerate(test_configs): - print(f"\n\n======== Running Test Case {i} ========") - run_speculate_verify_test(**cfg) + +class TestSpeculateVerify(unittest.TestCase): + def run_speculate_verify( + self, + real_bsz, + max_draft_tokens, + max_seq_len, + max_candidate_len, + verify_window, + end_length, + enable_topp, + seed, + ): + inputs = gen_speculate_verify_inputs( + real_bsz=real_bsz, + max_draft_tokens=max_draft_tokens, + max_seq_len=max_seq_len, + max_candidate_len=max_candidate_len, + verify_window=verify_window, + end_length=end_length, + enable_topp=enable_topp, + seed=seed, + ) + paddle_inputs = {k: v if isinstance(v, (int, bool)) else paddle.to_tensor(v) for k, v in inputs.items()} + inputs_xpu = list(paddle_inputs.values()) + speculate_verify(*inputs_xpu) + out_xpu = [inputs_xpu[1], inputs_xpu[2], inputs_xpu[3], inputs_xpu[4]] + + paddle_inputs_ref = {k: v if isinstance(v, (int, bool)) else paddle.to_tensor(v) for k, v in inputs.items()} + out_ref = speculate_verify_ref(**paddle_inputs_ref) + + names = ["accept_tokens", "accept_num", "step_idx", "stop_flags"] + for _, pd_val, np_val in zip(names, out_xpu, out_ref): + np.testing.assert_allclose(pd_val.numpy(), np_val.numpy()) + + def test_speculate_verify(self): + for config in test_configs: + self.run_speculate_verify(**config) + + +if __name__ == "__main__": + unittest.main() diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 28687ea53..c3a426488 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -71,7 +71,7 @@ def padding_sampling_params(top_p, top_k, infer_seed, seq_lens_this_time, seq_le paddle.ones_like(seq_lens_this_time), ) - batch_start = (paddle.cumsum(token_lens, axis=0) - token_lens.astype("int64")).reshape(-1) # [B] + batch_start = (paddle.cumsum(token_lens, axis=0, dtype="int64") - token_lens.astype("int64")).reshape([-1]) # [B] token_batch_ids = paddle.repeat_interleave( paddle.arange(token_lens.shape[0], dtype="int64"), token_lens, @@ -79,7 +79,7 @@ def padding_sampling_params(top_p, top_k, infer_seed, seq_lens_this_time, seq_le token_pos = paddle.arange(topp_seed.shape[0], dtype="int64") local_pos = token_pos - paddle.gather(batch_start, token_batch_ids) - is_decoder = paddle.gather(seq_lens_encoder[:real_bsz] == 0, token_batch_ids).reshape(-1) + is_decoder = paddle.gather(seq_lens_encoder[:real_bsz] == 0, token_batch_ids).reshape([-1]) offsets = paddle.where( is_decoder, @@ -879,6 +879,15 @@ class SpeculativeSampler(nn.Layer): probs = F.softmax(logits) + top_p, top_k, topp_seed = padding_sampling_params( + sampling_metadata.top_p, + sampling_metadata.top_k, + sampling_metadata.seed, + share_inputs["seq_lens_this_time"], + paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]), + ) + _, sampled_token_ids = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, topp_seed=topp_seed) + verify_scores, verify_tokens, actual_candidate_len = top_p_candidates( probs, sampling_metadata.top_p, @@ -888,6 +897,7 @@ class SpeculativeSampler(nn.Layer): ) speculate_verify( + sampled_token_ids, share_inputs["accept_tokens"], share_inputs["accept_num"], share_inputs["step_idx"],