[Feature] Marlin MoE backend supports DeepseekV3 (#2962)

Co-authored-by: K11OntheBoat <“ruianmaidanglao@163.com”>
This commit is contained in:
K11OntheBoat
2025-07-22 18:11:15 +08:00
committed by GitHub
parent dc67c10a7e
commit 93bb68aa71

View File

@@ -22,10 +22,28 @@ from fastdeploy.distributed.communication import tensor_model_parallel_all_reduc
from fastdeploy.model_executor.ops.gpu import (
MoeWna16MarlinGemmApi,
tritonmoe_preprocess_func,
noaux_tc,
)
from ..quantization.quant_base import QuantMethodBase
def get_moe_scores(gating_output: paddle.Tensor, n_group, topk_group, top_k,
routed_scaling_factor,
e_score_correction_bias) -> paddle.Tensor:
"""
compute moe scores using e_score_correction_bias.
"""
scores = paddle.nn.functional.sigmoid(gating_output)
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
scores = noaux_tc(
scores,
scores_with_bias,
n_group,
topk_group,
top_k,
routed_scaling_factor,
)
return scores
def gptq_marlin_moe_repack(
b_q_weight: paddle.Tensor,
@@ -205,14 +223,23 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
moe_intermediate_size = layer.moe_intermediate_size
hidden_size = layer.hidden_size
num_experts = layer.num_experts
topk_method = layer.topk_method
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
layer.gate_correction_bias,
top_k,
True, # apply_norm_weight,
False,
)
if topk_method == "noaux_tc":
gate_out = get_moe_scores(gate_out, layer.n_group,
layer.topk_group, layer.top_k,
layer.routed_scaling_factor,
layer.gate_correction_bias)
topk_weights, topk_ids = paddle.topk(gate_out, k=layer.top_k, axis=-1, sorted=False)
else:
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
layer.gate_correction_bias,
top_k,
True, # apply_norm_weight,
False,
)
block_size_m = 64
@@ -291,7 +318,7 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
ffn_out.reshape_([token_num, -1, hidden_size])
ffn_out = ffn_out.sum(axis=1)
if layer.tp_size > 1:
if layer.reduce_results and layer.tp_size > 1:
tensor_model_parallel_all_reduce(ffn_out)
return ffn_out