diff --git a/custom_ops/gpu_ops/get_output_ep.cc b/custom_ops/gpu_ops/get_output_ep.cc index 9fbc34cb6..f5f742022 100644 --- a/custom_ops/gpu_ops/get_output_ep.cc +++ b/custom_ops/gpu_ops/get_output_ep.cc @@ -36,9 +36,9 @@ void GetOutputKVSignal(const paddle::Tensor& x, int* out_data = const_cast(x.data()); int ret = -1; if (!wait_flag) { - ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 2 + 2) * 4, 0, IPC_NOWAIT); + ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, IPC_NOWAIT); } else { - ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 2 + 2) * 4, 0, 0); + ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, 0); } if (ret == -1) { out_data[0] = -1; @@ -47,7 +47,7 @@ void GetOutputKVSignal(const paddle::Tensor& x, } int encoder_count = msg_rcv.mtext[0]; - for (int i = 0; i < encoder_count * 2 + 2; i++) { + for (int i = 0; i < encoder_count * 3 + 2; i++) { out_data[i] = msg_rcv.mtext[i]; } return; diff --git a/custom_ops/gpu_ops/msg_utils.h b/custom_ops/gpu_ops/msg_utils.h index b4c33551e..e3ca0f646 100644 --- a/custom_ops/gpu_ops/msg_utils.h +++ b/custom_ops/gpu_ops/msg_utils.h @@ -35,5 +35,5 @@ struct msgdata { struct msgdatakv { long mtype; - int mtext[MAX_BSZ * 2 + 2]; // encoder_count, layer_id, bid- pair + int mtext[MAX_BSZ * 3 + 2]; // encoder_count, layer_id, bid- pair }; \ No newline at end of file diff --git a/custom_ops/gpu_ops/remote_cache_kv_ipc.h b/custom_ops/gpu_ops/remote_cache_kv_ipc.h index 5a4f6065d..4694e0b39 100644 --- a/custom_ops/gpu_ops/remote_cache_kv_ipc.h +++ b/custom_ops/gpu_ops/remote_cache_kv_ipc.h @@ -64,9 +64,10 @@ struct RemoteCacheKvIpc { int encoder_count = 0; for (int i = 0; i < real_bsz; i++) { if (seq_lens_encoder[i] > 0) { + msg_sed.mtext[3 * encoder_count + 2] = i; + msg_sed.mtext[3 * encoder_count + 3] = seq_lens_decoder[i]; + msg_sed.mtext[3 * encoder_count + 4] = seq_lens_encoder[i]; encoder_count++; - msg_sed.mtext[2 * i + 2] = i; - msg_sed.mtext[2 * i + 3] = seq_lens_decoder[i]; } } msg_sed.mtext[0] = encoder_count; @@ -82,7 +83,7 @@ struct RemoteCacheKvIpc { void CUDART_CB send_signal() { msg_sed.mtext[1] = layer_id_; - if ((msgsnd(msgid, &msg_sed, (MAX_BSZ * 2 + 2) * 4, 0)) == -1) { + if ((msgsnd(msgid, &msg_sed, (MAX_BSZ * 3 + 2) * 4, 0)) == -1) { printf("kv signal full msg buffer\n"); } layer_id_ = (layer_id_ + 1); diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 2e5473728..30f0b4781 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -16,6 +16,7 @@ from __future__ import annotations +import os from dataclasses import dataclass, field from enum import Enum from typing import Literal, Optional @@ -109,7 +110,7 @@ class ModelConfig: self.ori_vocab_size = self.vocab_size if "Ernie4_5_ForCausalLM" in self.architectures or "Ernie4_5_MoeForCausalLM" in self.architectures: - self.ori_vocab_size = args["ori_vocab_size"] + self.ori_vocab_size = args.get("ori_vocab_size", self.ori_vocab_size) class ParallelConfig: """Configuration for the distributed execution.""" @@ -191,6 +192,18 @@ class ParallelConfig: # enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce). self.enable_custom_all_reduce: bool = False + # pd_disaggregation + use_pd_disaggregation: int = int( + os.getenv("FLAGS_use_pd_disaggregation", 0)) + use_pd_disaggregation_per_chunk: int = int( + os.getenv("FLAGS_use_pd_disaggregation_per_chunk", 0)) + if use_pd_disaggregation_per_chunk: + self.pd_disaggregation_mode = "per_chunk" + elif use_pd_disaggregation: + self.pd_disaggregation_mode = "per_query" + else: + self.pd_disaggregation_mode = "None" + class SpeculativeConfig: """ Configuration for speculative decoding. diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index af5313054..f41deb62e 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -24,7 +24,8 @@ import paddle from fastdeploy.model_executor.layers.attention.ops import ( append_attention, get_block_shape_and_split_kv_block, - init_signal_layerwise, open_shm_and_get_meta_signal) + init_signal_layerwise, open_shm_and_get_meta_signal, + init_kv_signal_per_query) if TYPE_CHECKING: from fastdeploy.model_executor.forward_meta import ForwardMeta @@ -92,6 +93,7 @@ class AppendAttentionBackend(AttentionBackend): self.use_speculate: bool = self.speculative_method is not None self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" + self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"]) self.kv_num_heads: int = kv_num_heads self.num_heads: int = num_heads @@ -100,9 +102,8 @@ class AppendAttentionBackend(AttentionBackend): self.max_partition_size: int = int( os.getenv("FLAGS_max_partition_size", 32768)) - # pd_disaggregation - self.use_pd_disaggregation: int = int( - os.getenv("FLAGS_use_pd_disaggregation", 0)) + self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode + self.start_layer_index: int = fd_config.model_config.start_layer_index if fd_config.parallel_config.expert_parallel_rank is None: @@ -154,9 +155,19 @@ class AppendAttentionBackend(AttentionBackend): # pd_disaggregation metadata.kv_signal_data_list = [None] * self.num_layers - if self.use_pd_disaggregation: + if self.pd_disaggregation_mode == "per_chunk": + if not self.keep_pd_step_flag: + init_kv_signal_per_query( + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_this_time, + forward_meta.seq_lens_decoder, + self.rank, + self.num_layers + self.num_layers_draft_model, + ) + elif self.pd_disaggregation_mode == "per_query": metadata.kv_signal_metadata = open_shm_and_get_meta_signal( self.rank, int(self.device_id), self.keep_pd_step_flag) + self.attention_metadata: AttentionMetadata = metadata forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False) forward_meta.decoder_tile_ids_per_batch.copy_( @@ -192,7 +203,7 @@ class AppendAttentionBackend(AttentionBackend): """ metadata = self.attention_metadata - if self.use_pd_disaggregation: + if self.pd_disaggregation_mode == "per_query": metadata.kv_signal_data_list[ layer.layer_id] = init_signal_layerwise( metadata.kv_signal_metadata, diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index d78b444d2..4c1cde80b 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -33,7 +33,8 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, AttentionMetadata) from fastdeploy.model_executor.layers.attention.ops import ( get_block_shape_and_split_kv_block, gqa_rope_write_cache, - init_signal_layerwise, open_shm_and_get_meta_signal, pre_cache_len_concat) + init_signal_layerwise, open_shm_and_get_meta_signal, pre_cache_len_concat, + init_kv_signal_per_query) from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id if TYPE_CHECKING: from fastdeploy.model_executor.forward_meta import ForwardMeta @@ -102,10 +103,10 @@ class FlashAttentionBackend(AttentionBackend): self.use_speculate = self.speculative_method is not None self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" + self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"]) + + self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode - # pd_disaggregation - self.use_pd_disaggregation: int = int( - os.getenv("FLAGS_use_pd_disaggregation", 0)) self.start_layer_index: int = fd_config.model_config.start_layer_index if fd_config.parallel_config.expert_parallel_rank is None: @@ -173,7 +174,16 @@ class FlashAttentionBackend(AttentionBackend): # pd_disaggregation metadata.kv_signal_data_list = [None] * self.num_layers - if self.use_pd_disaggregation: + if self.pd_disaggregation_mode == "per_chunk": + if not self.keep_pd_step_flag: + init_kv_signal_per_query( + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_this_time, + forward_meta.seq_lens_decoder, + self.rank, + self.num_layers + self.num_layers_draft_model, + ) + elif self.pd_disaggregation_mode == "per_query": metadata.kv_signal_metadata = open_shm_and_get_meta_signal( self.rank, int(self.device_id), self.keep_pd_step_flag) self.attention_metadata = metadata @@ -194,7 +204,7 @@ class FlashAttentionBackend(AttentionBackend): ): metadata = self.attention_metadata - if self.use_pd_disaggregation: + if self.pd_disaggregation_mode == "per_query": metadata.kv_signal_data_list[ layer.layer_id] = init_signal_layerwise( metadata.kv_signal_metadata, diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index a29d5fe68..3940eb780 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -26,7 +26,7 @@ from paddle.nn.functional.flash_attention import flash_attn_unpadded from fastdeploy.model_executor.layers.attention.ops import ( get_block_shape_and_split_kv_block, init_signal_layerwise, - open_shm_and_get_meta_signal) + open_shm_and_get_meta_signal, init_kv_signal_per_query) from fastdeploy.platforms import current_platform if current_platform.is_cuda() and not current_platform.is_dcu(): @@ -109,6 +109,7 @@ class MLAAttentionBackend(AttentionBackend): self.use_speculate: bool = self.speculative_method is not None self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" + self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"]) self.kv_num_heads: int = kv_num_heads self.num_heads: int = num_heads @@ -129,9 +130,8 @@ class MLAAttentionBackend(AttentionBackend): mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.attn_softmax_scale = self.attn_softmax_scale * mscale * mscale - # pd_disaggregation - self.use_pd_disaggregation: int = int( - os.getenv("FLAGS_use_pd_disaggregation", 0)) + self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode + self.start_layer_index: int = fd_config.model_config.start_layer_index self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None) @@ -189,7 +189,16 @@ class MLAAttentionBackend(AttentionBackend): # pd_disaggregation metadata.kv_signal_data_list = [None] * self.num_layers - if self.use_pd_disaggregation: + if self.pd_disaggregation_mode == "per_chunk": + if not self.keep_pd_step_flag: + init_kv_signal_per_query( + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_this_time, + forward_meta.seq_lens_decoder, + self.rank, + self.num_layers + self.num_layers_draft_model, + ) + elif self.pd_disaggregation_mode == "per_query": metadata.kv_signal_metadata = open_shm_and_get_meta_signal( self.rank, int(self.device_id), self.keep_pd_step_flag) @@ -223,7 +232,7 @@ class MLAAttentionBackend(AttentionBackend): """ metadata = self.attention_metadata - if self.use_pd_disaggregation: + if self.pd_disaggregation_mode == "per_query": metadata.kv_signal_data_list[ layer.layer_id] = init_signal_layerwise( metadata.kv_signal_metadata, diff --git a/fastdeploy/model_executor/layers/attention/ops/__init__.py b/fastdeploy/model_executor/layers/attention/ops/__init__.py index 95cc06129..a44ca7cbf 100644 --- a/fastdeploy/model_executor/layers/attention/ops/__init__.py +++ b/fastdeploy/model_executor/layers/attention/ops/__init__.py @@ -21,6 +21,7 @@ from .gqa_rope_write_cache import gqa_rope_write_cache from .init_signal_layerwise import init_signal_layerwise from .open_shm_and_get_meta_signal import open_shm_and_get_meta_signal from .pre_cache_len_concat import pre_cache_len_concat +from .init_kv_signal_per_query import init_kv_signal_per_query __all__ = [ "get_block_shape_and_split_kv_block", @@ -29,4 +30,5 @@ __all__ = [ "init_signal_layerwise", "gqa_rope_write_cache", "pre_cache_len_concat", + "init_kv_signal_per_query" ] diff --git a/fastdeploy/model_executor/layers/attention/ops/init_kv_signal_per_query.py b/fastdeploy/model_executor/layers/attention/ops/init_kv_signal_per_query.py new file mode 100644 index 000000000..866c0f168 --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/ops/init_kv_signal_per_query.py @@ -0,0 +1,37 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import paddle + +from fastdeploy.platforms import current_platform + + +def init_kv_signal_per_query( + seq_lens_encoder: paddle.Tensor, + seq_lens_this_time: paddle.Tensor, + seq_lens_decoder: paddle.Tensor, + rank: int, + num_layers: int, +) -> paddle.Tensor: + """ + init_kv_signal_per_query + """ + if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import init_kv_signal_per_query + out = init_kv_signal_per_query(seq_lens_encoder, seq_lens_this_time, seq_lens_decoder, rank, num_layers) + return out + else: + raise NotImplementedError() diff --git a/fastdeploy/model_executor/load_weight_utils.py b/fastdeploy/model_executor/load_weight_utils.py index 64a160884..d1698776b 100644 --- a/fastdeploy/model_executor/load_weight_utils.py +++ b/fastdeploy/model_executor/load_weight_utils.py @@ -73,7 +73,7 @@ def load_ep_checkpoint(model_path: str, range(base_range.start + config.moe_num_experts[0], base_range.stop + config.moe_num_experts[0])) return base_range - for i in range(config.moe_layer_start_index, config.num_layers): + for i in range(config.moe_layer_start_index, config.num_hidden_layers): for j in get_expert_ranges(config): up_gate_proj_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight" down_proj_key = (f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight")