diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index b4eeb1fc7..9927f31a9 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -158,7 +158,8 @@ std::vector EPMoeExpertDispatchFP8( const paddle::Tensor &input, const paddle::Tensor &scale, const paddle::Tensor &topk_ids, const paddle::Tensor &topk_weights, const paddle::Tensor &token_nums_per_expert, - const paddle::Tensor &token_nums_per_expert_padded); + const paddle::Tensor &token_nums_per_expert_padded, + const bool use_in_ep, const int token_nums_this_rank_padded); std::vector PerTokenQuant(paddle::Tensor &input, const int block_size); diff --git a/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu b/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu index 105fa79b8..09e006cdc 100644 --- a/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu +++ b/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu @@ -870,7 +870,9 @@ std::vector EPMoeExpertDispatchFP8( const paddle::Tensor& topk_ids, const paddle::Tensor& topk_weights, const paddle::Tensor& num_experts_per_rank_tensor, - const paddle::Tensor& num_experts_per_rank_padded_tensor) { + const paddle::Tensor& num_experts_per_rank_padded_tensor, + const bool use_in_ep, + const int token_nums_this_rank_padded) { const auto input_type = input.dtype(); const int moe_topk = topk_ids.dims()[1]; auto place = input.place(); @@ -886,22 +888,21 @@ std::vector EPMoeExpertDispatchFP8( const int hidden_size = input.dims()[input_dims.size() - 1]; const int num_experts_per_rank = num_experts_per_rank_tensor.dims()[0]; - int32_t token_nums_this_rank_padded = token_rows * moe_topk + num_experts_per_rank * (128-1); - // token_nums_this_rank_padded = token_nums_this_rank_padded_useless; + int32_t token_nums_feed_to_ffn = use_in_ep ? token_nums_this_rank_padded : token_rows * moe_topk + num_experts_per_rank * (128-1); auto permute_input = GetEmptyTensor( - {token_nums_this_rank_padded, hidden_size}, + {token_nums_feed_to_ffn, hidden_size}, input_type, place); auto permute_scale = GetEmptyTensor( - {token_nums_this_rank_padded, hidden_size / 128}, + {token_nums_feed_to_ffn, hidden_size / 128}, paddle::DataType::FLOAT32, place); - auto m_indices = paddle::full({token_nums_this_rank_padded}, -1, paddle::DataType::INT32, place); + auto m_indices = paddle::full({token_nums_feed_to_ffn}, -1, paddle::DataType::INT32, place); auto token_nums_per_expert_cumsum = GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place); auto token_nums_per_expert_padded_cumsum = GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place); - auto dst_weights = GetEmptyTensor({token_nums_this_rank_padded}, paddle::DataType::FLOAT32, place); + auto dst_weights = GetEmptyTensor({token_nums_feed_to_ffn}, paddle::DataType::FLOAT32, place); auto dst_indices = GetEmptyTensor({num_rows, num_experts_per_rank}, paddle::DataType::INT32, place); auto permute_indices_per_token = paddle::full({num_experts_per_rank, num_rows}, -1, paddle::DataType::INT32, place); auto cumsum_idx_gpu = paddle::full({num_experts_per_rank}, 0, paddle::DataType::INT32, place); @@ -949,4 +950,5 @@ PD_BUILD_STATIC_OP(ep_moe_expert_dispatch_fp8) "dst_indices", "cumsum_idx_gpu", "m_indices"}) + .Attrs({"use_in_ep:bool", "token_nums_this_rank_padded:int"}) .SetKernelFn(PD_KERNEL(EPMoeExpertDispatchFP8)); diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index 5bc7f420a..9210ff241 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -33,6 +33,7 @@ from fastdeploy.config import FDConfig from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, AttentionMetadata) +from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id from fastdeploy.worker.forward_meta import ForwardMeta @@ -91,7 +92,6 @@ class AppendAttentionBackend(AttentionBackend): self.use_speculate: bool = self.speculative_method is not None self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" - self.rank: int = fd_config.parallel_config.tensor_parallel_rank self.kv_num_heads: int = kv_num_heads self.num_heads: int = num_heads @@ -104,16 +104,11 @@ class AppendAttentionBackend(AttentionBackend): self.use_pd_disaggregation: int = int( os.getenv("FLAGS_use_pd_disaggregation", 0)) self.start_layer_index: int = fd_config.model_config.start_layer_index - self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None) if fd_config.parallel_config.expert_parallel_rank is None: fd_config.parallel_config.expert_parallel_rank = 0 - device_id = self.rank + fd_config.parallel_config.tensor_parallel_degree * \ - fd_config.parallel_config.expert_parallel_rank - if self.device_id is None: - self.device_id = device_id - else: - self.device_id = self.device_id.split(",")[device_id] + + self.rank, self.device_id = init_rank_and_device_id(fd_config) def init_attention_metadata(self, forward_meta: ForwardMeta): """Initialize attntion metadata hence all layers in the forward pass can reuse it.""" diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index e0aef5ae0..b68d6cab4 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -34,6 +34,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import ( from fastdeploy.model_executor.layers.attention.ops import ( get_block_shape_and_split_kv_block, gqa_rope_write_cache, init_signal_layerwise, open_shm_and_get_meta_signal, pre_cache_len_concat) +from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id from fastdeploy.worker.forward_meta import ForwardMeta @@ -100,22 +101,16 @@ class FlashAttentionBackend(AttentionBackend): self.use_speculate = self.speculative_method is not None self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" - self.rank: int = fd_config.parallel_config.tensor_parallel_rank # pd_disaggregation self.use_pd_disaggregation: int = int( os.getenv("FLAGS_use_pd_disaggregation", 0)) self.start_layer_index: int = fd_config.model_config.start_layer_index - self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None) if fd_config.parallel_config.expert_parallel_rank is None: fd_config.parallel_config.expert_parallel_rank = 0 - device_id = self.rank + fd_config.parallel_config.tensor_parallel_degree * \ - fd_config.parallel_config.expert_parallel_rank - if self.device_id is None: - self.device_id = device_id - else: - self.device_id = self.device_id.split(",")[device_id] + + self.rank, self.device_id = init_rank_and_device_id(fd_config) def get_attntion_meta(self): """get_attntion_meta""" diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index 1d9c9773b..8c9ce9302 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -41,6 +41,7 @@ from fastdeploy.config import FDConfig from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, AttentionMetadata) +from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id from fastdeploy.worker.forward_meta import ForwardMeta @@ -109,7 +110,6 @@ class MLAAttentionBackend(AttentionBackend): self.use_speculate: bool = self.speculative_method is not None self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" - self.rank: int = fd_config.parallel_config.tensor_parallel_rank self.kv_num_heads: int = kv_num_heads self.num_heads: int = num_heads @@ -135,10 +135,8 @@ class MLAAttentionBackend(AttentionBackend): os.getenv("FLAGS_use_pd_disaggregation", 0)) self.start_layer_index: int = fd_config.model_config.start_layer_index self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None) - if self.device_id is None: - self.device_id = self.rank - else: - self.device_id = self.device_id.split(",")[self.rank] + + self.rank, self.device_id = init_rank_and_device_id(fd_config) def init_attention_metadata(self, forward_meta: ForwardMeta): """Initialize attention metadata hence all layers in the forward pass can reuse it.""" diff --git a/fastdeploy/model_executor/layers/attention/utils.py b/fastdeploy/model_executor/layers/attention/utils.py new file mode 100644 index 000000000..1ba93e3bb --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/utils.py @@ -0,0 +1,36 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os +from fastdeploy.config import FDConfig + +def init_rank_and_device_id(fd_config: FDConfig): + """ + + """ + rank = (fd_config.parallel_config.expert_parallel_rank * + fd_config.parallel_config.tensor_parallel_degree + fd_config.parallel_config.tensor_parallel_rank) + + cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", None) + + if cuda_visible_devices is None: + device_id = rank + else: + cuda_visible_devices = cuda_visible_devices.split(",") + rank_index = rank % len(cuda_visible_devices) + device_id = cuda_visible_devices[rank_index] + + return rank, device_id diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index c3bb8d3f1..dc01f1714 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -144,7 +144,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): if token_all_num > 0: logger.info(f"token_all_num {token_all_num}") (recv_x, recv_x_scale) = recv_x - tmp = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts) + + token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts) + token_nums_this_rank_padded = sum(token_nums_this_rank[1].numpy().tolist()) + ( permute_input, permute_scale, @@ -160,8 +163,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): recv_x_scale, recv_topk_idx, recv_topk_weights, - tmp[0], - tmp[1] + token_nums_this_rank[0], + token_nums_this_rank[1], + True, # use_in_ep + token_nums_this_rank_padded, ) permute_scale = permute_scale.transpose([1, 0]).contiguous() @@ -328,6 +333,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): topk_weights, tmp[0], tmp[1], + False, # use_in_ep + -1, ) permute_scale = permute_scale.transpose([1, 0]).contiguous()