mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature][MTP] Support cacheKV transfer in per_chunk mode (#2890)
* support chunk_prefill both normal and speculative_decoding(mtp) * optimize pd-disaggregation config * fix bug
This commit is contained in:
@@ -26,7 +26,7 @@ from paddle.nn.functional.flash_attention import flash_attn_unpadded
|
||||
|
||||
from fastdeploy.model_executor.layers.attention.ops import (
|
||||
get_block_shape_and_split_kv_block, init_signal_layerwise,
|
||||
open_shm_and_get_meta_signal)
|
||||
open_shm_and_get_meta_signal, init_kv_signal_per_query)
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda() and not current_platform.is_dcu():
|
||||
@@ -109,6 +109,7 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
self.use_speculate: bool = self.speculative_method is not None
|
||||
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
|
||||
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
||||
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
|
||||
|
||||
self.kv_num_heads: int = kv_num_heads
|
||||
self.num_heads: int = num_heads
|
||||
@@ -129,9 +130,8 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.attn_softmax_scale = self.attn_softmax_scale * mscale * mscale
|
||||
|
||||
# pd_disaggregation
|
||||
self.use_pd_disaggregation: int = int(
|
||||
os.getenv("FLAGS_use_pd_disaggregation", 0))
|
||||
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
|
||||
|
||||
self.start_layer_index: int = fd_config.model_config.start_layer_index
|
||||
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None)
|
||||
|
||||
@@ -189,7 +189,16 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
|
||||
# pd_disaggregation
|
||||
metadata.kv_signal_data_list = [None] * self.num_layers
|
||||
if self.use_pd_disaggregation:
|
||||
if self.pd_disaggregation_mode == "per_chunk":
|
||||
if not self.keep_pd_step_flag:
|
||||
init_kv_signal_per_query(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.seq_lens_decoder,
|
||||
self.rank,
|
||||
self.num_layers + self.num_layers_draft_model,
|
||||
)
|
||||
elif self.pd_disaggregation_mode == "per_query":
|
||||
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
|
||||
self.rank, int(self.device_id), self.keep_pd_step_flag)
|
||||
|
||||
@@ -223,7 +232,7 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
metadata = self.attention_metadata
|
||||
|
||||
if self.use_pd_disaggregation:
|
||||
if self.pd_disaggregation_mode == "per_query":
|
||||
metadata.kv_signal_data_list[
|
||||
layer.layer_id] = init_signal_layerwise(
|
||||
metadata.kv_signal_metadata,
|
||||
|
Reference in New Issue
Block a user