""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ import os import paddle import fastdeploy import paddle.distributed as dist from paddle.base.core import Config from paddle.distributed.communication.group import Group from paddle.distributed.communication import deep_ep from paddlenlp.utils.log import logger from fastdeploy.model_executor.layers.moe.moe import MoELayer from fastdeploy.inference_args import GenerationPhase from fastdeploy.model_executor.layers.utils import get_tensor import fastdeploy.model_executor.ops.gpu.deep_gemm as deep_gemm import numpy as np class DeepEPEngine: """ A wrapper class for DeepEP engine. """ def __init__( self, group: Group, num_ranks: int, rank_id: int, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int, generation_phase: GenerationPhase, async_finish: bool = False, ): """ Initialize the DeepEP engine. Args: group: The MPI group object. num_ranks: The number of ranks. rank_id: The rank id. num_max_dispatch_tokens_per_rank: The maximum number of tokens per rank to dispatch. hidden: The hidden dimension of the model. num_experts: The number of experts. """ self.group = group self.num_ranks = num_ranks self.rank_id = rank_id self.hidden = hidden self.num_experts = num_experts self.num_local_experts = num_experts // num_ranks self.generation_phase = generation_phase self.async_finish = async_finish self.deepep_engine = None if generation_phase == GenerationPhase.DECODER: logger.info("Initializing Low Latency Buffer") self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank self.get_low_latency_buffer() elif generation_phase == GenerationPhase.PREFILL: self.deepep_engine = deep_ep.Buffer( group, int(1e9), 0, low_latency_mode=False, num_qps_per_rank=1, ) self.ep_config = Config(24, 6, 256) else: raise ValueError(f"Unknown generation phase {generation_phase}") def get_low_latency_buffer(self) -> deep_ep.Buffer: """ Get the DeepEP buffer. Args: group: The MPI group object. num_max_dispatch_tokens_per_rank: The maximum number of tokens per rank to dispatch. hidden: The hidden dimension of the model. """ # NOTES: the low-latency mode will consume much more space than the normal mode # So we recommend that `num_max_dispatch_tokens_per_rank` # (the actual batch size in the decoding engine) should be less than 256 num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( self.num_max_dispatch_tokens_per_rank, self.hidden, self.num_ranks, self.num_experts, ) # Allocate a buffer if not existed or not enough buffer size if ( self.deepep_engine is None or self.deepep_engine.group != self.group or not self.deepep_engine.low_latency_mode or self.deepep_engine.num_rdma_bytes < num_rdma_bytes ): # NOTES: for best performance, the QP number **must** be equal to the number of the local experts assert self.num_experts % self.num_ranks == 0 self.deepep_engine = deep_ep.Buffer( self.group, 0, num_rdma_bytes, low_latency_mode=True, num_qps_per_rank=self.num_experts // self.num_ranks, ) def low_latency_dispatch( self, hidden_states: paddle.Tensor, topk_idx: paddle.Tensor, moe_in_w4a8_scale, use_fp8: bool = False, ): """ Args: hidden_states: [token_num, hidden] 'bfloat16/int8' topk_idx: [token_num, num_topk] 'int64' Returns: recv_hidden_states: [num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden] num_ranks * num_local_experts = num_experts recv_count: [num_local_experts] recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each expert receive. As mentioned before, all not tokens are valid in `recv_x`. handle: the communication handle to be used in the `low_latency_combine` function. event: the event after executing the kernel (valid only if `async_finish` is set). hook: the receiving hook function (valid only if `return_recv_hook` is set). """ ( packed_recv_x, recv_expert_count, handle, _, dispatch_hook, ) = self.deepep_engine.low_latency_dispatch( hidden_states, topk_idx, moe_in_w4a8_scale, self.num_max_dispatch_tokens_per_rank, self.num_experts, use_fp8=use_fp8, async_finish=False, return_recv_hook=True, ) return packed_recv_x, recv_expert_count, handle, dispatch_hook def low_latency_combine( self, hidden_states: paddle.Tensor, topk_idx: paddle.Tensor, topk_weights: paddle.Tensor, handle, ): """ Return: combined_hidden_states: [num_tokens, hidden] """ combined_hidden_states, _, combine_hook = ( self.deepep_engine.low_latency_combine( hidden_states, topk_idx, topk_weights, handle, async_finish=False, return_recv_hook=True, ) ) return combined_hidden_states, combine_hook def clean_low_latency_buffer(self): """ clean_low_latency_buffer """ self.deepep_engine.clean_low_latency_buffer( self.num_max_dispatch_tokens_per_rank, self.hidden, self.num_experts ) def barrier_all(self): """ barrier_all """ self.deepep_engine.barrier_all() class MoeEPLayer(MoELayer): """ MOE EP Layer """ def __init__( self, ep_engine: DeepEPEngine, num_local_experts: int, redundant_table_manger=None, *args, **kwargs, ): """ Initialize MOE EP Layer """ kwargs["num_local_experts"] = num_local_experts kwargs["nranks"] = 1 # Only support 1 rank for EP MOE super().__init__(*args, **kwargs) self.ep_engine = ep_engine self.ep_size = self.ep_engine.num_ranks self.ep_rank = self.ep_engine.rank_id self.redundant_table_manger = redundant_table_manger def load_scale_state_dict(self): """ load_scale_state_dict function. """ up_gate_proj_weight_scale = [] down_proj_weight_scale = [] up_gate_proj_in_scale = [] down_proj_in_scale = [] for j in range( self.num_experts_start_offset, self.num_experts_start_offset + self.num_local_experts, ): up_gate_proj_in_scale_value = self.inference_args.act_scale_dict.pop( self.ffn1_expert_in_scale_key.format(j) ) up_gate_proj_weight_scale_np = np.array( self.inference_args.weight_scale_dict.pop( self.ffn1_expert_weight_scale_key.format(j) ) ) up_gate_proj_weight_scale_np = up_gate_proj_weight_scale_np / ( 127.0 * 112.0 * up_gate_proj_in_scale_value ) up_gate_proj_in_scale.append(up_gate_proj_in_scale_value) up_gate_proj_weight_scale.append( paddle.to_tensor(up_gate_proj_weight_scale_np, dtype="float32") ) down_proj_in_scale_value = self.inference_args.act_scale_dict.pop( self.ffn2_expert_in_scale_key.format(j) ) down_proj_weight_scale_np = np.array( self.inference_args.weight_scale_dict.pop( self.ffn2_expert_weight_scale_key.format(j) ) ) down_proj_weight_scale_np = down_proj_weight_scale_np / ( 127.0 * 112.0 * down_proj_in_scale_value ) down_proj_in_scale.append(down_proj_in_scale_value) down_proj_weight_scale.append( paddle.to_tensor(down_proj_weight_scale_np, dtype="float32") ) return ( up_gate_proj_weight_scale, down_proj_weight_scale, up_gate_proj_in_scale, down_proj_in_scale, ) def load_gate_state_dict(self, state_dict): """ Load Gate State Dict from state_dict Args: state_dict: state dict """ logical_expert_ids = [ i for i in range( self.num_experts_start_offset, self.num_experts_start_offset + self.num_local_experts, ) ] if self.redundant_table_manger is not None: ( ep_rank_to_expert_id_list, expert_id_to_ep_rank_array, expert_in_rank_num_list, tokens_per_expert_stats_list, ) = self.redundant_table_manger.get_ep_rank_to_expert_id_list_by_layer( self.layer_idx ) logical_expert_ids = ep_rank_to_expert_id_list[ self.num_experts_start_offset : self.num_experts_start_offset + self.num_local_experts ] up_gate_proj_weight = [] up_gate_proj_weight_scale = [] down_proj_weight = [] down_proj_weight_scale = [] if self.redundant_table_manger is not None: for j in logical_expert_ids: if expert_in_rank_num_list[j] > 1: # TODO:减一计数,最后pop up_gate = ( state_dict.get(self.ffn1_expert_weight_key.format(j)) if self.moe_quant_type == "default" or not self.use_offline_quant else state_dict.get( (self.ffn1_expert_weight_key + ".quant_weight").format(j) ) ) down = ( state_dict.get(self.ffn2_expert_weight_key.format(j)) if self.moe_quant_type == "default" or not self.use_offline_quant else state_dict.get( (self.ffn2_expert_weight_key + ".quant_weight").format(j) ) ) if self.use_offline_quant: up_gate_scale = state_dict.get( (self.ffn1_expert_weight_key + ".quant_scale").format(j) ) down_scale = state_dict.get( (self.ffn2_expert_weight_key + ".quant_scale").format(j) ) up_gate_proj_weight_scale.append(get_tensor(up_gate_scale)) down_proj_weight_scale.append(get_tensor(down_scale)) else: up_gate = ( state_dict.pop(self.ffn1_expert_weight_key.format(j)) if self.moe_quant_type == "default" or not self.use_offline_quant else state_dict.pop( (self.ffn1_expert_weight_key + ".quant_weight").format(j) ) ) down = ( state_dict.pop(self.ffn2_expert_weight_key.format(j)) if self.moe_quant_type == "default" or not self.use_offline_quant else state_dict.pop( (self.ffn2_expert_weight_key + ".quant_weight").format(j) ) ) if self.use_offline_quant: up_gate_scale = state_dict.pop( (self.ffn1_expert_weight_key + ".quant_scale").format(j) ) down_scale = state_dict.pop( (self.ffn2_expert_weight_key + ".quant_scale").format(j) ) up_gate_proj_weight_scale.append(get_tensor(up_gate_scale)) down_proj_weight_scale.append(get_tensor(down_scale)) up_gate_proj_weight.append(get_tensor(up_gate)) down_proj_weight.append(get_tensor(down)) up_gate_proj_weight_scale.append(get_tensor(up_gate_scale)) down_proj_weight_scale.append(get_tensor(down_scale)) else: for j in logical_expert_ids: up_gate_proj_weight.append( get_tensor(state_dict.pop(self.ffn1_expert_weight_key.format(j))) if self.moe_quant_type == "default" or not self.use_offline_quant else get_tensor( state_dict.pop( (self.ffn1_expert_weight_key + ".quant_weight").format(j) ) ) ) down_proj_weight.append( get_tensor(state_dict.pop(self.ffn2_expert_weight_key.format(j))) if self.moe_quant_type == "default" or not self.use_offline_quant else get_tensor( state_dict.pop( (self.ffn2_expert_weight_key + ".quant_weight").format(j) ) ) ) if self.use_offline_quant: up_gate_proj_weight_scale.append( get_tensor( state_dict.pop( (self.ffn1_expert_weight_key + ".quant_scale").format(j) ) ) ) down_proj_weight_scale.append( get_tensor( state_dict.pop( (self.ffn2_expert_weight_key + ".quant_scale").format(j) ) ) ) return ( up_gate_proj_weight, down_proj_weight, up_gate_proj_weight_scale, down_proj_weight_scale, ) def forward(self, x, **kwargs): """ MoeEPLayer Forward Function """ raise NotImplementedError class PrefillMoeEPLayer(MoeEPLayer): """ Prefill MOE EP Layer """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) logger.debug("Init Prefill EP Layer") self.ep_async_finish = False def micro_batch_gate(self, x): """ Run the micro-batch's gate and select topk's export. Args: x (Tensor): The index of micro-batch. The shape is `[token, num_export]`. The data type should be bfloat16, float16 or float32. Returns: topk_idx (Tensor): The index of getting highest score's exports. The shape is `[token, topk]`. The data type should be int64. topk_weights (Tensor): The scores of getting highest score's exports. The shape is `[token, topk]`. The data type should be float32. """ topk_idx = None topk_weights = None gate_out = paddle.matmul(x.cast("float32"), self.gate_weight) if self.redundant_table_manger is not None: ( ep_rank_to_expert_id_list, expert_id_to_ep_rank_array, expert_in_rank_num_list, tokens_per_expert_stats_list, ) = self.redundant_table_manger.get_ep_rank_to_expert_id_list_by_layer( self.layer_idx ) topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.f_moe_redundant_topk_select( gating_logits=gate_out, expert_id_to_ep_rank_array=expert_id_to_ep_rank_array, expert_in_rank_num_list=expert_in_rank_num_list, tokens_per_expert_stats_list=tokens_per_expert_stats_list, bias=( self.gate_correction_bias if self.moe_config.moe_use_gate_correction_bias else None ), moe_topk=self.top_k, apply_norm_weight=True, # apply_norm_weight enable_softmax_top_k_fused=False, redundant_ep_rank_num_plus_one=self.inference_args.redundant_experts_num + 1, ) else: topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( gate_out, ( self.gate_correction_bias if self.moe_config.moe_use_gate_correction_bias else None ), self.top_k, True, False, ) return topk_idx, topk_weights def micro_batch_dispatch(self, x, topk_idx, topk_weights, event): """ Run the micro-batch's all to all dispatch. Args: x (Tensor): The index of micro-batch. The shape is `[token, num_export]`. The data type should be bfloat16, float16 or float32. topk_idx (Tensor): The index of getting highest score's exports. The shape is `[token, topk]`. The data type should be int64. topk_weights (Tensor): The scores of getting highest score's exports. The shape is `[token, topk]`. The data type should be float32. event (EventOverlap): The event of execute dispatch communication """ (num_tokens_per_rank, _, num_tokens_per_expert, is_token_in_rank, _) = ( self.ep_engine.deepep_engine.get_dispatch_layout( topk_idx, self.num_experts + self.inference_args.redundant_experts_num, previous_event=event, async_finish=self.ep_engine.async_finish, allocate_on_comm_stream=self.ep_engine.async_finish, ) ) dispatch_args = { "x": x, "num_tokens_per_rank": num_tokens_per_rank, "is_token_in_rank": is_token_in_rank, "num_tokens_per_expert": num_tokens_per_expert, "config": self.ep_engine.ep_config, "async_finish": self.ep_engine.async_finish, "topk_idx": topk_idx, "topk_weights": topk_weights, "previous_event": event, "allocate_on_comm_stream": self.ep_engine.async_finish, } ( recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event, ) = self.ep_engine.deepep_engine.dispatch(**dispatch_args) return ( recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event, ) def micro_batch_ffn( self, recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, ): r""" Run the micro-batch's moe ffn. """ ( rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head, ) = handle token_all_num = sum(recv_num_tokens_per_expert_list) if self.moe_quant_type == "fp8": if token_all_num > 0: recv_num_tokens_per_expert_list_np = np.array( recv_num_tokens_per_expert_list ) recv_num_tokens_per_expert_list_padded = ( 128 - recv_num_tokens_per_expert_list_np % 128 + recv_num_tokens_per_expert_list_np ).tolist() token_padded_all = sum(recv_num_tokens_per_expert_list_padded) (recv_x, recv_x_scale) = recv_x ( permute_input, permute_scale, permute_indices_per_token, recv_num_tokens_per_expert_list_cumsum, recv_num_tokens_per_expert_list_padded_cumsum, dst_weights, dst_indices, cumsum_idx_gpu, m_indices, ) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch_fp8( recv_x, recv_x_scale, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, recv_num_tokens_per_expert_list_padded, token_all_num, token_padded_all, ) # ffn1 ffn_out = paddle.empty( (permute_input.shape[0], self.ffn1_weight_shape[1]), dtype=paddle.bfloat16, ) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (permute_input, permute_scale), (self.moe_ffn1_weight, self.moe_ffn1_weight_scale), ffn_out, m_indices, ) # swiglu ffn_out = paddle.incubate.nn.functional.swiglu(ffn_out, None) # ffn2 ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant( ffn_out, self.inference_args.weight_block_size[0] ) ffn_out = paddle.empty( (ffn_out.shape[0], self.ffn2_weight_shape[1]), dtype=paddle.bfloat16 ) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (ffn_in_x, ffn_in_x_scale_tensor), (self.moe_ffn2_weight, self.moe_ffn2_weight_scale), ffn_out, m_indices, ) # prmt back per rank tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine( ffn_out, dst_weights, permute_indices_per_token, dst_indices, self.moe_ffn2_bias, False, # norm_topk_prob 1.0, )[0] else: tmp_ffn_out = paddle.cast(recv_x, self._dtype) else: if token_all_num > 0: # token个数为0时不能走自定义算子 ( permute_input, permute_indices_per_token, recv_num_tokens_per_expert_list_cumsum, dst_weights, dst_indices, cumsum_idx_gpu, expert_idx_per_token, ) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch( recv_x, recv_topk_idx, recv_topk_weights, ( self.moe_ffn1_in_scale if hasattr(self, "moe_ffn1_in_scale") else None ), recv_num_tokens_per_expert_list, token_all_num, self.moe_quant_type, ) # moe ffn per rank ffn_out = fastdeploy.model_executor.ops.gpu.moe_expert_ffn( permute_input, recv_num_tokens_per_expert_list_cumsum, self.moe_ffn1_weight, self.moe_ffn2_weight, self.moe_ffn1_bias, ( self.moe_ffn1_weight_scale if hasattr(self, "moe_ffn1_weight_scale") else None ), ( self.moe_ffn2_weight_scale if hasattr(self, "moe_ffn2_weight_scale") else None ), ( self.moe_ffn2_in_scale if hasattr(self, "moe_ffn2_in_scale") else None ), expert_idx_per_token, self.moe_quant_type, False, # used_in_ep_low_latency ) # prmt back per rank tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine( ffn_out, dst_weights, permute_indices_per_token, dst_indices, self.moe_ffn2_bias, False, # norm_topk_prob 1.0, )[0] else: tmp_ffn_out = recv_x return tmp_ffn_out def micro_batch_combine(self, tmp_ffn_out, recv_topk_weights, handle, event): """ Run the micro-batch's all to all dispatch. """ combine_args = { "x": tmp_ffn_out, "handle": handle, "config": self.ep_engine.ep_config, "async_finish": self.ep_engine.async_finish, "topk_weights": recv_topk_weights, "previous_event": event, "allocate_on_comm_stream": self.ep_engine.async_finish, } before_norm_fused_moe_out, combined_topk_weights, event = ( self.ep_engine.deepep_engine.combine(**combine_args) ) return before_norm_fused_moe_out, combined_topk_weights, event def forward(self, x, **kwargs): """ PrefillMoeEPLayer Forward Function Args: x: [token_num, hidden_dim] """ topk_idx = None topk_weights = None gate_out = paddle.matmul(x.cast("float32"), self.gate_weight) # get topk if self.redundant_table_manger is not None: ( ep_rank_to_expert_id_list, expert_id_to_ep_rank_array, expert_in_rank_num_list, tokens_per_expert_stats_list, ) = self.redundant_table_manger.get_ep_rank_to_expert_id_list_by_layer( self.layer_idx ) topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.f_moe_redundant_topk_select( gating_logits=gate_out, expert_id_to_ep_rank_array=expert_id_to_ep_rank_array, expert_in_rank_num_list=expert_in_rank_num_list, tokens_per_expert_stats_list=tokens_per_expert_stats_list, bias=( self.gate_correction_bias if self.moe_config.moe_use_gate_correction_bias else None ), moe_topk=self.top_k, apply_norm_weight=True, # apply_norm_weight enable_softmax_top_k_fused=False, redundant_ep_rank_num_plus_one=self.inference_args.redundant_experts_num + 1, ) else: topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( gate_out, ( self.gate_correction_bias if self.moe_config.moe_use_gate_correction_bias else None ), self.top_k, True, # apply_norm_weight, False, ) # dispatch intranode (num_tokens_per_rank, _, num_tokens_per_expert, is_token_in_rank, _) = ( self.ep_engine.deepep_engine.get_dispatch_layout( topk_idx, self.num_experts + self.inference_args.redundant_experts_num ) ) if self.moe_quant_type == "fp8": x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant( x, self.inference_args.weight_block_size[0] ) # dispatch intranode dispatch_args = { "x": (x, x_scale_tensor), "num_tokens_per_rank": num_tokens_per_rank, "is_token_in_rank": is_token_in_rank, "num_tokens_per_expert": num_tokens_per_expert, "config": self.ep_engine.ep_config, "async_finish": self.ep_engine.async_finish, "topk_idx": topk_idx, "topk_weights": topk_weights, } ( recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event, ) = self.ep_engine.deepep_engine.dispatch(**dispatch_args) ( rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head, ) = handle # prmt per rank token_all_num = sum(recv_num_tokens_per_expert_list) if token_all_num > 0: recv_num_tokens_per_expert_list_np = np.array( recv_num_tokens_per_expert_list ) recv_num_tokens_per_expert_list_padded = ( 128 - recv_num_tokens_per_expert_list_np % 128 + recv_num_tokens_per_expert_list_np ).tolist() token_padded_all = sum(recv_num_tokens_per_expert_list_padded) (recv_x, recv_x_scale) = recv_x # token个数为0时不能走自定义算子 ( permute_input, permute_scale, permute_indices_per_token, recv_num_tokens_per_expert_list_cumsum, recv_num_tokens_per_expert_list_padded_cumsum, dst_weights, dst_indices, cumsum_idx_gpu, m_indices, ) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch_fp8( recv_x, recv_x_scale, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, recv_num_tokens_per_expert_list_padded, token_all_num, token_padded_all, ) # ffn1 ffn_out = paddle.empty( (permute_input.shape[0], self.ffn1_weight_shape[1]), dtype=paddle.bfloat16, ) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (permute_input, permute_scale), (self.moe_ffn1_weight, self.moe_ffn1_weight_scale), ffn_out, m_indices, ) # swiglu ffn_out = paddle.incubate.nn.functional.swiglu(ffn_out, None) # ffn2 ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant( ffn_out, self.inference_args.weight_block_size[0] ) ffn_out = paddle.empty( (ffn_out.shape[0], self.ffn2_weight_shape[1]), dtype=paddle.bfloat16 ) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (ffn_in_x, ffn_in_x_scale_tensor), (self.moe_ffn2_weight, self.moe_ffn2_weight_scale), ffn_out, m_indices, ) # prmt back per rank tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine( ffn_out, dst_weights, permute_indices_per_token, dst_indices, self.moe_ffn2_bias, False, # norm_topk_prob 1.0, )[0] else: tmp_ffn_out = paddle.cast(recv_x, self._dtype) # intranode combine combine_args = { "x": tmp_ffn_out, "handle": handle, "config": self.ep_engine.ep_config, "async_finish": self.ep_engine.async_finish, "topk_weights": recv_topk_weights, } fused_moe_out, combined_topk_weights, event = ( self.ep_engine.deepep_engine.combine(**combine_args) ) else: # dispatch intranode dispatch_args = { "x": x, "num_tokens_per_rank": num_tokens_per_rank, "is_token_in_rank": is_token_in_rank, "num_tokens_per_expert": num_tokens_per_expert, "config": self.ep_engine.ep_config, "async_finish": self.ep_engine.async_finish, "topk_idx": topk_idx, "topk_weights": topk_weights, } ( recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event, ) = self.ep_engine.deepep_engine.dispatch(**dispatch_args) ( rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head, ) = handle # prmt per rank token_all_num = sum(recv_num_tokens_per_expert_list) if token_all_num > 0: # token个数为0时不能走自定义算子 ( permute_input, permute_indices_per_token, recv_num_tokens_per_expert_list_cumsum, dst_weights, dst_indices, cumsum_idx_gpu, expert_idx_per_token, ) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch( recv_x, recv_topk_idx, recv_topk_weights, ( self.moe_ffn1_in_scale if hasattr(self, "moe_ffn1_in_scale") else None ), recv_num_tokens_per_expert_list, token_all_num, self.moe_quant_type, ) # moe ffn per rank ffn_out = fastdeploy.model_executor.ops.gpu.moe_expert_ffn( permute_input, recv_num_tokens_per_expert_list_cumsum, self.moe_ffn1_weight, self.moe_ffn2_weight, self.moe_ffn1_bias, ( self.moe_ffn1_weight_scale if hasattr(self, "moe_ffn1_weight_scale") else None ), ( self.moe_ffn2_weight_scale if hasattr(self, "moe_ffn2_weight_scale") else None ), ( self.moe_ffn2_in_scale if hasattr(self, "moe_ffn2_in_scale") else None ), expert_idx_per_token, self.moe_quant_type, False, # used_in_ep_low_latency ) # prmt back per rank tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine( ffn_out, dst_weights, permute_indices_per_token, dst_indices, self.moe_ffn2_bias, False, # norm_topk_prob 1.0, )[0] else: tmp_ffn_out = recv_x # intranode combine combine_args = { "x": tmp_ffn_out, "handle": handle, "config": self.ep_engine.ep_config, "async_finish": self.ep_engine.async_finish, "topk_weights": recv_topk_weights, } fused_moe_out, combined_topk_weights, event = ( self.ep_engine.deepep_engine.combine(**combine_args) ) return fused_moe_out class DecoderMoeEPLayer(MoeEPLayer): """ DecoderMoeEPLayer """ def __init__(self, *args, **kwargs): """ DecoderMoeEPLayer Init """ super().__init__(*args, **kwargs) def gate(self, x): """ Calculate gate """ topk_idx = None topk_weights = None gate_out = paddle.matmul(x.cast("float32"), self.gate_weight) if os.getenv("EP_DECODER_PERF_TEST", "False") == "True": gate_out = paddle.rand(shape=gate_out.shape, dtype=gate_out.dtype) if self.redundant_table_manger is not None: ( ep_rank_to_expert_id_list, expert_id_to_ep_rank_array, expert_in_rank_num_list, tokens_per_expert_stats_list, ) = self.redundant_table_manger.get_ep_rank_to_expert_id_list_by_layer( self.layer_idx ) topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.f_moe_redundant_topk_select( gating_logits=gate_out, expert_id_to_ep_rank_array=expert_id_to_ep_rank_array, expert_in_rank_num_list=expert_in_rank_num_list, tokens_per_expert_stats_list=tokens_per_expert_stats_list, bias=( self.gate_correction_bias if self.moe_config.moe_use_gate_correction_bias else None ), moe_topk=self.top_k, apply_norm_weight=True, # apply_norm_weight enable_softmax_top_k_fused=False, redundant_ep_rank_num_plus_one=self.inference_args.redundant_experts_num + 1, ) else: topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( gate_out, ( self.gate_correction_bias if self.moe_config.moe_use_gate_correction_bias else None ), self.top_k, True, # apply_norm_weight False, ) return topk_idx, topk_weights def ffn(self, permute_input, token_nums_per_expert): """ Calculate moe """ if self.moe_quant_type == "fp8": assert isinstance(permute_input, tuple) ffn1_out = paddle.empty( [ self.num_local_experts, self.ep_engine.num_ranks * self.ep_engine.num_max_dispatch_tokens_per_rank, self.moe_intermediate_size * 2, ], dtype=self._dtype, ) ffn_out = paddle.empty( [ self.num_local_experts, self.ep_engine.num_ranks * self.ep_engine.num_max_dispatch_tokens_per_rank, self.ep_engine.hidden, ], dtype=self._dtype, ) expected_m = 128 deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( permute_input, ( self.moe_ffn1_weight, self.moe_ffn1_weight_scale, ), ffn1_out, token_nums_per_expert, expected_m, ) act_out = fastdeploy.model_executor.ops.gpu.group_swiglu_with_masked( ffn1_out, token_nums_per_expert ) act_out_fp8, scale = fastdeploy.model_executor.ops.gpu.masked_per_token_quant( act_out, token_nums_per_expert, 128 ) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( (act_out_fp8, scale), ( self.moe_ffn2_weight, self.moe_ffn2_weight_scale, ), ffn_out, token_nums_per_expert, expected_m, ) else: expert_idx_per_token = None if self.moe_quant_type == "w4a8": # Note (zkk) num_local_experts, max_num, _ = permute_input.shape expert_idx_per_token = paddle.arange(num_local_experts)[:, None].tile( [1, max_num] ) ffn_out = fastdeploy.model_executor.ops.gpu.moe_expert_ffn( permute_input, token_nums_per_expert.cast("int64"), self.moe_ffn1_weight, self.moe_ffn2_weight, self.moe_ffn1_bias, ( self.moe_ffn1_weight_scale if hasattr(self, "moe_ffn1_weight_scale") else None ), ( self.moe_ffn2_weight_scale if hasattr(self, "moe_ffn2_weight_scale") else None ), ( self.moe_ffn2_in_scale if hasattr(self, "moe_ffn2_in_scale") else None ), expert_idx_per_token, self.moe_quant_type, True, # used_in_ep_low_latency ) return ffn_out def forward(self, x, **kwargs): """ DecoderMoeEPLayer Forward (Not micro-batch) """ topk_idx, topk_weights = self.gate(x) moe_in_w4a8_scale = None if self.moe_quant_type == "w4a8": moe_in_w4a8_scale = [] dist.all_gather(moe_in_w4a8_scale, self.moe_ffn1_in_scale) moe_in_w4a8_scale = paddle.concat(moe_in_w4a8_scale, axis=0) recv_hidden_states, recv_expert_count, handle, dispatch_hook = ( self.ep_engine.low_latency_dispatch( x, topk_idx, moe_in_w4a8_scale, self.moe_quant_type == "fp8" ) ) if dispatch_hook is not None: dispatch_hook() ffn_out = self.ffn(recv_hidden_states, recv_expert_count) combined_hidden_states, combine_hook = self.ep_engine.low_latency_combine( ffn_out, topk_idx, topk_weights, handle ) if combine_hook is not None: combine_hook() return combined_hidden_states class DecoderEPMicroBatchRunner: """ DecoderEPMicroBatchRunner """ def __init__(self, moe_layers: list, ep_engine: DeepEPEngine): """ """ self.moe_layers = moe_layers self.ep_engine = ep_engine self.recv_hidden_states = None self.recv_expert_count = None self.combined_hidden_states = None self.handle = None self.dispatch_hook = None self.combine_hook = None self.topk_idx = None self.topk_weights = None self.ffn_out = None def dispatch_issue(self, x, topk_idx, topk_weights, layer_idx): """ issue dispatch """ self.topk_idx = topk_idx self.topk_weights = topk_weights ( self.recv_hidden_states, self.recv_expert_count, self.handle, self.dispatch_hook, ) = self.ep_engine.low_latency_dispatch( x, self.topk_idx, self.moe_layers[layer_idx].moe_quant_type == "fp8" ) def dispatch_hook_wrap(self): """ """ self.dispatch_hook() self.dispatch_hook = None def ffn(self, layer_idx): """ """ self.ffn_out = self.moe_layers[layer_idx].ffn( self.recv_hidden_states, self.recv_expert_count ) self.recv_hidden_states = None self.recv_expert_count = None def combine_issue(self): """ """ self.combined_hidden_states, self.combine_hook = ( self.ep_engine.low_latency_combine( self.ffn_out, self.topk_idx, self.topk_weights, self.handle ) ) def combine_hook_wrap(self): """ """ self.combine_hook() self.combine_hook = None self.ffn_out = None self.topk_idx = None self.topk_weights = None self.handle = None combine_out = self.combined_hidden_states self.combined_hidden_states = None return combine_out