mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[xpu] support mtp for xpu(mix) (#5274)
* [XPU] support kernel for mtp(base) * [XPU] support kernel for mtp(base) * format * format * format * fix gather next token * fix step && add test * fix * mv pre/post process * add adjust batch / gather next token for mtp * fix code style * fix mtp kenrel name * fix mtp kernel test * mv xpu pre/post process * mv xpu pre/post process * [xpu] support mtp * fix code style
This commit is contained in:
@@ -34,21 +34,39 @@ from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
||||
from fastdeploy.model_executor.layers.sample.sampler import MTPSampler
|
||||
from fastdeploy.model_executor.model_loader import get_model_loader
|
||||
from fastdeploy.model_executor.models import ModelForCasualLM
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
draft_model_postprocess,
|
||||
draft_model_preprocess,
|
||||
draft_model_update,
|
||||
eagle_get_hidden_states,
|
||||
eagle_get_self_hidden_states,
|
||||
hybrid_mtp_ngram,
|
||||
mtp_save_first_token,
|
||||
mtp_step_paddle,
|
||||
share_external_data,
|
||||
speculate_get_logits,
|
||||
speculate_save_output_topk,
|
||||
update_attn_mask_offsets,
|
||||
)
|
||||
from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_xpu():
|
||||
from fastdeploy.model_executor.ops.xpu import (
|
||||
draft_model_postprocess,
|
||||
draft_model_preprocess,
|
||||
draft_model_update,
|
||||
eagle_get_hidden_states,
|
||||
eagle_get_self_hidden_states,
|
||||
mtp_save_first_token,
|
||||
mtp_step_paddle,
|
||||
share_external_data,
|
||||
)
|
||||
from fastdeploy.model_executor.xpu_pre_and_post_process import (
|
||||
xpu_pre_process,
|
||||
xpu_process_output,
|
||||
)
|
||||
else:
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
draft_model_postprocess,
|
||||
draft_model_preprocess,
|
||||
draft_model_update,
|
||||
eagle_get_hidden_states,
|
||||
eagle_get_self_hidden_states,
|
||||
hybrid_mtp_ngram,
|
||||
mtp_save_first_token,
|
||||
mtp_step_paddle,
|
||||
share_external_data,
|
||||
speculate_get_logits,
|
||||
speculate_save_output_topk,
|
||||
update_attn_mask_offsets,
|
||||
)
|
||||
from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding
|
||||
|
||||
from .base import Proposer
|
||||
|
||||
@@ -79,6 +97,15 @@ class MTPProposer(Proposer):
|
||||
|
||||
# [mixed, prefill, decoder]
|
||||
self.role = self.scheduler_config.splitwise_role
|
||||
if current_platform.is_xpu():
|
||||
self.role = "mixed"
|
||||
|
||||
if current_platform.is_xpu():
|
||||
self._propose = self._propose_xpu
|
||||
elif current_platform.is_cuda():
|
||||
self._propose = self._propose_cuda
|
||||
else:
|
||||
raise RuntimeError("Unsupported platform.")
|
||||
|
||||
self.sampler = MTPSampler(fd_config)
|
||||
self._init_model_inputs()
|
||||
@@ -92,7 +119,7 @@ class MTPProposer(Proposer):
|
||||
self._initialize_attn_backend()
|
||||
|
||||
# Forward meta store the global meta information of the forward
|
||||
self.forward_meta: ForwardMeta = None
|
||||
self.forward_meta = None
|
||||
|
||||
def _update_mtp_config(self, main_model):
|
||||
"""
|
||||
@@ -166,7 +193,7 @@ class MTPProposer(Proposer):
|
||||
and hasattr(self.quant_config, "kv_cache_quant_type")
|
||||
and self.quant_config.kv_cache_quant_type is not None
|
||||
):
|
||||
cache_type = "uint8"
|
||||
cache_type = self._get_cache_type()
|
||||
kv_cache_quant_type = self.quant_config.kv_cache_quant_type
|
||||
|
||||
# Get kv cache shape
|
||||
@@ -220,7 +247,7 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["caches"] = list(self.cache_kvs.values())
|
||||
for value in self.cache_kvs.values():
|
||||
del value
|
||||
paddle.device.cuda.empty_cache()
|
||||
self._empty_cache()
|
||||
|
||||
def _initialize_attn_backend(
|
||||
self,
|
||||
@@ -245,9 +272,14 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["decoder_tile_ids_per_batch"] = paddle.zeros_like(
|
||||
self.target_model_inputs["decoder_tile_ids_per_batch"]
|
||||
)
|
||||
self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like(
|
||||
self.target_model_inputs["decoder_num_blocks_cpu"]
|
||||
).pin_memory()
|
||||
if current_platform.is_xpu():
|
||||
self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like(
|
||||
self.target_model_inputs["decoder_num_blocks_cpu"]
|
||||
).cpu()
|
||||
else:
|
||||
self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like(
|
||||
self.target_model_inputs["decoder_num_blocks_cpu"]
|
||||
).pin_memory()
|
||||
self.model_inputs["decoder_num_blocks_device"] = paddle.zeros_like(
|
||||
self.target_model_inputs["decoder_num_blocks_device"]
|
||||
)
|
||||
@@ -669,6 +701,36 @@ class MTPProposer(Proposer):
|
||||
|
||||
self.forward_meta.step_use_cudagraph = step_use_cudagraph and self.draft_model_use_cudagraph
|
||||
|
||||
def _initialize_forward_meta_xpu(self):
|
||||
|
||||
self.forward_meta.decoder_batch_ids = (self.model_inputs["decoder_batch_ids"],)
|
||||
self.forward_meta.decoder_tile_ids_per_batch = (self.model_inputs["decoder_tile_ids_per_batch"],)
|
||||
self.forward_meta.decoder_num_blocks_cpu = (self.model_inputs["decoder_num_blocks_cpu"],)
|
||||
self.forward_meta.decoder_num_blocks_device = (self.model_inputs["decoder_num_blocks_device"],)
|
||||
self.forward_meta.decoder_chunk_size_device = (self.model_inputs["decoder_chunk_size_device"],)
|
||||
self.forward_meta.max_len_tensor_cpu = (self.model_inputs["max_len_tensor_cpu"],)
|
||||
|
||||
self.forward_meta.encoder_batch_ids = (self.model_inputs["encoder_batch_ids"],)
|
||||
self.forward_meta.encoder_tile_ids_per_batch = (self.model_inputs["encoder_tile_ids_per_batch"],)
|
||||
self.forward_meta.encoder_num_blocks_x_cpu = (self.model_inputs["encoder_num_blocks_x_cpu"],)
|
||||
self.forward_meta.kv_batch_ids = (self.model_inputs["kv_batch_ids"],)
|
||||
self.forward_meta.kv_tile_ids_per_batch = (self.model_inputs["kv_tile_ids_per_batch"],)
|
||||
self.forward_meta.kv_num_blocks_x_cpu = (self.model_inputs["kv_num_blocks_x_cpu"],)
|
||||
self.forward_meta.pos_emb_type = "NORMAL"
|
||||
self.forward_meta.attn_backend = self.attn_backends[0]
|
||||
|
||||
# Initialzie attention meta data
|
||||
for attn_backend in self.attn_backends:
|
||||
attn_backend.init_attention_metadata(self.forward_meta)
|
||||
|
||||
# Mix ep in single node
|
||||
if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
|
||||
only_decode_batch_list = []
|
||||
prefill_exists = self.exist_prefill()
|
||||
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
|
||||
only_decode_batch = all(only_decode_batch_list)
|
||||
self.fd_config.model_config.moe_phase.phase = "decode" if only_decode_batch else "prefill"
|
||||
|
||||
def exist_prefill(self):
|
||||
"""
|
||||
check whether prefill stage exist
|
||||
@@ -682,7 +744,7 @@ class MTPProposer(Proposer):
|
||||
"""
|
||||
Prepare MTP inputs
|
||||
"""
|
||||
use_v1_cache_scheduler = envs.ENABLE_V1_KVCACHE_SCHEDULER
|
||||
use_v1_cache_scheduler = bool(envs.ENABLE_V1_KVCACHE_SCHEDULER)
|
||||
draft_model_preprocess(
|
||||
self.model_inputs["draft_tokens"],
|
||||
self.model_inputs["input_ids"],
|
||||
@@ -767,7 +829,7 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["step_idx"],
|
||||
)
|
||||
|
||||
def _propose(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False):
|
||||
def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False):
|
||||
"""
|
||||
Main process for MTP inference.
|
||||
Args:
|
||||
@@ -928,6 +990,96 @@ class MTPProposer(Proposer):
|
||||
if hasattr(self.model, "empty_input_forward"):
|
||||
self.model.empty_input_forward()
|
||||
|
||||
def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False):
|
||||
"""
|
||||
Main process for MTP inference.
|
||||
Args:
|
||||
step_use_cudagraph: bool
|
||||
Whether to use cuda graph. Use the target model flag to avoid hanging problems with EP.
|
||||
"""
|
||||
for substep in range(self.num_model_steps):
|
||||
if self.model_inputs["not_need_stop"]:
|
||||
self.model_inputs["substep"] = substep
|
||||
# Remove padding
|
||||
self.forward_meta = xpu_pre_process(
|
||||
self.model_inputs["input_ids"],
|
||||
self.model_inputs["seq_lens_this_time"],
|
||||
self.model_inputs,
|
||||
True,
|
||||
self.cache_config.block_size,
|
||||
self.model_inputs["draft_tokens"],
|
||||
self.model_inputs["seq_lens_encoder"],
|
||||
self.model_inputs["seq_lens_decoder"],
|
||||
)
|
||||
self._initialize_forward_meta_xpu()
|
||||
# Get sampling metadata
|
||||
self.sampling_metadata = SamplingMetadata(
|
||||
temperature=self.model_inputs["temperature"],
|
||||
top_p=self.model_inputs["top_p"],
|
||||
top_k=self.model_inputs["top_k"],
|
||||
seed=self.model_inputs["infer_seed"],
|
||||
step_idx=self.model_inputs["step_idx"],
|
||||
pre_token_ids=self.model_inputs["pre_ids"],
|
||||
frequency_penalties=self.model_inputs["frequency_score"],
|
||||
presence_penalties=self.model_inputs["presence_score"],
|
||||
repetition_penalties=self.model_inputs["penalty_score"],
|
||||
min_dec_lens=self.model_inputs["min_dec_len"],
|
||||
bad_words_token_ids=self.model_inputs["bad_tokens"],
|
||||
eos_token_ids=self.model_inputs["eos_token_id"],
|
||||
max_num_logprobs=20 if self.enable_logprob else None,
|
||||
temp_scaled_logprobs=self.model_inputs["temp_scaled_logprobs"],
|
||||
top_p_normalized_logprobs=self.model_inputs["top_p_normalized_logprobs"],
|
||||
share_inputs=self.model_inputs,
|
||||
)
|
||||
|
||||
if self.num_model_steps > 1:
|
||||
self.last_seq_lens_this_time = paddle.clone(self.model_inputs["seq_lens_this_time"])
|
||||
|
||||
model_output = self.model(
|
||||
ids_remove_padding=self.model_inputs["ids_remove_padding"],
|
||||
previous_hidden_states=self.model_inputs["target_hidden_states"],
|
||||
forward_meta=self.forward_meta,
|
||||
)
|
||||
hidden_states = xpu_process_output(
|
||||
model_output, self.model_inputs["cum_offsets"], self.forward_meta, self.model_inputs
|
||||
)
|
||||
# 4. Compute logits, Sample
|
||||
logits = self.model.compute_logits(hidden_states)
|
||||
sampled_token_ids, sampler_output = self.sampler(
|
||||
logits,
|
||||
self.sampling_metadata,
|
||||
self.max_model_len,
|
||||
self.model_inputs,
|
||||
)
|
||||
|
||||
if substep == 0 and sampler_output.logprobs_tensors is not None:
|
||||
real_bsz = self.model_inputs["seq_lens_this_time"].shape[0]
|
||||
speculate_save_output_topk(
|
||||
sampler_output.sampled_token_ids,
|
||||
sampler_output.logprobs_tensors.logprob_token_ids,
|
||||
sampler_output.logprobs_tensors.logprobs,
|
||||
sampler_output.logprobs_tensors.selected_token_ranks,
|
||||
self.model_inputs["batch_token_num"][:real_bsz],
|
||||
self.model_inputs["cu_batch_token_offset"][:real_bsz],
|
||||
self.model_inputs["not_need_stop"],
|
||||
4, # mtype
|
||||
self.local_rank,
|
||||
)
|
||||
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
paddle.distributed.broadcast(
|
||||
sampled_token_ids,
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
|
||||
self._post_process(sampled_token_ids)
|
||||
if substep != self.num_model_steps - 1:
|
||||
self._get_self_hidden_states(hidden_states)
|
||||
else:
|
||||
if hasattr(self.model, "empty_input_forward"):
|
||||
self.model.empty_input_forward()
|
||||
|
||||
def _get_self_hidden_states(self, hidden_states):
|
||||
target_hidden_states = eagle_get_self_hidden_states(
|
||||
hidden_states,
|
||||
@@ -1044,3 +1196,21 @@ class MTPProposer(Proposer):
|
||||
self.forward_meta.seq_lens_this_time = self.seq_lens_this_time_buffer
|
||||
self.real_token_num = self.forward_meta.ids_remove_padding.shape[0]
|
||||
return
|
||||
|
||||
def _empty_cache(self):
|
||||
if current_platform.is_cuda():
|
||||
paddle.device.cuda.empty_cache()
|
||||
elif current_platform.is_xpu():
|
||||
paddle.device.xpu.empty_cache()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_cache_type(self):
|
||||
cache_type = None
|
||||
if current_platform.is_cuda():
|
||||
cache_type = "uint8"
|
||||
elif current_platform.is_xpu():
|
||||
cache_type = "int8"
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return cache_type
|
||||
|
||||
Reference in New Issue
Block a user