[Metax] adapt cutlass moe for ernie-vl (#4685)

This commit is contained in:
Neil Zhu
2025-11-03 17:44:27 +08:00
committed by GitHub
parent 69c2f3cda1
commit c95d0740ec
6 changed files with 174 additions and 101 deletions

View File

@@ -101,6 +101,10 @@ std::vector<paddle::Tensor> FusedExpertMoe(
const auto input_type = input.dtype();
auto output = paddle::empty_like(input);
if (output.dims()[0] == 0) {
return {output};
}
switch (input_type) {
case paddle::DataType::BFLOAT16:
FusedMoeKernel<paddle::DataType::BFLOAT16,

View File

@@ -178,6 +178,14 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
auto permute_indices_per_token =
GetEmptyTensor({moe_topk, num_rows}, paddle::DataType::INT32, place);
if (token_rows == 0) {
return {permute_input,
tokens_expert_prefix_sum,
permute_indices_per_token,
top_k_weight,
top_k_indices};
}
switch (input_type) {
case paddle::DataType::BFLOAT16:
MoeDispatchKernel<paddle::DataType::BFLOAT16>(input,

View File

@@ -114,6 +114,10 @@ std::vector<paddle::Tensor> MoeExpertFFN(
const auto input_type = permute_input.dtype();
auto ffn_out = paddle::empty_like(permute_input);
if (permute_input.numel() == 0) {
return {ffn_out};
}
switch (input_type) {
case paddle::DataType::BFLOAT16:
McMoeFFNKernel<paddle::DataType::BFLOAT16,

View File

@@ -614,6 +614,8 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
"gpu_ops/text_image_gather_scatter.cu",
"gpu_ops/text_image_index_out.cu",
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
"gpu_ops/limit_thinking_content_length_v1.cu",
"gpu_ops/limit_thinking_content_length_v2.cu",
"gpu_ops/append_attn/mla_cache_kernel.cu",
"gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu",
"gpu_ops/moe/tritonmoe_preprocess.cu",

View File

@@ -50,8 +50,12 @@ elif current_platform.is_dcu():
elif current_platform.is_maca():
from fastdeploy.model_executor.ops.gpu import (
get_padding_offset,
limit_thinking_content_length_v1,
limit_thinking_content_length_v2,
save_output,
set_stop_value_multi_ends,
speculate_limit_thinking_content_length_v1,
speculate_limit_thinking_content_length_v2,
step_paddle,
update_inputs,
update_inputs_v1,
@@ -810,7 +814,9 @@ def rebuild_padding(
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
first_token_out,
max_input_length,
enable_logprob,
)
else:
raise RuntimeError("Not supported platform")

View File

@@ -129,7 +129,7 @@ class MetaxModelRunner(ModelRunnerBase):
# self.kv_caches: list[paddle.Tensor] = []
# CUDA Graph
self.use_cudagraph = self.graph_opt_config.use_cudagraph
self.use_cudagraph = False
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
self.cudagraph_only_prefill = self.graph_opt_config.cudagraph_only_prefill
@@ -285,6 +285,14 @@ class MetaxModelRunner(ModelRunnerBase):
req_len = len(req_dicts)
has_prefill_task = False
has_decode_task = False
multi_vision_inputs = {"images_lst": [], "grid_thw_lst": [], "vit_position_ids_lst": [], "cu_seqlens": [0]}
rope_3d_position_ids = {
"position_ids_idx": [],
"position_ids_lst": [],
"position_ids_offset": [0],
"max_tokens_lst": [],
}
for i in range(req_len):
request = req_dicts[i]
idx = request.idx
@@ -295,39 +303,49 @@ class MetaxModelRunner(ModelRunnerBase):
if self.enable_mm:
inputs = request.multimodal_inputs
if request.with_image:
vision_inputs = {}
vision_inputs["input_ids"] = paddle.to_tensor(
inputs["input_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
)
vision_inputs["token_type_ids"] = paddle.to_tensor(
inputs["token_type_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
)
vision_inputs["image_type_ids"] = paddle.to_tensor(
inputs["image_type_ids"][request.image_type_ids_start : request.image_type_ids_end],
dtype=paddle.int64,
)
vision_inputs["images"] = paddle.to_tensor(
inputs["images"][request.image_start : request.image_end],
dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16",
)
vision_inputs["grid_thw"] = paddle.to_tensor(
inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64"
)
self.share_inputs["image_features"] = self.extract_vision_features(vision_inputs)
if envs.FD_ENABLE_MAX_PREFILL:
multi_vision_inputs["images_lst"].append(
inputs["images"][request.image_start : request.image_end].cuda()
)
multi_vision_inputs["grid_thw_lst"].extend(
inputs["grid_thw"][request.num_image_start : request.num_image_end]
)
multi_vision_inputs["cu_seqlens"].extend(
inputs["vit_seqlen"][request.num_image_start : request.num_image_end]
)
multi_vision_inputs["vit_position_ids_lst"].extend(
inputs["vit_position_ids"][request.num_image_start : request.num_image_end]
)
else:
vision_inputs = {}
vision_inputs["input_ids"] = paddle.to_tensor(
inputs["input_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
)
vision_inputs["token_type_ids"] = paddle.to_tensor(
inputs["token_type_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
)
vision_inputs["image_type_ids"] = paddle.to_tensor(
inputs["image_type_ids"][request.image_type_ids_start : request.image_type_ids_end],
dtype=paddle.int64,
)
vision_inputs["images"] = paddle.to_tensor(
inputs["images"][request.image_start : request.image_end],
dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16",
)
vision_inputs["grid_thw"] = paddle.to_tensor(
inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64"
)
self.share_inputs["image_features"] = self.extract_vision_features(vision_inputs)
else:
self.share_inputs["image_features"] = None
if inputs["position_ids"] is not None:
position_ids = paddle.to_tensor(
request.multimodal_inputs["position_ids"],
dtype="int64",
).unsqueeze([0])
else:
position_ids = None
self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d(
position_ids, request.get("max_tokens", 2048)
position_ids = request.multimodal_inputs["position_ids"]
rope_3d_position_ids["position_ids_idx"].append(idx)
rope_3d_position_ids["position_ids_lst"].append(position_ids)
rope_3d_position_ids["position_ids_offset"].append(
position_ids.shape[0] + rope_3d_position_ids["position_ids_offset"][-1]
)
rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048))
if request.get("enable_thinking", False) and request.get("reasoning_max_tokens", None) is not None:
# Enable thinking
@@ -441,6 +459,21 @@ class MetaxModelRunner(ModelRunnerBase):
else:
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0
if len(multi_vision_inputs["images_lst"]) > 0:
self.share_inputs["image_features"] = self.extract_vision_features(multi_vision_inputs)
if len(rope_3d_position_ids["position_ids_idx"]) > 0:
packed_position_ids = paddle.to_tensor(
np.concatenate(rope_3d_position_ids["position_ids_lst"]), dtype="int64"
)
rope_3d_lst = self.prepare_rope3d(
packed_position_ids,
rope_3d_position_ids["max_tokens_lst"],
rope_3d_position_ids["position_ids_offset"],
)
for i, idx in enumerate(rope_3d_position_ids["position_ids_idx"]):
self.share_inputs["rope_emb"][idx : idx + 1, :] = rope_3d_lst[i]
if has_prefill_task or has_decode_task:
self.share_inputs["not_need_stop"][0] = True
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
@@ -520,7 +553,7 @@ class MetaxModelRunner(ModelRunnerBase):
position_ids = paddle.to_tensor(
request.multimodal_inputs["position_ids"],
dtype="int64",
).unsqueeze([0])
)
else:
position_ids = None
token_chunk_size = inputs["input_ids"].shape[1]
@@ -557,8 +590,8 @@ class MetaxModelRunner(ModelRunnerBase):
if self.enable_mm:
self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d(
position_ids, request.get("max_tokens", 2048)
)
position_ids, [request.get("max_tokens", 2048)], [0, position_ids.shape[0]]
)[0]
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
if request.get("enable_thinking", False) and request.get("reasoning_max_tokens", None) is not None:
@@ -970,10 +1003,7 @@ class MetaxModelRunner(ModelRunnerBase):
if self.enable_mm:
head_dim = self.model_config.head_dim
if "qwen" in self.model_config.model_type: # neox style = True
rope_head_dim = head_dim
else: # neox style = False
rope_head_dim = head_dim // 2
rope_head_dim = head_dim // 2
self.share_inputs["rope_emb"] = paddle.full(
shape=[
@@ -1638,67 +1668,90 @@ class MetaxModelRunner(ModelRunnerBase):
time_before_capture = time.perf_counter()
expected_decode_len = 1
capture_sizes = self.cudagraph_capture_sizes.copy()
if self.fd_config.graph_opt_config.cudagraph_only_prefill:
for num_tokens in sorted(capture_sizes, reverse=True):
self._dummy_run(
num_tokens=num_tokens,
batch_size=self.scheduler_config.max_num_seqs,
in_capturing=True,
expected_decode_len=expected_decode_len,
capture_prefill=True,
)
logger.info(
f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}"
)
elif self.speculative_decoding and self.speculative_method == "mtp":
# Capture Target Model without bsz 1
for batch_size in sorted(capture_sizes, reverse=True):
if batch_size == 1:
logger.info("Skip token_num = 1, when capture target model for mtp")
else:
assert batch_size % 2 == 0
try:
if self.fd_config.graph_opt_config.cudagraph_only_prefill:
for num_tokens in sorted(capture_sizes, reverse=True):
self._dummy_run(
num_tokens=self.scheduler_config.max_num_batched_tokens,
batch_size=int(batch_size / 2),
num_tokens=num_tokens,
batch_size=self.scheduler_config.max_num_seqs,
in_capturing=True,
expected_decode_len=1,
expected_decode_len=expected_decode_len,
capture_prefill=True,
)
logger.info(f"Warm up the Target model with the num_tokens:{batch_size}, expected_decode_len:{1}")
# Capture Draft Model without bsz 1
# NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph
for batch_size in sorted(capture_sizes, reverse=True):
if batch_size == 1:
logger.info("Skip token_num = 1, when capture Draft model for mtp")
else:
assert batch_size % 2 == 0
logger.info(
f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}"
)
elif self.speculative_decoding and self.speculative_method == "mtp":
# Capture Target Model without bsz 1
for batch_size in sorted(capture_sizes, reverse=True):
if batch_size == 1:
logger.info("Skip token_num = 1, when capture target model for mtp")
else:
assert batch_size % 2 == 0
self._dummy_run(
num_tokens=self.scheduler_config.max_num_batched_tokens,
batch_size=int(batch_size / 2),
in_capturing=True,
expected_decode_len=1,
)
logger.info(
f"Warm up the Target model with the num_tokens:{batch_size}, expected_decode_len:{1}"
)
# Capture Draft Model without bsz 1
# NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph
for batch_size in sorted(capture_sizes, reverse=True):
if batch_size == 1:
logger.info("Skip token_num = 1, when capture Draft model for mtp")
else:
assert batch_size % 2 == 0
self._dummy_run(
num_tokens=self.scheduler_config.max_num_batched_tokens,
batch_size=int(batch_size / 2),
in_capturing=True,
expected_decode_len=3,
accept_all_drafts=True,
)
logger.info(
f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}"
)
# Capture Draft Model with bsz 1
if 1 in capture_sizes:
self._dummy_run(
num_tokens=self.scheduler_config.max_num_batched_tokens,
batch_size=int(batch_size / 2),
batch_size=int(1),
in_capturing=True,
expected_decode_len=3,
accept_all_drafts=True,
accept_all_drafts=False,
)
logger.info(f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}")
# Capture Draft Model with bsz 1
if 1 in capture_sizes:
self._dummy_run(
num_tokens=self.scheduler_config.max_num_batched_tokens,
batch_size=int(1),
in_capturing=True,
expected_decode_len=3,
accept_all_drafts=False,
)
logger.info(f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}")
else:
for batch_size in sorted(capture_sizes, reverse=True):
self._dummy_run(
num_tokens=self.scheduler_config.max_num_batched_tokens,
batch_size=batch_size,
in_capturing=True,
expected_decode_len=expected_decode_len,
)
logger.info(f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}")
else:
for batch_size in sorted(capture_sizes, reverse=True):
self._dummy_run(
num_tokens=self.scheduler_config.max_num_batched_tokens,
batch_size=batch_size,
in_capturing=True,
expected_decode_len=expected_decode_len,
)
logger.info(
f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}"
)
except RuntimeError as e:
if "out of memory" in str(e):
raise RuntimeError(
"CUDA out of memory occurred when warming up CUDAGraph "
f"with the capture sizes {capture_sizes}. Please try "
"lowering `max_num_seqs` or `gpu_memory_utilization` when "
"initializing the engine."
) from e
if "CUDA error(700)" in str(e):
raise RuntimeError(
"CUDA error(700), an illegal memory access was encountered, "
"when warming up CUDAGraph. Please try to set the startup parameter: "
"--graph-optimization-config '{\"use_cudagraph\": false}' to close CUDAGraph"
) from e
else:
raise e
time_after_capture = time.perf_counter()
logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds")
@@ -2176,7 +2229,7 @@ class MetaxModelRunner(ModelRunnerBase):
grid_thw = None
if one["position_ids"] is not None:
position_ids = paddle.to_tensor(one["position_ids"], dtype="int64").unsqueeze([0])
position_ids = paddle.to_tensor(one["position_ids"], dtype="int64")
else:
position_ids = None
@@ -2235,24 +2288,20 @@ class MetaxModelRunner(ModelRunnerBase):
raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported")
@paddle.no_grad()
def prepare_rope3d(self, position_ids: paddle.Tensor, max_len: int) -> paddle.Tensor:
def prepare_rope3d(
self, position_ids: paddle.Tensor, max_len_lst: list[int], cumsum_seqlens: list[int]
) -> list[paddle.Tensor]:
"""prepare_rope3d"""
prefix_max_position_ids = paddle.max(position_ids) + 1
dec_pos_ids = paddle.tile(
paddle.arange(max_len, dtype="int64").unsqueeze(0).unsqueeze(-1),
[1, 1, 3],
)
dec_pos_ids = dec_pos_ids + prefix_max_position_ids
position_ids_3d_real = paddle.concat([position_ids, dec_pos_ids], axis=1)
rope_emb = get_rope_3d(
position_ids=position_ids_3d_real,
rope_emb_lst = get_rope_3d(
position_ids=position_ids,
rotary_dim=self.model_config.head_dim,
partial_rotary_factor=1.0,
base=self.model_config.rope_theta,
max_position=self.model_config.max_model_len,
freq_allocation=getattr(self.model_config, "freq_allocation", 20),
model_type=self.model_config.model_type,
max_len_lst=max_len_lst,
cumsum_seqlens=cumsum_seqlens,
)
return rope_emb
return rope_emb_lst