diff --git a/custom_ops/metax_ops/fused_moe.cu b/custom_ops/metax_ops/fused_moe.cu index fbfaa952d..c1cdf14e7 100644 --- a/custom_ops/metax_ops/fused_moe.cu +++ b/custom_ops/metax_ops/fused_moe.cu @@ -101,6 +101,10 @@ std::vector FusedExpertMoe( const auto input_type = input.dtype(); auto output = paddle::empty_like(input); + if (output.dims()[0] == 0) { + return {output}; + } + switch (input_type) { case paddle::DataType::BFLOAT16: FusedMoeKernel MoeExpertDispatch( auto permute_indices_per_token = GetEmptyTensor({moe_topk, num_rows}, paddle::DataType::INT32, place); + if (token_rows == 0) { + return {permute_input, + tokens_expert_prefix_sum, + permute_indices_per_token, + top_k_weight, + top_k_indices}; + } + switch (input_type) { case paddle::DataType::BFLOAT16: MoeDispatchKernel(input, diff --git a/custom_ops/metax_ops/moe_ffn.cu b/custom_ops/metax_ops/moe_ffn.cu index bc268f769..b390f4e87 100644 --- a/custom_ops/metax_ops/moe_ffn.cu +++ b/custom_ops/metax_ops/moe_ffn.cu @@ -114,6 +114,10 @@ std::vector MoeExpertFFN( const auto input_type = permute_input.dtype(); auto ffn_out = paddle::empty_like(permute_input); + if (permute_input.numel() == 0) { + return {ffn_out}; + } + switch (input_type) { case paddle::DataType::BFLOAT16: McMoeFFNKernel 0: + self.share_inputs["image_features"] = self.extract_vision_features(multi_vision_inputs) + + if len(rope_3d_position_ids["position_ids_idx"]) > 0: + packed_position_ids = paddle.to_tensor( + np.concatenate(rope_3d_position_ids["position_ids_lst"]), dtype="int64" + ) + rope_3d_lst = self.prepare_rope3d( + packed_position_ids, + rope_3d_position_ids["max_tokens_lst"], + rope_3d_position_ids["position_ids_offset"], + ) + for i, idx in enumerate(rope_3d_position_ids["position_ids_idx"]): + self.share_inputs["rope_emb"][idx : idx + 1, :] = rope_3d_lst[i] + if has_prefill_task or has_decode_task: self.share_inputs["not_need_stop"][0] = True self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests] @@ -520,7 +553,7 @@ class MetaxModelRunner(ModelRunnerBase): position_ids = paddle.to_tensor( request.multimodal_inputs["position_ids"], dtype="int64", - ).unsqueeze([0]) + ) else: position_ids = None token_chunk_size = inputs["input_ids"].shape[1] @@ -557,8 +590,8 @@ class MetaxModelRunner(ModelRunnerBase): if self.enable_mm: self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d( - position_ids, request.get("max_tokens", 2048) - ) + position_ids, [request.get("max_tokens", 2048)], [0, position_ids.shape[0]] + )[0] self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0 if request.get("enable_thinking", False) and request.get("reasoning_max_tokens", None) is not None: @@ -970,10 +1003,7 @@ class MetaxModelRunner(ModelRunnerBase): if self.enable_mm: head_dim = self.model_config.head_dim - if "qwen" in self.model_config.model_type: # neox style = True - rope_head_dim = head_dim - else: # neox style = False - rope_head_dim = head_dim // 2 + rope_head_dim = head_dim // 2 self.share_inputs["rope_emb"] = paddle.full( shape=[ @@ -1638,67 +1668,90 @@ class MetaxModelRunner(ModelRunnerBase): time_before_capture = time.perf_counter() expected_decode_len = 1 capture_sizes = self.cudagraph_capture_sizes.copy() - if self.fd_config.graph_opt_config.cudagraph_only_prefill: - for num_tokens in sorted(capture_sizes, reverse=True): - self._dummy_run( - num_tokens=num_tokens, - batch_size=self.scheduler_config.max_num_seqs, - in_capturing=True, - expected_decode_len=expected_decode_len, - capture_prefill=True, - ) - logger.info( - f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}" - ) - elif self.speculative_decoding and self.speculative_method == "mtp": - # Capture Target Model without bsz 1 - for batch_size in sorted(capture_sizes, reverse=True): - if batch_size == 1: - logger.info("Skip token_num = 1, when capture target model for mtp") - else: - assert batch_size % 2 == 0 + try: + if self.fd_config.graph_opt_config.cudagraph_only_prefill: + for num_tokens in sorted(capture_sizes, reverse=True): self._dummy_run( - num_tokens=self.scheduler_config.max_num_batched_tokens, - batch_size=int(batch_size / 2), + num_tokens=num_tokens, + batch_size=self.scheduler_config.max_num_seqs, in_capturing=True, - expected_decode_len=1, + expected_decode_len=expected_decode_len, + capture_prefill=True, ) - logger.info(f"Warm up the Target model with the num_tokens:{batch_size}, expected_decode_len:{1}") - # Capture Draft Model without bsz 1 - # NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph - for batch_size in sorted(capture_sizes, reverse=True): - if batch_size == 1: - logger.info("Skip token_num = 1, when capture Draft model for mtp") - else: - assert batch_size % 2 == 0 + logger.info( + f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}" + ) + elif self.speculative_decoding and self.speculative_method == "mtp": + # Capture Target Model without bsz 1 + for batch_size in sorted(capture_sizes, reverse=True): + if batch_size == 1: + logger.info("Skip token_num = 1, when capture target model for mtp") + else: + assert batch_size % 2 == 0 + self._dummy_run( + num_tokens=self.scheduler_config.max_num_batched_tokens, + batch_size=int(batch_size / 2), + in_capturing=True, + expected_decode_len=1, + ) + logger.info( + f"Warm up the Target model with the num_tokens:{batch_size}, expected_decode_len:{1}" + ) + # Capture Draft Model without bsz 1 + # NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph + for batch_size in sorted(capture_sizes, reverse=True): + if batch_size == 1: + logger.info("Skip token_num = 1, when capture Draft model for mtp") + else: + assert batch_size % 2 == 0 + self._dummy_run( + num_tokens=self.scheduler_config.max_num_batched_tokens, + batch_size=int(batch_size / 2), + in_capturing=True, + expected_decode_len=3, + accept_all_drafts=True, + ) + logger.info( + f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}" + ) + # Capture Draft Model with bsz 1 + if 1 in capture_sizes: self._dummy_run( num_tokens=self.scheduler_config.max_num_batched_tokens, - batch_size=int(batch_size / 2), + batch_size=int(1), in_capturing=True, expected_decode_len=3, - accept_all_drafts=True, + accept_all_drafts=False, ) logger.info(f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}") - # Capture Draft Model with bsz 1 - if 1 in capture_sizes: - self._dummy_run( - num_tokens=self.scheduler_config.max_num_batched_tokens, - batch_size=int(1), - in_capturing=True, - expected_decode_len=3, - accept_all_drafts=False, - ) - logger.info(f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}") - else: - for batch_size in sorted(capture_sizes, reverse=True): - self._dummy_run( - num_tokens=self.scheduler_config.max_num_batched_tokens, - batch_size=batch_size, - in_capturing=True, - expected_decode_len=expected_decode_len, - ) - logger.info(f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}") + else: + for batch_size in sorted(capture_sizes, reverse=True): + self._dummy_run( + num_tokens=self.scheduler_config.max_num_batched_tokens, + batch_size=batch_size, + in_capturing=True, + expected_decode_len=expected_decode_len, + ) + logger.info( + f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}" + ) + except RuntimeError as e: + if "out of memory" in str(e): + raise RuntimeError( + "CUDA out of memory occurred when warming up CUDAGraph " + f"with the capture sizes {capture_sizes}. Please try " + "lowering `max_num_seqs` or `gpu_memory_utilization` when " + "initializing the engine." + ) from e + if "CUDA error(700)" in str(e): + raise RuntimeError( + "CUDA error(700), an illegal memory access was encountered, " + "when warming up CUDAGraph. Please try to set the startup parameter: " + "--graph-optimization-config '{\"use_cudagraph\": false}' to close CUDAGraph" + ) from e + else: + raise e time_after_capture = time.perf_counter() logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds") @@ -2176,7 +2229,7 @@ class MetaxModelRunner(ModelRunnerBase): grid_thw = None if one["position_ids"] is not None: - position_ids = paddle.to_tensor(one["position_ids"], dtype="int64").unsqueeze([0]) + position_ids = paddle.to_tensor(one["position_ids"], dtype="int64") else: position_ids = None @@ -2235,24 +2288,20 @@ class MetaxModelRunner(ModelRunnerBase): raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported") @paddle.no_grad() - def prepare_rope3d(self, position_ids: paddle.Tensor, max_len: int) -> paddle.Tensor: + def prepare_rope3d( + self, position_ids: paddle.Tensor, max_len_lst: list[int], cumsum_seqlens: list[int] + ) -> list[paddle.Tensor]: """prepare_rope3d""" - prefix_max_position_ids = paddle.max(position_ids) + 1 - dec_pos_ids = paddle.tile( - paddle.arange(max_len, dtype="int64").unsqueeze(0).unsqueeze(-1), - [1, 1, 3], - ) - dec_pos_ids = dec_pos_ids + prefix_max_position_ids - position_ids_3d_real = paddle.concat([position_ids, dec_pos_ids], axis=1) - - rope_emb = get_rope_3d( - position_ids=position_ids_3d_real, + rope_emb_lst = get_rope_3d( + position_ids=position_ids, rotary_dim=self.model_config.head_dim, partial_rotary_factor=1.0, base=self.model_config.rope_theta, max_position=self.model_config.max_model_len, freq_allocation=getattr(self.model_config, "freq_allocation", 20), model_type=self.model_config.model_type, + max_len_lst=max_len_lst, + cumsum_seqlens=cumsum_seqlens, ) - return rope_emb + return rope_emb_lst