mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[xpu] support mtp for xpu(mix) (#5274)
* [XPU] support kernel for mtp(base) * [XPU] support kernel for mtp(base) * format * format * format * fix gather next token * fix step && add test * fix * mv pre/post process * add adjust batch / gather next token for mtp * fix code style * fix mtp kenrel name * fix mtp kernel test * mv xpu pre/post process * mv xpu pre/post process * [xpu] support mtp * fix code style
This commit is contained in:
@@ -182,24 +182,28 @@ def apply_speculative_penalty_multi_scores(
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
speculate_get_token_penalty_multi_scores,
|
||||
)
|
||||
|
||||
speculate_get_token_penalty_multi_scores(
|
||||
pre_token_ids,
|
||||
logits,
|
||||
repetition_penalties,
|
||||
frequency_penalties,
|
||||
presence_penalties,
|
||||
temperature,
|
||||
bad_words_token_ids,
|
||||
step_idx,
|
||||
min_dec_lens,
|
||||
eos_token_ids,
|
||||
seq_lens_this_time,
|
||||
output_padding_offset,
|
||||
output_cum_offsets,
|
||||
max_len,
|
||||
elif current_platform.is_xpu():
|
||||
from fastdeploy.model_executor.ops.xpu import (
|
||||
speculate_get_token_penalty_multi_scores,
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
speculate_get_token_penalty_multi_scores(
|
||||
pre_token_ids,
|
||||
logits,
|
||||
repetition_penalties,
|
||||
frequency_penalties,
|
||||
presence_penalties,
|
||||
temperature,
|
||||
bad_words_token_ids,
|
||||
step_idx,
|
||||
min_dec_lens,
|
||||
eos_token_ids,
|
||||
seq_lens_this_time,
|
||||
output_padding_offset,
|
||||
output_cum_offsets,
|
||||
max_len,
|
||||
)
|
||||
# inplace
|
||||
return logits
|
||||
|
||||
@@ -572,6 +572,8 @@ class SpeculativeSampler(nn.Layer):
|
||||
super().__init__()
|
||||
if current_platform.is_cuda():
|
||||
self.forward = self.forward_cuda
|
||||
elif current_platform.is_xpu():
|
||||
self.forward = self.forward_xpu
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.logprobs_mode = fd_config.model_config.logprobs_mode
|
||||
@@ -814,6 +816,80 @@ class SpeculativeSampler(nn.Layer):
|
||||
|
||||
return sampler_output
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
logits: paddle.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
max_model_len: int,
|
||||
share_inputs: List[paddle.Tensor],
|
||||
accept_all_drafts: bool = False,
|
||||
reject_all_drafts: bool = False,
|
||||
) -> paddle.Tensor:
|
||||
from fastdeploy.model_executor.ops.xpu import speculate_verify, top_p_candidates
|
||||
|
||||
logits = apply_speculative_penalty_multi_scores(
|
||||
sampling_metadata.pre_token_ids,
|
||||
logits,
|
||||
sampling_metadata.repetition_penalties,
|
||||
sampling_metadata.frequency_penalties,
|
||||
sampling_metadata.presence_penalties,
|
||||
sampling_metadata.temperature,
|
||||
sampling_metadata.bad_words_token_ids,
|
||||
sampling_metadata.step_idx,
|
||||
sampling_metadata.min_dec_lens,
|
||||
sampling_metadata.eos_token_ids,
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["output_padding_offset"],
|
||||
share_inputs["output_cum_offsets"],
|
||||
max_model_len,
|
||||
)
|
||||
|
||||
probs = F.softmax(logits)
|
||||
|
||||
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
|
||||
probs,
|
||||
sampling_metadata.top_p,
|
||||
share_inputs["output_padding_offset"],
|
||||
self.speculative_max_candidate_len,
|
||||
max_model_len,
|
||||
)
|
||||
|
||||
speculate_verify(
|
||||
share_inputs["accept_tokens"],
|
||||
share_inputs["accept_num"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs[
|
||||
"draft_tokens"
|
||||
], # Both input and output, need to write the last 1 token accepted to position 0.
|
||||
share_inputs["seq_lens_this_time"],
|
||||
verify_tokens,
|
||||
verify_scores,
|
||||
share_inputs["max_dec_len"],
|
||||
sampling_metadata.eos_token_ids,
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["output_cum_offsets"],
|
||||
actual_candidate_len,
|
||||
share_inputs["actual_draft_token_num"],
|
||||
sampling_metadata.top_p,
|
||||
max_model_len,
|
||||
self.speculative_verify_window,
|
||||
True, # enable_topp
|
||||
(self.speculative_benchmark_mode or reject_all_drafts),
|
||||
accept_all_drafts,
|
||||
)
|
||||
# TODO(chenhuan09): support return logprobs
|
||||
token_ids = share_inputs["accept_tokens"]
|
||||
sampler_output = SamplerOutput(
|
||||
sampled_token_ids=token_ids,
|
||||
logprobs_tensors=None,
|
||||
token_num_per_batch=share_inputs["accept_num"],
|
||||
cu_batch_token_offset=None,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
|
||||
class MTPSampler(nn.Layer):
|
||||
""" """
|
||||
@@ -823,6 +899,8 @@ class MTPSampler(nn.Layer):
|
||||
super().__init__()
|
||||
if current_platform.is_cuda():
|
||||
self.forward = self.forward_cuda
|
||||
elif current_platform.is_xpu():
|
||||
self.forward = self.forward_xpu
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.logprobs_mode = fd_config.model_config.logprobs_mode
|
||||
@@ -1013,3 +1091,44 @@ class MTPSampler(nn.Layer):
|
||||
cu_batch_token_offset=share_inputs["cu_batch_token_offset"],
|
||||
)
|
||||
return next_tokens, sampler_output
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
logits: paddle.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
max_model_len: int,
|
||||
share_inputs: List[paddle.Tensor],
|
||||
) -> paddle.Tensor:
|
||||
|
||||
logits = apply_speculative_penalty_multi_scores(
|
||||
sampling_metadata.pre_token_ids,
|
||||
logits,
|
||||
sampling_metadata.repetition_penalties,
|
||||
sampling_metadata.frequency_penalties,
|
||||
sampling_metadata.presence_penalties,
|
||||
sampling_metadata.temperature,
|
||||
sampling_metadata.bad_words_token_ids,
|
||||
sampling_metadata.step_idx,
|
||||
sampling_metadata.min_dec_lens,
|
||||
sampling_metadata.eos_token_ids,
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["output_padding_offset"],
|
||||
share_inputs["output_cum_offsets"],
|
||||
max_model_len,
|
||||
)
|
||||
probs = F.softmax(logits)
|
||||
|
||||
_, next_tokens = top_k_top_p_sampling(
|
||||
probs, sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.top_k_list
|
||||
)
|
||||
# TODO(chenhuan09): add support for logprobs
|
||||
token_ids = None
|
||||
logprobs_tensors = None
|
||||
|
||||
sampler_output = SamplerOutput(
|
||||
sampled_token_ids=token_ids,
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
token_num_per_batch=None,
|
||||
cu_batch_token_offset=None,
|
||||
)
|
||||
return next_tokens, sampler_output
|
||||
|
||||
@@ -31,6 +31,18 @@ if current_platform.is_xpu():
|
||||
get_padding_offset,
|
||||
limit_thinking_content_length_v1,
|
||||
limit_thinking_content_length_v2,
|
||||
save_output,
|
||||
set_stop_value_multi_ends,
|
||||
speculate_clear_accept_nums,
|
||||
speculate_get_output_padding_offset,
|
||||
speculate_get_padding_offset,
|
||||
speculate_get_seq_lens_output,
|
||||
speculate_save_output,
|
||||
speculate_set_value_by_flags_and_idx,
|
||||
speculate_step_paddle,
|
||||
speculate_update_v3,
|
||||
step_paddle,
|
||||
update_inputs,
|
||||
update_inputs_v1,
|
||||
)
|
||||
|
||||
@@ -45,19 +57,53 @@ def xpu_pre_process(
|
||||
seq_lens_encoder: Optional[paddle.Tensor] = None,
|
||||
seq_lens_decoder: Optional[paddle.Tensor] = None,
|
||||
is_profiling: bool = False,
|
||||
forward_meta=None,
|
||||
) -> XPUForwardMeta:
|
||||
""" """
|
||||
max_len = input_ids.shape[1]
|
||||
cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32")
|
||||
token_num = paddle.sum(seq_lens_this_time)
|
||||
|
||||
(
|
||||
ids_remove_padding,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time)
|
||||
if use_speculate_method:
|
||||
(
|
||||
ids_remove_padding,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
) = speculate_get_padding_offset(
|
||||
input_ids,
|
||||
draft_tokens,
|
||||
cum_offsets_now,
|
||||
token_num,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
)
|
||||
seq_lens_output = speculate_get_seq_lens_output(
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
)
|
||||
if isinstance(seq_lens_output, list):
|
||||
seq_lens_output = seq_lens_output[0]
|
||||
output_token_num = paddle.sum(seq_lens_output)
|
||||
output_cum_offsets_tmp = paddle.cumsum(max_len - seq_lens_output, dtype="int32")
|
||||
output_padding_offset, output_cum_offsets = speculate_get_output_padding_offset(
|
||||
output_cum_offsets_tmp,
|
||||
output_token_num,
|
||||
seq_lens_output,
|
||||
max_len,
|
||||
)
|
||||
share_inputs["output_cum_offsets"].copy_(output_cum_offsets, False)
|
||||
share_inputs["output_padding_offset"].copy_(output_padding_offset, False)
|
||||
else:
|
||||
(
|
||||
ids_remove_padding,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time)
|
||||
|
||||
share_inputs["ids_remove_padding"] = None # set this after adjust batch
|
||||
share_inputs["cum_offsets"] = cum_offsets
|
||||
@@ -173,11 +219,6 @@ def xpu_post_process_normal(
|
||||
line_break_id: int = None,
|
||||
) -> None:
|
||||
""" """
|
||||
from fastdeploy.model_executor.ops.xpu import (
|
||||
save_output,
|
||||
set_stop_value_multi_ends,
|
||||
update_inputs,
|
||||
)
|
||||
|
||||
if think_end_id > 0:
|
||||
limit_strategy = envs.FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR
|
||||
@@ -277,39 +318,110 @@ def xpu_post_process_normal(
|
||||
)
|
||||
|
||||
|
||||
def xpu_post_process_specualate(
|
||||
model_output: ModelOutputData, save_each_rank: bool = False, skip_save_output: bool = False
|
||||
):
|
||||
""""""
|
||||
speculate_update_v3(
|
||||
model_output.seq_lens_encoder,
|
||||
model_output.seq_lens_decoder,
|
||||
model_output.not_need_stop,
|
||||
model_output.draft_tokens,
|
||||
model_output.actual_draft_token_num,
|
||||
model_output.accept_tokens,
|
||||
model_output.accept_num,
|
||||
model_output.stop_flags,
|
||||
model_output.seq_lens_this_time,
|
||||
model_output.is_block_step,
|
||||
model_output.stop_nums,
|
||||
)
|
||||
if not skip_save_output:
|
||||
speculate_save_output(
|
||||
model_output.accept_tokens,
|
||||
model_output.accept_num,
|
||||
model_output.not_need_stop,
|
||||
model_output.mp_rank,
|
||||
save_each_rank, # False
|
||||
)
|
||||
|
||||
speculate_clear_accept_nums(model_output.accept_num, model_output.seq_lens_decoder)
|
||||
|
||||
# Update pre_ids through accept tokens
|
||||
speculate_set_value_by_flags_and_idx(
|
||||
model_output.pre_ids,
|
||||
model_output.accept_tokens,
|
||||
model_output.accept_num,
|
||||
model_output.stop_flags,
|
||||
model_output.seq_lens_this_time,
|
||||
model_output.seq_lens_encoder,
|
||||
model_output.seq_lens_decoder,
|
||||
model_output.step_idx,
|
||||
)
|
||||
|
||||
|
||||
def step_xpu(
|
||||
share_inputs: Dict[str, paddle.Tensor],
|
||||
block_size: int,
|
||||
enc_dec_block_num: int,
|
||||
speculative_decoding: bool,
|
||||
max_draft_token_num: int,
|
||||
) -> None:
|
||||
"""
|
||||
TODO(gongshaotian): normalization name
|
||||
TODO(chenhuan09): support PD
|
||||
"""
|
||||
from fastdeploy.model_executor.ops.xpu import step_paddle
|
||||
|
||||
step_paddle(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["step_seq_lens_encoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
share_inputs["encoder_block_lens"],
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["step_block_list"],
|
||||
share_inputs["step_lens"],
|
||||
share_inputs["recover_block_list"],
|
||||
share_inputs["recover_lens"],
|
||||
share_inputs["need_block_list"],
|
||||
share_inputs["need_block_len"],
|
||||
share_inputs["used_list_len"],
|
||||
share_inputs["free_list"],
|
||||
share_inputs["free_list_len"],
|
||||
share_inputs["input_ids"],
|
||||
share_inputs["pre_ids"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["next_tokens"],
|
||||
share_inputs["first_token_ids"],
|
||||
block_size,
|
||||
enc_dec_block_num,
|
||||
)
|
||||
if speculative_decoding:
|
||||
speculate_step_paddle(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["step_seq_lens_encoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
share_inputs["encoder_block_lens"],
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["step_block_list"],
|
||||
share_inputs["step_lens"],
|
||||
share_inputs["recover_block_list"],
|
||||
share_inputs["recover_lens"],
|
||||
share_inputs["need_block_list"],
|
||||
share_inputs["need_block_len"],
|
||||
share_inputs["used_list_len"],
|
||||
share_inputs["free_list"],
|
||||
share_inputs["free_list_len"],
|
||||
share_inputs["input_ids"],
|
||||
share_inputs["pre_ids"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["next_tokens"],
|
||||
share_inputs["first_token_ids"],
|
||||
share_inputs["accept_num"],
|
||||
block_size,
|
||||
enc_dec_block_num,
|
||||
max_draft_token_num,
|
||||
)
|
||||
else:
|
||||
step_paddle(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["step_seq_lens_encoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
share_inputs["encoder_block_lens"],
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["step_block_list"],
|
||||
share_inputs["step_lens"],
|
||||
share_inputs["recover_block_list"],
|
||||
share_inputs["recover_lens"],
|
||||
share_inputs["need_block_list"],
|
||||
share_inputs["need_block_len"],
|
||||
share_inputs["used_list_len"],
|
||||
share_inputs["free_list"],
|
||||
share_inputs["free_list_len"],
|
||||
share_inputs["input_ids"],
|
||||
share_inputs["pre_ids"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["next_tokens"],
|
||||
share_inputs["first_token_ids"],
|
||||
block_size,
|
||||
enc_dec_block_num,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user