diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index c2d076d0d..752ead74f 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -68,8 +68,7 @@ class DeepEPEngine: self.num_local_experts = num_experts // ep_size self.async_finish = async_finish - self.prefill_deepep_engine = None - self.decode_deepep_engine = None + self.deepep_engine = None self.ep_config = Config(24, 6, 256) self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank @@ -77,16 +76,12 @@ class DeepEPEngine: # In mixed EP mode on a single node, we dynamically switch between # high throughput and low latency modes. if splitwise_role == "mixed": - # decode engine - logger.info("Initializing Low Latency Buffer") - self.get_low_latency_buffer() - # prefill engine - self.prefill_deepep_engine = deep_ep.Buffer( + self.deepep_engine = deep_ep.Buffer( self.group, - int(5e8), - 0, - low_latency_mode=False, - num_qps_per_rank=1, + int(2e9), + int(5e9), + low_latency_mode=True, + num_qps_per_rank=24, ) # In disaggregated mode on mutiple nodes, we either use # high throughput mode or low latency mode. @@ -95,7 +90,7 @@ class DeepEPEngine: logger.info("Initializing Low Latency Buffer") self.get_low_latency_buffer() elif moe_phase.phase == "prefill": - self.prefill_deepep_engine = deep_ep.Buffer( + self.deepep_engine = deep_ep.Buffer( self.group, int(5e8), 0, @@ -124,14 +119,14 @@ class DeepEPEngine: ) # Allocate a buffer if not existed or not enough buffer size if ( - self.decode_deepep_engine is None - or self.decode_deepep_engine.group != self.group - or not self.decode_deepep_engine.low_latency_mode - or self.decode_deepep_engine.num_rdma_bytes < num_rdma_bytes + self.deepep_engine is None + or self.deepep_engine.group != self.group + or not self.deepep_engine.low_latency_mode + or self.deepep_engine.num_rdma_bytes < num_rdma_bytes ): # NOTES: for best performance, the QP number **must** be equal to the number of the local experts assert self.num_experts % self.ep_size == 0 - self.decode_deepep_engine = deep_ep.Buffer( + self.deepep_engine = deep_ep.Buffer( self.group, 0, num_rdma_bytes, @@ -168,7 +163,7 @@ class DeepEPEngine: handle, _, dispatch_hook, - ) = self.decode_deepep_engine.low_latency_dispatch( + ) = self.deepep_engine.low_latency_dispatch( hidden_states, topk_idx, expertwise_scale, @@ -210,7 +205,7 @@ class DeepEPEngine: num_experts, ) - combined_hidden_states, _, combine_hook = self.decode_deepep_engine.low_latency_combine( + combined_hidden_states, _, combine_hook = self.deepep_engine.low_latency_combine( hidden_states, topk_idx, topk_weights, @@ -224,7 +219,7 @@ class DeepEPEngine: """ clean_low_latency_buffer """ - self.decode_deepep_engine.clean_low_latency_buffer( + self.deepep_engine.clean_low_latency_buffer( self.num_max_dispatch_tokens_per_rank, self.hidden, self.num_experts ) @@ -232,11 +227,7 @@ class DeepEPEngine: """ barrier_all """ - if self.prefill_deepep_engine is not None: - self.prefill_deepep_engine.barrier_all() - - if self.decode_deepep_engine is not None: - self.decode_deepep_engine.barrier_all() + self.deepep_engine.barrier_all() class EPRunner: @@ -316,6 +307,9 @@ class EPRunner: """ raise NotImplementedError + def clean_low_latency_buffer(self): + self.ep_engine.clean_low_latency_buffer() + class EPPrefillRunner(EPRunner): """ @@ -328,6 +322,7 @@ class EPPrefillRunner(EPRunner): hidden: int, num_experts: int, splitwise_role: str, + num_max_dispatch_tokens_per_rank: int, ep_size: int = 1, ep_rank: int = 0, redundant_experts_num: int = 0, @@ -339,7 +334,7 @@ class EPPrefillRunner(EPRunner): num_experts, splitwise_role, moe_phase, - num_max_dispatch_tokens_per_rank=256, + num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank, ep_size=ep_size, ep_rank=ep_rank, redundant_experts_num=redundant_experts_num, @@ -359,7 +354,7 @@ class EPPrefillRunner(EPRunner): num_tokens_per_expert, is_token_in_rank, _, - ) = self.ep_engine.prefill_deepep_engine.get_dispatch_layout(topk_idx, self.num_experts) + ) = self.ep_engine.deepep_engine.get_dispatch_layout(topk_idx, self.num_experts) x_scale_tensor = kwargs.get("x_scale_tensor", None) dispatch_args = { @@ -372,7 +367,7 @@ class EPPrefillRunner(EPRunner): "topk_idx": topk_idx, "topk_weights": topk_weights, } - return self.ep_engine.prefill_deepep_engine.dispatch(**dispatch_args) + return self.ep_engine.deepep_engine.dispatch(**dispatch_args) def combine( self, @@ -387,14 +382,14 @@ class EPPrefillRunner(EPRunner): "async_finish": self.ep_engine.async_finish, "topk_weights": recv_topk_weights, } - fused_moe_out, _, _ = self.ep_engine.prefill_deepep_engine.combine(**combine_args) + fused_moe_out, _, _ = self.ep_engine.deepep_engine.combine(**combine_args) return fused_moe_out class EPDecoderRunner(EPRunner): """ - EPPrefillRunner + EPDecoderRunner """ def __init__( diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py index fe81c0616..391f8b3f3 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py @@ -51,6 +51,7 @@ class MoEMethodBase(QuantMethodBase): layer.hidden_size, layer.num_experts, layer.fd_config.parallel_config.splitwise_role, + layer.fd_config.model_config.num_max_dispatch_tokens_per_rank, layer.ep_size, layer.ep_rank, layer.fd_config.model_config.redundant_experts_num, @@ -74,6 +75,7 @@ class MoEMethodBase(QuantMethodBase): layer.hidden_size, layer.num_experts, layer.fd_config.parallel_config.splitwise_role, + layer.fd_config.model_config.num_max_dispatch_tokens_per_rank, layer.ep_size, layer.ep_rank, layer.fd_config.model_config.redundant_experts_num, @@ -165,8 +167,10 @@ class MoEMethodBase(QuantMethodBase): """ if layer.ep_size > 1: if layer.fd_config.parallel_config.moe_phase.phase == "prefill": + self.ep_prefill_runner.clean_low_latency_buffer() return self.apply_ep_prefill(layer, x, gate_out) else: + self.ep_decoder_runner.clean_low_latency_buffer() return self.apply_ep_decode(layer, x, gate_out) else: return self.apply_tp(layer, x, gate_out)