mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Feature][MTP] Support cacheKV transfer in per_chunk mode (#2890)
* support chunk_prefill both normal and speculative_decoding(mtp) * optimize pd-disaggregation config * fix bug
This commit is contained in:
@@ -24,7 +24,8 @@ import paddle
|
||||
|
||||
from fastdeploy.model_executor.layers.attention.ops import (
|
||||
append_attention, get_block_shape_and_split_kv_block,
|
||||
init_signal_layerwise, open_shm_and_get_meta_signal)
|
||||
init_signal_layerwise, open_shm_and_get_meta_signal,
|
||||
init_kv_signal_per_query)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
@@ -92,6 +93,7 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
self.use_speculate: bool = self.speculative_method is not None
|
||||
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
|
||||
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
||||
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
|
||||
|
||||
self.kv_num_heads: int = kv_num_heads
|
||||
self.num_heads: int = num_heads
|
||||
@@ -100,9 +102,8 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
self.max_partition_size: int = int(
|
||||
os.getenv("FLAGS_max_partition_size", 32768))
|
||||
|
||||
# pd_disaggregation
|
||||
self.use_pd_disaggregation: int = int(
|
||||
os.getenv("FLAGS_use_pd_disaggregation", 0))
|
||||
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
|
||||
|
||||
self.start_layer_index: int = fd_config.model_config.start_layer_index
|
||||
|
||||
if fd_config.parallel_config.expert_parallel_rank is None:
|
||||
@@ -154,9 +155,19 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
|
||||
# pd_disaggregation
|
||||
metadata.kv_signal_data_list = [None] * self.num_layers
|
||||
if self.use_pd_disaggregation:
|
||||
if self.pd_disaggregation_mode == "per_chunk":
|
||||
if not self.keep_pd_step_flag:
|
||||
init_kv_signal_per_query(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.seq_lens_decoder,
|
||||
self.rank,
|
||||
self.num_layers + self.num_layers_draft_model,
|
||||
)
|
||||
elif self.pd_disaggregation_mode == "per_query":
|
||||
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
|
||||
self.rank, int(self.device_id), self.keep_pd_step_flag)
|
||||
|
||||
self.attention_metadata: AttentionMetadata = metadata
|
||||
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
|
||||
forward_meta.decoder_tile_ids_per_batch.copy_(
|
||||
@@ -192,7 +203,7 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
metadata = self.attention_metadata
|
||||
|
||||
if self.use_pd_disaggregation:
|
||||
if self.pd_disaggregation_mode == "per_query":
|
||||
metadata.kv_signal_data_list[
|
||||
layer.layer_id] = init_signal_layerwise(
|
||||
metadata.kv_signal_metadata,
|
||||
|
Reference in New Issue
Block a user