mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-11 03:20:20 +08:00
[Feature] Marlin MoE backend supports DeepseekV3 (#2962)
Co-authored-by: K11OntheBoat <“ruianmaidanglao@163.com”>
This commit is contained in:
@@ -22,10 +22,28 @@ from fastdeploy.distributed.communication import tensor_model_parallel_all_reduc
|
|||||||
from fastdeploy.model_executor.ops.gpu import (
|
from fastdeploy.model_executor.ops.gpu import (
|
||||||
MoeWna16MarlinGemmApi,
|
MoeWna16MarlinGemmApi,
|
||||||
tritonmoe_preprocess_func,
|
tritonmoe_preprocess_func,
|
||||||
|
noaux_tc,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..quantization.quant_base import QuantMethodBase
|
from ..quantization.quant_base import QuantMethodBase
|
||||||
|
|
||||||
|
def get_moe_scores(gating_output: paddle.Tensor, n_group, topk_group, top_k,
|
||||||
|
routed_scaling_factor,
|
||||||
|
e_score_correction_bias) -> paddle.Tensor:
|
||||||
|
"""
|
||||||
|
compute moe scores using e_score_correction_bias.
|
||||||
|
"""
|
||||||
|
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||||
|
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
|
||||||
|
scores = noaux_tc(
|
||||||
|
scores,
|
||||||
|
scores_with_bias,
|
||||||
|
n_group,
|
||||||
|
topk_group,
|
||||||
|
top_k,
|
||||||
|
routed_scaling_factor,
|
||||||
|
)
|
||||||
|
return scores
|
||||||
|
|
||||||
def gptq_marlin_moe_repack(
|
def gptq_marlin_moe_repack(
|
||||||
b_q_weight: paddle.Tensor,
|
b_q_weight: paddle.Tensor,
|
||||||
@@ -205,7 +223,16 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
moe_intermediate_size = layer.moe_intermediate_size
|
moe_intermediate_size = layer.moe_intermediate_size
|
||||||
hidden_size = layer.hidden_size
|
hidden_size = layer.hidden_size
|
||||||
num_experts = layer.num_experts
|
num_experts = layer.num_experts
|
||||||
|
topk_method = layer.topk_method
|
||||||
|
|
||||||
|
if topk_method == "noaux_tc":
|
||||||
|
gate_out = get_moe_scores(gate_out, layer.n_group,
|
||||||
|
layer.topk_group, layer.top_k,
|
||||||
|
layer.routed_scaling_factor,
|
||||||
|
layer.gate_correction_bias)
|
||||||
|
|
||||||
|
topk_weights, topk_ids = paddle.topk(gate_out, k=layer.top_k, axis=-1, sorted=False)
|
||||||
|
else:
|
||||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||||
gate_out,
|
gate_out,
|
||||||
layer.gate_correction_bias,
|
layer.gate_correction_bias,
|
||||||
@@ -291,7 +318,7 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
ffn_out.reshape_([token_num, -1, hidden_size])
|
ffn_out.reshape_([token_num, -1, hidden_size])
|
||||||
ffn_out = ffn_out.sum(axis=1)
|
ffn_out = ffn_out.sum(axis=1)
|
||||||
|
|
||||||
if layer.tp_size > 1:
|
if layer.reduce_results and layer.tp_size > 1:
|
||||||
tensor_model_parallel_all_reduce(ffn_out)
|
tensor_model_parallel_all_reduce(ffn_out)
|
||||||
|
|
||||||
return ffn_out
|
return ffn_out
|
||||||
|
Reference in New Issue
Block a user