[Feature][MTP] Support cacheKV transfer in per_chunk mode (#2890)

* support chunk_prefill both normal and speculative_decoding(mtp)

* optimize pd-disaggregation config

* fix bug
This commit is contained in:
freeliuzc
2025-07-17 17:58:08 +08:00
committed by GitHub
parent 67180c1ff9
commit d49f8fb30a
10 changed files with 110 additions and 27 deletions

View File

@@ -36,9 +36,9 @@ void GetOutputKVSignal(const paddle::Tensor& x,
int* out_data = const_cast<int*>(x.data<int>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 2 + 2) * 4, 0, IPC_NOWAIT);
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, IPC_NOWAIT);
} else {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 2 + 2) * 4, 0, 0);
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, 0);
}
if (ret == -1) {
out_data[0] = -1;
@@ -47,7 +47,7 @@ void GetOutputKVSignal(const paddle::Tensor& x,
}
int encoder_count = msg_rcv.mtext[0];
for (int i = 0; i < encoder_count * 2 + 2; i++) {
for (int i = 0; i < encoder_count * 3 + 2; i++) {
out_data[i] = msg_rcv.mtext[i];
}
return;

View File

@@ -35,5 +35,5 @@ struct msgdata {
struct msgdatakv {
long mtype;
int mtext[MAX_BSZ * 2 + 2]; // encoder_count, layer_id, bid- pair
int mtext[MAX_BSZ * 3 + 2]; // encoder_count, layer_id, bid- pair
};

View File

@@ -64,9 +64,10 @@ struct RemoteCacheKvIpc {
int encoder_count = 0;
for (int i = 0; i < real_bsz; i++) {
if (seq_lens_encoder[i] > 0) {
msg_sed.mtext[3 * encoder_count + 2] = i;
msg_sed.mtext[3 * encoder_count + 3] = seq_lens_decoder[i];
msg_sed.mtext[3 * encoder_count + 4] = seq_lens_encoder[i];
encoder_count++;
msg_sed.mtext[2 * i + 2] = i;
msg_sed.mtext[2 * i + 3] = seq_lens_decoder[i];
}
}
msg_sed.mtext[0] = encoder_count;
@@ -82,7 +83,7 @@ struct RemoteCacheKvIpc {
void CUDART_CB send_signal() {
msg_sed.mtext[1] = layer_id_;
if ((msgsnd(msgid, &msg_sed, (MAX_BSZ * 2 + 2) * 4, 0)) == -1) {
if ((msgsnd(msgid, &msg_sed, (MAX_BSZ * 3 + 2) * 4, 0)) == -1) {
printf("kv signal full msg buffer\n");
}
layer_id_ = (layer_id_ + 1);

View File

@@ -16,6 +16,7 @@
from __future__ import annotations
import os
from dataclasses import dataclass, field
from enum import Enum
from typing import Literal, Optional
@@ -109,7 +110,7 @@ class ModelConfig:
self.ori_vocab_size = self.vocab_size
if "Ernie4_5_ForCausalLM" in self.architectures or "Ernie4_5_MoeForCausalLM" in self.architectures:
self.ori_vocab_size = args["ori_vocab_size"]
self.ori_vocab_size = args.get("ori_vocab_size", self.ori_vocab_size)
class ParallelConfig:
"""Configuration for the distributed execution."""
@@ -191,6 +192,18 @@ class ParallelConfig:
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
self.enable_custom_all_reduce: bool = False
# pd_disaggregation
use_pd_disaggregation: int = int(
os.getenv("FLAGS_use_pd_disaggregation", 0))
use_pd_disaggregation_per_chunk: int = int(
os.getenv("FLAGS_use_pd_disaggregation_per_chunk", 0))
if use_pd_disaggregation_per_chunk:
self.pd_disaggregation_mode = "per_chunk"
elif use_pd_disaggregation:
self.pd_disaggregation_mode = "per_query"
else:
self.pd_disaggregation_mode = "None"
class SpeculativeConfig:
"""
Configuration for speculative decoding.

View File

@@ -24,7 +24,8 @@ import paddle
from fastdeploy.model_executor.layers.attention.ops import (
append_attention, get_block_shape_and_split_kv_block,
init_signal_layerwise, open_shm_and_get_meta_signal)
init_signal_layerwise, open_shm_and_get_meta_signal,
init_kv_signal_per_query)
if TYPE_CHECKING:
from fastdeploy.model_executor.forward_meta import ForwardMeta
@@ -92,6 +93,7 @@ class AppendAttentionBackend(AttentionBackend):
self.use_speculate: bool = self.speculative_method is not None
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
self.kv_num_heads: int = kv_num_heads
self.num_heads: int = num_heads
@@ -100,9 +102,8 @@ class AppendAttentionBackend(AttentionBackend):
self.max_partition_size: int = int(
os.getenv("FLAGS_max_partition_size", 32768))
# pd_disaggregation
self.use_pd_disaggregation: int = int(
os.getenv("FLAGS_use_pd_disaggregation", 0))
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
self.start_layer_index: int = fd_config.model_config.start_layer_index
if fd_config.parallel_config.expert_parallel_rank is None:
@@ -154,9 +155,19 @@ class AppendAttentionBackend(AttentionBackend):
# pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers
if self.use_pd_disaggregation:
if self.pd_disaggregation_mode == "per_chunk":
if not self.keep_pd_step_flag:
init_kv_signal_per_query(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_this_time,
forward_meta.seq_lens_decoder,
self.rank,
self.num_layers + self.num_layers_draft_model,
)
elif self.pd_disaggregation_mode == "per_query":
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
self.rank, int(self.device_id), self.keep_pd_step_flag)
self.attention_metadata: AttentionMetadata = metadata
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
forward_meta.decoder_tile_ids_per_batch.copy_(
@@ -192,7 +203,7 @@ class AppendAttentionBackend(AttentionBackend):
"""
metadata = self.attention_metadata
if self.use_pd_disaggregation:
if self.pd_disaggregation_mode == "per_query":
metadata.kv_signal_data_list[
layer.layer_id] = init_signal_layerwise(
metadata.kv_signal_metadata,

View File

@@ -33,7 +33,8 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend, AttentionMetadata)
from fastdeploy.model_executor.layers.attention.ops import (
get_block_shape_and_split_kv_block, gqa_rope_write_cache,
init_signal_layerwise, open_shm_and_get_meta_signal, pre_cache_len_concat)
init_signal_layerwise, open_shm_and_get_meta_signal, pre_cache_len_concat,
init_kv_signal_per_query)
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
if TYPE_CHECKING:
from fastdeploy.model_executor.forward_meta import ForwardMeta
@@ -102,10 +103,10 @@ class FlashAttentionBackend(AttentionBackend):
self.use_speculate = self.speculative_method is not None
self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
# pd_disaggregation
self.use_pd_disaggregation: int = int(
os.getenv("FLAGS_use_pd_disaggregation", 0))
self.start_layer_index: int = fd_config.model_config.start_layer_index
if fd_config.parallel_config.expert_parallel_rank is None:
@@ -173,7 +174,16 @@ class FlashAttentionBackend(AttentionBackend):
# pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers
if self.use_pd_disaggregation:
if self.pd_disaggregation_mode == "per_chunk":
if not self.keep_pd_step_flag:
init_kv_signal_per_query(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_this_time,
forward_meta.seq_lens_decoder,
self.rank,
self.num_layers + self.num_layers_draft_model,
)
elif self.pd_disaggregation_mode == "per_query":
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
self.rank, int(self.device_id), self.keep_pd_step_flag)
self.attention_metadata = metadata
@@ -194,7 +204,7 @@ class FlashAttentionBackend(AttentionBackend):
):
metadata = self.attention_metadata
if self.use_pd_disaggregation:
if self.pd_disaggregation_mode == "per_query":
metadata.kv_signal_data_list[
layer.layer_id] = init_signal_layerwise(
metadata.kv_signal_metadata,

View File

@@ -26,7 +26,7 @@ from paddle.nn.functional.flash_attention import flash_attn_unpadded
from fastdeploy.model_executor.layers.attention.ops import (
get_block_shape_and_split_kv_block, init_signal_layerwise,
open_shm_and_get_meta_signal)
open_shm_and_get_meta_signal, init_kv_signal_per_query)
from fastdeploy.platforms import current_platform
if current_platform.is_cuda() and not current_platform.is_dcu():
@@ -109,6 +109,7 @@ class MLAAttentionBackend(AttentionBackend):
self.use_speculate: bool = self.speculative_method is not None
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
self.kv_num_heads: int = kv_num_heads
self.num_heads: int = num_heads
@@ -129,9 +130,8 @@ class MLAAttentionBackend(AttentionBackend):
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.attn_softmax_scale = self.attn_softmax_scale * mscale * mscale
# pd_disaggregation
self.use_pd_disaggregation: int = int(
os.getenv("FLAGS_use_pd_disaggregation", 0))
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
self.start_layer_index: int = fd_config.model_config.start_layer_index
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None)
@@ -189,7 +189,16 @@ class MLAAttentionBackend(AttentionBackend):
# pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers
if self.use_pd_disaggregation:
if self.pd_disaggregation_mode == "per_chunk":
if not self.keep_pd_step_flag:
init_kv_signal_per_query(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_this_time,
forward_meta.seq_lens_decoder,
self.rank,
self.num_layers + self.num_layers_draft_model,
)
elif self.pd_disaggregation_mode == "per_query":
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
self.rank, int(self.device_id), self.keep_pd_step_flag)
@@ -223,7 +232,7 @@ class MLAAttentionBackend(AttentionBackend):
"""
metadata = self.attention_metadata
if self.use_pd_disaggregation:
if self.pd_disaggregation_mode == "per_query":
metadata.kv_signal_data_list[
layer.layer_id] = init_signal_layerwise(
metadata.kv_signal_metadata,

View File

@@ -21,6 +21,7 @@ from .gqa_rope_write_cache import gqa_rope_write_cache
from .init_signal_layerwise import init_signal_layerwise
from .open_shm_and_get_meta_signal import open_shm_and_get_meta_signal
from .pre_cache_len_concat import pre_cache_len_concat
from .init_kv_signal_per_query import init_kv_signal_per_query
__all__ = [
"get_block_shape_and_split_kv_block",
@@ -29,4 +30,5 @@ __all__ = [
"init_signal_layerwise",
"gqa_rope_write_cache",
"pre_cache_len_concat",
"init_kv_signal_per_query"
]

View File

@@ -0,0 +1,37 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
from fastdeploy.platforms import current_platform
def init_kv_signal_per_query(
seq_lens_encoder: paddle.Tensor,
seq_lens_this_time: paddle.Tensor,
seq_lens_decoder: paddle.Tensor,
rank: int,
num_layers: int,
) -> paddle.Tensor:
"""
init_kv_signal_per_query
"""
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import init_kv_signal_per_query
out = init_kv_signal_per_query(seq_lens_encoder, seq_lens_this_time, seq_lens_decoder, rank, num_layers)
return out
else:
raise NotImplementedError()

View File

@@ -73,7 +73,7 @@ def load_ep_checkpoint(model_path: str,
range(base_range.start + config.moe_num_experts[0], base_range.stop + config.moe_num_experts[0]))
return base_range
for i in range(config.moe_layer_start_index, config.num_layers):
for i in range(config.moe_layer_start_index, config.num_hidden_layers):
for j in get_expert_ranges(config):
up_gate_proj_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight"
down_proj_key = (f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight")