fix ep prefill (#2762)

This commit is contained in:
RichardWooSJTU
2025-07-09 14:03:05 +08:00
committed by GitHub
parent c4718fd693
commit fee544e808
7 changed files with 66 additions and 32 deletions

View File

@@ -158,7 +158,8 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
const paddle::Tensor &input, const paddle::Tensor &scale, const paddle::Tensor &input, const paddle::Tensor &scale,
const paddle::Tensor &topk_ids, const paddle::Tensor &topk_weights, const paddle::Tensor &topk_ids, const paddle::Tensor &topk_weights,
const paddle::Tensor &token_nums_per_expert, const paddle::Tensor &token_nums_per_expert,
const paddle::Tensor &token_nums_per_expert_padded); const paddle::Tensor &token_nums_per_expert_padded,
const bool use_in_ep, const int token_nums_this_rank_padded);
std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input, std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
const int block_size); const int block_size);

View File

@@ -870,7 +870,9 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
const paddle::Tensor& topk_ids, const paddle::Tensor& topk_ids,
const paddle::Tensor& topk_weights, const paddle::Tensor& topk_weights,
const paddle::Tensor& num_experts_per_rank_tensor, const paddle::Tensor& num_experts_per_rank_tensor,
const paddle::Tensor& num_experts_per_rank_padded_tensor) { const paddle::Tensor& num_experts_per_rank_padded_tensor,
const bool use_in_ep,
const int token_nums_this_rank_padded) {
const auto input_type = input.dtype(); const auto input_type = input.dtype();
const int moe_topk = topk_ids.dims()[1]; const int moe_topk = topk_ids.dims()[1];
auto place = input.place(); auto place = input.place();
@@ -886,22 +888,21 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
const int hidden_size = input.dims()[input_dims.size() - 1]; const int hidden_size = input.dims()[input_dims.size() - 1];
const int num_experts_per_rank = num_experts_per_rank_tensor.dims()[0]; const int num_experts_per_rank = num_experts_per_rank_tensor.dims()[0];
int32_t token_nums_this_rank_padded = token_rows * moe_topk + num_experts_per_rank * (128-1); int32_t token_nums_feed_to_ffn = use_in_ep ? token_nums_this_rank_padded : token_rows * moe_topk + num_experts_per_rank * (128-1);
// token_nums_this_rank_padded = token_nums_this_rank_padded_useless;
auto permute_input = GetEmptyTensor( auto permute_input = GetEmptyTensor(
{token_nums_this_rank_padded, hidden_size}, {token_nums_feed_to_ffn, hidden_size},
input_type, input_type,
place); place);
auto permute_scale = GetEmptyTensor( auto permute_scale = GetEmptyTensor(
{token_nums_this_rank_padded, hidden_size / 128}, {token_nums_feed_to_ffn, hidden_size / 128},
paddle::DataType::FLOAT32, paddle::DataType::FLOAT32,
place); place);
auto m_indices = paddle::full({token_nums_this_rank_padded}, -1, paddle::DataType::INT32, place); auto m_indices = paddle::full({token_nums_feed_to_ffn}, -1, paddle::DataType::INT32, place);
auto token_nums_per_expert_cumsum = GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place); auto token_nums_per_expert_cumsum = GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place);
auto token_nums_per_expert_padded_cumsum = GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place); auto token_nums_per_expert_padded_cumsum = GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place);
auto dst_weights = GetEmptyTensor({token_nums_this_rank_padded}, paddle::DataType::FLOAT32, place); auto dst_weights = GetEmptyTensor({token_nums_feed_to_ffn}, paddle::DataType::FLOAT32, place);
auto dst_indices = GetEmptyTensor({num_rows, num_experts_per_rank}, paddle::DataType::INT32, place); auto dst_indices = GetEmptyTensor({num_rows, num_experts_per_rank}, paddle::DataType::INT32, place);
auto permute_indices_per_token = paddle::full({num_experts_per_rank, num_rows}, -1, paddle::DataType::INT32, place); auto permute_indices_per_token = paddle::full({num_experts_per_rank, num_rows}, -1, paddle::DataType::INT32, place);
auto cumsum_idx_gpu = paddle::full({num_experts_per_rank}, 0, paddle::DataType::INT32, place); auto cumsum_idx_gpu = paddle::full({num_experts_per_rank}, 0, paddle::DataType::INT32, place);
@@ -949,4 +950,5 @@ PD_BUILD_STATIC_OP(ep_moe_expert_dispatch_fp8)
"dst_indices", "dst_indices",
"cumsum_idx_gpu", "cumsum_idx_gpu",
"m_indices"}) "m_indices"})
.Attrs({"use_in_ep:bool", "token_nums_this_rank_padded:int"})
.SetKernelFn(PD_KERNEL(EPMoeExpertDispatchFP8)); .SetKernelFn(PD_KERNEL(EPMoeExpertDispatchFP8));

View File

@@ -33,6 +33,7 @@ from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import ( from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend, AttentionMetadata) AttentionBackend, AttentionMetadata)
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
from fastdeploy.worker.forward_meta import ForwardMeta from fastdeploy.worker.forward_meta import ForwardMeta
@@ -91,7 +92,6 @@ class AppendAttentionBackend(AttentionBackend):
self.use_speculate: bool = self.speculative_method is not None self.use_speculate: bool = self.speculative_method is not None
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
self.kv_num_heads: int = kv_num_heads self.kv_num_heads: int = kv_num_heads
self.num_heads: int = num_heads self.num_heads: int = num_heads
@@ -104,16 +104,11 @@ class AppendAttentionBackend(AttentionBackend):
self.use_pd_disaggregation: int = int( self.use_pd_disaggregation: int = int(
os.getenv("FLAGS_use_pd_disaggregation", 0)) os.getenv("FLAGS_use_pd_disaggregation", 0))
self.start_layer_index: int = fd_config.model_config.start_layer_index self.start_layer_index: int = fd_config.model_config.start_layer_index
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None)
if fd_config.parallel_config.expert_parallel_rank is None: if fd_config.parallel_config.expert_parallel_rank is None:
fd_config.parallel_config.expert_parallel_rank = 0 fd_config.parallel_config.expert_parallel_rank = 0
device_id = self.rank + fd_config.parallel_config.tensor_parallel_degree * \
fd_config.parallel_config.expert_parallel_rank self.rank, self.device_id = init_rank_and_device_id(fd_config)
if self.device_id is None:
self.device_id = device_id
else:
self.device_id = self.device_id.split(",")[device_id]
def init_attention_metadata(self, forward_meta: ForwardMeta): def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it.""" """Initialize attntion metadata hence all layers in the forward pass can reuse it."""

View File

@@ -34,6 +34,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
from fastdeploy.model_executor.layers.attention.ops import ( from fastdeploy.model_executor.layers.attention.ops import (
get_block_shape_and_split_kv_block, gqa_rope_write_cache, get_block_shape_and_split_kv_block, gqa_rope_write_cache,
init_signal_layerwise, open_shm_and_get_meta_signal, pre_cache_len_concat) init_signal_layerwise, open_shm_and_get_meta_signal, pre_cache_len_concat)
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
from fastdeploy.worker.forward_meta import ForwardMeta from fastdeploy.worker.forward_meta import ForwardMeta
@@ -100,22 +101,16 @@ class FlashAttentionBackend(AttentionBackend):
self.use_speculate = self.speculative_method is not None self.use_speculate = self.speculative_method is not None
self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
# pd_disaggregation # pd_disaggregation
self.use_pd_disaggregation: int = int( self.use_pd_disaggregation: int = int(
os.getenv("FLAGS_use_pd_disaggregation", 0)) os.getenv("FLAGS_use_pd_disaggregation", 0))
self.start_layer_index: int = fd_config.model_config.start_layer_index self.start_layer_index: int = fd_config.model_config.start_layer_index
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None)
if fd_config.parallel_config.expert_parallel_rank is None: if fd_config.parallel_config.expert_parallel_rank is None:
fd_config.parallel_config.expert_parallel_rank = 0 fd_config.parallel_config.expert_parallel_rank = 0
device_id = self.rank + fd_config.parallel_config.tensor_parallel_degree * \
fd_config.parallel_config.expert_parallel_rank self.rank, self.device_id = init_rank_and_device_id(fd_config)
if self.device_id is None:
self.device_id = device_id
else:
self.device_id = self.device_id.split(",")[device_id]
def get_attntion_meta(self): def get_attntion_meta(self):
"""get_attntion_meta""" """get_attntion_meta"""

View File

@@ -41,6 +41,7 @@ from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import ( from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend, AttentionMetadata) AttentionBackend, AttentionMetadata)
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
from fastdeploy.worker.forward_meta import ForwardMeta from fastdeploy.worker.forward_meta import ForwardMeta
@@ -109,7 +110,6 @@ class MLAAttentionBackend(AttentionBackend):
self.use_speculate: bool = self.speculative_method is not None self.use_speculate: bool = self.speculative_method is not None
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
self.kv_num_heads: int = kv_num_heads self.kv_num_heads: int = kv_num_heads
self.num_heads: int = num_heads self.num_heads: int = num_heads
@@ -135,10 +135,8 @@ class MLAAttentionBackend(AttentionBackend):
os.getenv("FLAGS_use_pd_disaggregation", 0)) os.getenv("FLAGS_use_pd_disaggregation", 0))
self.start_layer_index: int = fd_config.model_config.start_layer_index self.start_layer_index: int = fd_config.model_config.start_layer_index
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None) self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None)
if self.device_id is None:
self.device_id = self.rank self.rank, self.device_id = init_rank_and_device_id(fd_config)
else:
self.device_id = self.device_id.split(",")[self.rank]
def init_attention_metadata(self, forward_meta: ForwardMeta): def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attention metadata hence all layers in the forward pass can reuse it.""" """Initialize attention metadata hence all layers in the forward pass can reuse it."""

View File

@@ -0,0 +1,36 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import os
from fastdeploy.config import FDConfig
def init_rank_and_device_id(fd_config: FDConfig):
"""
"""
rank = (fd_config.parallel_config.expert_parallel_rank *
fd_config.parallel_config.tensor_parallel_degree + fd_config.parallel_config.tensor_parallel_rank)
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", None)
if cuda_visible_devices is None:
device_id = rank
else:
cuda_visible_devices = cuda_visible_devices.split(",")
rank_index = rank % len(cuda_visible_devices)
device_id = cuda_visible_devices[rank_index]
return rank, device_id

View File

@@ -144,7 +144,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
if token_all_num > 0: if token_all_num > 0:
logger.info(f"token_all_num {token_all_num}") logger.info(f"token_all_num {token_all_num}")
(recv_x, recv_x_scale) = recv_x (recv_x, recv_x_scale) = recv_x
tmp = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts)
token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts)
token_nums_this_rank_padded = sum(token_nums_this_rank[1].numpy().tolist())
( (
permute_input, permute_input,
permute_scale, permute_scale,
@@ -160,8 +163,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
recv_x_scale, recv_x_scale,
recv_topk_idx, recv_topk_idx,
recv_topk_weights, recv_topk_weights,
tmp[0], token_nums_this_rank[0],
tmp[1] token_nums_this_rank[1],
True, # use_in_ep
token_nums_this_rank_padded,
) )
permute_scale = permute_scale.transpose([1, 0]).contiguous() permute_scale = permute_scale.transpose([1, 0]).contiguous()
@@ -328,6 +333,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
topk_weights, topk_weights,
tmp[0], tmp[0],
tmp[1], tmp[1],
False, # use_in_ep
-1,
) )
permute_scale = permute_scale.transpose([1, 0]).contiguous() permute_scale = permute_scale.transpose([1, 0]).contiguous()