[Feat] ernie4_5_vl_moe support CudaGraph (#3226)

* delete dynamic control flow for decode

* coda-style

* fix scatter/gather typos and use input stream instead default stream

* support 0-Size Tensor

* update runner and model

* using static mem address as input

* fix mem leak

* refine code

* update mm_buffer

* fix typo

* fix buffersize

* fix unk token

* refine code

* refine

* support other arch

* open cudagraph in vlci

* fix

* update

* update

* update

* fix cmd

* update

---------

Co-authored-by: aquagull <hongyuh@qq.com>
Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
Ayakouji
2025-09-10 13:11:57 +08:00
committed by GitHub
parent 9d0074a91a
commit 453487d5b0
9 changed files with 207 additions and 98 deletions

View File

@@ -414,8 +414,8 @@ std::vector<paddle::Tensor> MoEDeepGEMMDePermute(
const paddle::Tensor &topk_idx, const paddle::Tensor &topk_weights);
void TextImageIndexOut(const paddle::Tensor &token_type_ids,
const paddle::Tensor &text_input,
const paddle::Tensor &image_input);
paddle::Tensor &text_input,
paddle::Tensor &image_input);
void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input,
paddle::Tensor &image_input,

View File

@@ -132,7 +132,7 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
}
PD_BUILD_STATIC_OP(get_padding_offset)
.Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"})
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"})
.Outputs({"x_remove_padding",
"batch_id_per_token",
"cu_seqlens_q",

View File

@@ -36,6 +36,9 @@ void MoeDispatchKernel(
paddle::Tensor *topk_idx, paddle::Tensor *expert_idx_per_token) {
using namespace phi;
if (num_rows == 0){
return;
}
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
@@ -185,6 +188,15 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
auto expert_idx_per_token =
GetEmptyTensor({num_rows * moe_topk}, paddle::DataType::INT32, place);
if (token_rows == 0){
return {permute_input,
tokens_expert_prefix_sum,
permute_indices_per_token,
topk_weight,
topk_idx,
expert_idx_per_token};
}
switch (input_type) {
case paddle::DataType::BFLOAT16:
MoeDispatchKernel<paddle::DataType::BFLOAT16>(

View File

@@ -412,7 +412,9 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype()
(quant_method == "w4afp8") ? paddle::DataType::BFLOAT16 :
permute_input.dtype();
auto ffn_out = paddle::empty_like(permute_input, t_type);
if(permute_input.numel() == 0){
return ffn_out;
}
switch (t_type) {
case paddle::DataType::BFLOAT16:
MoeFFNKernel<paddle::DataType::BFLOAT16>(permute_input,

View File

@@ -59,6 +59,10 @@ paddle::Tensor MoeExpertReduceFunc(
auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place);
if(num_rows == 0){
return output;
}
switch (input_type) {
case paddle::DataType::BFLOAT16:
MoeReduceKernel<paddle::DataType::BFLOAT16>(

View File

@@ -59,7 +59,7 @@ __global__ void text_image_scatter_kernel(
constexpr int HalfVecSize = VecSize / 2;
using T_Vec = AlignedVector<T, VecSize>;
T_Vec input_ptr_vec;
T_Vec text_imgaes_vec;
T_Vec text_images_vec;
int64_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t step = blockDim.x * gridDim.x * VecSize;
@@ -76,16 +76,20 @@ __global__ void text_image_scatter_kernel(
Load<T, VecSize>(input_ptr + input_load_offset, &input_ptr_vec);
#pragma unroll
for(int vi = 0; vi < VecSize; ++vi) {
text_imgaes_vec[vi] = input_ptr_vec[vi];
text_images_vec[vi] = input_ptr_vec[vi];
}
if (token_type_ids_num == 0) {
int64_t text_load_offset = text_index[token_idx] * hidden_size + hidden_offset;
Store<T,VecSize>(text_imgaes_vec, text_gather_ptr + text_load_offset);
Store<T,VecSize>(text_images_vec, text_gather_ptr + text_load_offset);
} else if(token_type_ids_num == 1){
int64_t image_load_offset = image_index[token_idx] * hidden_size + hidden_offset;
Store<T,VecSize>(text_images_vec, image_gather_ptr + image_load_offset);
} else {
int64_t image_load_offset = image_index[token_idx] * hidden_size + hidden_offset;
Store<T,VecSize>(text_imgaes_vec, image_gather_ptr + image_load_offset);
// skip cuda graph padding value
continue;
}
}
}
@@ -120,9 +124,12 @@ __global__ void text_image_gather_kernel(
int64_t text_load_offset = text_index[token_idx] * hidden_size + hidden_offset;
Load<T,VecSize>(text_gather_ptr + text_load_offset, &text_imgaes_vec);
} else {
} else if (token_type_ids_num == 1){
int64_t image_load_offset = image_index[token_idx] * hidden_size + hidden_offset;
Load<T,VecSize>(image_gather_ptr + image_load_offset, &text_imgaes_vec);
} else {
// skip cuda graph padding value
continue;
}
#pragma unroll
@@ -154,7 +161,6 @@ void LaunchTextImageGatherScatter(
const int64_t token_num = in_dims[0];
const int64_t hidden_size = in_dims[1];
const int VecSize = 16 / sizeof(data_t);
const int64_t tot_element_num = token_num * hidden_size;
@@ -168,7 +174,7 @@ void LaunchTextImageGatherScatter(
PADDLE_ENFORCE_GPU_SUCCESS(GetGridSize(tot_pack_num, block_size, kNumWaves, &grid_size_x));
dim3 grid_dim = dim3(grid_size_x, 1, 1);
if (is_scatter) {
text_image_scatter_kernel<DataType_, 8><<<grid_dim, block_size>>>(
text_image_scatter_kernel<DataType_, VecSize><<<grid_dim, block_size, 0, stream>>>(
reinterpret_cast<DataType_*>(input.data<data_t>()),
reinterpret_cast<DataType_*>(text_input.data<data_t>()),
reinterpret_cast<DataType_*>(image_input.data<data_t>()),
@@ -179,7 +185,7 @@ void LaunchTextImageGatherScatter(
tot_element_num
);
} else {
text_image_gather_kernel<DataType_, 8><<<grid_dim, block_size>>>(
text_image_gather_kernel<DataType_, VecSize><<<grid_dim, block_size, 0, stream>>>(
reinterpret_cast<DataType_*>(input.data<data_t>()),
reinterpret_cast<DataType_*>(text_input.data<data_t>()),
reinterpret_cast<DataType_*>(image_input.data<data_t>()),

View File

@@ -16,7 +16,7 @@
template <int VecSize>
__global__ void text_image_index_out_kernel(
int32_t* token_type_ids,
const int32_t* token_type_ids,
int32_t* text_index,
int32_t* image_index,
const int64_t token_num
@@ -31,23 +31,27 @@ __global__ void text_image_index_out_kernel(
if (token_type_ids[i] == 0) {
text_index[i] = text_count;
text_count += 1;
} else {
} else if (token_type_ids[i] == 1) {
image_index[i] = images_count;
images_count += 1;
} else {
// skip cuda graph padding value
continue;
}
}
}
void TextImageIndexOut(
const paddle::Tensor& token_type_ids,
const paddle::Tensor& text_index,
const paddle::Tensor& image_index) {
paddle::Tensor& text_index,
paddle::Tensor& image_index) {
const int64_t token_num = token_type_ids.shape()[0];
text_image_index_out_kernel<1><<<1, 1>>>(
const_cast<int32_t*>(token_type_ids.data<int32_t>()),
const_cast<int32_t*>(text_index.data<int32_t>()),
const_cast<int32_t*>(image_index.data<int32_t>()),
auto stream = token_type_ids.stream();
text_image_index_out_kernel<1><<<1, 1, 0, stream>>>(
token_type_ids.data<int32_t>(),
text_index.data<int32_t>(),
image_index.data<int32_t>(),
token_num
);
}