[BugFix] Fix zero workspace returned by CUB size query under CUDA Graph in MoE dispatch (#5087)

* fix bug about CubKeyValueSorter::run

* pre-commit and add comment

* pre-commit

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix precommit

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
Jundong Liu
2025-11-20 20:00:29 +08:00
committed by GitHub
parent 0857099191
commit 147b2e5eb0
2 changed files with 12 additions and 2 deletions

View File

@@ -16,8 +16,8 @@
*/
#pragma once
#include <string>
#include <sstream>
#include <string>
#include "cub/cub.cuh"
namespace phi {
@@ -45,7 +45,10 @@ class CubKeyValueSorter {
size_t getWorkspaceSize(const size_t num_key_value_pairs,
bool descending = false) {
num_key_value_pairs_ = num_key_value_pairs;
size_t required_storage = 0;
// Initialize to 1 as workaround: under CUDA Graph capture, CUB may not
// write to required_storage, and 1 is the minimum expected size in that
// scenario.
size_t required_storage = 1;
int* null_int = nullptr;
if (descending) {
cub::DeviceRadixSort::SortPairsDescending(NULL,

View File

@@ -87,6 +87,13 @@ void MoeDispatchKernel(
int8_t *sorter_ws_ptr = reinterpret_cast<int8_t *>(ws_ptr + bytes);
int *permuted_experts_ =
reinterpret_cast<int *>(sorter_ws_ptr + sorter_ws_size_bytes);
// If expected_ws_size > workspace_size ever occurs in sorter_.run (which
// should be practically impossible), there is a contiguous, currently unused
// region (permuted_experts_) right after sorter_ws_ptr. In practice, this
// region is larger than what cub::DeviceRadixSort::SortPairs requires.
// However, relying on this to “work” after canceling the assertion is unsafe:
// it constitutes undefined behavior, and there is no guarantee it will remain
// correct across inputs, CUDA/CUB versions, or architectures.
int *permuted_rows_ = permuted_experts_ + num_moe_inputs;
int *topk_idx_ptr = topk_idx->data<int>();