[xpu] add ep custom ops (#3911)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

This commit is contained in:
zhupengyang
2025-09-10 12:22:50 +08:00
committed by GitHub
parent c3b2a60fb8
commit 9d0074a91a
71 changed files with 5436 additions and 80 deletions

View File

@@ -143,9 +143,9 @@ function build_and_install_ops() {
TMP_DIR_REAL_PATH=`readlink -f ${OPS_TMP_DIR}`
is_xpu=`$python -c "import paddle; print(paddle.is_compiled_with_xpu())"`
if [ "$is_xpu" = "True" ]; then
cd xpu_ops/src
cd xpu_ops
bash build.sh ${TMP_DIR_REAL_PATH}
cd ../..
cd ..
elif [ "$FD_CPU_USE_BF16" == "true" ]; then
if [ "$FD_BUILDING_ARCS" == "" ]; then
FD_CPU_USE_BF16=True ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR}

View File

@@ -542,7 +542,7 @@ elif paddle.is_compiled_with_cuda():
include_package_data=True,
)
elif paddle.is_compiled_with_xpu():
assert False, "In XPU, we should use setup_ops.py in xpu_ops/src, not this."
assert False, "For XPU, please use setup_ops.py in the xpu_ops directory to compile custom ops."
elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
setup(
name="fastdeploy_ops",

View File

@@ -27,7 +27,7 @@ import paddle
from paddle.utils.cpp_extension import CppExtension, setup
current_file = Path(__file__).resolve()
base_dir = current_file.parent
base_dir = os.path.join(current_file.parent, "src")
def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR, XDNN_LIB_DIR):
@@ -136,33 +136,8 @@ def xpu_setup_ops():
# build plugin
build_plugin(CLANG_PATH, XRE_INC_PATH, XRE_LIB_DIR, XDNN_INC_PATH, XDNN_LIB_DIR)
ops = [
# custom ops
"./ops/save_with_output_msg.cc",
"./ops/stop_generation_multi_ends.cc",
"./ops/set_value_by_flags_and_idx.cc",
"./ops/get_token_penalty_multi_scores.cc",
"./ops/get_padding_offset.cc",
"./ops/update_inputs.cc",
"./ops/recover_decode_task.cc",
"./ops/update_inputs_v1.cc",
"./ops/get_output.cc",
"./ops/step.cc",
"./ops/get_infer_param.cc",
"./ops/adjust_batch.cc",
"./ops/gather_next_token.cc",
"./ops/block_attn.cc",
"./ops/moe_layer.cc",
"./ops/weight_quantize_xpu.cc",
# device manage ops
"./ops/device/get_context_gm_max_mem_demand.cc",
"./ops/device/get_free_global_memory.cc",
"./ops/device/get_total_global_memory.cc",
"./ops/device/get_used_global_memory.cc",
]
ops = [os.path.join(base_dir, op) for op in ops]
for root, dirs, files in os.walk(base_dir / "ops/mtp_ops"):
ops = []
for root, dirs, files in os.walk(os.path.join(base_dir, "ops")):
for file in files:
if file.endswith(".cc"):
ops.append(os.path.join(root, file))

View File

@@ -0,0 +1,225 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <infer_ops.h>
#include <functional>
#include "paddle/extension.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "utility/debug.h"
#include "utility/env.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
XPU_DECLARE_BOOL(ENABLE_XVLLM_SDNN_INFER, false);
namespace api = baidu::xpu::api;
template <typename T>
std::vector<paddle::Tensor> RmsNormKernel(
const paddle::Tensor& x,
const paddle::optional<paddle::Tensor>& bias,
const paddle::optional<paddle::Tensor>& residual,
const paddle::Tensor& norm_weight,
const paddle::optional<paddle::Tensor>& norm_bias,
const float epsilon,
const int begin_norm_axis,
const float quant_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound) {
using XPU_T = typename XPUTypeTrait<T>::Type;
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
int ret = -1;
auto x_shape = x.shape();
PD_CHECK(quant_scale <= 0, "Quantization is not supported");
PD_CHECK(begin_norm_axis > 0 && begin_norm_axis <= x_shape.size(),
"begin_norm_axis check fail");
PD_CHECK(norm_bias.get_ptr() == nullptr,
"rms norm kernel don't support norm_bias");
int64_t m = std::accumulate(x_shape.begin(),
x_shape.begin() + begin_norm_axis,
static_cast<int64_t>(1),
std::multiplies<int64_t>());
int64_t n = std::accumulate(x_shape.begin() + begin_norm_axis,
x_shape.end(),
static_cast<int64_t>(1),
std::multiplies<int64_t>());
PD_CHECK(n == norm_weight.shape()[0],
"The product from begin_norm_axis to the last axis of x must be "
"equal to the norm_weight's shape[0]");
if (bias.get_ptr()) {
PD_CHECK(n == bias.get_ptr()->shape()[0],
"The product from begin_norm_axis to the last axis of x must be "
"equal to the bias's shape[0]");
}
paddle::Tensor out = paddle::empty(x_shape, x.dtype(), x.place());
paddle::Tensor residual_out = paddle::empty(x_shape, x.dtype(), x.place());
const XPU_T* x_data = reinterpret_cast<const XPU_T*>(x.data<T>());
const XPU_T* norm_weight_data =
reinterpret_cast<const XPU_T*>(norm_weight.data<T>());
const XPU_T* bias_data =
bias.get_ptr() ? reinterpret_cast<const XPU_T*>(bias.get_ptr()->data<T>())
: nullptr;
const XPU_T* residual_data =
residual.get_ptr()
? reinterpret_cast<const XPU_T*>(residual.get_ptr()->data<T>())
: nullptr;
XPU_T* out_data = reinterpret_cast<XPU_T*>(const_cast<T*>(out.data<T>()));
XPU_T* residual_out_data = nullptr;
if (residual_data) {
residual_out_data =
reinterpret_cast<XPU_T*>(const_cast<T*>(residual_out.data<T>()));
}
XPU_T* add_out_data = const_cast<XPU_T*>(x_data);
if (bias_data) {
ret = api::broadcast_add(
xpu_ctx->x_context(), x_data, bias_data, out_data, {m, n}, {n});
PD_CHECK(ret == 0, "broadcast_add");
add_out_data = out_data;
}
bool use_sdnn = FLAGS_ENABLE_XVLLM_SDNN_INFER;
if (residual_data) {
ret = infer_ops::add_rms_layer_norm<XPU_T, XPU_T>(xpu_ctx->x_context(),
add_out_data,
residual_data,
out_data,
m,
n,
epsilon,
norm_weight_data,
nullptr,
nullptr,
residual_out_data,
nullptr,
use_sdnn);
PD_CHECK(ret == 0, "add_rms_layer_norm");
} else {
ret = api::rms_layer_norm<XPU_T, XPU_T>(xpu_ctx->x_context(),
add_out_data,
out_data,
m,
n,
epsilon,
norm_weight_data,
nullptr,
nullptr,
false);
PD_CHECK(ret == 0, "rms_layer_norm");
}
return {out, residual_out};
}
std::vector<paddle::Tensor> RmsNorm(
const paddle::Tensor& x,
const paddle::optional<paddle::Tensor>& bias,
const paddle::optional<paddle::Tensor>& residual,
const paddle::Tensor& norm_weight,
const paddle::optional<paddle::Tensor>& norm_bias,
const float epsilon,
const int begin_norm_axis,
const float quant_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound) {
const auto x_type = x.dtype();
#define APPLY_RMS_NORM_KERNEL(TX) \
return RmsNormKernel<TX>(x, \
bias, \
residual, \
norm_weight, \
norm_bias, \
epsilon, \
begin_norm_axis, \
quant_scale, \
quant_round_type, \
quant_max_bound, \
quant_min_bound);
if (x_type == paddle::DataType::BFLOAT16) {
APPLY_RMS_NORM_KERNEL(paddle::bfloat16);
} else if (x_type == paddle::DataType::FLOAT16) {
APPLY_RMS_NORM_KERNEL(paddle::float16);
} else if (x_type == paddle::DataType::FLOAT32) {
APPLY_RMS_NORM_KERNEL(float);
} else {
PD_THROW("RmsNorm not support x_type=", static_cast<int>(x_type));
return {};
}
#undef APPLY_RMS_NORM_KERNEL
}
std::vector<std::vector<int64_t>> RmsNormInferShape(
const std::vector<int64_t>& x_shape,
const paddle::optional<std::vector<int64_t>>& bias_shape,
const paddle::optional<std::vector<int64_t>>& residual_shape,
const std::vector<int64_t>& norm_weight_shape,
const paddle::optional<std::vector<int64_t>>& norm_bias_shape,
const float epsilon,
const int begin_norm_axis,
const float quant_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound) {
PD_CHECK(begin_norm_axis > 0 && begin_norm_axis <= x_shape.size(),
"begin_norm_axis check fail");
int64_t m = std::accumulate(x_shape.begin(),
x_shape.begin() + begin_norm_axis,
static_cast<int64_t>(1),
std::multiplies<int64_t>());
return {x_shape, x_shape, {m}};
}
std::vector<paddle::DataType> RmsNormInferDtype(
const paddle::DataType& x_dtype,
const paddle::optional<paddle::DataType>& bias_dtype,
const paddle::optional<paddle::DataType>& residual_dtype,
const paddle::DataType& norm_weight_dtype,
const paddle::optional<paddle::DataType>& norm_bias_dtype,
const float epsilon,
const int begin_norm_axis,
const float quant_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound) {
// out, residual_out
return {x_dtype, x_dtype};
}
PD_BUILD_STATIC_OP(fused_rms_norm_xpu)
.Inputs({"x",
paddle::Optional("bias"),
paddle::Optional("residual"),
"norm_weight",
paddle::Optional("norm_bias")})
.Outputs({"out", "residul_out"})
.Attrs({"epsilon:float",
"begin_norm_axis:int",
"quant_scale:float",
"quant_round_type:int",
"quant_max_bound:float",
"quant_min_bound:float"})
.SetKernelFn(PD_KERNEL(RmsNorm))
.SetInferShapeFn(PD_INFER_SHAPE(RmsNormInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(RmsNormInferDtype));

View File

@@ -18,13 +18,35 @@
#include <sys/ipc.h>
#include <sys/msg.h>
#include <sys/types.h>
#include "msg_utils.h"
#define MAX_BSZ 256
// #define GET_OUTPUT_DEBUG
struct msgdata {
long mtype;
int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens
};
void GetOutputKVSignal(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag) {
int msg_queue_id = 1024 + rank_id;
static struct msgdatakv msg_rcv;
static key_t key = ftok("/opt/", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);
int* out_data = const_cast<int*>(x.data<int>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, IPC_NOWAIT);
} else {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, 0);
}
if (ret == -1) {
out_data[0] = -1;
out_data[1] = -1;
return;
}
int encoder_count = msg_rcv.mtext[0];
for (int i = 0; i < encoder_count * 3 + 2; i++) {
out_data[i] = msg_rcv.mtext[i];
}
return;
}
void GetOutput(const paddle::Tensor &x, int64_t rank_id, bool wait_flag,
int msg_queue_id) {

View File

@@ -0,0 +1,119 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <infer_ops.h>
#include <xft_api.h>
#include "paddle/extension.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "utility/debug.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
template <typename T>
std::vector<paddle::Tensor> MoeEPCombineKernel(
const paddle::Tensor&
ffn_out, // expand_token_num * hidden_dim dtype is fp16/bf16
const paddle::Tensor& moe_index, // token_num * topk dtype is int
const paddle::Tensor&
weights, // token_num * topk dtype is same as ffn_out
int64_t recv_token_num,
int64_t expand_token_num,
int64_t hidden_dim,
int64_t topk) {
using XPU_T = typename XPUTypeTrait<T>::Type;
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
auto combined_out = paddle::empty(
{recv_token_num, hidden_dim}, ffn_out.dtype(), ffn_out.place());
const float* dequant_score = nullptr;
int ret = infer_ops::moe_ep_ffn_post_fusion(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_T*>(ffn_out.data<T>()),
moe_index.data<int32_t>(),
reinterpret_cast<const XPU_T*>(weights.data<T>()),
dequant_score,
reinterpret_cast<XPU_T*>(combined_out.mutable_data<T>()),
recv_token_num,
hidden_dim,
topk,
expand_token_num);
PD_CHECK(ret == 0);
return {combined_out};
}
std::vector<paddle::Tensor> MoeEPCombine(const paddle::Tensor& ffn_out,
const paddle::Tensor& moe_index,
const paddle::Tensor& weights,
const int recv_token_num,
const int expand_token_num,
const int hidden_dim,
const int topk) {
#define APPLY_KERNEL(TX) \
return MoeEPCombineKernel<TX>(ffn_out, \
moe_index, \
weights, \
recv_token_num, \
expand_token_num, \
hidden_dim, \
topk);
const auto ffn_out_dtype = ffn_out.dtype();
if (ffn_out_dtype == paddle::DataType::FLOAT16) {
APPLY_KERNEL(paddle::float16);
} else if (ffn_out_dtype == paddle::DataType::BFLOAT16) {
APPLY_KERNEL(paddle::bfloat16);
} else {
PD_THROW("MoeEPCombine not support ffn_out_type==%d",
static_cast<int>(ffn_out_dtype));
return {};
}
#undef APPLY_KERNEL
}
std::vector<std::vector<int64_t>> MoeEPCombineInferShape(
const std::vector<int64_t>& ffn_out_shape,
const std::vector<int64_t>& moe_index_shape,
const std::vector<int64_t>& weights_shape,
const int recv_token_num,
const int expand_token_num,
const int hidden_dim,
const int topk) {
std::vector<int64_t> combined_out_shape = {recv_token_num, hidden_dim};
return {combined_out_shape};
}
std::vector<paddle::DataType> MoeEPCombineInferDtype(
const paddle::DataType& ffn_out_dtype,
const paddle::DataType& moe_index_dtype,
const paddle::DataType& weights_dtype) {
return {ffn_out_dtype};
}
PD_BUILD_STATIC_OP(ep_moe_expert_combine)
.Inputs({"ffn_out", "moe_index", "weights"})
.Outputs({"combined_out"})
.Attrs({"recv_token_num: int",
"expand_token_num: int",
"hidden_dim: int",
"topk: int"})
.SetKernelFn(PD_KERNEL(MoeEPCombine))
.SetInferShapeFn(PD_INFER_SHAPE(MoeEPCombineInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoeEPCombineInferDtype));

View File

@@ -0,0 +1,201 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <infer_ops.h>
#include <infer_ops_eb.h>
#include <xft_api.h>
#include "paddle/extension.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "utility/debug.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
template <typename TX, typename TY>
std::vector<paddle::Tensor> EPMoeExpertDispatchKernel(
const paddle::Tensor& input,
const paddle::Tensor& topk_ids,
const paddle::Tensor& topk_weights,
const paddle::optional<paddle::Tensor>& input_scales,
const std::vector<int>& token_nums_per_expert,
const int64_t token_nums_this_rank) {
using XPU_TX = typename XPUTypeTrait<TX>::Type;
using XPU_TY = typename XPUTypeTrait<TY>::Type;
phi::XPUPlace xpu_place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx =
paddle::experimental::DeviceContextPool::Instance().Get(xpu_place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
const auto input_type = input.dtype();
auto m = input.dims()[0];
auto n = input.dims()[1];
const int64_t expert_num = token_nums_per_expert.size();
const int topk = topk_ids.dims()[1];
auto place = input.place();
auto block_num = xpu_ctx->x_context()->ncluster();
paddle::Tensor permute_input;
auto permute_indices_per_token =
paddle::empty({m, topk}, paddle::DataType::INT32, place);
auto expert_m = paddle::empty({expert_num}, paddle::DataType::INT32, place);
auto recv_num_tokens_per_expert_list_cumsum =
paddle::empty({expert_num + 1}, paddle::DataType::INT32, place);
auto expand_input_scales =
paddle::empty({token_nums_this_rank}, paddle::DataType::FLOAT32, place);
const int64_t ep_size = 1;
const int64_t ep_rank = 0;
if (std::is_same<TY, int8_t>::value) {
permute_input =
paddle::empty({token_nums_this_rank, n}, paddle::DataType::INT8, place);
auto ret = infer_ops::moe_ffn_pre_sorted_quant_pe<XPU_TX, int>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_TX*>(input.data<TX>()),
topk_ids.data<int>(),
input_scales.get_ptr()->data<float>(),
nullptr,
reinterpret_cast<int8_t*>(permute_input.data<int8_t>()),
const_cast<int*>(permute_indices_per_token.data<int>()),
const_cast<int*>(expert_m.data<int>()),
const_cast<int*>(recv_num_tokens_per_expert_list_cumsum.data<int>()),
expand_input_scales.data<float>(),
m,
n,
expert_num,
topk,
block_num,
token_nums_this_rank);
PD_CHECK(ret == 0, "moe_ep_ffn_pre_sorted failed");
} else {
permute_input = paddle::empty({token_nums_this_rank, n}, input_type, place);
auto ret = infer_ops::moe_ep_ffn_pre_sorted<XPU_TX, int>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_TX*>(input.data<TX>()),
topk_ids.data<int>(),
nullptr,
reinterpret_cast<XPU_TX*>(permute_input.data<TX>()),
const_cast<int*>(permute_indices_per_token.data<int>()),
const_cast<int*>(expert_m.data<int>()),
const_cast<int*>(recv_num_tokens_per_expert_list_cumsum.data<int>()),
m,
n,
expert_num,
topk,
block_num,
ep_size,
ep_rank,
token_nums_this_rank);
PD_CHECK(ret == 0, "moe_ep_ffn_pre_sorted failed");
}
return {permute_input,
permute_indices_per_token,
recv_num_tokens_per_expert_list_cumsum,
topk_weights,
expand_input_scales};
}
std::vector<paddle::Tensor> EPMoeExpertDispatch(
const paddle::Tensor& input,
const paddle::Tensor& topk_ids,
const paddle::Tensor& topk_weights,
const paddle::optional<paddle::Tensor>& input_scales,
const std::vector<int>& token_nums_per_expert,
const int token_nums_this_rank,
const std::string quant_method) {
#define APPLY_KERNEL(TX, TY) \
return EPMoeExpertDispatchKernel<TX, TY>(input, \
topk_ids, \
topk_weights, \
input_scales, \
token_nums_per_expert, \
token_nums_this_rank);
const auto input_dtype = input.dtype();
if (input_dtype == paddle::DataType::FLOAT16 && quant_method == "w4a8") {
APPLY_KERNEL(paddle::float16, int8_t);
} else if (input_dtype == paddle::DataType::FLOAT16 &&
quant_method != "w4a8") {
APPLY_KERNEL(paddle::float16, paddle::float16);
} else if (input_dtype == paddle::DataType::BFLOAT16 &&
quant_method == "w4a8") {
APPLY_KERNEL(paddle::bfloat16, int8_t);
} else if (input_dtype == paddle::DataType::BFLOAT16 &&
quant_method != "w4a8") {
APPLY_KERNEL(paddle::bfloat16, paddle::bfloat16);
} else {
PD_THROW("EPMoeExpertDispatch not support input_dtype=",
static_cast<int>(input_dtype),
"quant_method=",
quant_method);
return {};
}
#undef APPLY_KERNEL
}
std::vector<std::vector<int64_t>> EPMoeExpertDispatchInferShape(
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& topk_ids_shape,
const std::vector<int64_t>& topk_weights_shape,
const paddle::optional<std::vector<int64_t>>& input_scales_shape,
const std::vector<int>& token_nums_per_expert,
const int token_nums_this_rank,
const std::string quant_method) {
const int m = input_shape[0];
const int hidden_size = input_shape[input_shape.size() - 1];
const int topk = topk_ids_shape[topk_ids_shape.size() - 1];
const int expert_num = token_nums_per_expert.size();
return {{token_nums_this_rank, hidden_size},
{expert_num, m},
{expert_num},
{token_nums_this_rank},
{token_nums_this_rank}};
}
std::vector<paddle::DataType> EPMoeExpertDispatchInferDtype(
const paddle::DataType& input_dtype,
const paddle::DataType& topk_ids_dtype,
const paddle::DataType& topk_weights_dtype,
const paddle::optional<paddle::DataType>& input_scales_dtype,
const std::vector<int>& token_nums_per_expert,
const int token_nums_this_rank,
const std::string quant_method) {
auto output_dtype = input_dtype;
if (quant_method == "w4a8") {
output_dtype = paddle::DataType::INT8;
}
return {
output_dtype,
paddle::DataType::INT32,
paddle::DataType::INT32,
topk_weights_dtype,
paddle::DataType::FLOAT32,
};
}
PD_BUILD_STATIC_OP(ep_moe_expert_dispatch)
.Inputs(
{"input", "topk_ids", "topk_weights", paddle::Optional("input_scales")})
.Outputs({"permute_input",
"permute_indices_per_token",
"token_nums_per_expert_cumsum",
"dst_weights",
"expand_input_scales"})
.Attrs({"token_nums_per_expert: std::vector<int>",
"token_nums_this_rank: int",
"quant_method: std::string"})
.SetKernelFn(PD_KERNEL(EPMoeExpertDispatch))
.SetInferShapeFn(PD_INFER_SHAPE(EPMoeExpertDispatchInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(EPMoeExpertDispatchInferDtype));

View File

@@ -0,0 +1,535 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <blocks/moe_fc_block_eb.h>
#include <core/check.h>
#include <core/context.h>
#include <core/param.h>
#include <infer_ops.h>
#include <xft_api.h>
#include "paddle/extension.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "utility/debug.h"
#include "utility/env.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
XPU_DECLARE_BOOL(MOE_FFN_USE_DENSE_INPUT, false);
XPU_DECLARE_BOOL(BKCL_DISPATCH_ALL_GATHER, false);
namespace xftblock = baidu::xpu::xftblock;
namespace api = baidu::xpu::api;
template <typename TX1, typename TX2, typename TW, typename TGEMM>
void MoeExpertFFNImpl(xftblock::Tensor* ffn_in,
xftblock::Tensor* token_num_info,
xftblock::Tensor* ffn1_weight,
xftblock::Tensor* ffn2_weight,
xftblock::Tensor* ffn1_bias,
xftblock::Tensor* ffn2_bias,
xftblock::Tensor* ffn2_out,
float* ffn2_act_scale,
TX2* ffn2_shift,
TX2* ffn2_smooth,
const int hadamard_blocksize) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
xftblock::XFTContext xctx(xpu_ctx->x_context(), nullptr);
auto rt_guard = xctx.get_rt_guard();
auto xftblock_tx2 = xftblock::DataTypeToEnum<TX2>::value;
int ret = -1;
int expert_num = ffn1_weight->get_dim(0);
int inter_dim = ffn1_weight->get_dim(1);
int outer_dim = inter_dim / 2;
bool is_padding_input = ffn_in->get_dims().size() == 3;
auto ffn1_out_shape = ffn_in->get_dims();
int hidden_dim = ffn1_out_shape[ffn1_out_shape.size() - 1];
ffn1_out_shape[ffn1_out_shape.size() - 1] = inter_dim;
xftblock::Tensor ffn1_out(rt_guard, xftblock_tx2, ffn1_out_shape);
ret = xftblock::xft_moe_fc_block_eb<TX1, TW, TX2, float, int, TGEMM>(
&xctx,
ffn_in,
ffn1_weight,
&ffn1_out,
ffn1_bias,
is_padding_input ? nullptr : token_num_info,
is_padding_input ? token_num_info : nullptr,
expert_num,
1, // moe_topk
ffn1_out_shape.size() == 2 ? xftblock::MoeFCInputMode::DENSE
: xftblock::MoeFCInputMode::SPARSE);
PD_CHECK(ret == 0);
int token_num = ffn_in->numel() / hidden_dim;
auto swiglu_out_shape = ffn1_out_shape;
swiglu_out_shape[swiglu_out_shape.size() - 1] /= 2;
xftblock::Tensor swiglu_out(rt_guard, xftblock_tx2, swiglu_out_shape);
ret = api::fast_swiglu<TX2>(xpu_ctx->x_context(),
ffn1_out.data<TX2>(),
swiglu_out.mutable_data<TX2>(),
{token_num, inter_dim},
1,
true);
PD_CHECK(ret == 0);
// TODO(mayang02): use fusion_smooth_transform
if (ffn2_shift != nullptr) {
ret = api::broadcast_add<TX2>(xpu_ctx->x_context(),
ffn2_shift,
swiglu_out.data<TX2>(),
swiglu_out.mutable_data<TX2>(),
{1, outer_dim},
{token_num, outer_dim});
PD_CHECK(ret == 0);
}
if (ffn2_smooth != nullptr) {
ret = api::broadcast_mul<TX2>(xpu_ctx->x_context(),
ffn2_smooth,
swiglu_out.data<TX2>(),
swiglu_out.mutable_data<TX2>(),
{1, outer_dim},
{token_num, outer_dim});
PD_CHECK(ret == 0);
}
if (hadamard_blocksize > 0) {
ret = infer_ops::fast_walsh_transform<TX2>(xpu_ctx->x_context(),
swiglu_out.data<TX2>(),
nullptr,
nullptr,
swiglu_out.mutable_data<TX2>(),
hadamard_blocksize,
token_num,
outer_dim);
PD_CHECK(ret == 0);
}
xftblock::Tensor ffn2_in(swiglu_out.mutable_data<TX2>(),
nullptr,
ffn2_act_scale,
xftblock_tx2,
swiglu_out_shape);
ret = xftblock::xft_moe_fc_block_eb<TX2, TW, TX2, float, int, TGEMM>(
&xctx,
&ffn2_in,
ffn2_weight,
ffn2_out,
nullptr,
is_padding_input ? nullptr : token_num_info,
is_padding_input ? token_num_info : nullptr,
expert_num,
1, // moe_topk
ffn1_out_shape.size() == 2
? xftblock::MoeFCInputMode::DENSE
: xftblock::MoeFCInputMode::SPARSE); // bias_mode
PD_CHECK(ret == 0);
}
static void convert_to_lod(xftblock::XFTContext* xctx,
xftblock::Tensor* token_num_info) {
auto rt_guard = xctx->get_rt_guard();
auto ctx = xctx->get_context();
const int expert_num = token_num_info->numel();
xftblock::Tensor tokens_num_lod(
rt_guard, xftblock::DataType::DT_INT32, {expert_num + 1});
int ret = api::constant(ctx, tokens_num_lod.data<int>(), expert_num + 1, 0);
PD_CHECK(ret == 0);
ret = api::cumsum<int>(ctx,
token_num_info->data<int>(),
tokens_num_lod.data<int>() + 1,
{expert_num},
false,
false,
0);
PD_CHECK(ret == 0);
*token_num_info = std::move(tokens_num_lod);
}
template <typename TX1, typename TX2, typename TW>
std::vector<paddle::Tensor> MoeExpertFFNKernel(
const paddle::Tensor& ffn_in,
const paddle::Tensor& token_num_info,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn2_bias,
const paddle::optional<paddle::Tensor>& ffn1_act_scale,
const paddle::optional<paddle::Tensor>& ffn2_act_scale,
const paddle::optional<paddle::Tensor>& ffn1_weight_scale,
const paddle::optional<paddle::Tensor>& ffn2_weight_scale,
const paddle::optional<paddle::Tensor>& ffn2_shift,
const paddle::optional<paddle::Tensor>& ffn2_smooth,
const std::string& quant_method,
const int hadamard_blocksize,
const int valid_token_num) {
using XPU_TX1 = typename XPUTypeTrait<TX1>::Type;
using XPU_TX2 = typename XPUTypeTrait<TX2>::Type;
using XPU_TW = typename XPUTypeTrait<TW>::Type;
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
xftblock::XFTContext xctx(xpu_ctx->x_context(), nullptr);
auto rt_guard = xctx.get_rt_guard();
int ret = -1;
auto input_shape = ffn_in.shape();
auto ffn1_w_shape = ffn1_weight.shape();
int expert_num = ffn1_w_shape[0];
int hidden_dim = input_shape[input_shape.size() - 1];
int inter_dim = ffn1_w_shape[1];
int outer_dim = inter_dim / 2;
bool is_padding_input = input_shape.size() == 3;
if (is_padding_input) {
PD_CHECK(input_shape[0] == expert_num);
PD_CHECK(token_num_info.numel() == expert_num,
"token_num_info.numel() != expert_num, "
"token_num_info.numel(): ",
token_num_info.numel(),
", expert_num: ",
expert_num);
}
bool is_w4 = quant_method == "w4a8" || quant_method == "weight_only_int4";
auto xftblock_tx1 = xftblock::DataTypeToEnum<XPU_TX1>::value;
auto xftblock_tx2 = xftblock::DataTypeToEnum<XPU_TX2>::value;
auto xftblock_tw = xftblock::DataTypeToEnum<XPU_TW>::value;
if (is_w4) {
xftblock_tw = xftblock::DataTypeToEnum<int4_t>::value;
}
float* ffn1_act_scale_data =
ffn1_act_scale.get_ptr() == nullptr
? nullptr
: const_cast<float*>(ffn1_act_scale.get_ptr()->data<float>());
float* ffn2_act_scale_data =
ffn2_act_scale.get_ptr() == nullptr
? nullptr
: const_cast<float*>(ffn2_act_scale.get_ptr()->data<float>());
float* ffn1_w_scale_data =
ffn1_weight_scale.get_ptr() == nullptr
? nullptr
: const_cast<float*>(ffn1_weight_scale.get_ptr()->data<float>());
xftblock::Tensor xffn1_w(const_cast<TW*>(ffn1_weight.data<TW>()),
nullptr,
ffn1_w_scale_data,
xftblock_tw,
{expert_num, inter_dim, hidden_dim});
float* ffn2_w_scale_data =
ffn2_weight_scale.get_ptr() == nullptr
? nullptr
: const_cast<float*>(ffn2_weight_scale.get_ptr()->data<float>());
xftblock::Tensor xffn2_w(const_cast<TW*>(ffn2_weight.data<TW>()),
nullptr,
ffn2_w_scale_data,
xftblock_tw,
{expert_num, hidden_dim, outer_dim});
std::shared_ptr<xftblock::Tensor> xffn1_bias;
if (ffn1_bias.get_ptr()) {
xffn1_bias = std::make_shared<xftblock::Tensor>(
const_cast<float*>(ffn1_bias.get_ptr()->data<float>()),
xftblock::DataType::DT_FLOAT,
ffn1_bias.get_ptr()->shape());
}
std::shared_ptr<xftblock::Tensor> xffn2_bias;
if (ffn2_bias.get_ptr()) {
xffn2_bias = std::make_shared<xftblock::Tensor>(
const_cast<float*>(ffn2_bias.get_ptr()->data<float>()),
xftblock::DataType::DT_FLOAT,
ffn2_bias.get_ptr()->shape());
}
xftblock::Tensor xtoken_num_info(const_cast<int*>(token_num_info.data<int>()),
xftblock::DataType::DT_INT32,
token_num_info.shape());
XPU_TX2* shift_data = nullptr;
XPU_TX2* smooth_data = nullptr;
if (ffn2_shift.get_ptr()) {
shift_data = reinterpret_cast<XPU_TX2*>(
const_cast<TX2*>(ffn2_shift.get_ptr()->data<TX2>()));
}
if (ffn2_smooth.get_ptr()) {
smooth_data = reinterpret_cast<XPU_TX2*>(
const_cast<TX2*>(ffn2_smooth.get_ptr()->data<TX2>()));
}
paddle::Tensor ffn2_out =
paddle::empty_like(ffn_in, paddle::DataType::BFLOAT16);
xftblock::Tensor xffn1_in;
xftblock::Tensor xffn2_out;
paddle::Tensor ffn1_in_dense;
paddle::Tensor ffn1_in_scale_per_token;
if (FLAGS_MOE_FFN_USE_DENSE_INPUT && is_padding_input) {
convert_to_lod(&xctx, &xtoken_num_info);
if (quant_method == "w4a8") {
ffn1_in_scale_per_token = paddle::empty(
{valid_token_num}, paddle::DataType::FLOAT32, ffn_in.place());
ffn1_in_dense = paddle::empty({valid_token_num, hidden_dim},
paddle::DataType::INT8,
ffn_in.place());
xffn1_in = xftblock::Tensor(ffn1_in_dense.data<int8_t>(),
nullptr,
ffn1_in_scale_per_token.data<float>(),
xftblock::DataType::DT_INT8,
{valid_token_num, hidden_dim});
if (std::is_same<XPU_TX1, int8_t>::value) {
PD_CHECK(ffn1_act_scale_data != nullptr,
"need ffn1_act_scale for x int8 per expert input");
ret = infer_ops::sequence_unpad<float, int>(
xpu_ctx->x_context(),
ffn1_act_scale_data,
ffn1_in_scale_per_token.data<float>(),
xtoken_num_info.data<int>(),
expert_num,
input_shape[1],
1,
true);
PD_CHECK(ret == 0);
ret = infer_ops::sequence_unpad<int8_t, int>(
xpu_ctx->x_context(),
reinterpret_cast<const int8_t*>(ffn_in.data<int8_t>()),
reinterpret_cast<int8_t*>(xffn1_in.data<int8_t>()),
xtoken_num_info.data<int>(),
expert_num,
input_shape[1],
input_shape[2],
true);
PD_CHECK(ret == 0);
} else {
ret = infer_ops::quant2d_per_expert<XPU_TX1>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_TX1*>(ffn_in.data<TX1>()),
ffn1_act_scale_data,
xtoken_num_info.data<int>(),
reinterpret_cast<int8_t*>(xffn1_in.data<int8_t>()),
ffn1_in_scale_per_token.data<float>(),
expert_num,
valid_token_num,
hidden_dim,
true,
false,
input_shape[1]);
PD_CHECK(ret == 0);
}
} else {
ffn1_in_dense = paddle::empty(
{valid_token_num, hidden_dim}, ffn_in.dtype(), ffn_in.place());
xffn1_in = xftblock::Tensor(ffn1_in_dense.data<TX1>(),
nullptr,
ffn1_act_scale_data,
xftblock_tx1,
{valid_token_num, hidden_dim});
ret = infer_ops::sequence_unpad<XPU_TX1, int>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_TX1*>(ffn_in.data<TX1>()),
reinterpret_cast<XPU_TX1*>(xffn1_in.data<XPU_TX1>()),
xtoken_num_info.data<int>(),
expert_num,
input_shape[1],
input_shape[2],
true);
PD_CHECK(ret == 0);
}
xffn2_out =
xftblock::Tensor(rt_guard, xftblock_tx2, {valid_token_num, hidden_dim});
} else if (FLAGS_BKCL_DISPATCH_ALL_GATHER && !is_padding_input &&
quant_method == "w4a8") {
convert_to_lod(&xctx, &xtoken_num_info);
ffn1_in_scale_per_token = paddle::empty(
{valid_token_num}, paddle::DataType::FLOAT32, ffn_in.place());
ffn1_in_dense = paddle::empty(
{valid_token_num, hidden_dim}, paddle::DataType::INT8, ffn_in.place());
xffn1_in = xftblock::Tensor(ffn1_in_dense.data<int8_t>(),
nullptr,
ffn1_in_scale_per_token.data<float>(),
xftblock::DataType::DT_INT8,
{valid_token_num, hidden_dim});
ret = infer_ops::quant2d_per_expert<XPU_TX1>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_TX1*>(ffn_in.data<TX1>()),
ffn1_act_scale_data,
xtoken_num_info.data<int>(),
reinterpret_cast<int8_t*>(xffn1_in.data<int8_t>()),
ffn1_in_scale_per_token.data<float>(),
expert_num,
valid_token_num,
hidden_dim);
PD_CHECK(ret == 0);
xffn2_out =
xftblock::Tensor(ffn2_out.data<TX2>(), xftblock_tx2, input_shape);
} else {
xffn1_in = xftblock::Tensor(const_cast<TX1*>(ffn_in.data<TX1>()),
nullptr,
ffn1_act_scale_data,
xftblock_tx1,
input_shape);
xffn2_out = xftblock::Tensor(
ffn2_out.mutable_data<TX2>(), xftblock_tx2, input_shape);
}
#define FFN_IMPL(TX1, TX2, TW, TGEMM) \
MoeExpertFFNImpl<TX1, TX2, TW, TGEMM>(&xffn1_in, \
&xtoken_num_info, \
&xffn1_w, \
&xffn2_w, \
xffn1_bias.get(), \
xffn2_bias.get(), \
&xffn2_out, \
ffn2_act_scale_data, \
shift_data, \
smooth_data, \
hadamard_blocksize)
if (quant_method == "weight_only_int8") {
FFN_IMPL(XPU_TX1, XPU_TX2, int8_t, float);
} else if (quant_method == "weight_only_int4") {
FFN_IMPL(XPU_TX1, XPU_TX2, int4_t, int4_wo_int15);
} else if (quant_method == "w4a8") {
if (FLAGS_MOE_FFN_USE_DENSE_INPUT && is_padding_input) {
FFN_IMPL(int8_t, XPU_TX2, int4_t, int4_wo_int8);
} else if (FLAGS_BKCL_DISPATCH_ALL_GATHER && !is_padding_input) {
FFN_IMPL(int8_t, XPU_TX2, int4_t, int4_wo_int8);
} else {
FFN_IMPL(XPU_TX1, XPU_TX2, int4_t, int4_wo_int8);
}
} else {
FFN_IMPL(XPU_TX1, XPU_TX2, XPU_TW, float);
}
#undef FFN_IMPL
if (FLAGS_MOE_FFN_USE_DENSE_INPUT && is_padding_input) {
ret = infer_ops::sequence_pad<XPU_TX2, int>(
xpu_ctx->x_context(),
const_cast<XPU_TX2*>(xffn2_out.data<XPU_TX2>()),
reinterpret_cast<XPU_TX2*>(ffn2_out.data<TX2>()),
xtoken_num_info.data<int>(),
input_shape[0],
input_shape[1],
input_shape[2],
false,
0);
PD_CHECK(ret == 0);
}
return {ffn2_out};
}
std::vector<paddle::Tensor> MoeExpertFFN(
const paddle::Tensor& ffn_in,
const paddle::Tensor& token_num_info,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn2_bias,
const paddle::optional<paddle::Tensor>& ffn1_act_scale,
const paddle::optional<paddle::Tensor>& ffn2_act_scale,
const paddle::optional<paddle::Tensor>& ffn1_weight_scale,
const paddle::optional<paddle::Tensor>& ffn2_weight_scale,
const paddle::optional<paddle::Tensor>& ffn2_shift,
const paddle::optional<paddle::Tensor>& ffn2_smooth,
const std::string& quant_method,
const int hadamard_blocksize,
const int valid_token_num) {
const auto x_type = ffn_in.dtype();
const auto w_type = ffn1_weight.dtype();
#define APPLY_FFN_KERNEL(TX1, TX2, TW) \
return MoeExpertFFNKernel<TX1, TX2, TW>(ffn_in, \
token_num_info, \
ffn1_weight, \
ffn2_weight, \
ffn1_bias, \
ffn2_bias, \
ffn1_act_scale, \
ffn2_act_scale, \
ffn1_weight_scale, \
ffn2_weight_scale, \
ffn2_shift, \
ffn2_smooth, \
quant_method, \
hadamard_blocksize, \
valid_token_num);
if (x_type == paddle::DataType::BFLOAT16 &&
w_type == paddle::DataType::BFLOAT16) {
APPLY_FFN_KERNEL(paddle::bfloat16, paddle::bfloat16, paddle::bfloat16);
} else if (x_type == paddle::DataType::BFLOAT16 &&
w_type == paddle::DataType::INT8) {
APPLY_FFN_KERNEL(paddle::bfloat16, paddle::bfloat16, int8_t);
} else if (x_type == paddle::DataType::INT8 &&
w_type == paddle::DataType::INT8) {
APPLY_FFN_KERNEL(int8_t, paddle::bfloat16, int8_t);
} else {
PD_THROW("MoeExpertFFN not support x_type=",
static_cast<int>(x_type),
", w_type=",
static_cast<int>(w_type));
return {};
}
#undef APPLY_FFN_KERNEL
}
std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
const std::vector<int64_t>& permute_input_shape,
const std::vector<int64_t>& token_num_info_shape,
const std::vector<int64_t>& ffn1_weight_shape,
const std::vector<int64_t>& ffn2_weight_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_bias_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_bias_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_act_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_act_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_weight_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_weight_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_shift_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_smooth_shape) {
return {permute_input_shape};
}
std::vector<paddle::DataType> MoeExpertFFNInferDtype(
const paddle::DataType& permute_input_dtype,
const paddle::DataType& token_num_info_dtype,
const paddle::DataType& ffn1_weight_dtype,
const paddle::DataType& ffn2_weight_dtype,
const paddle::optional<paddle::DataType>& ffn1_bias_dtype,
const paddle::optional<paddle::DataType>& ffn2_bias_dtype,
const paddle::optional<paddle::DataType>& ffn1_act_scale_dtype,
const paddle::optional<paddle::DataType>& ffn2_act_scale_dtype,
const paddle::optional<paddle::DataType>& ffn1_weight_scale_dtype,
const paddle::optional<paddle::DataType>& ffn2_weight_scale_dtype,
const paddle::optional<paddle::DataType>& ffn2_shift_dtype,
const paddle::optional<paddle::DataType>& ffn2_smooth_dtype) {
if (permute_input_dtype == paddle::DataType::INT8) {
return {paddle::DataType::BFLOAT16};
} else {
return {permute_input_dtype};
}
}
PD_BUILD_STATIC_OP(moe_expert_ffn)
.Inputs({"ffn_in",
"token_num_info",
"ffn1_weight",
"ffn2_weight",
paddle::Optional("ffn1_bias"),
paddle::Optional("ffn2_bias"),
paddle::Optional("ffn1_act_scale"),
paddle::Optional("ffn2_act_scale"),
paddle::Optional("ffn1_weight_scale"),
paddle::Optional("ffn2_weight_scale"),
paddle::Optional("ffn2_shift"),
paddle::Optional("ffn2_smooth")})
.Outputs({"ffn_out"})
.Attrs({"quant_method:std::string",
"hadamard_blocksize:int",
"valid_token_num:int"})
.SetKernelFn(PD_KERNEL(MoeExpertFFN))
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNInferDtype));

View File

@@ -0,0 +1,134 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <infer_ops.h>
#include <xft_api.h>
#include "paddle/extension.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "utility/debug.h"
std::vector<paddle::Tensor> MoERedundantTopKSelect(
const paddle::Tensor& gating_logits,
const paddle::Tensor& expert_id_to_ep_rank_array,
const paddle::Tensor& expert_in_rank_num_list,
paddle::Tensor& tokens_per_expert_stats_list, // NOLINT
const paddle::optional<paddle::Tensor>& bias,
const int moe_topk,
const bool apply_norm_weight,
const bool enable_softmax_top_k_fused,
const int redundant_ep_rank_num_plus_one) {
namespace api = baidu::xpu::api;
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
api::Context* ctx = xpu_ctx->x_context();
if (gating_logits.is_cpu()) {
ctx = new api::Context(api::kCPU);
}
PD_CHECK(apply_norm_weight, "only support apply_norm_weight==true");
PD_CHECK(enable_softmax_top_k_fused,
"only support enable_softmax_top_k_fused==true");
PD_CHECK(bias.get_ptr() != nullptr, "only support bias != nullptr");
auto gating_logits_dims = gating_logits.shape();
int expert_num = gating_logits_dims[gating_logits_dims.size() - 1];
int64_t token_num = 0;
if (gating_logits_dims.size() == 3) {
token_num = gating_logits_dims[0] * gating_logits_dims[1];
} else {
token_num = gating_logits_dims[0];
}
auto topk_ids = paddle::empty(
{token_num, moe_topk}, paddle::DataType::INT32, gating_logits.place());
auto topk_ids_tmp = paddle::empty(
{token_num, moe_topk}, paddle::DataType::INT32, gating_logits.place());
auto source_rows_tmp = paddle::empty(
{token_num, moe_topk}, paddle::DataType::INT32, gating_logits.place());
auto topk_weights = paddle::empty(
{token_num, moe_topk}, paddle::DataType::FLOAT32, gating_logits.place());
const float* bias_data =
bias.get_ptr() != nullptr ? bias.get_ptr()->data<float>() : nullptr;
int ret = infer_ops::moe_redundant_softmax_topk_normed<float, float, int>(
ctx,
gating_logits.data<float>(),
bias_data,
expert_id_to_ep_rank_array.data<int>(),
expert_in_rank_num_list.data<int>(),
tokens_per_expert_stats_list.data<int>(),
topk_weights.data<float>(),
topk_ids.data<int>(),
topk_ids_tmp.data<int>(),
source_rows_tmp.data<int>(),
expert_num,
moe_topk,
token_num,
redundant_ep_rank_num_plus_one);
PD_CHECK(ret == 0);
return {topk_ids, topk_weights};
}
std::vector<std::vector<int64_t>> MoERedundantTopKSelectInferShape(
const std::vector<int64_t>& gating_logits_shape,
const std::vector<int64_t>& expert_id_to_ep_rank_array_shape,
const std::vector<int64_t>& expert_in_rank_num_list_shape,
const std::vector<int64_t>& tokens_per_expert_stats_list_shape,
const paddle::optional<std::vector<int64_t>>& bias_shape,
const int moe_topk,
const bool apply_norm_weight,
const bool enable_softmax_top_k_fused,
const int redundant_ep_rank_num_plus_one) {
int64_t token_rows = -1;
if (gating_logits_shape.size() == 3) {
token_rows = gating_logits_shape[0] * gating_logits_shape[1];
} else {
token_rows = gating_logits_shape[0];
}
std::vector<int64_t> topk_ids_shape = {token_rows, moe_topk};
std::vector<int64_t> topk_weights_shape = {token_rows, moe_topk};
return {topk_ids_shape, topk_weights_shape};
}
std::vector<paddle::DataType> MoERedundantTopKSelectInferDtype(
const paddle::DataType& gating_logits_dtype,
const paddle::DataType& expert_id_to_ep_rank_array_dtype,
const paddle::DataType& expert_in_rank_num_list_dtype,
const paddle::DataType& tokens_per_expert_stats_list_dtype,
const paddle::optional<paddle::DataType>& bias_type,
const int moe_topk,
const bool apply_norm_weight,
const bool enable_softmax_top_k_fused,
const int redundant_ep_rank_num_plus_one) {
return {paddle::DataType::INT32, paddle::DataType::FLOAT32};
}
PD_BUILD_OP(moe_redundant_topk_select)
.Inputs({"gating_logits",
"expert_id_to_ep_rank_array",
"expert_in_rank_num_list",
"tokens_per_expert_stats_list",
paddle::Optional("bias")})
.Outputs({"topk_ids", "topk_weights", "tokens_per_expert_stats_list_out"})
.Attrs({"moe_topk: int",
"apply_norm_weight: bool",
"enable_softmax_top_k_fused:bool",
"redundant_ep_rank_num_plus_one:int"})
.SetInplaceMap({{"tokens_per_expert_stats_list",
"tokens_per_expert_stats_list_out"}})
.SetKernelFn(PD_KERNEL(MoERedundantTopKSelect))
.SetInferShapeFn(PD_INFER_SHAPE(MoERedundantTopKSelectInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoERedundantTopKSelectInferDtype));

View File

@@ -0,0 +1,84 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <infer_ops.h>
#include <xft_api.h>
#include "paddle/extension.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "utility/debug.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
std::vector<paddle::Tensor> MoeTopkSelect(
const paddle::Tensor& gating_logits,
const paddle::optional<paddle::Tensor>& bias,
const int moe_topk,
const bool apply_norm_weight) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
PD_CHECK(apply_norm_weight, "only support apply_norm_weight==true");
auto gating_logits_dims = gating_logits.shape();
int token_num = gating_logits_dims[0];
int expert_num = gating_logits_dims[1];
auto topk_ids = paddle::empty(
{token_num, moe_topk}, paddle::DataType::INT32, gating_logits.place());
auto topk_weights = paddle::empty(
{token_num, moe_topk}, paddle::DataType::FLOAT32, gating_logits.place());
int32_t* block_statistic = nullptr;
const float* bias_data =
bias.get_ptr() != nullptr ? bias.get_ptr()->data<float>() : nullptr;
int ret = infer_ops::moe_softmax_topk_norm_fusion(
xpu_ctx->x_context(),
gating_logits.data<float>(),
topk_weights.mutable_data<float>(),
topk_ids.mutable_data<int>(),
block_statistic,
token_num,
expert_num,
moe_topk,
0,
bias_data);
PD_CHECK(ret == 0);
return {topk_ids, topk_weights};
}
std::vector<std::vector<int64_t>> MoeTopkSelectInferShape(
const std::vector<int64_t>& gating_logits_shape,
const std::vector<int64_t>& bias_shape,
const int moe_topk,
const bool apply_norm_weight) {
std::vector<int64_t> topk_ids_shape = {gating_logits_shape[0], moe_topk};
std::vector<int64_t> topk_weights_shape = {gating_logits_shape[0], moe_topk};
return {topk_ids_shape, topk_weights_shape};
}
std::vector<paddle::DataType> MoeTopkSelectInferDtype(
const paddle::DataType& gating_logits_dtype,
const paddle::DataType& bias_dtype) {
return {paddle::DataType::INT64, paddle::DataType::FLOAT32};
}
PD_BUILD_STATIC_OP(moe_topk_select)
.Inputs({"gating_logits", paddle::Optional("bias")})
.Outputs({"topk_ids", "topk_weights"})
.Attrs({"moe_topk: int", "apply_norm_weight: bool"})
.SetKernelFn(PD_KERNEL(MoeTopkSelect))
.SetInferShapeFn(PD_INFER_SHAPE(MoeTopkSelectInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoeTopkSelectInferDtype));

View File

@@ -0,0 +1,39 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fcntl.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/ipc.h>
#include <sys/mman.h>
#include <sys/msg.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include "paddle/extension.h"
#define MAX_BSZ 512
struct msgdata {
long mtype; // NOLINT
int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens
};
struct msgdatakv {
long mtype; // NOLINT
int mtext[MAX_BSZ * 3 + 2]; // encoder_count, layer_id, bid- pair
};

View File

@@ -17,6 +17,10 @@
#include "paddle/extension.h"
#include "xpu/plugin.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder,
@@ -37,7 +41,7 @@ void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
PADDLE_ENFORCE_XDNN_SUCCESS(r, "");
}
PD_BUILD_OP(draft_model_postprocess)
PD_BUILD_STATIC_OP(draft_model_postprocess)
.Inputs({"base_model_draft_tokens",
"base_model_seq_lens_this_time",
"base_model_seq_lens_encoder",

View File

@@ -17,6 +17,10 @@
#include "paddle/phi/core/enforce.h"
#include "xpu/plugin.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
namespace api = baidu::xpu::api;
void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& input_ids,
@@ -90,7 +94,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
PD_BUILD_OP(draft_model_preprocess)
PD_BUILD_STATIC_OP(draft_model_preprocess)
.Inputs({"draft_tokens",
"input_ids",
"stop_flags",

View File

@@ -17,6 +17,10 @@
#include "paddle/phi/core/enforce.h"
#include "xpu/plugin.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
const paddle::Tensor& draft_tokens,
const paddle::Tensor& pre_ids,
@@ -86,7 +90,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
PD_CHECK(r == 0, "draft_model_update failed.");
}
PD_BUILD_OP(draft_model_update)
PD_BUILD_STATIC_OP(draft_model_update)
.Inputs({"inter_next_tokens",
"draft_tokens",
"pre_ids",

View File

@@ -16,6 +16,10 @@
#include "paddle/extension.h"
#include "xpu/plugin.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
namespace api = baidu::xpu::api;
std::vector<paddle::Tensor> EagleGetHiddenStates(
const paddle::Tensor& input,
@@ -102,7 +106,7 @@ std::vector<paddle::Tensor> EagleGetHiddenStates(
}
}
PD_BUILD_OP(eagle_get_hidden_states)
PD_BUILD_STATIC_OP(eagle_get_hidden_states)
.Inputs({"input",
"seq_lens_this_time",
"seq_lens_encoder",

View File

@@ -16,6 +16,10 @@
#include "paddle/extension.h"
#include "xpu/plugin.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
namespace api = baidu::xpu::api;
std::vector<paddle::Tensor> EagleGetSelfHiddenStates(
const paddle::Tensor& input,
@@ -97,7 +101,7 @@ std::vector<paddle::Tensor> EagleGetSelfHiddenStates(
}
}
PD_BUILD_OP(eagle_get_self_hidden_states)
PD_BUILD_STATIC_OP(eagle_get_self_hidden_states)
.Inputs(
{"input", "last_seq_lens_this_time", "seq_lens_this_time", "step_idx"})
.Outputs({"out"})

View File

@@ -17,6 +17,10 @@
#include "paddle/phi/core/enforce.h"
#include "xpu/plugin.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
namespace api = baidu::xpu::api;
void MTPStepPaddle(
const paddle::Tensor &base_model_stop_flags,
@@ -64,7 +68,7 @@ void MTPStepPaddle(
}
}
PD_BUILD_OP(mtp_step_paddle)
PD_BUILD_STATIC_OP(mtp_step_paddle)
.Inputs({"base_model_stop_flags",
"stop_flags",
"batch_drop",

View File

@@ -16,6 +16,10 @@
#include "paddle/extension.h"
#include "xpu/plugin.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder) {
// printf("enter clear \n");
@@ -31,7 +35,7 @@ void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
PD_CHECK(r == 0, "speculate_clear_accept_nums_kernel failed.");
}
PD_BUILD_OP(speculate_clear_accept_nums)
PD_BUILD_STATIC_OP(speculate_clear_accept_nums)
.Inputs({"accept_num", "seq_lens_decoder"})
.Outputs({"seq_lens_decoder_out"})
.SetInplaceMap({{"seq_lens_decoder", "seq_lens_decoder_out"}})

View File

@@ -16,6 +16,10 @@
#include "paddle/extension.h"
#include "xpu/plugin.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
std::vector<paddle::Tensor> SpeculateGetOutputPaddingOffset(
const paddle::Tensor& output_cum_offsets_tmp,
const paddle::Tensor& out_token_num,
@@ -69,7 +73,7 @@ std::vector<paddle::DataType> SpeculateGetOutputPaddingOffsetInferDtype(
return {output_cum_offsets_tmp_dtype, output_cum_offsets_tmp_dtype};
}
PD_BUILD_OP(speculate_get_output_padding_offset)
PD_BUILD_STATIC_OP(speculate_get_output_padding_offset)
.Inputs({"output_cum_offsets_tmp", "out_token_num", "seq_lens_output"})
.Outputs({"output_padding_offset", "output_cum_offsets"})
.Attrs({"max_seq_len: int"})

View File

@@ -16,6 +16,10 @@
#include "paddle/extension.h"
#include "xpu/plugin.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
const paddle::Tensor& input_ids,
const paddle::Tensor& draft_tokens,
@@ -110,7 +114,7 @@ std::vector<paddle::DataType> SpeculateGetPaddingOffsetInferDtype(
seq_len_dtype};
}
PD_BUILD_OP(speculate_get_padding_offset)
PD_BUILD_STATIC_OP(speculate_get_padding_offset)
.Inputs({"input_ids",
"draft_tokens",
"cum_offsets",

View File

@@ -16,6 +16,10 @@
#include "paddle/extension.h"
#include "xpu/plugin.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
@@ -61,7 +65,7 @@ std::vector<paddle::DataType> SpeculateGetSeqLensOutputInferDtype(
return {seq_lens_this_time_dtype};
}
PD_BUILD_OP(speculate_get_seq_lens_output)
PD_BUILD_STATIC_OP(speculate_get_seq_lens_output)
.Inputs({"seq_lens_this_time", "seq_lens_encoder", "seq_lens_decoder"})
.Outputs({"seq_lens_output"})
.SetKernelFn(PD_KERNEL(SpeculateGetSeqLensOutput))

View File

@@ -16,6 +16,10 @@
#include "paddle/extension.h"
#include "xpu/plugin.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &accept_num,
@@ -53,7 +57,7 @@ void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
PD_CHECK(r == 0, "speculate_clear_accept_nums_kernel failed.");
}
PD_BUILD_OP(speculate_set_value_by_flags_and_idx)
PD_BUILD_STATIC_OP(speculate_set_value_by_flags_and_idx)
.Inputs({"pre_ids_all",
"accept_tokens",
"accept_num",

View File

@@ -17,6 +17,10 @@
#include "speculate_msg.h" // NOLINT
#include "xpu/plugin.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
// 为不修改接口调用方式,入参暂不改变
void SpeculateStepSchedule(
const paddle::Tensor &stop_flags,
@@ -150,7 +154,7 @@ void SpeculateStepSchedule(
}
}
PD_BUILD_OP(speculate_step_reschedule)
PD_BUILD_STATIC_OP(speculate_step_reschedule)
.Inputs({"stop_flags",
"seq_lens_this_time",
"ori_seq_lens_encoder",

View File

@@ -17,20 +17,25 @@
#include "paddle/phi/core/enforce.h"
#include "xpu/plugin.h"
void TokenPenaltyMultiScores(const paddle::Tensor& pre_ids,
const paddle::Tensor& logits,
const paddle::Tensor& penalty_scores,
const paddle::Tensor& frequency_scores,
const paddle::Tensor& presence_scores,
const paddle::Tensor& temperatures,
const paddle::Tensor& bad_tokens,
const paddle::Tensor& cur_len,
const paddle::Tensor& min_len,
const paddle::Tensor& eos_token_id,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& output_padding_offset,
const paddle::Tensor& output_cum_offsets,
const int max_seq_len) {
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
void SpeculateTokenPenaltyMultiScores(
const paddle::Tensor& pre_ids,
const paddle::Tensor& logits,
const paddle::Tensor& penalty_scores,
const paddle::Tensor& frequency_scores,
const paddle::Tensor& presence_scores,
const paddle::Tensor& temperatures,
const paddle::Tensor& bad_tokens,
const paddle::Tensor& cur_len,
const paddle::Tensor& min_len,
const paddle::Tensor& eos_token_id,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& output_padding_offset,
const paddle::Tensor& output_cum_offsets,
const int max_seq_len) {
namespace api = baidu::xpu::api;
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
@@ -137,7 +142,7 @@ void TokenPenaltyMultiScores(const paddle::Tensor& pre_ids,
}
}
PD_BUILD_OP(speculate_get_token_penalty_multi_scores)
PD_BUILD_STATIC_OP(speculate_get_token_penalty_multi_scores)
.Inputs({"pre_ids",
"logits",
"penalty_scores",
@@ -154,4 +159,4 @@ PD_BUILD_OP(speculate_get_token_penalty_multi_scores)
.Outputs({"logits_out"})
.Attrs({"max_seq_len: int"})
.SetInplaceMap({{"logits", "logits_out"}})
.SetKernelFn(PD_KERNEL(TokenPenaltyMultiScores));
.SetKernelFn(PD_KERNEL(SpeculateTokenPenaltyMultiScores));

View File

@@ -16,6 +16,10 @@
#include "paddle/extension.h"
#include "xpu/plugin.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
namespace api = baidu::xpu::api;
void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
@@ -66,7 +70,7 @@ void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
PD_BUILD_OP(speculate_update_v3)
PD_BUILD_STATIC_OP(speculate_update_v3)
.Inputs({"seq_lens_encoder",
"seq_lens_decoder",
"not_need_stop",

View File

@@ -17,10 +17,13 @@
#include "paddle/common/flags.h"
#include "paddle/extension.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "ops/utility/debug.h"
#include "xpu/internal/infra_op.h"
#include "xpu/plugin.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
namespace api = baidu::xpu::api;
void SpeculateVerify(const paddle::Tensor &accept_tokens,
@@ -221,7 +224,7 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens,
}
}
PD_BUILD_OP(speculate_verify)
PD_BUILD_STATIC_OP(speculate_verify)
.Inputs({"accept_tokens",
"accept_num",
"step_idx",

View File

@@ -16,6 +16,10 @@
#include "paddle/extension.h"
#include "xpu/plugin.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
#define FIXED_TOPK_BASE(topk, ...) \
case (topk): { \
constexpr auto kTopK = topk; \
@@ -149,7 +153,7 @@ std::vector<paddle::DataType> TopPCandidatesInferDtype(
return {probs_dtype, paddle::DataType::INT64, paddle::DataType::INT32};
}
PD_BUILD_OP(top_p_candidates)
PD_BUILD_STATIC_OP(top_p_candidates)
.Inputs({"probs", "top_p", "output_padding_offset"})
.Outputs({"verify_scores", "verify_tokens", "actual_candidate_lens"})
.Attrs({"candidates_len: int", "max_seq_len: int"})

View File

@@ -0,0 +1,91 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ops/pybind/pybind.h"
#include "ops/remote_cache_kv_ipc.h"
#include "ops/utility/env.h"
#include "paddle/extension.h"
XPU_DECLARE_BOOL(fmt_write_cache_completed_signal, false);
using cache_write_complete_signal_type =
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data;
paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank,
const bool keep_pd_step_flag) {
cache_write_complete_signal_type kv_signal_metadata;
const char *fmt_write_cache_completed_signal_str =
std::getenv("FLAGS_fmt_write_cache_completed_signal");
if (fmt_write_cache_completed_signal_str &&
(std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 ||
std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) {
kv_signal_metadata =
RemoteCacheKvIpc::open_shm_and_get_complete_signal_meta_data(
rank, keep_pd_step_flag);
}
auto kv_signal_metadata_out =
paddle::full({3}, -1, paddle::DataType::INT64, paddle::CPUPlace());
kv_signal_metadata_out.data<int64_t>()[0] =
static_cast<int64_t>(kv_signal_metadata.layer_id);
kv_signal_metadata_out.data<int64_t>()[1] =
reinterpret_cast<int64_t>(kv_signal_metadata.shm_ptr);
kv_signal_metadata_out.data<int64_t>()[2] =
static_cast<int64_t>(kv_signal_metadata.shm_fd);
return kv_signal_metadata_out;
}
void InitKVSignalPerQuery(const paddle::Tensor &seq_lens_encoder_tensor,
const paddle::Tensor &seq_lens_this_time_tensor,
const paddle::Tensor &seq_lens_decoder_tensor,
const int rank,
const int num_layers) {
if (FLAGS_fmt_write_cache_completed_signal) {
int real_bsz = seq_lens_this_time_tensor.dims()[0];
// GPU init, cp to cpu?
auto seq_lens_encoder_cpu =
seq_lens_encoder_tensor.copy_to(paddle::CPUPlace(), false);
auto seq_lens_decoder_cpu =
seq_lens_decoder_tensor.copy_to(paddle::CPUPlace(), false);
RemoteCacheKvIpc::kv_complete_signal_meta_data_per_query.init(
seq_lens_encoder_cpu.data<int>(),
seq_lens_decoder_cpu.data<int>(),
rank,
num_layers,
real_bsz);
}
}
std::vector<paddle::Tensor> OpenShmAndGetMetaSignal(
const int rank, const bool keep_pd_step_flag) {
return {OpenShmAndGetMetaSignalFunc(rank, keep_pd_step_flag)};
}
std::vector<std::vector<int64_t>> OpenShmAndGetMetaSignalShape(
const int rank, const bool keep_pd_step_flag) {
return {{3}};
}
std::vector<paddle::DataType> OpenShmAndGetMetaSignalDtype(
const int rank, const bool keep_pd_step_flag) {
return {paddle::DataType::INT64};
}
PD_BUILD_OP(open_shm_and_get_meta_signal)
.Inputs({})
.Outputs({"kv_signal_metadata"})
.Attrs({"rank: int", "keep_pd_step_flag: bool"})
.SetKernelFn(PD_KERNEL(OpenShmAndGetMetaSignal))
.SetInferShapeFn(PD_INFER_SHAPE(OpenShmAndGetMetaSignalShape))
.SetInferDtypeFn(PD_INFER_DTYPE(OpenShmAndGetMetaSignalDtype));

View File

@@ -0,0 +1,46 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <sys/mman.h> // NOLINT
#include "cuda_runtime_api.h" // NOLINT
#include "paddle/extension.h"
#include "xpu/runtime.h"
#include "ops/pybind/pybind.h"
void check_xpu_error(int error) {
if (error != XPU_SUCCESS) {
throw XPUError(error);
}
}
// 封装xpu_host_alloc的Python函数
uintptr_t custom_xpu_host_alloc(size_t size, unsigned int flags) {
void* ptr = nullptr;
// check_xpu_error(xpu_host_alloc(&ptr, size, flags));
ptr = malloc(size);
PD_CHECK(ptr != nullptr);
PD_CHECK(mlock(ptr, size) == 0);
return reinterpret_cast<uintptr_t>(ptr);
}
// 封装xpu_host_free的Python函数
void custom_xpu_host_free(uintptr_t ptr) {
check_xpu_error(xpu_host_free(reinterpret_cast<void*>(ptr)));
}
// 封装cudaHostRegister的Python函数将可分页内存注册为锁页的
void xpu_cuda_host_register(uintptr_t ptr, size_t size, unsigned int flags) {
cudaError_t e = cudaHostRegister(reinterpret_cast<void*>(ptr), size, flags);
PD_CHECK(e == cudaSuccess, cudaGetErrorString(e));
}

View File

@@ -0,0 +1,111 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ops/pybind/cachekv_signal_thread_worker.h"
#include <cuda_runtime_api.h>
#include "ops/remote_cache_kv_ipc.h"
#include "ops/utility/env.h"
XPU_DECLARE_BOOL(fmt_write_cache_completed_signal, false);
CacheKvSignalThreadWorker::CacheKvSignalThreadWorker() : stop(false) {
xpu_stream_create(&write_cache_kv_stream);
int devid;
auto ret = xpu_current_device(&devid);
PD_CHECK(ret == 0, "xpu_current_device failed.");
auto func = [this, devid]() {
int old_dev;
xpu_current_device(&old_dev);
auto ret = xpu_set_device(devid);
PD_CHECK(ret == 0, "xpu_set_device failed.");
ret = cudaSetDevice(devid);
PD_CHECK(ret == 0, "cudaSetDevice failed.");
while (true) {
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(write_mutex);
if (stop) return;
if (!signal_task_queue.empty()) {
task = std::move(signal_task_queue.front());
signal_task_queue.pop();
} else {
lock.unlock();
std::this_thread::sleep_for(std::chrono::microseconds(1));
continue;
}
}
task(); // 执行任务
}
};
worker_thread = std::thread(func);
}
void CacheKvSignalThreadWorker::push_signal_task(XPUEvent e1, void* meta_data) {
auto func = [this, e1, meta_data]() {
xpu_stream_wait_event(write_cache_kv_stream, e1);
xpu_wait(write_cache_kv_stream);
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise(meta_data);
xpu_event_destroy(e1);
};
std::lock_guard<std::mutex> lock(write_mutex);
signal_task_queue.push(func);
}
void CacheKvSignalThreadWorker::push_signal_task_per_query(XPUEvent e1,
void* meta_data) {
auto func = [this, e1, meta_data]() {
xpu_stream_wait_event(write_cache_kv_stream, e1);
xpu_wait(write_cache_kv_stream);
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_per_query(
meta_data);
xpu_event_destroy(e1);
};
std::lock_guard<std::mutex> lock(write_mutex);
signal_task_queue.push(func);
}
void CacheKvSignalThreadWorker::sync_all_signals() {
{
std::unique_lock<std::mutex> lock(write_mutex);
while (!signal_task_queue.empty()) {
// 1 微秒休眠
lock.unlock();
std::this_thread::sleep_for(std::chrono::microseconds(1));
lock.lock();
}
stop = true;
}
worker_thread.join();
xpu_stream_destroy(write_cache_kv_stream);
}
paddle::Tensor create_cachekv_signal_thread() {
CacheKvSignalThreadWorker* worker = nullptr;
if (FLAGS_fmt_write_cache_completed_signal) {
worker = new CacheKvSignalThreadWorker();
}
auto t = paddle::full({1}, 0, paddle::DataType::INT64, paddle::CPUPlace());
t.data<int64_t>()[0] = reinterpret_cast<int64_t>(worker);
return t;
}
void destroy_cachekv_signal_thread(const paddle::Tensor& t) {
auto worker =
reinterpret_cast<CacheKvSignalThreadWorker*>(t.data<int64_t>()[0]);
if (FLAGS_fmt_write_cache_completed_signal) {
PD_CHECK(worker != nullptr, "cachekv_signal_thread should not be nullptr");
worker->sync_all_signals();
delete worker;
} else {
PD_CHECK(worker == nullptr,
"cachekv_signal_thread should be nullptr if not pd split");
}
}

View File

@@ -0,0 +1,35 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <mutex>
#include <queue>
#include <thread>
#include "paddle/extension.h"
#include "xpu/runtime.h"
struct CacheKvSignalThreadWorker {
CacheKvSignalThreadWorker();
void push_signal_task(XPUEvent e1, void* meta_data);
void push_signal_task_per_query(XPUEvent e1, void* meta_data);
void sync_all_signals();
std::thread worker_thread;
std::queue<std::function<void()>> signal_task_queue;
std::mutex write_mutex;
XPUStream write_cache_kv_stream;
bool stop;
};
paddle::Tensor create_cachekv_signal_thread();
void destroy_cachekv_signal_thread(const paddle::Tensor& t);

View File

@@ -0,0 +1,26 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "cuda_runtime_api.h" // NOLINT
#include "paddle/extension.h"
#include "xpu/runtime.h"
uintptr_t xpu_get_peer_mem_addr(uintptr_t ptr) {
struct cudaPointerAttributes pointerAttr;
cudaPointerGetAttributes(&pointerAttr, reinterpret_cast<void*>(ptr));
PD_CHECK(pointerAttr.hostPointer != nullptr,
"Failed to get host pointer from device pointer");
uintptr_t ptr_out = reinterpret_cast<uintptr_t>(pointerAttr.hostPointer);
return ptr_out;
}

View File

@@ -0,0 +1,26 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "xpu/runtime.h"
void prof_start() {
int ret = xpu_profiler_start();
PD_CHECK(ret == 0, "xpu_profiler_start error");
}
void prof_stop() {
int ret = xpu_profiler_stop();
PD_CHECK(ret == 0, "xpu_profiler_stop error");
}

View File

@@ -0,0 +1,704 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ops/pybind/pybind.h"
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "cuda_runtime_api.h" // NOLINT
#include "paddle/extension.h"
namespace py = pybind11;
uintptr_t custom_xpu_host_alloc(size_t size, unsigned int flags);
void custom_xpu_host_free(uintptr_t ptr);
uintptr_t xpu_get_peer_mem_addr(uintptr_t ptr);
void xpu_cuda_host_register(uintptr_t ptr,
size_t size,
unsigned int flags = cudaHostRegisterDefault);
void prof_start();
void prof_stop();
void InitKVSignalPerQuery(const paddle::Tensor &seq_lens_encoder_tensor,
const paddle::Tensor &seq_lens_this_time_tensor,
const paddle::Tensor &seq_lens_decoder_tensor,
const int rank,
const int num_layers);
void GetOutputKVSignal(const paddle::Tensor &x,
int64_t rank_id,
bool wait_flag);
std::vector<paddle::Tensor> MoERedundantTopKSelect(
const paddle::Tensor& gating_logits,
const paddle::Tensor& expert_id_to_ep_rank_array,
const paddle::Tensor& expert_in_rank_num_list,
paddle::Tensor& tokens_per_expert_stats_list, // NOLINT
const paddle::optional<paddle::Tensor>& bias,
const int moe_topk,
const bool apply_norm_weight,
const bool enable_softmax_top_k_fused,
const int redundant_ep_rank_num_plus_one);
void set_ncluster(int num) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
xpu_ctx->x_context()->set_ncluster(num);
}
std::vector<paddle::Tensor> RmsNorm(
const paddle::Tensor& x,
const paddle::optional<paddle::Tensor>& bias,
const paddle::optional<paddle::Tensor>& residual,
const paddle::Tensor& norm_weight,
const paddle::optional<paddle::Tensor>& norm_bias,
const float epsilon,
const int begin_norm_axis,
const float quant_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound);
std::vector<paddle::Tensor> WeightOnlyLinear(
const paddle::Tensor& x,
const paddle::Tensor& weight,
const paddle::Tensor& weight_scale,
const paddle::optional<paddle::Tensor>& bias,
const std::string& weight_dtype,
const int arch,
const int group_size);
std::vector<paddle::Tensor> MoeEPCombine(const paddle::Tensor& ffn_out,
const paddle::Tensor& moe_index,
const paddle::Tensor& weights,
const int recv_token_num,
const int expand_token_num,
const int hidden_dim,
const int topk);
std::vector<paddle::Tensor> EPMoeExpertDispatch(
const paddle::Tensor& input,
const paddle::Tensor& topk_ids,
const paddle::Tensor& topk_weights,
const paddle::optional<paddle::Tensor>& input_scales,
const std::vector<int>& token_nums_per_expert,
const int token_nums_this_rank,
const std::string quant_method);
std::vector<paddle::Tensor> MoeExpertFFN(
const paddle::Tensor& ffn_in,
const paddle::Tensor& token_num_info,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn2_bias,
const paddle::optional<paddle::Tensor>& ffn1_act_scale,
const paddle::optional<paddle::Tensor>& ffn2_act_scale,
const paddle::optional<paddle::Tensor>& ffn1_weight_scale,
const paddle::optional<paddle::Tensor>& ffn2_weight_scale,
const paddle::optional<paddle::Tensor>& ffn2_shift,
const paddle::optional<paddle::Tensor>& ffn2_smooth,
const std::string& quant_method,
const int hadamard_blocksize,
const int valid_token_num);
std::vector<paddle::Tensor> MoeTopkSelect(
const paddle::Tensor& gating_logits,
const paddle::optional<paddle::Tensor>& bias,
const int moe_topk,
const bool apply_norm_weight);
void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
const paddle::Tensor& draft_tokens,
const paddle::Tensor& pre_ids,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& output_cum_offsets,
const paddle::Tensor& stop_flags,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& max_dec_len,
const paddle::Tensor& end_ids,
const paddle::Tensor& base_model_draft_tokens,
const int max_seq_len,
const int substep);
void SpeculateUpdateV3(const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& draft_tokens,
const paddle::Tensor& actual_draft_token_nums,
const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& is_block_step,
const paddle::Tensor& stop_nums);
void SpeculateTokenPenaltyMultiScores(
const paddle::Tensor& pre_ids,
const paddle::Tensor& logits,
const paddle::Tensor& penalty_scores,
const paddle::Tensor& frequency_scores,
const paddle::Tensor& presence_scores,
const paddle::Tensor& temperatures,
const paddle::Tensor& bad_tokens,
const paddle::Tensor& cur_len,
const paddle::Tensor& min_len,
const paddle::Tensor& eos_token_id,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& output_padding_offset,
const paddle::Tensor& output_cum_offsets,
const int max_seq_len);
void SpeculateUpdateV3(const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& draft_tokens,
const paddle::Tensor& actual_draft_token_nums,
const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& is_block_step,
const paddle::Tensor& stop_nums);
std::vector<paddle::Tensor> TopPCandidates(
const paddle::Tensor& probs,
const paddle::Tensor& top_p,
const paddle::Tensor& output_padding_offset,
int candidates_len,
int max_seq_len);
void SpeculateVerify(const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& step_idx,
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& draft_tokens,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& verify_tokens,
const paddle::Tensor& verify_scores,
const paddle::Tensor& max_dec_len,
const paddle::Tensor& end_tokens,
const paddle::Tensor& is_block_step,
const paddle::Tensor& output_cum_offsets,
const paddle::Tensor& actual_candidate_len,
const paddle::Tensor& actual_draft_token_nums,
const paddle::Tensor& topp,
int max_seq_len,
int verify_window,
bool enable_topp);
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder);
void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor& pre_ids_all,
const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx);
void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& input_ids,
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& seq_lens_encoder_record,
const paddle::Tensor& seq_lens_decoder_record,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& batch_drop,
const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& base_model_seq_lens_encoder,
const paddle::Tensor& base_model_seq_lens_decoder,
const paddle::Tensor& base_model_step_idx,
const paddle::Tensor& base_model_stop_flags,
const paddle::Tensor& base_model_is_block_step,
const paddle::Tensor& base_model_draft_tokens,
const int max_draft_token,
const bool truncate_first_token,
const bool splitwise_prefill);
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder,
const paddle::Tensor& base_model_stop_flags);
std::vector<paddle::Tensor> EagleGetHiddenStates(
const paddle::Tensor& input,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& stop_flags,
const paddle::Tensor& accept_nums,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder,
const int actual_draft_token_num);
std::vector<paddle::Tensor> EagleGetSelfHiddenStates(
const paddle::Tensor& input,
const paddle::Tensor& last_seq_lens_this_time,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& step_idx);
std::vector<paddle::Tensor> SpeculateGetOutputPaddingOffset(
const paddle::Tensor& output_cum_offsets_tmp,
const paddle::Tensor& out_token_num,
const paddle::Tensor& seq_lens_output,
const int max_seq_len);
std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
const paddle::Tensor& input_ids,
const paddle::Tensor& draft_tokens,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& token_num,
const paddle::Tensor& seq_len,
const paddle::Tensor& seq_lens_encoder);
void MTPStepPaddle(
const paddle::Tensor& base_model_stop_flags,
const paddle::Tensor& stop_flags,
const paddle::Tensor& batch_drop,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& block_tables, // [bsz, block_num_per_seq]
const paddle::Tensor& encoder_block_lens,
const paddle::Tensor& used_list_len,
const paddle::Tensor& free_list,
const paddle::Tensor& free_list_len,
const int block_size,
const int max_draft_tokens);
void SpeculateStepSchedule(
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& ori_seq_lens_encoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& block_tables, // [bsz, block_num_per_seq]
const paddle::Tensor& encoder_block_lens,
const paddle::Tensor& is_block_step,
const paddle::Tensor& step_block_list,
const paddle::Tensor& step_lens,
const paddle::Tensor& recover_block_list,
const paddle::Tensor& recover_lens,
const paddle::Tensor& need_block_list,
const paddle::Tensor& need_block_len,
const paddle::Tensor& used_list_len,
const paddle::Tensor& free_list,
const paddle::Tensor& free_list_len,
const paddle::Tensor& input_ids,
const paddle::Tensor& pre_ids,
const paddle::Tensor& step_idx,
const paddle::Tensor& next_tokens,
const paddle::Tensor& first_token_ids,
const paddle::Tensor& accept_num,
const int block_size,
const int encoder_decoder_block_num,
const int max_draft_tokens);
std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder);
PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("cuda_host_alloc",
&custom_xpu_host_alloc,
"Allocate pinned memory",
py::arg("size"),
py::arg("flags") = 0x00);
m.def("cuda_host_free",
&custom_xpu_host_free,
"Free pinned memory",
py::arg("ptr"));
m.def("get_peer_mem_addr",
&xpu_get_peer_mem_addr,
"Get Host memory address of device pointer",
py::arg("ptr"));
m.def("cuda_host_register",
&xpu_cuda_host_register,
"Register pinned memory",
py::arg("ptr"),
py::arg("size"),
py::arg("flags") = cudaHostRegisterDefault);
m.def("create_kv_signal_sender",
&create_cachekv_signal_thread,
"init write cache kv signal thread");
m.def("destroy_kv_signal_sender",
&destroy_cachekv_signal_thread,
"write cache kv signal thread exit");
m.def("prof_start", &prof_start, "prof_start");
m.def("prof_stop", &prof_stop, "prof_stop");
m.def("moe_redundant_topk_select",
&MoERedundantTopKSelect,
py::arg("gating_logits"),
py::arg("expert_id_to_ep_rank_array"),
py::arg("expert_in_rank_num_list"),
py::arg("tokens_per_expert_stats_list"),
py::arg("bias"),
py::arg("moe_topk"),
py::arg("apply_norm_weight"),
py::arg("enable_softmax_top_k_fused"),
py::arg("redundant_ep_rank_num_plus_one"),
"moe export RedundantTopKSelect function");
m.def("set_ncluster", &set_ncluster, "set ncluster");
/**
* open_shm_and_get_meta_signal.cc
* InitKVSingnalPerQuery
*/
m.def("init_kv_signal_per_query",
&InitKVSignalPerQuery,
py::arg("seq_lens_encoder_tensor"),
py::arg("seq_lens_this_time_tensor"),
py::arg("seq_lens_decoder_tensor"),
py::arg("rank"),
py::arg("num_layers"),
"init_kv_signal_per_query function");
/**
* GetOutputKVSignal
*/
m.def("get_output_kv_signal",
&GetOutputKVSignal,
py::arg("x"),
py::arg("rank_id"),
py::arg("wait_flag"),
"get_output_kv_signal function");
m.def("fused_rms_norm_xpu",
&RmsNorm,
"Fused RMS normalization for XPU",
py::arg("x"), // 输入张量
py::arg("bias"), // 偏置(可选)
py::arg("residual"), // 残差连接(可选)
py::arg("norm_weight"), // 归一化权重
py::arg("norm_bias"), // 归一化偏置(可选)
py::arg("epsilon"), // 数值稳定项
py::arg("begin_norm_axis"), // 归一化起始维度
py::arg("quant_scale"), // 量化缩放因子
py::arg("quant_round_type"), // 量化舍入类型
py::arg("quant_max_bound"), // 量化最大值边界
py::arg("quant_min_bound") // 量化最小值边界
);
m.def("weight_only_linear_xpu",
&WeightOnlyLinear,
"Weight-only quantized linear layer",
py::arg("x"),
py::arg("weight"),
py::arg("weight_scale"),
py::arg("bias"),
py::arg("weight_dtype"),
py::arg("arch"),
py::arg("group_size"));
m.def("ep_moe_expert_combine",
&MoeEPCombine,
"MoE (Mixture of Experts) EP combine operation",
py::arg("ffn_out"), // FFN输出张量 [token_num, hidden_dim]
py::arg("moe_index"), // MoE专家索引张量 [token_num, topk]
py::arg("weights"), // 专家权重张量 [token_num, topk]
py::arg("recv_token_num"), // 接收的token数量int
py::arg("expand_token_num"), // 扩展的token数量int
py::arg("hidden_dim"), // 隐藏层维度int
py::arg("topk") // 选择的专家数量int
);
m.def("ep_moe_expert_dispatch",
&EPMoeExpertDispatch,
"EP MoE expert dispatch operation",
py::arg("input"),
py::arg("topk_ids"),
py::arg("topk_weights"),
py::arg("input_scales") = py::none(),
py::arg("token_nums_per_expert"),
py::arg("token_nums_this_rank"),
py::arg("quant_method"));
m.def("moe_expert_ffn",
&MoeExpertFFN,
"MoE expert feed-forward network with quantization support",
py::arg("ffn_in"), // [valid_token_num, hidden_dim]
py::arg("token_num_info"),
py::arg("ffn1_weight"),
py::arg("ffn2_weight"),
py::arg("ffn1_bias") = py::none(),
py::arg("ffn2_bias") = py::none(),
py::arg("ffn1_act_scale") = py::none(),
py::arg("ffn2_act_scale") = py::none(),
py::arg("ffn1_weight_scale") = py::none(),
py::arg("ffn2_weight_scale") = py::none(),
py::arg("ffn2_shift") = py::none(),
py::arg("ffn2_smooth") = py::none(),
py::arg("quant_method"),
py::arg("hadamard_blocksize"),
py::arg("valid_token_num"));
m.def("moe_topk_select",
&MoeTopkSelect,
"MoE Top-k selection: selects top-k experts via gating logits",
py::arg("gating_logits"),
py::arg("bias") = py::none(),
py::arg("moe_topk"),
py::arg("apply_norm_weight"));
m.def("draft_model_update",
&DraftModelUpdate,
"Update draft model states during speculative decoding",
py::arg("inter_next_tokens"), // 中间next tokens张量
py::arg("draft_tokens"), // 草稿token张量
py::arg("pre_ids"), // 前置ID张量
py::arg("seq_lens_this_time"), // 当前步骤序列长度张量
py::arg("seq_lens_encoder"), // 编码器序列长度张量
py::arg("seq_lens_decoder"), // 解码器序列长度张量
py::arg("step_idx"), // 步骤索引张量
py::arg("output_cum_offsets"), // 输出累积偏移量张量
py::arg("stop_flags"), // 停止标志张量
py::arg("not_need_stop"), // 无需停止标志张量
py::arg("max_dec_len"), // 最大解码长度张量
py::arg("end_ids"), // 结束ID张量
py::arg("base_model_draft_tokens"), // 基础模型草稿token张量
py::arg("max_seq_len"), // 最大序列长度int
py::arg("substep") // 子步骤编号int
);
m.def("speculate_get_token_penalty_multi_scores",
&SpeculateTokenPenaltyMultiScores,
py::arg("pre_ids"),
py::arg("logits"),
py::arg("penalty_scores"),
py::arg("frequency_scores"),
py::arg("presence_scores"),
py::arg("temperatures"),
py::arg("bad_tokens"),
py::arg("cur_len"),
py::arg("min_len"),
py::arg("eos_token_id"),
py::arg("seq_lens_this_time"),
py::arg("output_padding_offset"),
py::arg("output_cum_offsets"),
py::arg("max_seq_len"),
"Applies token penalty with multiple scores");
m.def("speculate_update_v3",
&SpeculateUpdateV3,
py::arg("seq_lens_encoder"),
py::arg("seq_lens_decoder"),
py::arg("not_need_stop"),
py::arg("draft_tokens"),
py::arg("actual_draft_token_nums"),
py::arg("accept_tokens"),
py::arg("accept_num"),
py::arg("stop_flags"),
py::arg("seq_lens_this_time"),
py::arg("is_block_step"),
py::arg("stop_nums"),
"Update speculative decoding states (V3)");
m.def("top_p_candidates",
&TopPCandidates,
py::arg("probs"),
py::arg("top_p"),
py::arg("output_padding_offset"),
py::arg("candidates_len"),
py::arg("max_seq_len"),
"Generate top-p candidates based on probability distributions");
m.def("speculate_verify",
&SpeculateVerify,
py::arg("accept_tokens"),
py::arg("accept_num"),
py::arg("step_idx"),
py::arg("stop_flags"),
py::arg("seq_lens_encoder"),
py::arg("seq_lens_decoder"),
py::arg("draft_tokens"),
py::arg("seq_lens_this_time"),
py::arg("verify_tokens"),
py::arg("verify_scores"),
py::arg("max_dec_len"),
py::arg("end_tokens"),
py::arg("is_block_step"),
py::arg("output_cum_offsets"),
py::arg("actual_candidate_len"),
py::arg("actual_draft_token_nums"),
py::arg("topp"),
py::arg("max_seq_len"),
py::arg("verify_window"),
py::arg("enable_topp"),
"Perform speculative verification for decoding");
m.def("speculate_clear_accept_nums",
&SpeculateClearAcceptNums,
py::arg("accept_num"),
py::arg("seq_lens_decoder"),
"Clear accept numbers based on decoder sequence lengths");
m.def("speculate_set_value_by_flags_and_idx",
&SpeculateSetValueByFlagsAndIdx,
py::arg("pre_ids_all"),
py::arg("accept_tokens"),
py::arg("accept_num"),
py::arg("stop_flags"),
py::arg("seq_lens_this_time"),
py::arg("seq_lens_encoder"),
py::arg("seq_lens_decoder"),
py::arg("step_idx"),
"Set values based on flags and indices in speculative decoding");
m.def("draft_model_preprocess",
&DraftModelPreprocess,
py::arg("draft_tokens"),
py::arg("input_ids"),
py::arg("stop_flags"),
py::arg("seq_lens_this_time"),
py::arg("seq_lens_encoder"),
py::arg("seq_lens_decoder"),
py::arg("step_idx"),
py::arg("seq_lens_encoder_record"),
py::arg("seq_lens_decoder_record"),
py::arg("not_need_stop"),
py::arg("batch_drop"),
py::arg("accept_tokens"),
py::arg("accept_num"),
py::arg("base_model_seq_lens_encoder"),
py::arg("base_model_seq_lens_decoder"),
py::arg("base_model_step_idx"),
py::arg("base_model_stop_flags"),
py::arg("base_model_is_block_step"),
py::arg("base_model_draft_tokens"),
py::arg("max_draft_token"),
py::arg("truncate_first_token"),
py::arg("splitwise_prefill"),
"Preprocess data for draft model in speculative decoding");
m.def("draft_model_postprocess",
&DraftModelPostprocess,
py::arg("base_model_draft_tokens"),
py::arg("base_model_seq_lens_this_time"),
py::arg("base_model_seq_lens_encoder"),
py::arg("base_model_stop_flags"),
"Postprocess data for draft model in speculative decoding");
m.def("eagle_get_hidden_states",
&EagleGetHiddenStates,
py::arg("input"),
py::arg("seq_lens_this_time"),
py::arg("seq_lens_encoder"),
py::arg("seq_lens_decoder"),
py::arg("stop_flags"),
py::arg("accept_nums"),
py::arg("base_model_seq_lens_this_time"),
py::arg("base_model_seq_lens_encoder"),
py::arg("actual_draft_token_num"),
"Get draft model hidden states");
m.def("eagle_get_self_hidden_states",
&EagleGetSelfHiddenStates,
py::arg("input"),
py::arg("last_seq_lens_this_time"),
py::arg("seq_lens_this_time"),
py::arg("step_idx"),
"Rebuild draft model hidden states");
m.def("speculate_get_output_padding_offset",
&SpeculateGetOutputPaddingOffset,
py::arg("output_cum_offsets_tmp"),
py::arg("out_token_num"),
py::arg("seq_lens_output"),
py::arg("max_seq_len"),
"Get output padding offset");
m.def("speculate_get_padding_offset",
&SpeculateGetPaddingOffset,
py::arg("input_ids"),
py::arg("draft_tokens"),
py::arg("cum_offsets"),
py::arg("token_num"),
py::arg("seq_len"),
py::arg("seq_lens_encoder"),
"Get padding offset");
m.def("mtp_step_paddle",
&MTPStepPaddle,
py::arg("base_model_stop_flags"),
py::arg("stop_flags"),
py::arg("batch_drop"),
py::arg("seq_lens_this_time"),
py::arg("seq_lens_encoder"),
py::arg("seq_lens_decoder"),
py::arg("block_tables"), // [bsz, block_num_per_seq]
py::arg("encoder_block_lens"),
py::arg("used_list_len"),
py::arg("free_list"),
py::arg("free_list_len"),
py::arg("block_size"),
py::arg("max_draft_tokens"),
"MTP step paddle");
m.def("speculate_step_reschedule",
&SpeculateStepSchedule,
py::arg("stop_flags"),
py::arg("seq_lens_this_time"),
py::arg("ori_seq_lens_encoder"),
py::arg("seq_lens_encoder"),
py::arg("seq_lens_decoder"),
py::arg("block_tables"),
py::arg("encoder_block_lens"),
py::arg("is_block_step"),
py::arg("step_block_list"),
py::arg("step_lens"),
py::arg("recover_block_list"),
py::arg("recover_lens"),
py::arg("need_block_list"),
py::arg("need_block_len"),
py::arg("used_list_len"),
py::arg("free_list"),
py::arg("free_list_len"),
py::arg("input_ids"),
py::arg("pre_ids"),
py::arg("step_idx"),
py::arg("next_tokens"),
py::arg("first_token_ids"),
py::arg("accept_num"),
py::arg("block_size"),
py::arg("encoder_decoder_block_num"),
py::arg("max_draft_tokens"),
"Step reschedule");
m.def("speculate_get_seq_lens_output",
&SpeculateGetSeqLensOutput,
py::arg("seq_lens_this_time"),
py::arg("seq_lens_encoder"),
py::arg("seq_lens_decoder"),
"Get sequence lengths output");
// 添加XPU错误信息的异常处理类
py::register_exception<XPUError>(m, "XPUError");
}

View File

@@ -0,0 +1,29 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda_runtime_api.h>
#include <xpu/runtime.h>
#include <exception>
#include "ops/pybind/cachekv_signal_thread_worker.h"
// 自定义异常类用于处理XPU错误
class XPUError : public std::exception {
public:
explicit XPUError(int error) : error_(error) {}
const char *what() const noexcept override { return xpu_strerror(error_); }
private:
int error_;
};

View File

@@ -0,0 +1,95 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"
#include "xpu_multiprocess.h" // NOLINT
void ReadDataIpc(const paddle::Tensor &tmp_input, const std::string &shm_name) {
volatile shmStruct *shm = NULL;
sharedMemoryInfo info;
int ret = sharedMemoryOpen(shm_name.c_str(), sizeof(shmStruct), &info);
PD_CHECK(ret == 0, "sharedMemoryOpen failed");
shm = static_cast<volatile shmStruct *>(info.addr);
void *ptr = nullptr;
#if XPURT_VERSION_MAJOR == 5
ret = xpu_ipc_open_memhandle(
&ptr, *(XPUIpcMemHandle *)&shm->memHandle, 0x01); // NOLINT
#elif XPURT_VERSION_MAJOR == 4
PD_THROW("kl2 not support prefix cache");
#endif
PD_CHECK(ret == XPU_SUCCESS, "xpu_ipc_open_memhandle failed");
PD_CHECK(tmp_input.place().GetType() == phi::AllocationType::CPU);
// switch (tmp_input.dtype()) {
// case paddle::DataType::FLOAT32:
// ret = xpu_memcpy(const_cast<float *>(tmp_input.data<float>()),
// ptr,
// tmp_input.numel() * sizeof(float),
// XPUMemcpyKind::XPU_DEVICE_TO_HOST);
// break;
// case paddle::DataType::FLOAT16:
// ret = xpu_memcpy(const_cast<phi::dtype::float16 *>(
// tmp_input.data<phi::dtype::float16>()),
// ptr,
// tmp_input.numel() * sizeof(phi::dtype::float16),
// XPUMemcpyKind::XPU_DEVICE_TO_HOST);
// break;
// case paddle::DataType::UINT8:
// ret = xpu_memcpy(const_cast<uint8_t *>(tmp_input.data<uint8_t>()),
// ptr,
// tmp_input.numel() * sizeof(uint8_t),
// XPUMemcpyKind::XPU_DEVICE_TO_HOST);
// break;
// default:
// PD_THROW("not support dtype: ",
// phi::DataTypeToString(tmp_input.dtype()));
// }
// PD_CHECK(ret == XPU_SUCCESS, "not support dtype");
// ret = xpu_ipc_close_memhandle(ptr);
// PD_CHECK(ret == XPU_SUCCESS, "not support dtype");
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
void *data_ptr = reinterpret_cast<void *>(shm->data_ptr_addr);
auto x = paddle::from_blob(data_ptr,
tmp_input.shape(),
tmp_input.dtype(),
tmp_input.layout(),
place);
paddle::Tensor y = tmp_input.copy_to(place, false);
ret = baidu::xpu::api::scale<float, float>(xpu_ctx->x_context(),
x.data<float>(),
y.data<float>(),
tmp_input.numel(),
true,
1.f,
2.f);
PD_CHECK(ret == XPU_SUCCESS, "add2 fail");
ret = xpu_memcpy(const_cast<float *>(tmp_input.data<float>()),
y.data<float>(),
tmp_input.numel() * sizeof(float),
XPUMemcpyKind::XPU_DEVICE_TO_HOST);
PD_CHECK(ret == XPU_SUCCESS, "xpu_memcpy fail");
sharedMemoryClose(&info);
}
PD_BUILD_OP(read_data_ipc)
.Inputs({"tmp_input"})
.Attrs({"shm_name: std::string"})
.Outputs({"tmp_input_out"})
.SetInplaceMap({{"tmp_input", "tmp_input_out"}})
.SetKernelFn(PD_KERNEL(ReadDataIpc));

View File

@@ -0,0 +1,113 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ops/remote_cache_kv_ipc.h"
#include <chrono>
#include <iostream>
#include "paddle/extension.h"
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data
RemoteCacheKvIpc::kv_complete_signal_meta_data;
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data_per_query
RemoteCacheKvIpc::kv_complete_signal_meta_data_per_query;
void* RemoteCacheKvIpc::kv_complete_signal_identity_ptr = nullptr;
bool RemoteCacheKvIpc::kv_complete_signal_shmem_opened = false;
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data
RemoteCacheKvIpc::open_shm_and_get_complete_signal_meta_data(
const int rank_id, const bool keep_pd_step_flag) {
if (RemoteCacheKvIpc::kv_complete_signal_shmem_opened) {
if (keep_pd_step_flag) {
return RemoteCacheKvIpc::kv_complete_signal_meta_data;
}
int32_t current_identity = (*reinterpret_cast<int32_t*>(
RemoteCacheKvIpc::kv_complete_signal_identity_ptr));
int32_t* write_ptr = reinterpret_cast<int32_t*>(
RemoteCacheKvIpc::kv_complete_signal_identity_ptr);
*write_ptr = (current_identity + 1) % 100003;
RemoteCacheKvIpc::kv_complete_signal_meta_data.layer_id = -1;
int32_t* layer_complete_ptr =
reinterpret_cast<int32_t*>(kv_complete_signal_meta_data.shm_ptr);
*layer_complete_ptr = -1;
return RemoteCacheKvIpc::kv_complete_signal_meta_data;
}
std::string flags_server_uuid;
if (const char* iflags_server_uuid_env_p = std::getenv("SHM_UUID")) {
std::string iflags_server_uuid_env_str(iflags_server_uuid_env_p);
flags_server_uuid = iflags_server_uuid_env_str;
}
std::string step_shm_name =
("splitwise_complete_prefilled_step_" + std::to_string(rank_id) + "_" +
flags_server_uuid);
std::string layer_shm_name =
("splitwise_complete_prefilled_layer_" + std::to_string(rank_id) + "_" +
flags_server_uuid);
if (const char* use_ep = std::getenv("ENABLE_EP_DP")) {
if (std::strcmp(use_ep, "1") == 0) {
step_shm_name = "splitwise_complete_prefilled_step_tprank0_dprank" +
std::to_string(rank_id) + "_" + flags_server_uuid;
layer_shm_name = "splitwise_complete_prefilled_layer_tprank0_dprank" +
std::to_string(rank_id) + "_" + flags_server_uuid;
}
}
int signal_shm_fd = shm_open(layer_shm_name.c_str(), O_CREAT | O_RDWR, 0666);
PD_CHECK(signal_shm_fd != -1,
"can not open shm for cache_kv_complete_signal.");
int signal_shm_ftruncate = ftruncate(signal_shm_fd, 4);
void* signal_ptr = mmap(0, 4, PROT_WRITE, MAP_SHARED, signal_shm_fd, 0);
PD_CHECK(signal_ptr != MAP_FAILED,
"can not open shm for cache_kv_compelete_identity.");
int32_t* write_signal_ptr = reinterpret_cast<int32_t*>(signal_ptr);
*write_signal_ptr = -1;
using type_meta_data =
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data;
// std::printf("#### open_shm_and_get_complete_signal_meta_data layer idx:%d,
// to ptx:%p \n",
// -1, signal_ptr);
type_meta_data meta_data(-1, signal_ptr, signal_shm_fd);
RemoteCacheKvIpc::kv_complete_signal_meta_data = meta_data;
int identity_shm_fd = shm_open(step_shm_name.c_str(), O_CREAT | O_RDWR, 0666);
PD_CHECK(identity_shm_fd != -1,
"can not open shm for cache_kv_compelete_identity.");
int identity_shm_ftruncate = ftruncate(identity_shm_fd, 4);
void* identity_ptr = mmap(0, 4, PROT_WRITE, MAP_SHARED, identity_shm_fd, 0);
PD_CHECK(identity_ptr != MAP_FAILED, "MAP_FAILED for prefill_identity.");
int32_t current_identity = (*reinterpret_cast<int32_t*>(identity_ptr));
int32_t* write_ptr = reinterpret_cast<int32_t*>(identity_ptr);
*write_ptr = (current_identity + 1) % 100003;
RemoteCacheKvIpc::kv_complete_signal_identity_ptr = identity_ptr;
RemoteCacheKvIpc::kv_complete_signal_shmem_opened = true;
return meta_data;
}
void RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise(
void* meta_data) {
int64_t* meta_data_ptr = reinterpret_cast<int64_t*>(meta_data);
int32_t layer_id = meta_data_ptr[0];
int32_t* ptr = reinterpret_cast<int32_t*>(meta_data_ptr[1]);
*ptr = layer_id;
// std::printf("#### save_cache_kv_complete_signal_layerwise layer idx:%d, to
// ptx:%p \n",
// *ptr, meta_data_ptr[1]);
}
void RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_per_query(
void* meta_data) {
RemoteCacheKvIpc::kv_complete_signal_meta_data_per_query.send_signal();
}

View File

@@ -0,0 +1,98 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fcntl.h>
#include <sys/ipc.h>
#include <sys/mman.h>
#include <sys/msg.h>
#include <sys/stat.h>
#include <unistd.h>
#include <cstdint>
#include <vector>
#include "msg_utils.h" // NOLINT
struct RemoteCacheKvIpc {
struct save_cache_kv_complete_signal_layerwise_meta_data {
int32_t layer_id = -1;
void* shm_ptr = nullptr;
int shm_fd = -1;
save_cache_kv_complete_signal_layerwise_meta_data() {}
save_cache_kv_complete_signal_layerwise_meta_data(int32_t layer_id_,
void* shm_ptr_,
int shm_fd_)
: layer_id(layer_id_), shm_ptr(shm_ptr_), shm_fd(shm_fd_) {}
};
struct save_cache_kv_complete_signal_layerwise_meta_data_per_query {
int layer_id_;
int num_layers_;
bool inited = false;
struct msgdatakv msg_sed;
int msgid;
save_cache_kv_complete_signal_layerwise_meta_data_per_query() {}
void init(const int* seq_lens_encoder,
const int* seq_lens_decoder,
const int rank,
const int num_layers,
const int real_bsz) {
layer_id_ = 0;
num_layers_ = num_layers;
msg_sed.mtype = 1;
int encoder_count = 0;
for (int i = 0; i < real_bsz; i++) {
if (seq_lens_encoder[i] > 0) {
msg_sed.mtext[3 * encoder_count + 2] = i;
msg_sed.mtext[3 * encoder_count + 3] = seq_lens_decoder[i];
msg_sed.mtext[3 * encoder_count + 4] = seq_lens_encoder[i];
encoder_count++;
}
}
msg_sed.mtext[0] = encoder_count;
if (!inited) {
// just init once
const int msg_id = 1024 + rank;
key_t key = ftok("/opt/", msg_id);
msgid = msgget(key, IPC_CREAT | 0666);
inited = true;
}
}
void send_signal() {
msg_sed.mtext[1] = layer_id_;
if ((msgsnd(msgid, &msg_sed, (MAX_BSZ * 3 + 2) * 4, 0)) == -1) {
printf("kv signal full msg buffer\n");
}
layer_id_ = (layer_id_ + 1);
assert(layer_id_ <= num_layers_);
}
};
static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data
kv_complete_signal_meta_data;
static RemoteCacheKvIpc::
save_cache_kv_complete_signal_layerwise_meta_data_per_query
kv_complete_signal_meta_data_per_query;
static void* kv_complete_signal_identity_ptr;
static bool kv_complete_signal_shmem_opened;
static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data
open_shm_and_get_complete_signal_meta_data(const int rank_id,
const bool keep_pd_step_flag);
static void save_cache_kv_complete_signal_layerwise(void* meta_data);
static void save_cache_kv_complete_signal_layerwise_per_query(
void* meta_data);
};

View File

@@ -0,0 +1,69 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "xpu_multiprocess.h" // NOLINT
template <typename T>
void set_data_ipc(const paddle::Tensor &tmp_input,
const std::string &shm_name) {
sharedMemoryInfo info;
volatile shmStruct *shm = NULL;
int ret = sharedMemoryCreate(shm_name.c_str(), sizeof(*shm), &info);
PD_CHECK(ret == 0, "sharedMemoryCreate failed");
shm = (volatile shmStruct *)info.addr;
memset((void *)shm, 0, sizeof(*shm)); // NOLINT
void *data_ptr_now =
reinterpret_cast<void *>(const_cast<T *>(tmp_input.data<T>()));
#if XPURT_VERSION_MAJOR == 5
ret = xpu_ipc_get_memhandle((XPUIpcMemHandle *)&shm->memHandle, // NOLINT
data_ptr_now);
#elif XPURT_VERSION_MAJOR == 4
PD_THROW("kl2 not support prefix cache");
#endif
PD_CHECK(ret == XPU_SUCCESS, "xpu_ipc_get_memhandle failed");
shm->data_ptr_addr = reinterpret_cast<uint64_t>((data_ptr_now));
}
void SetDataIpc(const paddle::Tensor &tmp_input, const std::string &shm_name) {
switch (tmp_input.type()) {
case paddle::DataType::FLOAT16: {
return set_data_ipc<paddle::float16>(tmp_input, shm_name);
}
case paddle::DataType::FLOAT32: {
return set_data_ipc<float>(tmp_input, shm_name);
}
case paddle::DataType::INT8: {
return set_data_ipc<int8_t>(tmp_input, shm_name);
}
case paddle::DataType::UINT8: {
return set_data_ipc<uint8_t>(tmp_input, shm_name);
}
case paddle::DataType::BFLOAT16: {
return set_data_ipc<paddle::bfloat16>(tmp_input, shm_name);
}
default: {
PD_THROW("NOT supported data type.");
break;
}
}
}
PD_BUILD_OP(set_data_ipc)
.Inputs({"tmp_input"})
.Attrs({"shm_name: std::string"})
.Outputs({"tmp_input_out"})
.SetInplaceMap({{"tmp_input", "tmp_input_out"}})
.SetKernelFn(PD_KERNEL(SetDataIpc));

View File

@@ -0,0 +1,57 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/tensor_meta.h"
#include "xpu/plugin.h"
#include "xpu_multiprocess.h" // NOLINT(build/include_subdir)
std::vector<paddle::Tensor> ShareExternalData(const paddle::Tensor &input,
const std::string shm_name,
const std::vector<int> &shape,
bool use_ipc) {
sharedMemoryInfo info;
int ret = sharedMemoryOpen(shm_name.c_str(), sizeof(shmStruct), &info);
PD_CHECK(ret == 0, "sharedMemoryOpen failed");
volatile shmStruct *shm = static_cast<volatile shmStruct *>(info.addr);
void *data_ptr_addr = nullptr;
if (use_ipc) {
#if XPURT_VERSION_MAJOR == 5
int ret = xpu_ipc_open_memhandle(
&data_ptr_addr, *(XPUIpcMemHandle *)&shm->memHandle, 0x01); // NOLINT
PD_CHECK(ret == XPU_SUCCESS, "xpu_ipc_open_memhandle failed");
#elif XPURT_VERSION_MAJOR == 4
PD_THROW("kl2 not support prefix cache");
#endif
} else {
data_ptr_addr = reinterpret_cast<void *>(shm->data_ptr_addr);
}
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
paddle::Tensor output = paddle::from_blob(
data_ptr_addr, shape, input.dtype(), input.layout(), place);
sharedMemoryClose(&info);
return {output};
}
PD_BUILD_OP(share_external_data)
.Inputs({"input"})
.Outputs({"output"})
.Attrs({"shm_name: std::string",
"shape: std::vector<int>",
"use_ipc: bool"})
.SetKernelFn(PD_KERNEL(ShareExternalData));

View File

@@ -0,0 +1,166 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include <xpu/runtime.h>
#include "paddle/extension.h"
template <typename T>
void SwapCacheImplAllLayers(
const std::vector<paddle::Tensor>& cache_xpu_tensors, // xpu
const std::vector<int64_t>& cache_cpu_ptrs, // cpu
const int64_t& max_block_num_cpu,
const std::vector<int64_t>& swap_block_ids_xpu,
const std::vector<int64_t>& swap_block_ids_cpu,
int mode) {
using XPUType = typename XPUTypeTrait<T>::Type;
for (int layer_idx = 0; layer_idx < cache_xpu_tensors.size(); layer_idx++) {
const paddle::Tensor& cache_xpu = cache_xpu_tensors[layer_idx];
const int64_t& cache_cpu_pointer = cache_cpu_ptrs[layer_idx];
// XPUType* cache_xpu_ptr =
// reinterpret_cast<XPUType*>(const_cast<T*>(cache_xpu.data<T>()));
T* cache_xpu_ptr = const_cast<T*>(cache_xpu.data<T>());
auto* cache_cpu_ptr = reinterpret_cast<T*>(cache_cpu_pointer);
auto cache_shape = cache_xpu.shape();
const int64_t max_block_num_xpu = cache_shape[0];
const int64_t num_heads = cache_shape[1];
const int64_t block_size = cache_shape[2];
const int64_t head_dim = cache_shape[3];
const int64_t cache_stride = num_heads * block_size * head_dim;
if (swap_block_ids_xpu.size() == 0) {
return;
}
int i = 0;
int64_t consecutive_block_count = 1;
int64_t last_xpu_block_id = swap_block_ids_xpu[i];
int64_t last_cpu_block_id = swap_block_ids_cpu[i];
int64_t first_xpu_block_id =
last_xpu_block_id; // first block id in a consecutive block ids
int64_t first_cpu_block_id = last_cpu_block_id;
i += 1;
while (true) {
if (i >= swap_block_ids_xpu.size()) {
break;
}
int64_t xpu_block_id = swap_block_ids_xpu[i];
int64_t cpu_block_id = swap_block_ids_cpu[i];
PD_CHECK(xpu_block_id >= 0 && xpu_block_id < max_block_num_xpu);
PD_CHECK(cpu_block_id >= 0 && cpu_block_id < max_block_num_cpu);
if (xpu_block_id == last_xpu_block_id + 1 &&
cpu_block_id == last_cpu_block_id + 1) { // consecutive
consecutive_block_count += 1;
last_xpu_block_id = xpu_block_id;
last_cpu_block_id = cpu_block_id;
} else {
// end of a consecutive block ids
auto* cache_xpu_ptr_now =
cache_xpu_ptr + first_xpu_block_id * cache_stride;
auto* cache_cpu_ptr_now =
cache_cpu_ptr + first_cpu_block_id * cache_stride;
if (mode == 0) { // copy from device to host
xpu_memcpy(cache_cpu_ptr_now,
cache_xpu_ptr_now,
cache_stride * sizeof(XPUType) * consecutive_block_count,
XPU_DEVICE_TO_HOST);
} else { // copy from host to device
xpu_memcpy(cache_xpu_ptr_now,
cache_cpu_ptr_now,
cache_stride * sizeof(XPUType) * consecutive_block_count,
XPU_HOST_TO_DEVICE);
}
first_xpu_block_id = xpu_block_id;
first_cpu_block_id = cpu_block_id;
last_xpu_block_id = xpu_block_id;
last_cpu_block_id = cpu_block_id;
consecutive_block_count = 1;
}
i += 1;
}
// last batch
auto* cache_xpu_ptr_now = cache_xpu_ptr + first_xpu_block_id * cache_stride;
auto* cache_cpu_ptr_now = cache_cpu_ptr + first_cpu_block_id * cache_stride;
if (mode == 0) { // copy from device to host
xpu_memcpy(cache_cpu_ptr_now,
cache_xpu_ptr_now,
cache_stride * sizeof(XPUType) * consecutive_block_count,
XPU_DEVICE_TO_HOST);
} else { // copy from host to device
xpu_memcpy(cache_xpu_ptr_now,
cache_cpu_ptr_now,
cache_stride * sizeof(XPUType) * consecutive_block_count,
XPU_HOST_TO_DEVICE);
}
}
}
void SwapCacheAllLayers(
const std::vector<paddle::Tensor>& cache_xpu_tensors, // xpu
const std::vector<int64_t>& cache_cpu_ptrs, // cpu memory pointer
int64_t max_block_num_cpu, // cpu max block num
const std::vector<int64_t>& swap_block_ids_xpu,
const std::vector<int64_t>& swap_block_ids_cpu,
int rank,
int mode) {
xpu_set_device(rank); // used for distributed launch
PD_CHECK(cache_xpu_tensors.size() > 0 &&
cache_xpu_tensors.size() == cache_cpu_ptrs.size());
switch (cache_xpu_tensors[0].dtype()) {
case paddle::DataType::FLOAT16:
return SwapCacheImplAllLayers<paddle::float16>(cache_xpu_tensors,
cache_cpu_ptrs,
max_block_num_cpu,
swap_block_ids_xpu,
swap_block_ids_cpu,
mode);
case paddle::DataType::UINT8:
return SwapCacheImplAllLayers<uint8_t>(cache_xpu_tensors,
cache_cpu_ptrs,
max_block_num_cpu,
swap_block_ids_xpu,
swap_block_ids_cpu,
mode);
case paddle::DataType::INT8:
return SwapCacheImplAllLayers<int8_t>(cache_xpu_tensors,
cache_cpu_ptrs,
max_block_num_cpu,
swap_block_ids_xpu,
swap_block_ids_cpu,
mode);
case paddle::DataType::BFLOAT16:
return SwapCacheImplAllLayers<paddle::bfloat16>(cache_xpu_tensors,
cache_cpu_ptrs,
max_block_num_cpu,
swap_block_ids_xpu,
swap_block_ids_cpu,
mode);
default:
PD_THROW("Unsupported data type.");
}
}
PD_BUILD_OP(swap_cache_all_layers)
.Inputs({paddle::Vec("cache_xpu_tensors")})
.Attrs({
"cache_cpu_ptrs: std::vector<int64_t>",
"max_block_num_cpu: int64_t",
"swap_block_ids_xpu: std::vector<int64_t>",
"swap_block_ids_cpu: std::vector<int64_t>",
"rank: int",
"mode: int",
})
.Outputs({paddle::Vec("cache_dst_outs")})
.SetInplaceMap({{paddle::Vec("cache_xpu_tensors"),
paddle::Vec("cache_dst_outs")}})
.SetKernelFn(PD_KERNEL(SwapCacheAllLayers));

View File

@@ -0,0 +1,194 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ops/utility/debug.h"
#include <stdarg.h>
#include <cmath> // for std::sqrt
#include <cstring>
#include <memory>
#include <numeric> // for std::accumulate
#include <sstream>
#include <utility>
#include <vector>
#include "paddle/phi/common/float16.h"
#include "xpu/internal/infra_op.h"
namespace paddle {
std::string string_format(const std::string fmt_str, ...) {
// Reserve two times as much as the length of the fmt_str
int final_n, n = (static_cast<int>(fmt_str.size())) * 2;
std::unique_ptr<char[]> formatted;
va_list ap;
while (1) {
formatted.reset(new char[n]);
// Wrap the plain char array into the unique_ptr
std::strcpy(&formatted[0], fmt_str.c_str()); // NOLINT
va_start(ap, fmt_str);
final_n = vsnprintf(&formatted[0], n, fmt_str.c_str(), ap);
va_end(ap);
if (final_n < 0 || final_n >= n)
n += std::abs(final_n - n + 1);
else
break;
}
return std::string(formatted.get());
}
std::string shape_to_string(const std::vector<int64_t>& shape) {
std::ostringstream os;
auto rank = shape.size();
if (rank > 0) {
os << shape[0];
for (size_t i = 1; i < rank; i++) {
os << ", " << shape[i];
}
}
return os.str();
}
template <typename T>
float cal_mean(const std::vector<T>& data) {
return std::accumulate(data.begin(), data.end(), 0.f) /
static_cast<float>(data.size());
}
template <typename T>
float cal_std(const std::vector<T>& data) {
float mean = cal_mean(data);
float variance = std::accumulate(data.begin(),
data.end(),
0.0,
[mean](T acc, T val) {
return acc + (val - mean) * (val - mean);
}) /
data.size();
return std::sqrt(variance);
}
template <typename T>
void DebugPrintXPUTensor(const phi::XPUContext* xpu_ctx,
const paddle::Tensor& input,
std::string tag,
int len) {
const T* input_data_ptr = input.data<T>();
std::vector<T> input_data(len);
xpu::do_device2host(
xpu_ctx->x_context(), input_data_ptr, input_data.data(), len);
for (int i = 0; i < len; ++i) {
std::cout << "DebugPrintXPUTensor " << tag << ", data: " << input_data[i]
<< std::endl;
}
std::cout << "DebugPrintXPUTensor " << tag
<< ", mean: " << cal_mean(input_data) << std::endl;
std::cout << "DebugPrintXPUTensor " << tag << ", std: " << cal_std(input_data)
<< std::endl;
}
template <typename T>
void DebugPrintXPUTensorv2(const paddle::Tensor& input,
std::string tag,
int len) {
auto input_cpu = input.copy_to(phi::CPUPlace(), false);
std::ostringstream os;
const T* input_data = input_cpu.data<T>();
for (int i = 0; i < len; ++i) {
os << input_data[i] << ", ";
}
std::cout << "DebugPrintXPUTensorv2 " << tag << ", data: " << os.str()
<< std::endl;
}
template <>
void DebugPrintXPUTensorv2<paddle::float16>(const paddle::Tensor& input,
std::string tag,
int len) {
auto input_cpu = input.copy_to(phi::CPUPlace(), false);
std::ostringstream os;
const paddle::float16* input_data = input_cpu.data<paddle::float16>();
for (int i = 0; i < len; ++i) {
os << static_cast<float>(input_data[i]) << ", ";
}
std::cout << "DebugPrintXPUTensorv2 " << tag << ", data: " << os.str()
<< std::endl;
}
template <>
void DebugPrintXPUTensorv2<paddle::bfloat16>(const paddle::Tensor& input,
std::string tag,
int len) {
auto input_cpu = input.copy_to(phi::CPUPlace(), false);
std::ostringstream os;
const paddle::bfloat16* input_data = input_cpu.data<paddle::bfloat16>();
for (int i = 0; i < len; ++i) {
os << static_cast<float>(input_data[i]) << ", ";
}
std::cout << "DebugPrintXPUTensorv2 " << tag << ", data: " << os.str()
<< std::endl;
}
template <>
void DebugPrintXPUTensorv2<int8_t>(const paddle::Tensor& input,
std::string tag,
int len) {
auto input_cpu = input.copy_to(phi::CPUPlace(), false);
std::ostringstream os;
const int8_t* input_data = input_cpu.data<int8_t>();
for (int i = 0; i < len; ++i) {
int8_t tmp = input_data[i] >> 4;
os << (int32_t)tmp << ", ";
}
std::cout << "DebugPrintXPUTensorv2 " << tag << ", data: " << os.str()
<< std::endl;
}
#define INSTANTIATE_DEBUGPRINT_XPUTENSOR(Type, FuncName, ...) \
template void FuncName<Type>(__VA_ARGS__);
#define INSTANTIATE_DEBUGPRINT_XPUTENSOR_V1(Type) \
INSTANTIATE_DEBUGPRINT_XPUTENSOR(Type, \
DebugPrintXPUTensor, \
const phi::XPUContext* xpu_ctx, \
const paddle::Tensor& input, \
std::string tag, \
int len)
#define INSTANTIATE_DEBUGPRINT_XPUTENSOR_V2(Type) \
INSTANTIATE_DEBUGPRINT_XPUTENSOR(Type, \
DebugPrintXPUTensorv2, \
const paddle::Tensor& input, \
std::string tag, \
int len)
// do not support bool type now, please use DebugPrintXPUTensorv2<bool>
// INSTANTIATE_DEBUGPRINT_XPUTENSOR_V1(bool)
INSTANTIATE_DEBUGPRINT_XPUTENSOR_V1(float)
INSTANTIATE_DEBUGPRINT_XPUTENSOR_V1(int)
INSTANTIATE_DEBUGPRINT_XPUTENSOR_V1(int64_t)
INSTANTIATE_DEBUGPRINT_XPUTENSOR_V2(int8_t)
INSTANTIATE_DEBUGPRINT_XPUTENSOR_V2(bool)
INSTANTIATE_DEBUGPRINT_XPUTENSOR_V2(int64_t)
INSTANTIATE_DEBUGPRINT_XPUTENSOR_V2(float)
INSTANTIATE_DEBUGPRINT_XPUTENSOR_V2(int)
INSTANTIATE_DEBUGPRINT_XPUTENSOR_V2(paddle::float16)
INSTANTIATE_DEBUGPRINT_XPUTENSOR_V2(paddle::bfloat16)
} // namespace paddle

View File

@@ -0,0 +1,63 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "env.h" // NOLINT
namespace paddle {
// Specialization for bool
template <>
bool get_env<bool>(const std::string& var_name, bool default_value) {
const char* value = std::getenv(var_name.c_str());
if (!value) {
if (var_name.size() < 6 || var_name.substr(0, 6) != "FLAGS_") {
return get_env<bool>("FLAGS_" + var_name, default_value);
}
return default_value;
}
std::string valStr(value);
std::transform(valStr.begin(), valStr.end(), valStr.begin(), ::tolower);
if (valStr == "true" || valStr == "1") {
return true;
} else if (valStr == "false" || valStr == "0") {
return false;
}
PD_THROW("Unexpected value:", valStr, ", only bool supported.");
return default_value;
}
template <>
int get_env<int>(const std::string& var_name, int default_value) {
const char* value = std::getenv(var_name.c_str());
if (!value) {
if (var_name.size() < 6 || var_name.substr(0, 6) != "FLAGS_") {
return get_env<int>("FLAGS_" + var_name, default_value);
}
return default_value;
}
try {
return std::stoi(value);
} catch (...) {
PD_THROW("Unexpected value:", value, ", only int supported.");
}
}
#define DEFINE_GET_ENV_SPECIALIZATION(T) \
template <> \
T get_env<T>(const std::string& var_name, T default_value);
DEFINE_GET_ENV_SPECIALIZATION(bool)
DEFINE_GET_ENV_SPECIALIZATION(int)
} // namespace paddle

View File

@@ -0,0 +1,29 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/extension.h"
namespace paddle {
template <typename T>
T get_env(const std::string& var_name, T default_value);
}
#define XPU_DECLARE_VALUE(type, env_name, default_value) \
static type FLAGS_##env_name = \
paddle::get_env<type>(#env_name, default_value);
#define XPU_DECLARE_BOOL(env_name, default_value) \
XPU_DECLARE_VALUE(bool, env_name, default_value)
#define XPU_DECLARE_INT(env_name, default_value) \
XPU_DECLARE_VALUE(int, env_name, default_value)

View File

@@ -0,0 +1,95 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ops/utility/logging.h"
#include <iomanip>
namespace paddle {
void gen_log(std::ostream& log_stream_,
const char* file,
const char* func,
int lineno,
const char* level,
const int kMaxLen = 40) {
const int len = strlen(file);
struct tm tm_time; // Time of creation of LogMessage
time_t timestamp = time(NULL);
#if defined(_WIN32)
localtime_s(&tm_time, &timestamp);
#else
localtime_r(&timestamp, &tm_time);
#endif
struct timeval tv;
gettimeofday(&tv, NULL);
// print date / time
log_stream_ << '[' << level << ' ' << std::setw(2) << 1 + tm_time.tm_mon
<< '/' << std::setw(2) << tm_time.tm_mday << ' ' << std::setw(2)
<< tm_time.tm_hour << ':' << std::setw(2) << tm_time.tm_min << ':'
<< std::setw(2) << tm_time.tm_sec << '.' << std::setw(3)
<< tv.tv_usec / 1000 << " ";
if (len > kMaxLen) {
log_stream_ << "..." << file + len - kMaxLen << ":" << lineno << " " << func
<< "] ";
} else {
log_stream_ << file << " " << func << ":" << lineno << "] ";
}
}
CustomLogMessage::CustomLogMessage(const char* file,
const char* func,
int lineno,
const char* level)
: level_(level) {
gen_log(log_stream_, file, func, lineno, level);
}
CustomLogMessage::~CustomLogMessage() {
log_stream_ << '\n';
fprintf(stderr, "%s", log_stream_.str().c_str());
}
CustomLogMessageFatal::~CustomLogMessageFatal() noexcept(false) {
log_stream_ << '\n';
fprintf(stderr, "%s", log_stream_.str().c_str());
throw CustomException(log_stream_.str().c_str());
abort();
}
CustomVLogMessage::CustomVLogMessage(const char* file,
const char* func,
int lineno,
const int32_t level_int) {
const char* GLOG_v = std::getenv("GLOG_v");
GLOG_v_int = (GLOG_v && atoi(GLOG_v) > 0) ? atoi(GLOG_v) : 0;
this->level_int = level_int;
if (GLOG_v_int < level_int) {
return;
}
const char* level = std::to_string(level_int).c_str();
gen_log(log_stream_, file, func, lineno, level);
}
CustomVLogMessage::~CustomVLogMessage() {
if (GLOG_v_int < this->level_int) {
return;
}
log_stream_ << '\n';
fprintf(stderr, "%s", log_stream_.str().c_str());
}
} // namespace paddle

View File

@@ -0,0 +1,114 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <assert.h>
#include <time.h>
#if !defined(_WIN32)
#include <sys/time.h>
#include <sys/types.h>
#else
#define NOMINMAX // msvc max/min macro conflict with std::min/max
#include <windows.h>
#undef min
#undef max
extern struct timeval;
static int gettimeofday(struct timeval* tp, void* tzp) {
LARGE_INTEGER now, freq;
QueryPerformanceCounter(&now);
QueryPerformanceFrequency(&freq);
tp->tv_sec = now.QuadPart / freq.QuadPart;
tp->tv_usec = (now.QuadPart % freq.QuadPart) * 1000000 / freq.QuadPart;
return (0);
}
#endif
#include <cstdlib>
#include <cstring>
#include <sstream>
#include <string>
// LOG()
#define LOG(status) LOG_##status.stream()
#define LOG_INFO paddle::CustomLogMessage(__FILE__, __FUNCTION__, __LINE__, "I")
#define LOG_ERROR LOG_INFO
#define LOG_WARNING \
paddle::CustomLogMessage(__FILE__, __FUNCTION__, __LINE__, "W")
#define LOG_FATAL \
paddle::CustomLogMessageFatal(__FILE__, __FUNCTION__, __LINE__)
// VLOG()
#define VLOG(level) \
paddle::CustomVLogMessage(__FILE__, __FUNCTION__, __LINE__, level).stream()
namespace paddle {
struct CustomException : public std::exception {
const std::string exception_prefix = "Custom exception: \n";
std::string message;
explicit CustomException(const char* detail) {
message = exception_prefix + std::string(detail);
}
const char* what() const noexcept { return message.c_str(); }
};
class CustomLogMessage {
public:
CustomLogMessage(const char* file,
const char* func,
int lineno,
const char* level = "I");
~CustomLogMessage();
std::ostream& stream() { return log_stream_; }
protected:
std::stringstream log_stream_;
std::string level_;
CustomLogMessage(const CustomLogMessage&) = delete;
void operator=(const CustomLogMessage&) = delete;
};
class CustomLogMessageFatal : public CustomLogMessage {
public:
CustomLogMessageFatal(const char* file,
const char* func,
int lineno,
const char* level = "F")
: CustomLogMessage(file, func, lineno, level) {}
~CustomLogMessageFatal() noexcept(false);
};
class CustomVLogMessage {
public:
CustomVLogMessage(const char* file,
const char* func,
int lineno,
const int32_t level_int = 0);
~CustomVLogMessage();
std::ostream& stream() { return log_stream_; }
protected:
std::stringstream log_stream_;
int32_t GLOG_v_int;
int32_t level_int;
CustomVLogMessage(const CustomVLogMessage&) = delete;
void operator=(const CustomVLogMessage&) = delete;
};
} // namespace paddle

View File

@@ -0,0 +1,207 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <blocks/xft_blocks.h>
#include <infer_ops.h>
#include <functional>
#include "paddle/extension.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "utility/debug.h"
#include "utility/env.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
XPU_DECLARE_BOOL(ENABLE_XVLLM_SDNN_INFER, false);
namespace xftblock = baidu::xpu::xftblock;
namespace api = baidu::xpu::api;
template <typename TX, typename TW>
std::vector<paddle::Tensor> WeightOnlyLinearKernel(
const paddle::Tensor& x,
const paddle::Tensor& weight,
const paddle::Tensor& weight_scale,
const paddle::optional<paddle::Tensor>& bias,
const std::string& weight_dtype) {
using XPU_TX = typename XPUTypeTrait<TX>::Type;
using XPU_TW = typename XPUTypeTrait<TW>::Type;
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
xftblock::XFTContext xctx(xpu_ctx->x_context(), nullptr);
auto rt_guard = xctx.get_rt_guard();
auto xftblock_tx = xftblock::DataTypeToEnum<XPU_TX>::value;
auto xftblock_tw = xftblock::DataTypeToEnum<XPU_TW>::value;
int ret = -1;
auto x_shape = x.shape();
auto w_shape = weight.shape();
int64_t n = w_shape[0];
int64_t k = w_shape[1];
int64_t m = x.numel() / k;
if (weight_dtype == "int4_t") {
n = n * 2;
}
paddle::Tensor out = paddle::empty({m, n}, x.dtype(), x.place());
if (m == 0) {
return {out};
}
paddle::Tensor bias_fp32;
if (bias.get_ptr() && bias.get_ptr()->dtype() != paddle::DataType::FLOAT32) {
bias_fp32 = paddle::empty({n}, paddle::DataType::FLOAT32, x.place());
PD_CHECK(bias.get_ptr()->dtype() == x.dtype(), "bias.dtype != x.dtype");
ret = api::cast<XPU_TX, float>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_TX*>(bias.get_ptr()->data<TX>()),
bias_fp32.data<float>(),
n);
PD_CHECK(ret == 0, "cast");
}
xftblock::Tensor input_x(const_cast<TX*>(x.data<TX>()), xftblock_tx, {m, k});
xftblock::Tensor input_w(const_cast<TW*>(weight.data<TW>()),
nullptr,
const_cast<float*>(weight_scale.data<float>()),
xftblock_tw,
{n, k});
xftblock::Tensor output(const_cast<TX*>(out.data<TX>()), xftblock_tx, {m, n});
std::shared_ptr<xftblock::Tensor> input_bias;
if (bias.get_ptr()) {
if (bias.get_ptr()->dtype() != paddle::DataType::FLOAT32) {
input_bias = std::make_shared<xftblock::Tensor>(
const_cast<float*>(bias_fp32.data<float>()),
xftblock::DataType::DT_FLOAT,
std::vector<int64_t>({n}));
} else {
input_bias = std::make_shared<xftblock::Tensor>(
const_cast<float*>(bias.get_ptr()->data<float>()),
xftblock::DataType::DT_FLOAT,
std::vector<int64_t>({n}));
}
}
bool use_sdnn = FLAGS_ENABLE_XVLLM_SDNN_INFER;
if (x.dtype() == paddle::DataType::BFLOAT16) {
ret = xftblock::
xft_fc_block_cast_te_per_token<bfloat16, int8_t, bfloat16, float16>(
&xctx,
&input_x,
&input_w,
&output,
input_bias.get(),
api::Activation_t::LINEAR,
false,
true,
1.0f,
0.0f,
0,
1,
false,
false,
use_sdnn);
PD_CHECK(ret == 0, "xft_fc_block_cast_te_per_token");
} else {
ret = xftblock::xft_fc_block<XPU_TX, XPU_TW, XPU_TX, XPU_TX>(
&xctx,
&input_x,
&input_w,
&output,
input_bias.get(),
api::Activation_t::LINEAR,
false,
true,
1.0f,
0.0f,
0,
1,
false,
false);
PD_CHECK(ret == 0, "xft_fc_block");
}
return {out};
}
std::vector<paddle::Tensor> WeightOnlyLinear(
const paddle::Tensor& x,
const paddle::Tensor& weight,
const paddle::Tensor& weight_scale,
const paddle::optional<paddle::Tensor>& bias,
const std::string& weight_dtype,
const int arch,
const int group_size) {
const auto x_type = x.dtype();
const auto w_type = weight.dtype();
#define APPLY_FFN_KERNEL(TX, TW) \
return WeightOnlyLinearKernel<TX, TW>( \
x, weight, weight_scale, bias, weight_dtype);
if (x_type == paddle::DataType::BFLOAT16 &&
w_type == paddle::DataType::INT8) {
APPLY_FFN_KERNEL(paddle::bfloat16, int8_t);
} else if (x_type == paddle::DataType::FLOAT16 &&
w_type == paddle::DataType::INT8) {
APPLY_FFN_KERNEL(paddle::float16, int8_t);
} else {
PD_THROW("WeightOnlyLinear not support x_type=",
static_cast<int>(x_type),
", w_type=",
static_cast<int>(w_type));
return {};
}
#undef APPLY_FFN_KERNEL
}
std::vector<std::vector<int64_t>> WeightOnlyLinearInferShape(
const std::vector<int64_t>& x_shape,
const std::vector<int64_t>& weight_shape,
const std::vector<int64_t>& weight_scale_shape,
const paddle::optional<std::vector<int64_t>>& bias_shape,
const std::string& weight_dtype,
const int arch,
const int group_size) {
PD_CHECK(weight_shape.size() == 2);
int64_t n = weight_shape[0];
int64_t k = weight_shape[1];
int64_t x_numel = std::accumulate(x_shape.begin(),
x_shape.end(),
static_cast<int64_t>(1),
std::multiplies<int64_t>());
int64_t m = x_numel / k;
if (weight_dtype == "int4") {
n = n * 2;
}
return {{m, n}};
}
std::vector<paddle::DataType> WeightOnlyLinearInferDtype(
const paddle::DataType& x_dtype,
const paddle::DataType& w_dtype,
const paddle::DataType& weight_scale_dtype,
const paddle::optional<paddle::DataType>& bias_dtype,
const std::string& weight_dtype,
const int arch,
const int group_size) {
return {x_dtype};
}
PD_BUILD_STATIC_OP(weight_only_linear_xpu)
.Inputs({"x", "weight", "weight_scale", paddle::Optional("bias")})
.Outputs({"out"})
.Attrs({"weight_dtype:std::string", "arch:int", "group_size:int"})
.SetKernelFn(PD_KERNEL(WeightOnlyLinear))
.SetInferShapeFn(PD_INFER_SHAPE(WeightOnlyLinearInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(WeightOnlyLinearInferDtype));

View File

@@ -0,0 +1,336 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import block_attn, get_infer_param
head_num = 64
kv_head_num = 8
head_dim = 128
seq_len = 128
block_batch = 5
max_block_per_seq = 128
block_size = 64
seq_lens_encoder = paddle.to_tensor([128, 0, 0, 0, 0], dtype="int32")
seq_lens_decoder = paddle.to_tensor([0, 0, 0, 0, 0], dtype="int32")
seq_lens_this_time = paddle.to_tensor([128, 0, 0, 0, 0], dtype="int32")
block_tables = paddle.arange(0, block_batch * max_block_per_seq, dtype="int32")
block_tables = block_tables.reshape((block_batch, max_block_per_seq))
(
encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
) = get_infer_param(
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64
) # block_size
qkv = paddle.uniform(
shape=[seq_len, (head_num + 2 * kv_head_num) * head_dim],
dtype="bfloat16",
min=-1.0,
max=1.0,
)
cum_offsets = paddle.zeros(shape=[block_batch], dtype="bfloat16")
rotary_embs = paddle.uniform(shape=[2, 1, 8192, 1, head_dim], dtype="float32", min=-1.0, max=1.0)
key_cache = paddle.zeros(
shape=[block_batch * max_block_per_seq, kv_head_num, block_size, head_dim],
dtype="bfloat16",
)
value_cache = paddle.zeros(
shape=[block_batch * max_block_per_seq, kv_head_num, block_size, head_dim],
dtype="bfloat16",
)
# C8
key_cache_int8 = paddle.zeros(
shape=[block_batch * max_block_per_seq, kv_head_num, block_size, head_dim],
dtype="int8",
)
value_cache_int8 = paddle.zeros(
shape=[block_batch * max_block_per_seq, kv_head_num, block_size, head_dim],
dtype="int8",
)
scale_tensor_k = paddle.uniform(shape=[kv_head_num * head_dim], dtype="bfloat16", min=1.0, max=1.0) # max
scale_tensor_v = paddle.uniform(shape=[kv_head_num * head_dim], dtype="bfloat16", min=1.0, max=1.0) # max
k_quant_scale = 127.0 / scale_tensor_k # for C8 per channel means 127 / max
v_quant_scale = 127.0 / scale_tensor_v # for C8 per channel means 127 / max
k_dequant_scale = paddle.cast(scale_tensor_k, dtype="float32") # for C8 per channel means max
v_dequant_scale = paddle.cast(scale_tensor_v, dtype="float32") # for C8 per channel means max
k_dequant_scale_zp = 1 / k_quant_scale # for C8 per channel zp means max
v_dequant_scale_zp = 1 / v_quant_scale # for C8 per channel zp means max
k_zp = paddle.zeros(shape=[kv_head_num * head_dim], dtype="bfloat16")
v_zp = paddle.zeros(shape=[kv_head_num * head_dim], dtype="bfloat16")
attn_out = block_attn(
qkv,
key_cache,
value_cache,
cum_offsets,
rotary_embs,
block_tables,
prefix_block_tables,
len_info_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
encoder_batch_map_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
decoder_batch_map_cpu,
prefix_len_cpu,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
attn_out_C8 = block_attn(
qkv,
key_cache_int8,
value_cache_int8,
cum_offsets,
rotary_embs,
block_tables,
prefix_block_tables,
len_info_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
encoder_batch_map_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
decoder_batch_map_cpu,
prefix_len_cpu,
k_quant_scale,
v_quant_scale,
k_dequant_scale,
v_dequant_scale,
None,
None,
None,
None,
None,
None,
)
attn_out_C8_zp = block_attn(
qkv,
key_cache_int8,
value_cache_int8,
cum_offsets,
rotary_embs,
block_tables,
prefix_block_tables,
len_info_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
encoder_batch_map_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
decoder_batch_map_cpu,
prefix_len_cpu,
k_quant_scale,
v_quant_scale,
k_dequant_scale_zp,
v_dequant_scale_zp,
k_zp,
v_zp,
None,
None,
None,
None,
)
# prefix cache : hit 71 tokens
hit_prefix_len = 71
seq_lens_encoder = paddle.to_tensor([seq_len - hit_prefix_len, 0, 0, 0, 0], dtype="int32")
# 71 means prefix len
seq_lens_decoder = paddle.to_tensor([hit_prefix_len, 0, 0, 0, 0], dtype="int32")
(
encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
) = get_infer_param(
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64
) # block_size
qkv_prefix = qkv[hit_prefix_len:]
attn_out_prefix_cache = block_attn(
qkv_prefix,
key_cache,
value_cache,
cum_offsets,
rotary_embs,
block_tables,
prefix_block_tables,
len_info_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
encoder_batch_map_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
decoder_batch_map_cpu,
prefix_len_cpu,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
attn_out_C8_prefix_cache = block_attn(
qkv_prefix,
key_cache_int8,
value_cache_int8,
cum_offsets,
rotary_embs,
block_tables,
prefix_block_tables,
len_info_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
encoder_batch_map_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
decoder_batch_map_cpu,
prefix_len_cpu,
k_quant_scale,
v_quant_scale,
k_dequant_scale,
v_dequant_scale,
None,
None,
None,
None,
None,
None,
)
attn_out_C8_zp_prefix_cache = block_attn(
qkv_prefix,
key_cache_int8,
value_cache_int8,
cum_offsets,
rotary_embs,
block_tables,
prefix_block_tables,
len_info_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
encoder_batch_map_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
decoder_batch_map_cpu,
prefix_len_cpu,
k_quant_scale,
v_quant_scale,
k_dequant_scale_zp,
v_dequant_scale_zp,
k_zp,
v_zp,
None,
None,
None,
None,
)
print("-- C16 prefix cache test --")
print("attn_out[hit_prefix_len:]'s mean:", attn_out[hit_prefix_len:].mean().item())
print("attn_out_prefix_cache's mean: ", attn_out_prefix_cache.mean().item())
attn_out_prefix_cache_np = attn_out_prefix_cache.astype("float32").numpy()
attn_out_np = attn_out[hit_prefix_len:].astype("float32").numpy()
assert np.allclose(
attn_out_prefix_cache_np, attn_out_np, rtol=1e-2, atol=1e-3
), f"C16 prefix cache != No prefix cache,\n attn_out[hit_prefix_len:]: {attn_out_np},\nattn_out_prefix_cache: {attn_out_prefix_cache_np}"
print("\n-- C8 per channle prefix cache test --")
print(
"attn_out_C8[hit_prefix_len:]'s mean:",
attn_out_C8[hit_prefix_len:].mean().item(),
)
print("attn_out_C8_prefix_cache's mean: ", attn_out_C8_prefix_cache.mean().item())
attn_out_C8_prefix_cache_np = attn_out_C8_prefix_cache.astype("float32").numpy()
attn_out_C8_np = attn_out_C8[hit_prefix_len:].astype("float32").numpy()
assert np.allclose(
attn_out_C8_prefix_cache_np, attn_out_C8_np, rtol=1e-1, atol=1e-2
), f"C8 per channle prefix cache != No prefix cache,\n attn_out_C8[hit_prefix_len:]: {attn_out_C8_np},\nattn_out_C8_prefix_cache: {attn_out_C8_prefix_cache_np}"
print("\n-- C8 per channle zp prefix cache test --")
print(
"attn_out_C8_zp[hit_prefix_len:]'s mean:",
attn_out_C8_zp[hit_prefix_len:].mean().item(),
)
print(
"attn_out_C8_zp_prefix_cache's mean: ",
attn_out_C8_zp_prefix_cache.mean().item(),
)
attn_out_C8_zp_prefix_cache_np = attn_out_C8_zp_prefix_cache.astype("float32").numpy()
attn_out_C8_zp_np = attn_out_C8_zp[hit_prefix_len:].astype("float32").numpy()
assert np.allclose(
attn_out_C8_zp_prefix_cache_np, attn_out_C8_zp_np, rtol=1e-1, atol=1e-2
), f"C8 per channle zp prefix cache != No prefix cache,\n attn_out_C8_zp[hit_prefix_len:]: {attn_out_C8_zp_np},\nattn_out_C8_zp_prefix_cache: {attn_out_C8_zp_prefix_cache_np}"

View File

@@ -0,0 +1,137 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import fused_rms_norm_xpu
# from paddle.incubate.nn.functional import fused_rms_norm
def find_max_diff(arr1, arr2):
"""找出两个数组元素差值的最大值及其索引
返回:
max_diff (float): 最大绝对值差
index (tuple): 最大值的位置索引
actual_diff (float): 实际差值(带符号)
"""
diff = arr1 - arr2
abs_diff = np.abs(diff)
flat_idx = np.argmax(abs_diff)
idx = np.unravel_index(flat_idx, arr1.shape)
return abs_diff[idx], idx, diff[idx], arr1[idx], arr2[idx]
def naive_rmsnorm(
x,
gamma,
beta=None,
epsilon=1e-6,
begin_norm_axis=1,
bias=None,
residual=None,
):
residual_out = None
if bias is not None:
x = x + bias
if residual is not None:
x = x + residual
residual_out = x
variance = (x * x).mean(axis=-1)
out = np.expand_dims(1.0 / np.sqrt(variance + epsilon), axis=-1) * x
out = out * gamma
if beta is not None:
out = out + beta
return out, residual_out
def run_and_compare(x_in, residual, bias, norm_weight):
x_in_pd = paddle.to_tensor(x_in).astype(data_type)
residual_pd = None
if residual is not None:
residual_pd = paddle.to_tensor(residual).astype(data_type)
bias_pd = paddle.to_tensor(bias).astype(data_type)
norm_weight_pd = paddle.to_tensor(norm_weight).astype(data_type)
# norm_bias_pd = paddle.to_tensor(norm_bias).astype(data_type)
out_np, residual_out_np = naive_rmsnorm(x_in, norm_weight, None, epsilon, begin_norm_axis, bias, residual)
out_pd, residual_out_pd = fused_rms_norm_xpu(
x_in_pd,
bias_pd,
residual_pd,
norm_weight_pd,
None, # norm_bias_pd,
epsilon,
begin_norm_axis,
-1,
0,
0,
0,
)
"""
out_pd1, residual_out_pd1 = fused_rms_norm(
x_in_pd,
norm_weight=norm_weight_pd,
norm_bias=norm_bias_pd,
epsilon=epsilon,
begin_norm_axis=1,
bias=bias_pd,
residual=residual_pd,
quant_scale=-1,
quant_round_type=0,
quant_max_bound=0,
quant_min_bound=0,
)
"""
abs_diff, idx, diff, val1, val2 = find_max_diff(out_np, out_pd.astype("float32").numpy())
print(f"out compare: abs_diff={abs_diff}, index={idx}, diff={diff}, {val1} vs {val2}")
assert np.allclose(out_np, out_pd.astype("float32").numpy(), rtol=1e-5, atol=1e-5)
if residual is not None:
abs_diff, idx, diff, val1, val2 = find_max_diff(residual_out_np, residual_out_pd.astype("float32").numpy())
print(f"residual_out compare: abs_diff={abs_diff}, index={idx}, diff={diff}, {val1} vs {val2}")
assert np.allclose(
residual_out_np,
residual_out_pd.astype("float32").numpy(),
rtol=1e-5,
atol=1e-5,
)
if __name__ == "__main__":
seed = np.random.randint(0, 1e8)
print(f"numpy random seed is {seed}")
np.random.seed(seed)
m = 7
n = 8192
epsilon = 1e-5
begin_norm_axis = 1
data_type = "float32"
x_in = (np.random.random([m, n]) - 0.5).astype("float32")
residual = (np.random.random([m, n]) - 0.5).astype("float32")
bias = (np.random.random([n]) - 0.5).astype("float32")
norm_weight = (np.random.random([n]) - 0.5).astype("float32")
# norm_bias = np.zeros([n]).astype("float32")
# norm_bias = (np.random.random([n]) - 0.5).astype("float32")
x_in_pd = paddle.to_tensor(x_in).astype(data_type)
residual_pd = paddle.to_tensor(residual).astype(data_type)
bias_pd = paddle.to_tensor(bias).astype(data_type)
norm_weight_pd = paddle.to_tensor(norm_weight).astype(data_type)
# norm_bias_pd = paddle.to_tensor(norm_bias).astype(data_type)
run_and_compare(x_in, residual, bias, norm_weight)
run_and_compare(x_in, None, bias, norm_weight)

View File

@@ -0,0 +1,95 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from fastdeploy.model_executor.ops.xpu import get_infer_param
seq_lens_encoder = paddle.to_tensor([100, 0, 0, 0, 300], dtype="int32")
seq_lens_decoder = paddle.to_tensor([0, 5, 0, 25, 64], dtype="int32")
seq_lens_this_time = paddle.to_tensor([100, 1, 0, 1, 300], dtype="int32")
block_table = paddle.arange(0, 40, dtype="int32")
block_table = block_table.reshape((5, 8))
(
encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
) = get_infer_param(
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_table, 64
) # block_size
print("block_table", block_table)
print("encoder_batch_map", encoder_batch_map) # [0, 4, 0, 0, 0]
print("decoder_batch_map", decoder_batch_map) # [1, 3, 0, 0, 0]
print("encoder_batch_idx", encoder_batch_idx) # [0, 3, 0, 0, 0]
print("decoder_batch_idx", decoder_batch_idx) # [1, 2, 0, 0, 0]
print("encoder_seq_lod", encoder_seq_lod) # [0, 100, 400 ,0 ,0 ,0]
print("decoder_seq_lod", decoder_seq_lod) # [0, 1, 2 ,0 ,0 ,0]
print("encoder_kv_lod", encoder_kv_lod) # [0, 100, 464, 0, 0, 0]
print("prefix_len", prefix_len) # [0, 64, 0, 0, 0]
print("decoder_context_len", decoder_context_len) # [6, 26, 0, 0, 0]
print("decoder_context_len_cache", decoder_context_len_cache) # [5, 25, 0, 0, 0]
print("prefix_block_tables", prefix_block_tables)
print("encoder_batch_map_cpu", encoder_batch_map_cpu) # [0, 4, 0, 0, 0]
print("decoder_batch_map_cpu", decoder_batch_map_cpu) # [1, 3, 0, 0, 0]
print("encoder_batch_idx_cpu", encoder_batch_idx_cpu) # [0, 3, 0, 0, 0]
print("decoder_batch_idx_cpu", decoder_batch_idx_cpu) # [1, 2, 0, 0, 0]
print("encoder_seq_lod_cpu", encoder_seq_lod_cpu) # [0, 100, 400 ,0 ,0 ,0]
print("decoder_seq_lod_cpu", decoder_seq_lod_cpu) # [0, 1, 2 ,0 ,0 ,0]
print("encoder_kv_lod_cpu", encoder_kv_lod_cpu) # [0, 100, 464, 0, 0, 0]
print("prefix_len_cpu", prefix_len_cpu) # [0, 64, 0, 0, 0]
print("decoder_context_len_cpu", decoder_context_len_cpu) # [6, 26, 0, 0, 0]
print("decoder_context_len_cache_cpu", decoder_context_len_cache_cpu) # [5, 25, 0, 0, 0]
print(
"len_info_cpu", len_info_cpu
) # {enc_batch, dec_batch, total_enc_len, max_seq_len, max_kv_len, prefix_block_num_per_seq} = [2, 2, 400, 300, 364, 6]
"""
block_table Tensor(shape=[5, 8], dtype=int32, place=Place(xpu:0), stop_gradient=True,
[[0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ],
[8 , 9 , 10, 11, 12, 13, 14, 15],
[16, 17, 18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29, 30, 31],
[32, 33, 34, 35, 36, 37, 38, 39]])
prefix_block_tables Tensor(shape=[5, 8], dtype=int32, place=Place(xpu:0), stop_gradient=True,
[[ 0, 1, -1, -1, -1, -1, 32, 33],
[34, 35, 36, 37, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1, -1]])
The size of the prefix_block_tables tensor is same as block_table to avoid problems with InferShape of the prefix_block_tables.
However, the actual size used by prefix_block_tables is [block_bs, prefix_block_num_per_seq], where prefix_block_num_per_seq = ceil(max_kv_len / block_size).
Therefore, do not use the tensor shape of prefix_block_tables. Its shape is obtained through block_table.dims[0] and len_info_cpu[-1]
"""

View File

@@ -0,0 +1,93 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import ep_moe_expert_combine
np.random.seed(2025)
def np_softmax(x, axis=-1):
x_max = np.max(x, axis=axis, keepdims=True)
x_exp = np.exp(x - x_max)
return x_exp / np.sum(x_exp, axis=axis, keepdims=True)
def create_moe_index(token_num, moe_topk, expand_token_num):
total_positions = token_num * moe_topk
positions = np.random.choice(total_positions, size=expand_token_num, replace=False)
rows = positions // moe_topk
cols = positions % moe_topk
values = np.random.permutation(expand_token_num)
# moe_index is the output of moe_ep_dispatch
# the val in moe_index is the row in ffn_out for corresponding token and expert, -1 means invalid
moe_index = np.full((token_num, moe_topk), -1)
for i in range(expand_token_num):
moe_index[rows[i], cols[i]] = values[i]
return moe_index
# 1) preparation
token_num = 10
moe_topk = 8
hidden_dim = 128
expand_token_num = 30
ffn_out = np.random.random((expand_token_num, hidden_dim))
moe_index = create_moe_index(token_num, moe_topk, expand_token_num)
moe_weights = np.random.random((token_num, moe_topk))
moe_weights = np_softmax(moe_weights)
moe_weights[moe_index == -1] = -1
print(f"ffn_out:\n{ffn_out}")
print(f"moe_index:\n{moe_index}")
print(f"moe_weights:\n{moe_weights}")
# 2) np calculation
combined_out_np = np.zeros((token_num, hidden_dim))
for token_idx, item in enumerate(moe_index):
for topk_idx, ffn_out_row in enumerate(item):
if ffn_out_row == -1:
continue
combined_out_np[token_idx] += ffn_out[ffn_out_row] * moe_weights[token_idx][topk_idx]
print(f"combined_out_np:\n{combined_out_np}")
# 3) xpu calculation
dtype = "bfloat16"
ffn_out_pd = paddle.to_tensor(ffn_out, dtype=dtype)
moe_index_pd = paddle.to_tensor(moe_index, dtype="int32")
moe_weights_pd = paddle.to_tensor(moe_weights, dtype=dtype)
combined_out_pd = ep_moe_expert_combine(
ffn_out_pd,
moe_index_pd,
moe_weights_pd,
moe_index_pd.shape[0],
ffn_out_pd.shape[0],
ffn_out_pd.shape[1],
moe_index_pd.shape[1],
)
# comparation
# print("moe_index:\n", moe_index)
# print("moe_weights:\n", moe_weights)
# print("combined_out_np:\n", combined_out_np)
# print("combined_out_pd:\n", combined_out_pd)
combined_out_pd = combined_out_pd.astype("float32").numpy()
avg_diff = np.sum(np.abs(combined_out_pd - combined_out_np)) / combined_out_pd.size
assert (
avg_diff < 2e-3
), f"avg_diff: {avg_diff}\n combined_out_np:\n{combined_out_np}\n combined_out_pd:\n{combined_out_pd}\n"
print(f"[Passed] avg_diff: {avg_diff}")

View File

@@ -0,0 +1,136 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import ep_moe_expert_dispatch
np.random.seed(2025)
def ep_moe_expert_dispatch_cpu(input, topk_ids, topk_weights, token_nums_per_expert, token_nums_this_rank):
m, n = input.shape[0], input.shape[1]
topk = topk_ids.shape[1]
expert_num = len(token_nums_per_expert)
expert_per_rank = expert_num
permute_input = np.full((token_nums_this_rank, n), 0.0, dtype=np.float32)
permute_indices_per_token = np.full((m, topk), -1, dtype=np.int32)
recv_num_tokens_per_expert_list_cumsum = np.full(expert_num + 1, 0, dtype=np.int32)
dst_indices = np.full((expert_num, m), -1, dtype=np.int32)
cumsum_idx = np.full(expert_num, 0, dtype=np.int32)
offset = 0
for expert_id in range(expert_per_rank):
for token_id in range(m):
for k in range(topk):
cur_index = topk_ids[token_id, k]
if cur_index == expert_id:
permute_indices_per_token[token_id, k] = offset
permute_input[offset, :] = input[token_id, :]
offset += 1
recv_num_tokens_per_expert_list_cumsum[expert_id + 1] = offset
return (
permute_input,
permute_indices_per_token,
recv_num_tokens_per_expert_list_cumsum,
topk_weights,
dst_indices,
cumsum_idx,
)
def create_moe_index(token_num, topk, expert_num):
topk_ids = np.full((token_num, topk), -1, dtype=np.int32)
token_nums_per_expert = np.full(expert_num_per_rank, 0, dtype=np.int32)
token_all_num = 0
for i in range(topk_ids.shape[0]):
pos = np.random.choice(np.arange(topk), np.random.randint(1, topk + 1), replace=False)
token_all_num += len(pos)
for j in pos:
topk_ids[i, j] = np.random.choice(expert_num, replace=False)
token_nums_per_expert[topk_ids[i, j]] += 1
return token_all_num, topk_ids, list(token_nums_per_expert)
# 1) preparation
token_num = 7
expert_num_per_rank = 4
topk = 8
hidden_dim = 8192
input = np.random.random((token_num, hidden_dim))
token_nums_this_rank, topk_ids, token_nums_per_expert = create_moe_index(token_num, topk, expert_num_per_rank)
topk_weights = np.random.random((token_num, topk))
print(f"input:\n{input}")
print(f"token_nums_this_rank:\n{token_nums_this_rank}")
print(f"topk_ids:\n{topk_ids}")
print(f"token_nums_per_expert:\n{token_nums_per_expert}")
print(f"topk_weights:\n{topk_weights}")
dtype = "bfloat16"
input_xpu = paddle.to_tensor(input, dtype=dtype)
topk_ids_xpu = paddle.to_tensor(topk_ids)
topk_weights_xpu = paddle.to_tensor(topk_weights)
# 2) cpu calculation
(
permute_input,
permute_indices_per_token,
recv_num_tokens_per_expert_list_cumsum,
dst_weights,
dst_indices,
cumsum_idx,
) = ep_moe_expert_dispatch_cpu(input, topk_ids, topk_weights, token_nums_per_expert, token_nums_this_rank)
print(f"permute_input:\n{permute_input}")
print(f"permute_indices_per_token:\n{permute_indices_per_token}")
print(f"recv_num_tokens_per_expert_list_cumsum:\n{recv_num_tokens_per_expert_list_cumsum}")
print(f"dst_weights:\n{dst_weights}")
print(f"dst_indices:\n{dst_indices}")
print(f"cumsum_idx:\n{cumsum_idx}")
# 3) xpu calculation
(
permute_input_xpu,
permute_indices_per_token_xpu,
recv_num_tokens_per_expert_list_cumsum_xpu,
dst_weights_xpu,
expand_input_scales,
) = ep_moe_expert_dispatch(
input_xpu,
topk_ids_xpu,
topk_weights_xpu,
None,
token_nums_per_expert,
token_nums_this_rank,
"weight_only_int8",
)
# comparation
permute_input_xpu = permute_input_xpu.astype("float32").numpy()
permute_indices_per_token_xpu = permute_indices_per_token_xpu.numpy()
recv_num_tokens_per_expert_list_cumsum_xpu = recv_num_tokens_per_expert_list_cumsum_xpu.numpy()
diff = np.sum(np.abs(permute_input - permute_input_xpu)) / permute_input.size
assert diff < 1e-2, f"diff: {diff}\n permute_input:\n {permute_input}\n permute_input_xpu:\n {permute_input_xpu}\n"
assert (
permute_indices_per_token == permute_indices_per_token_xpu
).all(), f"permute_indices_per_token:\n {permute_indices_per_token}\n permute_indices_per_token_xpu:\n {permute_indices_per_token_xpu}\n"
assert (
recv_num_tokens_per_expert_list_cumsum == recv_num_tokens_per_expert_list_cumsum_xpu
).all(), f"recv_num_tokens_per_expert_list_cumsum:\n {recv_num_tokens_per_expert_list_cumsum}\n recv_num_tokens_per_expert_list_cumsum_xpu:\n {recv_num_tokens_per_expert_list_cumsum_xpu}\n"
print("ep_moe_expert_dispatch test success!")

View File

@@ -0,0 +1,295 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import moe_expert_ffn
np.random.seed(2025)
token_num = 7
expert_num = 64
hidden_dim = 8192
ffn_inter_dim = 7168
ffn_outer_dim = ffn_inter_dim // 2
num_max_dispatch_tokens_per_rank = 128
num_rank = 8
expert_num_per_rank = expert_num // num_rank
used_in_ep_low_latency = True
hadamard_blocksize = 512
ffn_in = (np.random.random([token_num, hidden_dim]) - 0.5).astype("float32")
token_num_lod = np.full([expert_num_per_rank + 1], 0, "int32")
token_num_lod[-1] = token_num
token_num_lod[1:-1] = np.random.randint(0, token_num, [expert_num_per_rank - 1])
token_num_lod = np.sort(token_num_lod)
token_num_per_expert = token_num_lod[1:] - token_num_lod[:-1]
ffn1_w = (np.random.random([expert_num_per_rank, ffn_inter_dim, hidden_dim]) - 0.5).astype("float32")
ffn2_w = (np.random.random([expert_num_per_rank, hidden_dim, ffn_outer_dim]) - 0.5).astype("float32")
ffn2_shift = (np.random.random([1, ffn_outer_dim]) - 0.5).astype("float32")
ffn2_smooth = (np.random.random([1, ffn_outer_dim]) - 0.5).astype("float32")
if used_in_ep_low_latency:
ffn_in_tmp = ffn_in
ffn_in = np.zeros(
[
expert_num_per_rank,
num_max_dispatch_tokens_per_rank * num_rank,
hidden_dim,
],
"float32",
)
for i in range(expert_num_per_rank):
ffn_in[i][: token_num_per_expert[i]] = ffn_in_tmp[token_num_lod[i] : token_num_lod[i + 1]]
token_num_info = token_num_per_expert
else:
token_num_info = token_num_lod
print(f"ffn_in: {ffn_in}")
print(f"token_num_lod: {token_num_lod}")
print(f"token_num_per_expert: {token_num_per_expert}")
print(f"ffn1_w: {ffn1_w}")
print(f"ffn2_w: {ffn2_w}")
def clip_and_round(x, quant_max_bound=127):
return np.clip(np.around(x), -quant_max_bound, quant_max_bound).astype("int8")
def weight_quant_wint8(w_fp32):
w_max = np.max(np.abs(w_fp32), axis=-1, keepdims=True)
w_int8 = clip_and_round(w_fp32 / w_max * 127.0)
return w_int8, w_max.reshape([-1])
def weight_quant_wint4(w_fp32):
w_max = np.max(np.abs(w_fp32), axis=-1, keepdims=True)
w_int4 = clip_and_round(w_fp32 / w_max * 7.0, 7)
w_int4 = (w_int4[:, :, 1::2] & 0xF) << 4 | (w_int4[:, :, ::2] & 0xF) # pack int4
return w_int4, w_max.reshape([-1])
def weight_quant(w_fp32, algo="weight_only_int8"):
if algo == "weight_only_int8":
return weight_quant_wint8(w_fp32)
elif algo == "weight_only_int4":
return weight_quant_wint4(w_fp32)
else:
return None, None
quant_method = "weight_only_int4"
print(f"quant_method={quant_method}, used_in_ep_low_latency={used_in_ep_low_latency}")
ffn1_quant_w, ffn1_w_scale = weight_quant(ffn1_w, quant_method)
ffn2_quant_w, ffn2_w_scale = weight_quant(ffn2_w, quant_method)
print(f"ffn1_w {ffn1_w.shape}: {ffn1_w}")
print(f"ffn2_w {ffn2_w.shape}: {ffn2_w}")
print(f"ffn1_quant_w {ffn1_quant_w.shape}: {ffn1_quant_w}")
print(f"ffn1_w_scale {ffn1_w_scale.shape}: {ffn1_w_scale}")
print(f"ffn2_quant_w {ffn2_quant_w.shape}: {ffn2_quant_w}")
print(f"ffn2_w_scale {ffn2_w_scale.shape}: {ffn2_w_scale}")
def weight_dequant_wint8(w_int, w_scale):
w_shape = w_int.shape
w_scale_new_shape = list(w_shape)
w_scale_new_shape[-1] = 1
w_scale_new = w_scale.reshape(w_scale_new_shape)
w_fp32 = w_int.astype("float32") / 127.0 * w_scale_new
return w_fp32
def weight_dequant_wint4(w_int, w_scale):
w_shape = w_int.shape
w_scale_new_shape = list(w_shape)
w_scale_new_shape[-1] = 1
# w_scale_new_shape[-2] = w_scale_new_shape[-2] * 2
w_scale_new = w_scale.reshape(w_scale_new_shape)
w_new_shape = list(w_shape)
w_new_shape[-1] = w_new_shape[-1] * 2
w_int8 = np.zeros(w_new_shape, dtype=np.int8)
w_int8[:, :, ::2] = w_int & 0xF
w_int8[:, :, 1::2] = (w_int >> 4) & 0xF
w_int8 = np.where(w_int8 >= 8, w_int8 - 16, w_int8)
w_fp32 = w_int8.astype("float32") / 7.0 * w_scale_new
return w_fp32
def weight_dequant(w_int, w_scale, algo="weight_only_int8"):
if algo == "weight_only_int8":
return weight_dequant_wint8(w_int, w_scale)
elif algo == "weight_only_int4":
return weight_dequant_wint4(w_int, w_scale)
else:
return None, None
def fwt(a):
"""
快速 Walsh-Hadamard 变换(正向变换)
:param a: 输入列表长度必须是2的幂
:return: 变换后的列表
"""
n = len(a)
# 检查输入长度是否为2的幂
if n == 0 or n & (n - 1) != 0:
raise ValueError("输入长度必须是2的幂")
# 复制输入以避免修改原始数据
a = a.copy()
h = 1
while h < n:
for i in range(0, n, 2 * h):
for j in range(i, i + h):
x = a[j]
y = a[j + h]
a[j] = x + y
a[j + h] = x - y
h <<= 1 # 等同于 h *= 2
return a
def hadamard(_x, block_size):
x = np.copy(_x).reshape((-1, _x.shape[-1]))
if block_size == -1:
return x
m = 1
n = x.shape[-1]
for i in range(len(x.shape) - 1):
m = m * x.shape[i]
for i in range(m):
for j in range(0, n, block_size):
subx = x[i][j : j + block_size]
x[i][j : j + block_size] = fwt(subx)
return x.reshape(_x.shape)
# print(f"ffn1_w {ffn1_w.shape}: {ffn1_w}")
# ffn1_quant_w8, ffn1_w8_scale = weight_quant(ffn1_w, "weight_only_int8")
# ffn1_quant_w4, ffn1_w4_scale = weight_quant(ffn1_w, "weight_only_int4")
# print(f"ffn1_quant_w8 {ffn1_quant_w8.shape}: {ffn1_quant_w8}")
# print(f"ffn1_w8_scale {ffn1_w8_scale.shape}: {ffn1_w8_scale}")
# print(f"ffn1_quant_w4 {ffn1_quant_w4.shape}: {ffn1_quant_w4}")
# print(f"ffn1_w4_scale {ffn1_w4_scale.shape}: {ffn1_w4_scale}")
# ffn1_w8_dq = weight_dequant(ffn1_quant_w8, ffn1_w8_scale, "weight_only_int8")
# ffn1_w4_dq = weight_dequant(ffn1_quant_w4, ffn1_w4_scale, "weight_only_int4")
# print(f"ffn1_w8_dq {ffn1_w8_dq.shape}: {ffn1_w8_dq}")
# print(f"ffn1_w4_dq {ffn1_w4_dq.shape}: {ffn1_w4_dq}")
def batch_matmul(x, token_num_info, w, w_scale, algo):
w_fp32 = weight_dequant(w, w_scale, algo)
print(f"x {x.shape}, w {w_fp32.shape}")
out_hidden_dim = w_fp32.shape[1]
if not used_in_ep_low_latency:
y = np.zeros([x.shape[0], out_hidden_dim], "float32")
token_num_lod = token_num_info
for i in range(expert_num_per_rank):
start_i = token_num_lod[i]
end_i = token_num_lod[i + 1]
subx = x[start_i:end_i]
subw = w_fp32[i : i + 1].transpose([0, 2, 1])
y[start_i:end_i] = np.matmul(subx, subw)
else:
y = np.zeros(
[
expert_num_per_rank,
num_max_dispatch_tokens_per_rank,
out_hidden_dim,
],
"float32",
)
token_num_per_expert = token_num_info
for i in range(expert_num_per_rank):
subx = x[i][: token_num_per_expert[i]]
subw = w_fp32[i : i + 1].transpose([0, 2, 1])
y[i][: token_num_per_expert[i]] = np.matmul(subx, subw)
return y
def swiglu(x):
new_shape = list(x.shape)
new_shape[-1] //= 2
x1 = np.copy(x[..., : new_shape[-1]])
x2 = np.copy(x[..., new_shape[-1] :])
y = x1 * 1.0 / (1.0 + np.exp(-x1)) * x2
return y
ref_ffn1_out = batch_matmul(ffn_in, token_num_info, ffn1_quant_w, ffn1_w_scale, quant_method)
print(f"ref_ffn1_out {ref_ffn1_out.shape}: {ref_ffn1_out}")
ref_swiglu_out = swiglu(ref_ffn1_out)
print(f"ref_swiglu_out {ref_swiglu_out.shape}: {ref_swiglu_out}")
ref_swiglu_out = (ref_swiglu_out + ffn2_shift) * ffn2_smooth
ref_hadamard_out = hadamard(ref_swiglu_out, hadamard_blocksize)
ref_ffn2_out = batch_matmul(
ref_hadamard_out,
token_num_info,
ffn2_quant_w,
ffn2_w_scale,
quant_method,
)
ffn_in_tensor = paddle.to_tensor(ffn_in).astype("bfloat16")
token_num_info_tensor = paddle.to_tensor(token_num_info)
ffn1_quant_w_tensor = paddle.to_tensor(ffn1_quant_w)
ffn2_quant_w_tensor = paddle.to_tensor(ffn2_quant_w)
ffn1_w_scale_tensor = paddle.to_tensor(ffn1_w_scale)
ffn2_w_scale_tensor = paddle.to_tensor(ffn2_w_scale)
ffn2_shift_tensor = paddle.to_tensor(ffn2_shift).astype("bfloat16")
ffn2_smooth_tensor = paddle.to_tensor(ffn2_smooth).astype("bfloat16")
ffn2_out = moe_expert_ffn(
ffn_in_tensor,
token_num_info_tensor,
ffn1_quant_w_tensor,
ffn2_quant_w_tensor,
None, # ffn1_bias
None, # ffn2_bias
None, # ffn1_act_scale
None, # ffn2_act_scale
ffn1_w_scale_tensor,
ffn2_w_scale_tensor,
ffn2_shift_tensor,
ffn2_smooth_tensor,
quant_method,
hadamard_blocksize,
token_num,
)
ffn2_out = ffn2_out.astype("float32").numpy()
print(f"ffn2_out: {ffn2_out}")
print(f"ref_ffn2_out: {ref_ffn2_out}")
if not used_in_ep_low_latency:
diff = np.sum(np.abs(ffn2_out - ref_ffn2_out)) / np.sum(np.abs(ffn2_out))
print(f"diff: {diff}")
assert diff < 0.01, f"diff: {diff}\nffn2_out:\n{ffn2_out}\nref_ffn2_out:\n{ref_ffn2_out}\n"
else:
diff_all = 0
for i in range(expert_num_per_rank):
token_num_this_expert = token_num_per_expert[i]
if token_num_this_expert == 0:
continue
tmp_ffn2_out = ffn2_out[i][:token_num_this_expert]
tmp_ref_ffn2_out = ref_ffn2_out[i][:token_num_this_expert]
diff = np.sum(np.abs(tmp_ffn2_out - tmp_ref_ffn2_out)) / np.sum(np.abs(tmp_ffn2_out))
print(f"diff: {diff}")
print(f"{i}, tmp_ffn2_out: {tmp_ffn2_out}")
print(f"{i}, tmp_ref_ffn2_out: {tmp_ref_ffn2_out}")
diff_all += diff
diff_avg = diff_all / expert_num_per_rank
print(f"diff_avg: {diff_avg}")
assert diff_avg < 0.03, f"diff_avg: {diff_avg}\nffn2_out:\n{ffn2_out}\nref_ffn2_out:\n{ref_ffn2_out}\n"

View File

@@ -0,0 +1,200 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import moe_redundant_topk_select
def ref_moe_topk_select(gating_logits, bias, moe_topk, apply_norm_weight):
assert apply_norm_weight is True
def _softmax(x):
axis = 1
x_max = np.max(x, axis=axis, keepdims=True)
e_x = np.exp(x - x_max)
return e_x / np.sum(e_x, axis=axis, keepdims=True)
softmax_logits = _softmax(gating_logits)
softmax_logits_with_bias = np.copy(softmax_logits)
if bias is not None:
softmax_logits_with_bias += bias.reshape([1, -1])
sorted_indices = np.argsort(softmax_logits_with_bias, axis=1, kind="stable")[:, ::-1]
topk_ids = sorted_indices[:, :moe_topk]
topk_weights = np.take_along_axis(softmax_logits, topk_ids, axis=1)
topk_weights = topk_weights[:, :moe_topk]
topk_weights /= np.sum(topk_weights, axis=1, keepdims=True)
return topk_ids, topk_weights
def generate_expert_in_rank_num(num_values, extra_num):
if num_values <= 0:
return np.array([])
# 一次性生成所有随机索引
indices = np.random.randint(0, num_values, extra_num)
# 使用 bincount 统计频率(向量化操作)
bin_counts = np.bincount(indices, minlength=num_values)
# 结果 = 基础值1 + 额外增加值
return 1 + bin_counts
def generate_expert_id_to_ep_rank(expert_in_rank_num_list, num_rank, redundant_num_plus_one):
num_expert = expert_in_rank_num_list.size
redundant_num = redundant_num_plus_one - 1
# 生成随机排名ID (一次性生成)
rank_idx = np.random.randint(0, num_rank, num_expert)
# 初始化结果矩阵 (-1 表示未分配)
expert_id_to_rank_id = np.full((num_expert, redundant_num + 1), -1, dtype=int)
# 初始分配 - 每个专家分配一个基础ID
expert_ids = np.arange(num_expert)
expert_id_to_rank_id[expert_ids, 0] = rank_idx
if redundant_num > 0:
positions = np.ones(num_expert, dtype=int)
for expert_id in range(expert_in_rank_num_list.size):
repeat_num = expert_in_rank_num_list[expert_id]
while repeat_num > 1:
rank_idx = np.random.randint(0, num_rank)
expert_id_to_rank_id[expert_id][positions[expert_id]] = rank_idx
positions[expert_id] += 1
repeat_num -= 1
return expert_id_to_rank_id
def generate_rank_to_id(id_to_rank, rank_num):
max_rank = -1
for ranks in id_to_rank:
if ranks:
current_max = max(ranks)
if current_max > max_rank:
max_rank = current_max
if max_rank < 0 or max_rank >= rank_num:
return []
rank_to_id = [[] for _ in range(rank_num)]
for id_val, ranks in enumerate(id_to_rank):
for r in ranks:
if r < 0: # 忽略负数值
continue
if r < len(rank_to_id): # 确保索引在有效范围内
rank_to_id[r].append(id_val)
return rank_to_id
def my_sort(key_arr, val_arr):
if key_arr.shape != val_arr.shape:
return None, None
# 不转换整个数组,逐行处理
sorted_keys = np.empty_like(key_arr)
sorted_vals = np.empty_like(val_arr)
for i in range(key_arr.shape[0]):
keys = key_arr[i]
vals = val_arr[i]
idx = np.lexsort((keys, vals))
sorted_keys[i] = keys[idx]
sorted_vals[i] = vals[idx]
return sorted_keys, sorted_vals
if __name__ == "__main__":
seed = np.random.randint(1, 1e9)
print(f"numpy random seed={seed}")
np.random.seed(seed)
rank_num = 8
token_num = 1215
expert_num = 256
moe_topk = 8
redundant_ep_rank_num_plus_one = 1 # no redundant experts
apply_norm_weight = True
enable_softmax_top_k_fused = True
gating_logits = np.random.random([token_num, expert_num]).astype("float32")
bias = np.random.random([expert_num]).astype("float32")
expert_in_rank_num_list = generate_expert_in_rank_num(expert_num, redundant_ep_rank_num_plus_one - 1)
print(f"expert_in_rank_num_list={expert_in_rank_num_list}")
expert_id_to_ep_rank_array = generate_expert_id_to_ep_rank(
expert_in_rank_num_list, rank_num, redundant_ep_rank_num_plus_one
)
tokens_per_expert_stats_list = np.random.randint(0, 20, size=(expert_num))
print(f"expert_id_to_ep_rank_array={expert_id_to_ep_rank_array}")
print(f"tokens_per_expert_stats_list={tokens_per_expert_stats_list}")
# ref_topk_ids, ref_topk_weights = ref_moe_topk_select(
# gating_logits, bias, moe_topk, apply_norm_weight
# )
gating_logits = paddle.to_tensor(gating_logits).astype("float32")
expert_id_to_ep_rank_array = paddle.to_tensor(expert_id_to_ep_rank_array).astype("int32")
expert_in_rank_num_list = paddle.to_tensor(expert_in_rank_num_list).astype("int32")
tokens_per_expert_stats_list = paddle.to_tensor(tokens_per_expert_stats_list).astype("int32")
if bias is not None:
bias = paddle.to_tensor(bias).astype("float32")
gating_logits_ref = gating_logits.cpu()
expert_id_to_ep_rank_array_ref = expert_id_to_ep_rank_array.cpu()
expert_in_rank_num_list_ref = expert_in_rank_num_list.cpu()
tokens_per_expert_stats_list_ref = tokens_per_expert_stats_list.cpu()
bias_ref = None
if bias is not None:
bias_ref = bias.cpu()
topk_ids, topk_weights = moe_redundant_topk_select(
gating_logits,
expert_id_to_ep_rank_array,
expert_in_rank_num_list,
tokens_per_expert_stats_list,
bias,
moe_topk,
apply_norm_weight,
enable_softmax_top_k_fused,
redundant_ep_rank_num_plus_one,
)
topk_ids_ref, topk_weights_ref = moe_redundant_topk_select(
gating_logits_ref,
expert_id_to_ep_rank_array_ref,
expert_in_rank_num_list_ref,
tokens_per_expert_stats_list_ref,
bias_ref,
moe_topk,
apply_norm_weight,
enable_softmax_top_k_fused,
redundant_ep_rank_num_plus_one,
)
topk_ids_np, topk_weights_np, tokens_per_expert_stats_list_np = (
topk_ids.numpy(),
topk_weights.numpy(),
tokens_per_expert_stats_list.numpy(),
)
topk_ids_ref, topk_weights_ref, tokens_per_expert_stats_list_ref = (
topk_ids_ref.numpy(),
topk_weights_ref.numpy(),
tokens_per_expert_stats_list_ref.numpy(),
)
sorted_topk_ids, sorted_topk_weights = my_sort(topk_ids_np, topk_weights_np)
sorted_topk_ids_ref, sorted_topk_weights_ref = my_sort(topk_ids_ref, topk_weights_ref)
assert np.array_equal(
tokens_per_expert_stats_list_np, tokens_per_expert_stats_list_ref
), f"\ntokens_per_expert_stats_list:\n{tokens_per_expert_stats_list.numpy()}\ntokens_per_expert_stats_list_ref:\n{tokens_per_expert_stats_list_ref}"
assert np.array_equal(
sorted_topk_ids, sorted_topk_ids_ref
), f"\ntopk_ids:\n{topk_ids.numpy()}\ntopk_ids_ref:\n{topk_ids_ref}"
assert np.allclose(
sorted_topk_weights, sorted_topk_weights_ref
), f"\ntopk_weights:\n{topk_weights.numpy()}\ntopk_weights_ref:\n{topk_weights_ref}"
print("Passed all tests.")

View File

@@ -0,0 +1,67 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import f_moe_topk_select
np.random.seed(2025)
token_num = 15
expert_num = 256
moe_topk = 8
apply_norm_weight = True
gating_logits = np.random.random([token_num, expert_num]).astype("float32")
bias = np.random.random([expert_num]).astype("float32")
def ref_moe_topk_select(gating_logits, bias, moe_topk, apply_norm_weight):
assert apply_norm_weight is True
def _softmax(x):
axis = 1
x_max = np.max(x, axis=axis, keepdims=True)
e_x = np.exp(x - x_max)
return e_x / np.sum(e_x, axis=axis, keepdims=True)
softmax_logits = _softmax(gating_logits)
softmax_logits_with_bias = np.copy(softmax_logits)
if bias is not None:
softmax_logits_with_bias += bias.reshape([1, -1])
sorted_indices = np.argsort(softmax_logits_with_bias, axis=1, kind="stable")[:, ::-1]
topk_ids = sorted_indices[:, :moe_topk]
topk_weights = np.take_along_axis(softmax_logits, topk_ids, axis=1)
topk_weights = topk_weights[:, :moe_topk]
topk_weights /= np.sum(topk_weights, axis=1, keepdims=True)
return topk_ids, topk_weights
ref_topk_ids, ref_topk_weights = ref_moe_topk_select(gating_logits, bias, moe_topk, apply_norm_weight)
gating_logits = paddle.to_tensor(gating_logits)
if bias is not None:
bias = paddle.to_tensor(bias)
topk_ids, topk_weights = f_moe_topk_select(gating_logits, bias, moe_topk, apply_norm_weight)
assert np.array_equal(
topk_ids.numpy(), ref_topk_ids
), f"\ntopk_ids:\n{topk_ids.numpy()}\nref_topk_ids:\n{ref_topk_ids}"
assert np.allclose(
topk_weights.numpy(), ref_topk_weights
), f"\ntopk_weights:\n{topk_weights.numpy()}\nref_topk_weights:\n{ref_topk_weights}"
print("Passed all tests.")

View File

@@ -0,0 +1,23 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import read_data_ipc
x = np.zeros([512, 8, 64, 128], dtype="float32")
x = paddle.to_tensor(x, place=paddle.CPUPlace())
read_data_ipc(x, "test_set_data_ipc")
print(x.numpy().flatten()[:100])

View File

@@ -0,0 +1,25 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import paddle
from fastdeploy.model_executor.ops.xpu import set_data_ipc
x = paddle.full(shape=[512, 8, 64, 128], fill_value=2, dtype="float32")
set_data_ipc(x, "test_set_data_ipc")
print("set_data_ipc done")
time.sleep(60)

View File

@@ -0,0 +1,45 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import time
import paddle
from fastdeploy.model_executor.ops.xpu import set_data_ipc, share_external_data
shape = [8, 128]
dtype = "bfloat16"
shm_name = "xpu_shm_tensor"
paddle.set_device("xpu:0")
if sys.argv[1] == "0":
print("set data ipc")
input_tensor = paddle.cast(paddle.rand(shape), dtype)
set_data_ipc(input_tensor, shm_name)
print(input_tensor)
time.sleep(120)
elif sys.argv[1] == "1":
print("test share_external_data")
tmp_input = paddle.empty([], dtype=dtype)
output = share_external_data(tmp_input, shm_name, shape, use_ipc=True)
print(output.shape)
print(output.cpu()) # use xpu_memcpy
else:
print("test share_external_data")
tmp_input = paddle.empty([], dtype=dtype)
output = share_external_data(tmp_input, shm_name, shape, use_ipc=False)
temp_output = output * 1 # avoid xpu_memcpy
print(temp_output)

View File

@@ -0,0 +1,138 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import (
weight_only_linear_xpu as weight_only_linear,
)
np.random.seed(2025)
def np_clip_and_round(x, abs_max=127):
return np.clip(np.around(x), -abs_max, abs_max).astype("int8")
def np_quant_weight_int4(weight_np):
assert weight_np.dtype == np.float32 # k,n
weight = weight_np
# weight = np.transpose(weight_np, [1, 0]) # n,k
max_value = np.max(np.abs(weight), axis=1).reshape(-1, 1) # k => k,1
quanted_weight = np_clip_and_round(weight / max_value * 7.0, 7) # n,k
quanted_weight = (quanted_weight[:, 1::2] & 0xF) << 4 | (quanted_weight[:, ::2] & 0xF) # pack int4, [n,k//2]
weight_scales = (max_value).astype(weight_np.dtype).reshape(-1)
return quanted_weight, weight_scales.astype(np.float32)
def np_quant_weight(weight_np, algo="weight_only_int8"):
assert weight_np.dtype == np.float32
if algo == "weight_only_int4":
return np_quant_weight_int4(weight_np)
weight = weight_np
# weight = np.transpose(weight_np, [1, 0])
max_value = np.max(np.abs(weight), axis=1).reshape(-1, 1)
quanted_weight = np_clip_and_round(weight / max_value * 127.0)
weight_scales = (max_value).astype(weight_np.dtype).reshape(-1)
return quanted_weight, weight_scales.astype(np.float32)
def int8_to_bin_np(value):
value_np = np.int8(value)
return np.binary_repr(value_np, width=8)
def int8_to_bin(value):
if not -128 <= value <= 127:
raise ValueError("int8 值必须在 -128 到 127 之间")
return format(value & 0xFF, "08b") # '08b' 表示 8 位二进制,高位补零
def weight_dequant_wint8(w_int, w_scale):
w_shape = w_int.shape
# print(f"w_shape={w_shape}")
w_scale_new_shape = list(w_shape)
w_scale_new_shape[-1] = 1
w_scale_new = w_scale.reshape(w_scale_new_shape)
w_fp32 = w_int.astype("float32") / 127.0 * w_scale_new
return w_fp32
def weight_dequant_wint4(w_int, w_scale):
w_shape = w_int.shape
w_scale_new_shape = list(w_shape)
w_scale_new_shape[-1] = 1
# w_scale_new_shape[-2] = w_scale_new_shape[-2] * 2
w_scale_new = w_scale.reshape(w_scale_new_shape)
w_new_shape = list(w_shape)
w_new_shape[-1] = w_new_shape[-1] * 2
w_int8 = np.zeros(w_new_shape, dtype=np.int8)
w_int8[:, :, ::2] = w_int & 0xF
w_int8[:, :, 1::2] = (w_int >> 4) & 0xF
w_int8 = np.where(w_int8 >= 8, w_int8 - 16, w_int8)
w_fp32 = w_int8.astype("float32") / 7.0 * w_scale_new
return w_fp32
def weight_dequant(w_int, w_scale, algo="weight_only_int8"):
if algo == "weight_only_int8":
return weight_dequant_wint8(w_int, w_scale)
elif algo == "weight_only_int4":
return weight_dequant_wint4(w_int, w_scale)
else:
return None, None
def batch_matmul(x, qw, wscale, algo, bias=None):
w_fp32 = weight_dequant(qw, wscale, algo)
# print(f"w_dequant={w_fp32}")
# print(f"x.shape={x.shape}, w.shape={w_fp32.shape}")
w_trans = np.transpose(w_fp32, [1, 0])
y = np.matmul(x, w_trans)
if bias is not None:
y = y + bias
return y
# 1) preparation
m, n, k = 64, 128, 256
algo = "weight_only_int8"
weight_dtype = "int8"
# m, n, k = 12, 14336, 8192
x_np = (np.random.random((m, k)).astype(np.float32) - 0.5) * 10
w_np = (np.random.random((n, k)).astype(np.float32) - 0.5) * 10
qw_np, wscale_np = np_quant_weight(w_np, algo)
# print(f"x_np={x_np}")
# print(f"w_np={w_np}")
# 2) np calculation
out_np = batch_matmul(x_np, qw_np, wscale_np, algo)
# 3) xpu calculation
x_pd = paddle.to_tensor(x_np).astype("bfloat16")
qw_pd = paddle.to_tensor(qw_np)
wscale_pd = paddle.to_tensor(wscale_np).astype("float32")
out_pd = weight_only_linear(x_pd, qw_pd, wscale_pd, None, weight_dtype, -1, -1)
print(f"out_pd:\n{out_pd}")
print(f"out_np:\n{out_np}")
# comparation
print(f"out_pd, mean={out_pd.mean()}, std={out_pd.std()}")
print(f"out_np, mean={out_np.mean()}, std={out_np.std()}")
sum_diff = np.sum(np.abs(out_pd.astype("float32").numpy() - out_np.astype("float32")))
print(f"sum_diff: {sum_diff}")
print(f"avg_diff: {sum_diff / (m * n)}")

View File

@@ -83,20 +83,20 @@ cd FastDeploy
### Download Kunlunxin Compilation Dependency
```bash
bash custom_ops/xpu_ops/src/download_dependencies.sh stable
bash custom_ops/xpu_ops/download_dependencies.sh stable
```
Alternatively, you can download the latest versions of XTDK and XVLLM (Not recommended)
```bash
bash custom_ops/xpu_ops/src/download_dependencies.sh develop
bash custom_ops/xpu_ops/download_dependencies.sh develop
```
Set environment variables,
```bash
export CLANG_PATH=$(pwd)/custom_ops/xpu_ops/src/third_party/xtdk
export XVLLM_PATH=$(pwd)/custom_ops/xpu_ops/src/third_party/xvllm
export CLANG_PATH=$(pwd)/custom_ops/xpu_ops/third_party/xtdk
export XVLLM_PATH=$(pwd)/custom_ops/xpu_ops/third_party/xvllm
```
### Compile and Install.

View File

@@ -83,20 +83,20 @@ cd FastDeploy
### 下载昆仑编译依赖
```bash
bash custom_ops/xpu_ops/src/download_dependencies.sh stable
bash custom_ops/xpu_ops/download_dependencies.sh stable
```
或者你也可以下载最新版编译依赖
```bash
bash custom_ops/xpu_ops/src/download_dependencies.sh develop
bash custom_ops/xpu_ops/download_dependencies.sh develop
```
设置环境变量
```bash
export CLANG_PATH=$(pwd)/custom_ops/xpu_ops/src/third_party/xtdk
export XVLLM_PATH=$(pwd)/custom_ops/xpu_ops/src/third_party/xvllm
export CLANG_PATH=$(pwd)/custom_ops/xpu_ops/third_party/xtdk
export XVLLM_PATH=$(pwd)/custom_ops/xpu_ops/third_party/xvllm
```
### 开始编译并安装:

View File

@@ -20,9 +20,9 @@ python -m pip uninstall fastdeploy-xpu -y
python -m pip install paddlepaddle-xpu -i https://www.paddlepaddle.org.cn/packages/nightly/xpu-p800/
# python -m pip install https://paddle-whl.bj.bcebos.com/nightly/xpu-p800/paddlepaddle-xpu/paddlepaddle_xpu-3.0.0.dev20250901-cp310-cp310-linux_x86_64.whl
echo "build whl"
bash custom_ops/xpu_ops/src/download_dependencies.sh develop
export CLANG_PATH=$(pwd)/custom_ops/xpu_ops/src/third_party/xtdk
export XVLLM_PATH=$(pwd)/custom_ops/xpu_ops/src/third_party/xvllm
bash custom_ops/xpu_ops/download_dependencies.sh develop
export CLANG_PATH=$(pwd)/custom_ops/xpu_ops/third_party/xtdk
export XVLLM_PATH=$(pwd)/custom_ops/xpu_ops/third_party/xvllm
bash build.sh || exit 1
echo "pip others"
python -m pip install openai -U