mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -33,7 +33,7 @@ __global__ void update_inputs_beam_kernel(
|
||||
if (block_idx == 0) {
|
||||
seq_lens_this_time[thread_idx] = seq_lens_this_time[bsz_index];
|
||||
seq_lens_encoder[thread_idx] = seq_lens_encoder[bsz_index];
|
||||
}
|
||||
}
|
||||
if (block_idx < seq_len) {
|
||||
input_ids[thread_idx * seq_len + block_idx] = input_ids[bsz_index * seq_len + block_idx];
|
||||
}
|
||||
@@ -74,8 +74,8 @@ void UpdateInputesBeam(
|
||||
|
||||
PD_BUILD_STATIC_OP(update_inputs_beam)
|
||||
.Inputs({"beam_width",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_encoder",
|
||||
"input_ids",
|
||||
"logits"})
|
||||
.Outputs({"seq_lens_this_time_out",
|
||||
@@ -86,4 +86,4 @@ PD_BUILD_STATIC_OP(update_inputs_beam)
|
||||
{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||
{"input_ids", "input_ids_out"},
|
||||
{"logits", "logits_out"}})
|
||||
.SetKernelFn(PD_KERNEL(UpdateInputesBeam));
|
||||
.SetKernelFn(PD_KERNEL(UpdateInputesBeam));
|
||||
|
||||
Reference in New Issue
Block a user