mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
fix top_p_candidates and support separate setting of sampling params for mtp (#4189)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* fix top_p_candidates * For separate setting params for mtp * delete print * fix
This commit is contained in:
@@ -38,14 +38,20 @@ __device__ int64_t topp_sampling_kernel(const int64_t *candidate_ids,
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
float sum_scores = 0.0f;
|
||||
float rand_top_p = curand_uniform(dev_curand_states + tid) * topp;
|
||||
for (int i = 0; i < candidate_len; i++) {
|
||||
sum_scores += candidate_scores[i];
|
||||
}
|
||||
float tgt_topp = sum_scores < topp ? sum_scores : topp;
|
||||
|
||||
sum_scores = 0.0f;
|
||||
float rand_top_p = curand_uniform(dev_curand_states + tid) * tgt_topp;
|
||||
for (int i = 0; i < candidate_len; i++) {
|
||||
sum_scores += candidate_scores[i];
|
||||
if (rand_top_p <= sum_scores) {
|
||||
return candidate_ids[i];
|
||||
return candidate_ids[i];
|
||||
}
|
||||
}
|
||||
return candidate_ids[0];
|
||||
return candidate_ids[0];
|
||||
}
|
||||
|
||||
__global__ void setup_kernel(curandState_t *state, const uint64_t seed,
|
||||
|
Reference in New Issue
Block a user