[Graph Optimization] Fix IR graph dependency error exposed after enabling SOT by updating the return value of TextImageGatherScatter (#4610)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled

* fix TextImageGatherScatter in sot

* fix codestyle
This commit is contained in:
Ryan
2025-10-28 18:31:23 +08:00
committed by GitHub
parent 4d2f478d53
commit 07956a87b3
3 changed files with 159 additions and 155 deletions

View File

@@ -485,13 +485,14 @@ void TextImageIndexOut(const paddle::Tensor& token_type_ids,
paddle::Tensor& text_input,
paddle::Tensor& image_input);
void TextImageGatherScatter(paddle::Tensor& input,
paddle::Tensor& text_input,
paddle::Tensor& image_input,
paddle::Tensor& token_type_ids,
paddle::Tensor& text_index,
paddle::Tensor& image_index,
const bool is_scatter);
std::vector<paddle::Tensor> TextImageGatherScatter(
paddle::Tensor& input,
paddle::Tensor& text_input,
paddle::Tensor& image_input,
paddle::Tensor& token_type_ids,
paddle::Tensor& text_index,
paddle::Tensor& image_index,
const bool is_scatter);
paddle::Tensor count_tokens_per_expert_func(const paddle::Tensor& topk_ids,
int64_t num_experts);

View File

@@ -14,7 +14,10 @@
#include "helper.h"
inline cudaError_t GetGridSize(int64_t n, int block_size, int num_waves, int* num_blocks) {
inline cudaError_t GetGridSize(int64_t n,
int block_size,
int num_waves,
int* num_blocks) {
int dev;
{
cudaError_t err = cudaGetDevice(&dev);
@@ -45,136 +48,136 @@ inline cudaError_t GetGridSize(int64_t n, int block_size, int num_waves, int* nu
return cudaSuccess;
}
template<typename T, int VecSize>
__global__ void text_image_scatter_kernel(
T* input_ptr,
T* text_gather_ptr,
T* image_gather_ptr,
int32_t* token_type_ids,
int32_t* text_index,
int32_t* image_index,
const int64_t hidden_size,
const int64_t total_element_num
){
constexpr int HalfVecSize = VecSize / 2;
using T_Vec = AlignedVector<T, VecSize>;
T_Vec input_ptr_vec;
T_Vec text_images_vec;
template <typename T, int VecSize>
__global__ void text_image_scatter_kernel(T* input_ptr,
T* text_gather_ptr,
T* image_gather_ptr,
int32_t* token_type_ids,
int32_t* text_index,
int32_t* image_index,
const int64_t hidden_size,
const int64_t total_element_num) {
constexpr int HalfVecSize = VecSize / 2;
using T_Vec = AlignedVector<T, VecSize>;
T_Vec input_ptr_vec;
T_Vec text_images_vec;
int64_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t step = blockDim.x * gridDim.x * VecSize;
int64_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t step = blockDim.x * gridDim.x * VecSize;
for(int64_t element_idx = global_thread_id * VecSize;
element_idx < total_element_num;
element_idx += step){
int64_t token_idx = element_idx / hidden_size;
int64_t hidden_offset = element_idx % hidden_size;
int32_t token_type_ids_num = token_type_ids[token_idx];
for (int64_t element_idx = global_thread_id * VecSize;
element_idx < total_element_num;
element_idx += step) {
int64_t token_idx = element_idx / hidden_size;
int64_t hidden_offset = element_idx % hidden_size;
int32_t token_type_ids_num = token_type_ids[token_idx];
int64_t input_load_offset = token_idx * hidden_size + hidden_offset;
int64_t input_load_offset = token_idx * hidden_size + hidden_offset;
Load<T, VecSize>(input_ptr + input_load_offset, &input_ptr_vec);
#pragma unroll
for(int vi = 0; vi < VecSize; ++vi) {
text_images_vec[vi] = input_ptr_vec[vi];
}
if (token_type_ids_num == 0) {
int64_t text_load_offset = text_index[token_idx] * hidden_size + hidden_offset;
Store<T,VecSize>(text_images_vec, text_gather_ptr + text_load_offset);
} else if(token_type_ids_num == 1){
int64_t image_load_offset = image_index[token_idx] * hidden_size + hidden_offset;
Store<T,VecSize>(text_images_vec, image_gather_ptr + image_load_offset);
} else {
// skip cuda graph padding value
continue;
}
Load<T, VecSize>(input_ptr + input_load_offset, &input_ptr_vec);
#pragma unroll
for (int vi = 0; vi < VecSize; ++vi) {
text_images_vec[vi] = input_ptr_vec[vi];
}
if (token_type_ids_num == 0) {
int64_t text_load_offset =
text_index[token_idx] * hidden_size + hidden_offset;
Store<T, VecSize>(text_images_vec, text_gather_ptr + text_load_offset);
} else if (token_type_ids_num == 1) {
int64_t image_load_offset =
image_index[token_idx] * hidden_size + hidden_offset;
Store<T, VecSize>(text_images_vec, image_gather_ptr + image_load_offset);
} else {
// skip cuda graph padding value
continue;
}
}
}
template<typename T, int VecSize>
__global__ void text_image_gather_kernel(
T* output_ptr,
T* text_gather_ptr,
T* image_gather_ptr,
int32_t* token_type_ids,
int32_t* text_index,
int32_t* image_index,
const int64_t hidden_size,
const int64_t total_element_num
){
constexpr int HalfVecSize = VecSize / 2;
using T_Vec = AlignedVector<T, VecSize>;
T_Vec output_ptr_vec;
T_Vec text_imgaes_vec;
template <typename T, int VecSize>
__global__ void text_image_gather_kernel(T* output_ptr,
T* text_gather_ptr,
T* image_gather_ptr,
int32_t* token_type_ids,
int32_t* text_index,
int32_t* image_index,
const int64_t hidden_size,
const int64_t total_element_num) {
constexpr int HalfVecSize = VecSize / 2;
using T_Vec = AlignedVector<T, VecSize>;
T_Vec output_ptr_vec;
T_Vec text_imgaes_vec;
int64_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t step = blockDim.x * gridDim.x * VecSize;
int64_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t step = blockDim.x * gridDim.x * VecSize;
for(int64_t element_idx = global_thread_id * VecSize;
element_idx < total_element_num;
element_idx += step){
int64_t token_idx = element_idx / hidden_size;
int64_t hidden_offset = element_idx % hidden_size;
int32_t token_type_ids_num = token_type_ids[token_idx];
for (int64_t element_idx = global_thread_id * VecSize;
element_idx < total_element_num;
element_idx += step) {
int64_t token_idx = element_idx / hidden_size;
int64_t hidden_offset = element_idx % hidden_size;
int32_t token_type_ids_num = token_type_ids[token_idx];
if (token_type_ids_num == 0) {
int64_t text_load_offset = text_index[token_idx] * hidden_size + hidden_offset;
Load<T,VecSize>(text_gather_ptr + text_load_offset, &text_imgaes_vec);
if (token_type_ids_num == 0) {
int64_t text_load_offset =
text_index[token_idx] * hidden_size + hidden_offset;
Load<T, VecSize>(text_gather_ptr + text_load_offset, &text_imgaes_vec);
} else if (token_type_ids_num == 1){
int64_t image_load_offset = image_index[token_idx] * hidden_size + hidden_offset;
Load<T,VecSize>(image_gather_ptr + image_load_offset, &text_imgaes_vec);
} else {
// skip cuda graph padding value
continue;
}
#pragma unroll
for(int vi = 0; vi < VecSize; ++vi) {
output_ptr_vec[vi] = text_imgaes_vec[vi];
}
int64_t input_load_offset = token_idx * hidden_size + hidden_offset;
Store<T, VecSize>(output_ptr_vec, output_ptr + input_load_offset);
} else if (token_type_ids_num == 1) {
int64_t image_load_offset =
image_index[token_idx] * hidden_size + hidden_offset;
Load<T, VecSize>(image_gather_ptr + image_load_offset, &text_imgaes_vec);
} else {
// skip cuda graph padding value
continue;
}
#pragma unroll
for (int vi = 0; vi < VecSize; ++vi) {
output_ptr_vec[vi] = text_imgaes_vec[vi];
}
int64_t input_load_offset = token_idx * hidden_size + hidden_offset;
Store<T, VecSize>(output_ptr_vec, output_ptr + input_load_offset);
}
}
template <paddle::DataType D>
void LaunchTextImageGatherScatter(
paddle::Tensor& input,
paddle::Tensor& text_input,
paddle::Tensor& image_input,
paddle::Tensor& token_type_ids,
paddle::Tensor& text_index,
paddle::Tensor& image_index,
const bool is_scatter) {
void LaunchTextImageGatherScatter(paddle::Tensor& input,
paddle::Tensor& text_input,
paddle::Tensor& image_input,
paddle::Tensor& token_type_ids,
paddle::Tensor& text_index,
paddle::Tensor& image_index,
const bool is_scatter) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto stream = input.stream();
const auto& in_dims = input.dims();
const int64_t token_num = in_dims[0];
const int64_t hidden_size = in_dims[1];
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto stream = input.stream();
const auto& in_dims = input.dims();
const int64_t token_num = in_dims[0];
const int64_t hidden_size = in_dims[1];
const int VecSize = 16 / sizeof(data_t);
const int64_t tot_element_num = token_num * hidden_size;
const int VecSize = 16 / sizeof(data_t);
const int64_t tot_element_num = token_num * hidden_size;
int64_t tot_pack_num = (tot_element_num + VecSize - 1) / VecSize;
int64_t tot_pack_num = (tot_element_num + VecSize - 1) / VecSize;
const int block_size = 128;
int grid_index = (token_num + block_size - 1) / block_size;
constexpr int32_t kNumWaves = 16;
int grid_size_x = -1;
const int block_size = 128;
int grid_index = (token_num + block_size - 1) / block_size;
constexpr int32_t kNumWaves = 16;
int grid_size_x = -1;
PADDLE_ENFORCE_GPU_SUCCESS(GetGridSize(tot_pack_num, block_size, kNumWaves, &grid_size_x));
dim3 grid_dim = dim3(grid_size_x, 1, 1);
if (is_scatter) {
text_image_scatter_kernel<DataType_, VecSize><<<grid_dim, block_size, 0, stream>>>(
PADDLE_ENFORCE_GPU_SUCCESS(
GetGridSize(tot_pack_num, block_size, kNumWaves, &grid_size_x));
dim3 grid_dim = dim3(grid_size_x, 1, 1);
if (is_scatter) {
text_image_scatter_kernel<DataType_, VecSize>
<<<grid_dim, block_size, 0, stream>>>(
reinterpret_cast<DataType_*>(input.data<data_t>()),
reinterpret_cast<DataType_*>(text_input.data<data_t>()),
reinterpret_cast<DataType_*>(image_input.data<data_t>()),
@@ -182,10 +185,10 @@ void LaunchTextImageGatherScatter(
reinterpret_cast<int32_t*>(text_index.data<int32_t>()),
reinterpret_cast<int32_t*>(image_index.data<int32_t>()),
hidden_size,
tot_element_num
);
} else {
text_image_gather_kernel<DataType_, VecSize><<<grid_dim, block_size, 0, stream>>>(
tot_element_num);
} else {
text_image_gather_kernel<DataType_, VecSize>
<<<grid_dim, block_size, 0, stream>>>(
reinterpret_cast<DataType_*>(input.data<data_t>()),
reinterpret_cast<DataType_*>(text_input.data<data_t>()),
reinterpret_cast<DataType_*>(image_input.data<data_t>()),
@@ -193,33 +196,37 @@ void LaunchTextImageGatherScatter(
reinterpret_cast<int32_t*>(text_index.data<int32_t>()),
reinterpret_cast<int32_t*>(image_index.data<int32_t>()),
hidden_size,
tot_element_num
);
}
tot_element_num);
}
}
void TextImageGatherScatter(
paddle::Tensor& input,
paddle::Tensor& text_input,
paddle::Tensor& image_input,
paddle::Tensor& token_type_ids,
paddle::Tensor& text_index,
paddle::Tensor& image_index,
const bool is_scatter) {
switch (input.type()) {
case paddle::DataType::BFLOAT16: {
return LaunchTextImageGatherScatter<paddle::DataType::BFLOAT16>(input, text_input, image_input, token_type_ids, text_index, image_index, is_scatter);
}
default: {
PD_THROW(
"NOT supported data type. Only support BFLOAT16. ");
break;
}
std::vector<paddle::Tensor> TextImageGatherScatter(
paddle::Tensor& input,
paddle::Tensor& text_input,
paddle::Tensor& image_input,
paddle::Tensor& token_type_ids,
paddle::Tensor& text_index,
paddle::Tensor& image_index,
const bool is_scatter) {
switch (input.dtype()) {
case paddle::DataType::BFLOAT16: {
LaunchTextImageGatherScatter<paddle::DataType::BFLOAT16>(input,
text_input,
image_input,
token_type_ids,
text_index,
image_index,
is_scatter);
break;
}
default: {
PD_THROW("NOT supported data type. Only support BFLOAT16, but got",
input.dtype());
}
}
return {input, text_input, image_input};
}
PD_BUILD_STATIC_OP(text_image_gather_scatter)
.Inputs({"input",
"text_input",
@@ -227,13 +234,9 @@ PD_BUILD_STATIC_OP(text_image_gather_scatter)
"token_type_ids",
"text_index",
"image_index"})
.Outputs({"text_input_out",
"image_input_out",
"text_index_out",
"image_index_out"})
.Outputs({"output", "text_input_out", "image_input_out"})
.Attrs({"is_scatter:bool"})
.SetInplaceMap({{"text_input", "text_input_out"},
{"image_input", "image_input_out"},
{"text_index", "text_index_out"},
{"image_index", "image_index_out"}})
.SetInplaceMap({{"input", "output"},
{"text_input", "text_input_out"},
{"image_input", "image_input_out"}})
.SetKernelFn(PD_KERNEL(TextImageGatherScatter));

View File

@@ -277,7 +277,7 @@ class Ernie4_5_VLMoE(nn.Layer):
def forward(self, hidden_states: paddle.Tensor, vl_moe_meta: VLMoEMeta):
if self.num_shared_experts > 0:
shared_experts_out = self.shared_experts(hidden_states)
text_image_gather_scatter(
hidden_states, vl_moe_meta.text_input, vl_moe_meta.image_input = text_image_gather_scatter(
hidden_states,
vl_moe_meta.text_input,
vl_moe_meta.image_input,
@@ -288,7 +288,7 @@ class Ernie4_5_VLMoE(nn.Layer):
)
text_out = self.text_fused_moe(vl_moe_meta.text_input)
image_out = self.image_fused_moe(vl_moe_meta.image_input)
text_image_gather_scatter(
hidden_states, _, _ = text_image_gather_scatter(
hidden_states,
text_out,
image_out,