mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[XPU] modify speculate_verify (#5522)
This commit is contained in:
@@ -26,7 +26,8 @@
|
|||||||
|
|
||||||
namespace api = baidu::xpu::api;
|
namespace api = baidu::xpu::api;
|
||||||
|
|
||||||
void SpeculateVerify(const paddle::Tensor &accept_tokens,
|
void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
|
||||||
|
const paddle::Tensor &accept_tokens,
|
||||||
const paddle::Tensor &accept_num,
|
const paddle::Tensor &accept_num,
|
||||||
const paddle::Tensor &step_idx,
|
const paddle::Tensor &step_idx,
|
||||||
const paddle::Tensor &stop_flags,
|
const paddle::Tensor &stop_flags,
|
||||||
@@ -48,17 +49,12 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens,
|
|||||||
bool enable_topp,
|
bool enable_topp,
|
||||||
bool benchmark_mode,
|
bool benchmark_mode,
|
||||||
bool accept_all_drafts) {
|
bool accept_all_drafts) {
|
||||||
// TODO(chenhuan09):support accept_all_drafts
|
|
||||||
auto bsz = accept_tokens.shape()[0];
|
auto bsz = accept_tokens.shape()[0];
|
||||||
int real_bsz = seq_lens_this_time.shape()[0];
|
int real_bsz = seq_lens_this_time.shape()[0];
|
||||||
auto max_draft_tokens = draft_tokens.shape()[1];
|
auto max_draft_tokens = draft_tokens.shape()[1];
|
||||||
auto end_length = end_tokens.shape()[0];
|
auto end_length = end_tokens.shape()[0];
|
||||||
auto max_candidate_len = verify_tokens.shape()[1];
|
auto max_candidate_len = verify_tokens.shape()[1];
|
||||||
|
|
||||||
constexpr int BlockSize = 512;
|
|
||||||
// set topp_seed if needed
|
|
||||||
const paddle::optional<paddle::Tensor> &topp_seed = nullptr;
|
|
||||||
|
|
||||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||||
api::Context *ctx =
|
api::Context *ctx =
|
||||||
@@ -69,16 +65,16 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens,
|
|||||||
xpu_ctx_flag = false;
|
xpu_ctx_flag = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
|
||||||
// auto dev_ctx =
|
|
||||||
// paddle::experimental::DeviceContextPool::Instance().Get(place); auto
|
|
||||||
// xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
|
|
||||||
|
|
||||||
bool use_topk = false;
|
bool use_topk = false;
|
||||||
char *env_var = getenv("SPECULATE_VERIFY_USE_TOPK");
|
char *env_var = getenv("SPECULATE_VERIFY_USE_TOPK");
|
||||||
if (env_var) {
|
if (env_var) {
|
||||||
use_topk = static_cast<bool>(std::stoi(env_var));
|
use_topk = static_cast<bool>(std::stoi(env_var));
|
||||||
}
|
}
|
||||||
|
bool use_target_sampling = false;
|
||||||
|
char *env_var_1 = getenv("SPECULATE_VERIFY_USE_TARGET_SAMPLING");
|
||||||
|
if (env_var_1) {
|
||||||
|
use_target_sampling = static_cast<bool>(std::stoi(env_var_1));
|
||||||
|
}
|
||||||
bool prefill_one_step_stop = false;
|
bool prefill_one_step_stop = false;
|
||||||
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
|
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
|
||||||
// std::cout << "Your PATH is: " << env_p << '\n';
|
// std::cout << "Your PATH is: " << env_p << '\n';
|
||||||
@@ -108,10 +104,12 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens,
|
|||||||
|
|
||||||
auto dev_curand_states =
|
auto dev_curand_states =
|
||||||
!xpu_ctx_flag ? dev_curand_states_cpu.data() : dev_curand_states_xpu;
|
!xpu_ctx_flag ? dev_curand_states_cpu.data() : dev_curand_states_xpu;
|
||||||
|
int ret;
|
||||||
if (use_topk) {
|
if (use_topk) {
|
||||||
if (enable_topp) {
|
if (enable_topp) {
|
||||||
baidu::xpu::api::plugin::speculate_verify<true, true>(
|
ret = baidu::xpu::api::plugin::speculate_verify<true, true>(
|
||||||
ctx,
|
ctx,
|
||||||
|
sampled_token_ids.data<int64_t>(),
|
||||||
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
||||||
const_cast<int *>(accept_num.data<int>()),
|
const_cast<int *>(accept_num.data<int>()),
|
||||||
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
||||||
@@ -137,10 +135,14 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens,
|
|||||||
max_candidate_len,
|
max_candidate_len,
|
||||||
verify_window,
|
verify_window,
|
||||||
prefill_one_step_stop,
|
prefill_one_step_stop,
|
||||||
benchmark_mode);
|
benchmark_mode,
|
||||||
|
accept_all_drafts,
|
||||||
|
use_target_sampling);
|
||||||
|
PD_CHECK(ret == 0, "speculate_verify failed.");
|
||||||
} else {
|
} else {
|
||||||
baidu::xpu::api::plugin::speculate_verify<false, true>(
|
ret = baidu::xpu::api::plugin::speculate_verify<false, true>(
|
||||||
ctx,
|
ctx,
|
||||||
|
sampled_token_ids.data<int64_t>(),
|
||||||
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
||||||
const_cast<int *>(accept_num.data<int>()),
|
const_cast<int *>(accept_num.data<int>()),
|
||||||
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
||||||
@@ -166,12 +168,16 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens,
|
|||||||
max_candidate_len,
|
max_candidate_len,
|
||||||
verify_window,
|
verify_window,
|
||||||
prefill_one_step_stop,
|
prefill_one_step_stop,
|
||||||
benchmark_mode);
|
benchmark_mode,
|
||||||
|
accept_all_drafts,
|
||||||
|
use_target_sampling);
|
||||||
}
|
}
|
||||||
|
PD_CHECK(ret == 0, "speculate_verify failed.");
|
||||||
} else {
|
} else {
|
||||||
if (enable_topp) {
|
if (enable_topp) {
|
||||||
baidu::xpu::api::plugin::speculate_verify<true, false>(
|
ret = baidu::xpu::api::plugin::speculate_verify<true, false>(
|
||||||
ctx,
|
ctx,
|
||||||
|
sampled_token_ids.data<int64_t>(),
|
||||||
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
||||||
const_cast<int *>(accept_num.data<int>()),
|
const_cast<int *>(accept_num.data<int>()),
|
||||||
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
||||||
@@ -197,10 +203,14 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens,
|
|||||||
max_candidate_len,
|
max_candidate_len,
|
||||||
verify_window,
|
verify_window,
|
||||||
prefill_one_step_stop,
|
prefill_one_step_stop,
|
||||||
benchmark_mode);
|
benchmark_mode,
|
||||||
|
accept_all_drafts,
|
||||||
|
use_target_sampling);
|
||||||
|
PD_CHECK(ret == 0, "speculate_verify failed.");
|
||||||
} else {
|
} else {
|
||||||
baidu::xpu::api::plugin::speculate_verify<false, false>(
|
ret = baidu::xpu::api::plugin::speculate_verify<false, false>(
|
||||||
ctx,
|
ctx,
|
||||||
|
sampled_token_ids.data<int64_t>(),
|
||||||
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
||||||
const_cast<int *>(accept_num.data<int>()),
|
const_cast<int *>(accept_num.data<int>()),
|
||||||
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
||||||
@@ -226,18 +236,25 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens,
|
|||||||
max_candidate_len,
|
max_candidate_len,
|
||||||
verify_window,
|
verify_window,
|
||||||
prefill_one_step_stop,
|
prefill_one_step_stop,
|
||||||
benchmark_mode);
|
benchmark_mode,
|
||||||
|
accept_all_drafts,
|
||||||
|
use_target_sampling);
|
||||||
}
|
}
|
||||||
|
PD_CHECK(ret == 0, "speculate_verify failed.");
|
||||||
|
}
|
||||||
|
if (draft_tokens.is_cpu()) {
|
||||||
|
delete ctx;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(speculate_verify)
|
PD_BUILD_STATIC_OP(speculate_verify)
|
||||||
.Inputs({"accept_tokens",
|
.Inputs({"sampled_token_ids",
|
||||||
|
"accept_tokens",
|
||||||
"accept_num",
|
"accept_num",
|
||||||
"step_idx",
|
"step_idx",
|
||||||
"stop_flags",
|
|
||||||
"seq_lens_encoder",
|
"seq_lens_encoder",
|
||||||
"seq_lens_decoder",
|
"seq_lens_decoder",
|
||||||
|
"stop_flags",
|
||||||
"draft_tokens",
|
"draft_tokens",
|
||||||
"seq_lens_this_time",
|
"seq_lens_this_time",
|
||||||
"verify_tokens",
|
"verify_tokens",
|
||||||
|
|||||||
@@ -248,7 +248,8 @@ std::vector<paddle::Tensor> TopPCandidates(
|
|||||||
int candidates_len,
|
int candidates_len,
|
||||||
int max_seq_len);
|
int max_seq_len);
|
||||||
|
|
||||||
void SpeculateVerify(const paddle::Tensor& accept_tokens,
|
void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
|
||||||
|
const paddle::Tensor& accept_tokens,
|
||||||
const paddle::Tensor& accept_num,
|
const paddle::Tensor& accept_num,
|
||||||
const paddle::Tensor& step_idx,
|
const paddle::Tensor& step_idx,
|
||||||
const paddle::Tensor& stop_flags,
|
const paddle::Tensor& stop_flags,
|
||||||
@@ -1013,6 +1014,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
|||||||
|
|
||||||
m.def("speculate_verify",
|
m.def("speculate_verify",
|
||||||
&SpeculateVerify,
|
&SpeculateVerify,
|
||||||
|
py::arg("sampled_token_ids"),
|
||||||
py::arg("accept_tokens"),
|
py::arg("accept_tokens"),
|
||||||
py::arg("accept_num"),
|
py::arg("accept_num"),
|
||||||
py::arg("step_idx"),
|
py::arg("step_idx"),
|
||||||
|
|||||||
@@ -384,6 +384,7 @@ DLL_EXPORT int mtp_free_and_dispatch_block(Context* ctx,
|
|||||||
|
|
||||||
template <bool ENABLE_TOPP, bool USE_TOPK>
|
template <bool ENABLE_TOPP, bool USE_TOPK>
|
||||||
DLL_EXPORT int speculate_verify(Context* ctx,
|
DLL_EXPORT int speculate_verify(Context* ctx,
|
||||||
|
const int64_t* sampled_token_ids,
|
||||||
int64_t* accept_tokens,
|
int64_t* accept_tokens,
|
||||||
int* accept_num,
|
int* accept_num,
|
||||||
int64_t* step_idx,
|
int64_t* step_idx,
|
||||||
@@ -409,7 +410,9 @@ DLL_EXPORT int speculate_verify(Context* ctx,
|
|||||||
const int max_candidate_len,
|
const int max_candidate_len,
|
||||||
const int verify_window,
|
const int verify_window,
|
||||||
const bool prefill_one_step_stop,
|
const bool prefill_one_step_stop,
|
||||||
const bool benchmark_mode);
|
const bool benchmark_mode,
|
||||||
|
const bool accept_all_drafts,
|
||||||
|
const bool use_target_sampling);
|
||||||
|
|
||||||
DLL_EXPORT int speculate_clear_accept_nums(Context* ctx,
|
DLL_EXPORT int speculate_clear_accept_nums(Context* ctx,
|
||||||
int* accept_num,
|
int* accept_num,
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ static __device__ inline unsigned int xorwow(unsigned int &state) {
|
|||||||
state ^= state >> 13;
|
state ^= state >> 13;
|
||||||
return state;
|
return state;
|
||||||
}
|
}
|
||||||
typedef uint32_t curandStatePhilox4_32_10_t;
|
|
||||||
__device__ int64_t
|
__device__ int64_t
|
||||||
topp_sampling_kernel(__global_ptr__ const int64_t *candidate_ids,
|
topp_sampling_kernel(__global_ptr__ const int64_t *candidate_ids,
|
||||||
__global_ptr__ const float *candidate_scores,
|
__global_ptr__ const float *candidate_scores,
|
||||||
@@ -91,9 +91,10 @@ topp_sampling_kernel(__global_ptr__ const int64_t *candidate_ids,
|
|||||||
}
|
}
|
||||||
return candidate_ids[0];
|
return candidate_ids[0];
|
||||||
}
|
}
|
||||||
#define sm_size 1024
|
|
||||||
template <bool ENABLE_TOPP, bool USE_TOPK>
|
template <bool ENABLE_TOPP, bool USE_TOPK>
|
||||||
__global__ void speculate_verify(
|
__global__ void speculate_verify(
|
||||||
|
const int64_t *sampled_token_ids,
|
||||||
int64_t *accept_tokens, // out [real_bsz, max_draft_tokens], 输出最终接收的
|
int64_t *accept_tokens, // out [real_bsz, max_draft_tokens], 输出最终接收的
|
||||||
// token(通过验证或采样)
|
// token(通过验证或采样)
|
||||||
int *accept_num, // out [real_bsz], 每个序列最终接受的 token
|
int *accept_num, // out [real_bsz], 每个序列最终接受的 token
|
||||||
@@ -139,9 +140,10 @@ __global__ void speculate_verify(
|
|||||||
// 的最大候选数(用于验证或采样)
|
// 的最大候选数(用于验证或采样)
|
||||||
const int verify_window, // scalar, TopK 验证窗口(允许连续 top1 匹配次数)
|
const int verify_window, // scalar, TopK 验证窗口(允许连续 top1 匹配次数)
|
||||||
const bool prefill_one_step_stop,
|
const bool prefill_one_step_stop,
|
||||||
const bool benchmark_mode) {
|
const bool benchmark_mode,
|
||||||
const int cid = core_id();
|
const bool accept_all_drafts,
|
||||||
const int64_t tid = cluster_id() * core_num() + core_id();
|
const bool use_target_sampling) {
|
||||||
|
const int64_t tid = core_id() * cluster_num() + cluster_id();
|
||||||
const int64_t nthreads = cluster_num() * core_num();
|
const int64_t nthreads = cluster_num() * core_num();
|
||||||
for (int64_t bid = tid; bid < real_bsz; bid += nthreads) {
|
for (int64_t bid = tid; bid < real_bsz; bid += nthreads) {
|
||||||
int stop_flag_now_int = 0;
|
int stop_flag_now_int = 0;
|
||||||
@@ -158,6 +160,7 @@ __global__ void speculate_verify(
|
|||||||
verify_tokens + start_token_id * max_candidate_len;
|
verify_tokens + start_token_id * max_candidate_len;
|
||||||
auto *draft_tokens_now = draft_tokens + bid * max_draft_tokens;
|
auto *draft_tokens_now = draft_tokens + bid * max_draft_tokens;
|
||||||
auto *actual_candidate_len_now = actual_candidate_len + start_token_id;
|
auto *actual_candidate_len_now = actual_candidate_len + start_token_id;
|
||||||
|
auto *sampled_token_id_now = sampled_token_ids + start_token_id;
|
||||||
int i = 0;
|
int i = 0;
|
||||||
// printf("seq_lens_this_time[%d]-1: %d \n",bid,
|
// printf("seq_lens_this_time[%d]-1: %d \n",bid,
|
||||||
// seq_lens_this_time[bid]-1);
|
// seq_lens_this_time[bid]-1);
|
||||||
@@ -168,7 +171,43 @@ __global__ void speculate_verify(
|
|||||||
if (seq_lens_encoder[bid] != 0) {
|
if (seq_lens_encoder[bid] != 0) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (USE_TOPK) {
|
if (accept_all_drafts) {
|
||||||
|
// accept all draft tokens
|
||||||
|
step_idx[bid]++;
|
||||||
|
auto accept_token = draft_tokens_now[i + 1];
|
||||||
|
accept_tokens[bid * max_draft_tokens + i] = accept_token;
|
||||||
|
|
||||||
|
if (is_in_end(accept_token, end_tokens, end_length) ||
|
||||||
|
step_idx[bid] >= max_dec_len[bid]) {
|
||||||
|
stop_flags[bid] = true;
|
||||||
|
stop_flag_now_int = 1;
|
||||||
|
if (step_idx[bid] >= max_dec_len[bid])
|
||||||
|
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
accept_num_now++;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (use_target_sampling) {
|
||||||
|
if (sampled_token_id_now[i] == draft_tokens_now[i + 1]) {
|
||||||
|
step_idx[bid]++;
|
||||||
|
auto accept_token = draft_tokens_now[i + 1];
|
||||||
|
accept_tokens[bid * max_draft_tokens + i] = accept_token;
|
||||||
|
if (is_in_end(accept_token, end_tokens, end_length) ||
|
||||||
|
step_idx[bid] >= max_dec_len[bid]) {
|
||||||
|
stop_flags[bid] = true;
|
||||||
|
stop_flag_now_int = 1;
|
||||||
|
if (step_idx[bid] >= max_dec_len[bid])
|
||||||
|
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
accept_num_now++;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else if (USE_TOPK) {
|
||||||
if (verify_tokens_now[i * max_candidate_len] ==
|
if (verify_tokens_now[i * max_candidate_len] ==
|
||||||
draft_tokens_now[i + 1]) {
|
draft_tokens_now[i + 1]) {
|
||||||
// accept_num_now++;
|
// accept_num_now++;
|
||||||
@@ -274,7 +313,9 @@ __global__ void speculate_verify(
|
|||||||
__global_ptr__ const float *verify_scores_now =
|
__global_ptr__ const float *verify_scores_now =
|
||||||
verify_scores + start_token_id * max_candidate_len;
|
verify_scores + start_token_id * max_candidate_len;
|
||||||
step_idx[bid]++;
|
step_idx[bid]++;
|
||||||
if (ENABLE_TOPP) {
|
if (use_target_sampling) {
|
||||||
|
accept_token = sampled_token_id_now[i];
|
||||||
|
} else if (ENABLE_TOPP) {
|
||||||
auto actual_candidate_len_value =
|
auto actual_candidate_len_value =
|
||||||
actual_candidate_len_now[i] > max_candidate_len
|
actual_candidate_len_now[i] > max_candidate_len
|
||||||
? max_candidate_len
|
? max_candidate_len
|
||||||
@@ -306,7 +347,8 @@ __global__ void speculate_verify(
|
|||||||
}
|
}
|
||||||
#define SPECULATE_VERIFY_INSTANTIATE(ENABLE_TOPP, USE_TOPK) \
|
#define SPECULATE_VERIFY_INSTANTIATE(ENABLE_TOPP, USE_TOPK) \
|
||||||
template __global__ void speculate_verify<ENABLE_TOPP, USE_TOPK>( \
|
template __global__ void speculate_verify<ENABLE_TOPP, USE_TOPK>( \
|
||||||
int64_t * accept_tokens, \
|
const int64_t *sampled_token_ids, \
|
||||||
|
int64_t *accept_tokens, \
|
||||||
int *accept_num, \
|
int *accept_num, \
|
||||||
int64_t *step_idx, \
|
int64_t *step_idx, \
|
||||||
bool *stop_flags, \
|
bool *stop_flags, \
|
||||||
@@ -331,7 +373,9 @@ __global__ void speculate_verify(
|
|||||||
int max_candidate_len, \
|
int max_candidate_len, \
|
||||||
int verify_window, \
|
int verify_window, \
|
||||||
bool prefill_one_step_stop, \
|
bool prefill_one_step_stop, \
|
||||||
bool benchmark_mode);
|
bool benchmark_mode, \
|
||||||
|
bool accept_all_drafts, \
|
||||||
|
bool use_target_sampling);
|
||||||
SPECULATE_VERIFY_INSTANTIATE(true, true)
|
SPECULATE_VERIFY_INSTANTIATE(true, true)
|
||||||
SPECULATE_VERIFY_INSTANTIATE(true, false)
|
SPECULATE_VERIFY_INSTANTIATE(true, false)
|
||||||
SPECULATE_VERIFY_INSTANTIATE(false, true)
|
SPECULATE_VERIFY_INSTANTIATE(false, true)
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ typedef uint32_t curandStatePhilox4_32_10_t;
|
|||||||
|
|
||||||
template <bool ENABLE_TOPP, bool USE_TOPK>
|
template <bool ENABLE_TOPP, bool USE_TOPK>
|
||||||
__attribute__((global)) void speculate_verify(
|
__attribute__((global)) void speculate_verify(
|
||||||
|
const int64_t *sampled_token_ids,
|
||||||
int64_t *accept_tokens,
|
int64_t *accept_tokens,
|
||||||
int *accept_num,
|
int *accept_num,
|
||||||
int64_t *step_idx,
|
int64_t *step_idx,
|
||||||
@@ -49,7 +50,9 @@ __attribute__((global)) void speculate_verify(
|
|||||||
const int max_candidate_len,
|
const int max_candidate_len,
|
||||||
const int verify_window,
|
const int verify_window,
|
||||||
const bool prefill_one_step_stop,
|
const bool prefill_one_step_stop,
|
||||||
const bool benchmark_mode);
|
const bool benchmark_mode,
|
||||||
|
const bool accept_all_drafts,
|
||||||
|
const bool use_target_sampling);
|
||||||
} // namespace plugin
|
} // namespace plugin
|
||||||
} // namespace xpu3
|
} // namespace xpu3
|
||||||
|
|
||||||
@@ -113,6 +116,7 @@ static int64_t topp_sampling_kernel(const int64_t *candidate_ids,
|
|||||||
|
|
||||||
template <bool ENABLE_TOPP, bool USE_TOPK>
|
template <bool ENABLE_TOPP, bool USE_TOPK>
|
||||||
static int cpu_wrapper(Context *ctx,
|
static int cpu_wrapper(Context *ctx,
|
||||||
|
const int64_t *sampled_token_ids,
|
||||||
int64_t *accept_tokens,
|
int64_t *accept_tokens,
|
||||||
int *accept_num,
|
int *accept_num,
|
||||||
int64_t *step_idx,
|
int64_t *step_idx,
|
||||||
@@ -138,7 +142,9 @@ static int cpu_wrapper(Context *ctx,
|
|||||||
const int max_candidate_len,
|
const int max_candidate_len,
|
||||||
const int verify_window,
|
const int verify_window,
|
||||||
const bool prefill_one_step_stop,
|
const bool prefill_one_step_stop,
|
||||||
const bool benchmark_mode) {
|
const bool benchmark_mode,
|
||||||
|
const bool accept_all_drafts,
|
||||||
|
const bool use_target_sampling) {
|
||||||
for (int bid = 0; bid < real_bsz; ++bid) {
|
for (int bid = 0; bid < real_bsz; ++bid) {
|
||||||
// verify and set stop flags
|
// verify and set stop flags
|
||||||
int accept_num_now = 1;
|
int accept_num_now = 1;
|
||||||
@@ -157,6 +163,7 @@ static int cpu_wrapper(Context *ctx,
|
|||||||
verify_tokens + start_token_id * max_candidate_len;
|
verify_tokens + start_token_id * max_candidate_len;
|
||||||
auto *draft_tokens_now = draft_tokens + bid * max_draft_tokens;
|
auto *draft_tokens_now = draft_tokens + bid * max_draft_tokens;
|
||||||
auto *actual_candidate_len_now = actual_candidate_len + start_token_id;
|
auto *actual_candidate_len_now = actual_candidate_len + start_token_id;
|
||||||
|
auto *sampled_token_id_now = sampled_token_ids + start_token_id;
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
// printf("seq_lens_this_time[%d]-1: %d \n",bid,
|
// printf("seq_lens_this_time[%d]-1: %d \n",bid,
|
||||||
@@ -168,7 +175,43 @@ static int cpu_wrapper(Context *ctx,
|
|||||||
if (seq_lens_encoder[bid] != 0) {
|
if (seq_lens_encoder[bid] != 0) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (USE_TOPK) {
|
if (accept_all_drafts) {
|
||||||
|
// accept all draft tokens
|
||||||
|
step_idx[bid]++;
|
||||||
|
auto accept_token = draft_tokens_now[i + 1];
|
||||||
|
accept_tokens[bid * max_draft_tokens + i] = accept_token;
|
||||||
|
|
||||||
|
if (is_in_end(accept_token, end_tokens, end_length) ||
|
||||||
|
step_idx[bid] >= max_dec_len[bid]) {
|
||||||
|
stop_flags[bid] = true;
|
||||||
|
stop_flag_now_int = 1;
|
||||||
|
if (step_idx[bid] >= max_dec_len[bid])
|
||||||
|
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
accept_num_now++;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (use_target_sampling) {
|
||||||
|
if (sampled_token_id_now[i] == draft_tokens_now[i + 1]) {
|
||||||
|
step_idx[bid]++;
|
||||||
|
auto accept_token = draft_tokens_now[i + 1];
|
||||||
|
accept_tokens[bid * max_draft_tokens + i] = accept_token;
|
||||||
|
if (is_in_end(accept_token, end_tokens, end_length) ||
|
||||||
|
step_idx[bid] >= max_dec_len[bid]) {
|
||||||
|
stop_flags[bid] = true;
|
||||||
|
stop_flag_now_int = 1;
|
||||||
|
if (step_idx[bid] >= max_dec_len[bid])
|
||||||
|
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
accept_num_now++;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else if (USE_TOPK) {
|
||||||
if (verify_tokens_now[i * max_candidate_len] ==
|
if (verify_tokens_now[i * max_candidate_len] ==
|
||||||
draft_tokens_now[i + 1]) {
|
draft_tokens_now[i + 1]) {
|
||||||
// accept_num_now++;
|
// accept_num_now++;
|
||||||
@@ -270,7 +313,9 @@ static int cpu_wrapper(Context *ctx,
|
|||||||
const float *verify_scores_now =
|
const float *verify_scores_now =
|
||||||
verify_scores + start_token_id * max_candidate_len;
|
verify_scores + start_token_id * max_candidate_len;
|
||||||
step_idx[bid]++;
|
step_idx[bid]++;
|
||||||
if (ENABLE_TOPP) {
|
if (use_target_sampling) {
|
||||||
|
accept_token = sampled_token_id_now[i];
|
||||||
|
} else if (ENABLE_TOPP) {
|
||||||
auto actual_candidate_len_value =
|
auto actual_candidate_len_value =
|
||||||
actual_candidate_len_now[i] > max_candidate_len
|
actual_candidate_len_now[i] > max_candidate_len
|
||||||
? max_candidate_len
|
? max_candidate_len
|
||||||
@@ -307,6 +352,7 @@ static int cpu_wrapper(Context *ctx,
|
|||||||
|
|
||||||
template <bool ENABLE_TOPP, bool USE_TOPK>
|
template <bool ENABLE_TOPP, bool USE_TOPK>
|
||||||
static int xpu3_wrapper(Context *ctx,
|
static int xpu3_wrapper(Context *ctx,
|
||||||
|
const int64_t *sampled_token_ids,
|
||||||
int64_t *accept_tokens,
|
int64_t *accept_tokens,
|
||||||
int *accept_num,
|
int *accept_num,
|
||||||
int64_t *step_idx,
|
int64_t *step_idx,
|
||||||
@@ -332,10 +378,13 @@ static int xpu3_wrapper(Context *ctx,
|
|||||||
const int max_candidate_len,
|
const int max_candidate_len,
|
||||||
const int verify_window,
|
const int verify_window,
|
||||||
const bool prefill_one_step_stop,
|
const bool prefill_one_step_stop,
|
||||||
const bool benchmark_mode) {
|
const bool benchmark_mode,
|
||||||
|
const bool accept_all_drafts,
|
||||||
|
const bool use_target_sampling) {
|
||||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||||
xpu3::plugin::speculate_verify<ENABLE_TOPP, USE_TOPK>
|
xpu3::plugin::speculate_verify<ENABLE_TOPP, USE_TOPK>
|
||||||
<<<1, 64, ctx->xpu_stream>>>(
|
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||||
|
reinterpret_cast<const XPU_INT64 *>(sampled_token_ids),
|
||||||
reinterpret_cast<XPU_INT64 *>(accept_tokens),
|
reinterpret_cast<XPU_INT64 *>(accept_tokens),
|
||||||
accept_num,
|
accept_num,
|
||||||
reinterpret_cast<XPU_INT64 *>(step_idx),
|
reinterpret_cast<XPU_INT64 *>(step_idx),
|
||||||
@@ -361,11 +410,14 @@ static int xpu3_wrapper(Context *ctx,
|
|||||||
max_candidate_len,
|
max_candidate_len,
|
||||||
verify_window,
|
verify_window,
|
||||||
prefill_one_step_stop,
|
prefill_one_step_stop,
|
||||||
benchmark_mode);
|
benchmark_mode,
|
||||||
|
accept_all_drafts,
|
||||||
|
use_target_sampling);
|
||||||
return api::SUCCESS;
|
return api::SUCCESS;
|
||||||
}
|
}
|
||||||
template <bool ENABLE_TOPP, bool USE_TOPK>
|
template <bool ENABLE_TOPP, bool USE_TOPK>
|
||||||
int speculate_verify(Context *ctx,
|
int speculate_verify(Context *ctx,
|
||||||
|
const int64_t *sampled_token_ids,
|
||||||
int64_t *accept_tokens,
|
int64_t *accept_tokens,
|
||||||
int *accept_num,
|
int *accept_num,
|
||||||
int64_t *step_idx,
|
int64_t *step_idx,
|
||||||
@@ -391,10 +443,13 @@ int speculate_verify(Context *ctx,
|
|||||||
const int max_candidate_len,
|
const int max_candidate_len,
|
||||||
const int verify_window,
|
const int verify_window,
|
||||||
const bool prefill_one_step_stop,
|
const bool prefill_one_step_stop,
|
||||||
const bool benchmark_mode) {
|
const bool benchmark_mode,
|
||||||
|
const bool accept_all_drafts,
|
||||||
|
const bool use_target_sampling) {
|
||||||
WRAPPER_CHECK_CTX(ctx);
|
WRAPPER_CHECK_CTX(ctx);
|
||||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_verify", int64_t);
|
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_verify", int64_t);
|
||||||
WRAPPER_DUMP_PARAM3(ctx, accept_tokens, accept_num, step_idx);
|
WRAPPER_DUMP_PARAM4(
|
||||||
|
ctx, sampled_token_ids, accept_tokens, accept_num, step_idx);
|
||||||
WRAPPER_DUMP_PARAM6(ctx,
|
WRAPPER_DUMP_PARAM6(ctx,
|
||||||
stop_flags,
|
stop_flags,
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
@@ -421,6 +476,7 @@ int speculate_verify(Context *ctx,
|
|||||||
verify_window,
|
verify_window,
|
||||||
prefill_one_step_stop,
|
prefill_one_step_stop,
|
||||||
benchmark_mode);
|
benchmark_mode);
|
||||||
|
WRAPPER_DUMP_PARAM2(ctx, accept_all_drafts, use_target_sampling);
|
||||||
WRAPPER_DUMP(ctx);
|
WRAPPER_DUMP(ctx);
|
||||||
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * max_draft_tokens, accept_tokens);
|
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * max_draft_tokens, accept_tokens);
|
||||||
WRAPPER_CHECK_PTR(ctx, int, real_bsz, accept_num);
|
WRAPPER_CHECK_PTR(ctx, int, real_bsz, accept_num);
|
||||||
@@ -454,6 +510,7 @@ int speculate_verify(Context *ctx,
|
|||||||
|
|
||||||
if (ctx->dev().type() == api::kCPU) {
|
if (ctx->dev().type() == api::kCPU) {
|
||||||
return cpu_wrapper<ENABLE_TOPP, USE_TOPK>(ctx,
|
return cpu_wrapper<ENABLE_TOPP, USE_TOPK>(ctx,
|
||||||
|
sampled_token_ids,
|
||||||
accept_tokens,
|
accept_tokens,
|
||||||
accept_num,
|
accept_num,
|
||||||
step_idx,
|
step_idx,
|
||||||
@@ -479,10 +536,13 @@ int speculate_verify(Context *ctx,
|
|||||||
max_candidate_len,
|
max_candidate_len,
|
||||||
verify_window,
|
verify_window,
|
||||||
prefill_one_step_stop,
|
prefill_one_step_stop,
|
||||||
benchmark_mode);
|
benchmark_mode,
|
||||||
|
accept_all_drafts,
|
||||||
|
use_target_sampling);
|
||||||
}
|
}
|
||||||
if (ctx->dev().type() == api::kXPU3) {
|
if (ctx->dev().type() == api::kXPU3) {
|
||||||
return xpu3_wrapper<ENABLE_TOPP, USE_TOPK>(ctx,
|
return xpu3_wrapper<ENABLE_TOPP, USE_TOPK>(ctx,
|
||||||
|
sampled_token_ids,
|
||||||
accept_tokens,
|
accept_tokens,
|
||||||
accept_num,
|
accept_num,
|
||||||
step_idx,
|
step_idx,
|
||||||
@@ -508,7 +568,9 @@ int speculate_verify(Context *ctx,
|
|||||||
max_candidate_len,
|
max_candidate_len,
|
||||||
verify_window,
|
verify_window,
|
||||||
prefill_one_step_stop,
|
prefill_one_step_stop,
|
||||||
benchmark_mode);
|
benchmark_mode,
|
||||||
|
accept_all_drafts,
|
||||||
|
use_target_sampling);
|
||||||
}
|
}
|
||||||
WRAPPER_UNIMPLEMENTED(ctx);
|
WRAPPER_UNIMPLEMENTED(ctx);
|
||||||
}
|
}
|
||||||
@@ -517,6 +579,7 @@ int speculate_verify(Context *ctx,
|
|||||||
template int \
|
template int \
|
||||||
baidu::xpu::api::plugin::speculate_verify<ENABLE_TOPP, USE_TOPK>( \
|
baidu::xpu::api::plugin::speculate_verify<ENABLE_TOPP, USE_TOPK>( \
|
||||||
baidu::xpu::api::Context *, /* xpu_ctx */ \
|
baidu::xpu::api::Context *, /* xpu_ctx */ \
|
||||||
|
const int64_t *, /* sampled_token_ids */ \
|
||||||
int64_t *, /* accept_tokens */ \
|
int64_t *, /* accept_tokens */ \
|
||||||
int *, /* accept_num */ \
|
int *, /* accept_num */ \
|
||||||
int64_t *, /* step_idx */ \
|
int64_t *, /* step_idx */ \
|
||||||
@@ -541,8 +604,11 @@ int speculate_verify(Context *ctx,
|
|||||||
int, /* max_seq_len */ \
|
int, /* max_seq_len */ \
|
||||||
int, /* max_candidate_len */ \
|
int, /* max_candidate_len */ \
|
||||||
int, /* verify_window */ \
|
int, /* verify_window */ \
|
||||||
bool, \
|
bool, /* prefill_one_step_stop */ \
|
||||||
bool); /* prefill_one_step_stop */
|
bool, /* benchmark_mode */ \
|
||||||
|
bool, /* accept_all_drafts */ \
|
||||||
|
bool /* use_target_sampling */ \
|
||||||
|
);
|
||||||
|
|
||||||
INSTANTIATE_SPECULATE_VERIFY(false, false)
|
INSTANTIATE_SPECULATE_VERIFY(false, false)
|
||||||
INSTANTIATE_SPECULATE_VERIFY(false, true)
|
INSTANTIATE_SPECULATE_VERIFY(false, true)
|
||||||
|
|||||||
@@ -13,11 +13,10 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import random
|
import random
|
||||||
|
import unittest
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
# tests/speculate_verify.py
|
|
||||||
import paddle
|
import paddle
|
||||||
|
|
||||||
from fastdeploy.model_executor.ops.xpu import speculate_verify
|
from fastdeploy.model_executor.ops.xpu import speculate_verify
|
||||||
@@ -25,53 +24,31 @@ from fastdeploy.model_executor.ops.xpu import speculate_verify
|
|||||||
|
|
||||||
def topp_sampling_kernel(candidate_ids, candidate_scores, curand_value, candidate_len, topp, tid=0):
|
def topp_sampling_kernel(candidate_ids, candidate_scores, curand_value, candidate_len, topp, tid=0):
|
||||||
"""
|
"""
|
||||||
Python 仿真版 Top-p 样本选择函数。
|
Python simulation version of the Top-p sampling function.
|
||||||
|
|
||||||
参数:
|
Parameters:
|
||||||
- candidate_ids: [candidate_len] int64 array,候选 token
|
- candidate_ids: [candidate_len] int64 array, candidate tokens
|
||||||
- candidate_scores: [candidate_len] float32 array,对应概率
|
- candidate_scores: [candidate_len] float32 array, corresponding probabilities
|
||||||
- curand_value: float,范围在 [0,1),模拟 GPU 中的 curand_uniform
|
- curand_value: float, in the range [0, 1), simulating the GPU's curand_uniform
|
||||||
- candidate_len: int,候选个数
|
- candidate_len: int, number of candidates
|
||||||
- topp: float,TopP 截断阈值
|
- topp: float, Top-P truncation threshold
|
||||||
- tid: 模拟线程 ID,仅用于调试(非必须)
|
- tid: simulated thread ID, for debugging purposes only (optional)
|
||||||
|
|
||||||
返回:
|
Returns:
|
||||||
- 采样得到的 token(int64)
|
- The sampled token (int64)
|
||||||
"""
|
"""
|
||||||
rand_top_p = curand_value * topp
|
rand_top_p = curand_value * topp
|
||||||
sum_scores = 0.0
|
sum_scores = 0.0
|
||||||
for i in range(candidate_len):
|
for i in range(candidate_len):
|
||||||
print(
|
|
||||||
f"debug sample i:{i} scores:{candidate_scores[i]},ids:{candidate_ids[i]},curand_value{curand_value},topp{topp}, value*topp{rand_top_p}"
|
|
||||||
)
|
|
||||||
sum_scores += candidate_scores[i]
|
sum_scores += candidate_scores[i]
|
||||||
sum_scores += candidate_scores[i]
|
sum_scores += candidate_scores[i]
|
||||||
if rand_top_p <= sum_scores:
|
if rand_top_p <= sum_scores:
|
||||||
return candidate_ids[i]
|
return candidate_ids[i]
|
||||||
return candidate_ids[0] # fallback(理论上不会走到这)
|
return candidate_ids[0]
|
||||||
|
|
||||||
|
|
||||||
# def is_in_end(id: int, end_ids: np.ndarray, length: int) -> bool:
|
def speculate_verify_ref(
|
||||||
# """
|
sampled_token_ids,
|
||||||
# 判断 id 是否存在于 end_ids 前 length 个元素中。
|
|
||||||
# """
|
|
||||||
# for i in range(length):
|
|
||||||
# if id == end_ids[i]:
|
|
||||||
# return True
|
|
||||||
# return False
|
|
||||||
|
|
||||||
# def is_in(candidates: np.ndarray, draft: int, candidate_len: int) -> bool:
|
|
||||||
# """
|
|
||||||
# 判断 draft 是否在 candidates 的前 candidate_len 个元素中。
|
|
||||||
# """
|
|
||||||
# for i in range(candidate_len):
|
|
||||||
# if draft == candidates[i]:
|
|
||||||
# return True
|
|
||||||
# return False
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------- NumPy 参考实现 ----------------
|
|
||||||
def speculate_verify_np(
|
|
||||||
accept_tokens,
|
accept_tokens,
|
||||||
accept_num,
|
accept_num,
|
||||||
step_idx,
|
step_idx,
|
||||||
@@ -92,6 +69,8 @@ def speculate_verify_np(
|
|||||||
max_seq_len,
|
max_seq_len,
|
||||||
verify_window,
|
verify_window,
|
||||||
enable_topp,
|
enable_topp,
|
||||||
|
benchmark_mode,
|
||||||
|
accept_all_drafts,
|
||||||
):
|
):
|
||||||
def is_in_end(token, end_tokens, end_length):
|
def is_in_end(token, end_tokens, end_length):
|
||||||
return token in end_tokens[:end_length]
|
return token in end_tokens[:end_length]
|
||||||
@@ -112,84 +91,43 @@ def speculate_verify_np(
|
|||||||
infer_seed: List[int] = [initial_seed] * bsz
|
infer_seed: List[int] = [initial_seed] * bsz
|
||||||
dev_curand_states: List[float] = []
|
dev_curand_states: List[float] = []
|
||||||
|
|
||||||
# 循环生成随机数
|
|
||||||
for i in range(bsz):
|
for i in range(bsz):
|
||||||
current_seed = infer_seed[i] # 这里 current_seed 总是等于 initial_seed
|
current_seed = infer_seed[i]
|
||||||
|
|
||||||
# 使用当前的种子创建一个独立的随机数生成器实例
|
# std::mt19937_64 engine(infer_seed[i]);
|
||||||
# 这对应于 C++ 的 std::mt19937_64 engine(infer_seed[i]);
|
|
||||||
rng = random.Random(current_seed)
|
rng = random.Random(current_seed)
|
||||||
|
|
||||||
# 从独立的生成器中获取一个 [0.0, 1.0) 范围内的浮点数
|
|
||||||
# 这对应于 C++ 的 dist(engine);
|
|
||||||
dev_curand_states.append(rng.random())
|
dev_curand_states.append(rng.random())
|
||||||
# --- 在函数内部进行扁平化操作 ---
|
|
||||||
# 只有那些在 C++ 中通过指针算术访问的多维数组需要扁平化
|
# flatten
|
||||||
accept_tokens_flat = accept_tokens.reshape(-1)
|
accept_tokens_flat = accept_tokens.reshape([-1])
|
||||||
draft_tokens_flat = draft_tokens.reshape(-1)
|
draft_tokens_flat = draft_tokens.reshape([-1])
|
||||||
verify_tokens_flat = verify_tokens.reshape(-1)
|
verify_tokens_flat = verify_tokens.reshape([-1])
|
||||||
verify_scores_flat = verify_scores.reshape(-1)
|
verify_scores_flat = verify_scores.reshape([-1])
|
||||||
print(f"DEBUG: accept_tokens_flat shape: {accept_tokens_flat.shape}")
|
|
||||||
print(f"DEBUG: draft_tokens_flat shape: {draft_tokens_flat.shape}")
|
|
||||||
print(f"DEBUG: verify_tokens_flat shape: {verify_tokens_flat.shape}")
|
|
||||||
print(f"DEBUG: verify_scores_flat shape: {verify_scores_flat.shape}")
|
|
||||||
# 其他数组 (如 accept_num, step_idx, stop_flags, end_tokens, dev_curand_states, actual_candidate_len,
|
|
||||||
# seq_lens_encoder, seq_lens_decoder, actual_draft_token_nums, topp_values,
|
|
||||||
# seq_lens_this_time, max_dec_len, is_block_step, output_cum_offsets)
|
|
||||||
# 根据其 C++ 原始定义,如果本身就是一维的,则不需要额外的 reshape。
|
|
||||||
# 这里直接使用其原始引用,或者如果其维度不确定,也可以做 flatten()。
|
|
||||||
# 为了明确,我们假设这些参数如果不是 (N, K) 形式,就已经是 (N,) 形式。
|
|
||||||
print()
|
|
||||||
# 遍历批次中的每个样本
|
|
||||||
for bid in range(real_bsz):
|
for bid in range(real_bsz):
|
||||||
# C++: const int start_token_id = bid * max_seq_len - output_cum_offsets[bid];
|
|
||||||
start_token_id = bid * max_seq_len - output_cum_offsets[bid]
|
start_token_id = bid * max_seq_len - output_cum_offsets[bid]
|
||||||
accept_num_now = 1
|
accept_num_now = 1
|
||||||
stop_flag_now_int = 0
|
stop_flag_now_int = 0
|
||||||
print(
|
|
||||||
f"DEBUG: start_token_id: {start_token_id}, max_seq_len: {max_seq_len}, output_cum_offsets[{bid}]: {output_cum_offsets[bid]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# C++: if (!(is_block_step[bid] || bid >= real_bsz))
|
if not (is_block_step[bid] or bid >= real_bsz): # bid >= real_bsz reserved for consistency with gpu
|
||||||
if not (
|
|
||||||
is_block_step[bid] or bid >= real_bsz
|
|
||||||
): # bid >= real_bsz 在 Python for 循环中天然满足,但为保持一致保留
|
|
||||||
if stop_flags[bid]:
|
if stop_flags[bid]:
|
||||||
stop_flag_now_int = 1
|
stop_flag_now_int = 1
|
||||||
else:
|
else:
|
||||||
# C++: auto *verify_tokens_now = verify_tokens + start_token_id * max_candidate_len;
|
verify_tokens_now = verify_tokens_flat[start_token_id * max_candidate_len :]
|
||||||
# Python: verify_tokens_now 是一个指向当前批次 verify_tokens 起始的扁平视图
|
draft_tokens_now = draft_tokens_flat[bid * max_draft_tokens :]
|
||||||
# 模拟了 C++ 中指针偏移后的“基地址”
|
actual_candidate_len_now = actual_candidate_len[start_token_id:]
|
||||||
verify_tokens_now = verify_tokens_flat[start_token_id * max_candidate_len :] # 从基址到末尾
|
|
||||||
|
|
||||||
# C++: auto *draft_tokens_now = draft_tokens + bid * max_draft_tokens;
|
|
||||||
# Python: draft_tokens_now 是当前批次 draft_tokens 起始的扁平视图
|
|
||||||
draft_tokens_now = draft_tokens_flat[bid * max_draft_tokens :] # 从基址到末尾
|
|
||||||
|
|
||||||
# C++: auto *actual_candidate_len_now = actual_candidate_len + start_token_id;
|
|
||||||
# Python: actual_candidate_len_now 是当前批次 actual_candidate_len 起始的扁平视图
|
|
||||||
actual_candidate_len_now = actual_candidate_len[start_token_id:] # actual_candidate_len 已经是 1D
|
|
||||||
|
|
||||||
# C++: int i = 0;
|
|
||||||
i = 0
|
i = 0
|
||||||
|
for loop_i in range(seq_lens_this_time[bid] - 1):
|
||||||
|
i = loop_i
|
||||||
|
|
||||||
# C++: for (; i < seq_lens_this_time[bid] - 1; i++)
|
|
||||||
for loop_i in range(seq_lens_this_time[bid] - 1): # 使用 loop_i 作为 Python 的循环变量
|
|
||||||
i = loop_i # 保持 C++ 的 i 在每次迭代中更新为当前索引
|
|
||||||
|
|
||||||
# C++: if (seq_lens_encoder[bid] != 0)
|
|
||||||
if seq_lens_encoder[bid] != 0:
|
if seq_lens_encoder[bid] != 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
if use_topk:
|
if use_topk:
|
||||||
# C++: if (verify_tokens_now[i * max_candidate_len] == draft_tokens_now[i + 1])
|
|
||||||
if verify_tokens_now[i * max_candidate_len] == draft_tokens_now[i + 1]:
|
if verify_tokens_now[i * max_candidate_len] == draft_tokens_now[i + 1]:
|
||||||
step_idx[bid] += 1
|
step_idx[bid] += 1
|
||||||
accept_token = draft_tokens_now[i + 1]
|
accept_token = draft_tokens_now[i + 1]
|
||||||
# C++: accept_tokens[bid * max_draft_tokens + i] = accept_token;
|
|
||||||
accept_tokens_flat[bid * max_draft_tokens + i] = accept_token
|
accept_tokens_flat[bid * max_draft_tokens + i] = accept_token
|
||||||
|
|
||||||
# C++: if (is_in_end(accept_token, end_tokens, end_length) || step_idx[bid] >= max_dec_len[bid])
|
|
||||||
if is_in_end(accept_token, end_tokens, end_length) or step_idx[bid] >= max_dec_len[bid]:
|
if is_in_end(accept_token, end_tokens, end_length) or step_idx[bid] >= max_dec_len[bid]:
|
||||||
stop_flags[bid] = True
|
stop_flags[bid] = True
|
||||||
stop_flag_now_int = 1
|
stop_flag_now_int = 1
|
||||||
@@ -200,12 +138,8 @@ def speculate_verify_np(
|
|||||||
accept_num_now += 1
|
accept_num_now += 1
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
else: # C++: else (Top P verify)
|
else:
|
||||||
# C++: auto actual_candidate_len_value = actual_candidate_len_now[i] > max_candidate_len ? max_candidate_len : actual_candidate_len_now[i];
|
|
||||||
actual_candidate_len_value = min(actual_candidate_len_now[i], max_candidate_len)
|
actual_candidate_len_value = min(actual_candidate_len_now[i], max_candidate_len)
|
||||||
|
|
||||||
# C++: if (is_in(verify_tokens_now + i * max_candidate_len, draft_tokens_now[i + 1], actual_candidate_len_value))
|
|
||||||
# 传入当前候选的扁平视图
|
|
||||||
verify_tokens_current_candidate_view = verify_tokens_now[
|
verify_tokens_current_candidate_view = verify_tokens_now[
|
||||||
i * max_candidate_len : (i + 1) * max_candidate_len
|
i * max_candidate_len : (i + 1) * max_candidate_len
|
||||||
]
|
]
|
||||||
@@ -229,27 +163,23 @@ def speculate_verify_np(
|
|||||||
accept_num_now += 1
|
accept_num_now += 1
|
||||||
else:
|
else:
|
||||||
# TopK verify
|
# TopK verify
|
||||||
ii = i # C++ 中 ii 从 i 开始
|
ii = i # Start from i
|
||||||
# C++: if (max_candidate_len >= 2 && verify_tokens_now[ii * max_candidate_len + 1] == draft_tokens_now[ii + 1])
|
|
||||||
if (
|
if (
|
||||||
max_candidate_len >= 2
|
max_candidate_len >= 2
|
||||||
and verify_tokens_now[ii * max_candidate_len + 1] == draft_tokens_now[ii + 1]
|
and verify_tokens_now[ii * max_candidate_len + 1] == draft_tokens_now[ii + 1]
|
||||||
): # top-2
|
): # top-2
|
||||||
j = 0
|
j = 0
|
||||||
ii += 1 # C++ 中 ii 从下一个位置开始检查
|
ii += 1 # Start from ii next position
|
||||||
# C++: for (; j < verify_window && ii < seq_lens_this_time[bid] - 1; j++, ii++)
|
|
||||||
while j < verify_window and ii < seq_lens_this_time[bid] - 1:
|
while j < verify_window and ii < seq_lens_this_time[bid] - 1:
|
||||||
if verify_tokens_now[ii * max_candidate_len] != draft_tokens_now[ii + 1]:
|
if verify_tokens_now[ii * max_candidate_len] != draft_tokens_now[ii + 1]:
|
||||||
break
|
break
|
||||||
j += 1
|
j += 1
|
||||||
ii += 1
|
ii += 1
|
||||||
|
|
||||||
# C++: if (j >= verify_window)
|
|
||||||
if j >= verify_window: # accept all
|
if j >= verify_window: # accept all
|
||||||
accept_num_now += verify_window + 1
|
accept_num_now += verify_window + 1
|
||||||
step_idx[bid] += verify_window + 1
|
step_idx[bid] += verify_window + 1
|
||||||
# C++: for (; i < ii; i++)
|
for k_accepted_idx in range(i, ii):
|
||||||
for k_accepted_idx in range(i, ii): # i 会被更新
|
|
||||||
accept_token = draft_tokens_now[k_accepted_idx + 1]
|
accept_token = draft_tokens_now[k_accepted_idx + 1]
|
||||||
accept_tokens_flat[bid * max_draft_tokens + k_accepted_idx] = accept_token
|
accept_tokens_flat[bid * max_draft_tokens + k_accepted_idx] = accept_token
|
||||||
|
|
||||||
@@ -269,29 +199,18 @@ def speculate_verify_np(
|
|||||||
)
|
)
|
||||||
accept_num_now -= 1
|
accept_num_now -= 1
|
||||||
step_idx[bid] -= 1
|
step_idx[bid] -= 1
|
||||||
break # 跳出内层接受循环
|
break
|
||||||
break # 跳出主验证循环 (TopK 逻辑结束,无论成功与否)
|
break # TopK finish
|
||||||
# else 的 break 对应 is_in(Top P 验证失败,也不是 TopK 匹配)
|
break # Jump main loop
|
||||||
break # 跳出主验证循环
|
|
||||||
|
|
||||||
# 采样阶段 (Sampling Phase)
|
|
||||||
# C++ 中 i 变量在循环结束后会保留其最终值,直接用于采样
|
|
||||||
# Python 同样,loop_i 的最终值赋值给了 i
|
|
||||||
|
|
||||||
if not stop_flag_now_int:
|
if not stop_flag_now_int:
|
||||||
accept_token: int
|
accept_token: int
|
||||||
|
|
||||||
# C++: const float *verify_scores_now = verify_scores + start_token_id * max_candidate_len;
|
|
||||||
# Python: verify_scores_now 对应 C++ 中从 start_token_id 开始的 verify_scores 视图
|
|
||||||
verify_scores_now = verify_scores_flat[start_token_id * max_candidate_len :]
|
verify_scores_now = verify_scores_flat[start_token_id * max_candidate_len :]
|
||||||
|
|
||||||
step_idx[bid] += 1
|
step_idx[bid] += 1
|
||||||
|
|
||||||
if enable_topp:
|
if enable_topp:
|
||||||
# C++: auto actual_candidate_len_value = actual_candidate_len_now[i] > max_candidate_len ? max_candidate_len : actual_candidate_len_now[i];
|
|
||||||
actual_candidate_len_value = min(actual_candidate_len_now[i], max_candidate_len)
|
actual_candidate_len_value = min(actual_candidate_len_now[i], max_candidate_len)
|
||||||
|
|
||||||
# 传入当前候选的扁平视图
|
|
||||||
verify_tokens_sampling_view = verify_tokens_now[
|
verify_tokens_sampling_view = verify_tokens_now[
|
||||||
i * max_candidate_len : (i + 1) * max_candidate_len
|
i * max_candidate_len : (i + 1) * max_candidate_len
|
||||||
]
|
]
|
||||||
@@ -299,24 +218,17 @@ def speculate_verify_np(
|
|||||||
i * max_candidate_len : (i + 1) * max_candidate_len
|
i * max_candidate_len : (i + 1) * max_candidate_len
|
||||||
]
|
]
|
||||||
|
|
||||||
# C++: accept_token = topp_sampling_kernel(...)
|
|
||||||
accept_token = topp_sampling_kernel(
|
accept_token = topp_sampling_kernel(
|
||||||
verify_tokens_sampling_view,
|
verify_tokens_sampling_view,
|
||||||
verify_scores_sampling_view,
|
verify_scores_sampling_view,
|
||||||
dev_curand_states[i], # C++: dev_curand_states + i
|
dev_curand_states[i],
|
||||||
actual_candidate_len_value,
|
actual_candidate_len_value,
|
||||||
topp[bid], # C++: topp[bid]
|
topp[bid],
|
||||||
bid, # C++: bid
|
bid,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
accept_token = int(verify_tokens_now[i * max_candidate_len])
|
accept_token = int(verify_tokens_now[i * max_candidate_len])
|
||||||
print(
|
|
||||||
"debug python last accept_token",
|
|
||||||
accept_token,
|
|
||||||
"prefill_one_step_stop",
|
|
||||||
prefill_one_step_stop,
|
|
||||||
)
|
|
||||||
# C++: accept_tokens[bid * max_draft_tokens + i] = accept_token;
|
|
||||||
accept_tokens_flat[bid * max_draft_tokens + i] = accept_token
|
accept_tokens_flat[bid * max_draft_tokens + i] = accept_token
|
||||||
|
|
||||||
if prefill_one_step_stop:
|
if prefill_one_step_stop:
|
||||||
@@ -333,7 +245,6 @@ def speculate_verify_np(
|
|||||||
return accept_tokens, accept_num, step_idx, stop_flags
|
return accept_tokens, accept_num, step_idx, stop_flags
|
||||||
|
|
||||||
|
|
||||||
# ---------------- 生成随机输入 ----------------
|
|
||||||
def gen_speculate_verify_inputs(
|
def gen_speculate_verify_inputs(
|
||||||
real_bsz=123,
|
real_bsz=123,
|
||||||
max_draft_tokens=16,
|
max_draft_tokens=16,
|
||||||
@@ -341,12 +252,11 @@ def gen_speculate_verify_inputs(
|
|||||||
max_candidate_len=8,
|
max_candidate_len=8,
|
||||||
verify_window=2,
|
verify_window=2,
|
||||||
end_length=4,
|
end_length=4,
|
||||||
enable_topp=True,
|
enable_topp=False,
|
||||||
seed=2025,
|
seed=2025,
|
||||||
):
|
):
|
||||||
rng = np.random.default_rng(seed)
|
rng = np.random.default_rng(seed)
|
||||||
|
|
||||||
# 基础输入
|
|
||||||
seq_lens_encoder = rng.integers(0, 3, size=real_bsz, dtype=np.int32)
|
seq_lens_encoder = rng.integers(0, 3, size=real_bsz, dtype=np.int32)
|
||||||
seq_lens_decoder = rng.integers(1, max_draft_tokens, size=real_bsz, dtype=np.int32)
|
seq_lens_decoder = rng.integers(1, max_draft_tokens, size=real_bsz, dtype=np.int32)
|
||||||
draft_tokens = rng.integers(0, 1000, size=(real_bsz, max_draft_tokens), dtype=np.int64)
|
draft_tokens = rng.integers(0, 1000, size=(real_bsz, max_draft_tokens), dtype=np.int64)
|
||||||
@@ -354,10 +264,8 @@ def gen_speculate_verify_inputs(
|
|||||||
|
|
||||||
seq_lens_this_time = rng.integers(1, max_seq_len + 1, size=real_bsz, dtype=np.int32)
|
seq_lens_this_time = rng.integers(1, max_seq_len + 1, size=real_bsz, dtype=np.int32)
|
||||||
sum_seq_this_time = int(np.sum(seq_lens_this_time))
|
sum_seq_this_time = int(np.sum(seq_lens_this_time))
|
||||||
# print("debug param set sum_seq_this_time",sum_seq_this_time)
|
|
||||||
# print("debug param real_bsz * max_draft_tokens < 2k",real_bsz * max_draft_tokens)
|
|
||||||
# print("debug sum_seq_this_time * max_candidate_len < 2k",sum_seq_this_time * max_candidate_len)
|
|
||||||
|
|
||||||
|
sampled_token_ids = rng.integers(0, 1000, size=(sum_seq_this_time, 1), dtype=np.int64)
|
||||||
verify_tokens = rng.integers(0, 1000, size=(sum_seq_this_time, max_candidate_len), dtype=np.int64)
|
verify_tokens = rng.integers(0, 1000, size=(sum_seq_this_time, max_candidate_len), dtype=np.int64)
|
||||||
verify_scores = rng.random(size=(sum_seq_this_time, max_candidate_len)).astype(np.float32)
|
verify_scores = rng.random(size=(sum_seq_this_time, max_candidate_len)).astype(np.float32)
|
||||||
|
|
||||||
@@ -378,13 +286,14 @@ def gen_speculate_verify_inputs(
|
|||||||
else np.zeros(real_bsz, dtype=np.float32)
|
else np.zeros(real_bsz, dtype=np.float32)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 输出(占位)
|
# Output(inplace)
|
||||||
accept_tokens = np.zeros((real_bsz, max_draft_tokens), dtype=np.int64)
|
accept_tokens = np.zeros((real_bsz, max_draft_tokens), dtype=np.int64)
|
||||||
accept_num = np.zeros(real_bsz, dtype=np.int32)
|
accept_num = np.zeros(real_bsz, dtype=np.int32)
|
||||||
step_idx = np.zeros(real_bsz, dtype=np.int64)
|
step_idx = np.zeros(real_bsz, dtype=np.int64)
|
||||||
stop_flags = np.zeros(real_bsz, dtype=bool)
|
stop_flags = np.zeros(real_bsz, dtype=bool)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
"sampled_token_ids": sampled_token_ids,
|
||||||
"accept_tokens": accept_tokens,
|
"accept_tokens": accept_tokens,
|
||||||
"accept_num": accept_num,
|
"accept_num": accept_num,
|
||||||
"step_idx": step_idx,
|
"step_idx": step_idx,
|
||||||
@@ -405,178 +314,12 @@ def gen_speculate_verify_inputs(
|
|||||||
"max_seq_len": max_seq_len,
|
"max_seq_len": max_seq_len,
|
||||||
"verify_window": verify_window,
|
"verify_window": verify_window,
|
||||||
"enable_topp": enable_topp,
|
"enable_topp": enable_topp,
|
||||||
|
"benchmark_mode": False,
|
||||||
|
"accept_all_drafts": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# ------------------- 单测主体 -------------------
|
|
||||||
# # ---- Paddle 端 ----
|
|
||||||
def run_speculate_verify_test(
|
|
||||||
real_bsz,
|
|
||||||
max_draft_tokens,
|
|
||||||
max_seq_len,
|
|
||||||
max_candidate_len,
|
|
||||||
verify_window,
|
|
||||||
end_length,
|
|
||||||
enable_topp,
|
|
||||||
seed,
|
|
||||||
):
|
|
||||||
inputs = gen_speculate_verify_inputs(
|
|
||||||
real_bsz=real_bsz,
|
|
||||||
max_draft_tokens=max_draft_tokens,
|
|
||||||
max_seq_len=max_seq_len,
|
|
||||||
max_candidate_len=max_candidate_len,
|
|
||||||
verify_window=verify_window,
|
|
||||||
end_length=end_length,
|
|
||||||
enable_topp=enable_topp,
|
|
||||||
seed=seed,
|
|
||||||
)
|
|
||||||
|
|
||||||
paddle_inputs = {}
|
|
||||||
|
|
||||||
print("========= 1 xpu process==========")
|
|
||||||
|
|
||||||
for k, v in inputs.items():
|
|
||||||
if isinstance(v, (int, bool)):
|
|
||||||
paddle_inputs[k] = v
|
|
||||||
# print(f"{k:<25} type: {type(v).__name__}, value: {v}")
|
|
||||||
else:
|
|
||||||
# paddle_inputs[k] = paddle.to_tensor(v, place=paddle.CPUPlace())
|
|
||||||
paddle_inputs[k] = paddle.to_tensor(v, place=paddle.XPUPlace(0))
|
|
||||||
# print(f"{k:<25} type: Tensor, dtype: {paddle_inputs[k].dtype}, shape: {paddle_inputs[k].shape}")
|
|
||||||
|
|
||||||
out_pd = speculate_verify(**paddle_inputs)
|
|
||||||
(accept_tokens_pd, accept_num_pd, step_idx_pd, stop_flags_pd) = out_pd
|
|
||||||
pd_tensors = [accept_tokens_pd, accept_num_pd, step_idx_pd, stop_flags_pd]
|
|
||||||
|
|
||||||
print("========= 1 end==========")
|
|
||||||
print("========= 2 python process==========")
|
|
||||||
|
|
||||||
# np_inputs = {k: (paddle_inputs[k].numpy().copy() if isinstance(paddle_inputs[k], paddle.Tensor)
|
|
||||||
# else paddle_inputs[k])
|
|
||||||
# for k in paddle_inputs}
|
|
||||||
|
|
||||||
# out_np = speculate_verify_np(**np_inputs)
|
|
||||||
# (accept_tokens_np, accept_num_np, step_idx_np, stop_flags_np) = out_np
|
|
||||||
# np_tensors = [accept_tokens_np, accept_num_np, step_idx_np, stop_flags_np]
|
|
||||||
|
|
||||||
print("=========2 end =======")
|
|
||||||
|
|
||||||
print("========= 3 (CPU)==========")
|
|
||||||
paddle_inputs_cpu = {}
|
|
||||||
|
|
||||||
for k, v in inputs.items(): # 重新使用原始的 inputs 字典,确保数据原始状态
|
|
||||||
if isinstance(v, (int, bool)):
|
|
||||||
paddle_inputs_cpu[k] = v
|
|
||||||
# print(f"{k:<25} type: {type(v).__name__}, value: {v}")
|
|
||||||
else:
|
|
||||||
# 核心修改:使用 paddle.CPUPlace()
|
|
||||||
paddle_inputs_cpu[k] = paddle.to_tensor(v, place=paddle.CPUPlace())
|
|
||||||
# print(f"{k:<25} type: Tensor, dtype: {paddle_inputs_cpu[k].dtype}, shape: {paddle_inputs_cpu[k].shape}")
|
|
||||||
|
|
||||||
out_cpu = speculate_verify(**paddle_inputs_cpu)
|
|
||||||
(accept_tokens_cpu, accept_num_cpu, step_idx_cpu, stop_flags_cpu) = out_cpu
|
|
||||||
|
|
||||||
cpu_tensors = [
|
|
||||||
accept_tokens_cpu,
|
|
||||||
accept_num_cpu,
|
|
||||||
step_idx_cpu,
|
|
||||||
stop_flags_cpu,
|
|
||||||
]
|
|
||||||
print("========= 3 (CPU) end==========")
|
|
||||||
|
|
||||||
# ---------------- 校对 ----------------
|
|
||||||
# print("========= python/cpu vs xpu verify ==========")
|
|
||||||
|
|
||||||
# names = ["accept_tokens", "accept_num", "step_idx", "stop_flags"]
|
|
||||||
# for name, pd_val, np_val in zip(names, pd_tensors, np_tensors):
|
|
||||||
# pd_arr = pd_val.numpy()
|
|
||||||
# ok = np.array_equal(pd_arr, np_val)
|
|
||||||
# print(f"{name:20s} equal: {ok}")
|
|
||||||
# if not ok:
|
|
||||||
# print(f"{name} mismatch!\nPaddle:\n{pd_arr}\n\nNumPy:\n{np_val}")
|
|
||||||
|
|
||||||
print("========= cpu vs xpu verify ==========")
|
|
||||||
|
|
||||||
names = ["accept_tokens", "accept_num", "step_idx", "stop_flags"]
|
|
||||||
# for name, pd_val, np_val in zip(names, pd_tensors, cpu_tensors):
|
|
||||||
# pd_arr = pd_val.numpy()
|
|
||||||
# ok = np.array_equal(pd_arr, np_val)
|
|
||||||
# print(f"{name:20s} equal: {ok}")
|
|
||||||
# if not ok:
|
|
||||||
# print(f"{name} mismatch!\nPaddle:\n{pd_arr}\n\nNumPy:\n{np_val}")
|
|
||||||
|
|
||||||
for name, pd_val, np_val in zip(names, pd_tensors, cpu_tensors):
|
|
||||||
pd_arr = pd_val.numpy()
|
|
||||||
ok = np.array_equal(pd_arr, np_val)
|
|
||||||
print(f"{name:20s} equal: {ok}")
|
|
||||||
if not ok:
|
|
||||||
print(f"{name} mismatch!")
|
|
||||||
|
|
||||||
# 输出不同位置的索引和对应值
|
|
||||||
print(f"{name} mismatch!\nPaddle:\n{pd_arr}\n\nNumPy:\n{np_val}")
|
|
||||||
mismatches = np.where(pd_arr != np_val)
|
|
||||||
for idx in zip(*mismatches):
|
|
||||||
print(f" idx {idx}: Paddle = {pd_arr[idx]}, NumPy = {np_val[idx]}")
|
|
||||||
|
|
||||||
# 如果差异太多可限制输出数量
|
|
||||||
if len(mismatches[0]) > 20:
|
|
||||||
print(" ... (truncated)")
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
|
||||||
# 测试用例
|
|
||||||
# -------------------------------------
|
|
||||||
test_configs = [
|
test_configs = [
|
||||||
{
|
|
||||||
"real_bsz": 4,
|
|
||||||
"max_draft_tokens": 3,
|
|
||||||
"max_seq_len": 30,
|
|
||||||
"max_candidate_len": 4,
|
|
||||||
"verify_window": 2,
|
|
||||||
"end_length": 2,
|
|
||||||
"enable_topp": True,
|
|
||||||
"seed": 2025,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"real_bsz": 77,
|
|
||||||
"max_draft_tokens": 10,
|
|
||||||
"max_seq_len": 12000,
|
|
||||||
"max_candidate_len": 8,
|
|
||||||
"verify_window": 2,
|
|
||||||
"end_length": 4,
|
|
||||||
"enable_topp": True,
|
|
||||||
"seed": 2025,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"real_bsz": 1,
|
|
||||||
"max_draft_tokens": 2,
|
|
||||||
"max_seq_len": 10,
|
|
||||||
"max_candidate_len": 1,
|
|
||||||
"verify_window": 1,
|
|
||||||
"end_length": 1,
|
|
||||||
"enable_topp": True,
|
|
||||||
"seed": 42,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"real_bsz": 128,
|
|
||||||
"max_draft_tokens": 7,
|
|
||||||
"max_seq_len": 999,
|
|
||||||
"max_candidate_len": 5,
|
|
||||||
"verify_window": 3,
|
|
||||||
"end_length": 3,
|
|
||||||
"enable_topp": True,
|
|
||||||
"seed": 422,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"real_bsz": 99,
|
|
||||||
"max_draft_tokens": 5,
|
|
||||||
"max_seq_len": 10,
|
|
||||||
"max_candidate_len": 3,
|
|
||||||
"verify_window": 4,
|
|
||||||
"end_length": 4,
|
|
||||||
"enable_topp": True,
|
|
||||||
"seed": 42,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"real_bsz": 1,
|
"real_bsz": 1,
|
||||||
"max_draft_tokens": 9,
|
"max_draft_tokens": 9,
|
||||||
@@ -629,6 +372,45 @@ test_configs = [
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
for i, cfg in enumerate(test_configs):
|
|
||||||
print(f"\n\n======== Running Test Case {i} ========")
|
class TestSpeculateVerify(unittest.TestCase):
|
||||||
run_speculate_verify_test(**cfg)
|
def run_speculate_verify(
|
||||||
|
self,
|
||||||
|
real_bsz,
|
||||||
|
max_draft_tokens,
|
||||||
|
max_seq_len,
|
||||||
|
max_candidate_len,
|
||||||
|
verify_window,
|
||||||
|
end_length,
|
||||||
|
enable_topp,
|
||||||
|
seed,
|
||||||
|
):
|
||||||
|
inputs = gen_speculate_verify_inputs(
|
||||||
|
real_bsz=real_bsz,
|
||||||
|
max_draft_tokens=max_draft_tokens,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
max_candidate_len=max_candidate_len,
|
||||||
|
verify_window=verify_window,
|
||||||
|
end_length=end_length,
|
||||||
|
enable_topp=enable_topp,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
paddle_inputs = {k: v if isinstance(v, (int, bool)) else paddle.to_tensor(v) for k, v in inputs.items()}
|
||||||
|
inputs_xpu = list(paddle_inputs.values())
|
||||||
|
speculate_verify(*inputs_xpu)
|
||||||
|
out_xpu = [inputs_xpu[1], inputs_xpu[2], inputs_xpu[3], inputs_xpu[4]]
|
||||||
|
|
||||||
|
paddle_inputs_ref = {k: v if isinstance(v, (int, bool)) else paddle.to_tensor(v) for k, v in inputs.items()}
|
||||||
|
out_ref = speculate_verify_ref(**paddle_inputs_ref)
|
||||||
|
|
||||||
|
names = ["accept_tokens", "accept_num", "step_idx", "stop_flags"]
|
||||||
|
for _, pd_val, np_val in zip(names, out_xpu, out_ref):
|
||||||
|
np.testing.assert_allclose(pd_val.numpy(), np_val.numpy())
|
||||||
|
|
||||||
|
def test_speculate_verify(self):
|
||||||
|
for config in test_configs:
|
||||||
|
self.run_speculate_verify(**config)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ def padding_sampling_params(top_p, top_k, infer_seed, seq_lens_this_time, seq_le
|
|||||||
paddle.ones_like(seq_lens_this_time),
|
paddle.ones_like(seq_lens_this_time),
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_start = (paddle.cumsum(token_lens, axis=0) - token_lens.astype("int64")).reshape(-1) # [B]
|
batch_start = (paddle.cumsum(token_lens, axis=0, dtype="int64") - token_lens.astype("int64")).reshape([-1]) # [B]
|
||||||
token_batch_ids = paddle.repeat_interleave(
|
token_batch_ids = paddle.repeat_interleave(
|
||||||
paddle.arange(token_lens.shape[0], dtype="int64"),
|
paddle.arange(token_lens.shape[0], dtype="int64"),
|
||||||
token_lens,
|
token_lens,
|
||||||
@@ -79,7 +79,7 @@ def padding_sampling_params(top_p, top_k, infer_seed, seq_lens_this_time, seq_le
|
|||||||
token_pos = paddle.arange(topp_seed.shape[0], dtype="int64")
|
token_pos = paddle.arange(topp_seed.shape[0], dtype="int64")
|
||||||
local_pos = token_pos - paddle.gather(batch_start, token_batch_ids)
|
local_pos = token_pos - paddle.gather(batch_start, token_batch_ids)
|
||||||
|
|
||||||
is_decoder = paddle.gather(seq_lens_encoder[:real_bsz] == 0, token_batch_ids).reshape(-1)
|
is_decoder = paddle.gather(seq_lens_encoder[:real_bsz] == 0, token_batch_ids).reshape([-1])
|
||||||
|
|
||||||
offsets = paddle.where(
|
offsets = paddle.where(
|
||||||
is_decoder,
|
is_decoder,
|
||||||
@@ -879,6 +879,15 @@ class SpeculativeSampler(nn.Layer):
|
|||||||
|
|
||||||
probs = F.softmax(logits)
|
probs = F.softmax(logits)
|
||||||
|
|
||||||
|
top_p, top_k, topp_seed = padding_sampling_params(
|
||||||
|
sampling_metadata.top_p,
|
||||||
|
sampling_metadata.top_k,
|
||||||
|
sampling_metadata.seed,
|
||||||
|
share_inputs["seq_lens_this_time"],
|
||||||
|
paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]),
|
||||||
|
)
|
||||||
|
_, sampled_token_ids = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, topp_seed=topp_seed)
|
||||||
|
|
||||||
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
|
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
|
||||||
probs,
|
probs,
|
||||||
sampling_metadata.top_p,
|
sampling_metadata.top_p,
|
||||||
@@ -888,6 +897,7 @@ class SpeculativeSampler(nn.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
speculate_verify(
|
speculate_verify(
|
||||||
|
sampled_token_ids,
|
||||||
share_inputs["accept_tokens"],
|
share_inputs["accept_tokens"],
|
||||||
share_inputs["accept_num"],
|
share_inputs["accept_num"],
|
||||||
share_inputs["step_idx"],
|
share_inputs["step_idx"],
|
||||||
|
|||||||
Reference in New Issue
Block a user