diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index 69920649a..1b0e3a7cb 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -24,6 +24,7 @@ from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs from fastdeploy.utils import ceil_div from ..quantization.quant_base import QuantMethodBase +from .ep import get_moe_scores try: from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func @@ -167,13 +168,24 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): moe_intermediate_size = layer.moe_intermediate_size hidden_size = layer.hidden_size - topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( - gate_out, - layer.gate_correction_bias, - top_k, - True, # apply_norm_weight, - False, - ) + if layer.topk_method == "noaux_tc": + _, topk_weights, topk_ids = get_moe_scores( + gate_out, + layer.n_group, + layer.topk_group, + layer.top_k, + layer.routed_scaling_factor, + layer.gate_correction_bias, + ) + else: + topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( + gate_out, + layer.gate_correction_bias, + layer.top_k, + True, # apply_norm_weight + False, + ) + up_gate_proj_out = paddle.empty( [token_num * top_k, moe_intermediate_size * 2], dtype=x.dtype, @@ -419,13 +431,25 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): moe_intermediate_size = layer.moe_intermediate_size hidden_size = layer.hidden_size - topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( - gate_out, - layer.gate_correction_bias, - top_k, - True, # apply_norm_weight, - False, - ) + if layer.topk_method == "noaux_tc": + + _, topk_weights, topk_ids = get_moe_scores( + gate_out, + layer.n_group, + layer.topk_group, + layer.top_k, + layer.routed_scaling_factor, + layer.gate_correction_bias, + ) + else: + + topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( + gate_out, + layer.gate_correction_bias, + top_k, + True, # apply_norm_weight, + False, + ) up_gate_proj_out = paddle.empty( [token_num * top_k, moe_intermediate_size * 2], @@ -829,13 +853,23 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): E, N1, _ = getattr(layer, self.added_weight_attrs[0]).shape N2 = getattr(layer, self.added_weight_attrs[1]).shape[1] - topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( - gate_out, - layer.gate_correction_bias, - layer.top_k, - True, # apply_norm_weight - False, - ) + if layer.topk_method == "noaux_tc": + _, topk_weights, topk_ids = get_moe_scores( + gate_out, + layer.n_group, + layer.topk_group, + layer.top_k, + layer.routed_scaling_factor, + layer.gate_correction_bias, + ) + else: + topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( + gate_out, + layer.gate_correction_bias, + layer.top_k, + True, # apply_norm_weight + False, + ) config = { "BLOCK_SIZE_M": 64, diff --git a/fastdeploy/rl/rollout_config.py b/fastdeploy/rl/rollout_config.py index 82074b70c..1fe797868 100644 --- a/fastdeploy/rl/rollout_config.py +++ b/fastdeploy/rl/rollout_config.py @@ -60,6 +60,7 @@ class RolloutModelConfig: early_stop_config: str = None, local_rank: int = 0, moba_attention_config: str = None, + data_parallel_size: int = 1, ): # Required parameters self.model = model_name_or_path @@ -95,6 +96,7 @@ class RolloutModelConfig: self.splitwise_role = splitwise_role self.expert_parallel_size = expert_parallel_size self.enable_expert_parallel = enable_expert_parallel + self.data_parallel_size = data_parallel_size self.ori_vocab_size = ori_vocab_size self.quantization = quantization self.guided_decoding_backend = guided_decoding_backend