fix cutlass ep (#5337)

This commit is contained in:
Sunny-bot1
2025-12-03 14:06:01 +08:00
committed by GitHub
parent 690bcb8e50
commit d5a9b75b4e

View File

@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Ignore CUTLASS warnings about type punning // Ignore CUTLASS warnings about type punning
#pragma GCC diagnostic push #pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing" #pragma GCC diagnostic ignored "-Wstrict-aliasing"
@@ -39,18 +38,33 @@ void moe_topk_select_kernel(const T* input,
const int64_t k, const int64_t k,
cudaStream_t stream, cudaStream_t stream,
const bool apply_norm_weight = false, const bool apply_norm_weight = false,
const bool enable_softmax_top_k_fused = false const bool enable_softmax_top_k_fused = false) {
) {
static constexpr int WARPS_PER_TB = 4; static constexpr int WARPS_PER_TB = 4;
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \ #define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
case N: { \ case N: { \
if (apply_norm_weight) { \ if (apply_norm_weight) { \
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB, true>( \ topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB, true>( \
input, bias, output, indices, source_row, num_rows, num_experts, k, stream); \ input, \
bias, \
output, \
indices, \
source_row, \
num_rows, \
num_experts, \
k, \
stream); \
} else { \ } else { \
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB, false>( \ topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB, false>( \
input, bias, output, indices, source_row, num_rows, num_experts, k, stream); \ input, \
bias, \
output, \
indices, \
source_row, \
num_rows, \
num_experts, \
k, \
stream); \
} \ } \
break; \ break; \
} }
@@ -72,7 +86,8 @@ void moe_topk_select_kernel(const T* input,
input, softmax, num_experts, num_rows); input, softmax, num_experts, num_rows);
if (apply_norm_weight) { if (apply_norm_weight) {
moe_top_k<T, TPB, true> moe_top_k<T, TPB, true>
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(softmax, <<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(
softmax,
bias, bias,
output, output,
indices, indices,
@@ -92,12 +107,12 @@ void moe_topk_select_kernel(const T* input,
num_rows); num_rows);
} }
cudaGetLastError(); cudaGetLastError();
} } else {
else { assert(k <= TPB);
assert(k<=TPB);
if (apply_norm_weight) { if (apply_norm_weight) {
moe_softmax_top_k_fused<T, TPB, true> moe_softmax_top_k_fused<T, TPB, true>
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(input, <<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(
input,
bias, bias,
output, output,
indices, indices,
@@ -117,7 +132,6 @@ void moe_topk_select_kernel(const T* input,
num_rows); num_rows);
} }
} }
} }
} }
} }
@@ -146,6 +160,13 @@ std::vector<paddle::Tensor> MoETopKSelectKernel(
auto topk_weights = auto topk_weights =
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place); GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
// NOTE(sunxin): Avoid "invalid configuration argument" error caused by empty
// tensors.
if (gating_dims[0] == 0) {
cudaGetLastError();
return {topk_ids, topk_weights};
}
const int num_moe_inputs = AlignTo16(num_rows * moe_topk); const int num_moe_inputs = AlignTo16(num_rows * moe_topk);
const int bytes = num_moe_inputs * sizeof(int); const int bytes = num_moe_inputs * sizeof(int);
@@ -213,8 +234,7 @@ std::vector<std::vector<int64_t>> MoETopKSelectKernelInferShape(
} }
const int num_rows = token_rows; const int num_rows = token_rows;
return {{num_rows, moe_topk}, return {{num_rows, moe_topk}, {num_rows, moe_topk}};
{num_rows, moe_topk}};
} }
std::vector<paddle::DataType> MoETopKSelectKernelInferDtype( std::vector<paddle::DataType> MoETopKSelectKernelInferDtype(
@@ -223,16 +243,15 @@ std::vector<paddle::DataType> MoETopKSelectKernelInferDtype(
const int moe_topk, const int moe_topk,
const bool apply_norm_weight, const bool apply_norm_weight,
const bool enable_softmax_top_k_fused) { const bool enable_softmax_top_k_fused) {
return {paddle::DataType::INT64, return {paddle::DataType::INT64, paddle::DataType::FLOAT32};
paddle::DataType::FLOAT32};
} }
PD_BUILD_STATIC_OP(moe_topk_select) PD_BUILD_STATIC_OP(moe_topk_select)
.Inputs({"gating_logits", paddle::Optional("bias")}) .Inputs({"gating_logits", paddle::Optional("bias")})
.Outputs({"topk_ids", .Outputs({"topk_ids", "topk_weights"})
"topk_weights"}) .Attrs({"moe_topk:int",
.Attrs({"moe_topk:int", "apply_norm_weight:bool", "enable_softmax_top_k_fused:bool"}) "apply_norm_weight:bool",
"enable_softmax_top_k_fused:bool"})
.SetKernelFn(PD_KERNEL(MoETopKSelectKernel)) .SetKernelFn(PD_KERNEL(MoETopKSelectKernel))
.SetInferShapeFn(PD_INFER_SHAPE(MoETopKSelectKernelInferShape)) .SetInferShapeFn(PD_INFER_SHAPE(MoETopKSelectKernelInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoETopKSelectKernelInferDtype)); .SetInferDtypeFn(PD_INFER_DTYPE(MoETopKSelectKernelInferDtype));