[xpu] support mtp for xpu(mix) (#5274)

* [XPU] support kernel for mtp(base)

* [XPU] support kernel for mtp(base)

* format

* format

* format

* fix gather next token

* fix step && add test

* fix

* mv pre/post process

* add adjust batch / gather next token for mtp

* fix code style

* fix mtp kenrel name

* fix mtp kernel test

* mv xpu pre/post process

* mv xpu pre/post process

* [xpu] support mtp

* fix code style
This commit is contained in:
cmcamdy
2025-12-01 11:03:14 +08:00
committed by GitHub
parent 8aec3acc8c
commit 9f4977eb74
8 changed files with 691 additions and 106 deletions

View File

@@ -39,7 +39,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
)
from fastdeploy.model_executor.layers.rotary_embedding import get_rope, get_rope_3d
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.sampler import Sampler
from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler
from fastdeploy.model_executor.model_loader import get_model_loader
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
from fastdeploy.model_executor.ops.xpu import (
@@ -49,12 +49,14 @@ from fastdeploy.model_executor.ops.xpu import (
set_data_ipc,
share_external_data,
)
from fastdeploy.model_executor.xpu_pre_and_post_process import ( # xpu_post_process_specualate, # TODO(chenhuan09): add xpu_post_process_specualate
from fastdeploy.model_executor.xpu_pre_and_post_process import (
step_xpu,
xpu_post_process_normal,
xpu_post_process_specualate,
xpu_pre_process,
xpu_process_output,
)
from fastdeploy.spec_decode import MTPProposer
from fastdeploy.utils import get_logger
from fastdeploy.worker.model_runner_base import ModelRunnerBase
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput
@@ -102,9 +104,20 @@ class XPUModelRunner(ModelRunnerBase):
"fused_gemm_epilogue",
]
self.device_id = device_id
self.speculative_method = self.fd_config.speculative_config.method
self.speculative_decoding = self.speculative_method is not None
# used by SamplingMetadata
self.enable_logprob = False # fd_config.model_config.enable_logprob
self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop
# Sampler
# TODU(lilujia): sync with GPU
self.sampler = Sampler(fd_config)
if not self.speculative_decoding:
self.sampler = Sampler(fd_config)
else:
self.sampler = SpeculativeSampler(fd_config)
# Lazy initialize kv cache after model loading
# self.kv_caches: list[paddle.Tensor] = []
@@ -143,7 +156,7 @@ class XPUModelRunner(ModelRunnerBase):
else:
return 0
def insert_tasks_v1(self, req_dicts: List[Request]):
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int):
"""
Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1
req_dict: A list of Request dict
@@ -340,7 +353,10 @@ class XPUModelRunner(ModelRunnerBase):
if has_prefill_task or has_decode_task:
self.share_inputs["not_need_stop"][0] = True
def insert_prefill_inputs(self, req_dicts: List[Request]):
if self.speculative_method in ["mtp"]:
self.proposer.insert_tasks_v1(req_dicts, num_running_requests)
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int):
"""Process inputs for prefill tasks and update share_inputs buffer"""
# NOTE(luotingdan): Set environment variable of prefill node
if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill":
@@ -480,6 +496,15 @@ class XPUModelRunner(ModelRunnerBase):
self.share_inputs["not_need_stop"][0] = True
if self.speculative_method in ["mtp"]:
self.share_inputs["temp_scaled_logprobs"][idx : idx + 1] = get_attr_from_request(
request, "temp_scaled_logprobs", False
)
self.share_inputs["top_p_normalized_logprobs"][idx : idx + 1] = get_attr_from_request(
request, "top_p_normalized_logprobs", False
)
self.proposer.insert_prefill_inputs(req_dicts, num_running_requests)
def _init_share_inputs(self, max_num_seqs: int):
"""Initialize all share buffers for model inputs.
Note: In the future, we may abandon share buffers.
@@ -558,6 +583,15 @@ class XPUModelRunner(ModelRunnerBase):
self.share_inputs["system_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["system_ids"] = paddle.full([max_num_seqs, 1], -1, dtype="int32")
self.share_inputs["ids_remove_padding"] = paddle.full(
[max_num_seqs * self.model_config.max_model_len],
0,
dtype="int64",
)
self.share_inputs["batch_id_per_token"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
# Initialize thinking related buffers
self.share_inputs["max_think_lens"] = paddle.full(shape=[max_num_seqs, 1], fill_value=-1, dtype="int32")
self.share_inputs["limit_think_status"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
@@ -629,6 +663,56 @@ class XPUModelRunner(ModelRunnerBase):
)
self.share_inputs["image_features"] = None
if self.speculative_decoding:
max_draft_token_num = self.speculative_config.num_speculative_tokens
self.share_inputs["input_ids_cpu"] = paddle.full(
shape=[max_num_seqs, self.model_config.max_model_len],
fill_value=1,
dtype="int64",
).cpu()
self.share_inputs["accept_tokens"] = paddle.full(
shape=[max_num_seqs, max_draft_token_num + 1],
fill_value=0,
dtype="int64",
)
self.share_inputs["accept_num"] = paddle.full(shape=[max_num_seqs], fill_value=0, dtype="int32")
self.share_inputs["draft_tokens"] = paddle.full(
shape=[max_num_seqs, max_draft_token_num + 1],
fill_value=0,
dtype="int64",
)
self.share_inputs["actual_draft_token_num"] = paddle.full(
shape=[max_num_seqs],
fill_value=max_draft_token_num,
dtype="int32",
)
self.share_inputs["output_cum_offsets"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
self.share_inputs["output_padding_offset"] = paddle.full(
shape=[max_num_seqs * (max_draft_token_num + 1)],
fill_value=0,
dtype="int32",
)
# For V1_KVCACHE_SCHEDULER
self.share_inputs["step_draft_tokens"] = paddle.full(
shape=[max_num_seqs, max_draft_token_num + 1],
fill_value=0,
dtype="int64",
)
self.share_inputs["step_seq_lens_this_time"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["temp_scaled_logprobs"] = paddle.full([max_num_seqs, 1], False, dtype=bool)
self.share_inputs["top_p_normalized_logprobs"] = paddle.full([max_num_seqs, 1], False, dtype=bool)
# For MTP Logprob
self.share_inputs["draft_logits"] = paddle.full(
[max_num_seqs * (self.speculative_config.num_speculative_tokens + 1), self.model_config.vocab_size],
-1,
dtype="float32",
)
self.share_inputs["cu_batch_token_offset"] = paddle.full(
shape=[max_num_seqs + 1], fill_value=0, dtype="int32"
)
self.max_num_seqs = max_num_seqs
def _prepare_inputs(self, is_dummy_run=False) -> None:
"""Prepare the model inputs"""
if envs.ENABLE_V1_KVCACHE_SCHEDULER and not is_dummy_run:
@@ -646,9 +730,9 @@ class XPUModelRunner(ModelRunnerBase):
self.share_inputs["input_ids"],
self.share_inputs["seq_lens_this_time"],
self.share_inputs,
use_speculate_method=False,
use_speculate_method=self.speculative_decoding,
block_size=self.cache_config.block_size,
draft_tokens=None,
draft_tokens=self.share_inputs["draft_tokens"] if self.speculative_decoding else None,
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
is_profiling=is_dummy_run,
@@ -696,6 +780,7 @@ class XPUModelRunner(ModelRunnerBase):
# 2. Load lora model
# 3. Load drafter model(for speculative decoding)
self._init_speculative_proposer()
def get_model(self) -> nn.Layer:
"""Get current model"""
@@ -793,6 +878,44 @@ class XPUModelRunner(ModelRunnerBase):
)
head_dim = self.model_config.head_dim
if self.speculative_decoding:
# Initialize AttentionBackend buffers
encoder_block_shape_q = 64
decoder_block_shape_q = 16
decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1
decode_max_tile_size = self.max_num_seqs * np.ceil(
(decoder_step_token_num * np.ceil(num_heads / self.model_config.kv_num_heads)) / decoder_block_shape_q
)
group_size = np.ceil(num_heads / self.model_config.kv_num_heads)
encode_max_tile_size = self.scheduler_config.max_num_seqs * np.ceil(
(self.model_config.max_model_len * group_size) / encoder_block_shape_q
)
kv_max_tile_size = self.scheduler_config.max_num_seqs * np.ceil(
self.model_config.max_model_len / self.fd_config.cache_config.block_size
)
self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full(
[int(decode_max_tile_size)], 0, dtype="int32"
)
self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
# NOTE: (changwenbin) MLA kernel only needs decoder_num_blocks_device in place of GPU tensor,
# adapted to cudagraph.
self.share_inputs["decoder_num_blocks_device"] = paddle.full([1], 0, dtype="int32")
self.share_inputs["decoder_chunk_size_device"] = paddle.full([1], 64, dtype="int32")
self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu()
self.share_inputs["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
self.share_inputs["encoder_tile_ids_per_batch"] = paddle.full(
[int(encode_max_tile_size)], 0, dtype="int32"
)
self.share_inputs["encoder_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
self.share_inputs["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
self.share_inputs["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
self.share_inputs["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
self.share_inputs["max_len_kv_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
# Get the attention backend
attn_cls = get_attention_backend()
attn_backend = attn_cls(
@@ -851,12 +974,38 @@ class XPUModelRunner(ModelRunnerBase):
"""
self._dummy_prefill_inputs(num_tokens, batch_size)
if self.speculative_method in ["mtp"]:
self.proposer.dummy_prefill_inputs(
num_tokens=num_tokens,
batch_size=batch_size,
expected_decode_len=1,
)
while True:
self.execute_model(is_dummy_run=True)
if int((self.share_inputs["seq_lens_this_time"] > 0).sum()) == 0:
break
def _init_speculative_proposer(self):
"""
Init speculative proposer
"""
if self.speculative_method == "ngram":
# xpu not support ngram proposer now
# self.proposer = NgramProposer(self.fd_config)
self.proposer = None
elif self.speculative_method == "mtp":
self.proposer = MTPProposer(
self.fd_config,
self.get_model(),
self.local_rank,
self.device_id,
self.share_inputs,
)
else:
self.proposer = None
def _set_debug_level(
self, debug_level: int = 0x1, model_forward_batch: Optional[List[Request]] = None, is_dummy_run: bool = False
) -> None:
@@ -941,7 +1090,16 @@ class XPUModelRunner(ModelRunnerBase):
)
# 4. Compute logits, Sample
logits = self.model.compute_logits(hidden_states)
sampler_output = self.sampler(logits, self.sampling_metadata)
sampler_output = None
if not self.speculative_decoding:
sampler_output = self.sampler(logits, self.sampling_metadata)
else:
self.sampler(
logits,
self.sampling_metadata,
self.model_config.max_model_len,
self.share_inputs,
)
# 5. Speculative decode
@@ -961,26 +1119,36 @@ class XPUModelRunner(ModelRunnerBase):
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
is_block_step=self.share_inputs["is_block_step"],
# 投机解码
full_hidden_states=None,
full_hidden_states=model_output if self.speculative_decoding else None,
msg_queue_id=self.parallel_config.msg_queue_id,
mp_rank=self.local_rank,
use_ep=self.parallel_config.use_ep,
draft_tokens=None,
actual_draft_token_num=None,
accept_tokens=None,
accept_num=None,
draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None),
actual_draft_token_num=(
self.share_inputs["actual_draft_token_num"] if self.speculative_decoding else None
),
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
)
xpu_post_process_normal(
sampled_token_ids=sampler_output.sampled_token_ids,
model_output=model_output_data,
share_inputs=self.share_inputs,
block_size=self.cache_config.block_size,
skip_save_output=is_dummy_run,
think_end_id=self.model_config.think_end_id,
line_break_id=self.model_config.line_break_id,
)
if self.speculative_decoding:
# base model post process
xpu_post_process_specualate(model_output_data, False, is_dummy_run)
else:
xpu_post_process_normal(
sampled_token_ids=sampler_output.sampled_token_ids,
model_output=model_output_data,
share_inputs=self.share_inputs,
block_size=self.cache_config.block_size,
skip_save_output=is_dummy_run,
think_end_id=self.model_config.think_end_id,
line_break_id=self.model_config.line_break_id,
)
# draft model propose
if self.speculative_method == "mtp":
self.proposer.run(full_hidden_states=model_output)
# 7. Updata 'infer_seed' and step_paddle()
self.share_inputs["infer_seed"].add_(self.infer_seed_increment)
@@ -989,6 +1157,8 @@ class XPUModelRunner(ModelRunnerBase):
self.share_inputs,
self.cache_config.block_size,
self.cache_config.enc_dec_block_num,
self.speculative_decoding,
self.speculative_config.num_speculative_tokens,
)
if self.pd_disaggregation_mode == "per_chunk" or self.pd_disaggregation_mode == "per_query":
@@ -1013,6 +1183,9 @@ class XPUModelRunner(ModelRunnerBase):
self.num_gpu_blocks = self.cache_config.total_block_num
self.initialize_kv_cache(profile=True)
if self.speculative_method in ["mtp"]:
self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks, profile=True)
self._dummy_run(
num_tokens=int(self.scheduler_config.max_num_batched_tokens),
batch_size=min(self.scheduler_config.max_num_seqs, 1),