mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
revise get_moe_scores (#3164)
This commit is contained in:
@@ -31,6 +31,35 @@ import fastdeploy
|
||||
from fastdeploy.config import MoEPhase
|
||||
from fastdeploy.utils import singleton
|
||||
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import noaux_tc
|
||||
except:
|
||||
logger.warning("import noaux_tc Failed!")
|
||||
|
||||
|
||||
def get_moe_scores(
|
||||
gating_output: paddle.Tensor,
|
||||
n_group,
|
||||
topk_group,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
compute moe scores using e_score_correction_bias.
|
||||
"""
|
||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
|
||||
scores, topk_values, topk_idx = noaux_tc(
|
||||
scores,
|
||||
scores_with_bias,
|
||||
n_group,
|
||||
topk_group,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
return scores, topk_values, topk_idx
|
||||
|
||||
|
||||
@singleton
|
||||
class DeepEPEngine:
|
||||
@@ -284,13 +313,23 @@ class EPRunner:
|
||||
redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1,
|
||||
)
|
||||
else:
|
||||
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
self.top_k,
|
||||
True, # apply_norm_weight,
|
||||
False,
|
||||
)
|
||||
if layer.topk_method == "noaux_tc":
|
||||
score, topk_weights, topk_idx = get_moe_scores(
|
||||
gate_out,
|
||||
layer.n_group,
|
||||
layer.topk_group,
|
||||
layer.top_k,
|
||||
layer.routed_scaling_factor,
|
||||
layer.gate_correction_bias,
|
||||
)
|
||||
else:
|
||||
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
self.top_k,
|
||||
True, # apply_norm_weight,
|
||||
False,
|
||||
)
|
||||
return topk_idx, topk_weights
|
||||
|
||||
@abstractmethod
|
||||
|
@@ -53,7 +53,7 @@ def get_moe_scores(
|
||||
"""
|
||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
|
||||
scores = noaux_tc(
|
||||
scores, topk_values, topk_idx = noaux_tc(
|
||||
scores,
|
||||
scores_with_bias,
|
||||
n_group,
|
||||
@@ -61,7 +61,7 @@ def get_moe_scores(
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
return scores
|
||||
return scores, topk_values, topk_idx
|
||||
|
||||
|
||||
class CutlassMoEMethod(MoEMethodBase):
|
||||
@@ -248,7 +248,7 @@ class CutlassMoEMethod(MoEMethodBase):
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
"""
|
||||
if layer.topk_method == "noaux_tc":
|
||||
gate_out = get_moe_scores(
|
||||
gate_out, _, _ = get_moe_scores(
|
||||
gate_out,
|
||||
layer.n_group,
|
||||
layer.topk_group,
|
||||
|
@@ -41,7 +41,7 @@ def get_moe_scores(
|
||||
"""
|
||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
|
||||
scores = noaux_tc(
|
||||
scores, topk_values, topk_idx = noaux_tc(
|
||||
scores,
|
||||
scores_with_bias,
|
||||
n_group,
|
||||
@@ -49,7 +49,7 @@ def get_moe_scores(
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
return scores
|
||||
return scores, topk_values, topk_idx
|
||||
|
||||
|
||||
def gptq_marlin_moe_repack(
|
||||
@@ -233,7 +233,7 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
|
||||
topk_method = layer.topk_method
|
||||
|
||||
if topk_method == "noaux_tc":
|
||||
gate_out = get_moe_scores(
|
||||
gate_out, _, _ = get_moe_scores(
|
||||
gate_out,
|
||||
layer.n_group,
|
||||
layer.topk_group,
|
||||
|
Reference in New Issue
Block a user