mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[XPU] support XPU VL model inference (#4030)
* [XPU] support XPU VL model inference * fix image op import and device check * rebase develop * fix perf
This commit is contained in:
@@ -41,7 +41,9 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
const paddle::Tensor &encoder_seq_lod_cpu,
|
||||
const paddle::Tensor &encoder_batch_map_cpu,
|
||||
const paddle::Tensor &decoder_context_len_cpu,
|
||||
const paddle::Tensor &decoder_batch_map_cpu) {
|
||||
const paddle::Tensor &decoder_batch_map_cpu,
|
||||
const std::string &pos_emb_type="NORMAL",
|
||||
bool rope_3d=false) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx =
|
||||
paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
@@ -72,6 +74,14 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
int enc_batch = enc_batch_tensor.data<int32_t>()[0];
|
||||
int dec_batch = dec_batch_tensor.data<int32_t>()[0];
|
||||
int total_enc_len = total_enc_len_tensor.data<int32_t>()[0];
|
||||
int rope_max_seqlen = 0;
|
||||
int rope_3d_num_seqs = 1;
|
||||
if (rope_3d) {
|
||||
rope_max_seqlen = rotary_embs.dims()[3];
|
||||
rope_3d_num_seqs = rotary_embs.dims()[0];
|
||||
} else {
|
||||
rope_max_seqlen = rotary_embs.dims()[2];
|
||||
}
|
||||
|
||||
auto block_attn_out =
|
||||
paddle::full({token_num, hidden_dim}, -1, qkv.type(), qkv.place());
|
||||
@@ -151,10 +161,10 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
prefix_lens_vp, // start_tokens
|
||||
param.batch_size, // batch_size
|
||||
1, // emb_batch_size
|
||||
rotary_embs.dims()[2], // max_seqlen
|
||||
rope_max_seqlen, // max_seqlen
|
||||
param.head_num, param.kv_head_num, param.head_dim,
|
||||
param.max_batch_size, block_size, max_block_per_seq, "BLHD",
|
||||
"HLD", "NORMAL",
|
||||
"HLD", pos_emb_type,
|
||||
!p_kcache_perhead_scale.defined()
|
||||
? nullptr
|
||||
: p_kcache_perhead_scale.data<float>() +
|
||||
@@ -246,10 +256,10 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
vsl.slot_mapping_vp, // real_batch
|
||||
param.batch_size, // batch_size
|
||||
1, // emb_batch_size
|
||||
rotary_embs.dims()[2], // max_seqlen TODO!!double check
|
||||
rope_max_seqlen, // max_seqlen
|
||||
param.head_num, param.kv_head_num, param.head_dim,
|
||||
param.max_batch_size, block_size, max_block_per_seq, "BLHD", "HLD",
|
||||
"NORMAL",
|
||||
pos_emb_type,
|
||||
!p_kcache_perhead_scale.defined()
|
||||
? nullptr
|
||||
: p_kcache_perhead_scale.data<float>() +
|
||||
@@ -260,7 +270,9 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
param.kv_head_num, // v_cache_scale_inv
|
||||
nullptr, // k_cache_zp
|
||||
nullptr, // v_cache_zp
|
||||
false); // b_c8_pc
|
||||
false, // b_c8_pc
|
||||
rope_3d, // rope_3d
|
||||
rope_3d_num_seqs);
|
||||
XFTBLOCK_CHECK_EQ(ret, api::SUCCESS);
|
||||
|
||||
// attn decode
|
||||
@@ -314,6 +326,7 @@ PD_BUILD_OP(block_attn)
|
||||
"decoder_context_len_cpu",
|
||||
"decoder_batch_map_cpu",
|
||||
})
|
||||
.Attrs({"pos_emb_type:std::string", "rope_3d:bool"})
|
||||
.Outputs({"block_attn_out"})
|
||||
.SetKernelFn(PD_KERNEL(BlockAttnKernel))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(BlockAttnInferShape))
|
||||
|
||||
60
custom_ops/xpu_ops/src/ops/get_img_boundaries.cc
Normal file
60
custom_ops/xpu_ops/src/ops/get_img_boundaries.cc
Normal file
@@ -0,0 +1,60 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
std::vector<paddle::Tensor> GetImgBoundaries(const paddle::Tensor& task_input_ids,
|
||||
const paddle::Tensor& grid_thw,
|
||||
const int64_t image_patch_id) {
|
||||
// All tensor in cpu
|
||||
auto input_ids_ptr = task_input_ids.data<int64_t>();
|
||||
int64_t seq_lens_origin = task_input_ids.numel();
|
||||
auto grid_thw_ptr = grid_thw.data<int64_t>();
|
||||
|
||||
int token_times = 4;
|
||||
int token_idx = 0;
|
||||
int image_idx = 0;
|
||||
std::vector<int> img_boundaries, img_nums;
|
||||
img_boundaries.emplace_back(0);
|
||||
img_nums.emplace_back(0);
|
||||
while (token_idx < seq_lens_origin) {
|
||||
if (input_ids_ptr[token_idx] != image_patch_id) {
|
||||
do {
|
||||
token_idx++;
|
||||
} while (token_idx < seq_lens_origin && input_ids_ptr[token_idx] != image_patch_id);
|
||||
} else {
|
||||
int cur_image_token_len = (grid_thw_ptr[image_idx * 3 + 1] * grid_thw_ptr[image_idx * 3 + 2]) / token_times;
|
||||
image_idx++;
|
||||
token_idx += cur_image_token_len;
|
||||
}
|
||||
img_boundaries.emplace_back(token_idx);
|
||||
img_nums.emplace_back(image_idx);
|
||||
}
|
||||
|
||||
int64_t num_img_boundaries = static_cast<int64_t>(img_boundaries.size());
|
||||
auto out = paddle::full({2, num_img_boundaries}, 0, paddle::DataType::INT64, paddle::CPUPlace());
|
||||
|
||||
for (int i = 0; i < num_img_boundaries; i++) {
|
||||
out.data<int64_t>()[i] = img_boundaries[i];
|
||||
out.data<int64_t>()[num_img_boundaries + i] = img_nums[i];
|
||||
}
|
||||
|
||||
return {out};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(get_img_boundaries)
|
||||
.Inputs({"task_input_ids", "grid_thw"})
|
||||
.Attrs({"image_patch_id: int64_t"})
|
||||
.Outputs({"img_boundaries"})
|
||||
.SetKernelFn(PD_KERNEL(GetImgBoundaries));
|
||||
@@ -145,7 +145,8 @@ std::vector<paddle::Tensor> MoeLayerKernel(
|
||||
? up_gate_proj_weight_scale.get_ptr()->data<float>()
|
||||
: nullptr),
|
||||
xftblock_tw,
|
||||
std::vector<int64_t>{expert_num, inter_dim, hidden_dim});
|
||||
std::vector<int64_t>{expert_num, inter_dim, hidden_dim}
|
||||
);
|
||||
|
||||
xdown_proj_w = std::make_shared<xftblock::Tensor>(
|
||||
const_cast<TW *>(down_proj_weight.data<TW>()), nullptr,
|
||||
@@ -153,7 +154,8 @@ std::vector<paddle::Tensor> MoeLayerKernel(
|
||||
? down_proj_weight_scale.get_ptr()->data<float>()
|
||||
: nullptr),
|
||||
xftblock_tw,
|
||||
std::vector<int64_t>{expert_num, hidden_dim, outer_dim});
|
||||
std::vector<int64_t>{expert_num, hidden_dim, outer_dim}
|
||||
);
|
||||
}
|
||||
std::shared_ptr<xftblock::Tensor> xup_gate_proj_bias;
|
||||
std::shared_ptr<xftblock::Tensor> xdown_proj_bias;
|
||||
|
||||
83
custom_ops/xpu_ops/src/ops/text_image_gather_scatter.cc
Normal file
83
custom_ops/xpu_ops/src/ops/text_image_gather_scatter.cc
Normal file
@@ -0,0 +1,83 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <paddle/phi/backends/xpu/xpu_context.h>
|
||||
#include <xft/xdnn_plugin.h>
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
void TextImageGatherScatter(
|
||||
paddle::Tensor& input,
|
||||
paddle::Tensor& text_input,
|
||||
paddle::Tensor& image_input,
|
||||
paddle::Tensor& token_type_ids,
|
||||
paddle::Tensor& text_index,
|
||||
paddle::Tensor& image_index,
|
||||
const bool is_scatter) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
|
||||
|
||||
const int64_t token_num = input.dims()[0];
|
||||
const int64_t hidden_size = input.dims()[1];
|
||||
const int64_t text_token_num = text_input.dims()[0];
|
||||
const int64_t image_token_num = image_input.dims()[0];
|
||||
|
||||
switch (input.type()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
using XPUType = typename XPUTypeTrait<bfloat16>::Type;
|
||||
typedef paddle::bfloat16 data_t;
|
||||
int r = baidu::xpu::api::plugin::text_image_gather_scatter<XPUType>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<XPUType*>(input.data<data_t>()),
|
||||
reinterpret_cast<XPUType*>(text_input.data<data_t>()),
|
||||
reinterpret_cast<XPUType*>(image_input.data<data_t>()),
|
||||
reinterpret_cast<int*>(token_type_ids.data<int>()),
|
||||
reinterpret_cast<int*>(text_index.data<int>()),
|
||||
reinterpret_cast<int*>(image_index.data<int>()),
|
||||
token_num,
|
||||
text_token_num,
|
||||
image_token_num,
|
||||
hidden_size,
|
||||
is_scatter
|
||||
);
|
||||
PADDLE_ENFORCE_XDNN_SUCCESS(r, "text_image_gather_scatter");
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
PD_THROW(
|
||||
"NOT supported data type. Only support BFLOAT16. ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_OP(text_image_gather_scatter)
|
||||
.Inputs({"input",
|
||||
"text_input",
|
||||
"image_input",
|
||||
"token_type_ids",
|
||||
"text_index",
|
||||
"image_index"})
|
||||
.Outputs({"text_input_out",
|
||||
"image_input_out",
|
||||
"text_index_out",
|
||||
"image_index_out"})
|
||||
.Attrs({"is_scatter:bool"})
|
||||
.SetInplaceMap({{"text_input", "text_input_out"},
|
||||
{"image_input", "image_input_out"},
|
||||
{"text_index", "text_index_out"},
|
||||
{"image_index", "image_index_out"}})
|
||||
.SetKernelFn(PD_KERNEL(TextImageGatherScatter));
|
||||
48
custom_ops/xpu_ops/src/ops/text_image_index_out.cc
Normal file
48
custom_ops/xpu_ops/src/ops/text_image_index_out.cc
Normal file
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <paddle/phi/backends/xpu/xpu_context.h>
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
void TextImageIndexOut(
|
||||
const paddle::Tensor& token_type_ids,
|
||||
const paddle::Tensor& text_index,
|
||||
const paddle::Tensor& image_index) {
|
||||
if (token_type_ids.type() != paddle::DataType::INT32 || text_index.type()
|
||||
!= paddle::DataType::INT32 || image_index.type() != paddle::DataType::INT32) {
|
||||
PD_THROW("NOT supported data type. Only support BFLOAT16. ");
|
||||
}
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
|
||||
const int64_t token_num = token_type_ids.shape()[0];
|
||||
int r = baidu::xpu::api::plugin::text_image_index_out(xpu_ctx->x_context(),
|
||||
token_type_ids.data<int32_t>(),
|
||||
const_cast<int32_t*>(text_index.data<int32_t>()),
|
||||
const_cast<int32_t*>(image_index.data<int32_t>()),
|
||||
token_num);
|
||||
PADDLE_ENFORCE_XDNN_SUCCESS(r, "text_image_index_out");
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_OP(text_image_index_out)
|
||||
.Inputs({"token_type_ids",
|
||||
"text_index",
|
||||
"image_index"})
|
||||
.Outputs({"text_index_out",
|
||||
"image_index_out"})
|
||||
.SetInplaceMap({{"text_index", "text_index_out"},
|
||||
{"image_index", "image_index_out"}})
|
||||
.SetKernelFn(PD_KERNEL(TextImageIndexOut));
|
||||
@@ -140,6 +140,25 @@ DLL_EXPORT int quant2d_per_channel(api::Context *ctx, const TX *x,
|
||||
const TSCALE *scale_in, TY *y,
|
||||
TSCALE *scale_out, int64_t m, int64_t n);
|
||||
|
||||
DLL_EXPORT int text_image_index_out(Context* ctx,
|
||||
const int* token_type_ids, // x
|
||||
int* text_index, // y1
|
||||
int* image_index, // y2
|
||||
const int64_t token_num);
|
||||
|
||||
template <typename T>
|
||||
DLL_EXPORT int text_image_gather_scatter(api::Context* ctx,
|
||||
T* input,
|
||||
T* text_input,
|
||||
T* image_input,
|
||||
int* token_type_ids,
|
||||
int* text_index,
|
||||
int* image_index,
|
||||
int64_t token_num,
|
||||
int64_t text_token_num,
|
||||
int64_t image_token_num,
|
||||
int64_t hidden_size,
|
||||
bool is_scatter);
|
||||
|
||||
/*--------------------------------------- MTP being --------------------------------------------*/
|
||||
|
||||
|
||||
@@ -0,0 +1,175 @@
|
||||
#include "xpu/kernel/cluster.h"
|
||||
#include "xpu/kernel/cluster_partition.h"
|
||||
#include "xpu/kernel/cluster_primitive.h"
|
||||
#include "xpu/kernel/xtdk_io.h"
|
||||
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
|
||||
template <typename T>
|
||||
static __device__ inline void text_image_gather(
|
||||
__global_ptr__ T* input,
|
||||
__global_ptr__ T* text_input,
|
||||
__global_ptr__ T* image_input,
|
||||
__global_ptr__ int* token_type_ids,
|
||||
__global_ptr__ int* text_index,
|
||||
__global_ptr__ int* image_index,
|
||||
int64_t token_num,
|
||||
int64_t text_token_num,
|
||||
int64_t image_token_num,
|
||||
int64_t hidden_size,
|
||||
T* input_lm) {
|
||||
int cid = core_id();
|
||||
int clusterid = cluster_id();
|
||||
int token_start_cluster;
|
||||
int token_end_cluster;
|
||||
int token_start_core;
|
||||
int token_end_core;
|
||||
|
||||
const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32
|
||||
// cluster partition
|
||||
partition(cluster_id(), cluster_num(), (int)token_num, 1, &token_start_cluster, &token_end_cluster);
|
||||
if (token_start_cluster >= token_end_cluster) {
|
||||
return;
|
||||
}
|
||||
int rows_cluster = token_end_cluster - token_start_cluster; // total rows for a cluster
|
||||
// core partition
|
||||
partition(core_id(), core_num(), rows_cluster, 1, &token_start_core, &token_end_core);
|
||||
int rows_core = token_end_core - token_start_core; // total rows for a core
|
||||
token_start_core += token_start_cluster;
|
||||
token_end_core += token_start_cluster;
|
||||
|
||||
int read_len;
|
||||
for (int i = token_start_core; i < token_end_core; i += 1) {
|
||||
int token_type, text_image_token_idx;
|
||||
__global_ptr__ T* text_image_input = nullptr;
|
||||
__global_ptr__ int* text_image_index = nullptr;
|
||||
|
||||
GM2LM(token_type_ids + i, &token_type, sizeof(int));
|
||||
if (token_type == 0) {
|
||||
text_image_input = text_input;
|
||||
text_image_index = text_index;
|
||||
} else {
|
||||
text_image_input = image_input;
|
||||
text_image_index = image_index;
|
||||
}
|
||||
GM2LM(text_image_index + i, &text_image_token_idx, sizeof(int));
|
||||
int input_offset = i * hidden_size;
|
||||
int text_image_offset = text_image_token_idx * hidden_size;
|
||||
|
||||
for (int j = 0; j < hidden_size; j += BUFSIZE) {
|
||||
read_len = min(hidden_size - j, BUFSIZE);
|
||||
GM2LM(text_image_input + text_image_offset + j, input_lm, sizeof(T) * read_len);
|
||||
LM2GM(input_lm, input + input_offset + j, sizeof(T) * read_len);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __device__ inline void text_image_scatter(
|
||||
__global_ptr__ T* input,
|
||||
__global_ptr__ T* text_input,
|
||||
__global_ptr__ T* image_input,
|
||||
__global_ptr__ int* token_type_ids,
|
||||
__global_ptr__ int* text_index,
|
||||
__global_ptr__ int* image_index,
|
||||
int64_t token_num,
|
||||
int64_t text_token_num,
|
||||
int64_t image_token_num,
|
||||
int64_t hidden_size,
|
||||
T* input_lm) {
|
||||
int cid = core_id();
|
||||
int clusterid = cluster_id();
|
||||
int token_start_cluster;
|
||||
int token_end_cluster;
|
||||
int token_start_core;
|
||||
int token_end_core;
|
||||
|
||||
const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32
|
||||
// cluster partition
|
||||
partition(cluster_id(), cluster_num(), (int)token_num, 1, &token_start_cluster, &token_end_cluster);
|
||||
if (token_start_cluster >= token_end_cluster) {
|
||||
return;
|
||||
}
|
||||
int rows_cluster = token_end_cluster - token_start_cluster; // total rows for a cluster
|
||||
// core partition
|
||||
partition(core_id(), core_num(), rows_cluster, 1, &token_start_core, &token_end_core);
|
||||
int rows_core = token_end_core - token_start_core; // total rows for a core
|
||||
token_start_core += token_start_cluster;
|
||||
token_end_core += token_start_cluster;
|
||||
|
||||
int read_len;
|
||||
for (int i = token_start_core; i < token_end_core; i += 1) {
|
||||
int token_type, text_image_token_idx;
|
||||
__global_ptr__ T* text_image_input = nullptr;
|
||||
__global_ptr__ int* text_image_index = nullptr;
|
||||
|
||||
GM2LM(token_type_ids + i, &token_type, sizeof(int));
|
||||
if (token_type == 0) {
|
||||
text_image_input = text_input;
|
||||
text_image_index = text_index;
|
||||
} else {
|
||||
text_image_input = image_input;
|
||||
text_image_index = image_index;
|
||||
}
|
||||
GM2LM(text_image_index + i, &text_image_token_idx, sizeof(int));
|
||||
int input_offset = i * hidden_size;
|
||||
int text_image_offset = text_image_token_idx * hidden_size;
|
||||
|
||||
for (int j = 0; j < hidden_size; j += BUFSIZE) {
|
||||
read_len = min(hidden_size - j, BUFSIZE);
|
||||
GM2LM(input + input_offset + j, input_lm, sizeof(T) * read_len);
|
||||
LM2GM(input_lm, text_image_input + text_image_offset + j, sizeof(T) * read_len);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void text_image_gather_scatter(
|
||||
T* input,
|
||||
T* text_input,
|
||||
T* image_input,
|
||||
int* token_type_ids,
|
||||
int* text_index,
|
||||
int* image_index,
|
||||
int64_t token_num,
|
||||
int64_t text_token_num,
|
||||
int64_t image_token_num,
|
||||
int64_t hidden_size,
|
||||
bool is_scatter) {
|
||||
int cid = core_id();
|
||||
int ncores = core_num();
|
||||
int clusterid = cluster_id();
|
||||
int nclusters = cluster_num();
|
||||
const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32
|
||||
__simd__ T input_lm[BUFSIZE]; // 2KB for bf16 and fp32
|
||||
if (is_scatter) {
|
||||
text_image_scatter(
|
||||
input, text_input, image_input, token_type_ids, text_index, image_index,
|
||||
token_num, text_token_num, image_token_num, hidden_size, input_lm);
|
||||
} else {
|
||||
text_image_gather(
|
||||
input, text_input, image_input, token_type_ids, text_index, image_index,
|
||||
token_num, text_token_num, image_token_num, hidden_size, input_lm);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#define _XPU_DEF_TEXT_IMAGE_GATHER_SCATTER(T) \
|
||||
template __global__ void text_image_gather_scatter<T>( \
|
||||
T* input, \
|
||||
T* text_input, \
|
||||
T* image_input, \
|
||||
int* token_type_ids, \
|
||||
int* text_index, \
|
||||
int* image_index, \
|
||||
int64_t token_num, \
|
||||
int64_t text_token_num, \
|
||||
int64_t image_token_num, \
|
||||
int64_t hidden_size, \
|
||||
bool is_scatter);
|
||||
|
||||
_XPU_DEF_TEXT_IMAGE_GATHER_SCATTER(bfloat16);
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
@@ -0,0 +1,97 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
/*
|
||||
* copyright (C) 2025 KUNLUNXIN, Inc
|
||||
*/
|
||||
|
||||
#include "xpu/kernel/cluster.h"
|
||||
#include "xpu/kernel/cluster_partition.h"
|
||||
#include "xpu/kernel/cluster_primitive.h"
|
||||
#include "xpu/kernel/cluster_primitive_template.h"
|
||||
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
|
||||
static __device__ void do_calc(const _shared_ptr_ int* lm_x, int* lm_y1, int* lm_y2, int64_t size, int& text_count, int& images_count) {
|
||||
for (int j = 0; j < size; j++) {
|
||||
if (lm_x[j] == 0) {
|
||||
lm_y1[j] = text_count;
|
||||
text_count += 1;
|
||||
} else {
|
||||
lm_y2[j] = images_count;
|
||||
images_count += 1;
|
||||
}
|
||||
}
|
||||
mfence_lm_sm();
|
||||
}
|
||||
|
||||
__global__ void text_image_index_out_kernel(
|
||||
const int* token_type_ids, // x
|
||||
int* text_index, // y1
|
||||
int* image_index, // y2
|
||||
const int64_t token_num) {
|
||||
const int cid = core_id();
|
||||
const int tid = core_id() * cluster_num() + cluster_id();
|
||||
const int nthreads = core_num() * cluster_num();
|
||||
if (tid >= 1) return;
|
||||
constexpr int BUFSIZE = 1024;
|
||||
constexpr int READ_MAX_SIZE = BUFSIZE / sizeof(int);
|
||||
const int64_t len = token_num;
|
||||
|
||||
__simd__ char buffer0[BUFSIZE * 3];
|
||||
__simd__ char buffer1[BUFSIZE * 3];
|
||||
__simd__ __shared__ char buffer2[64][BUFSIZE * 2];
|
||||
|
||||
DoublePtr<READ_MAX_SIZE, SmPtr<int>> buffer_ptr_x((SmPtr<int>((_shared_ptr_ int*)buffer2[cid])));
|
||||
TriplePtr<READ_MAX_SIZE, LmPtr<int>> buffer_ptr_y1((LmPtr<int>((int*)buffer0)));
|
||||
TriplePtr<READ_MAX_SIZE, LmPtr<int>> buffer_ptr_y2((LmPtr<int>((int*)buffer1)));
|
||||
int64_t buflen = get_1d_buflen(len, nthreads, READ_MAX_SIZE, 64);
|
||||
int64_t i = tid * buflen;
|
||||
int read_size = 0;
|
||||
int offset = nthreads * buflen;
|
||||
|
||||
int text_count = 0;
|
||||
int images_count = 0;
|
||||
|
||||
if (i < len) {
|
||||
read_size = min<int64_t>(buflen, len - i);
|
||||
buffer_ptr_y1.gm_load_async(text_index + tid * buflen, read_size);
|
||||
buffer_ptr_y2.gm_load_async(image_index + tid * buflen, read_size);
|
||||
buffer_ptr_x.gm_load_async(token_type_ids + tid * buflen, read_size);
|
||||
mfence();
|
||||
}
|
||||
while (i < len && i + offset < len) {
|
||||
i = i + offset;
|
||||
int read_size_next = min<int64_t>(buflen, len - i);
|
||||
buffer_ptr_x.next().gm_load_async(token_type_ids + i, read_size_next);
|
||||
buffer_ptr_y1.next().gm_load_async(text_index + i, read_size_next);
|
||||
buffer_ptr_y2.next().gm_load_async(image_index + i, read_size_next);
|
||||
|
||||
do_calc(buffer_ptr_x.ptr, buffer_ptr_y1.ptr, buffer_ptr_y2.ptr, read_size, text_count, images_count);
|
||||
|
||||
buffer_ptr_y1.gm_store_async(text_index + i - offset, read_size);
|
||||
buffer_ptr_y2.gm_store_async(image_index + i - offset, read_size);
|
||||
buffer_ptr_x.toggle();
|
||||
buffer_ptr_y1.toggle();
|
||||
buffer_ptr_y2.toggle();
|
||||
read_size = read_size_next;
|
||||
}
|
||||
if (i < len) {
|
||||
do_calc(buffer_ptr_x.ptr, buffer_ptr_y1.ptr, buffer_ptr_y2.ptr, read_size, text_count, images_count);
|
||||
buffer_ptr_y1.gm_store_async(text_index + i, read_size);
|
||||
buffer_ptr_y2.gm_store(image_index + i, read_size);
|
||||
}
|
||||
}
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
@@ -0,0 +1,182 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/refactor/impl_public/wrapper_check.h"
|
||||
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
template <typename T>
|
||||
__attribute__((global)) void text_image_gather_scatter(
|
||||
T* input,
|
||||
T* text_input,
|
||||
T* image_input,
|
||||
int* token_type_ids,
|
||||
int* text_index,
|
||||
int* image_index,
|
||||
int64_t token_num,
|
||||
int64_t text_token_num,
|
||||
int64_t image_token_num,
|
||||
int64_t hidden_size,
|
||||
bool is_scatter);
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
|
||||
namespace baidu {
|
||||
namespace xpu {
|
||||
namespace api {
|
||||
namespace plugin {
|
||||
|
||||
template <typename T>
|
||||
static int cpu_wrapper(
|
||||
Context* ctx,
|
||||
T* input, // shape [token_num, hidden_size]
|
||||
T* text_input, // shape [text_token_num, hidden_size]
|
||||
T* image_input, // shape [image_token_num, hidden_size]
|
||||
int* token_type_ids,// shape [token_num], 0 for text, 1 for image
|
||||
int* text_index, // shape [token_num], mapping from input to text_input
|
||||
int* image_index, // shape [token_num], mapping from input to image_input
|
||||
int64_t token_num,
|
||||
int64_t text_token_num,
|
||||
int64_t image_token_num,
|
||||
int64_t hidden_size,
|
||||
bool is_scatter) {
|
||||
|
||||
if (is_scatter) {
|
||||
// Scatter mode: input -> text_input/image_input
|
||||
for (int64_t i = 0; i < token_num; i++) {
|
||||
int token_type = token_type_ids[i];
|
||||
|
||||
T* text_image_input = nullptr;
|
||||
int* text_image_index = nullptr;
|
||||
if (token_type == 0) {
|
||||
text_image_input = text_input;
|
||||
text_image_index = text_index;
|
||||
} else { // token_type == 1
|
||||
text_image_input = image_input;
|
||||
text_image_index = image_index;
|
||||
}
|
||||
|
||||
int text_image_token_idx = text_image_index[i];
|
||||
int input_offset = i * hidden_size;
|
||||
int text_image_offset = text_image_token_idx * hidden_size;
|
||||
|
||||
for (int64_t j = 0; j < hidden_size; j++) {
|
||||
T value = input[input_offset + j];
|
||||
text_image_input[text_image_offset + j] = value;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Gather mode: text_input/image_input -> input
|
||||
for (int64_t i = 0; i < token_num; i++) {
|
||||
int token_type = token_type_ids[i];
|
||||
|
||||
T* text_image_input = nullptr;
|
||||
int* text_image_index = nullptr;
|
||||
if (token_type == 0) {
|
||||
text_image_input = text_input;
|
||||
text_image_index = text_index;
|
||||
} else { // token_type == 1
|
||||
text_image_input = image_input;
|
||||
text_image_index = image_index;
|
||||
}
|
||||
|
||||
int text_image_token_idx = text_image_index[i];
|
||||
int input_offset = i * hidden_size;
|
||||
int text_image_offset = text_image_token_idx * hidden_size;
|
||||
|
||||
for (int64_t j = 0; j < hidden_size; j++) {
|
||||
T value = text_image_input[text_image_offset + j];
|
||||
input[input_offset + j] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static int xpu3_wrapper(
|
||||
Context* ctx,
|
||||
T* input,
|
||||
T* text_input,
|
||||
T* image_input,
|
||||
int* token_type_ids,
|
||||
int* text_index,
|
||||
int* image_index,
|
||||
int64_t token_num,
|
||||
int64_t text_token_num,
|
||||
int64_t image_token_num,
|
||||
int64_t hidden_size,
|
||||
bool is_scatter) {
|
||||
xpu3::plugin::text_image_gather_scatter<T> <<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
input, text_input, image_input, token_type_ids, text_index, image_index,
|
||||
token_num, text_token_num, image_token_num, hidden_size, is_scatter
|
||||
);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
int text_image_gather_scatter(
|
||||
Context* ctx,
|
||||
T* input, // shape [token_num, hidden_size]
|
||||
T* text_input, // shape [text_token_num, hidden_size]
|
||||
T* image_input, // shape [image_token_num, hidden_size]
|
||||
int* token_type_ids,// shape [token_num], 0 for text, 1 for image
|
||||
int* text_index, // shape [token_num], mapping from input to text_input
|
||||
int* image_index, // shape [token_num], mapping from input to image_input
|
||||
int64_t token_num,
|
||||
int64_t text_token_num,
|
||||
int64_t image_token_num,
|
||||
int64_t hidden_size,
|
||||
bool is_scatter) {
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "text_image_gather_scatter", T);
|
||||
WRAPPER_DUMP_PARAM6(ctx, input, text_input, image_input, token_type_ids, text_index, image_index);
|
||||
WRAPPER_DUMP_PARAM5(ctx, token_num, text_token_num, image_token_num, hidden_size, is_scatter);
|
||||
WRAPPER_DUMP(ctx);
|
||||
WRAPPER_CHECK_PTR(ctx, T, token_num * hidden_size, input);
|
||||
if (text_token_num != 0) { // avoiding text_input tensor with shape [0, hidden_size]
|
||||
WRAPPER_CHECK_PTR(ctx, T, text_token_num * hidden_size, text_input);
|
||||
}
|
||||
if (image_token_num != 0) { // avoiding image_input tensor with shape [0, hidden_size]
|
||||
WRAPPER_CHECK_PTR(ctx, T, image_token_num * hidden_size, image_input);
|
||||
}
|
||||
WRAPPER_CHECK_PTR(ctx, int, token_num, token_type_ids);
|
||||
WRAPPER_CHECK_PTR(ctx, int, token_num, text_index);
|
||||
WRAPPER_CHECK_PTR(ctx, int, token_num, image_index);
|
||||
WRAPPER_ASSERT_EQ(ctx, token_num, text_token_num + image_token_num);
|
||||
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
return cpu_wrapper<T>(
|
||||
ctx, input, text_input, image_input, token_type_ids, text_index, image_index,
|
||||
token_num, text_token_num, image_token_num, hidden_size, is_scatter
|
||||
);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper<T>(
|
||||
ctx, input, text_input, image_input, token_type_ids, text_index, image_index,
|
||||
token_num, text_token_num, image_token_num, hidden_size, is_scatter
|
||||
);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
}
|
||||
|
||||
|
||||
template int text_image_gather_scatter(
|
||||
Context*, bfloat16*, bfloat16*, bfloat16*, int*, int*, int*, const int64_t, const int64_t, const int64_t, const int64_t, bool);
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
@@ -0,0 +1,103 @@
|
||||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/refactor/impl_public/wrapper_check.h"
|
||||
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
__attribute__((global)) void text_image_index_out_kernel(const int* token_type_ids, // x
|
||||
int* text_index, // y1
|
||||
int* image_index, // y2
|
||||
const int64_t token_num);
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
|
||||
namespace baidu {
|
||||
namespace xpu {
|
||||
namespace api {
|
||||
namespace plugin {
|
||||
|
||||
static int cpu_wrapper(Context* ctx,
|
||||
const int* token_type_ids, // x
|
||||
int* text_index, // y1
|
||||
int* image_index, // y2
|
||||
const int64_t token_num) {
|
||||
int text_count = 0;
|
||||
int image_count = 0;
|
||||
|
||||
for (int64_t i = 0; i < token_num; ++i) {
|
||||
if (token_type_ids[i] == 0) {
|
||||
text_index[i] = text_count;
|
||||
++text_count;
|
||||
} else {
|
||||
image_index[i] = image_count;
|
||||
++image_count;
|
||||
}
|
||||
}
|
||||
return api::SUCCESS;
|
||||
|
||||
}
|
||||
|
||||
static int xpu3_wrapper(Context* ctx,
|
||||
const int* token_type_ids, // x
|
||||
int* text_index, // y1
|
||||
int* image_index, // y2
|
||||
const int64_t token_num) {
|
||||
|
||||
xpu3::plugin::text_image_index_out_kernel<<<1, 1, ctx->xpu_stream>>>(
|
||||
token_type_ids,
|
||||
text_index,
|
||||
image_index,
|
||||
token_num);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
int text_image_index_out(Context* ctx,
|
||||
const int* token_type_ids, // x
|
||||
int* text_index, // y1
|
||||
int* image_index, // y2
|
||||
const int64_t token_num) {
|
||||
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "text_image_index_out", int);
|
||||
WRAPPER_DUMP_PARAM4(
|
||||
ctx, token_type_ids, text_index, image_index, token_num);
|
||||
WRAPPER_DUMP(ctx);
|
||||
WRAPPER_ASSERT_GT(ctx, token_num, 0);
|
||||
WRAPPER_CHECK_PTR(ctx, int, token_num, token_type_ids);
|
||||
WRAPPER_CHECK_PTR(ctx, int, token_num, text_index);
|
||||
WRAPPER_CHECK_PTR(ctx, int, token_num, image_index);
|
||||
|
||||
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
return cpu_wrapper(ctx,
|
||||
token_type_ids,
|
||||
text_index,
|
||||
image_index,
|
||||
token_num);
|
||||
} else if (ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper(ctx,
|
||||
token_type_ids,
|
||||
text_index,
|
||||
image_index,
|
||||
token_num);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
}
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
Reference in New Issue
Block a user