mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[xpu] support mtp for xpu(mix) (#5274)
* [XPU] support kernel for mtp(base) * [XPU] support kernel for mtp(base) * format * format * format * fix gather next token * fix step && add test * fix * mv pre/post process * add adjust batch / gather next token for mtp * fix code style * fix mtp kenrel name * fix mtp kernel test * mv xpu pre/post process * mv xpu pre/post process * [xpu] support mtp * fix code style
This commit is contained in:
@@ -182,24 +182,28 @@ def apply_speculative_penalty_multi_scores(
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
speculate_get_token_penalty_multi_scores,
|
||||
)
|
||||
|
||||
speculate_get_token_penalty_multi_scores(
|
||||
pre_token_ids,
|
||||
logits,
|
||||
repetition_penalties,
|
||||
frequency_penalties,
|
||||
presence_penalties,
|
||||
temperature,
|
||||
bad_words_token_ids,
|
||||
step_idx,
|
||||
min_dec_lens,
|
||||
eos_token_ids,
|
||||
seq_lens_this_time,
|
||||
output_padding_offset,
|
||||
output_cum_offsets,
|
||||
max_len,
|
||||
elif current_platform.is_xpu():
|
||||
from fastdeploy.model_executor.ops.xpu import (
|
||||
speculate_get_token_penalty_multi_scores,
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
speculate_get_token_penalty_multi_scores(
|
||||
pre_token_ids,
|
||||
logits,
|
||||
repetition_penalties,
|
||||
frequency_penalties,
|
||||
presence_penalties,
|
||||
temperature,
|
||||
bad_words_token_ids,
|
||||
step_idx,
|
||||
min_dec_lens,
|
||||
eos_token_ids,
|
||||
seq_lens_this_time,
|
||||
output_padding_offset,
|
||||
output_cum_offsets,
|
||||
max_len,
|
||||
)
|
||||
# inplace
|
||||
return logits
|
||||
|
||||
@@ -572,6 +572,8 @@ class SpeculativeSampler(nn.Layer):
|
||||
super().__init__()
|
||||
if current_platform.is_cuda():
|
||||
self.forward = self.forward_cuda
|
||||
elif current_platform.is_xpu():
|
||||
self.forward = self.forward_xpu
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.logprobs_mode = fd_config.model_config.logprobs_mode
|
||||
@@ -814,6 +816,80 @@ class SpeculativeSampler(nn.Layer):
|
||||
|
||||
return sampler_output
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
logits: paddle.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
max_model_len: int,
|
||||
share_inputs: List[paddle.Tensor],
|
||||
accept_all_drafts: bool = False,
|
||||
reject_all_drafts: bool = False,
|
||||
) -> paddle.Tensor:
|
||||
from fastdeploy.model_executor.ops.xpu import speculate_verify, top_p_candidates
|
||||
|
||||
logits = apply_speculative_penalty_multi_scores(
|
||||
sampling_metadata.pre_token_ids,
|
||||
logits,
|
||||
sampling_metadata.repetition_penalties,
|
||||
sampling_metadata.frequency_penalties,
|
||||
sampling_metadata.presence_penalties,
|
||||
sampling_metadata.temperature,
|
||||
sampling_metadata.bad_words_token_ids,
|
||||
sampling_metadata.step_idx,
|
||||
sampling_metadata.min_dec_lens,
|
||||
sampling_metadata.eos_token_ids,
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["output_padding_offset"],
|
||||
share_inputs["output_cum_offsets"],
|
||||
max_model_len,
|
||||
)
|
||||
|
||||
probs = F.softmax(logits)
|
||||
|
||||
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
|
||||
probs,
|
||||
sampling_metadata.top_p,
|
||||
share_inputs["output_padding_offset"],
|
||||
self.speculative_max_candidate_len,
|
||||
max_model_len,
|
||||
)
|
||||
|
||||
speculate_verify(
|
||||
share_inputs["accept_tokens"],
|
||||
share_inputs["accept_num"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs[
|
||||
"draft_tokens"
|
||||
], # Both input and output, need to write the last 1 token accepted to position 0.
|
||||
share_inputs["seq_lens_this_time"],
|
||||
verify_tokens,
|
||||
verify_scores,
|
||||
share_inputs["max_dec_len"],
|
||||
sampling_metadata.eos_token_ids,
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["output_cum_offsets"],
|
||||
actual_candidate_len,
|
||||
share_inputs["actual_draft_token_num"],
|
||||
sampling_metadata.top_p,
|
||||
max_model_len,
|
||||
self.speculative_verify_window,
|
||||
True, # enable_topp
|
||||
(self.speculative_benchmark_mode or reject_all_drafts),
|
||||
accept_all_drafts,
|
||||
)
|
||||
# TODO(chenhuan09): support return logprobs
|
||||
token_ids = share_inputs["accept_tokens"]
|
||||
sampler_output = SamplerOutput(
|
||||
sampled_token_ids=token_ids,
|
||||
logprobs_tensors=None,
|
||||
token_num_per_batch=share_inputs["accept_num"],
|
||||
cu_batch_token_offset=None,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
|
||||
class MTPSampler(nn.Layer):
|
||||
""" """
|
||||
@@ -823,6 +899,8 @@ class MTPSampler(nn.Layer):
|
||||
super().__init__()
|
||||
if current_platform.is_cuda():
|
||||
self.forward = self.forward_cuda
|
||||
elif current_platform.is_xpu():
|
||||
self.forward = self.forward_xpu
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.logprobs_mode = fd_config.model_config.logprobs_mode
|
||||
@@ -1013,3 +1091,44 @@ class MTPSampler(nn.Layer):
|
||||
cu_batch_token_offset=share_inputs["cu_batch_token_offset"],
|
||||
)
|
||||
return next_tokens, sampler_output
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
logits: paddle.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
max_model_len: int,
|
||||
share_inputs: List[paddle.Tensor],
|
||||
) -> paddle.Tensor:
|
||||
|
||||
logits = apply_speculative_penalty_multi_scores(
|
||||
sampling_metadata.pre_token_ids,
|
||||
logits,
|
||||
sampling_metadata.repetition_penalties,
|
||||
sampling_metadata.frequency_penalties,
|
||||
sampling_metadata.presence_penalties,
|
||||
sampling_metadata.temperature,
|
||||
sampling_metadata.bad_words_token_ids,
|
||||
sampling_metadata.step_idx,
|
||||
sampling_metadata.min_dec_lens,
|
||||
sampling_metadata.eos_token_ids,
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["output_padding_offset"],
|
||||
share_inputs["output_cum_offsets"],
|
||||
max_model_len,
|
||||
)
|
||||
probs = F.softmax(logits)
|
||||
|
||||
_, next_tokens = top_k_top_p_sampling(
|
||||
probs, sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.top_k_list
|
||||
)
|
||||
# TODO(chenhuan09): add support for logprobs
|
||||
token_ids = None
|
||||
logprobs_tensors = None
|
||||
|
||||
sampler_output = SamplerOutput(
|
||||
sampled_token_ids=token_ids,
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
token_num_per_batch=None,
|
||||
cu_batch_token_offset=None,
|
||||
)
|
||||
return next_tokens, sampler_output
|
||||
|
||||
@@ -31,6 +31,18 @@ if current_platform.is_xpu():
|
||||
get_padding_offset,
|
||||
limit_thinking_content_length_v1,
|
||||
limit_thinking_content_length_v2,
|
||||
save_output,
|
||||
set_stop_value_multi_ends,
|
||||
speculate_clear_accept_nums,
|
||||
speculate_get_output_padding_offset,
|
||||
speculate_get_padding_offset,
|
||||
speculate_get_seq_lens_output,
|
||||
speculate_save_output,
|
||||
speculate_set_value_by_flags_and_idx,
|
||||
speculate_step_paddle,
|
||||
speculate_update_v3,
|
||||
step_paddle,
|
||||
update_inputs,
|
||||
update_inputs_v1,
|
||||
)
|
||||
|
||||
@@ -45,19 +57,53 @@ def xpu_pre_process(
|
||||
seq_lens_encoder: Optional[paddle.Tensor] = None,
|
||||
seq_lens_decoder: Optional[paddle.Tensor] = None,
|
||||
is_profiling: bool = False,
|
||||
forward_meta=None,
|
||||
) -> XPUForwardMeta:
|
||||
""" """
|
||||
max_len = input_ids.shape[1]
|
||||
cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32")
|
||||
token_num = paddle.sum(seq_lens_this_time)
|
||||
|
||||
(
|
||||
ids_remove_padding,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time)
|
||||
if use_speculate_method:
|
||||
(
|
||||
ids_remove_padding,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
) = speculate_get_padding_offset(
|
||||
input_ids,
|
||||
draft_tokens,
|
||||
cum_offsets_now,
|
||||
token_num,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
)
|
||||
seq_lens_output = speculate_get_seq_lens_output(
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
)
|
||||
if isinstance(seq_lens_output, list):
|
||||
seq_lens_output = seq_lens_output[0]
|
||||
output_token_num = paddle.sum(seq_lens_output)
|
||||
output_cum_offsets_tmp = paddle.cumsum(max_len - seq_lens_output, dtype="int32")
|
||||
output_padding_offset, output_cum_offsets = speculate_get_output_padding_offset(
|
||||
output_cum_offsets_tmp,
|
||||
output_token_num,
|
||||
seq_lens_output,
|
||||
max_len,
|
||||
)
|
||||
share_inputs["output_cum_offsets"].copy_(output_cum_offsets, False)
|
||||
share_inputs["output_padding_offset"].copy_(output_padding_offset, False)
|
||||
else:
|
||||
(
|
||||
ids_remove_padding,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time)
|
||||
|
||||
share_inputs["ids_remove_padding"] = None # set this after adjust batch
|
||||
share_inputs["cum_offsets"] = cum_offsets
|
||||
@@ -173,11 +219,6 @@ def xpu_post_process_normal(
|
||||
line_break_id: int = None,
|
||||
) -> None:
|
||||
""" """
|
||||
from fastdeploy.model_executor.ops.xpu import (
|
||||
save_output,
|
||||
set_stop_value_multi_ends,
|
||||
update_inputs,
|
||||
)
|
||||
|
||||
if think_end_id > 0:
|
||||
limit_strategy = envs.FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR
|
||||
@@ -277,39 +318,110 @@ def xpu_post_process_normal(
|
||||
)
|
||||
|
||||
|
||||
def xpu_post_process_specualate(
|
||||
model_output: ModelOutputData, save_each_rank: bool = False, skip_save_output: bool = False
|
||||
):
|
||||
""""""
|
||||
speculate_update_v3(
|
||||
model_output.seq_lens_encoder,
|
||||
model_output.seq_lens_decoder,
|
||||
model_output.not_need_stop,
|
||||
model_output.draft_tokens,
|
||||
model_output.actual_draft_token_num,
|
||||
model_output.accept_tokens,
|
||||
model_output.accept_num,
|
||||
model_output.stop_flags,
|
||||
model_output.seq_lens_this_time,
|
||||
model_output.is_block_step,
|
||||
model_output.stop_nums,
|
||||
)
|
||||
if not skip_save_output:
|
||||
speculate_save_output(
|
||||
model_output.accept_tokens,
|
||||
model_output.accept_num,
|
||||
model_output.not_need_stop,
|
||||
model_output.mp_rank,
|
||||
save_each_rank, # False
|
||||
)
|
||||
|
||||
speculate_clear_accept_nums(model_output.accept_num, model_output.seq_lens_decoder)
|
||||
|
||||
# Update pre_ids through accept tokens
|
||||
speculate_set_value_by_flags_and_idx(
|
||||
model_output.pre_ids,
|
||||
model_output.accept_tokens,
|
||||
model_output.accept_num,
|
||||
model_output.stop_flags,
|
||||
model_output.seq_lens_this_time,
|
||||
model_output.seq_lens_encoder,
|
||||
model_output.seq_lens_decoder,
|
||||
model_output.step_idx,
|
||||
)
|
||||
|
||||
|
||||
def step_xpu(
|
||||
share_inputs: Dict[str, paddle.Tensor],
|
||||
block_size: int,
|
||||
enc_dec_block_num: int,
|
||||
speculative_decoding: bool,
|
||||
max_draft_token_num: int,
|
||||
) -> None:
|
||||
"""
|
||||
TODO(gongshaotian): normalization name
|
||||
TODO(chenhuan09): support PD
|
||||
"""
|
||||
from fastdeploy.model_executor.ops.xpu import step_paddle
|
||||
|
||||
step_paddle(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["step_seq_lens_encoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
share_inputs["encoder_block_lens"],
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["step_block_list"],
|
||||
share_inputs["step_lens"],
|
||||
share_inputs["recover_block_list"],
|
||||
share_inputs["recover_lens"],
|
||||
share_inputs["need_block_list"],
|
||||
share_inputs["need_block_len"],
|
||||
share_inputs["used_list_len"],
|
||||
share_inputs["free_list"],
|
||||
share_inputs["free_list_len"],
|
||||
share_inputs["input_ids"],
|
||||
share_inputs["pre_ids"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["next_tokens"],
|
||||
share_inputs["first_token_ids"],
|
||||
block_size,
|
||||
enc_dec_block_num,
|
||||
)
|
||||
if speculative_decoding:
|
||||
speculate_step_paddle(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["step_seq_lens_encoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
share_inputs["encoder_block_lens"],
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["step_block_list"],
|
||||
share_inputs["step_lens"],
|
||||
share_inputs["recover_block_list"],
|
||||
share_inputs["recover_lens"],
|
||||
share_inputs["need_block_list"],
|
||||
share_inputs["need_block_len"],
|
||||
share_inputs["used_list_len"],
|
||||
share_inputs["free_list"],
|
||||
share_inputs["free_list_len"],
|
||||
share_inputs["input_ids"],
|
||||
share_inputs["pre_ids"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["next_tokens"],
|
||||
share_inputs["first_token_ids"],
|
||||
share_inputs["accept_num"],
|
||||
block_size,
|
||||
enc_dec_block_num,
|
||||
max_draft_token_num,
|
||||
)
|
||||
else:
|
||||
step_paddle(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["step_seq_lens_encoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
share_inputs["encoder_block_lens"],
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["step_block_list"],
|
||||
share_inputs["step_lens"],
|
||||
share_inputs["recover_block_list"],
|
||||
share_inputs["recover_lens"],
|
||||
share_inputs["need_block_list"],
|
||||
share_inputs["need_block_len"],
|
||||
share_inputs["used_list_len"],
|
||||
share_inputs["free_list"],
|
||||
share_inputs["free_list_len"],
|
||||
share_inputs["input_ids"],
|
||||
share_inputs["pre_ids"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["next_tokens"],
|
||||
share_inputs["first_token_ids"],
|
||||
block_size,
|
||||
enc_dec_block_num,
|
||||
)
|
||||
|
||||
@@ -340,7 +340,11 @@ class TokenProcessor:
|
||||
"""
|
||||
|
||||
if current_platform.is_xpu():
|
||||
from fastdeploy.model_executor.ops.xpu import get_output, get_output_ep
|
||||
from fastdeploy.model_executor.ops.xpu import (
|
||||
get_output,
|
||||
get_output_ep,
|
||||
speculate_get_output,
|
||||
)
|
||||
elif current_platform.is_iluvatar():
|
||||
from fastdeploy.model_executor.ops.iluvatar import get_output
|
||||
elif current_platform.is_gcu():
|
||||
|
||||
@@ -14,9 +14,12 @@
|
||||
"""
|
||||
speculative decoding module
|
||||
"""
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
from .base import Proposer
|
||||
from .mtp import MTPProposer
|
||||
from .ngram import NgramProposer
|
||||
|
||||
# XPU is not support ngram proposer now
|
||||
if not current_platform.is_xpu():
|
||||
from .ngram import NgramProposer
|
||||
__all__ = ["Proposer", "MTPProposer", "NgramProposer"]
|
||||
|
||||
@@ -34,21 +34,39 @@ from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
||||
from fastdeploy.model_executor.layers.sample.sampler import MTPSampler
|
||||
from fastdeploy.model_executor.model_loader import get_model_loader
|
||||
from fastdeploy.model_executor.models import ModelForCasualLM
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
draft_model_postprocess,
|
||||
draft_model_preprocess,
|
||||
draft_model_update,
|
||||
eagle_get_hidden_states,
|
||||
eagle_get_self_hidden_states,
|
||||
hybrid_mtp_ngram,
|
||||
mtp_save_first_token,
|
||||
mtp_step_paddle,
|
||||
share_external_data,
|
||||
speculate_get_logits,
|
||||
speculate_save_output_topk,
|
||||
update_attn_mask_offsets,
|
||||
)
|
||||
from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_xpu():
|
||||
from fastdeploy.model_executor.ops.xpu import (
|
||||
draft_model_postprocess,
|
||||
draft_model_preprocess,
|
||||
draft_model_update,
|
||||
eagle_get_hidden_states,
|
||||
eagle_get_self_hidden_states,
|
||||
mtp_save_first_token,
|
||||
mtp_step_paddle,
|
||||
share_external_data,
|
||||
)
|
||||
from fastdeploy.model_executor.xpu_pre_and_post_process import (
|
||||
xpu_pre_process,
|
||||
xpu_process_output,
|
||||
)
|
||||
else:
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
draft_model_postprocess,
|
||||
draft_model_preprocess,
|
||||
draft_model_update,
|
||||
eagle_get_hidden_states,
|
||||
eagle_get_self_hidden_states,
|
||||
hybrid_mtp_ngram,
|
||||
mtp_save_first_token,
|
||||
mtp_step_paddle,
|
||||
share_external_data,
|
||||
speculate_get_logits,
|
||||
speculate_save_output_topk,
|
||||
update_attn_mask_offsets,
|
||||
)
|
||||
from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding
|
||||
|
||||
from .base import Proposer
|
||||
|
||||
@@ -79,6 +97,15 @@ class MTPProposer(Proposer):
|
||||
|
||||
# [mixed, prefill, decoder]
|
||||
self.role = self.scheduler_config.splitwise_role
|
||||
if current_platform.is_xpu():
|
||||
self.role = "mixed"
|
||||
|
||||
if current_platform.is_xpu():
|
||||
self._propose = self._propose_xpu
|
||||
elif current_platform.is_cuda():
|
||||
self._propose = self._propose_cuda
|
||||
else:
|
||||
raise RuntimeError("Unsupported platform.")
|
||||
|
||||
self.sampler = MTPSampler(fd_config)
|
||||
self._init_model_inputs()
|
||||
@@ -92,7 +119,7 @@ class MTPProposer(Proposer):
|
||||
self._initialize_attn_backend()
|
||||
|
||||
# Forward meta store the global meta information of the forward
|
||||
self.forward_meta: ForwardMeta = None
|
||||
self.forward_meta = None
|
||||
|
||||
def _update_mtp_config(self, main_model):
|
||||
"""
|
||||
@@ -166,7 +193,7 @@ class MTPProposer(Proposer):
|
||||
and hasattr(self.quant_config, "kv_cache_quant_type")
|
||||
and self.quant_config.kv_cache_quant_type is not None
|
||||
):
|
||||
cache_type = "uint8"
|
||||
cache_type = self._get_cache_type()
|
||||
kv_cache_quant_type = self.quant_config.kv_cache_quant_type
|
||||
|
||||
# Get kv cache shape
|
||||
@@ -220,7 +247,7 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["caches"] = list(self.cache_kvs.values())
|
||||
for value in self.cache_kvs.values():
|
||||
del value
|
||||
paddle.device.cuda.empty_cache()
|
||||
self._empty_cache()
|
||||
|
||||
def _initialize_attn_backend(
|
||||
self,
|
||||
@@ -245,9 +272,14 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["decoder_tile_ids_per_batch"] = paddle.zeros_like(
|
||||
self.target_model_inputs["decoder_tile_ids_per_batch"]
|
||||
)
|
||||
self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like(
|
||||
self.target_model_inputs["decoder_num_blocks_cpu"]
|
||||
).pin_memory()
|
||||
if current_platform.is_xpu():
|
||||
self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like(
|
||||
self.target_model_inputs["decoder_num_blocks_cpu"]
|
||||
).cpu()
|
||||
else:
|
||||
self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like(
|
||||
self.target_model_inputs["decoder_num_blocks_cpu"]
|
||||
).pin_memory()
|
||||
self.model_inputs["decoder_num_blocks_device"] = paddle.zeros_like(
|
||||
self.target_model_inputs["decoder_num_blocks_device"]
|
||||
)
|
||||
@@ -669,6 +701,36 @@ class MTPProposer(Proposer):
|
||||
|
||||
self.forward_meta.step_use_cudagraph = step_use_cudagraph and self.draft_model_use_cudagraph
|
||||
|
||||
def _initialize_forward_meta_xpu(self):
|
||||
|
||||
self.forward_meta.decoder_batch_ids = (self.model_inputs["decoder_batch_ids"],)
|
||||
self.forward_meta.decoder_tile_ids_per_batch = (self.model_inputs["decoder_tile_ids_per_batch"],)
|
||||
self.forward_meta.decoder_num_blocks_cpu = (self.model_inputs["decoder_num_blocks_cpu"],)
|
||||
self.forward_meta.decoder_num_blocks_device = (self.model_inputs["decoder_num_blocks_device"],)
|
||||
self.forward_meta.decoder_chunk_size_device = (self.model_inputs["decoder_chunk_size_device"],)
|
||||
self.forward_meta.max_len_tensor_cpu = (self.model_inputs["max_len_tensor_cpu"],)
|
||||
|
||||
self.forward_meta.encoder_batch_ids = (self.model_inputs["encoder_batch_ids"],)
|
||||
self.forward_meta.encoder_tile_ids_per_batch = (self.model_inputs["encoder_tile_ids_per_batch"],)
|
||||
self.forward_meta.encoder_num_blocks_x_cpu = (self.model_inputs["encoder_num_blocks_x_cpu"],)
|
||||
self.forward_meta.kv_batch_ids = (self.model_inputs["kv_batch_ids"],)
|
||||
self.forward_meta.kv_tile_ids_per_batch = (self.model_inputs["kv_tile_ids_per_batch"],)
|
||||
self.forward_meta.kv_num_blocks_x_cpu = (self.model_inputs["kv_num_blocks_x_cpu"],)
|
||||
self.forward_meta.pos_emb_type = "NORMAL"
|
||||
self.forward_meta.attn_backend = self.attn_backends[0]
|
||||
|
||||
# Initialzie attention meta data
|
||||
for attn_backend in self.attn_backends:
|
||||
attn_backend.init_attention_metadata(self.forward_meta)
|
||||
|
||||
# Mix ep in single node
|
||||
if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
|
||||
only_decode_batch_list = []
|
||||
prefill_exists = self.exist_prefill()
|
||||
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
|
||||
only_decode_batch = all(only_decode_batch_list)
|
||||
self.fd_config.model_config.moe_phase.phase = "decode" if only_decode_batch else "prefill"
|
||||
|
||||
def exist_prefill(self):
|
||||
"""
|
||||
check whether prefill stage exist
|
||||
@@ -682,7 +744,7 @@ class MTPProposer(Proposer):
|
||||
"""
|
||||
Prepare MTP inputs
|
||||
"""
|
||||
use_v1_cache_scheduler = envs.ENABLE_V1_KVCACHE_SCHEDULER
|
||||
use_v1_cache_scheduler = bool(envs.ENABLE_V1_KVCACHE_SCHEDULER)
|
||||
draft_model_preprocess(
|
||||
self.model_inputs["draft_tokens"],
|
||||
self.model_inputs["input_ids"],
|
||||
@@ -767,7 +829,7 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["step_idx"],
|
||||
)
|
||||
|
||||
def _propose(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False):
|
||||
def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False):
|
||||
"""
|
||||
Main process for MTP inference.
|
||||
Args:
|
||||
@@ -928,6 +990,96 @@ class MTPProposer(Proposer):
|
||||
if hasattr(self.model, "empty_input_forward"):
|
||||
self.model.empty_input_forward()
|
||||
|
||||
def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False):
|
||||
"""
|
||||
Main process for MTP inference.
|
||||
Args:
|
||||
step_use_cudagraph: bool
|
||||
Whether to use cuda graph. Use the target model flag to avoid hanging problems with EP.
|
||||
"""
|
||||
for substep in range(self.num_model_steps):
|
||||
if self.model_inputs["not_need_stop"]:
|
||||
self.model_inputs["substep"] = substep
|
||||
# Remove padding
|
||||
self.forward_meta = xpu_pre_process(
|
||||
self.model_inputs["input_ids"],
|
||||
self.model_inputs["seq_lens_this_time"],
|
||||
self.model_inputs,
|
||||
True,
|
||||
self.cache_config.block_size,
|
||||
self.model_inputs["draft_tokens"],
|
||||
self.model_inputs["seq_lens_encoder"],
|
||||
self.model_inputs["seq_lens_decoder"],
|
||||
)
|
||||
self._initialize_forward_meta_xpu()
|
||||
# Get sampling metadata
|
||||
self.sampling_metadata = SamplingMetadata(
|
||||
temperature=self.model_inputs["temperature"],
|
||||
top_p=self.model_inputs["top_p"],
|
||||
top_k=self.model_inputs["top_k"],
|
||||
seed=self.model_inputs["infer_seed"],
|
||||
step_idx=self.model_inputs["step_idx"],
|
||||
pre_token_ids=self.model_inputs["pre_ids"],
|
||||
frequency_penalties=self.model_inputs["frequency_score"],
|
||||
presence_penalties=self.model_inputs["presence_score"],
|
||||
repetition_penalties=self.model_inputs["penalty_score"],
|
||||
min_dec_lens=self.model_inputs["min_dec_len"],
|
||||
bad_words_token_ids=self.model_inputs["bad_tokens"],
|
||||
eos_token_ids=self.model_inputs["eos_token_id"],
|
||||
max_num_logprobs=20 if self.enable_logprob else None,
|
||||
temp_scaled_logprobs=self.model_inputs["temp_scaled_logprobs"],
|
||||
top_p_normalized_logprobs=self.model_inputs["top_p_normalized_logprobs"],
|
||||
share_inputs=self.model_inputs,
|
||||
)
|
||||
|
||||
if self.num_model_steps > 1:
|
||||
self.last_seq_lens_this_time = paddle.clone(self.model_inputs["seq_lens_this_time"])
|
||||
|
||||
model_output = self.model(
|
||||
ids_remove_padding=self.model_inputs["ids_remove_padding"],
|
||||
previous_hidden_states=self.model_inputs["target_hidden_states"],
|
||||
forward_meta=self.forward_meta,
|
||||
)
|
||||
hidden_states = xpu_process_output(
|
||||
model_output, self.model_inputs["cum_offsets"], self.forward_meta, self.model_inputs
|
||||
)
|
||||
# 4. Compute logits, Sample
|
||||
logits = self.model.compute_logits(hidden_states)
|
||||
sampled_token_ids, sampler_output = self.sampler(
|
||||
logits,
|
||||
self.sampling_metadata,
|
||||
self.max_model_len,
|
||||
self.model_inputs,
|
||||
)
|
||||
|
||||
if substep == 0 and sampler_output.logprobs_tensors is not None:
|
||||
real_bsz = self.model_inputs["seq_lens_this_time"].shape[0]
|
||||
speculate_save_output_topk(
|
||||
sampler_output.sampled_token_ids,
|
||||
sampler_output.logprobs_tensors.logprob_token_ids,
|
||||
sampler_output.logprobs_tensors.logprobs,
|
||||
sampler_output.logprobs_tensors.selected_token_ranks,
|
||||
self.model_inputs["batch_token_num"][:real_bsz],
|
||||
self.model_inputs["cu_batch_token_offset"][:real_bsz],
|
||||
self.model_inputs["not_need_stop"],
|
||||
4, # mtype
|
||||
self.local_rank,
|
||||
)
|
||||
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
paddle.distributed.broadcast(
|
||||
sampled_token_ids,
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
|
||||
self._post_process(sampled_token_ids)
|
||||
if substep != self.num_model_steps - 1:
|
||||
self._get_self_hidden_states(hidden_states)
|
||||
else:
|
||||
if hasattr(self.model, "empty_input_forward"):
|
||||
self.model.empty_input_forward()
|
||||
|
||||
def _get_self_hidden_states(self, hidden_states):
|
||||
target_hidden_states = eagle_get_self_hidden_states(
|
||||
hidden_states,
|
||||
@@ -1044,3 +1196,21 @@ class MTPProposer(Proposer):
|
||||
self.forward_meta.seq_lens_this_time = self.seq_lens_this_time_buffer
|
||||
self.real_token_num = self.forward_meta.ids_remove_padding.shape[0]
|
||||
return
|
||||
|
||||
def _empty_cache(self):
|
||||
if current_platform.is_cuda():
|
||||
paddle.device.cuda.empty_cache()
|
||||
elif current_platform.is_xpu():
|
||||
paddle.device.xpu.empty_cache()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_cache_type(self):
|
||||
cache_type = None
|
||||
if current_platform.is_cuda():
|
||||
cache_type = "uint8"
|
||||
elif current_platform.is_xpu():
|
||||
cache_type = "int8"
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return cache_type
|
||||
|
||||
@@ -39,7 +39,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
)
|
||||
from fastdeploy.model_executor.layers.rotary_embedding import get_rope, get_rope_3d
|
||||
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
||||
from fastdeploy.model_executor.layers.sample.sampler import Sampler
|
||||
from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler
|
||||
from fastdeploy.model_executor.model_loader import get_model_loader
|
||||
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
|
||||
from fastdeploy.model_executor.ops.xpu import (
|
||||
@@ -49,12 +49,14 @@ from fastdeploy.model_executor.ops.xpu import (
|
||||
set_data_ipc,
|
||||
share_external_data,
|
||||
)
|
||||
from fastdeploy.model_executor.xpu_pre_and_post_process import ( # xpu_post_process_specualate, # TODO(chenhuan09): add xpu_post_process_specualate
|
||||
from fastdeploy.model_executor.xpu_pre_and_post_process import (
|
||||
step_xpu,
|
||||
xpu_post_process_normal,
|
||||
xpu_post_process_specualate,
|
||||
xpu_pre_process,
|
||||
xpu_process_output,
|
||||
)
|
||||
from fastdeploy.spec_decode import MTPProposer
|
||||
from fastdeploy.utils import get_logger
|
||||
from fastdeploy.worker.model_runner_base import ModelRunnerBase
|
||||
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput
|
||||
@@ -102,9 +104,20 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
"fused_gemm_epilogue",
|
||||
]
|
||||
|
||||
self.device_id = device_id
|
||||
self.speculative_method = self.fd_config.speculative_config.method
|
||||
self.speculative_decoding = self.speculative_method is not None
|
||||
|
||||
# used by SamplingMetadata
|
||||
self.enable_logprob = False # fd_config.model_config.enable_logprob
|
||||
self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop
|
||||
|
||||
# Sampler
|
||||
# TODU(lilujia): sync with GPU
|
||||
self.sampler = Sampler(fd_config)
|
||||
if not self.speculative_decoding:
|
||||
self.sampler = Sampler(fd_config)
|
||||
else:
|
||||
self.sampler = SpeculativeSampler(fd_config)
|
||||
|
||||
# Lazy initialize kv cache after model loading
|
||||
# self.kv_caches: list[paddle.Tensor] = []
|
||||
@@ -143,7 +156,7 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
else:
|
||||
return 0
|
||||
|
||||
def insert_tasks_v1(self, req_dicts: List[Request]):
|
||||
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int):
|
||||
"""
|
||||
Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
req_dict: A list of Request dict
|
||||
@@ -340,7 +353,10 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
if has_prefill_task or has_decode_task:
|
||||
self.share_inputs["not_need_stop"][0] = True
|
||||
|
||||
def insert_prefill_inputs(self, req_dicts: List[Request]):
|
||||
if self.speculative_method in ["mtp"]:
|
||||
self.proposer.insert_tasks_v1(req_dicts, num_running_requests)
|
||||
|
||||
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int):
|
||||
"""Process inputs for prefill tasks and update share_inputs buffer"""
|
||||
# NOTE(luotingdan): Set environment variable of prefill node
|
||||
if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill":
|
||||
@@ -480,6 +496,15 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
|
||||
self.share_inputs["not_need_stop"][0] = True
|
||||
|
||||
if self.speculative_method in ["mtp"]:
|
||||
self.share_inputs["temp_scaled_logprobs"][idx : idx + 1] = get_attr_from_request(
|
||||
request, "temp_scaled_logprobs", False
|
||||
)
|
||||
self.share_inputs["top_p_normalized_logprobs"][idx : idx + 1] = get_attr_from_request(
|
||||
request, "top_p_normalized_logprobs", False
|
||||
)
|
||||
self.proposer.insert_prefill_inputs(req_dicts, num_running_requests)
|
||||
|
||||
def _init_share_inputs(self, max_num_seqs: int):
|
||||
"""Initialize all share buffers for model inputs.
|
||||
Note: In the future, we may abandon share buffers.
|
||||
@@ -558,6 +583,15 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["system_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.share_inputs["system_ids"] = paddle.full([max_num_seqs, 1], -1, dtype="int32")
|
||||
|
||||
self.share_inputs["ids_remove_padding"] = paddle.full(
|
||||
[max_num_seqs * self.model_config.max_model_len],
|
||||
0,
|
||||
dtype="int64",
|
||||
)
|
||||
self.share_inputs["batch_id_per_token"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
|
||||
# Initialize thinking related buffers
|
||||
self.share_inputs["max_think_lens"] = paddle.full(shape=[max_num_seqs, 1], fill_value=-1, dtype="int32")
|
||||
self.share_inputs["limit_think_status"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
||||
@@ -629,6 +663,56 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
self.share_inputs["image_features"] = None
|
||||
|
||||
if self.speculative_decoding:
|
||||
max_draft_token_num = self.speculative_config.num_speculative_tokens
|
||||
self.share_inputs["input_ids_cpu"] = paddle.full(
|
||||
shape=[max_num_seqs, self.model_config.max_model_len],
|
||||
fill_value=1,
|
||||
dtype="int64",
|
||||
).cpu()
|
||||
self.share_inputs["accept_tokens"] = paddle.full(
|
||||
shape=[max_num_seqs, max_draft_token_num + 1],
|
||||
fill_value=0,
|
||||
dtype="int64",
|
||||
)
|
||||
self.share_inputs["accept_num"] = paddle.full(shape=[max_num_seqs], fill_value=0, dtype="int32")
|
||||
self.share_inputs["draft_tokens"] = paddle.full(
|
||||
shape=[max_num_seqs, max_draft_token_num + 1],
|
||||
fill_value=0,
|
||||
dtype="int64",
|
||||
)
|
||||
|
||||
self.share_inputs["actual_draft_token_num"] = paddle.full(
|
||||
shape=[max_num_seqs],
|
||||
fill_value=max_draft_token_num,
|
||||
dtype="int32",
|
||||
)
|
||||
self.share_inputs["output_cum_offsets"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
||||
self.share_inputs["output_padding_offset"] = paddle.full(
|
||||
shape=[max_num_seqs * (max_draft_token_num + 1)],
|
||||
fill_value=0,
|
||||
dtype="int32",
|
||||
)
|
||||
# For V1_KVCACHE_SCHEDULER
|
||||
self.share_inputs["step_draft_tokens"] = paddle.full(
|
||||
shape=[max_num_seqs, max_draft_token_num + 1],
|
||||
fill_value=0,
|
||||
dtype="int64",
|
||||
)
|
||||
self.share_inputs["step_seq_lens_this_time"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.share_inputs["temp_scaled_logprobs"] = paddle.full([max_num_seqs, 1], False, dtype=bool)
|
||||
self.share_inputs["top_p_normalized_logprobs"] = paddle.full([max_num_seqs, 1], False, dtype=bool)
|
||||
# For MTP Logprob
|
||||
self.share_inputs["draft_logits"] = paddle.full(
|
||||
[max_num_seqs * (self.speculative_config.num_speculative_tokens + 1), self.model_config.vocab_size],
|
||||
-1,
|
||||
dtype="float32",
|
||||
)
|
||||
self.share_inputs["cu_batch_token_offset"] = paddle.full(
|
||||
shape=[max_num_seqs + 1], fill_value=0, dtype="int32"
|
||||
)
|
||||
self.max_num_seqs = max_num_seqs
|
||||
|
||||
def _prepare_inputs(self, is_dummy_run=False) -> None:
|
||||
"""Prepare the model inputs"""
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER and not is_dummy_run:
|
||||
@@ -646,9 +730,9 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["input_ids"],
|
||||
self.share_inputs["seq_lens_this_time"],
|
||||
self.share_inputs,
|
||||
use_speculate_method=False,
|
||||
use_speculate_method=self.speculative_decoding,
|
||||
block_size=self.cache_config.block_size,
|
||||
draft_tokens=None,
|
||||
draft_tokens=self.share_inputs["draft_tokens"] if self.speculative_decoding else None,
|
||||
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
|
||||
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
|
||||
is_profiling=is_dummy_run,
|
||||
@@ -696,6 +780,7 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
# 2. Load lora model
|
||||
|
||||
# 3. Load drafter model(for speculative decoding)
|
||||
self._init_speculative_proposer()
|
||||
|
||||
def get_model(self) -> nn.Layer:
|
||||
"""Get current model"""
|
||||
@@ -793,6 +878,44 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
head_dim = self.model_config.head_dim
|
||||
|
||||
if self.speculative_decoding:
|
||||
# Initialize AttentionBackend buffers
|
||||
encoder_block_shape_q = 64
|
||||
decoder_block_shape_q = 16
|
||||
decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1
|
||||
decode_max_tile_size = self.max_num_seqs * np.ceil(
|
||||
(decoder_step_token_num * np.ceil(num_heads / self.model_config.kv_num_heads)) / decoder_block_shape_q
|
||||
)
|
||||
|
||||
group_size = np.ceil(num_heads / self.model_config.kv_num_heads)
|
||||
encode_max_tile_size = self.scheduler_config.max_num_seqs * np.ceil(
|
||||
(self.model_config.max_model_len * group_size) / encoder_block_shape_q
|
||||
)
|
||||
kv_max_tile_size = self.scheduler_config.max_num_seqs * np.ceil(
|
||||
self.model_config.max_model_len / self.fd_config.cache_config.block_size
|
||||
)
|
||||
self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full(
|
||||
[int(decode_max_tile_size)], 0, dtype="int32"
|
||||
)
|
||||
self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
|
||||
# NOTE: (changwenbin) MLA kernel only needs decoder_num_blocks_device in place of GPU tensor,
|
||||
# adapted to cudagraph.
|
||||
self.share_inputs["decoder_num_blocks_device"] = paddle.full([1], 0, dtype="int32")
|
||||
self.share_inputs["decoder_chunk_size_device"] = paddle.full([1], 64, dtype="int32")
|
||||
self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu()
|
||||
|
||||
self.share_inputs["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["encoder_tile_ids_per_batch"] = paddle.full(
|
||||
[int(encode_max_tile_size)], 0, dtype="int32"
|
||||
)
|
||||
self.share_inputs["encoder_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
|
||||
|
||||
self.share_inputs["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
|
||||
self.share_inputs["max_len_kv_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
|
||||
|
||||
# Get the attention backend
|
||||
attn_cls = get_attention_backend()
|
||||
attn_backend = attn_cls(
|
||||
@@ -851,12 +974,38 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
"""
|
||||
self._dummy_prefill_inputs(num_tokens, batch_size)
|
||||
|
||||
if self.speculative_method in ["mtp"]:
|
||||
self.proposer.dummy_prefill_inputs(
|
||||
num_tokens=num_tokens,
|
||||
batch_size=batch_size,
|
||||
expected_decode_len=1,
|
||||
)
|
||||
|
||||
while True:
|
||||
self.execute_model(is_dummy_run=True)
|
||||
|
||||
if int((self.share_inputs["seq_lens_this_time"] > 0).sum()) == 0:
|
||||
break
|
||||
|
||||
def _init_speculative_proposer(self):
|
||||
"""
|
||||
Init speculative proposer
|
||||
"""
|
||||
if self.speculative_method == "ngram":
|
||||
# xpu not support ngram proposer now
|
||||
# self.proposer = NgramProposer(self.fd_config)
|
||||
self.proposer = None
|
||||
elif self.speculative_method == "mtp":
|
||||
self.proposer = MTPProposer(
|
||||
self.fd_config,
|
||||
self.get_model(),
|
||||
self.local_rank,
|
||||
self.device_id,
|
||||
self.share_inputs,
|
||||
)
|
||||
else:
|
||||
self.proposer = None
|
||||
|
||||
def _set_debug_level(
|
||||
self, debug_level: int = 0x1, model_forward_batch: Optional[List[Request]] = None, is_dummy_run: bool = False
|
||||
) -> None:
|
||||
@@ -941,7 +1090,16 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
# 4. Compute logits, Sample
|
||||
logits = self.model.compute_logits(hidden_states)
|
||||
sampler_output = self.sampler(logits, self.sampling_metadata)
|
||||
sampler_output = None
|
||||
if not self.speculative_decoding:
|
||||
sampler_output = self.sampler(logits, self.sampling_metadata)
|
||||
else:
|
||||
self.sampler(
|
||||
logits,
|
||||
self.sampling_metadata,
|
||||
self.model_config.max_model_len,
|
||||
self.share_inputs,
|
||||
)
|
||||
|
||||
# 5. Speculative decode
|
||||
|
||||
@@ -961,26 +1119,36 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
|
||||
is_block_step=self.share_inputs["is_block_step"],
|
||||
# 投机解码
|
||||
full_hidden_states=None,
|
||||
full_hidden_states=model_output if self.speculative_decoding else None,
|
||||
msg_queue_id=self.parallel_config.msg_queue_id,
|
||||
mp_rank=self.local_rank,
|
||||
use_ep=self.parallel_config.use_ep,
|
||||
draft_tokens=None,
|
||||
actual_draft_token_num=None,
|
||||
accept_tokens=None,
|
||||
accept_num=None,
|
||||
draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None),
|
||||
actual_draft_token_num=(
|
||||
self.share_inputs["actual_draft_token_num"] if self.speculative_decoding else None
|
||||
),
|
||||
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
|
||||
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
|
||||
stop_token_ids=self.share_inputs["stop_seqs"],
|
||||
stop_seqs_len=self.share_inputs["stop_seqs_len"],
|
||||
)
|
||||
xpu_post_process_normal(
|
||||
sampled_token_ids=sampler_output.sampled_token_ids,
|
||||
model_output=model_output_data,
|
||||
share_inputs=self.share_inputs,
|
||||
block_size=self.cache_config.block_size,
|
||||
skip_save_output=is_dummy_run,
|
||||
think_end_id=self.model_config.think_end_id,
|
||||
line_break_id=self.model_config.line_break_id,
|
||||
)
|
||||
if self.speculative_decoding:
|
||||
# base model post process
|
||||
xpu_post_process_specualate(model_output_data, False, is_dummy_run)
|
||||
else:
|
||||
xpu_post_process_normal(
|
||||
sampled_token_ids=sampler_output.sampled_token_ids,
|
||||
model_output=model_output_data,
|
||||
share_inputs=self.share_inputs,
|
||||
block_size=self.cache_config.block_size,
|
||||
skip_save_output=is_dummy_run,
|
||||
think_end_id=self.model_config.think_end_id,
|
||||
line_break_id=self.model_config.line_break_id,
|
||||
)
|
||||
|
||||
# draft model propose
|
||||
if self.speculative_method == "mtp":
|
||||
self.proposer.run(full_hidden_states=model_output)
|
||||
|
||||
# 7. Updata 'infer_seed' and step_paddle()
|
||||
self.share_inputs["infer_seed"].add_(self.infer_seed_increment)
|
||||
@@ -989,6 +1157,8 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs,
|
||||
self.cache_config.block_size,
|
||||
self.cache_config.enc_dec_block_num,
|
||||
self.speculative_decoding,
|
||||
self.speculative_config.num_speculative_tokens,
|
||||
)
|
||||
|
||||
if self.pd_disaggregation_mode == "per_chunk" or self.pd_disaggregation_mode == "per_query":
|
||||
@@ -1013,6 +1183,9 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
self.num_gpu_blocks = self.cache_config.total_block_num
|
||||
self.initialize_kv_cache(profile=True)
|
||||
|
||||
if self.speculative_method in ["mtp"]:
|
||||
self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks, profile=True)
|
||||
|
||||
self._dummy_run(
|
||||
num_tokens=int(self.scheduler_config.max_num_batched_tokens),
|
||||
batch_size=min(self.scheduler_config.max_num_seqs, 1),
|
||||
|
||||
@@ -167,9 +167,9 @@ class XpuWorker(WorkerBase):
|
||||
and workers and modelrunners should not perceive it.
|
||||
"""
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
self.model_runner.insert_tasks_v1(req_dicts=req_dicts)
|
||||
self.model_runner.insert_tasks_v1(req_dicts=req_dicts, num_running_requests=num_running_requests)
|
||||
else:
|
||||
self.model_runner.insert_prefill_inputs(req_dicts=req_dicts)
|
||||
self.model_runner.insert_prefill_inputs(req_dicts=req_dicts, num_running_requests=num_running_requests)
|
||||
|
||||
def graph_optimize_and_warm_up_model(self) -> None:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user