mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
fix ep prefill (#2762)
This commit is contained in:
@@ -158,7 +158,8 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
|
|||||||
const paddle::Tensor &input, const paddle::Tensor &scale,
|
const paddle::Tensor &input, const paddle::Tensor &scale,
|
||||||
const paddle::Tensor &topk_ids, const paddle::Tensor &topk_weights,
|
const paddle::Tensor &topk_ids, const paddle::Tensor &topk_weights,
|
||||||
const paddle::Tensor &token_nums_per_expert,
|
const paddle::Tensor &token_nums_per_expert,
|
||||||
const paddle::Tensor &token_nums_per_expert_padded);
|
const paddle::Tensor &token_nums_per_expert_padded,
|
||||||
|
const bool use_in_ep, const int token_nums_this_rank_padded);
|
||||||
|
|
||||||
std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
|
std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
|
||||||
const int block_size);
|
const int block_size);
|
||||||
|
@@ -870,7 +870,9 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
|
|||||||
const paddle::Tensor& topk_ids,
|
const paddle::Tensor& topk_ids,
|
||||||
const paddle::Tensor& topk_weights,
|
const paddle::Tensor& topk_weights,
|
||||||
const paddle::Tensor& num_experts_per_rank_tensor,
|
const paddle::Tensor& num_experts_per_rank_tensor,
|
||||||
const paddle::Tensor& num_experts_per_rank_padded_tensor) {
|
const paddle::Tensor& num_experts_per_rank_padded_tensor,
|
||||||
|
const bool use_in_ep,
|
||||||
|
const int token_nums_this_rank_padded) {
|
||||||
const auto input_type = input.dtype();
|
const auto input_type = input.dtype();
|
||||||
const int moe_topk = topk_ids.dims()[1];
|
const int moe_topk = topk_ids.dims()[1];
|
||||||
auto place = input.place();
|
auto place = input.place();
|
||||||
@@ -886,22 +888,21 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
|
|||||||
const int hidden_size = input.dims()[input_dims.size() - 1];
|
const int hidden_size = input.dims()[input_dims.size() - 1];
|
||||||
const int num_experts_per_rank = num_experts_per_rank_tensor.dims()[0];
|
const int num_experts_per_rank = num_experts_per_rank_tensor.dims()[0];
|
||||||
|
|
||||||
int32_t token_nums_this_rank_padded = token_rows * moe_topk + num_experts_per_rank * (128-1);
|
int32_t token_nums_feed_to_ffn = use_in_ep ? token_nums_this_rank_padded : token_rows * moe_topk + num_experts_per_rank * (128-1);
|
||||||
// token_nums_this_rank_padded = token_nums_this_rank_padded_useless;
|
|
||||||
|
|
||||||
auto permute_input = GetEmptyTensor(
|
auto permute_input = GetEmptyTensor(
|
||||||
{token_nums_this_rank_padded, hidden_size},
|
{token_nums_feed_to_ffn, hidden_size},
|
||||||
input_type,
|
input_type,
|
||||||
place);
|
place);
|
||||||
auto permute_scale = GetEmptyTensor(
|
auto permute_scale = GetEmptyTensor(
|
||||||
{token_nums_this_rank_padded, hidden_size / 128},
|
{token_nums_feed_to_ffn, hidden_size / 128},
|
||||||
paddle::DataType::FLOAT32,
|
paddle::DataType::FLOAT32,
|
||||||
place);
|
place);
|
||||||
|
|
||||||
auto m_indices = paddle::full({token_nums_this_rank_padded}, -1, paddle::DataType::INT32, place);
|
auto m_indices = paddle::full({token_nums_feed_to_ffn}, -1, paddle::DataType::INT32, place);
|
||||||
auto token_nums_per_expert_cumsum = GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place);
|
auto token_nums_per_expert_cumsum = GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place);
|
||||||
auto token_nums_per_expert_padded_cumsum = GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place);
|
auto token_nums_per_expert_padded_cumsum = GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place);
|
||||||
auto dst_weights = GetEmptyTensor({token_nums_this_rank_padded}, paddle::DataType::FLOAT32, place);
|
auto dst_weights = GetEmptyTensor({token_nums_feed_to_ffn}, paddle::DataType::FLOAT32, place);
|
||||||
auto dst_indices = GetEmptyTensor({num_rows, num_experts_per_rank}, paddle::DataType::INT32, place);
|
auto dst_indices = GetEmptyTensor({num_rows, num_experts_per_rank}, paddle::DataType::INT32, place);
|
||||||
auto permute_indices_per_token = paddle::full({num_experts_per_rank, num_rows}, -1, paddle::DataType::INT32, place);
|
auto permute_indices_per_token = paddle::full({num_experts_per_rank, num_rows}, -1, paddle::DataType::INT32, place);
|
||||||
auto cumsum_idx_gpu = paddle::full({num_experts_per_rank}, 0, paddle::DataType::INT32, place);
|
auto cumsum_idx_gpu = paddle::full({num_experts_per_rank}, 0, paddle::DataType::INT32, place);
|
||||||
@@ -949,4 +950,5 @@ PD_BUILD_STATIC_OP(ep_moe_expert_dispatch_fp8)
|
|||||||
"dst_indices",
|
"dst_indices",
|
||||||
"cumsum_idx_gpu",
|
"cumsum_idx_gpu",
|
||||||
"m_indices"})
|
"m_indices"})
|
||||||
|
.Attrs({"use_in_ep:bool", "token_nums_this_rank_padded:int"})
|
||||||
.SetKernelFn(PD_KERNEL(EPMoeExpertDispatchFP8));
|
.SetKernelFn(PD_KERNEL(EPMoeExpertDispatchFP8));
|
||||||
|
@@ -33,6 +33,7 @@ from fastdeploy.config import FDConfig
|
|||||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||||
AttentionBackend, AttentionMetadata)
|
AttentionBackend, AttentionMetadata)
|
||||||
|
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
|
||||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||||
|
|
||||||
|
|
||||||
@@ -91,7 +92,6 @@ class AppendAttentionBackend(AttentionBackend):
|
|||||||
self.use_speculate: bool = self.speculative_method is not None
|
self.use_speculate: bool = self.speculative_method is not None
|
||||||
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
|
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
|
||||||
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
||||||
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
|
|
||||||
|
|
||||||
self.kv_num_heads: int = kv_num_heads
|
self.kv_num_heads: int = kv_num_heads
|
||||||
self.num_heads: int = num_heads
|
self.num_heads: int = num_heads
|
||||||
@@ -104,16 +104,11 @@ class AppendAttentionBackend(AttentionBackend):
|
|||||||
self.use_pd_disaggregation: int = int(
|
self.use_pd_disaggregation: int = int(
|
||||||
os.getenv("FLAGS_use_pd_disaggregation", 0))
|
os.getenv("FLAGS_use_pd_disaggregation", 0))
|
||||||
self.start_layer_index: int = fd_config.model_config.start_layer_index
|
self.start_layer_index: int = fd_config.model_config.start_layer_index
|
||||||
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None)
|
|
||||||
|
|
||||||
if fd_config.parallel_config.expert_parallel_rank is None:
|
if fd_config.parallel_config.expert_parallel_rank is None:
|
||||||
fd_config.parallel_config.expert_parallel_rank = 0
|
fd_config.parallel_config.expert_parallel_rank = 0
|
||||||
device_id = self.rank + fd_config.parallel_config.tensor_parallel_degree * \
|
|
||||||
fd_config.parallel_config.expert_parallel_rank
|
self.rank, self.device_id = init_rank_and_device_id(fd_config)
|
||||||
if self.device_id is None:
|
|
||||||
self.device_id = device_id
|
|
||||||
else:
|
|
||||||
self.device_id = self.device_id.split(",")[device_id]
|
|
||||||
|
|
||||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||||
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
||||||
|
@@ -34,6 +34,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
|||||||
from fastdeploy.model_executor.layers.attention.ops import (
|
from fastdeploy.model_executor.layers.attention.ops import (
|
||||||
get_block_shape_and_split_kv_block, gqa_rope_write_cache,
|
get_block_shape_and_split_kv_block, gqa_rope_write_cache,
|
||||||
init_signal_layerwise, open_shm_and_get_meta_signal, pre_cache_len_concat)
|
init_signal_layerwise, open_shm_and_get_meta_signal, pre_cache_len_concat)
|
||||||
|
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
|
||||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||||
|
|
||||||
|
|
||||||
@@ -100,22 +101,16 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
self.use_speculate = self.speculative_method is not None
|
self.use_speculate = self.speculative_method is not None
|
||||||
self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
|
self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
|
||||||
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
||||||
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
|
|
||||||
|
|
||||||
# pd_disaggregation
|
# pd_disaggregation
|
||||||
self.use_pd_disaggregation: int = int(
|
self.use_pd_disaggregation: int = int(
|
||||||
os.getenv("FLAGS_use_pd_disaggregation", 0))
|
os.getenv("FLAGS_use_pd_disaggregation", 0))
|
||||||
self.start_layer_index: int = fd_config.model_config.start_layer_index
|
self.start_layer_index: int = fd_config.model_config.start_layer_index
|
||||||
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None)
|
|
||||||
|
|
||||||
if fd_config.parallel_config.expert_parallel_rank is None:
|
if fd_config.parallel_config.expert_parallel_rank is None:
|
||||||
fd_config.parallel_config.expert_parallel_rank = 0
|
fd_config.parallel_config.expert_parallel_rank = 0
|
||||||
device_id = self.rank + fd_config.parallel_config.tensor_parallel_degree * \
|
|
||||||
fd_config.parallel_config.expert_parallel_rank
|
self.rank, self.device_id = init_rank_and_device_id(fd_config)
|
||||||
if self.device_id is None:
|
|
||||||
self.device_id = device_id
|
|
||||||
else:
|
|
||||||
self.device_id = self.device_id.split(",")[device_id]
|
|
||||||
|
|
||||||
def get_attntion_meta(self):
|
def get_attntion_meta(self):
|
||||||
"""get_attntion_meta"""
|
"""get_attntion_meta"""
|
||||||
|
@@ -41,6 +41,7 @@ from fastdeploy.config import FDConfig
|
|||||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||||
AttentionBackend, AttentionMetadata)
|
AttentionBackend, AttentionMetadata)
|
||||||
|
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
|
||||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||||
|
|
||||||
|
|
||||||
@@ -109,7 +110,6 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
self.use_speculate: bool = self.speculative_method is not None
|
self.use_speculate: bool = self.speculative_method is not None
|
||||||
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
|
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
|
||||||
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
||||||
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
|
|
||||||
|
|
||||||
self.kv_num_heads: int = kv_num_heads
|
self.kv_num_heads: int = kv_num_heads
|
||||||
self.num_heads: int = num_heads
|
self.num_heads: int = num_heads
|
||||||
@@ -135,10 +135,8 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
os.getenv("FLAGS_use_pd_disaggregation", 0))
|
os.getenv("FLAGS_use_pd_disaggregation", 0))
|
||||||
self.start_layer_index: int = fd_config.model_config.start_layer_index
|
self.start_layer_index: int = fd_config.model_config.start_layer_index
|
||||||
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None)
|
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None)
|
||||||
if self.device_id is None:
|
|
||||||
self.device_id = self.rank
|
self.rank, self.device_id = init_rank_and_device_id(fd_config)
|
||||||
else:
|
|
||||||
self.device_id = self.device_id.split(",")[self.rank]
|
|
||||||
|
|
||||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||||
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""
|
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""
|
||||||
|
36
fastdeploy/model_executor/layers/attention/utils.py
Normal file
36
fastdeploy/model_executor/layers/attention/utils.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from fastdeploy.config import FDConfig
|
||||||
|
|
||||||
|
def init_rank_and_device_id(fd_config: FDConfig):
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
rank = (fd_config.parallel_config.expert_parallel_rank *
|
||||||
|
fd_config.parallel_config.tensor_parallel_degree + fd_config.parallel_config.tensor_parallel_rank)
|
||||||
|
|
||||||
|
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", None)
|
||||||
|
|
||||||
|
if cuda_visible_devices is None:
|
||||||
|
device_id = rank
|
||||||
|
else:
|
||||||
|
cuda_visible_devices = cuda_visible_devices.split(",")
|
||||||
|
rank_index = rank % len(cuda_visible_devices)
|
||||||
|
device_id = cuda_visible_devices[rank_index]
|
||||||
|
|
||||||
|
return rank, device_id
|
@@ -144,7 +144,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
|||||||
if token_all_num > 0:
|
if token_all_num > 0:
|
||||||
logger.info(f"token_all_num {token_all_num}")
|
logger.info(f"token_all_num {token_all_num}")
|
||||||
(recv_x, recv_x_scale) = recv_x
|
(recv_x, recv_x_scale) = recv_x
|
||||||
tmp = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts)
|
|
||||||
|
token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts)
|
||||||
|
token_nums_this_rank_padded = sum(token_nums_this_rank[1].numpy().tolist())
|
||||||
|
|
||||||
(
|
(
|
||||||
permute_input,
|
permute_input,
|
||||||
permute_scale,
|
permute_scale,
|
||||||
@@ -160,8 +163,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
|||||||
recv_x_scale,
|
recv_x_scale,
|
||||||
recv_topk_idx,
|
recv_topk_idx,
|
||||||
recv_topk_weights,
|
recv_topk_weights,
|
||||||
tmp[0],
|
token_nums_this_rank[0],
|
||||||
tmp[1]
|
token_nums_this_rank[1],
|
||||||
|
True, # use_in_ep
|
||||||
|
token_nums_this_rank_padded,
|
||||||
)
|
)
|
||||||
|
|
||||||
permute_scale = permute_scale.transpose([1, 0]).contiguous()
|
permute_scale = permute_scale.transpose([1, 0]).contiguous()
|
||||||
@@ -328,6 +333,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
|||||||
topk_weights,
|
topk_weights,
|
||||||
tmp[0],
|
tmp[0],
|
||||||
tmp[1],
|
tmp[1],
|
||||||
|
False, # use_in_ep
|
||||||
|
-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
permute_scale = permute_scale.transpose([1, 0]).contiguous()
|
permute_scale = permute_scale.transpose([1, 0]).contiguous()
|
||||||
|
Reference in New Issue
Block a user