[xpu] support mtp for xpu(mix) (#5274)

* [XPU] support kernel for mtp(base)

* [XPU] support kernel for mtp(base)

* format

* format

* format

* fix gather next token

* fix step && add test

* fix

* mv pre/post process

* add adjust batch / gather next token for mtp

* fix code style

* fix mtp kenrel name

* fix mtp kernel test

* mv xpu pre/post process

* mv xpu pre/post process

* [xpu] support mtp

* fix code style
This commit is contained in:
cmcamdy
2025-12-01 11:03:14 +08:00
committed by GitHub
parent 8aec3acc8c
commit 9f4977eb74
8 changed files with 691 additions and 106 deletions

View File

@@ -34,21 +34,39 @@ from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.sampler import MTPSampler
from fastdeploy.model_executor.model_loader import get_model_loader
from fastdeploy.model_executor.models import ModelForCasualLM
from fastdeploy.model_executor.ops.gpu import (
draft_model_postprocess,
draft_model_preprocess,
draft_model_update,
eagle_get_hidden_states,
eagle_get_self_hidden_states,
hybrid_mtp_ngram,
mtp_save_first_token,
mtp_step_paddle,
share_external_data,
speculate_get_logits,
speculate_save_output_topk,
update_attn_mask_offsets,
)
from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding
from fastdeploy.platforms import current_platform
if current_platform.is_xpu():
from fastdeploy.model_executor.ops.xpu import (
draft_model_postprocess,
draft_model_preprocess,
draft_model_update,
eagle_get_hidden_states,
eagle_get_self_hidden_states,
mtp_save_first_token,
mtp_step_paddle,
share_external_data,
)
from fastdeploy.model_executor.xpu_pre_and_post_process import (
xpu_pre_process,
xpu_process_output,
)
else:
from fastdeploy.model_executor.ops.gpu import (
draft_model_postprocess,
draft_model_preprocess,
draft_model_update,
eagle_get_hidden_states,
eagle_get_self_hidden_states,
hybrid_mtp_ngram,
mtp_save_first_token,
mtp_step_paddle,
share_external_data,
speculate_get_logits,
speculate_save_output_topk,
update_attn_mask_offsets,
)
from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding
from .base import Proposer
@@ -79,6 +97,15 @@ class MTPProposer(Proposer):
# [mixed, prefill, decoder]
self.role = self.scheduler_config.splitwise_role
if current_platform.is_xpu():
self.role = "mixed"
if current_platform.is_xpu():
self._propose = self._propose_xpu
elif current_platform.is_cuda():
self._propose = self._propose_cuda
else:
raise RuntimeError("Unsupported platform.")
self.sampler = MTPSampler(fd_config)
self._init_model_inputs()
@@ -92,7 +119,7 @@ class MTPProposer(Proposer):
self._initialize_attn_backend()
# Forward meta store the global meta information of the forward
self.forward_meta: ForwardMeta = None
self.forward_meta = None
def _update_mtp_config(self, main_model):
"""
@@ -166,7 +193,7 @@ class MTPProposer(Proposer):
and hasattr(self.quant_config, "kv_cache_quant_type")
and self.quant_config.kv_cache_quant_type is not None
):
cache_type = "uint8"
cache_type = self._get_cache_type()
kv_cache_quant_type = self.quant_config.kv_cache_quant_type
# Get kv cache shape
@@ -220,7 +247,7 @@ class MTPProposer(Proposer):
self.model_inputs["caches"] = list(self.cache_kvs.values())
for value in self.cache_kvs.values():
del value
paddle.device.cuda.empty_cache()
self._empty_cache()
def _initialize_attn_backend(
self,
@@ -245,9 +272,14 @@ class MTPProposer(Proposer):
self.model_inputs["decoder_tile_ids_per_batch"] = paddle.zeros_like(
self.target_model_inputs["decoder_tile_ids_per_batch"]
)
self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like(
self.target_model_inputs["decoder_num_blocks_cpu"]
).pin_memory()
if current_platform.is_xpu():
self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like(
self.target_model_inputs["decoder_num_blocks_cpu"]
).cpu()
else:
self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like(
self.target_model_inputs["decoder_num_blocks_cpu"]
).pin_memory()
self.model_inputs["decoder_num_blocks_device"] = paddle.zeros_like(
self.target_model_inputs["decoder_num_blocks_device"]
)
@@ -669,6 +701,36 @@ class MTPProposer(Proposer):
self.forward_meta.step_use_cudagraph = step_use_cudagraph and self.draft_model_use_cudagraph
def _initialize_forward_meta_xpu(self):
self.forward_meta.decoder_batch_ids = (self.model_inputs["decoder_batch_ids"],)
self.forward_meta.decoder_tile_ids_per_batch = (self.model_inputs["decoder_tile_ids_per_batch"],)
self.forward_meta.decoder_num_blocks_cpu = (self.model_inputs["decoder_num_blocks_cpu"],)
self.forward_meta.decoder_num_blocks_device = (self.model_inputs["decoder_num_blocks_device"],)
self.forward_meta.decoder_chunk_size_device = (self.model_inputs["decoder_chunk_size_device"],)
self.forward_meta.max_len_tensor_cpu = (self.model_inputs["max_len_tensor_cpu"],)
self.forward_meta.encoder_batch_ids = (self.model_inputs["encoder_batch_ids"],)
self.forward_meta.encoder_tile_ids_per_batch = (self.model_inputs["encoder_tile_ids_per_batch"],)
self.forward_meta.encoder_num_blocks_x_cpu = (self.model_inputs["encoder_num_blocks_x_cpu"],)
self.forward_meta.kv_batch_ids = (self.model_inputs["kv_batch_ids"],)
self.forward_meta.kv_tile_ids_per_batch = (self.model_inputs["kv_tile_ids_per_batch"],)
self.forward_meta.kv_num_blocks_x_cpu = (self.model_inputs["kv_num_blocks_x_cpu"],)
self.forward_meta.pos_emb_type = "NORMAL"
self.forward_meta.attn_backend = self.attn_backends[0]
# Initialzie attention meta data
for attn_backend in self.attn_backends:
attn_backend.init_attention_metadata(self.forward_meta)
# Mix ep in single node
if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
only_decode_batch_list = []
prefill_exists = self.exist_prefill()
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
only_decode_batch = all(only_decode_batch_list)
self.fd_config.model_config.moe_phase.phase = "decode" if only_decode_batch else "prefill"
def exist_prefill(self):
"""
check whether prefill stage exist
@@ -682,7 +744,7 @@ class MTPProposer(Proposer):
"""
Prepare MTP inputs
"""
use_v1_cache_scheduler = envs.ENABLE_V1_KVCACHE_SCHEDULER
use_v1_cache_scheduler = bool(envs.ENABLE_V1_KVCACHE_SCHEDULER)
draft_model_preprocess(
self.model_inputs["draft_tokens"],
self.model_inputs["input_ids"],
@@ -767,7 +829,7 @@ class MTPProposer(Proposer):
self.model_inputs["step_idx"],
)
def _propose(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False):
def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False):
"""
Main process for MTP inference.
Args:
@@ -928,6 +990,96 @@ class MTPProposer(Proposer):
if hasattr(self.model, "empty_input_forward"):
self.model.empty_input_forward()
def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False):
"""
Main process for MTP inference.
Args:
step_use_cudagraph: bool
Whether to use cuda graph. Use the target model flag to avoid hanging problems with EP.
"""
for substep in range(self.num_model_steps):
if self.model_inputs["not_need_stop"]:
self.model_inputs["substep"] = substep
# Remove padding
self.forward_meta = xpu_pre_process(
self.model_inputs["input_ids"],
self.model_inputs["seq_lens_this_time"],
self.model_inputs,
True,
self.cache_config.block_size,
self.model_inputs["draft_tokens"],
self.model_inputs["seq_lens_encoder"],
self.model_inputs["seq_lens_decoder"],
)
self._initialize_forward_meta_xpu()
# Get sampling metadata
self.sampling_metadata = SamplingMetadata(
temperature=self.model_inputs["temperature"],
top_p=self.model_inputs["top_p"],
top_k=self.model_inputs["top_k"],
seed=self.model_inputs["infer_seed"],
step_idx=self.model_inputs["step_idx"],
pre_token_ids=self.model_inputs["pre_ids"],
frequency_penalties=self.model_inputs["frequency_score"],
presence_penalties=self.model_inputs["presence_score"],
repetition_penalties=self.model_inputs["penalty_score"],
min_dec_lens=self.model_inputs["min_dec_len"],
bad_words_token_ids=self.model_inputs["bad_tokens"],
eos_token_ids=self.model_inputs["eos_token_id"],
max_num_logprobs=20 if self.enable_logprob else None,
temp_scaled_logprobs=self.model_inputs["temp_scaled_logprobs"],
top_p_normalized_logprobs=self.model_inputs["top_p_normalized_logprobs"],
share_inputs=self.model_inputs,
)
if self.num_model_steps > 1:
self.last_seq_lens_this_time = paddle.clone(self.model_inputs["seq_lens_this_time"])
model_output = self.model(
ids_remove_padding=self.model_inputs["ids_remove_padding"],
previous_hidden_states=self.model_inputs["target_hidden_states"],
forward_meta=self.forward_meta,
)
hidden_states = xpu_process_output(
model_output, self.model_inputs["cum_offsets"], self.forward_meta, self.model_inputs
)
# 4. Compute logits, Sample
logits = self.model.compute_logits(hidden_states)
sampled_token_ids, sampler_output = self.sampler(
logits,
self.sampling_metadata,
self.max_model_len,
self.model_inputs,
)
if substep == 0 and sampler_output.logprobs_tensors is not None:
real_bsz = self.model_inputs["seq_lens_this_time"].shape[0]
speculate_save_output_topk(
sampler_output.sampled_token_ids,
sampler_output.logprobs_tensors.logprob_token_ids,
sampler_output.logprobs_tensors.logprobs,
sampler_output.logprobs_tensors.selected_token_ranks,
self.model_inputs["batch_token_num"][:real_bsz],
self.model_inputs["cu_batch_token_offset"][:real_bsz],
self.model_inputs["not_need_stop"],
4, # mtype
self.local_rank,
)
if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.broadcast(
sampled_token_ids,
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
group=self.parallel_config.tp_group,
)
self._post_process(sampled_token_ids)
if substep != self.num_model_steps - 1:
self._get_self_hidden_states(hidden_states)
else:
if hasattr(self.model, "empty_input_forward"):
self.model.empty_input_forward()
def _get_self_hidden_states(self, hidden_states):
target_hidden_states = eagle_get_self_hidden_states(
hidden_states,
@@ -1044,3 +1196,21 @@ class MTPProposer(Proposer):
self.forward_meta.seq_lens_this_time = self.seq_lens_this_time_buffer
self.real_token_num = self.forward_meta.ids_remove_padding.shape[0]
return
def _empty_cache(self):
if current_platform.is_cuda():
paddle.device.cuda.empty_cache()
elif current_platform.is_xpu():
paddle.device.xpu.empty_cache()
else:
raise NotImplementedError
def _get_cache_type(self):
cache_type = None
if current_platform.is_cuda():
cache_type = "uint8"
elif current_platform.is_xpu():
cache_type = "int8"
else:
raise NotImplementedError
return cache_type