mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feat] ernie4_5_vl_moe support CudaGraph (#3226)
* delete dynamic control flow for decode * coda-style * fix scatter/gather typos and use input stream instead default stream * support 0-Size Tensor * update runner and model * using static mem address as input * fix mem leak * refine code * update mm_buffer * fix typo * fix buffersize * fix unk token * refine code * refine * support other arch * open cudagraph in vlci * fix * update * update * update * fix cmd * update --------- Co-authored-by: aquagull <hongyuh@qq.com> Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
@@ -414,8 +414,8 @@ std::vector<paddle::Tensor> MoEDeepGEMMDePermute(
|
||||
const paddle::Tensor &topk_idx, const paddle::Tensor &topk_weights);
|
||||
|
||||
void TextImageIndexOut(const paddle::Tensor &token_type_ids,
|
||||
const paddle::Tensor &text_input,
|
||||
const paddle::Tensor &image_input);
|
||||
paddle::Tensor &text_input,
|
||||
paddle::Tensor &image_input);
|
||||
|
||||
void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input,
|
||||
paddle::Tensor &image_input,
|
||||
|
||||
@@ -132,7 +132,7 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(get_padding_offset)
|
||||
.Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"})
|
||||
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"})
|
||||
.Outputs({"x_remove_padding",
|
||||
"batch_id_per_token",
|
||||
"cu_seqlens_q",
|
||||
|
||||
@@ -36,6 +36,9 @@ void MoeDispatchKernel(
|
||||
paddle::Tensor *topk_idx, paddle::Tensor *expert_idx_per_token) {
|
||||
using namespace phi;
|
||||
|
||||
if (num_rows == 0){
|
||||
return;
|
||||
}
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
@@ -185,6 +188,15 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
auto expert_idx_per_token =
|
||||
GetEmptyTensor({num_rows * moe_topk}, paddle::DataType::INT32, place);
|
||||
|
||||
if (token_rows == 0){
|
||||
return {permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
permute_indices_per_token,
|
||||
topk_weight,
|
||||
topk_idx,
|
||||
expert_idx_per_token};
|
||||
}
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeDispatchKernel<paddle::DataType::BFLOAT16>(
|
||||
|
||||
@@ -412,7 +412,9 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype()
|
||||
(quant_method == "w4afp8") ? paddle::DataType::BFLOAT16 :
|
||||
permute_input.dtype();
|
||||
auto ffn_out = paddle::empty_like(permute_input, t_type);
|
||||
|
||||
if(permute_input.numel() == 0){
|
||||
return ffn_out;
|
||||
}
|
||||
switch (t_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeFFNKernel<paddle::DataType::BFLOAT16>(permute_input,
|
||||
|
||||
@@ -59,6 +59,10 @@ paddle::Tensor MoeExpertReduceFunc(
|
||||
|
||||
auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place);
|
||||
|
||||
if(num_rows == 0){
|
||||
return output;
|
||||
}
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeReduceKernel<paddle::DataType::BFLOAT16>(
|
||||
|
||||
@@ -59,7 +59,7 @@ __global__ void text_image_scatter_kernel(
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using T_Vec = AlignedVector<T, VecSize>;
|
||||
T_Vec input_ptr_vec;
|
||||
T_Vec text_imgaes_vec;
|
||||
T_Vec text_images_vec;
|
||||
|
||||
int64_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int64_t step = blockDim.x * gridDim.x * VecSize;
|
||||
@@ -76,16 +76,20 @@ __global__ void text_image_scatter_kernel(
|
||||
Load<T, VecSize>(input_ptr + input_load_offset, &input_ptr_vec);
|
||||
#pragma unroll
|
||||
for(int vi = 0; vi < VecSize; ++vi) {
|
||||
text_imgaes_vec[vi] = input_ptr_vec[vi];
|
||||
text_images_vec[vi] = input_ptr_vec[vi];
|
||||
}
|
||||
|
||||
if (token_type_ids_num == 0) {
|
||||
int64_t text_load_offset = text_index[token_idx] * hidden_size + hidden_offset;
|
||||
Store<T,VecSize>(text_imgaes_vec, text_gather_ptr + text_load_offset);
|
||||
Store<T,VecSize>(text_images_vec, text_gather_ptr + text_load_offset);
|
||||
|
||||
} else if(token_type_ids_num == 1){
|
||||
int64_t image_load_offset = image_index[token_idx] * hidden_size + hidden_offset;
|
||||
Store<T,VecSize>(text_images_vec, image_gather_ptr + image_load_offset);
|
||||
|
||||
} else {
|
||||
int64_t image_load_offset = image_index[token_idx] * hidden_size + hidden_offset;
|
||||
Store<T,VecSize>(text_imgaes_vec, image_gather_ptr + image_load_offset);
|
||||
// skip cuda graph padding value
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -120,9 +124,12 @@ __global__ void text_image_gather_kernel(
|
||||
int64_t text_load_offset = text_index[token_idx] * hidden_size + hidden_offset;
|
||||
Load<T,VecSize>(text_gather_ptr + text_load_offset, &text_imgaes_vec);
|
||||
|
||||
} else {
|
||||
} else if (token_type_ids_num == 1){
|
||||
int64_t image_load_offset = image_index[token_idx] * hidden_size + hidden_offset;
|
||||
Load<T,VecSize>(image_gather_ptr + image_load_offset, &text_imgaes_vec);
|
||||
} else {
|
||||
// skip cuda graph padding value
|
||||
continue;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
@@ -154,7 +161,6 @@ void LaunchTextImageGatherScatter(
|
||||
const int64_t token_num = in_dims[0];
|
||||
const int64_t hidden_size = in_dims[1];
|
||||
|
||||
|
||||
const int VecSize = 16 / sizeof(data_t);
|
||||
const int64_t tot_element_num = token_num * hidden_size;
|
||||
|
||||
@@ -168,7 +174,7 @@ void LaunchTextImageGatherScatter(
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(GetGridSize(tot_pack_num, block_size, kNumWaves, &grid_size_x));
|
||||
dim3 grid_dim = dim3(grid_size_x, 1, 1);
|
||||
if (is_scatter) {
|
||||
text_image_scatter_kernel<DataType_, 8><<<grid_dim, block_size>>>(
|
||||
text_image_scatter_kernel<DataType_, VecSize><<<grid_dim, block_size, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(input.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(text_input.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(image_input.data<data_t>()),
|
||||
@@ -179,7 +185,7 @@ void LaunchTextImageGatherScatter(
|
||||
tot_element_num
|
||||
);
|
||||
} else {
|
||||
text_image_gather_kernel<DataType_, 8><<<grid_dim, block_size>>>(
|
||||
text_image_gather_kernel<DataType_, VecSize><<<grid_dim, block_size, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(input.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(text_input.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(image_input.data<data_t>()),
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
template <int VecSize>
|
||||
__global__ void text_image_index_out_kernel(
|
||||
int32_t* token_type_ids,
|
||||
const int32_t* token_type_ids,
|
||||
int32_t* text_index,
|
||||
int32_t* image_index,
|
||||
const int64_t token_num
|
||||
@@ -31,23 +31,27 @@ __global__ void text_image_index_out_kernel(
|
||||
if (token_type_ids[i] == 0) {
|
||||
text_index[i] = text_count;
|
||||
text_count += 1;
|
||||
} else {
|
||||
} else if (token_type_ids[i] == 1) {
|
||||
image_index[i] = images_count;
|
||||
images_count += 1;
|
||||
} else {
|
||||
// skip cuda graph padding value
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TextImageIndexOut(
|
||||
const paddle::Tensor& token_type_ids,
|
||||
const paddle::Tensor& text_index,
|
||||
const paddle::Tensor& image_index) {
|
||||
paddle::Tensor& text_index,
|
||||
paddle::Tensor& image_index) {
|
||||
|
||||
const int64_t token_num = token_type_ids.shape()[0];
|
||||
text_image_index_out_kernel<1><<<1, 1>>>(
|
||||
const_cast<int32_t*>(token_type_ids.data<int32_t>()),
|
||||
const_cast<int32_t*>(text_index.data<int32_t>()),
|
||||
const_cast<int32_t*>(image_index.data<int32_t>()),
|
||||
auto stream = token_type_ids.stream();
|
||||
text_image_index_out_kernel<1><<<1, 1, 0, stream>>>(
|
||||
token_type_ids.data<int32_t>(),
|
||||
text_index.data<int32_t>(),
|
||||
image_index.data<int32_t>(),
|
||||
token_num
|
||||
);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user