mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
revise get_moe_scores (#3164)
This commit is contained in:
@@ -33,10 +33,14 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
|
|||||||
auto input_type = scores_with_bias.dtype();
|
auto input_type = scores_with_bias.dtype();
|
||||||
auto place = scores_with_bias.place();
|
auto place = scores_with_bias.place();
|
||||||
auto group_scores = paddle::empty({num_tokens, n_group}, input_type, place);
|
auto group_scores = paddle::empty({num_tokens, n_group}, input_type, place);
|
||||||
|
auto topk_values = paddle::empty({num_tokens, topk}, input_type, place);
|
||||||
|
auto topk_indices = paddle::empty({num_tokens, topk}, paddle::DataType::INT32, place);
|
||||||
auto stream = scores_with_bias.stream();
|
auto stream = scores_with_bias.stream();
|
||||||
|
|
||||||
invokeNoAuxTc<float>(reinterpret_cast<float*>(scores.data<float>()),
|
invokeNoAuxTc<float, int32_t>(reinterpret_cast<float*>(scores.data<float>()),
|
||||||
reinterpret_cast<float*>(group_scores.data<float>()),
|
reinterpret_cast<float*>(group_scores.data<float>()),
|
||||||
|
reinterpret_cast<float*>(topk_values.data<float>()),
|
||||||
|
reinterpret_cast<int32_t*>(topk_indices.data<int32_t>()),
|
||||||
reinterpret_cast<float*>(scores_with_bias.data<float>()),
|
reinterpret_cast<float*>(scores_with_bias.data<float>()),
|
||||||
num_tokens,
|
num_tokens,
|
||||||
num_experts,
|
num_experts,
|
||||||
@@ -46,19 +50,23 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
|
|||||||
routed_scaling_factor,
|
routed_scaling_factor,
|
||||||
stream);
|
stream);
|
||||||
|
|
||||||
return {scores};
|
return {scores, topk_values, topk_indices};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<paddle::DataType> NoauxTcInferDtype(
|
std::vector<paddle::DataType> NoauxTcInferDtype(
|
||||||
const paddle::DataType& scores_dtype,
|
const paddle::DataType& scores_dtype,
|
||||||
const paddle::DataType& scores_with_bias_dtype) {
|
const paddle::DataType& scores_with_bias_dtype) {
|
||||||
return {scores_dtype};
|
return {scores_dtype, scores_dtype, paddle::DataType::INT32};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::vector<int64_t>> NoauxTcInferShape(
|
std::vector<std::vector<int64_t>> NoauxTcInferShape(
|
||||||
const std::vector<int64_t>& scores_shape,
|
const std::vector<int64_t>& scores_shape,
|
||||||
const std::vector<int64_t>& gating_output_shape) {
|
const std::vector<int64_t>& ,
|
||||||
return {scores_shape};
|
const int topk) {
|
||||||
|
auto num_tokens = scores_shape[0];
|
||||||
|
auto topk_values_shape = std::vector<int64_t>{num_tokens, topk};
|
||||||
|
auto topk_indices_shape = std::vector<int64_t>{num_tokens, topk};
|
||||||
|
return {scores_shape, topk_values_shape, topk_indices_shape};
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(noaux_tc)
|
PD_BUILD_STATIC_OP(noaux_tc)
|
||||||
|
@@ -372,10 +372,12 @@ __global__ void topk_with_k2_kernel(T* output,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T, typename IdxT>
|
||||||
__global__ void group_idx_and_topk_idx_kernel(
|
__global__ void group_idx_and_topk_idx_kernel(
|
||||||
T* scores,
|
T* scores,
|
||||||
T const* group_scores,
|
T const* group_scores,
|
||||||
|
T* topk_values,
|
||||||
|
IdxT* topk_indices,
|
||||||
T* scores_with_bias,
|
T* scores_with_bias,
|
||||||
int64_t const num_tokens,
|
int64_t const num_tokens,
|
||||||
int64_t const n_group,
|
int64_t const n_group,
|
||||||
@@ -391,6 +393,8 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|||||||
scores_with_bias += case_id * num_experts;
|
scores_with_bias += case_id * num_experts;
|
||||||
scores += case_id * num_experts;
|
scores += case_id * num_experts;
|
||||||
group_scores += case_id * n_group;
|
group_scores += case_id * n_group;
|
||||||
|
topk_values += case_id * topk;
|
||||||
|
topk_indices += case_id * topk;
|
||||||
int32_t align_num_experts_per_group =
|
int32_t align_num_experts_per_group =
|
||||||
warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
|
warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
|
||||||
|
|
||||||
@@ -436,6 +440,7 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|||||||
queue((int32_t)topk, cuda::std::numeric_limits<T>::min());
|
queue((int32_t)topk, cuda::std::numeric_limits<T>::min());
|
||||||
|
|
||||||
int count_equalto_topkth_group = 0;
|
int count_equalto_topkth_group = 0;
|
||||||
|
bool if_proceed_next_topk = (topk_group_value != cuda::std::numeric_limits<T>::min());
|
||||||
if (case_id < num_tokens) {
|
if (case_id < num_tokens) {
|
||||||
for (int i_group = 0; i_group < n_group; i_group++) {
|
for (int i_group = 0; i_group < n_group; i_group++) {
|
||||||
if ((group_scores[i_group] > topk_group_value) ||
|
if ((group_scores[i_group] > topk_group_value) ||
|
||||||
@@ -490,13 +495,23 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|||||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||||
float value = s_topk_value[i] / topk_sum * routed_scaling_factor;
|
float value = s_topk_value[i] / topk_sum * routed_scaling_factor;
|
||||||
scores[s_topk_idx[i]] = value;
|
scores[s_topk_idx[i]] = value;
|
||||||
|
if (if_proceed_next_topk) {
|
||||||
|
topk_indices[i] = s_topk_idx[i];
|
||||||
|
topk_values[i] = static_cast<T>(value);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
topk_indices[i] = i;
|
||||||
|
topk_values[i] = static_cast<float>(1.0f / topk);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T, typename IdxT>
|
||||||
void invokeNoAuxTc(T* scores,
|
void invokeNoAuxTc(T* scores,
|
||||||
T* group_scores,
|
T* group_scores,
|
||||||
|
T* topk_values,
|
||||||
|
IdxT* topk_indices,
|
||||||
T* scores_with_bias,
|
T* scores_with_bias,
|
||||||
int64_t const num_tokens,
|
int64_t const num_tokens,
|
||||||
int64_t const num_experts,
|
int64_t const num_experts,
|
||||||
@@ -526,6 +541,8 @@ void invokeNoAuxTc(T* scores,
|
|||||||
dynamic_smem_in_bytes,
|
dynamic_smem_in_bytes,
|
||||||
stream>>>(scores,
|
stream>>>(scores,
|
||||||
group_scores,
|
group_scores,
|
||||||
|
topk_values,
|
||||||
|
topk_indices,
|
||||||
scores_with_bias,
|
scores_with_bias,
|
||||||
num_tokens,
|
num_tokens,
|
||||||
n_group,
|
n_group,
|
||||||
@@ -536,9 +553,11 @@ void invokeNoAuxTc(T* scores,
|
|||||||
routed_scaling_factor);
|
routed_scaling_factor);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define INSTANTIATE_NOAUX_TC(T) \
|
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
|
||||||
template void invokeNoAuxTc<T>(T * scores, \
|
template void invokeNoAuxTc<T, IdxT>(T * scores, \
|
||||||
T * group_scores, \
|
T * group_scores, \
|
||||||
|
T* topk_values, \
|
||||||
|
IdxT* topk_indices, \
|
||||||
T * scores_with_bias, \
|
T * scores_with_bias, \
|
||||||
int64_t const num_tokens, \
|
int64_t const num_tokens, \
|
||||||
int64_t const num_experts, \
|
int64_t const num_experts, \
|
||||||
@@ -548,4 +567,4 @@ void invokeNoAuxTc(T* scores,
|
|||||||
double const routed_scaling_factor, \
|
double const routed_scaling_factor, \
|
||||||
cudaStream_t const stream);
|
cudaStream_t const stream);
|
||||||
|
|
||||||
INSTANTIATE_NOAUX_TC(float);
|
INSTANTIATE_NOAUX_TC(float, int32_t);
|
||||||
|
@@ -31,6 +31,35 @@ import fastdeploy
|
|||||||
from fastdeploy.config import MoEPhase
|
from fastdeploy.config import MoEPhase
|
||||||
from fastdeploy.utils import singleton
|
from fastdeploy.utils import singleton
|
||||||
|
|
||||||
|
try:
|
||||||
|
from fastdeploy.model_executor.ops.gpu import noaux_tc
|
||||||
|
except:
|
||||||
|
logger.warning("import noaux_tc Failed!")
|
||||||
|
|
||||||
|
|
||||||
|
def get_moe_scores(
|
||||||
|
gating_output: paddle.Tensor,
|
||||||
|
n_group,
|
||||||
|
topk_group,
|
||||||
|
top_k,
|
||||||
|
routed_scaling_factor,
|
||||||
|
e_score_correction_bias,
|
||||||
|
) -> paddle.Tensor:
|
||||||
|
"""
|
||||||
|
compute moe scores using e_score_correction_bias.
|
||||||
|
"""
|
||||||
|
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||||
|
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
|
||||||
|
scores, topk_values, topk_idx = noaux_tc(
|
||||||
|
scores,
|
||||||
|
scores_with_bias,
|
||||||
|
n_group,
|
||||||
|
topk_group,
|
||||||
|
top_k,
|
||||||
|
routed_scaling_factor,
|
||||||
|
)
|
||||||
|
return scores, topk_values, topk_idx
|
||||||
|
|
||||||
|
|
||||||
@singleton
|
@singleton
|
||||||
class DeepEPEngine:
|
class DeepEPEngine:
|
||||||
@@ -283,6 +312,16 @@ class EPRunner:
|
|||||||
enable_softmax_top_k_fused=False,
|
enable_softmax_top_k_fused=False,
|
||||||
redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1,
|
redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
if layer.topk_method == "noaux_tc":
|
||||||
|
score, topk_weights, topk_idx = get_moe_scores(
|
||||||
|
gate_out,
|
||||||
|
layer.n_group,
|
||||||
|
layer.topk_group,
|
||||||
|
layer.top_k,
|
||||||
|
layer.routed_scaling_factor,
|
||||||
|
layer.gate_correction_bias,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||||
gate_out,
|
gate_out,
|
||||||
|
@@ -53,7 +53,7 @@ def get_moe_scores(
|
|||||||
"""
|
"""
|
||||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||||
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
|
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
|
||||||
scores = noaux_tc(
|
scores, topk_values, topk_idx = noaux_tc(
|
||||||
scores,
|
scores,
|
||||||
scores_with_bias,
|
scores_with_bias,
|
||||||
n_group,
|
n_group,
|
||||||
@@ -61,7 +61,7 @@ def get_moe_scores(
|
|||||||
top_k,
|
top_k,
|
||||||
routed_scaling_factor,
|
routed_scaling_factor,
|
||||||
)
|
)
|
||||||
return scores
|
return scores, topk_values, topk_idx
|
||||||
|
|
||||||
|
|
||||||
class CutlassMoEMethod(MoEMethodBase):
|
class CutlassMoEMethod(MoEMethodBase):
|
||||||
@@ -248,7 +248,7 @@ class CutlassMoEMethod(MoEMethodBase):
|
|||||||
Paddle Cutlass compute Fused MoE.
|
Paddle Cutlass compute Fused MoE.
|
||||||
"""
|
"""
|
||||||
if layer.topk_method == "noaux_tc":
|
if layer.topk_method == "noaux_tc":
|
||||||
gate_out = get_moe_scores(
|
gate_out, _, _ = get_moe_scores(
|
||||||
gate_out,
|
gate_out,
|
||||||
layer.n_group,
|
layer.n_group,
|
||||||
layer.topk_group,
|
layer.topk_group,
|
||||||
|
@@ -41,7 +41,7 @@ def get_moe_scores(
|
|||||||
"""
|
"""
|
||||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||||
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
|
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
|
||||||
scores = noaux_tc(
|
scores, topk_values, topk_idx = noaux_tc(
|
||||||
scores,
|
scores,
|
||||||
scores_with_bias,
|
scores_with_bias,
|
||||||
n_group,
|
n_group,
|
||||||
@@ -49,7 +49,7 @@ def get_moe_scores(
|
|||||||
top_k,
|
top_k,
|
||||||
routed_scaling_factor,
|
routed_scaling_factor,
|
||||||
)
|
)
|
||||||
return scores
|
return scores, topk_values, topk_idx
|
||||||
|
|
||||||
|
|
||||||
def gptq_marlin_moe_repack(
|
def gptq_marlin_moe_repack(
|
||||||
@@ -233,7 +233,7 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
topk_method = layer.topk_method
|
topk_method = layer.topk_method
|
||||||
|
|
||||||
if topk_method == "noaux_tc":
|
if topk_method == "noaux_tc":
|
||||||
gate_out = get_moe_scores(
|
gate_out, _, _ = get_moe_scores(
|
||||||
gate_out,
|
gate_out,
|
||||||
layer.n_group,
|
layer.n_group,
|
||||||
layer.topk_group,
|
layer.topk_group,
|
||||||
|
76
test/operators/test_noaux_tc.py
Normal file
76
test/operators/test_noaux_tc.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.ops.gpu import noaux_tc
|
||||||
|
|
||||||
|
|
||||||
|
class TestMoeRouting(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.num_tokens = 10
|
||||||
|
self.num_experts = 64
|
||||||
|
self.gating_output = paddle.rand([self.num_tokens, self.num_experts])
|
||||||
|
self.e_score_correction_bias = paddle.rand([self.num_experts])
|
||||||
|
self.n_group = 8
|
||||||
|
self.topk_group = 4
|
||||||
|
self.top_k = 8
|
||||||
|
self.routed_scaling_factor = 1.5
|
||||||
|
|
||||||
|
def node_limit_routing(self, gate_probs):
|
||||||
|
"""将所有专家分组, 只在topk_group个group内选择专家"""
|
||||||
|
assert len(gate_probs.shape) == 2
|
||||||
|
seq_length, n_experts = gate_probs.shape
|
||||||
|
|
||||||
|
group_scores = gate_probs.reshape([seq_length, 8, -1]).topk(2, axis=-1)[0].sum(axis=-1)
|
||||||
|
group_idx = paddle.topk(group_scores, k=4, axis=-1, sorted=True)[1]
|
||||||
|
group_mask = paddle.zeros_like(group_scores).put_along_axis(
|
||||||
|
group_idx, paddle.ones([], dtype="float32"), axis=-1
|
||||||
|
)
|
||||||
|
score_mask = group_mask.unsqueeze(-1).expand([seq_length, 8, n_experts // 8]).reshape([seq_length, -1])
|
||||||
|
gate_probs = gate_probs.masked_fill(~score_mask.astype(paddle.bool), float("-inf"))
|
||||||
|
return gate_probs
|
||||||
|
|
||||||
|
def ref_moe_routing(self):
|
||||||
|
scores = paddle.nn.functional.sigmoid(self.gating_output)
|
||||||
|
prob_for_choice = scores + self.e_score_correction_bias.unsqueeze(0)
|
||||||
|
prob_for_choice = self.node_limit_routing(prob_for_choice)
|
||||||
|
top_logits, topk_idx_ref = paddle.topk(prob_for_choice, self.top_k, axis=1)
|
||||||
|
|
||||||
|
token_num, top_k = topk_idx_ref.shape
|
||||||
|
_, num_expert = prob_for_choice.shape
|
||||||
|
topk_idx_expanded = paddle.unsqueeze(topk_idx_ref, axis=-1)
|
||||||
|
indices = paddle.concat(
|
||||||
|
[
|
||||||
|
paddle.arange(token_num, dtype="int64").unsqueeze(1).tile([1, top_k]).unsqueeze(-1),
|
||||||
|
topk_idx_expanded,
|
||||||
|
],
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
selected_gate_probs = paddle.gather_nd(scores, indices)
|
||||||
|
|
||||||
|
selected_gate_probs_sum = paddle.sum(selected_gate_probs, axis=1, keepdim=True)
|
||||||
|
topk_weights_ref = selected_gate_probs / selected_gate_probs_sum
|
||||||
|
topk_weights_ref = topk_weights_ref * self.routed_scaling_factor
|
||||||
|
return topk_weights_ref, topk_idx_ref
|
||||||
|
|
||||||
|
def test_moe_select(self):
|
||||||
|
scores = paddle.nn.functional.sigmoid(self.gating_output)
|
||||||
|
scores_with_bias = scores + self.e_score_correction_bias.unsqueeze(0)
|
||||||
|
|
||||||
|
scores, topk_values, topk_idx = noaux_tc(
|
||||||
|
scores,
|
||||||
|
scores_with_bias,
|
||||||
|
self.n_group,
|
||||||
|
self.topk_group,
|
||||||
|
self.top_k,
|
||||||
|
self.routed_scaling_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
ref_topk_values, ref_topk_idx = self.ref_moe_routing()
|
||||||
|
|
||||||
|
paddle.allclose(topk_values, ref_topk_values)
|
||||||
|
paddle.allclose(topk_idx.cast(int), ref_topk_idx.cast(int))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Reference in New Issue
Block a user