[xpu] support mtp for xpu(mix) (#5274)

* [XPU] support kernel for mtp(base)

* [XPU] support kernel for mtp(base)

* format

* format

* format

* fix gather next token

* fix step && add test

* fix

* mv pre/post process

* add adjust batch / gather next token for mtp

* fix code style

* fix mtp kenrel name

* fix mtp kernel test

* mv xpu pre/post process

* mv xpu pre/post process

* [xpu] support mtp

* fix code style
This commit is contained in:
cmcamdy
2025-12-01 11:03:14 +08:00
committed by GitHub
parent 8aec3acc8c
commit 9f4977eb74
8 changed files with 691 additions and 106 deletions

View File

@@ -572,6 +572,8 @@ class SpeculativeSampler(nn.Layer):
super().__init__()
if current_platform.is_cuda():
self.forward = self.forward_cuda
elif current_platform.is_xpu():
self.forward = self.forward_xpu
else:
raise NotImplementedError
self.logprobs_mode = fd_config.model_config.logprobs_mode
@@ -814,6 +816,80 @@ class SpeculativeSampler(nn.Layer):
return sampler_output
def forward_xpu(
self,
logits: paddle.Tensor,
sampling_metadata: SamplingMetadata,
max_model_len: int,
share_inputs: List[paddle.Tensor],
accept_all_drafts: bool = False,
reject_all_drafts: bool = False,
) -> paddle.Tensor:
from fastdeploy.model_executor.ops.xpu import speculate_verify, top_p_candidates
logits = apply_speculative_penalty_multi_scores(
sampling_metadata.pre_token_ids,
logits,
sampling_metadata.repetition_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.presence_penalties,
sampling_metadata.temperature,
sampling_metadata.bad_words_token_ids,
sampling_metadata.step_idx,
sampling_metadata.min_dec_lens,
sampling_metadata.eos_token_ids,
share_inputs["seq_lens_this_time"],
share_inputs["output_padding_offset"],
share_inputs["output_cum_offsets"],
max_model_len,
)
probs = F.softmax(logits)
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
probs,
sampling_metadata.top_p,
share_inputs["output_padding_offset"],
self.speculative_max_candidate_len,
max_model_len,
)
speculate_verify(
share_inputs["accept_tokens"],
share_inputs["accept_num"],
share_inputs["step_idx"],
share_inputs["stop_flags"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs[
"draft_tokens"
], # Both input and output, need to write the last 1 token accepted to position 0.
share_inputs["seq_lens_this_time"],
verify_tokens,
verify_scores,
share_inputs["max_dec_len"],
sampling_metadata.eos_token_ids,
share_inputs["is_block_step"],
share_inputs["output_cum_offsets"],
actual_candidate_len,
share_inputs["actual_draft_token_num"],
sampling_metadata.top_p,
max_model_len,
self.speculative_verify_window,
True, # enable_topp
(self.speculative_benchmark_mode or reject_all_drafts),
accept_all_drafts,
)
# TODO(chenhuan09): support return logprobs
token_ids = share_inputs["accept_tokens"]
sampler_output = SamplerOutput(
sampled_token_ids=token_ids,
logprobs_tensors=None,
token_num_per_batch=share_inputs["accept_num"],
cu_batch_token_offset=None,
)
return sampler_output
class MTPSampler(nn.Layer):
""" """
@@ -823,6 +899,8 @@ class MTPSampler(nn.Layer):
super().__init__()
if current_platform.is_cuda():
self.forward = self.forward_cuda
elif current_platform.is_xpu():
self.forward = self.forward_xpu
else:
raise NotImplementedError
self.logprobs_mode = fd_config.model_config.logprobs_mode
@@ -1013,3 +1091,44 @@ class MTPSampler(nn.Layer):
cu_batch_token_offset=share_inputs["cu_batch_token_offset"],
)
return next_tokens, sampler_output
def forward_xpu(
self,
logits: paddle.Tensor,
sampling_metadata: SamplingMetadata,
max_model_len: int,
share_inputs: List[paddle.Tensor],
) -> paddle.Tensor:
logits = apply_speculative_penalty_multi_scores(
sampling_metadata.pre_token_ids,
logits,
sampling_metadata.repetition_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.presence_penalties,
sampling_metadata.temperature,
sampling_metadata.bad_words_token_ids,
sampling_metadata.step_idx,
sampling_metadata.min_dec_lens,
sampling_metadata.eos_token_ids,
share_inputs["seq_lens_this_time"],
share_inputs["output_padding_offset"],
share_inputs["output_cum_offsets"],
max_model_len,
)
probs = F.softmax(logits)
_, next_tokens = top_k_top_p_sampling(
probs, sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.top_k_list
)
# TODO(chenhuan09): add support for logprobs
token_ids = None
logprobs_tensors = None
sampler_output = SamplerOutput(
sampled_token_ids=token_ids,
logprobs_tensors=logprobs_tensors,
token_num_per_batch=None,
cu_batch_token_offset=None,
)
return next_tokens, sampler_output