mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[BugFix] Fix zero workspace returned by CUB size query under CUDA Graph in MoE dispatch (#5087)
* fix bug about CubKeyValueSorter::run * pre-commit and add comment * pre-commit * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix precommit --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -16,8 +16,8 @@
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include "cub/cub.cuh"
|
||||
|
||||
namespace phi {
|
||||
@@ -45,7 +45,10 @@ class CubKeyValueSorter {
|
||||
size_t getWorkspaceSize(const size_t num_key_value_pairs,
|
||||
bool descending = false) {
|
||||
num_key_value_pairs_ = num_key_value_pairs;
|
||||
size_t required_storage = 0;
|
||||
// Initialize to 1 as workaround: under CUDA Graph capture, CUB may not
|
||||
// write to required_storage, and 1 is the minimum expected size in that
|
||||
// scenario.
|
||||
size_t required_storage = 1;
|
||||
int* null_int = nullptr;
|
||||
if (descending) {
|
||||
cub::DeviceRadixSort::SortPairsDescending(NULL,
|
||||
|
||||
@@ -87,6 +87,13 @@ void MoeDispatchKernel(
|
||||
int8_t *sorter_ws_ptr = reinterpret_cast<int8_t *>(ws_ptr + bytes);
|
||||
int *permuted_experts_ =
|
||||
reinterpret_cast<int *>(sorter_ws_ptr + sorter_ws_size_bytes);
|
||||
// If expected_ws_size > workspace_size ever occurs in sorter_.run (which
|
||||
// should be practically impossible), there is a contiguous, currently unused
|
||||
// region (permuted_experts_) right after sorter_ws_ptr. In practice, this
|
||||
// region is larger than what cub::DeviceRadixSort::SortPairs requires.
|
||||
// However, relying on this to “work” after canceling the assertion is unsafe:
|
||||
// it constitutes undefined behavior, and there is no guarantee it will remain
|
||||
// correct across inputs, CUDA/CUB versions, or architectures.
|
||||
int *permuted_rows_ = permuted_experts_ + num_moe_inputs;
|
||||
|
||||
int *topk_idx_ptr = topk_idx->data<int>();
|
||||
|
||||
Reference in New Issue
Block a user