[PD Disaggregation][XPU] Add XPU support for PD disaggregation (#5113)

* [XPU] xpu support PD disaggregation

* [XPU] fix the issue of cache KV transfer process startup failure on non-zero XPU cards

* [XPU] xpu support PD disaggregation in v1 scheduler

---------

Co-authored-by: ddchenhao66 <dhaochen163.com>
This commit is contained in:
ddchenhao66
2025-11-21 14:09:01 +08:00
committed by GitHub
parent 79f18331b6
commit e70e2279ce
16 changed files with 273 additions and 81 deletions

View File

@@ -0,0 +1,54 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ops/remote_cache_kv_ipc.h"
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
using cache_write_complete_signal_type =
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data;
paddle::Tensor InitSignalLayerwiseFunc(const paddle::Tensor& kv_signal_metadata,
const int layer_id) {
auto kv_signal_metadata_out =
kv_signal_metadata.copy_to(paddle::CPUPlace(), false);
kv_signal_metadata_out.data<int64_t>()[0] = static_cast<int64_t>(layer_id);
return kv_signal_metadata_out;
}
std::vector<paddle::Tensor> InitSignalLayerwise(
const paddle::Tensor& kv_signal_metadata, const int layer_id) {
return {InitSignalLayerwiseFunc(kv_signal_metadata, layer_id)};
}
std::vector<std::vector<int64_t>> InitSignalLayerwiseShape(
const std::vector<int64_t>& kv_signal_metadata_shape, const int layer_id) {
return {kv_signal_metadata_shape};
}
std::vector<paddle::DataType> InitSignalLayerwiseDtype(
const paddle::DataType& kv_signal_metadata_dtype, const int layer_id) {
return {paddle::DataType::INT64};
}
PD_BUILD_STATIC_OP(init_signal_layerwise)
.Inputs({"kv_signal_metadata"})
.Outputs({"kv_signal_metadata_out"})
.Attrs({"layer_id: int"})
.SetKernelFn(PD_KERNEL(InitSignalLayerwise))
.SetInferShapeFn(PD_INFER_SHAPE(InitSignalLayerwiseShape))
.SetInferDtypeFn(PD_INFER_DTYPE(InitSignalLayerwiseDtype));

View File

@@ -17,22 +17,27 @@
#include "ops/utility/env.h"
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
XPU_DECLARE_BOOL(fmt_write_cache_completed_signal, false);
using cache_write_complete_signal_type =
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data;
paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank,
const int device_id,
const bool keep_pd_step_flag) {
cache_write_complete_signal_type kv_signal_metadata;
const char *fmt_write_cache_completed_signal_str =
const char* fmt_write_cache_completed_signal_str =
std::getenv("FLAGS_fmt_write_cache_completed_signal");
if (fmt_write_cache_completed_signal_str &&
(std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 ||
std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) {
kv_signal_metadata =
RemoteCacheKvIpc::open_shm_and_get_complete_signal_meta_data(
rank, keep_pd_step_flag);
rank, device_id, keep_pd_step_flag);
}
auto kv_signal_metadata_out =
@@ -46,9 +51,9 @@ paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank,
return kv_signal_metadata_out;
}
void InitKVSignalPerQuery(const paddle::Tensor &seq_lens_encoder_tensor,
const paddle::Tensor &seq_lens_this_time_tensor,
const paddle::Tensor &seq_lens_decoder_tensor,
void InitKVSignalPerQuery(const paddle::Tensor& seq_lens_encoder_tensor,
const paddle::Tensor& seq_lens_this_time_tensor,
const paddle::Tensor& seq_lens_decoder_tensor,
const int rank,
const int num_layers) {
if (FLAGS_fmt_write_cache_completed_signal) {
@@ -68,24 +73,24 @@ void InitKVSignalPerQuery(const paddle::Tensor &seq_lens_encoder_tensor,
}
std::vector<paddle::Tensor> OpenShmAndGetMetaSignal(
const int rank, const bool keep_pd_step_flag) {
return {OpenShmAndGetMetaSignalFunc(rank, keep_pd_step_flag)};
const int rank, const int device_id, const bool keep_pd_step_flag) {
return {OpenShmAndGetMetaSignalFunc(rank, device_id, keep_pd_step_flag)};
}
std::vector<std::vector<int64_t>> OpenShmAndGetMetaSignalShape(
const int rank, const bool keep_pd_step_flag) {
const int rank, const int device_id, const bool keep_pd_step_flag) {
return {{3}};
}
std::vector<paddle::DataType> OpenShmAndGetMetaSignalDtype(
const int rank, const bool keep_pd_step_flag) {
const int rank, const int device_id, const bool keep_pd_step_flag) {
return {paddle::DataType::INT64};
}
PD_BUILD_OP(open_shm_and_get_meta_signal)
PD_BUILD_STATIC_OP(open_shm_and_get_meta_signal)
.Inputs({})
.Outputs({"kv_signal_metadata"})
.Attrs({"rank: int", "keep_pd_step_flag: bool"})
.Attrs({"rank: int", "device_id: int", "keep_pd_step_flag: bool"})
.SetKernelFn(PD_KERNEL(OpenShmAndGetMetaSignal))
.SetInferShapeFn(PD_INFER_SHAPE(OpenShmAndGetMetaSignalShape))
.SetInferDtypeFn(PD_INFER_DTYPE(OpenShmAndGetMetaSignalDtype));

View File

@@ -26,7 +26,7 @@ bool RemoteCacheKvIpc::kv_complete_signal_shmem_opened = false;
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data
RemoteCacheKvIpc::open_shm_and_get_complete_signal_meta_data(
const int rank_id, const bool keep_pd_step_flag) {
const int rank_id, const int device_id, const bool keep_pd_step_flag) {
if (RemoteCacheKvIpc::kv_complete_signal_shmem_opened) {
if (keep_pd_step_flag) {
return RemoteCacheKvIpc::kv_complete_signal_meta_data;
@@ -47,12 +47,13 @@ RemoteCacheKvIpc::open_shm_and_get_complete_signal_meta_data(
std::string iflags_server_uuid_env_str(iflags_server_uuid_env_p);
flags_server_uuid = iflags_server_uuid_env_str;
}
std::string step_shm_name =
("splitwise_complete_prefilled_step_" + std::to_string(rank_id) + "_" +
flags_server_uuid);
("splitwise_complete_prefilled_step_" + std::to_string(rank_id) + "." +
std::to_string(device_id));
std::string layer_shm_name =
("splitwise_complete_prefilled_layer_" + std::to_string(rank_id) + "_" +
flags_server_uuid);
("splitwise_complete_prefilled_layer_" + std::to_string(rank_id) + "." +
std::to_string(device_id));
if (const char* use_ep = std::getenv("ENABLE_EP_DP")) {
if (std::strcmp(use_ep, "1") == 0) {
step_shm_name = "splitwise_complete_prefilled_step_tprank0_dprank" +

View File

@@ -93,6 +93,7 @@ struct RemoteCacheKvIpc {
static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data
open_shm_and_get_complete_signal_meta_data(const int rank_id,
const int device_id,
const bool keep_pd_step_flag);
static void save_cache_kv_complete_signal_layerwise(void* meta_data);
static void save_cache_kv_complete_signal_layerwise_per_query(

View File

@@ -19,26 +19,26 @@
#include "xpu/plugin.h"
#include "xpu_multiprocess.h" // NOLINT(build/include_subdir)
std::vector<paddle::Tensor> ShareExternalData(const paddle::Tensor &input,
std::vector<paddle::Tensor> ShareExternalData(const paddle::Tensor& input,
const std::string shm_name,
const std::vector<int> &shape,
const std::vector<int>& shape,
bool use_ipc) {
sharedMemoryInfo info;
int ret = sharedMemoryOpen(shm_name.c_str(), sizeof(shmStruct), &info);
PD_CHECK(ret == 0, "sharedMemoryOpen failed");
volatile shmStruct *shm = static_cast<volatile shmStruct *>(info.addr);
void *data_ptr_addr = nullptr;
volatile shmStruct* shm = static_cast<volatile shmStruct*>(info.addr);
void* data_ptr_addr = nullptr;
if (use_ipc) {
#if XPURT_VERSION_MAJOR == 5
int ret = xpu_ipc_open_memhandle(&data_ptr_addr,
*(XPUIpcMemHandle *)&shm->memHandle,
*(XPUIpcMemHandle*)&shm->memHandle,
0x01); // NOLINT
PD_CHECK(ret == XPU_SUCCESS, "xpu_ipc_open_memhandle failed");
PD_CHECK(ret == XPU_SUCCESS, shm_name, " xpu_ipc_open_memhandle failed");
#elif XPURT_VERSION_MAJOR == 4
PD_THROW("kl2 not support prefix cache");
#endif
} else {
data_ptr_addr = reinterpret_cast<void *>(shm->data_ptr_addr);
data_ptr_addr = reinterpret_cast<void*>(shm->data_ptr_addr);
}
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());

View File

@@ -25,6 +25,13 @@ import traceback
import numpy as np
import paddle
from fastdeploy.cache_manager.ops import (
get_output_kv_signal,
get_peer_mem_addr,
memory_allocated,
set_data_ipc,
set_device,
)
from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager
from fastdeploy.config import SpeculativeConfig
from fastdeploy.inter_communicator import (
@@ -32,7 +39,6 @@ from fastdeploy.inter_communicator import (
IPCSignal,
shared_memory_exists,
)
from fastdeploy.model_executor.ops.gpu import get_output_kv_signal, set_data_ipc
from fastdeploy.utils import envs, get_logger
logger = get_logger("cache_messager", "cache_messager.log")
@@ -157,8 +163,12 @@ class CacheMessager:
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
cache_k.append(key_cache)
cache_v.append(val_cache)
cache_k_ptr_list.append(key_cache.data_ptr())
cache_v_ptr_list.append(val_cache.data_ptr())
if paddle.is_compiled_with_xpu():
cache_k_ptr_list.append(get_peer_mem_addr(key_cache.data_ptr()))
cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr()))
else:
cache_k_ptr_list.append(key_cache.data_ptr())
cache_v_ptr_list.append(val_cache.data_ptr())
cache_k_ptr_list = np.array(cache_k_ptr_list)
cache_v_ptr_list = np.array(cache_v_ptr_list)
@@ -166,7 +176,7 @@ class CacheMessager:
cache_shape = key_cache.shape
max_block_num = cache_shape[0]
block_bytes = math.prod(cache_shape[1:])
if key_cache.dtype == paddle.bfloat16:
if key_cache.dtype == paddle.bfloat16 or key_cache.dtype == paddle.float16:
block_bytes *= 2
logger.info(
f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
@@ -452,8 +462,12 @@ class CacheMessagerV1:
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
cache_k.append(key_cache)
cache_v.append(val_cache)
cache_k_ptr_list.append(key_cache.data_ptr())
cache_v_ptr_list.append(val_cache.data_ptr())
if paddle.is_compiled_with_xpu():
cache_k_ptr_list.append(get_peer_mem_addr(key_cache.data_ptr()))
cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr()))
else:
cache_k_ptr_list.append(key_cache.data_ptr())
cache_v_ptr_list.append(val_cache.data_ptr())
cache_k_ptr_list = np.array(cache_k_ptr_list)
cache_v_ptr_list = np.array(cache_v_ptr_list)
@@ -763,7 +777,7 @@ class CacheMessagerV1:
def main():
device = args.device_id
rank = args.rank
paddle.set_device(f"gpu:{device}")
set_device(device)
cache_type = args.cache_dtype
speculative_config = SpeculativeConfig(args.speculative_config)
num_extra_layers = speculative_config.num_extra_cache_layer
@@ -823,7 +837,7 @@ def main():
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in gpu_cache_kvs.items()])
logger.info(f"device :{device}")
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
logger.info(f"done init cache (full) gmem alloc : {memory_allocated}")
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
cache_messager = CacheMessagerV1(
@@ -875,7 +889,6 @@ if __name__ == "__main__":
args = parse_args()
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
logger = get_logger("cache_messager", f"cache_messager_rank{rank_id}.log")
logger.info("create cache messager...")
logger.info(f"{args}")
main()

View File

@@ -6,6 +6,10 @@ if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import (
cuda_host_alloc,
cuda_host_free,
get_data_ptr_ipc,
get_output_kv_signal,
ipc_sent_key_value_cache_by_remote_ptr,
ipc_sent_key_value_cache_by_remote_ptr_block_sync,
set_data_ipc,
share_external_data,
swap_cache_all_layers,
@@ -13,10 +17,16 @@ if current_platform.is_cuda():
)
memory_allocated = paddle.device.cuda.memory_allocated
def get_peer_mem_addr(*args, **kwargs):
raise RuntimeError("CUDA no need of get_peer_mem_addr!")
elif current_platform.is_xpu():
from fastdeploy.model_executor.ops.xpu import (
cuda_host_alloc,
cuda_host_free,
get_output_kv_signal,
get_peer_mem_addr,
set_data_ipc,
share_external_data,
swap_cache_all_layers,
@@ -25,6 +35,15 @@ elif current_platform.is_xpu():
unset_data_ipc = None
memory_allocated = paddle.device.xpu.memory_allocated
def get_data_ptr_ipc(*args, **kwargs):
raise RuntimeError("XPU get_data_ptr_ipc UNIMPLENENTED!")
def ipc_sent_key_value_cache_by_remote_ptr(*args, **kwargs):
raise RuntimeError("XPU ipc_sent_key_value_cache_by_remote_ptr UNIMPLENENTED")
def ipc_sent_key_value_cache_by_remote_ptr_block_sync(*args, **kwargs):
raise RuntimeError("XPU No ipc_sent_key_value_cache_by_remote_ptr UNIMPLENENTED")
else:
raise RuntimeError("Prefix cache ops only supported CUDA nor XPU platform ")
@@ -48,6 +67,13 @@ def share_external_data_(cache, cache_name, cache_shape, use_ipc):
return cache
def get_all_visible_devices():
if current_platform.is_xpu():
return "XPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
else:
return "CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
__all__ = [
"cuda_host_alloc",
"cuda_host_free",
@@ -57,4 +83,10 @@ __all__ = [
"unset_data_ipc", # XPU是 None
"set_device",
"memory_allocated",
"get_output_kv_signal",
"get_data_ptr_ipc",
"ipc_sent_key_value_cache_by_remote_ptr",
"ipc_sent_key_value_cache_by_remote_ptr_block_sync",
"get_peer_mem_addr",
"get_all_visible_devices",
]

View File

@@ -33,6 +33,7 @@ import numpy as np
from fastdeploy import envs
from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus
from fastdeploy.cache_manager.cache_metrics import CacheMetrics
from fastdeploy.cache_manager.ops import get_all_visible_devices
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTreeStatus
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import get_logger
@@ -243,9 +244,11 @@ class PrefixCacheManager:
# Run command to launch cache transfer managers
log_dir = envs.FD_LOG_DIR
cache_manager_processes = []
visible_devices = get_all_visible_devices()
for i in range(tensor_parallel_size):
launch_cmd = (
"FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
"FLAGS_allocator_strategy=auto_growth "
+ visible_devices
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
+ f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}"
+ f" {sys.executable} {py_path}"
@@ -328,9 +331,11 @@ class PrefixCacheManager:
py_path = os.path.join(current_dir_path, filename)
log_dir = envs.FD_LOG_DIR
cache_messager_processes = []
visible_devices = get_all_visible_devices()
for i in range(tensor_parallel_size):
launch_cmd = (
"FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
"FLAGS_allocator_strategy=auto_growth "
+ visible_devices
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
+ f" {sys.executable} {py_path}"
+ f" --device_id {int(device_ids[i])}"

View File

@@ -16,7 +16,7 @@
import paddle
from fastdeploy.model_executor.ops.gpu import (
from fastdeploy.cache_manager.ops import (
get_data_ptr_ipc,
ipc_sent_key_value_cache_by_remote_ptr,
ipc_sent_key_value_cache_by_remote_ptr_block_sync,

View File

@@ -246,6 +246,8 @@ class XPUForwardMeta(ForwardMeta):
total_enc_len: Optional[paddle.Tensor] = None
# position embedding type in rope, supports 'NORMAL' or 'HALF_HEAD_DIM'
pos_emb_type: Optional[str] = "NORMAL"
# for pd_disaggregation
kv_signal_sender: Optional[paddle.Tensor] = None
@dataclass

View File

@@ -32,6 +32,17 @@ def init_kv_signal_per_query(
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import init_kv_signal_per_query
out = init_kv_signal_per_query(
seq_lens_encoder,
seq_lens_this_time,
seq_lens_decoder,
rank,
num_layers,
)
return out
elif current_platform.is_xpu():
from fastdeploy.model_executor.ops.xpu import init_kv_signal_per_query
out = init_kv_signal_per_query(
seq_lens_encoder,
seq_lens_this_time,

View File

@@ -29,6 +29,11 @@ def init_signal_layerwise(
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import init_signal_layerwise
out = init_signal_layerwise(kv_signal_metadata, layer_id)
return out
elif current_platform.is_xpu():
from fastdeploy.model_executor.ops.xpu import init_signal_layerwise
out = init_signal_layerwise(kv_signal_metadata, layer_id)
return out
else:

View File

@@ -30,6 +30,11 @@ def open_shm_and_get_meta_signal(
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import open_shm_and_get_meta_signal
out = open_shm_and_get_meta_signal(rank, device_id, keep_pd_step_flag)
return out
elif current_platform.is_xpu():
from fastdeploy.model_executor.ops.xpu import open_shm_and_get_meta_signal
out = open_shm_and_get_meta_signal(rank, device_id, keep_pd_step_flag)
return out
else:

View File

@@ -17,6 +17,7 @@
import os
from fastdeploy.config import FDConfig
from fastdeploy.platforms import current_platform
def init_rank_and_device_id(fd_config: FDConfig):
@@ -26,7 +27,10 @@ def init_rank_and_device_id(fd_config: FDConfig):
+ fd_config.parallel_config.tensor_parallel_rank
)
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", None)
if current_platform.is_xpu():
cuda_visible_devices = os.getenv("XPU_VISIBLE_DEVICES", None)
else: # default cuda
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", None)
if cuda_visible_devices is None:
device_id = rank

View File

@@ -16,13 +16,13 @@
from __future__ import annotations
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional, Tuple
import paddle
from fastdeploy.model_executor.layers.attention.ops import (
init_kv_signal_per_query,
init_signal_layerwise,
open_shm_and_get_meta_signal,
)
@@ -36,6 +36,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend,
AttentionMetadata,
)
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
@dataclass
@@ -90,7 +91,7 @@ class XPUAttentionBackend(AttentionBackend):
)
self.causal: bool = getattr(fd_config.model_config, "causal", True)
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
self.kv_num_heads: int = kv_num_heads
self.num_heads: int = num_heads
@@ -98,8 +99,10 @@ class XPUAttentionBackend(AttentionBackend):
self.num_layers: int = fd_config.model_config.num_hidden_layers
# pd_disaggregation
self.use_pd_disaggregation: int = int(os.getenv("FLAGS_use_pd_disaggregation", 0))
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
self.start_layer_index: int = fd_config.model_config.start_layer_index
self.rank, self.device_id = init_rank_and_device_id(fd_config)
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
@@ -120,8 +123,20 @@ class XPUAttentionBackend(AttentionBackend):
# pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers
if self.use_pd_disaggregation:
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(self.rank, self.keep_pd_step_flag)
if self.pd_disaggregation_mode == "per_chunk" and not forward_meta.is_profiling:
if not self.keep_pd_step_flag:
init_kv_signal_per_query(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_this_time,
forward_meta.seq_lens_decoder,
self.rank,
self.num_layers + self.num_layers_draft_model,
)
elif self.pd_disaggregation_mode == "per_query":
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
self.rank, int(self.device_id), self.keep_pd_step_flag
)
self.attention_metadata: AttentionMetadata = metadata
def get_attntion_meta(self) -> AttentionMetadata:
@@ -154,8 +169,7 @@ class XPUAttentionBackend(AttentionBackend):
forward_mixed
"""
metadata = self.attention_metadata
if self.use_pd_disaggregation:
if self.pd_disaggregation_mode == "per_query":
metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise(
metadata.kv_signal_metadata,
layer.layer_id + self.start_layer_index,
@@ -197,9 +211,10 @@ class XPUAttentionBackend(AttentionBackend):
v_zp, # zero_point_quant_scale
None, # shift
None, # smooth
None, # kv_signal_data
None, # kv_signal_sender
metadata.kv_signal_data_list[layer.layer_id], # kv_signal_data
forward_meta.kv_signal_sender, # kv_signal_sender
forward_meta.pos_emb_type,
self.rope_3d,
)
return res

View File

@@ -14,6 +14,7 @@
# limitations under the License.
"""
import os
import random
import time
from typing import Dict, List, Optional
@@ -43,6 +44,8 @@ from fastdeploy.model_executor.model_loader import get_model_loader
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
from fastdeploy.model_executor.ops.xpu import (
adjust_batch,
create_kv_signal_sender,
destroy_kv_signal_sender,
get_infer_param,
get_padding_offset,
limit_thinking_content_length_v1,
@@ -68,6 +71,7 @@ def xpu_pre_process(
draft_tokens: Optional[paddle.Tensor] = None,
seq_lens_encoder: Optional[paddle.Tensor] = None,
seq_lens_decoder: Optional[paddle.Tensor] = None,
is_profiling: bool = False,
) -> XPUForwardMeta:
""" """
max_len = input_ids.shape[1]
@@ -152,6 +156,8 @@ def xpu_pre_process(
share_inputs["ids_remove_padding"] = adjusted_input
xpu_forward_meta.ids_remove_padding = adjusted_input
# Set forward_meta.is_profiling to True to skip init_kv_signal_per_query for attention backends
xpu_forward_meta.is_profiling = is_profiling
return xpu_forward_meta
@@ -402,6 +408,8 @@ class XPUModelRunner(ModelRunnerBase):
# Forward meta store the global meta information of the forward
self.forward_meta: ForwardMeta = None
self.pd_disaggregation_mode: str = self.fd_config.parallel_config.pd_disaggregation_mode
def exist_prefill(self):
"""
check whether prefill stage exist
@@ -610,48 +618,75 @@ class XPUModelRunner(ModelRunnerBase):
def insert_prefill_inputs(self, req_dicts: List[Request]):
"""Process inputs for prefill tasks and update share_inputs buffer"""
# NOTE(luotingdan): Set environment variable of prefill node
if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill":
os.environ["PREFILL_NODE_ONE_STEP_STOP"] = "1"
req_len = len(req_dicts)
for i in range(req_len):
request = req_dicts[i]
idx = request.idx
length = len(request.prompt_token_ids)
assert length > 0, "The prompt requested must not be empty."
self.share_inputs["pre_ids"][idx : idx + 1] = -1
self.share_inputs["step_idx"][idx : idx + 1] = 0
self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids)
self.share_inputs["prompt_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids)
if self.enable_mm:
inputs = self._preprocess_mm_task(request.multimodal_inputs)
if inputs.get("images") is not None:
self.share_inputs["image_features"] = self.extract_vision_features(inputs)
# Is Decode Node
if req_dicts[i].disaggregate_info is not None and req_dicts[i].disaggregate_info["role"] == "decode":
self.share_inputs["pre_ids"][idx : idx + 1] = request.prompt_token_ids[-1]
self.share_inputs["input_ids"][idx : idx + 1, 0] = request.prompt_token_ids[0]
self.share_inputs["prompt_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids)
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = length
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = 1
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = 0
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = length
self.share_inputs["prompt_lens"][idx : idx + 1] = length
self.share_inputs["step_idx"][idx : idx + 1] = 1
# TODO support MTP
# if self.speculative_decoding:
# num_prefill_send_token = self.speculative_config.num_speculative_tokens + 1
# self.share_inputs["draft_tokens"][idx : idx + 1, 0:num_prefill_send_token] = paddle.to_tensor(
# request.draft_token_ids[0:num_prefill_send_token],
# dtype="int64",
# )
# self.seq_lens_this_time_buffer[idx : idx + 1] = num_prefill_send_token
else:
self.share_inputs["pre_ids"][idx : idx + 1] = -1
self.share_inputs["step_idx"][idx : idx + 1] = 0
self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids)
self.share_inputs["prompt_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids)
if self.enable_mm:
inputs = self._preprocess_mm_task(request.multimodal_inputs)
if inputs.get("images") is not None:
self.share_inputs["image_features"] = self.extract_vision_features(inputs)
else:
# Compatible with the situation that lacks images and videos
self.share_inputs["image_features"] = None
position_ids = inputs["position_ids"]
length = inputs["input_ids"].shape[1]
self.share_inputs["input_ids"][idx : idx + 1, :length] = inputs["input_ids"]
else:
# Compatible with the situation that lacks images and videos
self.share_inputs["image_features"] = None
position_ids = inputs["position_ids"]
length = inputs["input_ids"].shape[1]
self.share_inputs["input_ids"][idx : idx + 1, :length] = inputs["input_ids"]
else:
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = length
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length
self.share_inputs["prompt_lens"][idx : idx + 1] = length
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = length
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length
self.share_inputs["prompt_lens"][idx : idx + 1] = length
if self.enable_mm:
self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d(
position_ids, [request.get("max_tokens", 2048)], [0, position_ids.shape[0]]
)[0]
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
if self.enable_mm:
self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d(
position_ids, [request.get("max_tokens", 2048)], [0, position_ids.shape[0]]
)[0]
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
if request.get("enable_thinking", False) and request.get("reasoning_max_tokens", None) is not None:
# Enable thinking
self.share_inputs["max_think_lens"][idx : idx + 1, :] = request.get("reasoning_max_tokens")
self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0
else:
# Disable thinking
self.share_inputs["max_think_lens"][idx : idx + 1, :] = -1
self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0
if request.get("enable_thinking", False) and request.get("reasoning_max_tokens", None) is not None:
# Enable thinking
self.share_inputs["max_think_lens"][idx : idx + 1, :] = request.get("reasoning_max_tokens")
self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0
else:
# Disable thinking
self.share_inputs["max_think_lens"][idx : idx + 1, :] = -1
self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0
def get_attr_from_request(request, attr, default_value=None):
res = request.get(attr, default_value)
@@ -892,6 +927,7 @@ class XPUModelRunner(ModelRunnerBase):
draft_tokens=None,
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
is_profiling=is_dummy_run,
)
# Update bad tokens len
max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"])
@@ -900,7 +936,8 @@ class XPUModelRunner(ModelRunnerBase):
self.forward_meta.pos_emb_type = self.share_inputs["pos_emb_type"]
self.forward_meta.attn_backend = self.attn_backends[0]
self.initialize_attention_backend()
if self.pd_disaggregation_mode == "per_chunk" or self.pd_disaggregation_mode == "per_query":
self.forward_meta.kv_signal_sender = self.kv_signal_sender
# Get sampling metadata
# TODU(lilujia): sync with GPU
self.sampling_metadata = SamplingMetadata(
@@ -1151,10 +1188,10 @@ class XPUModelRunner(ModelRunnerBase):
"""
# 0. set debug level
# self._set_debug_level(0x1, model_forward_batch, is_dummy_run)
if self.pd_disaggregation_mode == "per_chunk" or self.pd_disaggregation_mode == "per_query":
self.kv_signal_sender = create_kv_signal_sender()
# 1. Prepare inputs of model and decoder.
self._prepare_inputs(is_dummy_run=is_dummy_run)
# NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state.
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
# when there is data on other runner, the current runner is required to execute part of the model.
@@ -1229,6 +1266,8 @@ class XPUModelRunner(ModelRunnerBase):
self.cache_config.enc_dec_block_num,
)
if self.pd_disaggregation_mode == "per_chunk" or self.pd_disaggregation_mode == "per_query":
destroy_kv_signal_sender(self.kv_signal_sender)
return None
def _execute_empty_input(self) -> None: