[Feature][MTP] Support cacheKV transfer in per_chunk mode (#2890)

* support chunk_prefill both normal and speculative_decoding(mtp)

* optimize pd-disaggregation config

* fix bug
This commit is contained in:
freeliuzc
2025-07-17 17:58:08 +08:00
committed by GitHub
parent 67180c1ff9
commit d49f8fb30a
10 changed files with 110 additions and 27 deletions

View File

@@ -36,9 +36,9 @@ void GetOutputKVSignal(const paddle::Tensor& x,
int* out_data = const_cast<int*>(x.data<int>()); int* out_data = const_cast<int*>(x.data<int>());
int ret = -1; int ret = -1;
if (!wait_flag) { if (!wait_flag) {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 2 + 2) * 4, 0, IPC_NOWAIT); ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, IPC_NOWAIT);
} else { } else {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 2 + 2) * 4, 0, 0); ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, 0);
} }
if (ret == -1) { if (ret == -1) {
out_data[0] = -1; out_data[0] = -1;
@@ -47,7 +47,7 @@ void GetOutputKVSignal(const paddle::Tensor& x,
} }
int encoder_count = msg_rcv.mtext[0]; int encoder_count = msg_rcv.mtext[0];
for (int i = 0; i < encoder_count * 2 + 2; i++) { for (int i = 0; i < encoder_count * 3 + 2; i++) {
out_data[i] = msg_rcv.mtext[i]; out_data[i] = msg_rcv.mtext[i];
} }
return; return;

View File

@@ -35,5 +35,5 @@ struct msgdata {
struct msgdatakv { struct msgdatakv {
long mtype; long mtype;
int mtext[MAX_BSZ * 2 + 2]; // encoder_count, layer_id, bid- pair int mtext[MAX_BSZ * 3 + 2]; // encoder_count, layer_id, bid- pair
}; };

View File

@@ -64,9 +64,10 @@ struct RemoteCacheKvIpc {
int encoder_count = 0; int encoder_count = 0;
for (int i = 0; i < real_bsz; i++) { for (int i = 0; i < real_bsz; i++) {
if (seq_lens_encoder[i] > 0) { if (seq_lens_encoder[i] > 0) {
msg_sed.mtext[3 * encoder_count + 2] = i;
msg_sed.mtext[3 * encoder_count + 3] = seq_lens_decoder[i];
msg_sed.mtext[3 * encoder_count + 4] = seq_lens_encoder[i];
encoder_count++; encoder_count++;
msg_sed.mtext[2 * i + 2] = i;
msg_sed.mtext[2 * i + 3] = seq_lens_decoder[i];
} }
} }
msg_sed.mtext[0] = encoder_count; msg_sed.mtext[0] = encoder_count;
@@ -82,7 +83,7 @@ struct RemoteCacheKvIpc {
void CUDART_CB send_signal() { void CUDART_CB send_signal() {
msg_sed.mtext[1] = layer_id_; msg_sed.mtext[1] = layer_id_;
if ((msgsnd(msgid, &msg_sed, (MAX_BSZ * 2 + 2) * 4, 0)) == -1) { if ((msgsnd(msgid, &msg_sed, (MAX_BSZ * 3 + 2) * 4, 0)) == -1) {
printf("kv signal full msg buffer\n"); printf("kv signal full msg buffer\n");
} }
layer_id_ = (layer_id_ + 1); layer_id_ = (layer_id_ + 1);

View File

@@ -16,6 +16,7 @@
from __future__ import annotations from __future__ import annotations
import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import Literal, Optional from typing import Literal, Optional
@@ -109,7 +110,7 @@ class ModelConfig:
self.ori_vocab_size = self.vocab_size self.ori_vocab_size = self.vocab_size
if "Ernie4_5_ForCausalLM" in self.architectures or "Ernie4_5_MoeForCausalLM" in self.architectures: if "Ernie4_5_ForCausalLM" in self.architectures or "Ernie4_5_MoeForCausalLM" in self.architectures:
self.ori_vocab_size = args["ori_vocab_size"] self.ori_vocab_size = args.get("ori_vocab_size", self.ori_vocab_size)
class ParallelConfig: class ParallelConfig:
"""Configuration for the distributed execution.""" """Configuration for the distributed execution."""
@@ -191,6 +192,18 @@ class ParallelConfig:
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce). # enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
self.enable_custom_all_reduce: bool = False self.enable_custom_all_reduce: bool = False
# pd_disaggregation
use_pd_disaggregation: int = int(
os.getenv("FLAGS_use_pd_disaggregation", 0))
use_pd_disaggregation_per_chunk: int = int(
os.getenv("FLAGS_use_pd_disaggregation_per_chunk", 0))
if use_pd_disaggregation_per_chunk:
self.pd_disaggregation_mode = "per_chunk"
elif use_pd_disaggregation:
self.pd_disaggregation_mode = "per_query"
else:
self.pd_disaggregation_mode = "None"
class SpeculativeConfig: class SpeculativeConfig:
""" """
Configuration for speculative decoding. Configuration for speculative decoding.

View File

@@ -24,7 +24,8 @@ import paddle
from fastdeploy.model_executor.layers.attention.ops import ( from fastdeploy.model_executor.layers.attention.ops import (
append_attention, get_block_shape_and_split_kv_block, append_attention, get_block_shape_and_split_kv_block,
init_signal_layerwise, open_shm_and_get_meta_signal) init_signal_layerwise, open_shm_and_get_meta_signal,
init_kv_signal_per_query)
if TYPE_CHECKING: if TYPE_CHECKING:
from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.forward_meta import ForwardMeta
@@ -92,6 +93,7 @@ class AppendAttentionBackend(AttentionBackend):
self.use_speculate: bool = self.speculative_method is not None self.use_speculate: bool = self.speculative_method is not None
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
self.kv_num_heads: int = kv_num_heads self.kv_num_heads: int = kv_num_heads
self.num_heads: int = num_heads self.num_heads: int = num_heads
@@ -100,9 +102,8 @@ class AppendAttentionBackend(AttentionBackend):
self.max_partition_size: int = int( self.max_partition_size: int = int(
os.getenv("FLAGS_max_partition_size", 32768)) os.getenv("FLAGS_max_partition_size", 32768))
# pd_disaggregation self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
self.use_pd_disaggregation: int = int(
os.getenv("FLAGS_use_pd_disaggregation", 0))
self.start_layer_index: int = fd_config.model_config.start_layer_index self.start_layer_index: int = fd_config.model_config.start_layer_index
if fd_config.parallel_config.expert_parallel_rank is None: if fd_config.parallel_config.expert_parallel_rank is None:
@@ -154,9 +155,19 @@ class AppendAttentionBackend(AttentionBackend):
# pd_disaggregation # pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers metadata.kv_signal_data_list = [None] * self.num_layers
if self.use_pd_disaggregation: if self.pd_disaggregation_mode == "per_chunk":
if not self.keep_pd_step_flag:
init_kv_signal_per_query(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_this_time,
forward_meta.seq_lens_decoder,
self.rank,
self.num_layers + self.num_layers_draft_model,
)
elif self.pd_disaggregation_mode == "per_query":
metadata.kv_signal_metadata = open_shm_and_get_meta_signal( metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
self.rank, int(self.device_id), self.keep_pd_step_flag) self.rank, int(self.device_id), self.keep_pd_step_flag)
self.attention_metadata: AttentionMetadata = metadata self.attention_metadata: AttentionMetadata = metadata
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False) forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
forward_meta.decoder_tile_ids_per_batch.copy_( forward_meta.decoder_tile_ids_per_batch.copy_(
@@ -192,7 +203,7 @@ class AppendAttentionBackend(AttentionBackend):
""" """
metadata = self.attention_metadata metadata = self.attention_metadata
if self.use_pd_disaggregation: if self.pd_disaggregation_mode == "per_query":
metadata.kv_signal_data_list[ metadata.kv_signal_data_list[
layer.layer_id] = init_signal_layerwise( layer.layer_id] = init_signal_layerwise(
metadata.kv_signal_metadata, metadata.kv_signal_metadata,

View File

@@ -33,7 +33,8 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend, AttentionMetadata) AttentionBackend, AttentionMetadata)
from fastdeploy.model_executor.layers.attention.ops import ( from fastdeploy.model_executor.layers.attention.ops import (
get_block_shape_and_split_kv_block, gqa_rope_write_cache, get_block_shape_and_split_kv_block, gqa_rope_write_cache,
init_signal_layerwise, open_shm_and_get_meta_signal, pre_cache_len_concat) init_signal_layerwise, open_shm_and_get_meta_signal, pre_cache_len_concat,
init_kv_signal_per_query)
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
if TYPE_CHECKING: if TYPE_CHECKING:
from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.forward_meta import ForwardMeta
@@ -102,10 +103,10 @@ class FlashAttentionBackend(AttentionBackend):
self.use_speculate = self.speculative_method is not None self.use_speculate = self.speculative_method is not None
self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
# pd_disaggregation
self.use_pd_disaggregation: int = int(
os.getenv("FLAGS_use_pd_disaggregation", 0))
self.start_layer_index: int = fd_config.model_config.start_layer_index self.start_layer_index: int = fd_config.model_config.start_layer_index
if fd_config.parallel_config.expert_parallel_rank is None: if fd_config.parallel_config.expert_parallel_rank is None:
@@ -173,7 +174,16 @@ class FlashAttentionBackend(AttentionBackend):
# pd_disaggregation # pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers metadata.kv_signal_data_list = [None] * self.num_layers
if self.use_pd_disaggregation: if self.pd_disaggregation_mode == "per_chunk":
if not self.keep_pd_step_flag:
init_kv_signal_per_query(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_this_time,
forward_meta.seq_lens_decoder,
self.rank,
self.num_layers + self.num_layers_draft_model,
)
elif self.pd_disaggregation_mode == "per_query":
metadata.kv_signal_metadata = open_shm_and_get_meta_signal( metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
self.rank, int(self.device_id), self.keep_pd_step_flag) self.rank, int(self.device_id), self.keep_pd_step_flag)
self.attention_metadata = metadata self.attention_metadata = metadata
@@ -194,7 +204,7 @@ class FlashAttentionBackend(AttentionBackend):
): ):
metadata = self.attention_metadata metadata = self.attention_metadata
if self.use_pd_disaggregation: if self.pd_disaggregation_mode == "per_query":
metadata.kv_signal_data_list[ metadata.kv_signal_data_list[
layer.layer_id] = init_signal_layerwise( layer.layer_id] = init_signal_layerwise(
metadata.kv_signal_metadata, metadata.kv_signal_metadata,

View File

@@ -26,7 +26,7 @@ from paddle.nn.functional.flash_attention import flash_attn_unpadded
from fastdeploy.model_executor.layers.attention.ops import ( from fastdeploy.model_executor.layers.attention.ops import (
get_block_shape_and_split_kv_block, init_signal_layerwise, get_block_shape_and_split_kv_block, init_signal_layerwise,
open_shm_and_get_meta_signal) open_shm_and_get_meta_signal, init_kv_signal_per_query)
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
if current_platform.is_cuda() and not current_platform.is_dcu(): if current_platform.is_cuda() and not current_platform.is_dcu():
@@ -109,6 +109,7 @@ class MLAAttentionBackend(AttentionBackend):
self.use_speculate: bool = self.speculative_method is not None self.use_speculate: bool = self.speculative_method is not None
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
self.kv_num_heads: int = kv_num_heads self.kv_num_heads: int = kv_num_heads
self.num_heads: int = num_heads self.num_heads: int = num_heads
@@ -129,9 +130,8 @@ class MLAAttentionBackend(AttentionBackend):
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.attn_softmax_scale = self.attn_softmax_scale * mscale * mscale self.attn_softmax_scale = self.attn_softmax_scale * mscale * mscale
# pd_disaggregation self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
self.use_pd_disaggregation: int = int(
os.getenv("FLAGS_use_pd_disaggregation", 0))
self.start_layer_index: int = fd_config.model_config.start_layer_index self.start_layer_index: int = fd_config.model_config.start_layer_index
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None) self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None)
@@ -189,7 +189,16 @@ class MLAAttentionBackend(AttentionBackend):
# pd_disaggregation # pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers metadata.kv_signal_data_list = [None] * self.num_layers
if self.use_pd_disaggregation: if self.pd_disaggregation_mode == "per_chunk":
if not self.keep_pd_step_flag:
init_kv_signal_per_query(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_this_time,
forward_meta.seq_lens_decoder,
self.rank,
self.num_layers + self.num_layers_draft_model,
)
elif self.pd_disaggregation_mode == "per_query":
metadata.kv_signal_metadata = open_shm_and_get_meta_signal( metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
self.rank, int(self.device_id), self.keep_pd_step_flag) self.rank, int(self.device_id), self.keep_pd_step_flag)
@@ -223,7 +232,7 @@ class MLAAttentionBackend(AttentionBackend):
""" """
metadata = self.attention_metadata metadata = self.attention_metadata
if self.use_pd_disaggregation: if self.pd_disaggregation_mode == "per_query":
metadata.kv_signal_data_list[ metadata.kv_signal_data_list[
layer.layer_id] = init_signal_layerwise( layer.layer_id] = init_signal_layerwise(
metadata.kv_signal_metadata, metadata.kv_signal_metadata,

View File

@@ -21,6 +21,7 @@ from .gqa_rope_write_cache import gqa_rope_write_cache
from .init_signal_layerwise import init_signal_layerwise from .init_signal_layerwise import init_signal_layerwise
from .open_shm_and_get_meta_signal import open_shm_and_get_meta_signal from .open_shm_and_get_meta_signal import open_shm_and_get_meta_signal
from .pre_cache_len_concat import pre_cache_len_concat from .pre_cache_len_concat import pre_cache_len_concat
from .init_kv_signal_per_query import init_kv_signal_per_query
__all__ = [ __all__ = [
"get_block_shape_and_split_kv_block", "get_block_shape_and_split_kv_block",
@@ -29,4 +30,5 @@ __all__ = [
"init_signal_layerwise", "init_signal_layerwise",
"gqa_rope_write_cache", "gqa_rope_write_cache",
"pre_cache_len_concat", "pre_cache_len_concat",
"init_kv_signal_per_query"
] ]

View File

@@ -0,0 +1,37 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
from fastdeploy.platforms import current_platform
def init_kv_signal_per_query(
seq_lens_encoder: paddle.Tensor,
seq_lens_this_time: paddle.Tensor,
seq_lens_decoder: paddle.Tensor,
rank: int,
num_layers: int,
) -> paddle.Tensor:
"""
init_kv_signal_per_query
"""
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import init_kv_signal_per_query
out = init_kv_signal_per_query(seq_lens_encoder, seq_lens_this_time, seq_lens_decoder, rank, num_layers)
return out
else:
raise NotImplementedError()

View File

@@ -73,7 +73,7 @@ def load_ep_checkpoint(model_path: str,
range(base_range.start + config.moe_num_experts[0], base_range.stop + config.moe_num_experts[0])) range(base_range.start + config.moe_num_experts[0], base_range.stop + config.moe_num_experts[0]))
return base_range return base_range
for i in range(config.moe_layer_start_index, config.num_layers): for i in range(config.moe_layer_start_index, config.num_hidden_layers):
for j in get_expert_ranges(config): for j in get_expert_ranges(config):
up_gate_proj_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight" up_gate_proj_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight"
down_proj_key = (f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight") down_proj_key = (f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight")