[Feature] support Two batch overlap, mainly used in Prefill (#5078)

This commit is contained in:
周周周
2025-12-05 14:58:50 +08:00
committed by GitHub
parent 1aefbef0b3
commit c83dc58105
4 changed files with 143 additions and 3 deletions

View File

@@ -135,6 +135,8 @@ class VocabParallelEmbedding(nn.Layer):
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
self.params_dtype: str = params_dtype
self.embedding_dim = embedding_dim
self.general = general # used for general Embedding
self.num_embeddings = num_embeddings
self.padding_size = padding_size
@@ -297,6 +299,8 @@ class VocabParallelEmbedding(nn.Layer):
Returns:
Tensor: Embedded tensor representation of the input IDs.
"""
if ids_remove_padding.shape[0] == 0:
return paddle.empty([0, self.embedding_dim], dtype=self.embeddings.weight.dtype)
if self.column_cut:
input_embedings = self.embeddings(ids_remove_padding)
inputs_embeds_temp = []

View File

@@ -505,6 +505,8 @@ class EPPrefillRunner(EPRunner):
EPPrefillRunner
"""
allocate_on_comm_stream = False
def __init__(
self,
top_k: int,
@@ -533,6 +535,12 @@ class EPPrefillRunner(EPRunner):
use_internode_ll_two_stage=use_internode_ll_two_stage,
)
def set_allocate_on_comm_stream(allocate_on_comm_stream: bool = False):
logger.info(
f"set allocate_on_comm_stream to {allocate_on_comm_stream}, this will force Prefill dispatch's output tensor is allocated on communication stream"
)
EPPrefillRunner.allocate_on_comm_stream = allocate_on_comm_stream
def dispatch(
self,
x: paddle.Tensor,
@@ -552,7 +560,13 @@ class EPPrefillRunner(EPRunner):
num_tokens_per_expert,
is_token_in_rank,
event,
) = buffer.get_dispatch_layout(topk_idx, self.num_experts, async_finish=self.ep_engine.async_finish)
) = buffer.get_dispatch_layout(
topk_idx,
self.num_experts,
previous_event=kwargs.get("previous_event", None),
allocate_on_comm_stream=EPPrefillRunner.allocate_on_comm_stream,
async_finish=self.ep_engine.async_finish,
)
x_scale_tensor = kwargs.get("x_scale_tensor", None)
dispatch_args = {
@@ -566,6 +580,7 @@ class EPPrefillRunner(EPRunner):
"topk_idx": topk_idx,
"topk_weights": topk_weights,
"expert_alignment": expert_alignment,
"allocate_on_comm_stream": EPPrefillRunner.allocate_on_comm_stream,
"previous_event": event,
}
return buffer.dispatch(**dispatch_args)
@@ -575,6 +590,7 @@ class EPPrefillRunner(EPRunner):
tmp_ffn_out: paddle.Tensor,
handle: tuple,
recv_topk_weights: paddle.Tensor,
event=None,
):
buffer = self.ep_engine.deepep_engine
if buffer is None:
@@ -586,6 +602,7 @@ class EPPrefillRunner(EPRunner):
"config": self.ep_engine.ep_config,
"async_finish": self.ep_engine.async_finish,
"topk_weights": recv_topk_weights,
"previous_event": event,
}
fused_moe_out, _, event = buffer.combine(**combine_args)
return fused_moe_out, event

View File

@@ -16,11 +16,13 @@
import paddle
from paddle import nn
from paddle.distributed.communication import deep_ep
from paddleformers.utils.log import logger
import fastdeploy
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.ops.gpu import count_tokens_per_expert_func, deep_gemm
from fastdeploy.worker.tbo import let_another_thread_run
from .fused_moe_backend_base import MoEMethodBase
from .fused_moe_triton_backend import BlockWiseFP8MoEMethod
@@ -142,12 +144,17 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
Apply the EP prefill method.
"""
gate_out = gate(x.cast("float32"))
# 1. Select topk experts and weights
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
# 2. Dynamic compute blockwise quantization scales
x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
x, self.quant_config.weight_block_size[0]
)
event = deep_ep.Buffer.capture()
let_another_thread_run()
# 3. EP Dispatch
(
recv_x,
@@ -157,8 +164,9 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
handle,
event,
) = self.ep_prefill_runner.dispatch(
x, topk_idx, topk_weights, x_scale_tensor=x_scale_tensor, expert_alignment=128
x, topk_idx, topk_weights, x_scale_tensor=x_scale_tensor, expert_alignment=128, previous_event=event
)
if self.ep_prefill_runner.ep_engine.async_finish:
event.current_stream_wait()
@@ -241,7 +249,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
tmp_ffn_out = paddle.cast(recv_x[0], paddle.bfloat16)
# 5. EP combine
tmp_ffn_out, event = self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights)
event = deep_ep.Buffer.capture()
let_another_thread_run()
tmp_ffn_out, event = self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights, event)
if self.ep_prefill_runner.ep_engine.async_finish:
event.current_stream_wait()

108
fastdeploy/worker/tbo.py Normal file
View File

@@ -0,0 +1,108 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import threading
from fastdeploy.model_executor.forward_meta import ForwardMeta
event0 = threading.Event()
event1 = threading.Event()
GLOBAL_THREAD_INFO = {}
GLOBAL_THREAD_INFO["thread0"] = [event0, event1]
GLOBAL_THREAD_INFO["thread1"] = [event1, event0]
GLOBAL_ATTN_BUFFERS = {}
def let_another_thread_run():
thread_name = threading.current_thread().name
if thread_name in GLOBAL_THREAD_INFO:
GLOBAL_THREAD_INFO[thread_name][1].set()
GLOBAL_THREAD_INFO[thread_name][0].wait()
GLOBAL_THREAD_INFO[thread_name][0].clear()
def split_batch_decoder_layers(forward_meta: ForwardMeta):
split_num = 2
real_bs = forward_meta.seq_lens_this_time.shape[0]
res = [forward_meta] * split_num
if real_bs < split_num or forward_meta.ids_remove_padding.shape[0] == 0:
return res
mc_bs = (real_bs + split_num - 1) // split_num
for i in range(0, split_num):
start_bs = i * mc_bs
end_bs = start_bs + mc_bs
end_bs = min(end_bs, real_bs)
if start_bs >= end_bs:
continue
start_token_id = forward_meta.cu_seqlens_q[start_bs].item()
end_token_id = forward_meta.cu_seqlens_q[end_bs].item()
if start_token_id >= end_token_id:
continue
res[i] = ForwardMeta(
ids_remove_padding=None,
rotary_embs=forward_meta.rotary_embs,
attn_backend=forward_meta.attn_backend,
caches=forward_meta.caches,
)
res[i].rotary_embs = forward_meta.rotary_embs[start_bs:end_bs]
res[i].ids_remove_padding = forward_meta.ids_remove_padding[start_token_id:end_token_id]
res[i].batch_id_per_token = forward_meta.batch_id_per_token[start_token_id:end_token_id] - start_bs
res[i].seq_lens_encoder = forward_meta.seq_lens_encoder[start_bs:end_bs]
res[i].seq_lens_decoder = forward_meta.seq_lens_decoder[start_bs:end_bs]
res[i].seq_lens_this_time = forward_meta.seq_lens_this_time[start_bs:end_bs]
res[i].block_tables = forward_meta.block_tables[start_bs:end_bs]
res[i].cu_seqlens_q = forward_meta.cu_seqlens_q[start_bs : end_bs + 1] - start_token_id
res[i].cu_seqlens_k = forward_meta.cu_seqlens_k[start_bs : end_bs + 1] - start_token_id
for key in GLOBAL_ATTN_BUFFERS[i]:
setattr(res[i], key, GLOBAL_ATTN_BUFFERS[i][key])
if forward_meta.attn_mask_offsets is not None:
mask_num = forward_meta.attn_mask_offsets.shape[0]
token_num = forward_meta.ids_remove_padding.shape[0]
if mask_num == token_num * 2:
res[i].attn_mask_offsets = forward_meta.attn_mask_offsets[start_token_id * 2 : end_token_id * 2]
elif mask_num == token_num:
res[i].attn_mask_offsets = forward_meta.attn_mask_offsets[start_token_id:end_token_id]
else:
assert False, "Invalid attn_mask_offsets shape"
# This is to adapt 5
if hasattr(forward_meta, "hidden_states"):
res[i].hidden_states = forward_meta.hidden_states[start_token_id:end_token_id]
res[i].decode_states = forward_meta.decode_states[start_bs:end_bs]
return res