diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 88280b079..a8348fed1 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -414,8 +414,8 @@ std::vector MoEDeepGEMMDePermute( const paddle::Tensor &topk_idx, const paddle::Tensor &topk_weights); void TextImageIndexOut(const paddle::Tensor &token_type_ids, - const paddle::Tensor &text_input, - const paddle::Tensor &image_input); + paddle::Tensor &text_input, + paddle::Tensor &image_input); void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input, paddle::Tensor &image_input, diff --git a/custom_ops/gpu_ops/get_padding_offset.cu b/custom_ops/gpu_ops/get_padding_offset.cu index 560310148..f36201389 100644 --- a/custom_ops/gpu_ops/get_padding_offset.cu +++ b/custom_ops/gpu_ops/get_padding_offset.cu @@ -132,7 +132,7 @@ std::vector GetPaddingOffsetInferDtype( } PD_BUILD_STATIC_OP(get_padding_offset) - .Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"}) + .Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"}) .Outputs({"x_remove_padding", "batch_id_per_token", "cu_seqlens_q", diff --git a/custom_ops/gpu_ops/moe/moe_dispatch.cu b/custom_ops/gpu_ops/moe/moe_dispatch.cu index 8fa663c10..bc18ece45 100644 --- a/custom_ops/gpu_ops/moe/moe_dispatch.cu +++ b/custom_ops/gpu_ops/moe/moe_dispatch.cu @@ -36,6 +36,9 @@ void MoeDispatchKernel( paddle::Tensor *topk_idx, paddle::Tensor *expert_idx_per_token) { using namespace phi; + if (num_rows == 0){ + return; + } typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; @@ -185,6 +188,15 @@ std::vector MoeExpertDispatch( auto expert_idx_per_token = GetEmptyTensor({num_rows * moe_topk}, paddle::DataType::INT32, place); + if (token_rows == 0){ + return {permute_input, + tokens_expert_prefix_sum, + permute_indices_per_token, + topk_weight, + topk_idx, + expert_idx_per_token}; + } + switch (input_type) { case paddle::DataType::BFLOAT16: MoeDispatchKernel( diff --git a/custom_ops/gpu_ops/moe/moe_ffn.cu b/custom_ops/gpu_ops/moe/moe_ffn.cu index 7387246ab..c13590377 100644 --- a/custom_ops/gpu_ops/moe/moe_ffn.cu +++ b/custom_ops/gpu_ops/moe/moe_ffn.cu @@ -412,7 +412,9 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype() (quant_method == "w4afp8") ? paddle::DataType::BFLOAT16 : permute_input.dtype(); auto ffn_out = paddle::empty_like(permute_input, t_type); - + if(permute_input.numel() == 0){ + return ffn_out; + } switch (t_type) { case paddle::DataType::BFLOAT16: MoeFFNKernel(permute_input, diff --git a/custom_ops/gpu_ops/moe/moe_reduce.cu b/custom_ops/gpu_ops/moe/moe_reduce.cu index e8532d5cd..9a7bad147 100644 --- a/custom_ops/gpu_ops/moe/moe_reduce.cu +++ b/custom_ops/gpu_ops/moe/moe_reduce.cu @@ -59,6 +59,10 @@ paddle::Tensor MoeExpertReduceFunc( auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place); + if(num_rows == 0){ + return output; + } + switch (input_type) { case paddle::DataType::BFLOAT16: MoeReduceKernel( diff --git a/custom_ops/gpu_ops/text_image_gather_scatter.cu b/custom_ops/gpu_ops/text_image_gather_scatter.cu index 09fc07f96..59823af47 100644 --- a/custom_ops/gpu_ops/text_image_gather_scatter.cu +++ b/custom_ops/gpu_ops/text_image_gather_scatter.cu @@ -59,7 +59,7 @@ __global__ void text_image_scatter_kernel( constexpr int HalfVecSize = VecSize / 2; using T_Vec = AlignedVector; T_Vec input_ptr_vec; - T_Vec text_imgaes_vec; + T_Vec text_images_vec; int64_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x; const int64_t step = blockDim.x * gridDim.x * VecSize; @@ -76,16 +76,20 @@ __global__ void text_image_scatter_kernel( Load(input_ptr + input_load_offset, &input_ptr_vec); #pragma unroll for(int vi = 0; vi < VecSize; ++vi) { - text_imgaes_vec[vi] = input_ptr_vec[vi]; + text_images_vec[vi] = input_ptr_vec[vi]; } if (token_type_ids_num == 0) { int64_t text_load_offset = text_index[token_idx] * hidden_size + hidden_offset; - Store(text_imgaes_vec, text_gather_ptr + text_load_offset); + Store(text_images_vec, text_gather_ptr + text_load_offset); + + } else if(token_type_ids_num == 1){ + int64_t image_load_offset = image_index[token_idx] * hidden_size + hidden_offset; + Store(text_images_vec, image_gather_ptr + image_load_offset); } else { - int64_t image_load_offset = image_index[token_idx] * hidden_size + hidden_offset; - Store(text_imgaes_vec, image_gather_ptr + image_load_offset); + // skip cuda graph padding value + continue; } } } @@ -120,9 +124,12 @@ __global__ void text_image_gather_kernel( int64_t text_load_offset = text_index[token_idx] * hidden_size + hidden_offset; Load(text_gather_ptr + text_load_offset, &text_imgaes_vec); - } else { + } else if (token_type_ids_num == 1){ int64_t image_load_offset = image_index[token_idx] * hidden_size + hidden_offset; Load(image_gather_ptr + image_load_offset, &text_imgaes_vec); + } else { + // skip cuda graph padding value + continue; } #pragma unroll @@ -154,7 +161,6 @@ void LaunchTextImageGatherScatter( const int64_t token_num = in_dims[0]; const int64_t hidden_size = in_dims[1]; - const int VecSize = 16 / sizeof(data_t); const int64_t tot_element_num = token_num * hidden_size; @@ -168,7 +174,7 @@ void LaunchTextImageGatherScatter( PADDLE_ENFORCE_GPU_SUCCESS(GetGridSize(tot_pack_num, block_size, kNumWaves, &grid_size_x)); dim3 grid_dim = dim3(grid_size_x, 1, 1); if (is_scatter) { - text_image_scatter_kernel<<>>( + text_image_scatter_kernel<<>>( reinterpret_cast(input.data()), reinterpret_cast(text_input.data()), reinterpret_cast(image_input.data()), @@ -179,7 +185,7 @@ void LaunchTextImageGatherScatter( tot_element_num ); } else { - text_image_gather_kernel<<>>( + text_image_gather_kernel<<>>( reinterpret_cast(input.data()), reinterpret_cast(text_input.data()), reinterpret_cast(image_input.data()), diff --git a/custom_ops/gpu_ops/text_image_index_out.cu b/custom_ops/gpu_ops/text_image_index_out.cu index b6d8941d6..7a5b44472 100644 --- a/custom_ops/gpu_ops/text_image_index_out.cu +++ b/custom_ops/gpu_ops/text_image_index_out.cu @@ -16,7 +16,7 @@ template __global__ void text_image_index_out_kernel( - int32_t* token_type_ids, + const int32_t* token_type_ids, int32_t* text_index, int32_t* image_index, const int64_t token_num @@ -31,23 +31,27 @@ __global__ void text_image_index_out_kernel( if (token_type_ids[i] == 0) { text_index[i] = text_count; text_count += 1; - } else { + } else if (token_type_ids[i] == 1) { image_index[i] = images_count; images_count += 1; + } else { + // skip cuda graph padding value + continue; } } } void TextImageIndexOut( const paddle::Tensor& token_type_ids, - const paddle::Tensor& text_index, - const paddle::Tensor& image_index) { + paddle::Tensor& text_index, + paddle::Tensor& image_index) { const int64_t token_num = token_type_ids.shape()[0]; - text_image_index_out_kernel<1><<<1, 1>>>( - const_cast(token_type_ids.data()), - const_cast(text_index.data()), - const_cast(image_index.data()), + auto stream = token_type_ids.stream(); + text_image_index_out_kernel<1><<<1, 1, 0, stream>>>( + token_type_ids.data(), + text_index.data(), + image_index.data(), token_num ); } diff --git a/fastdeploy/model_executor/graph_optimization/decorator.py b/fastdeploy/model_executor/graph_optimization/decorator.py index ef4f54f98..2937579b0 100644 --- a/fastdeploy/model_executor/graph_optimization/decorator.py +++ b/fastdeploy/model_executor/graph_optimization/decorator.py @@ -99,3 +99,35 @@ class GraphOptWrapper: fd_config.graph_opt_config.graph_opt_level < 1 ), "Currently unable to update weights in static graph mode." self.graph_opt_backend.clear_cudagraph_piecewise_backend() + + +def cuda_graph_buffers(buffer_meta): + def decorator(cls): + original_init = cls.__init__ + + def __init__(self, fd_config: FDConfig, **kwargs): + original_init(self, fd_config=fd_config, **kwargs) + + def _resolve_path(root, path: str): + cur = root + for p in path.split("."): + cur = getattr(cur, p) + return cur + + if not hasattr(self, "_mm_buffers"): + self._mm_buffers = {} + for name, meta in buffer_meta.items(): + shape = [_resolve_path(fd_config, s) if isinstance(s, str) else s for s in meta["shape"]] + dtype = meta["dtype"] + if "." in meta["dtype"]: + dtype = _resolve_path(fd_config, meta["dtype"]) + self._mm_buffers[name] = paddle.full( + shape=shape, + dtype=dtype, + fill_value=meta.get("value", 0), + ) + + cls.__init__ = __init__ + return cls + + return decorator diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py index 34e2d9881..702b05d50 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py @@ -32,6 +32,7 @@ from paddleformers.utils.log import logger from fastdeploy.config import FDConfig from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.model_executor.graph_optimization.decorator import ( + cuda_graph_buffers, support_graph_optimization, ) from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding @@ -66,12 +67,23 @@ class Ernie4_5_VLAttention(Ernie4_5_Attention): @dataclass class VLMoEMeta: - image_input: Optional[paddle.Tensor] = None - text_input: Optional[paddle.Tensor] = None - text_index: Optional[paddle.Tensor] = None - image_index: Optional[paddle.Tensor] = None - token_type_ids: Optional[paddle.Tensor] = None - fake_hidden_states: Optional[paddle.Tensor] = None + image_input: paddle.Tensor + text_input: paddle.Tensor + text_index: paddle.Tensor + image_index: paddle.Tensor + token_type_ids: paddle.Tensor + image_token_num: paddle.Tensor + + def __str__(self): + return ( + f"VLMoEMeta(\n" + f" image_input: {self.image_input}, pointer: {self.image_input.data_ptr()}\n" + f" text_input: {self.text_input}, pointer: {self.text_input.data_ptr()}\n" + f" text_index: {self.text_index}, pointer: {self.text_index.data_ptr()}\n" + f" image_index: {self.image_index}, pointer: {self.image_index.data_ptr()}\n" + f" token_type_ids: {self.token_type_ids}, pointer: {self.token_type_ids.data_ptr()}\n\n" + f")" + ) class Ernie4_5_VLMoeBlock(nn.Layer): @@ -266,31 +278,26 @@ class Ernie4_5_VLMoE(nn.Layer): def forward(self, hidden_states: paddle.Tensor, vl_moe_meta: VLMoEMeta): if self.num_shared_experts > 0: shared_experts_out = self.shared_experts(hidden_states) - if vl_moe_meta.image_input is not None: - text_image_gather_scatter( - hidden_states, - vl_moe_meta.text_input, - vl_moe_meta.image_input, - vl_moe_meta.token_type_ids, - vl_moe_meta.text_index, - vl_moe_meta.image_index, - True, - ) - text_out = self.text_fused_moe(vl_moe_meta.text_input) - image_out = self.image_fused_moe(vl_moe_meta.image_input) - text_image_gather_scatter( - hidden_states, - text_out, - image_out, - vl_moe_meta.token_type_ids, - vl_moe_meta.text_index, - vl_moe_meta.image_index, - False, - ) - else: - hidden_states = self.text_fused_moe(hidden_states) - if vl_moe_meta.fake_hidden_states is not None: - self.image_fused_moe(vl_moe_meta.fake_hidden_states) + text_image_gather_scatter( + hidden_states, + vl_moe_meta.text_input, + vl_moe_meta.image_input, + vl_moe_meta.token_type_ids, + vl_moe_meta.text_index, + vl_moe_meta.image_index, + True, + ) + text_out = self.text_fused_moe(vl_moe_meta.text_input) + image_out = self.image_fused_moe(vl_moe_meta.image_input) + text_image_gather_scatter( + hidden_states, + text_out, + image_out, + vl_moe_meta.token_type_ids, + vl_moe_meta.text_index, + vl_moe_meta.image_index, + False, + ) if self.num_shared_experts > 0: hidden_states += shared_experts_out if self.tp_size > 1: @@ -394,6 +401,40 @@ class Ernie4_5_VLDecoderLayer(nn.Layer): return hidden_states, residual +@cuda_graph_buffers( + { + "text_input": { + "shape": ["parallel_config.max_model_len", "model_config.hidden_size"], + "dtype": "model_config.dtype", + "value": 1, + }, + "image_input": { + "shape": ["parallel_config.max_model_len", "model_config.hidden_size"], + "dtype": "model_config.dtype", + "value": 1, + }, + "text_index": { + "shape": ["parallel_config.max_model_len"], + "dtype": "int32", + "value": 0, + }, + "image_index": { + "shape": ["parallel_config.max_model_len"], + "dtype": "int32", + "value": 0, + }, + "token_type_ids": { + "shape": ["parallel_config.max_model_len"], + "dtype": "int32", + "value": -1, + }, + "image_token_num": { + "shape": [1], + "dtype": "int64", + "value": 0, + }, + } +) @support_graph_optimization class Ernie4_5_VLModel(nn.Layer): def __init__( @@ -454,59 +495,46 @@ class Ernie4_5_VLModel(nn.Layer): logger.info(f"Start load layer {i}") self.layers[i].load_state_dict(state_dict) - def forward( + def prepare_vl_moe_meta( self, ids_remove_padding: paddle.Tensor, - image_features: Optional[paddle.Tensor], - forward_meta: ForwardMeta, - ): - text_input = None - image_input = None - text_index = None - image_index = None - fake_hidden_states = None + ) -> VLMoEMeta: - hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding) - token_num, hidden_dim = hidden_states.shape - - # ----------------------- image_mask = ids_remove_padding == self.im_patch_id + token_type_ids = image_mask.cast("int32") image_token_num = image_mask.sum() + token_num = ids_remove_padding.shape[0] text_token_num = paddle.maximum((token_num - image_token_num), paddle.ones([], dtype="int64")) - token_type_ids = image_mask.cast("int32") - if self.fd_config.parallel_config.use_ep is True: - fake_hidden_states = paddle.empty( - shape=[0, self.fd_config.model_config.hidden_size], - dtype=paddle.get_default_dtype(), - ) - text_input = fake_hidden_states + # The scenario requiring padding is CUDA graph, thus we only need to pad the maximum capture size. + self._mm_buffers["token_type_ids"][: self.fd_config.graph_opt_config.max_capture_size].fill_(-1) + self._mm_buffers["token_type_ids"].copy_(token_type_ids, False) + self._mm_buffers["image_token_num"].copy_(image_token_num, False) - if image_token_num > 0: - hidden_states[image_mask] = image_features.cast(self._dtype) - text_input = paddle.ones( - shape=[text_token_num, hidden_dim], - dtype=self._dtype, - ) - image_input = paddle.ones( - shape=[image_token_num, hidden_dim], - dtype=self._dtype, - ) - text_index = paddle.zeros_like(image_mask, dtype="int32") - image_index = paddle.zeros_like(image_mask, dtype="int32") - text_image_index_out(token_type_ids, text_index, image_index) - - vl_moe_meta = VLMoEMeta( - text_input=text_input, - image_input=image_input, - text_index=text_index, - image_index=image_index, - token_type_ids=token_type_ids, - fake_hidden_states=fake_hidden_states, + return VLMoEMeta( + text_input=self._mm_buffers["text_input"][:text_token_num], + image_input=self._mm_buffers["image_input"][:image_token_num], + text_index=self._mm_buffers["text_index"][:token_num], + image_index=self._mm_buffers["image_index"][:token_num], + token_type_ids=self._mm_buffers["token_type_ids"][:token_num], + image_token_num=self._mm_buffers["image_token_num"], ) - # ----------------------- + def get_input_embeddings(self, ids_remove_padding: paddle.Tensor) -> paddle.Tensor: + return self.embed_tokens(ids_remove_padding=ids_remove_padding) + + def forward( + self, + input_embeddings: paddle.Tensor, + ids_remove_padding: paddle.Tensor, + forward_meta: ForwardMeta, + vl_moe_meta: VLMoEMeta, + ): + text_image_index_out(vl_moe_meta.token_type_ids, vl_moe_meta.text_index, vl_moe_meta.image_index) + + hidden_states = input_embeddings residual = None + for i in range(self.num_layers): hidden_states, residual = self.layers[i]( forward_meta, @@ -517,17 +545,15 @@ class Ernie4_5_VLModel(nn.Layer): hidden_states = hidden_states + residual - # ----------------------- max_seq_len, max_seq_len_index = paddle.topk(forward_meta.seq_lens_this_time, k=1) hidden_states = extract_text_token_output( max_seq_len, max_seq_len_index.cast("int32"), - image_token_num.cast("int32"), + vl_moe_meta.image_token_num.cast("int32"), forward_meta.seq_lens_this_time, forward_meta.cu_seqlens_q, hidden_states.cast("float32"), ).cast(self._dtype) - # ----------------------- out = self.norm(hidden_states) @@ -552,6 +578,12 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): # ernie self.ernie = Ernie4_5_VLModel(fd_config=fd_config) + # Persistent buffers for CUDA graphs. + self._input_embeddings = paddle.zeros( + [fd_config.parallel_config.max_model_len, fd_config.model_config.hidden_size], + dtype=fd_config.model_config.dtype, + ) + self.ori_vocab_size = fd_config.model_config.ori_vocab_size self.lm_head = ParallelLMHead( @@ -733,16 +765,33 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): self.ernie.layers[i].mlp.text_fused_moe(fake_hidden_states) self.ernie.layers[i].mlp.image_fused_moe(fake_hidden_states) + def get_input_embeddings( + self, + ids_remove_padding: paddle.Tensor, + image_features: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + input_embeddings = self.ernie.get_input_embeddings(ids_remove_padding=ids_remove_padding) + if image_features is not None and len(image_features) > 0: + input_embeddings[ids_remove_padding == self.ernie.im_patch_id] = image_features.cast(self.ernie._dtype) + return input_embeddings + def forward( self, ids_remove_padding: paddle.Tensor, image_features: Optional[paddle.Tensor], forward_meta: ForwardMeta, ): + input_embeddings = self.get_input_embeddings( + ids_remove_padding=ids_remove_padding, image_features=image_features + ) + self._input_embeddings.copy_(input_embeddings, False) + vl_moe_meta = self.ernie.prepare_vl_moe_meta(ids_remove_padding=ids_remove_padding) + hidden_states = self.ernie( + input_embeddings=self._input_embeddings, ids_remove_padding=ids_remove_padding, - image_features=image_features, forward_meta=forward_meta, + vl_moe_meta=vl_moe_meta, ) return hidden_states