fix cutlass ep (#5337)

This commit is contained in:
Sunny-bot1
2025-12-03 14:06:01 +08:00
committed by GitHub
parent 690bcb8e50
commit d5a9b75b4e

View File

@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// Ignore CUTLASS warnings about type punning
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
@@ -39,20 +38,35 @@ void moe_topk_select_kernel(const T* input,
const int64_t k,
cudaStream_t stream,
const bool apply_norm_weight = false,
const bool enable_softmax_top_k_fused = false
) {
const bool enable_softmax_top_k_fused = false) {
static constexpr int WARPS_PER_TB = 4;
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
case N: { \
if (apply_norm_weight) { \
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB, true>( \
input, bias, output, indices, source_row, num_rows, num_experts, k, stream); \
} else { \
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB, false>( \
input, bias, output, indices, source_row, num_rows, num_experts, k, stream); \
} \
break; \
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
case N: { \
if (apply_norm_weight) { \
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB, true>( \
input, \
bias, \
output, \
indices, \
source_row, \
num_rows, \
num_experts, \
k, \
stream); \
} else { \
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB, false>( \
input, \
bias, \
output, \
indices, \
source_row, \
num_rows, \
num_experts, \
k, \
stream); \
} \
break; \
}
switch (num_experts) {
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2)
@@ -68,56 +82,56 @@ void moe_topk_select_kernel(const T* input,
static constexpr int TPB = 256;
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
if (!enable_softmax_top_k_fused) {
moe_softmax<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
input, softmax, num_experts, num_rows);
if (apply_norm_weight) {
moe_top_k<T, TPB, true>
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(softmax,
bias,
output,
indices,
source_row,
num_experts,
k,
num_rows);
} else {
moe_top_k<T, TPB, false>
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
bias,
output,
indices,
source_row,
num_experts,
k,
num_rows);
}
cudaGetLastError();
moe_softmax<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
input, softmax, num_experts, num_rows);
if (apply_norm_weight) {
moe_top_k<T, TPB, true>
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(
softmax,
bias,
output,
indices,
source_row,
num_experts,
k,
num_rows);
} else {
moe_top_k<T, TPB, false>
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
bias,
output,
indices,
source_row,
num_experts,
k,
num_rows);
}
cudaGetLastError();
} else {
assert(k <= TPB);
if (apply_norm_weight) {
moe_softmax_top_k_fused<T, TPB, true>
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(
input,
bias,
output,
indices,
source_row,
num_experts,
k,
num_rows);
} else {
moe_softmax_top_k_fused<T, TPB, false>
<<<config_topk.block_per_grid, TPB, 0, stream>>>(input,
bias,
output,
indices,
source_row,
num_experts,
k,
num_rows);
}
}
else {
assert(k<=TPB);
if (apply_norm_weight) {
moe_softmax_top_k_fused<T, TPB, true>
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(input,
bias,
output,
indices,
source_row,
num_experts,
k,
num_rows);
} else {
moe_softmax_top_k_fused<T, TPB, false>
<<<config_topk.block_per_grid, TPB, 0, stream>>>(input,
bias,
output,
indices,
source_row,
num_experts,
k,
num_rows);
}
}
}
}
}
@@ -146,6 +160,13 @@ std::vector<paddle::Tensor> MoETopKSelectKernel(
auto topk_weights =
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
// NOTE(sunxin): Avoid "invalid configuration argument" error caused by empty
// tensors.
if (gating_dims[0] == 0) {
cudaGetLastError();
return {topk_ids, topk_weights};
}
const int num_moe_inputs = AlignTo16(num_rows * moe_topk);
const int bytes = num_moe_inputs * sizeof(int);
@@ -213,8 +234,7 @@ std::vector<std::vector<int64_t>> MoETopKSelectKernelInferShape(
}
const int num_rows = token_rows;
return {{num_rows, moe_topk},
{num_rows, moe_topk}};
return {{num_rows, moe_topk}, {num_rows, moe_topk}};
}
std::vector<paddle::DataType> MoETopKSelectKernelInferDtype(
@@ -223,16 +243,15 @@ std::vector<paddle::DataType> MoETopKSelectKernelInferDtype(
const int moe_topk,
const bool apply_norm_weight,
const bool enable_softmax_top_k_fused) {
return {paddle::DataType::INT64,
paddle::DataType::FLOAT32};
return {paddle::DataType::INT64, paddle::DataType::FLOAT32};
}
PD_BUILD_STATIC_OP(moe_topk_select)
.Inputs({"gating_logits", paddle::Optional("bias")})
.Outputs({"topk_ids",
"topk_weights"})
.Attrs({"moe_topk:int", "apply_norm_weight:bool", "enable_softmax_top_k_fused:bool"})
.Outputs({"topk_ids", "topk_weights"})
.Attrs({"moe_topk:int",
"apply_norm_weight:bool",
"enable_softmax_top_k_fused:bool"})
.SetKernelFn(PD_KERNEL(MoETopKSelectKernel))
.SetInferShapeFn(PD_INFER_SHAPE(MoETopKSelectKernelInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoETopKSelectKernelInferDtype));