diff --git a/fastdeploy/model_executor/layers/embeddings.py b/fastdeploy/model_executor/layers/embeddings.py index 1fea7d06f..52d7dadee 100644 --- a/fastdeploy/model_executor/layers/embeddings.py +++ b/fastdeploy/model_executor/layers/embeddings.py @@ -135,6 +135,8 @@ class VocabParallelEmbedding(nn.Layer): self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings self.params_dtype: str = params_dtype + self.embedding_dim = embedding_dim + self.general = general # used for general Embedding self.num_embeddings = num_embeddings self.padding_size = padding_size @@ -297,6 +299,8 @@ class VocabParallelEmbedding(nn.Layer): Returns: Tensor: Embedded tensor representation of the input IDs. """ + if ids_remove_padding.shape[0] == 0: + return paddle.empty([0, self.embedding_dim], dtype=self.embeddings.weight.dtype) if self.column_cut: input_embedings = self.embeddings(ids_remove_padding) inputs_embeds_temp = [] diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index be40b3d04..b61fe48f6 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -505,6 +505,8 @@ class EPPrefillRunner(EPRunner): EPPrefillRunner """ + allocate_on_comm_stream = False + def __init__( self, top_k: int, @@ -533,6 +535,12 @@ class EPPrefillRunner(EPRunner): use_internode_ll_two_stage=use_internode_ll_two_stage, ) + def set_allocate_on_comm_stream(allocate_on_comm_stream: bool = False): + logger.info( + f"set allocate_on_comm_stream to {allocate_on_comm_stream}, this will force Prefill dispatch's output tensor is allocated on communication stream" + ) + EPPrefillRunner.allocate_on_comm_stream = allocate_on_comm_stream + def dispatch( self, x: paddle.Tensor, @@ -552,7 +560,13 @@ class EPPrefillRunner(EPRunner): num_tokens_per_expert, is_token_in_rank, event, - ) = buffer.get_dispatch_layout(topk_idx, self.num_experts, async_finish=self.ep_engine.async_finish) + ) = buffer.get_dispatch_layout( + topk_idx, + self.num_experts, + previous_event=kwargs.get("previous_event", None), + allocate_on_comm_stream=EPPrefillRunner.allocate_on_comm_stream, + async_finish=self.ep_engine.async_finish, + ) x_scale_tensor = kwargs.get("x_scale_tensor", None) dispatch_args = { @@ -566,6 +580,7 @@ class EPPrefillRunner(EPRunner): "topk_idx": topk_idx, "topk_weights": topk_weights, "expert_alignment": expert_alignment, + "allocate_on_comm_stream": EPPrefillRunner.allocate_on_comm_stream, "previous_event": event, } return buffer.dispatch(**dispatch_args) @@ -575,6 +590,7 @@ class EPPrefillRunner(EPRunner): tmp_ffn_out: paddle.Tensor, handle: tuple, recv_topk_weights: paddle.Tensor, + event=None, ): buffer = self.ep_engine.deepep_engine if buffer is None: @@ -586,6 +602,7 @@ class EPPrefillRunner(EPRunner): "config": self.ep_engine.ep_config, "async_finish": self.ep_engine.async_finish, "topk_weights": recv_topk_weights, + "previous_event": event, } fused_moe_out, _, event = buffer.combine(**combine_args) return fused_moe_out, event diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index 4e591f8e0..1245cddce 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -16,11 +16,13 @@ import paddle from paddle import nn +from paddle.distributed.communication import deep_ep from paddleformers.utils.log import logger import fastdeploy from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.ops.gpu import count_tokens_per_expert_func, deep_gemm +from fastdeploy.worker.tbo import let_another_thread_run from .fused_moe_backend_base import MoEMethodBase from .fused_moe_triton_backend import BlockWiseFP8MoEMethod @@ -142,12 +144,17 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): Apply the EP prefill method. """ gate_out = gate(x.cast("float32")) + # 1. Select topk experts and weights topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out) # 2. Dynamic compute blockwise quantization scales x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant( x, self.quant_config.weight_block_size[0] ) + + event = deep_ep.Buffer.capture() + let_another_thread_run() + # 3. EP Dispatch ( recv_x, @@ -157,8 +164,9 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): handle, event, ) = self.ep_prefill_runner.dispatch( - x, topk_idx, topk_weights, x_scale_tensor=x_scale_tensor, expert_alignment=128 + x, topk_idx, topk_weights, x_scale_tensor=x_scale_tensor, expert_alignment=128, previous_event=event ) + if self.ep_prefill_runner.ep_engine.async_finish: event.current_stream_wait() @@ -241,7 +249,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): tmp_ffn_out = paddle.cast(recv_x[0], paddle.bfloat16) # 5. EP combine - tmp_ffn_out, event = self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights) + event = deep_ep.Buffer.capture() + let_another_thread_run() + + tmp_ffn_out, event = self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights, event) if self.ep_prefill_runner.ep_engine.async_finish: event.current_stream_wait() diff --git a/fastdeploy/worker/tbo.py b/fastdeploy/worker/tbo.py new file mode 100644 index 000000000..856437959 --- /dev/null +++ b/fastdeploy/worker/tbo.py @@ -0,0 +1,108 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import threading + +from fastdeploy.model_executor.forward_meta import ForwardMeta + +event0 = threading.Event() +event1 = threading.Event() + + +GLOBAL_THREAD_INFO = {} + +GLOBAL_THREAD_INFO["thread0"] = [event0, event1] +GLOBAL_THREAD_INFO["thread1"] = [event1, event0] + + +GLOBAL_ATTN_BUFFERS = {} + + +def let_another_thread_run(): + thread_name = threading.current_thread().name + + if thread_name in GLOBAL_THREAD_INFO: + GLOBAL_THREAD_INFO[thread_name][1].set() + GLOBAL_THREAD_INFO[thread_name][0].wait() + GLOBAL_THREAD_INFO[thread_name][0].clear() + + +def split_batch_decoder_layers(forward_meta: ForwardMeta): + split_num = 2 + real_bs = forward_meta.seq_lens_this_time.shape[0] + + res = [forward_meta] * split_num + + if real_bs < split_num or forward_meta.ids_remove_padding.shape[0] == 0: + return res + + mc_bs = (real_bs + split_num - 1) // split_num + + for i in range(0, split_num): + start_bs = i * mc_bs + + end_bs = start_bs + mc_bs + end_bs = min(end_bs, real_bs) + + if start_bs >= end_bs: + continue + + start_token_id = forward_meta.cu_seqlens_q[start_bs].item() + end_token_id = forward_meta.cu_seqlens_q[end_bs].item() + + if start_token_id >= end_token_id: + continue + + res[i] = ForwardMeta( + ids_remove_padding=None, + rotary_embs=forward_meta.rotary_embs, + attn_backend=forward_meta.attn_backend, + caches=forward_meta.caches, + ) + + res[i].rotary_embs = forward_meta.rotary_embs[start_bs:end_bs] + + res[i].ids_remove_padding = forward_meta.ids_remove_padding[start_token_id:end_token_id] + res[i].batch_id_per_token = forward_meta.batch_id_per_token[start_token_id:end_token_id] - start_bs + + res[i].seq_lens_encoder = forward_meta.seq_lens_encoder[start_bs:end_bs] + res[i].seq_lens_decoder = forward_meta.seq_lens_decoder[start_bs:end_bs] + res[i].seq_lens_this_time = forward_meta.seq_lens_this_time[start_bs:end_bs] + + res[i].block_tables = forward_meta.block_tables[start_bs:end_bs] + + res[i].cu_seqlens_q = forward_meta.cu_seqlens_q[start_bs : end_bs + 1] - start_token_id + res[i].cu_seqlens_k = forward_meta.cu_seqlens_k[start_bs : end_bs + 1] - start_token_id + + for key in GLOBAL_ATTN_BUFFERS[i]: + setattr(res[i], key, GLOBAL_ATTN_BUFFERS[i][key]) + + if forward_meta.attn_mask_offsets is not None: + mask_num = forward_meta.attn_mask_offsets.shape[0] + token_num = forward_meta.ids_remove_padding.shape[0] + if mask_num == token_num * 2: + res[i].attn_mask_offsets = forward_meta.attn_mask_offsets[start_token_id * 2 : end_token_id * 2] + elif mask_num == token_num: + res[i].attn_mask_offsets = forward_meta.attn_mask_offsets[start_token_id:end_token_id] + else: + assert False, "Invalid attn_mask_offsets shape" + + # This is to adapt 5 + if hasattr(forward_meta, "hidden_states"): + res[i].hidden_states = forward_meta.hidden_states[start_token_id:end_token_id] + res[i].decode_states = forward_meta.decode_states[start_bs:end_bs] + + return res