mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[xpu] support mtp for xpu(mix) (#5274)
* [XPU] support kernel for mtp(base) * [XPU] support kernel for mtp(base) * format * format * format * fix gather next token * fix step && add test * fix * mv pre/post process * add adjust batch / gather next token for mtp * fix code style * fix mtp kenrel name * fix mtp kernel test * mv xpu pre/post process * mv xpu pre/post process * [xpu] support mtp * fix code style
This commit is contained in:
@@ -572,6 +572,8 @@ class SpeculativeSampler(nn.Layer):
|
||||
super().__init__()
|
||||
if current_platform.is_cuda():
|
||||
self.forward = self.forward_cuda
|
||||
elif current_platform.is_xpu():
|
||||
self.forward = self.forward_xpu
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.logprobs_mode = fd_config.model_config.logprobs_mode
|
||||
@@ -814,6 +816,80 @@ class SpeculativeSampler(nn.Layer):
|
||||
|
||||
return sampler_output
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
logits: paddle.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
max_model_len: int,
|
||||
share_inputs: List[paddle.Tensor],
|
||||
accept_all_drafts: bool = False,
|
||||
reject_all_drafts: bool = False,
|
||||
) -> paddle.Tensor:
|
||||
from fastdeploy.model_executor.ops.xpu import speculate_verify, top_p_candidates
|
||||
|
||||
logits = apply_speculative_penalty_multi_scores(
|
||||
sampling_metadata.pre_token_ids,
|
||||
logits,
|
||||
sampling_metadata.repetition_penalties,
|
||||
sampling_metadata.frequency_penalties,
|
||||
sampling_metadata.presence_penalties,
|
||||
sampling_metadata.temperature,
|
||||
sampling_metadata.bad_words_token_ids,
|
||||
sampling_metadata.step_idx,
|
||||
sampling_metadata.min_dec_lens,
|
||||
sampling_metadata.eos_token_ids,
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["output_padding_offset"],
|
||||
share_inputs["output_cum_offsets"],
|
||||
max_model_len,
|
||||
)
|
||||
|
||||
probs = F.softmax(logits)
|
||||
|
||||
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
|
||||
probs,
|
||||
sampling_metadata.top_p,
|
||||
share_inputs["output_padding_offset"],
|
||||
self.speculative_max_candidate_len,
|
||||
max_model_len,
|
||||
)
|
||||
|
||||
speculate_verify(
|
||||
share_inputs["accept_tokens"],
|
||||
share_inputs["accept_num"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs[
|
||||
"draft_tokens"
|
||||
], # Both input and output, need to write the last 1 token accepted to position 0.
|
||||
share_inputs["seq_lens_this_time"],
|
||||
verify_tokens,
|
||||
verify_scores,
|
||||
share_inputs["max_dec_len"],
|
||||
sampling_metadata.eos_token_ids,
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["output_cum_offsets"],
|
||||
actual_candidate_len,
|
||||
share_inputs["actual_draft_token_num"],
|
||||
sampling_metadata.top_p,
|
||||
max_model_len,
|
||||
self.speculative_verify_window,
|
||||
True, # enable_topp
|
||||
(self.speculative_benchmark_mode or reject_all_drafts),
|
||||
accept_all_drafts,
|
||||
)
|
||||
# TODO(chenhuan09): support return logprobs
|
||||
token_ids = share_inputs["accept_tokens"]
|
||||
sampler_output = SamplerOutput(
|
||||
sampled_token_ids=token_ids,
|
||||
logprobs_tensors=None,
|
||||
token_num_per_batch=share_inputs["accept_num"],
|
||||
cu_batch_token_offset=None,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
|
||||
class MTPSampler(nn.Layer):
|
||||
""" """
|
||||
@@ -823,6 +899,8 @@ class MTPSampler(nn.Layer):
|
||||
super().__init__()
|
||||
if current_platform.is_cuda():
|
||||
self.forward = self.forward_cuda
|
||||
elif current_platform.is_xpu():
|
||||
self.forward = self.forward_xpu
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.logprobs_mode = fd_config.model_config.logprobs_mode
|
||||
@@ -1013,3 +1091,44 @@ class MTPSampler(nn.Layer):
|
||||
cu_batch_token_offset=share_inputs["cu_batch_token_offset"],
|
||||
)
|
||||
return next_tokens, sampler_output
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
logits: paddle.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
max_model_len: int,
|
||||
share_inputs: List[paddle.Tensor],
|
||||
) -> paddle.Tensor:
|
||||
|
||||
logits = apply_speculative_penalty_multi_scores(
|
||||
sampling_metadata.pre_token_ids,
|
||||
logits,
|
||||
sampling_metadata.repetition_penalties,
|
||||
sampling_metadata.frequency_penalties,
|
||||
sampling_metadata.presence_penalties,
|
||||
sampling_metadata.temperature,
|
||||
sampling_metadata.bad_words_token_ids,
|
||||
sampling_metadata.step_idx,
|
||||
sampling_metadata.min_dec_lens,
|
||||
sampling_metadata.eos_token_ids,
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["output_padding_offset"],
|
||||
share_inputs["output_cum_offsets"],
|
||||
max_model_len,
|
||||
)
|
||||
probs = F.softmax(logits)
|
||||
|
||||
_, next_tokens = top_k_top_p_sampling(
|
||||
probs, sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.top_k_list
|
||||
)
|
||||
# TODO(chenhuan09): add support for logprobs
|
||||
token_ids = None
|
||||
logprobs_tensors = None
|
||||
|
||||
sampler_output = SamplerOutput(
|
||||
sampled_token_ids=token_ids,
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
token_num_per_batch=None,
|
||||
cu_batch_token_offset=None,
|
||||
)
|
||||
return next_tokens, sampler_output
|
||||
|
||||
Reference in New Issue
Block a user