[XPU] support XPU VL model inference (#4030)

* [XPU] support XPU VL model inference

* fix image op import and device check

* rebase develop

* fix perf
This commit is contained in:
Lucas
2025-09-25 14:34:15 +08:00
committed by GitHub
parent e36eccfdad
commit 87179cb744
18 changed files with 1300 additions and 146 deletions

View File

@@ -41,7 +41,9 @@ std::vector<paddle::Tensor> BlockAttnKernel(
const paddle::Tensor &encoder_seq_lod_cpu,
const paddle::Tensor &encoder_batch_map_cpu,
const paddle::Tensor &decoder_context_len_cpu,
const paddle::Tensor &decoder_batch_map_cpu) {
const paddle::Tensor &decoder_batch_map_cpu,
const std::string &pos_emb_type="NORMAL",
bool rope_3d=false) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx =
paddle::experimental::DeviceContextPool::Instance().Get(place);
@@ -72,6 +74,14 @@ std::vector<paddle::Tensor> BlockAttnKernel(
int enc_batch = enc_batch_tensor.data<int32_t>()[0];
int dec_batch = dec_batch_tensor.data<int32_t>()[0];
int total_enc_len = total_enc_len_tensor.data<int32_t>()[0];
int rope_max_seqlen = 0;
int rope_3d_num_seqs = 1;
if (rope_3d) {
rope_max_seqlen = rotary_embs.dims()[3];
rope_3d_num_seqs = rotary_embs.dims()[0];
} else {
rope_max_seqlen = rotary_embs.dims()[2];
}
auto block_attn_out =
paddle::full({token_num, hidden_dim}, -1, qkv.type(), qkv.place());
@@ -151,10 +161,10 @@ std::vector<paddle::Tensor> BlockAttnKernel(
prefix_lens_vp, // start_tokens
param.batch_size, // batch_size
1, // emb_batch_size
rotary_embs.dims()[2], // max_seqlen
rope_max_seqlen, // max_seqlen
param.head_num, param.kv_head_num, param.head_dim,
param.max_batch_size, block_size, max_block_per_seq, "BLHD",
"HLD", "NORMAL",
"HLD", pos_emb_type,
!p_kcache_perhead_scale.defined()
? nullptr
: p_kcache_perhead_scale.data<float>() +
@@ -246,10 +256,10 @@ std::vector<paddle::Tensor> BlockAttnKernel(
vsl.slot_mapping_vp, // real_batch
param.batch_size, // batch_size
1, // emb_batch_size
rotary_embs.dims()[2], // max_seqlen TODO!!double check
rope_max_seqlen, // max_seqlen
param.head_num, param.kv_head_num, param.head_dim,
param.max_batch_size, block_size, max_block_per_seq, "BLHD", "HLD",
"NORMAL",
pos_emb_type,
!p_kcache_perhead_scale.defined()
? nullptr
: p_kcache_perhead_scale.data<float>() +
@@ -260,7 +270,9 @@ std::vector<paddle::Tensor> BlockAttnKernel(
param.kv_head_num, // v_cache_scale_inv
nullptr, // k_cache_zp
nullptr, // v_cache_zp
false); // b_c8_pc
false, // b_c8_pc
rope_3d, // rope_3d
rope_3d_num_seqs);
XFTBLOCK_CHECK_EQ(ret, api::SUCCESS);
// attn decode
@@ -314,6 +326,7 @@ PD_BUILD_OP(block_attn)
"decoder_context_len_cpu",
"decoder_batch_map_cpu",
})
.Attrs({"pos_emb_type:std::string", "rope_3d:bool"})
.Outputs({"block_attn_out"})
.SetKernelFn(PD_KERNEL(BlockAttnKernel))
.SetInferShapeFn(PD_INFER_SHAPE(BlockAttnInferShape))

View File

@@ -0,0 +1,60 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
std::vector<paddle::Tensor> GetImgBoundaries(const paddle::Tensor& task_input_ids,
const paddle::Tensor& grid_thw,
const int64_t image_patch_id) {
// All tensor in cpu
auto input_ids_ptr = task_input_ids.data<int64_t>();
int64_t seq_lens_origin = task_input_ids.numel();
auto grid_thw_ptr = grid_thw.data<int64_t>();
int token_times = 4;
int token_idx = 0;
int image_idx = 0;
std::vector<int> img_boundaries, img_nums;
img_boundaries.emplace_back(0);
img_nums.emplace_back(0);
while (token_idx < seq_lens_origin) {
if (input_ids_ptr[token_idx] != image_patch_id) {
do {
token_idx++;
} while (token_idx < seq_lens_origin && input_ids_ptr[token_idx] != image_patch_id);
} else {
int cur_image_token_len = (grid_thw_ptr[image_idx * 3 + 1] * grid_thw_ptr[image_idx * 3 + 2]) / token_times;
image_idx++;
token_idx += cur_image_token_len;
}
img_boundaries.emplace_back(token_idx);
img_nums.emplace_back(image_idx);
}
int64_t num_img_boundaries = static_cast<int64_t>(img_boundaries.size());
auto out = paddle::full({2, num_img_boundaries}, 0, paddle::DataType::INT64, paddle::CPUPlace());
for (int i = 0; i < num_img_boundaries; i++) {
out.data<int64_t>()[i] = img_boundaries[i];
out.data<int64_t>()[num_img_boundaries + i] = img_nums[i];
}
return {out};
}
PD_BUILD_OP(get_img_boundaries)
.Inputs({"task_input_ids", "grid_thw"})
.Attrs({"image_patch_id: int64_t"})
.Outputs({"img_boundaries"})
.SetKernelFn(PD_KERNEL(GetImgBoundaries));

View File

@@ -145,7 +145,8 @@ std::vector<paddle::Tensor> MoeLayerKernel(
? up_gate_proj_weight_scale.get_ptr()->data<float>()
: nullptr),
xftblock_tw,
std::vector<int64_t>{expert_num, inter_dim, hidden_dim});
std::vector<int64_t>{expert_num, inter_dim, hidden_dim}
);
xdown_proj_w = std::make_shared<xftblock::Tensor>(
const_cast<TW *>(down_proj_weight.data<TW>()), nullptr,
@@ -153,7 +154,8 @@ std::vector<paddle::Tensor> MoeLayerKernel(
? down_proj_weight_scale.get_ptr()->data<float>()
: nullptr),
xftblock_tw,
std::vector<int64_t>{expert_num, hidden_dim, outer_dim});
std::vector<int64_t>{expert_num, hidden_dim, outer_dim}
);
}
std::shared_ptr<xftblock::Tensor> xup_gate_proj_bias;
std::shared_ptr<xftblock::Tensor> xdown_proj_bias;

View File

@@ -0,0 +1,83 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include <xft/xdnn_plugin.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"
void TextImageGatherScatter(
paddle::Tensor& input,
paddle::Tensor& text_input,
paddle::Tensor& image_input,
paddle::Tensor& token_type_ids,
paddle::Tensor& text_index,
paddle::Tensor& image_index,
const bool is_scatter) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
const int64_t token_num = input.dims()[0];
const int64_t hidden_size = input.dims()[1];
const int64_t text_token_num = text_input.dims()[0];
const int64_t image_token_num = image_input.dims()[0];
switch (input.type()) {
case paddle::DataType::BFLOAT16: {
using XPUType = typename XPUTypeTrait<bfloat16>::Type;
typedef paddle::bfloat16 data_t;
int r = baidu::xpu::api::plugin::text_image_gather_scatter<XPUType>(
xpu_ctx->x_context(),
reinterpret_cast<XPUType*>(input.data<data_t>()),
reinterpret_cast<XPUType*>(text_input.data<data_t>()),
reinterpret_cast<XPUType*>(image_input.data<data_t>()),
reinterpret_cast<int*>(token_type_ids.data<int>()),
reinterpret_cast<int*>(text_index.data<int>()),
reinterpret_cast<int*>(image_index.data<int>()),
token_num,
text_token_num,
image_token_num,
hidden_size,
is_scatter
);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "text_image_gather_scatter");
break;
}
default: {
PD_THROW(
"NOT supported data type. Only support BFLOAT16. ");
break;
}
}
}
PD_BUILD_OP(text_image_gather_scatter)
.Inputs({"input",
"text_input",
"image_input",
"token_type_ids",
"text_index",
"image_index"})
.Outputs({"text_input_out",
"image_input_out",
"text_index_out",
"image_index_out"})
.Attrs({"is_scatter:bool"})
.SetInplaceMap({{"text_input", "text_input_out"},
{"image_input", "image_input_out"},
{"text_index", "text_index_out"},
{"image_index", "image_index_out"}})
.SetKernelFn(PD_KERNEL(TextImageGatherScatter));

View File

@@ -0,0 +1,48 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"
void TextImageIndexOut(
const paddle::Tensor& token_type_ids,
const paddle::Tensor& text_index,
const paddle::Tensor& image_index) {
if (token_type_ids.type() != paddle::DataType::INT32 || text_index.type()
!= paddle::DataType::INT32 || image_index.type() != paddle::DataType::INT32) {
PD_THROW("NOT supported data type. Only support BFLOAT16. ");
}
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
const int64_t token_num = token_type_ids.shape()[0];
int r = baidu::xpu::api::plugin::text_image_index_out(xpu_ctx->x_context(),
token_type_ids.data<int32_t>(),
const_cast<int32_t*>(text_index.data<int32_t>()),
const_cast<int32_t*>(image_index.data<int32_t>()),
token_num);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "text_image_index_out");
}
PD_BUILD_OP(text_image_index_out)
.Inputs({"token_type_ids",
"text_index",
"image_index"})
.Outputs({"text_index_out",
"image_index_out"})
.SetInplaceMap({{"text_index", "text_index_out"},
{"image_index", "image_index_out"}})
.SetKernelFn(PD_KERNEL(TextImageIndexOut));

View File

@@ -140,6 +140,25 @@ DLL_EXPORT int quant2d_per_channel(api::Context *ctx, const TX *x,
const TSCALE *scale_in, TY *y,
TSCALE *scale_out, int64_t m, int64_t n);
DLL_EXPORT int text_image_index_out(Context* ctx,
const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
const int64_t token_num);
template <typename T>
DLL_EXPORT int text_image_gather_scatter(api::Context* ctx,
T* input,
T* text_input,
T* image_input,
int* token_type_ids,
int* text_index,
int* image_index,
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
bool is_scatter);
/*--------------------------------------- MTP being --------------------------------------------*/

View File

@@ -0,0 +1,175 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
#include "xpu/kernel/xtdk_io.h"
namespace xpu3 {
namespace plugin {
template <typename T>
static __device__ inline void text_image_gather(
__global_ptr__ T* input,
__global_ptr__ T* text_input,
__global_ptr__ T* image_input,
__global_ptr__ int* token_type_ids,
__global_ptr__ int* text_index,
__global_ptr__ int* image_index,
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
T* input_lm) {
int cid = core_id();
int clusterid = cluster_id();
int token_start_cluster;
int token_end_cluster;
int token_start_core;
int token_end_core;
const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32
// cluster partition
partition(cluster_id(), cluster_num(), (int)token_num, 1, &token_start_cluster, &token_end_cluster);
if (token_start_cluster >= token_end_cluster) {
return;
}
int rows_cluster = token_end_cluster - token_start_cluster; // total rows for a cluster
// core partition
partition(core_id(), core_num(), rows_cluster, 1, &token_start_core, &token_end_core);
int rows_core = token_end_core - token_start_core; // total rows for a core
token_start_core += token_start_cluster;
token_end_core += token_start_cluster;
int read_len;
for (int i = token_start_core; i < token_end_core; i += 1) {
int token_type, text_image_token_idx;
__global_ptr__ T* text_image_input = nullptr;
__global_ptr__ int* text_image_index = nullptr;
GM2LM(token_type_ids + i, &token_type, sizeof(int));
if (token_type == 0) {
text_image_input = text_input;
text_image_index = text_index;
} else {
text_image_input = image_input;
text_image_index = image_index;
}
GM2LM(text_image_index + i, &text_image_token_idx, sizeof(int));
int input_offset = i * hidden_size;
int text_image_offset = text_image_token_idx * hidden_size;
for (int j = 0; j < hidden_size; j += BUFSIZE) {
read_len = min(hidden_size - j, BUFSIZE);
GM2LM(text_image_input + text_image_offset + j, input_lm, sizeof(T) * read_len);
LM2GM(input_lm, input + input_offset + j, sizeof(T) * read_len);
}
}
}
template <typename T>
static __device__ inline void text_image_scatter(
__global_ptr__ T* input,
__global_ptr__ T* text_input,
__global_ptr__ T* image_input,
__global_ptr__ int* token_type_ids,
__global_ptr__ int* text_index,
__global_ptr__ int* image_index,
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
T* input_lm) {
int cid = core_id();
int clusterid = cluster_id();
int token_start_cluster;
int token_end_cluster;
int token_start_core;
int token_end_core;
const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32
// cluster partition
partition(cluster_id(), cluster_num(), (int)token_num, 1, &token_start_cluster, &token_end_cluster);
if (token_start_cluster >= token_end_cluster) {
return;
}
int rows_cluster = token_end_cluster - token_start_cluster; // total rows for a cluster
// core partition
partition(core_id(), core_num(), rows_cluster, 1, &token_start_core, &token_end_core);
int rows_core = token_end_core - token_start_core; // total rows for a core
token_start_core += token_start_cluster;
token_end_core += token_start_cluster;
int read_len;
for (int i = token_start_core; i < token_end_core; i += 1) {
int token_type, text_image_token_idx;
__global_ptr__ T* text_image_input = nullptr;
__global_ptr__ int* text_image_index = nullptr;
GM2LM(token_type_ids + i, &token_type, sizeof(int));
if (token_type == 0) {
text_image_input = text_input;
text_image_index = text_index;
} else {
text_image_input = image_input;
text_image_index = image_index;
}
GM2LM(text_image_index + i, &text_image_token_idx, sizeof(int));
int input_offset = i * hidden_size;
int text_image_offset = text_image_token_idx * hidden_size;
for (int j = 0; j < hidden_size; j += BUFSIZE) {
read_len = min(hidden_size - j, BUFSIZE);
GM2LM(input + input_offset + j, input_lm, sizeof(T) * read_len);
LM2GM(input_lm, text_image_input + text_image_offset + j, sizeof(T) * read_len);
}
}
}
template <typename T>
__global__ void text_image_gather_scatter(
T* input,
T* text_input,
T* image_input,
int* token_type_ids,
int* text_index,
int* image_index,
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
bool is_scatter) {
int cid = core_id();
int ncores = core_num();
int clusterid = cluster_id();
int nclusters = cluster_num();
const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32
__simd__ T input_lm[BUFSIZE]; // 2KB for bf16 and fp32
if (is_scatter) {
text_image_scatter(
input, text_input, image_input, token_type_ids, text_index, image_index,
token_num, text_token_num, image_token_num, hidden_size, input_lm);
} else {
text_image_gather(
input, text_input, image_input, token_type_ids, text_index, image_index,
token_num, text_token_num, image_token_num, hidden_size, input_lm);
}
}
#define _XPU_DEF_TEXT_IMAGE_GATHER_SCATTER(T) \
template __global__ void text_image_gather_scatter<T>( \
T* input, \
T* text_input, \
T* image_input, \
int* token_type_ids, \
int* text_index, \
int* image_index, \
int64_t token_num, \
int64_t text_token_num, \
int64_t image_token_num, \
int64_t hidden_size, \
bool is_scatter);
_XPU_DEF_TEXT_IMAGE_GATHER_SCATTER(bfloat16);
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,97 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2025 KUNLUNXIN, Inc
*/
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
#include "xpu/kernel/cluster_primitive_template.h"
namespace xpu3 {
namespace plugin {
static __device__ void do_calc(const _shared_ptr_ int* lm_x, int* lm_y1, int* lm_y2, int64_t size, int& text_count, int& images_count) {
for (int j = 0; j < size; j++) {
if (lm_x[j] == 0) {
lm_y1[j] = text_count;
text_count += 1;
} else {
lm_y2[j] = images_count;
images_count += 1;
}
}
mfence_lm_sm();
}
__global__ void text_image_index_out_kernel(
const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
const int64_t token_num) {
const int cid = core_id();
const int tid = core_id() * cluster_num() + cluster_id();
const int nthreads = core_num() * cluster_num();
if (tid >= 1) return;
constexpr int BUFSIZE = 1024;
constexpr int READ_MAX_SIZE = BUFSIZE / sizeof(int);
const int64_t len = token_num;
__simd__ char buffer0[BUFSIZE * 3];
__simd__ char buffer1[BUFSIZE * 3];
__simd__ __shared__ char buffer2[64][BUFSIZE * 2];
DoublePtr<READ_MAX_SIZE, SmPtr<int>> buffer_ptr_x((SmPtr<int>((_shared_ptr_ int*)buffer2[cid])));
TriplePtr<READ_MAX_SIZE, LmPtr<int>> buffer_ptr_y1((LmPtr<int>((int*)buffer0)));
TriplePtr<READ_MAX_SIZE, LmPtr<int>> buffer_ptr_y2((LmPtr<int>((int*)buffer1)));
int64_t buflen = get_1d_buflen(len, nthreads, READ_MAX_SIZE, 64);
int64_t i = tid * buflen;
int read_size = 0;
int offset = nthreads * buflen;
int text_count = 0;
int images_count = 0;
if (i < len) {
read_size = min<int64_t>(buflen, len - i);
buffer_ptr_y1.gm_load_async(text_index + tid * buflen, read_size);
buffer_ptr_y2.gm_load_async(image_index + tid * buflen, read_size);
buffer_ptr_x.gm_load_async(token_type_ids + tid * buflen, read_size);
mfence();
}
while (i < len && i + offset < len) {
i = i + offset;
int read_size_next = min<int64_t>(buflen, len - i);
buffer_ptr_x.next().gm_load_async(token_type_ids + i, read_size_next);
buffer_ptr_y1.next().gm_load_async(text_index + i, read_size_next);
buffer_ptr_y2.next().gm_load_async(image_index + i, read_size_next);
do_calc(buffer_ptr_x.ptr, buffer_ptr_y1.ptr, buffer_ptr_y2.ptr, read_size, text_count, images_count);
buffer_ptr_y1.gm_store_async(text_index + i - offset, read_size);
buffer_ptr_y2.gm_store_async(image_index + i - offset, read_size);
buffer_ptr_x.toggle();
buffer_ptr_y1.toggle();
buffer_ptr_y2.toggle();
read_size = read_size_next;
}
if (i < len) {
do_calc(buffer_ptr_x.ptr, buffer_ptr_y1.ptr, buffer_ptr_y2.ptr, read_size, text_count, images_count);
buffer_ptr_y1.gm_store_async(text_index + i, read_size);
buffer_ptr_y2.gm_store(image_index + i, read_size);
}
}
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,182 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
template <typename T>
__attribute__((global)) void text_image_gather_scatter(
T* input,
T* text_input,
T* image_input,
int* token_type_ids,
int* text_index,
int* image_index,
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
bool is_scatter);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
template <typename T>
static int cpu_wrapper(
Context* ctx,
T* input, // shape [token_num, hidden_size]
T* text_input, // shape [text_token_num, hidden_size]
T* image_input, // shape [image_token_num, hidden_size]
int* token_type_ids,// shape [token_num], 0 for text, 1 for image
int* text_index, // shape [token_num], mapping from input to text_input
int* image_index, // shape [token_num], mapping from input to image_input
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
bool is_scatter) {
if (is_scatter) {
// Scatter mode: input -> text_input/image_input
for (int64_t i = 0; i < token_num; i++) {
int token_type = token_type_ids[i];
T* text_image_input = nullptr;
int* text_image_index = nullptr;
if (token_type == 0) {
text_image_input = text_input;
text_image_index = text_index;
} else { // token_type == 1
text_image_input = image_input;
text_image_index = image_index;
}
int text_image_token_idx = text_image_index[i];
int input_offset = i * hidden_size;
int text_image_offset = text_image_token_idx * hidden_size;
for (int64_t j = 0; j < hidden_size; j++) {
T value = input[input_offset + j];
text_image_input[text_image_offset + j] = value;
}
}
} else {
// Gather mode: text_input/image_input -> input
for (int64_t i = 0; i < token_num; i++) {
int token_type = token_type_ids[i];
T* text_image_input = nullptr;
int* text_image_index = nullptr;
if (token_type == 0) {
text_image_input = text_input;
text_image_index = text_index;
} else { // token_type == 1
text_image_input = image_input;
text_image_index = image_index;
}
int text_image_token_idx = text_image_index[i];
int input_offset = i * hidden_size;
int text_image_offset = text_image_token_idx * hidden_size;
for (int64_t j = 0; j < hidden_size; j++) {
T value = text_image_input[text_image_offset + j];
input[input_offset + j] = value;
}
}
}
return api::SUCCESS;
}
template <typename T>
static int xpu3_wrapper(
Context* ctx,
T* input,
T* text_input,
T* image_input,
int* token_type_ids,
int* text_index,
int* image_index,
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
bool is_scatter) {
xpu3::plugin::text_image_gather_scatter<T> <<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
input, text_input, image_input, token_type_ids, text_index, image_index,
token_num, text_token_num, image_token_num, hidden_size, is_scatter
);
return api::SUCCESS;
}
template <typename T>
int text_image_gather_scatter(
Context* ctx,
T* input, // shape [token_num, hidden_size]
T* text_input, // shape [text_token_num, hidden_size]
T* image_input, // shape [image_token_num, hidden_size]
int* token_type_ids,// shape [token_num], 0 for text, 1 for image
int* text_index, // shape [token_num], mapping from input to text_input
int* image_index, // shape [token_num], mapping from input to image_input
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
bool is_scatter) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "text_image_gather_scatter", T);
WRAPPER_DUMP_PARAM6(ctx, input, text_input, image_input, token_type_ids, text_index, image_index);
WRAPPER_DUMP_PARAM5(ctx, token_num, text_token_num, image_token_num, hidden_size, is_scatter);
WRAPPER_DUMP(ctx);
WRAPPER_CHECK_PTR(ctx, T, token_num * hidden_size, input);
if (text_token_num != 0) { // avoiding text_input tensor with shape [0, hidden_size]
WRAPPER_CHECK_PTR(ctx, T, text_token_num * hidden_size, text_input);
}
if (image_token_num != 0) { // avoiding image_input tensor with shape [0, hidden_size]
WRAPPER_CHECK_PTR(ctx, T, image_token_num * hidden_size, image_input);
}
WRAPPER_CHECK_PTR(ctx, int, token_num, token_type_ids);
WRAPPER_CHECK_PTR(ctx, int, token_num, text_index);
WRAPPER_CHECK_PTR(ctx, int, token_num, image_index);
WRAPPER_ASSERT_EQ(ctx, token_num, text_token_num + image_token_num);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper<T>(
ctx, input, text_input, image_input, token_type_ids, text_index, image_index,
token_num, text_token_num, image_token_num, hidden_size, is_scatter
);
}
if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper<T>(
ctx, input, text_input, image_input, token_type_ids, text_index, image_index,
token_num, text_token_num, image_token_num, hidden_size, is_scatter
);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
template int text_image_gather_scatter(
Context*, bfloat16*, bfloat16*, bfloat16*, int*, int*, int*, const int64_t, const int64_t, const int64_t, const int64_t, bool);
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -0,0 +1,103 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
__attribute__((global)) void text_image_index_out_kernel(const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
const int64_t token_num);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int cpu_wrapper(Context* ctx,
const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
const int64_t token_num) {
int text_count = 0;
int image_count = 0;
for (int64_t i = 0; i < token_num; ++i) {
if (token_type_ids[i] == 0) {
text_index[i] = text_count;
++text_count;
} else {
image_index[i] = image_count;
++image_count;
}
}
return api::SUCCESS;
}
static int xpu3_wrapper(Context* ctx,
const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
const int64_t token_num) {
xpu3::plugin::text_image_index_out_kernel<<<1, 1, ctx->xpu_stream>>>(
token_type_ids,
text_index,
image_index,
token_num);
return api::SUCCESS;
}
int text_image_index_out(Context* ctx,
const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
const int64_t token_num) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "text_image_index_out", int);
WRAPPER_DUMP_PARAM4(
ctx, token_type_ids, text_index, image_index, token_num);
WRAPPER_DUMP(ctx);
WRAPPER_ASSERT_GT(ctx, token_num, 0);
WRAPPER_CHECK_PTR(ctx, int, token_num, token_type_ids);
WRAPPER_CHECK_PTR(ctx, int, token_num, text_index);
WRAPPER_CHECK_PTR(ctx, int, token_num, image_index);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx,
token_type_ids,
text_index,
image_index,
token_num);
} else if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
token_type_ids,
text_index,
image_index,
token_num);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu