[XPU] bind some OPs for VL model with pybind (#4522)

This commit is contained in:
Lucas
2025-10-27 10:50:08 +08:00
committed by GitHub
parent cdc40cdc2a
commit 5c6105f4a2
8 changed files with 1789 additions and 1087 deletions

View File

@@ -15,15 +15,14 @@
#include "helper.h"
template <int THREADBLOCK_SIZE>
__global__ void update_inputs_beam_kernel(
int *seq_lens_this_time,
int *seq_lens_encoder,
int64_t *input_ids,
float *logits,
const int bsz,
const int seq_len,
const int hidden_size,
const int beam_width) {
__global__ void update_inputs_beam_kernel(int* seq_lens_this_time,
int* seq_lens_encoder,
int64_t* input_ids,
float* logits,
const int bsz,
const int seq_len,
const int hidden_size,
const int beam_width) {
int thread_idx = threadIdx.x;
int block_idx = blockIdx.x;
@@ -35,23 +34,22 @@ __global__ void update_inputs_beam_kernel(
seq_lens_encoder[thread_idx] = seq_lens_encoder[bsz_index];
}
if (block_idx < seq_len) {
input_ids[thread_idx * seq_len + block_idx] = input_ids[bsz_index * seq_len + block_idx];
input_ids[thread_idx * seq_len + block_idx] =
input_ids[bsz_index * seq_len + block_idx];
}
logits[thread_idx * hidden_size + block_idx] = logits[bsz_index * hidden_size + block_idx];
logits[thread_idx * hidden_size + block_idx] =
logits[bsz_index * hidden_size + block_idx];
}
}
__syncthreads();
}
void UpdateInputesBeam(
const paddle::Tensor& beam_width,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& input_ids,
const paddle::Tensor& logits) {
void UpdateInputsBeam(const paddle::Tensor& beam_width,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& input_ids,
const paddle::Tensor& logits) {
int beam_width_scalar = beam_width.data<int>()[0];
if (beam_width_scalar > 1) {
@@ -59,16 +57,16 @@ void UpdateInputesBeam(
const int seq_len = input_ids.shape()[1];
const int hidden_size = logits.shape()[1];
update_inputs_beam_kernel<1024><<<hidden_size, 1024, 0, input_ids.stream()>>>(
const_cast<int*>(seq_lens_this_time.data<int>()),
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int64_t*>(input_ids.data<int64_t>()),
const_cast<float*>(logits.data<float>()),
bsz,
seq_len,
hidden_size,
beam_width_scalar
);
update_inputs_beam_kernel<1024>
<<<hidden_size, 1024, 0, input_ids.stream()>>>(
const_cast<int*>(seq_lens_this_time.data<int>()),
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int64_t*>(input_ids.data<int64_t>()),
const_cast<float*>(logits.data<float>()),
bsz,
seq_len,
hidden_size,
beam_width_scalar);
}
}
@@ -86,4 +84,4 @@ PD_BUILD_STATIC_OP(update_inputs_beam)
{"seq_lens_encoder", "seq_lens_encoder_out"},
{"input_ids", "input_ids_out"},
{"logits", "logits_out"}})
.SetKernelFn(PD_KERNEL(UpdateInputesBeam));
.SetKernelFn(PD_KERNEL(UpdateInputsBeam));