mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature][MTP] Support cacheKV transfer in per_chunk mode (#2890)
* support chunk_prefill both normal and speculative_decoding(mtp) * optimize pd-disaggregation config * fix bug
This commit is contained in:
@@ -36,9 +36,9 @@ void GetOutputKVSignal(const paddle::Tensor& x,
|
||||
int* out_data = const_cast<int*>(x.data<int>());
|
||||
int ret = -1;
|
||||
if (!wait_flag) {
|
||||
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 2 + 2) * 4, 0, IPC_NOWAIT);
|
||||
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, IPC_NOWAIT);
|
||||
} else {
|
||||
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 2 + 2) * 4, 0, 0);
|
||||
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, 0);
|
||||
}
|
||||
if (ret == -1) {
|
||||
out_data[0] = -1;
|
||||
@@ -47,7 +47,7 @@ void GetOutputKVSignal(const paddle::Tensor& x,
|
||||
}
|
||||
int encoder_count = msg_rcv.mtext[0];
|
||||
|
||||
for (int i = 0; i < encoder_count * 2 + 2; i++) {
|
||||
for (int i = 0; i < encoder_count * 3 + 2; i++) {
|
||||
out_data[i] = msg_rcv.mtext[i];
|
||||
}
|
||||
return;
|
||||
|
@@ -35,5 +35,5 @@ struct msgdata {
|
||||
|
||||
struct msgdatakv {
|
||||
long mtype;
|
||||
int mtext[MAX_BSZ * 2 + 2]; // encoder_count, layer_id, bid- pair
|
||||
int mtext[MAX_BSZ * 3 + 2]; // encoder_count, layer_id, bid- pair
|
||||
};
|
@@ -64,9 +64,10 @@ struct RemoteCacheKvIpc {
|
||||
int encoder_count = 0;
|
||||
for (int i = 0; i < real_bsz; i++) {
|
||||
if (seq_lens_encoder[i] > 0) {
|
||||
msg_sed.mtext[3 * encoder_count + 2] = i;
|
||||
msg_sed.mtext[3 * encoder_count + 3] = seq_lens_decoder[i];
|
||||
msg_sed.mtext[3 * encoder_count + 4] = seq_lens_encoder[i];
|
||||
encoder_count++;
|
||||
msg_sed.mtext[2 * i + 2] = i;
|
||||
msg_sed.mtext[2 * i + 3] = seq_lens_decoder[i];
|
||||
}
|
||||
}
|
||||
msg_sed.mtext[0] = encoder_count;
|
||||
@@ -82,7 +83,7 @@ struct RemoteCacheKvIpc {
|
||||
|
||||
void CUDART_CB send_signal() {
|
||||
msg_sed.mtext[1] = layer_id_;
|
||||
if ((msgsnd(msgid, &msg_sed, (MAX_BSZ * 2 + 2) * 4, 0)) == -1) {
|
||||
if ((msgsnd(msgid, &msg_sed, (MAX_BSZ * 3 + 2) * 4, 0)) == -1) {
|
||||
printf("kv signal full msg buffer\n");
|
||||
}
|
||||
layer_id_ = (layer_id_ + 1);
|
||||
|
@@ -16,6 +16,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional
|
||||
@@ -109,7 +110,7 @@ class ModelConfig:
|
||||
|
||||
self.ori_vocab_size = self.vocab_size
|
||||
if "Ernie4_5_ForCausalLM" in self.architectures or "Ernie4_5_MoeForCausalLM" in self.architectures:
|
||||
self.ori_vocab_size = args["ori_vocab_size"]
|
||||
self.ori_vocab_size = args.get("ori_vocab_size", self.ori_vocab_size)
|
||||
|
||||
class ParallelConfig:
|
||||
"""Configuration for the distributed execution."""
|
||||
@@ -191,6 +192,18 @@ class ParallelConfig:
|
||||
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
|
||||
self.enable_custom_all_reduce: bool = False
|
||||
|
||||
# pd_disaggregation
|
||||
use_pd_disaggregation: int = int(
|
||||
os.getenv("FLAGS_use_pd_disaggregation", 0))
|
||||
use_pd_disaggregation_per_chunk: int = int(
|
||||
os.getenv("FLAGS_use_pd_disaggregation_per_chunk", 0))
|
||||
if use_pd_disaggregation_per_chunk:
|
||||
self.pd_disaggregation_mode = "per_chunk"
|
||||
elif use_pd_disaggregation:
|
||||
self.pd_disaggregation_mode = "per_query"
|
||||
else:
|
||||
self.pd_disaggregation_mode = "None"
|
||||
|
||||
class SpeculativeConfig:
|
||||
"""
|
||||
Configuration for speculative decoding.
|
||||
|
@@ -24,7 +24,8 @@ import paddle
|
||||
|
||||
from fastdeploy.model_executor.layers.attention.ops import (
|
||||
append_attention, get_block_shape_and_split_kv_block,
|
||||
init_signal_layerwise, open_shm_and_get_meta_signal)
|
||||
init_signal_layerwise, open_shm_and_get_meta_signal,
|
||||
init_kv_signal_per_query)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
@@ -92,6 +93,7 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
self.use_speculate: bool = self.speculative_method is not None
|
||||
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
|
||||
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
||||
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
|
||||
|
||||
self.kv_num_heads: int = kv_num_heads
|
||||
self.num_heads: int = num_heads
|
||||
@@ -100,9 +102,8 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
self.max_partition_size: int = int(
|
||||
os.getenv("FLAGS_max_partition_size", 32768))
|
||||
|
||||
# pd_disaggregation
|
||||
self.use_pd_disaggregation: int = int(
|
||||
os.getenv("FLAGS_use_pd_disaggregation", 0))
|
||||
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
|
||||
|
||||
self.start_layer_index: int = fd_config.model_config.start_layer_index
|
||||
|
||||
if fd_config.parallel_config.expert_parallel_rank is None:
|
||||
@@ -154,9 +155,19 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
|
||||
# pd_disaggregation
|
||||
metadata.kv_signal_data_list = [None] * self.num_layers
|
||||
if self.use_pd_disaggregation:
|
||||
if self.pd_disaggregation_mode == "per_chunk":
|
||||
if not self.keep_pd_step_flag:
|
||||
init_kv_signal_per_query(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.seq_lens_decoder,
|
||||
self.rank,
|
||||
self.num_layers + self.num_layers_draft_model,
|
||||
)
|
||||
elif self.pd_disaggregation_mode == "per_query":
|
||||
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
|
||||
self.rank, int(self.device_id), self.keep_pd_step_flag)
|
||||
|
||||
self.attention_metadata: AttentionMetadata = metadata
|
||||
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
|
||||
forward_meta.decoder_tile_ids_per_batch.copy_(
|
||||
@@ -192,7 +203,7 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
metadata = self.attention_metadata
|
||||
|
||||
if self.use_pd_disaggregation:
|
||||
if self.pd_disaggregation_mode == "per_query":
|
||||
metadata.kv_signal_data_list[
|
||||
layer.layer_id] = init_signal_layerwise(
|
||||
metadata.kv_signal_metadata,
|
||||
|
@@ -33,7 +33,8 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend, AttentionMetadata)
|
||||
from fastdeploy.model_executor.layers.attention.ops import (
|
||||
get_block_shape_and_split_kv_block, gqa_rope_write_cache,
|
||||
init_signal_layerwise, open_shm_and_get_meta_signal, pre_cache_len_concat)
|
||||
init_signal_layerwise, open_shm_and_get_meta_signal, pre_cache_len_concat,
|
||||
init_kv_signal_per_query)
|
||||
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
|
||||
if TYPE_CHECKING:
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
@@ -102,10 +103,10 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
self.use_speculate = self.speculative_method is not None
|
||||
self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
|
||||
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
||||
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
|
||||
|
||||
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
|
||||
|
||||
# pd_disaggregation
|
||||
self.use_pd_disaggregation: int = int(
|
||||
os.getenv("FLAGS_use_pd_disaggregation", 0))
|
||||
self.start_layer_index: int = fd_config.model_config.start_layer_index
|
||||
|
||||
if fd_config.parallel_config.expert_parallel_rank is None:
|
||||
@@ -173,7 +174,16 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
# pd_disaggregation
|
||||
metadata.kv_signal_data_list = [None] * self.num_layers
|
||||
if self.use_pd_disaggregation:
|
||||
if self.pd_disaggregation_mode == "per_chunk":
|
||||
if not self.keep_pd_step_flag:
|
||||
init_kv_signal_per_query(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.seq_lens_decoder,
|
||||
self.rank,
|
||||
self.num_layers + self.num_layers_draft_model,
|
||||
)
|
||||
elif self.pd_disaggregation_mode == "per_query":
|
||||
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
|
||||
self.rank, int(self.device_id), self.keep_pd_step_flag)
|
||||
self.attention_metadata = metadata
|
||||
@@ -194,7 +204,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
):
|
||||
metadata = self.attention_metadata
|
||||
|
||||
if self.use_pd_disaggregation:
|
||||
if self.pd_disaggregation_mode == "per_query":
|
||||
metadata.kv_signal_data_list[
|
||||
layer.layer_id] = init_signal_layerwise(
|
||||
metadata.kv_signal_metadata,
|
||||
|
@@ -26,7 +26,7 @@ from paddle.nn.functional.flash_attention import flash_attn_unpadded
|
||||
|
||||
from fastdeploy.model_executor.layers.attention.ops import (
|
||||
get_block_shape_and_split_kv_block, init_signal_layerwise,
|
||||
open_shm_and_get_meta_signal)
|
||||
open_shm_and_get_meta_signal, init_kv_signal_per_query)
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda() and not current_platform.is_dcu():
|
||||
@@ -109,6 +109,7 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
self.use_speculate: bool = self.speculative_method is not None
|
||||
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
|
||||
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
||||
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
|
||||
|
||||
self.kv_num_heads: int = kv_num_heads
|
||||
self.num_heads: int = num_heads
|
||||
@@ -129,9 +130,8 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.attn_softmax_scale = self.attn_softmax_scale * mscale * mscale
|
||||
|
||||
# pd_disaggregation
|
||||
self.use_pd_disaggregation: int = int(
|
||||
os.getenv("FLAGS_use_pd_disaggregation", 0))
|
||||
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
|
||||
|
||||
self.start_layer_index: int = fd_config.model_config.start_layer_index
|
||||
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None)
|
||||
|
||||
@@ -189,7 +189,16 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
|
||||
# pd_disaggregation
|
||||
metadata.kv_signal_data_list = [None] * self.num_layers
|
||||
if self.use_pd_disaggregation:
|
||||
if self.pd_disaggregation_mode == "per_chunk":
|
||||
if not self.keep_pd_step_flag:
|
||||
init_kv_signal_per_query(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.seq_lens_decoder,
|
||||
self.rank,
|
||||
self.num_layers + self.num_layers_draft_model,
|
||||
)
|
||||
elif self.pd_disaggregation_mode == "per_query":
|
||||
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
|
||||
self.rank, int(self.device_id), self.keep_pd_step_flag)
|
||||
|
||||
@@ -223,7 +232,7 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
metadata = self.attention_metadata
|
||||
|
||||
if self.use_pd_disaggregation:
|
||||
if self.pd_disaggregation_mode == "per_query":
|
||||
metadata.kv_signal_data_list[
|
||||
layer.layer_id] = init_signal_layerwise(
|
||||
metadata.kv_signal_metadata,
|
||||
|
@@ -21,6 +21,7 @@ from .gqa_rope_write_cache import gqa_rope_write_cache
|
||||
from .init_signal_layerwise import init_signal_layerwise
|
||||
from .open_shm_and_get_meta_signal import open_shm_and_get_meta_signal
|
||||
from .pre_cache_len_concat import pre_cache_len_concat
|
||||
from .init_kv_signal_per_query import init_kv_signal_per_query
|
||||
|
||||
__all__ = [
|
||||
"get_block_shape_and_split_kv_block",
|
||||
@@ -29,4 +30,5 @@ __all__ = [
|
||||
"init_signal_layerwise",
|
||||
"gqa_rope_write_cache",
|
||||
"pre_cache_len_concat",
|
||||
"init_kv_signal_per_query"
|
||||
]
|
||||
|
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
|
||||
def init_kv_signal_per_query(
|
||||
seq_lens_encoder: paddle.Tensor,
|
||||
seq_lens_this_time: paddle.Tensor,
|
||||
seq_lens_decoder: paddle.Tensor,
|
||||
rank: int,
|
||||
num_layers: int,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
init_kv_signal_per_query
|
||||
"""
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import init_kv_signal_per_query
|
||||
out = init_kv_signal_per_query(seq_lens_encoder, seq_lens_this_time, seq_lens_decoder, rank, num_layers)
|
||||
return out
|
||||
else:
|
||||
raise NotImplementedError()
|
@@ -73,7 +73,7 @@ def load_ep_checkpoint(model_path: str,
|
||||
range(base_range.start + config.moe_num_experts[0], base_range.stop + config.moe_num_experts[0]))
|
||||
return base_range
|
||||
|
||||
for i in range(config.moe_layer_start_index, config.num_layers):
|
||||
for i in range(config.moe_layer_start_index, config.num_hidden_layers):
|
||||
for j in get_expert_ranges(config):
|
||||
up_gate_proj_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight"
|
||||
down_proj_key = (f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight")
|
||||
|
Reference in New Issue
Block a user