From 9d0074a91a9e3c7d08c570e8f8f804dd199bc77a Mon Sep 17 00:00:00 2001 From: zhupengyang <1165938320@qq.com> Date: Wed, 10 Sep 2025 12:22:50 +0800 Subject: [PATCH] [xpu] add ep custom ops (#3911) --- build.sh | 4 +- custom_ops/setup_ops.py | 2 +- custom_ops/xpu_ops/{src => }/build.sh | 0 .../{src => }/download_dependencies.sh | 0 custom_ops/xpu_ops/{src => }/setup_ops.py | 31 +- custom_ops/xpu_ops/src/ops/fused_rms_norm.cc | 225 ++++++ custom_ops/xpu_ops/src/ops/get_output.cc | 34 +- custom_ops/xpu_ops/src/ops/moe_ep_combine.cc | 119 +++ custom_ops/xpu_ops/src/ops/moe_ep_dispatch.cc | 201 +++++ custom_ops/xpu_ops/src/ops/moe_expert_ffn.cc | 535 +++++++++++++ .../src/ops/moe_redundant_topk_select.cc | 134 ++++ custom_ops/xpu_ops/src/ops/moe_topk_select.cc | 84 +++ custom_ops/xpu_ops/src/ops/msg_utils.h | 39 + .../draft_model_postprocess.cc | 6 +- .../draft_model_preprocess.cc | 6 +- .../{mtp_ops => mtp}/draft_model_update.cc | 6 +- .../eagle_get_hidden_states.cc | 6 +- .../eagle_get_self_hidden_states.cc | 6 +- .../{mtp_ops => mtp}/mtp_save_first_token.cc | 0 .../ops/{mtp_ops => mtp}/mtp_step_paddle.cc | 6 +- .../speculate_clear_accept_nums.cc | 6 +- .../{mtp_ops => mtp}/speculate_get_output.cc | 0 .../speculate_get_output_padding_offset.cc | 6 +- .../speculate_get_padding_offset.cc | 6 +- .../speculate_get_seq_lens_output.cc | 6 +- .../src/ops/{mtp_ops => mtp}/speculate_msg.h | 0 .../speculate_rebuild_append_padding.cc | 0 .../{mtp_ops => mtp}/speculate_save_output.cc | 0 .../speculate_set_stop_value_multi_seqs.cc | 0 .../speculate_set_value_by_flags.cc | 6 +- .../speculate_step_reschedule.cc | 6 +- .../speculate_token_penalty_multi_scores.cc | 37 +- .../speculate_update_input_ids_cpu.cc | 0 .../{mtp_ops => mtp}/speculate_update_v3.cc | 6 +- .../ops/{mtp_ops => mtp}/speculate_verify.cc | 7 +- .../ops/{mtp_ops => mtp}/top_p_candidates.cc | 6 +- .../src/ops/open_shm_and_get_meta_signal.cc | 91 +++ .../src/ops/pybind/alloc_cache_pinned.cc | 46 ++ .../pybind/cachekv_signal_thread_worker.cc | 111 +++ .../ops/pybind/cachekv_signal_thread_worker.h | 35 + .../src/ops/pybind/get_peermem_addr.cc | 26 + custom_ops/xpu_ops/src/ops/pybind/profiler.cc | 26 + custom_ops/xpu_ops/src/ops/pybind/pybind.cc | 704 ++++++++++++++++++ custom_ops/xpu_ops/src/ops/pybind/pybind.h | 29 + custom_ops/xpu_ops/src/ops/read_data_ipc.cc | 95 +++ .../xpu_ops/src/ops/remote_cache_kv_ipc.cc | 113 +++ .../xpu_ops/src/ops/remote_cache_kv_ipc.h | 98 +++ custom_ops/xpu_ops/src/ops/set_data_ipc.cc | 69 ++ .../xpu_ops/src/ops/share_external_data.cc | 57 ++ .../xpu_ops/src/ops/swap_cache_batch.cc | 166 +++++ custom_ops/xpu_ops/src/ops/utility/debug.cc | 194 +++++ custom_ops/xpu_ops/src/ops/utility/env.cc | 63 ++ custom_ops/xpu_ops/src/ops/utility/env.h | 29 + custom_ops/xpu_ops/src/ops/utility/logging.cc | 95 +++ custom_ops/xpu_ops/src/ops/utility/logging.h | 114 +++ .../xpu_ops/src/ops/weight_only_linear.cc | 207 +++++ .../test/test_block_attn_prefix_cache.py | 336 +++++++++ .../xpu_ops/test/test_fused_rms_norm.py | 137 ++++ .../xpu_ops/test/test_get_infer_param.py | 95 +++ .../xpu_ops/test/test_moe_ep_combine.py | 93 +++ .../xpu_ops/test/test_moe_ep_dispatch.py | 136 ++++ .../xpu_ops/test/test_moe_expert_ffn.py | 295 ++++++++ .../test/test_moe_redundant_topk_select.py | 200 +++++ .../xpu_ops/test/test_moe_topk_select.py | 67 ++ custom_ops/xpu_ops/test/test_read_data_ipc.py | 23 + custom_ops/xpu_ops/test/test_set_data_ipc.py | 25 + .../xpu_ops/test/test_set_get_data_ipc.py | 45 ++ .../xpu_ops/test/test_weight_only_linear.py | 138 ++++ .../get_started/installation/kunlunxin_xpu.md | 8 +- .../get_started/installation/kunlunxin_xpu.md | 8 +- scripts/run_ci_xpu.sh | 6 +- 71 files changed, 5436 insertions(+), 80 deletions(-) rename custom_ops/xpu_ops/{src => }/build.sh (100%) rename custom_ops/xpu_ops/{src => }/download_dependencies.sh (100%) rename custom_ops/xpu_ops/{src => }/setup_ops.py (84%) create mode 100644 custom_ops/xpu_ops/src/ops/fused_rms_norm.cc create mode 100644 custom_ops/xpu_ops/src/ops/moe_ep_combine.cc create mode 100644 custom_ops/xpu_ops/src/ops/moe_ep_dispatch.cc create mode 100644 custom_ops/xpu_ops/src/ops/moe_expert_ffn.cc create mode 100644 custom_ops/xpu_ops/src/ops/moe_redundant_topk_select.cc create mode 100644 custom_ops/xpu_ops/src/ops/moe_topk_select.cc create mode 100644 custom_ops/xpu_ops/src/ops/msg_utils.h rename custom_ops/xpu_ops/src/ops/{mtp_ops => mtp}/draft_model_postprocess.cc (94%) rename custom_ops/xpu_ops/src/ops/{mtp_ops => mtp}/draft_model_preprocess.cc (97%) rename custom_ops/xpu_ops/src/ops/{mtp_ops => mtp}/draft_model_update.cc (97%) rename custom_ops/xpu_ops/src/ops/{mtp_ops => mtp}/eagle_get_hidden_states.cc (97%) rename custom_ops/xpu_ops/src/ops/{mtp_ops => mtp}/eagle_get_self_hidden_states.cc (96%) rename custom_ops/xpu_ops/src/ops/{mtp_ops => mtp}/mtp_save_first_token.cc (100%) rename custom_ops/xpu_ops/src/ops/{mtp_ops => mtp}/mtp_step_paddle.cc (96%) rename custom_ops/xpu_ops/src/ops/{mtp_ops => mtp}/speculate_clear_accept_nums.cc (91%) rename custom_ops/xpu_ops/src/ops/{mtp_ops => mtp}/speculate_get_output.cc (100%) rename custom_ops/xpu_ops/src/ops/{mtp_ops => mtp}/speculate_get_output_padding_offset.cc (95%) rename custom_ops/xpu_ops/src/ops/{mtp_ops => mtp}/speculate_get_padding_offset.cc (97%) rename custom_ops/xpu_ops/src/ops/{mtp_ops => mtp}/speculate_get_seq_lens_output.cc (94%) rename custom_ops/xpu_ops/src/ops/{mtp_ops => mtp}/speculate_msg.h (100%) rename custom_ops/xpu_ops/src/ops/{mtp_ops => mtp}/speculate_rebuild_append_padding.cc (100%) rename custom_ops/xpu_ops/src/ops/{mtp_ops => mtp}/speculate_save_output.cc (100%) rename custom_ops/xpu_ops/src/ops/{mtp_ops => mtp}/speculate_set_stop_value_multi_seqs.cc (100%) rename custom_ops/xpu_ops/src/ops/{mtp_ops => mtp}/speculate_set_value_by_flags.cc (94%) rename custom_ops/xpu_ops/src/ops/{mtp_ops => mtp}/speculate_step_reschedule.cc (98%) rename custom_ops/xpu_ops/src/ops/{mtp_ops => mtp}/speculate_token_penalty_multi_scores.cc (83%) rename custom_ops/xpu_ops/src/ops/{mtp_ops => mtp}/speculate_update_input_ids_cpu.cc (100%) rename custom_ops/xpu_ops/src/ops/{mtp_ops => mtp}/speculate_update_v3.cc (96%) rename custom_ops/xpu_ops/src/ops/{mtp_ops => mtp}/speculate_verify.cc (98%) rename custom_ops/xpu_ops/src/ops/{mtp_ops => mtp}/top_p_candidates.cc (97%) create mode 100644 custom_ops/xpu_ops/src/ops/open_shm_and_get_meta_signal.cc create mode 100644 custom_ops/xpu_ops/src/ops/pybind/alloc_cache_pinned.cc create mode 100644 custom_ops/xpu_ops/src/ops/pybind/cachekv_signal_thread_worker.cc create mode 100644 custom_ops/xpu_ops/src/ops/pybind/cachekv_signal_thread_worker.h create mode 100644 custom_ops/xpu_ops/src/ops/pybind/get_peermem_addr.cc create mode 100644 custom_ops/xpu_ops/src/ops/pybind/profiler.cc create mode 100644 custom_ops/xpu_ops/src/ops/pybind/pybind.cc create mode 100644 custom_ops/xpu_ops/src/ops/pybind/pybind.h create mode 100644 custom_ops/xpu_ops/src/ops/read_data_ipc.cc create mode 100644 custom_ops/xpu_ops/src/ops/remote_cache_kv_ipc.cc create mode 100644 custom_ops/xpu_ops/src/ops/remote_cache_kv_ipc.h create mode 100644 custom_ops/xpu_ops/src/ops/set_data_ipc.cc create mode 100644 custom_ops/xpu_ops/src/ops/share_external_data.cc create mode 100644 custom_ops/xpu_ops/src/ops/swap_cache_batch.cc create mode 100644 custom_ops/xpu_ops/src/ops/utility/debug.cc create mode 100644 custom_ops/xpu_ops/src/ops/utility/env.cc create mode 100644 custom_ops/xpu_ops/src/ops/utility/env.h create mode 100644 custom_ops/xpu_ops/src/ops/utility/logging.cc create mode 100644 custom_ops/xpu_ops/src/ops/utility/logging.h create mode 100644 custom_ops/xpu_ops/src/ops/weight_only_linear.cc create mode 100644 custom_ops/xpu_ops/test/test_block_attn_prefix_cache.py create mode 100644 custom_ops/xpu_ops/test/test_fused_rms_norm.py create mode 100755 custom_ops/xpu_ops/test/test_get_infer_param.py create mode 100644 custom_ops/xpu_ops/test/test_moe_ep_combine.py create mode 100644 custom_ops/xpu_ops/test/test_moe_ep_dispatch.py create mode 100644 custom_ops/xpu_ops/test/test_moe_expert_ffn.py create mode 100644 custom_ops/xpu_ops/test/test_moe_redundant_topk_select.py create mode 100644 custom_ops/xpu_ops/test/test_moe_topk_select.py create mode 100644 custom_ops/xpu_ops/test/test_read_data_ipc.py create mode 100644 custom_ops/xpu_ops/test/test_set_data_ipc.py create mode 100644 custom_ops/xpu_ops/test/test_set_get_data_ipc.py create mode 100644 custom_ops/xpu_ops/test/test_weight_only_linear.py diff --git a/build.sh b/build.sh index e37fa2bdc..d8b27d03b 100644 --- a/build.sh +++ b/build.sh @@ -143,9 +143,9 @@ function build_and_install_ops() { TMP_DIR_REAL_PATH=`readlink -f ${OPS_TMP_DIR}` is_xpu=`$python -c "import paddle; print(paddle.is_compiled_with_xpu())"` if [ "$is_xpu" = "True" ]; then - cd xpu_ops/src + cd xpu_ops bash build.sh ${TMP_DIR_REAL_PATH} - cd ../.. + cd .. elif [ "$FD_CPU_USE_BF16" == "true" ]; then if [ "$FD_BUILDING_ARCS" == "" ]; then FD_CPU_USE_BF16=True ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR} diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index ded12fb16..8bca9837d 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -542,7 +542,7 @@ elif paddle.is_compiled_with_cuda(): include_package_data=True, ) elif paddle.is_compiled_with_xpu(): - assert False, "In XPU, we should use setup_ops.py in xpu_ops/src, not this." + assert False, "For XPU, please use setup_ops.py in the xpu_ops directory to compile custom ops." elif paddle.is_compiled_with_custom_device("iluvatar_gpu"): setup( name="fastdeploy_ops", diff --git a/custom_ops/xpu_ops/src/build.sh b/custom_ops/xpu_ops/build.sh similarity index 100% rename from custom_ops/xpu_ops/src/build.sh rename to custom_ops/xpu_ops/build.sh diff --git a/custom_ops/xpu_ops/src/download_dependencies.sh b/custom_ops/xpu_ops/download_dependencies.sh similarity index 100% rename from custom_ops/xpu_ops/src/download_dependencies.sh rename to custom_ops/xpu_ops/download_dependencies.sh diff --git a/custom_ops/xpu_ops/src/setup_ops.py b/custom_ops/xpu_ops/setup_ops.py similarity index 84% rename from custom_ops/xpu_ops/src/setup_ops.py rename to custom_ops/xpu_ops/setup_ops.py index 88f450916..0ff1f3557 100755 --- a/custom_ops/xpu_ops/src/setup_ops.py +++ b/custom_ops/xpu_ops/setup_ops.py @@ -27,7 +27,7 @@ import paddle from paddle.utils.cpp_extension import CppExtension, setup current_file = Path(__file__).resolve() -base_dir = current_file.parent +base_dir = os.path.join(current_file.parent, "src") def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR, XDNN_LIB_DIR): @@ -136,33 +136,8 @@ def xpu_setup_ops(): # build plugin build_plugin(CLANG_PATH, XRE_INC_PATH, XRE_LIB_DIR, XDNN_INC_PATH, XDNN_LIB_DIR) - ops = [ - # custom ops - "./ops/save_with_output_msg.cc", - "./ops/stop_generation_multi_ends.cc", - "./ops/set_value_by_flags_and_idx.cc", - "./ops/get_token_penalty_multi_scores.cc", - "./ops/get_padding_offset.cc", - "./ops/update_inputs.cc", - "./ops/recover_decode_task.cc", - "./ops/update_inputs_v1.cc", - "./ops/get_output.cc", - "./ops/step.cc", - "./ops/get_infer_param.cc", - "./ops/adjust_batch.cc", - "./ops/gather_next_token.cc", - "./ops/block_attn.cc", - "./ops/moe_layer.cc", - "./ops/weight_quantize_xpu.cc", - # device manage ops - "./ops/device/get_context_gm_max_mem_demand.cc", - "./ops/device/get_free_global_memory.cc", - "./ops/device/get_total_global_memory.cc", - "./ops/device/get_used_global_memory.cc", - ] - ops = [os.path.join(base_dir, op) for op in ops] - - for root, dirs, files in os.walk(base_dir / "ops/mtp_ops"): + ops = [] + for root, dirs, files in os.walk(os.path.join(base_dir, "ops")): for file in files: if file.endswith(".cc"): ops.append(os.path.join(root, file)) diff --git a/custom_ops/xpu_ops/src/ops/fused_rms_norm.cc b/custom_ops/xpu_ops/src/ops/fused_rms_norm.cc new file mode 100644 index 000000000..08fbb344d --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/fused_rms_norm.cc @@ -0,0 +1,225 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "paddle/extension.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "utility/debug.h" +#include "utility/env.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +XPU_DECLARE_BOOL(ENABLE_XVLLM_SDNN_INFER, false); +namespace api = baidu::xpu::api; + +template +std::vector RmsNormKernel( + const paddle::Tensor& x, + const paddle::optional& bias, + const paddle::optional& residual, + const paddle::Tensor& norm_weight, + const paddle::optional& norm_bias, + const float epsilon, + const int begin_norm_axis, + const float quant_scale, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound) { + using XPU_T = typename XPUTypeTrait::Type; + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + + int ret = -1; + auto x_shape = x.shape(); + PD_CHECK(quant_scale <= 0, "Quantization is not supported"); + PD_CHECK(begin_norm_axis > 0 && begin_norm_axis <= x_shape.size(), + "begin_norm_axis check fail"); + PD_CHECK(norm_bias.get_ptr() == nullptr, + "rms norm kernel don't support norm_bias"); + + int64_t m = std::accumulate(x_shape.begin(), + x_shape.begin() + begin_norm_axis, + static_cast(1), + std::multiplies()); + int64_t n = std::accumulate(x_shape.begin() + begin_norm_axis, + x_shape.end(), + static_cast(1), + std::multiplies()); + + PD_CHECK(n == norm_weight.shape()[0], + "The product from begin_norm_axis to the last axis of x must be " + "equal to the norm_weight's shape[0]"); + if (bias.get_ptr()) { + PD_CHECK(n == bias.get_ptr()->shape()[0], + "The product from begin_norm_axis to the last axis of x must be " + "equal to the bias's shape[0]"); + } + + paddle::Tensor out = paddle::empty(x_shape, x.dtype(), x.place()); + paddle::Tensor residual_out = paddle::empty(x_shape, x.dtype(), x.place()); + const XPU_T* x_data = reinterpret_cast(x.data()); + const XPU_T* norm_weight_data = + reinterpret_cast(norm_weight.data()); + const XPU_T* bias_data = + bias.get_ptr() ? reinterpret_cast(bias.get_ptr()->data()) + : nullptr; + const XPU_T* residual_data = + residual.get_ptr() + ? reinterpret_cast(residual.get_ptr()->data()) + : nullptr; + XPU_T* out_data = reinterpret_cast(const_cast(out.data())); + XPU_T* residual_out_data = nullptr; + if (residual_data) { + residual_out_data = + reinterpret_cast(const_cast(residual_out.data())); + } + + XPU_T* add_out_data = const_cast(x_data); + if (bias_data) { + ret = api::broadcast_add( + xpu_ctx->x_context(), x_data, bias_data, out_data, {m, n}, {n}); + PD_CHECK(ret == 0, "broadcast_add"); + add_out_data = out_data; + } + + bool use_sdnn = FLAGS_ENABLE_XVLLM_SDNN_INFER; + if (residual_data) { + ret = infer_ops::add_rms_layer_norm(xpu_ctx->x_context(), + add_out_data, + residual_data, + out_data, + m, + n, + epsilon, + norm_weight_data, + nullptr, + nullptr, + residual_out_data, + nullptr, + use_sdnn); + PD_CHECK(ret == 0, "add_rms_layer_norm"); + } else { + ret = api::rms_layer_norm(xpu_ctx->x_context(), + add_out_data, + out_data, + m, + n, + epsilon, + norm_weight_data, + nullptr, + nullptr, + false); + PD_CHECK(ret == 0, "rms_layer_norm"); + } + + return {out, residual_out}; +} + +std::vector RmsNorm( + const paddle::Tensor& x, + const paddle::optional& bias, + const paddle::optional& residual, + const paddle::Tensor& norm_weight, + const paddle::optional& norm_bias, + const float epsilon, + const int begin_norm_axis, + const float quant_scale, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound) { + const auto x_type = x.dtype(); + +#define APPLY_RMS_NORM_KERNEL(TX) \ + return RmsNormKernel(x, \ + bias, \ + residual, \ + norm_weight, \ + norm_bias, \ + epsilon, \ + begin_norm_axis, \ + quant_scale, \ + quant_round_type, \ + quant_max_bound, \ + quant_min_bound); + + if (x_type == paddle::DataType::BFLOAT16) { + APPLY_RMS_NORM_KERNEL(paddle::bfloat16); + } else if (x_type == paddle::DataType::FLOAT16) { + APPLY_RMS_NORM_KERNEL(paddle::float16); + } else if (x_type == paddle::DataType::FLOAT32) { + APPLY_RMS_NORM_KERNEL(float); + } else { + PD_THROW("RmsNorm not support x_type=", static_cast(x_type)); + return {}; + } +#undef APPLY_RMS_NORM_KERNEL +} + +std::vector> RmsNormInferShape( + const std::vector& x_shape, + const paddle::optional>& bias_shape, + const paddle::optional>& residual_shape, + const std::vector& norm_weight_shape, + const paddle::optional>& norm_bias_shape, + const float epsilon, + const int begin_norm_axis, + const float quant_scale, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound) { + PD_CHECK(begin_norm_axis > 0 && begin_norm_axis <= x_shape.size(), + "begin_norm_axis check fail"); + int64_t m = std::accumulate(x_shape.begin(), + x_shape.begin() + begin_norm_axis, + static_cast(1), + std::multiplies()); + return {x_shape, x_shape, {m}}; +} + +std::vector RmsNormInferDtype( + const paddle::DataType& x_dtype, + const paddle::optional& bias_dtype, + const paddle::optional& residual_dtype, + const paddle::DataType& norm_weight_dtype, + const paddle::optional& norm_bias_dtype, + const float epsilon, + const int begin_norm_axis, + const float quant_scale, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound) { + // out, residual_out + return {x_dtype, x_dtype}; +} + +PD_BUILD_STATIC_OP(fused_rms_norm_xpu) + .Inputs({"x", + paddle::Optional("bias"), + paddle::Optional("residual"), + "norm_weight", + paddle::Optional("norm_bias")}) + .Outputs({"out", "residul_out"}) + .Attrs({"epsilon:float", + "begin_norm_axis:int", + "quant_scale:float", + "quant_round_type:int", + "quant_max_bound:float", + "quant_min_bound:float"}) + .SetKernelFn(PD_KERNEL(RmsNorm)) + .SetInferShapeFn(PD_INFER_SHAPE(RmsNormInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(RmsNormInferDtype)); diff --git a/custom_ops/xpu_ops/src/ops/get_output.cc b/custom_ops/xpu_ops/src/ops/get_output.cc index 58ea591c3..6886f441f 100644 --- a/custom_ops/xpu_ops/src/ops/get_output.cc +++ b/custom_ops/xpu_ops/src/ops/get_output.cc @@ -18,13 +18,35 @@ #include #include #include +#include "msg_utils.h" -#define MAX_BSZ 256 -// #define GET_OUTPUT_DEBUG -struct msgdata { - long mtype; - int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens -}; +void GetOutputKVSignal(const paddle::Tensor& x, + int64_t rank_id, + bool wait_flag) { + int msg_queue_id = 1024 + rank_id; + static struct msgdatakv msg_rcv; + static key_t key = ftok("/opt/", msg_queue_id); + static int msgid = msgget(key, IPC_CREAT | 0666); + + int* out_data = const_cast(x.data()); + int ret = -1; + if (!wait_flag) { + ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, IPC_NOWAIT); + } else { + ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, 0); + } + if (ret == -1) { + out_data[0] = -1; + out_data[1] = -1; + return; + } + int encoder_count = msg_rcv.mtext[0]; + + for (int i = 0; i < encoder_count * 3 + 2; i++) { + out_data[i] = msg_rcv.mtext[i]; + } + return; +} void GetOutput(const paddle::Tensor &x, int64_t rank_id, bool wait_flag, int msg_queue_id) { diff --git a/custom_ops/xpu_ops/src/ops/moe_ep_combine.cc b/custom_ops/xpu_ops/src/ops/moe_ep_combine.cc new file mode 100644 index 000000000..7ae0782b6 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/moe_ep_combine.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "paddle/extension.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "utility/debug.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +template +std::vector MoeEPCombineKernel( + const paddle::Tensor& + ffn_out, // expand_token_num * hidden_dim dtype is fp16/bf16 + const paddle::Tensor& moe_index, // token_num * topk dtype is int + const paddle::Tensor& + weights, // token_num * topk dtype is same as ffn_out + int64_t recv_token_num, + int64_t expand_token_num, + int64_t hidden_dim, + int64_t topk) { + using XPU_T = typename XPUTypeTrait::Type; + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + + auto combined_out = paddle::empty( + {recv_token_num, hidden_dim}, ffn_out.dtype(), ffn_out.place()); + + const float* dequant_score = nullptr; + int ret = infer_ops::moe_ep_ffn_post_fusion( + xpu_ctx->x_context(), + reinterpret_cast(ffn_out.data()), + moe_index.data(), + reinterpret_cast(weights.data()), + dequant_score, + reinterpret_cast(combined_out.mutable_data()), + recv_token_num, + hidden_dim, + topk, + expand_token_num); + PD_CHECK(ret == 0); + + return {combined_out}; +} + +std::vector MoeEPCombine(const paddle::Tensor& ffn_out, + const paddle::Tensor& moe_index, + const paddle::Tensor& weights, + const int recv_token_num, + const int expand_token_num, + const int hidden_dim, + const int topk) { +#define APPLY_KERNEL(TX) \ + return MoeEPCombineKernel(ffn_out, \ + moe_index, \ + weights, \ + recv_token_num, \ + expand_token_num, \ + hidden_dim, \ + topk); + + const auto ffn_out_dtype = ffn_out.dtype(); + if (ffn_out_dtype == paddle::DataType::FLOAT16) { + APPLY_KERNEL(paddle::float16); + } else if (ffn_out_dtype == paddle::DataType::BFLOAT16) { + APPLY_KERNEL(paddle::bfloat16); + } else { + PD_THROW("MoeEPCombine not support ffn_out_type==%d", + static_cast(ffn_out_dtype)); + return {}; + } + +#undef APPLY_KERNEL +} + +std::vector> MoeEPCombineInferShape( + const std::vector& ffn_out_shape, + const std::vector& moe_index_shape, + const std::vector& weights_shape, + const int recv_token_num, + const int expand_token_num, + const int hidden_dim, + const int topk) { + std::vector combined_out_shape = {recv_token_num, hidden_dim}; + return {combined_out_shape}; +} + +std::vector MoeEPCombineInferDtype( + const paddle::DataType& ffn_out_dtype, + const paddle::DataType& moe_index_dtype, + const paddle::DataType& weights_dtype) { + return {ffn_out_dtype}; +} + +PD_BUILD_STATIC_OP(ep_moe_expert_combine) + .Inputs({"ffn_out", "moe_index", "weights"}) + .Outputs({"combined_out"}) + .Attrs({"recv_token_num: int", + "expand_token_num: int", + "hidden_dim: int", + "topk: int"}) + .SetKernelFn(PD_KERNEL(MoeEPCombine)) + .SetInferShapeFn(PD_INFER_SHAPE(MoeEPCombineInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(MoeEPCombineInferDtype)); diff --git a/custom_ops/xpu_ops/src/ops/moe_ep_dispatch.cc b/custom_ops/xpu_ops/src/ops/moe_ep_dispatch.cc new file mode 100644 index 000000000..2690b8b13 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/moe_ep_dispatch.cc @@ -0,0 +1,201 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "paddle/extension.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "utility/debug.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +template +std::vector EPMoeExpertDispatchKernel( + const paddle::Tensor& input, + const paddle::Tensor& topk_ids, + const paddle::Tensor& topk_weights, + const paddle::optional& input_scales, + const std::vector& token_nums_per_expert, + const int64_t token_nums_this_rank) { + using XPU_TX = typename XPUTypeTrait::Type; + using XPU_TY = typename XPUTypeTrait::Type; + phi::XPUPlace xpu_place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = + paddle::experimental::DeviceContextPool::Instance().Get(xpu_place); + auto xpu_ctx = static_cast(dev_ctx); + + const auto input_type = input.dtype(); + auto m = input.dims()[0]; + auto n = input.dims()[1]; + const int64_t expert_num = token_nums_per_expert.size(); + const int topk = topk_ids.dims()[1]; + auto place = input.place(); + + auto block_num = xpu_ctx->x_context()->ncluster(); + paddle::Tensor permute_input; + auto permute_indices_per_token = + paddle::empty({m, topk}, paddle::DataType::INT32, place); + auto expert_m = paddle::empty({expert_num}, paddle::DataType::INT32, place); + auto recv_num_tokens_per_expert_list_cumsum = + paddle::empty({expert_num + 1}, paddle::DataType::INT32, place); + auto expand_input_scales = + paddle::empty({token_nums_this_rank}, paddle::DataType::FLOAT32, place); + const int64_t ep_size = 1; + const int64_t ep_rank = 0; + + if (std::is_same::value) { + permute_input = + paddle::empty({token_nums_this_rank, n}, paddle::DataType::INT8, place); + auto ret = infer_ops::moe_ffn_pre_sorted_quant_pe( + xpu_ctx->x_context(), + reinterpret_cast(input.data()), + topk_ids.data(), + input_scales.get_ptr()->data(), + nullptr, + reinterpret_cast(permute_input.data()), + const_cast(permute_indices_per_token.data()), + const_cast(expert_m.data()), + const_cast(recv_num_tokens_per_expert_list_cumsum.data()), + expand_input_scales.data(), + m, + n, + expert_num, + topk, + block_num, + token_nums_this_rank); + PD_CHECK(ret == 0, "moe_ep_ffn_pre_sorted failed"); + } else { + permute_input = paddle::empty({token_nums_this_rank, n}, input_type, place); + auto ret = infer_ops::moe_ep_ffn_pre_sorted( + xpu_ctx->x_context(), + reinterpret_cast(input.data()), + topk_ids.data(), + nullptr, + reinterpret_cast(permute_input.data()), + const_cast(permute_indices_per_token.data()), + const_cast(expert_m.data()), + const_cast(recv_num_tokens_per_expert_list_cumsum.data()), + m, + n, + expert_num, + topk, + block_num, + ep_size, + ep_rank, + token_nums_this_rank); + PD_CHECK(ret == 0, "moe_ep_ffn_pre_sorted failed"); + } + return {permute_input, + permute_indices_per_token, + recv_num_tokens_per_expert_list_cumsum, + topk_weights, + expand_input_scales}; +} + +std::vector EPMoeExpertDispatch( + const paddle::Tensor& input, + const paddle::Tensor& topk_ids, + const paddle::Tensor& topk_weights, + const paddle::optional& input_scales, + const std::vector& token_nums_per_expert, + const int token_nums_this_rank, + const std::string quant_method) { +#define APPLY_KERNEL(TX, TY) \ + return EPMoeExpertDispatchKernel(input, \ + topk_ids, \ + topk_weights, \ + input_scales, \ + token_nums_per_expert, \ + token_nums_this_rank); + + const auto input_dtype = input.dtype(); + if (input_dtype == paddle::DataType::FLOAT16 && quant_method == "w4a8") { + APPLY_KERNEL(paddle::float16, int8_t); + } else if (input_dtype == paddle::DataType::FLOAT16 && + quant_method != "w4a8") { + APPLY_KERNEL(paddle::float16, paddle::float16); + } else if (input_dtype == paddle::DataType::BFLOAT16 && + quant_method == "w4a8") { + APPLY_KERNEL(paddle::bfloat16, int8_t); + } else if (input_dtype == paddle::DataType::BFLOAT16 && + quant_method != "w4a8") { + APPLY_KERNEL(paddle::bfloat16, paddle::bfloat16); + } else { + PD_THROW("EPMoeExpertDispatch not support input_dtype=", + static_cast(input_dtype), + "quant_method=", + quant_method); + return {}; + } + +#undef APPLY_KERNEL +} + +std::vector> EPMoeExpertDispatchInferShape( + const std::vector& input_shape, + const std::vector& topk_ids_shape, + const std::vector& topk_weights_shape, + const paddle::optional>& input_scales_shape, + const std::vector& token_nums_per_expert, + const int token_nums_this_rank, + const std::string quant_method) { + const int m = input_shape[0]; + const int hidden_size = input_shape[input_shape.size() - 1]; + const int topk = topk_ids_shape[topk_ids_shape.size() - 1]; + const int expert_num = token_nums_per_expert.size(); + return {{token_nums_this_rank, hidden_size}, + {expert_num, m}, + {expert_num}, + {token_nums_this_rank}, + {token_nums_this_rank}}; +} + +std::vector EPMoeExpertDispatchInferDtype( + const paddle::DataType& input_dtype, + const paddle::DataType& topk_ids_dtype, + const paddle::DataType& topk_weights_dtype, + const paddle::optional& input_scales_dtype, + const std::vector& token_nums_per_expert, + const int token_nums_this_rank, + const std::string quant_method) { + auto output_dtype = input_dtype; + if (quant_method == "w4a8") { + output_dtype = paddle::DataType::INT8; + } + return { + output_dtype, + paddle::DataType::INT32, + paddle::DataType::INT32, + topk_weights_dtype, + paddle::DataType::FLOAT32, + }; +} + +PD_BUILD_STATIC_OP(ep_moe_expert_dispatch) + .Inputs( + {"input", "topk_ids", "topk_weights", paddle::Optional("input_scales")}) + .Outputs({"permute_input", + "permute_indices_per_token", + "token_nums_per_expert_cumsum", + "dst_weights", + "expand_input_scales"}) + .Attrs({"token_nums_per_expert: std::vector", + "token_nums_this_rank: int", + "quant_method: std::string"}) + .SetKernelFn(PD_KERNEL(EPMoeExpertDispatch)) + .SetInferShapeFn(PD_INFER_SHAPE(EPMoeExpertDispatchInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(EPMoeExpertDispatchInferDtype)); diff --git a/custom_ops/xpu_ops/src/ops/moe_expert_ffn.cc b/custom_ops/xpu_ops/src/ops/moe_expert_ffn.cc new file mode 100644 index 000000000..7916f38a9 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/moe_expert_ffn.cc @@ -0,0 +1,535 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include "paddle/extension.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "utility/debug.h" +#include "utility/env.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +XPU_DECLARE_BOOL(MOE_FFN_USE_DENSE_INPUT, false); +XPU_DECLARE_BOOL(BKCL_DISPATCH_ALL_GATHER, false); + +namespace xftblock = baidu::xpu::xftblock; +namespace api = baidu::xpu::api; + +template +void MoeExpertFFNImpl(xftblock::Tensor* ffn_in, + xftblock::Tensor* token_num_info, + xftblock::Tensor* ffn1_weight, + xftblock::Tensor* ffn2_weight, + xftblock::Tensor* ffn1_bias, + xftblock::Tensor* ffn2_bias, + xftblock::Tensor* ffn2_out, + float* ffn2_act_scale, + TX2* ffn2_shift, + TX2* ffn2_smooth, + const int hadamard_blocksize) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + xftblock::XFTContext xctx(xpu_ctx->x_context(), nullptr); + auto rt_guard = xctx.get_rt_guard(); + auto xftblock_tx2 = xftblock::DataTypeToEnum::value; + + int ret = -1; + int expert_num = ffn1_weight->get_dim(0); + int inter_dim = ffn1_weight->get_dim(1); + int outer_dim = inter_dim / 2; + + bool is_padding_input = ffn_in->get_dims().size() == 3; + auto ffn1_out_shape = ffn_in->get_dims(); + int hidden_dim = ffn1_out_shape[ffn1_out_shape.size() - 1]; + ffn1_out_shape[ffn1_out_shape.size() - 1] = inter_dim; + xftblock::Tensor ffn1_out(rt_guard, xftblock_tx2, ffn1_out_shape); + ret = xftblock::xft_moe_fc_block_eb( + &xctx, + ffn_in, + ffn1_weight, + &ffn1_out, + ffn1_bias, + is_padding_input ? nullptr : token_num_info, + is_padding_input ? token_num_info : nullptr, + expert_num, + 1, // moe_topk + ffn1_out_shape.size() == 2 ? xftblock::MoeFCInputMode::DENSE + : xftblock::MoeFCInputMode::SPARSE); + PD_CHECK(ret == 0); + + int token_num = ffn_in->numel() / hidden_dim; + auto swiglu_out_shape = ffn1_out_shape; + swiglu_out_shape[swiglu_out_shape.size() - 1] /= 2; + xftblock::Tensor swiglu_out(rt_guard, xftblock_tx2, swiglu_out_shape); + ret = api::fast_swiglu(xpu_ctx->x_context(), + ffn1_out.data(), + swiglu_out.mutable_data(), + {token_num, inter_dim}, + 1, + true); + PD_CHECK(ret == 0); + // TODO(mayang02): use fusion_smooth_transform + if (ffn2_shift != nullptr) { + ret = api::broadcast_add(xpu_ctx->x_context(), + ffn2_shift, + swiglu_out.data(), + swiglu_out.mutable_data(), + {1, outer_dim}, + {token_num, outer_dim}); + PD_CHECK(ret == 0); + } + if (ffn2_smooth != nullptr) { + ret = api::broadcast_mul(xpu_ctx->x_context(), + ffn2_smooth, + swiglu_out.data(), + swiglu_out.mutable_data(), + {1, outer_dim}, + {token_num, outer_dim}); + PD_CHECK(ret == 0); + } + + if (hadamard_blocksize > 0) { + ret = infer_ops::fast_walsh_transform(xpu_ctx->x_context(), + swiglu_out.data(), + nullptr, + nullptr, + swiglu_out.mutable_data(), + hadamard_blocksize, + token_num, + outer_dim); + PD_CHECK(ret == 0); + } + + xftblock::Tensor ffn2_in(swiglu_out.mutable_data(), + nullptr, + ffn2_act_scale, + xftblock_tx2, + swiglu_out_shape); + ret = xftblock::xft_moe_fc_block_eb( + &xctx, + &ffn2_in, + ffn2_weight, + ffn2_out, + nullptr, + is_padding_input ? nullptr : token_num_info, + is_padding_input ? token_num_info : nullptr, + expert_num, + 1, // moe_topk + ffn1_out_shape.size() == 2 + ? xftblock::MoeFCInputMode::DENSE + : xftblock::MoeFCInputMode::SPARSE); // bias_mode + PD_CHECK(ret == 0); +} + +static void convert_to_lod(xftblock::XFTContext* xctx, + xftblock::Tensor* token_num_info) { + auto rt_guard = xctx->get_rt_guard(); + auto ctx = xctx->get_context(); + const int expert_num = token_num_info->numel(); + xftblock::Tensor tokens_num_lod( + rt_guard, xftblock::DataType::DT_INT32, {expert_num + 1}); + int ret = api::constant(ctx, tokens_num_lod.data(), expert_num + 1, 0); + PD_CHECK(ret == 0); + ret = api::cumsum(ctx, + token_num_info->data(), + tokens_num_lod.data() + 1, + {expert_num}, + false, + false, + 0); + PD_CHECK(ret == 0); + *token_num_info = std::move(tokens_num_lod); +} + +template +std::vector MoeExpertFFNKernel( + const paddle::Tensor& ffn_in, + const paddle::Tensor& token_num_info, + const paddle::Tensor& ffn1_weight, + const paddle::Tensor& ffn2_weight, + const paddle::optional& ffn1_bias, + const paddle::optional& ffn2_bias, + const paddle::optional& ffn1_act_scale, + const paddle::optional& ffn2_act_scale, + const paddle::optional& ffn1_weight_scale, + const paddle::optional& ffn2_weight_scale, + const paddle::optional& ffn2_shift, + const paddle::optional& ffn2_smooth, + const std::string& quant_method, + const int hadamard_blocksize, + const int valid_token_num) { + using XPU_TX1 = typename XPUTypeTrait::Type; + using XPU_TX2 = typename XPUTypeTrait::Type; + using XPU_TW = typename XPUTypeTrait::Type; + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + xftblock::XFTContext xctx(xpu_ctx->x_context(), nullptr); + auto rt_guard = xctx.get_rt_guard(); + + int ret = -1; + auto input_shape = ffn_in.shape(); + auto ffn1_w_shape = ffn1_weight.shape(); + int expert_num = ffn1_w_shape[0]; + int hidden_dim = input_shape[input_shape.size() - 1]; + int inter_dim = ffn1_w_shape[1]; + int outer_dim = inter_dim / 2; + bool is_padding_input = input_shape.size() == 3; + if (is_padding_input) { + PD_CHECK(input_shape[0] == expert_num); + PD_CHECK(token_num_info.numel() == expert_num, + "token_num_info.numel() != expert_num, " + "token_num_info.numel(): ", + token_num_info.numel(), + ", expert_num: ", + expert_num); + } + + bool is_w4 = quant_method == "w4a8" || quant_method == "weight_only_int4"; + auto xftblock_tx1 = xftblock::DataTypeToEnum::value; + auto xftblock_tx2 = xftblock::DataTypeToEnum::value; + auto xftblock_tw = xftblock::DataTypeToEnum::value; + if (is_w4) { + xftblock_tw = xftblock::DataTypeToEnum::value; + } + float* ffn1_act_scale_data = + ffn1_act_scale.get_ptr() == nullptr + ? nullptr + : const_cast(ffn1_act_scale.get_ptr()->data()); + float* ffn2_act_scale_data = + ffn2_act_scale.get_ptr() == nullptr + ? nullptr + : const_cast(ffn2_act_scale.get_ptr()->data()); + float* ffn1_w_scale_data = + ffn1_weight_scale.get_ptr() == nullptr + ? nullptr + : const_cast(ffn1_weight_scale.get_ptr()->data()); + xftblock::Tensor xffn1_w(const_cast(ffn1_weight.data()), + nullptr, + ffn1_w_scale_data, + xftblock_tw, + {expert_num, inter_dim, hidden_dim}); + float* ffn2_w_scale_data = + ffn2_weight_scale.get_ptr() == nullptr + ? nullptr + : const_cast(ffn2_weight_scale.get_ptr()->data()); + xftblock::Tensor xffn2_w(const_cast(ffn2_weight.data()), + nullptr, + ffn2_w_scale_data, + xftblock_tw, + {expert_num, hidden_dim, outer_dim}); + std::shared_ptr xffn1_bias; + if (ffn1_bias.get_ptr()) { + xffn1_bias = std::make_shared( + const_cast(ffn1_bias.get_ptr()->data()), + xftblock::DataType::DT_FLOAT, + ffn1_bias.get_ptr()->shape()); + } + std::shared_ptr xffn2_bias; + if (ffn2_bias.get_ptr()) { + xffn2_bias = std::make_shared( + const_cast(ffn2_bias.get_ptr()->data()), + xftblock::DataType::DT_FLOAT, + ffn2_bias.get_ptr()->shape()); + } + xftblock::Tensor xtoken_num_info(const_cast(token_num_info.data()), + xftblock::DataType::DT_INT32, + token_num_info.shape()); + XPU_TX2* shift_data = nullptr; + XPU_TX2* smooth_data = nullptr; + if (ffn2_shift.get_ptr()) { + shift_data = reinterpret_cast( + const_cast(ffn2_shift.get_ptr()->data())); + } + if (ffn2_smooth.get_ptr()) { + smooth_data = reinterpret_cast( + const_cast(ffn2_smooth.get_ptr()->data())); + } + paddle::Tensor ffn2_out = + paddle::empty_like(ffn_in, paddle::DataType::BFLOAT16); + xftblock::Tensor xffn1_in; + xftblock::Tensor xffn2_out; + paddle::Tensor ffn1_in_dense; + paddle::Tensor ffn1_in_scale_per_token; + if (FLAGS_MOE_FFN_USE_DENSE_INPUT && is_padding_input) { + convert_to_lod(&xctx, &xtoken_num_info); + if (quant_method == "w4a8") { + ffn1_in_scale_per_token = paddle::empty( + {valid_token_num}, paddle::DataType::FLOAT32, ffn_in.place()); + ffn1_in_dense = paddle::empty({valid_token_num, hidden_dim}, + paddle::DataType::INT8, + ffn_in.place()); + xffn1_in = xftblock::Tensor(ffn1_in_dense.data(), + nullptr, + ffn1_in_scale_per_token.data(), + xftblock::DataType::DT_INT8, + {valid_token_num, hidden_dim}); + if (std::is_same::value) { + PD_CHECK(ffn1_act_scale_data != nullptr, + "need ffn1_act_scale for x int8 per expert input"); + ret = infer_ops::sequence_unpad( + xpu_ctx->x_context(), + ffn1_act_scale_data, + ffn1_in_scale_per_token.data(), + xtoken_num_info.data(), + expert_num, + input_shape[1], + 1, + true); + PD_CHECK(ret == 0); + ret = infer_ops::sequence_unpad( + xpu_ctx->x_context(), + reinterpret_cast(ffn_in.data()), + reinterpret_cast(xffn1_in.data()), + xtoken_num_info.data(), + expert_num, + input_shape[1], + input_shape[2], + true); + PD_CHECK(ret == 0); + } else { + ret = infer_ops::quant2d_per_expert( + xpu_ctx->x_context(), + reinterpret_cast(ffn_in.data()), + ffn1_act_scale_data, + xtoken_num_info.data(), + reinterpret_cast(xffn1_in.data()), + ffn1_in_scale_per_token.data(), + expert_num, + valid_token_num, + hidden_dim, + true, + false, + input_shape[1]); + PD_CHECK(ret == 0); + } + } else { + ffn1_in_dense = paddle::empty( + {valid_token_num, hidden_dim}, ffn_in.dtype(), ffn_in.place()); + xffn1_in = xftblock::Tensor(ffn1_in_dense.data(), + nullptr, + ffn1_act_scale_data, + xftblock_tx1, + {valid_token_num, hidden_dim}); + ret = infer_ops::sequence_unpad( + xpu_ctx->x_context(), + reinterpret_cast(ffn_in.data()), + reinterpret_cast(xffn1_in.data()), + xtoken_num_info.data(), + expert_num, + input_shape[1], + input_shape[2], + true); + PD_CHECK(ret == 0); + } + xffn2_out = + xftblock::Tensor(rt_guard, xftblock_tx2, {valid_token_num, hidden_dim}); + } else if (FLAGS_BKCL_DISPATCH_ALL_GATHER && !is_padding_input && + quant_method == "w4a8") { + convert_to_lod(&xctx, &xtoken_num_info); + ffn1_in_scale_per_token = paddle::empty( + {valid_token_num}, paddle::DataType::FLOAT32, ffn_in.place()); + ffn1_in_dense = paddle::empty( + {valid_token_num, hidden_dim}, paddle::DataType::INT8, ffn_in.place()); + xffn1_in = xftblock::Tensor(ffn1_in_dense.data(), + nullptr, + ffn1_in_scale_per_token.data(), + xftblock::DataType::DT_INT8, + {valid_token_num, hidden_dim}); + ret = infer_ops::quant2d_per_expert( + xpu_ctx->x_context(), + reinterpret_cast(ffn_in.data()), + ffn1_act_scale_data, + xtoken_num_info.data(), + reinterpret_cast(xffn1_in.data()), + ffn1_in_scale_per_token.data(), + expert_num, + valid_token_num, + hidden_dim); + PD_CHECK(ret == 0); + xffn2_out = + xftblock::Tensor(ffn2_out.data(), xftblock_tx2, input_shape); + } else { + xffn1_in = xftblock::Tensor(const_cast(ffn_in.data()), + nullptr, + ffn1_act_scale_data, + xftblock_tx1, + input_shape); + xffn2_out = xftblock::Tensor( + ffn2_out.mutable_data(), xftblock_tx2, input_shape); + } + +#define FFN_IMPL(TX1, TX2, TW, TGEMM) \ + MoeExpertFFNImpl(&xffn1_in, \ + &xtoken_num_info, \ + &xffn1_w, \ + &xffn2_w, \ + xffn1_bias.get(), \ + xffn2_bias.get(), \ + &xffn2_out, \ + ffn2_act_scale_data, \ + shift_data, \ + smooth_data, \ + hadamard_blocksize) + if (quant_method == "weight_only_int8") { + FFN_IMPL(XPU_TX1, XPU_TX2, int8_t, float); + } else if (quant_method == "weight_only_int4") { + FFN_IMPL(XPU_TX1, XPU_TX2, int4_t, int4_wo_int15); + } else if (quant_method == "w4a8") { + if (FLAGS_MOE_FFN_USE_DENSE_INPUT && is_padding_input) { + FFN_IMPL(int8_t, XPU_TX2, int4_t, int4_wo_int8); + } else if (FLAGS_BKCL_DISPATCH_ALL_GATHER && !is_padding_input) { + FFN_IMPL(int8_t, XPU_TX2, int4_t, int4_wo_int8); + } else { + FFN_IMPL(XPU_TX1, XPU_TX2, int4_t, int4_wo_int8); + } + } else { + FFN_IMPL(XPU_TX1, XPU_TX2, XPU_TW, float); + } +#undef FFN_IMPL + if (FLAGS_MOE_FFN_USE_DENSE_INPUT && is_padding_input) { + ret = infer_ops::sequence_pad( + xpu_ctx->x_context(), + const_cast(xffn2_out.data()), + reinterpret_cast(ffn2_out.data()), + xtoken_num_info.data(), + input_shape[0], + input_shape[1], + input_shape[2], + false, + 0); + PD_CHECK(ret == 0); + } + + return {ffn2_out}; +} + +std::vector MoeExpertFFN( + const paddle::Tensor& ffn_in, + const paddle::Tensor& token_num_info, + const paddle::Tensor& ffn1_weight, + const paddle::Tensor& ffn2_weight, + const paddle::optional& ffn1_bias, + const paddle::optional& ffn2_bias, + const paddle::optional& ffn1_act_scale, + const paddle::optional& ffn2_act_scale, + const paddle::optional& ffn1_weight_scale, + const paddle::optional& ffn2_weight_scale, + const paddle::optional& ffn2_shift, + const paddle::optional& ffn2_smooth, + const std::string& quant_method, + const int hadamard_blocksize, + const int valid_token_num) { + const auto x_type = ffn_in.dtype(); + const auto w_type = ffn1_weight.dtype(); + +#define APPLY_FFN_KERNEL(TX1, TX2, TW) \ + return MoeExpertFFNKernel(ffn_in, \ + token_num_info, \ + ffn1_weight, \ + ffn2_weight, \ + ffn1_bias, \ + ffn2_bias, \ + ffn1_act_scale, \ + ffn2_act_scale, \ + ffn1_weight_scale, \ + ffn2_weight_scale, \ + ffn2_shift, \ + ffn2_smooth, \ + quant_method, \ + hadamard_blocksize, \ + valid_token_num); + if (x_type == paddle::DataType::BFLOAT16 && + w_type == paddle::DataType::BFLOAT16) { + APPLY_FFN_KERNEL(paddle::bfloat16, paddle::bfloat16, paddle::bfloat16); + } else if (x_type == paddle::DataType::BFLOAT16 && + w_type == paddle::DataType::INT8) { + APPLY_FFN_KERNEL(paddle::bfloat16, paddle::bfloat16, int8_t); + } else if (x_type == paddle::DataType::INT8 && + w_type == paddle::DataType::INT8) { + APPLY_FFN_KERNEL(int8_t, paddle::bfloat16, int8_t); + } else { + PD_THROW("MoeExpertFFN not support x_type=", + static_cast(x_type), + ", w_type=", + static_cast(w_type)); + return {}; + } +#undef APPLY_FFN_KERNEL +} + +std::vector> MoeExpertFFNInferShape( + const std::vector& permute_input_shape, + const std::vector& token_num_info_shape, + const std::vector& ffn1_weight_shape, + const std::vector& ffn2_weight_shape, + const paddle::optional>& ffn1_bias_shape, + const paddle::optional>& ffn2_bias_shape, + const paddle::optional>& ffn1_act_scale_shape, + const paddle::optional>& ffn2_act_scale_shape, + const paddle::optional>& ffn1_weight_scale_shape, + const paddle::optional>& ffn2_weight_scale_shape, + const paddle::optional>& ffn2_shift_shape, + const paddle::optional>& ffn2_smooth_shape) { + return {permute_input_shape}; +} + +std::vector MoeExpertFFNInferDtype( + const paddle::DataType& permute_input_dtype, + const paddle::DataType& token_num_info_dtype, + const paddle::DataType& ffn1_weight_dtype, + const paddle::DataType& ffn2_weight_dtype, + const paddle::optional& ffn1_bias_dtype, + const paddle::optional& ffn2_bias_dtype, + const paddle::optional& ffn1_act_scale_dtype, + const paddle::optional& ffn2_act_scale_dtype, + const paddle::optional& ffn1_weight_scale_dtype, + const paddle::optional& ffn2_weight_scale_dtype, + const paddle::optional& ffn2_shift_dtype, + const paddle::optional& ffn2_smooth_dtype) { + if (permute_input_dtype == paddle::DataType::INT8) { + return {paddle::DataType::BFLOAT16}; + } else { + return {permute_input_dtype}; + } +} + +PD_BUILD_STATIC_OP(moe_expert_ffn) + .Inputs({"ffn_in", + "token_num_info", + "ffn1_weight", + "ffn2_weight", + paddle::Optional("ffn1_bias"), + paddle::Optional("ffn2_bias"), + paddle::Optional("ffn1_act_scale"), + paddle::Optional("ffn2_act_scale"), + paddle::Optional("ffn1_weight_scale"), + paddle::Optional("ffn2_weight_scale"), + paddle::Optional("ffn2_shift"), + paddle::Optional("ffn2_smooth")}) + .Outputs({"ffn_out"}) + .Attrs({"quant_method:std::string", + "hadamard_blocksize:int", + "valid_token_num:int"}) + .SetKernelFn(PD_KERNEL(MoeExpertFFN)) + .SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNInferDtype)); diff --git a/custom_ops/xpu_ops/src/ops/moe_redundant_topk_select.cc b/custom_ops/xpu_ops/src/ops/moe_redundant_topk_select.cc new file mode 100644 index 000000000..8f24d6e45 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/moe_redundant_topk_select.cc @@ -0,0 +1,134 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "paddle/extension.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "utility/debug.h" + +std::vector MoERedundantTopKSelect( + const paddle::Tensor& gating_logits, + const paddle::Tensor& expert_id_to_ep_rank_array, + const paddle::Tensor& expert_in_rank_num_list, + paddle::Tensor& tokens_per_expert_stats_list, // NOLINT + const paddle::optional& bias, + const int moe_topk, + const bool apply_norm_weight, + const bool enable_softmax_top_k_fused, + const int redundant_ep_rank_num_plus_one) { + namespace api = baidu::xpu::api; + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + api::Context* ctx = xpu_ctx->x_context(); + if (gating_logits.is_cpu()) { + ctx = new api::Context(api::kCPU); + } + + PD_CHECK(apply_norm_weight, "only support apply_norm_weight==true"); + PD_CHECK(enable_softmax_top_k_fused, + "only support enable_softmax_top_k_fused==true"); + PD_CHECK(bias.get_ptr() != nullptr, "only support bias != nullptr"); + + auto gating_logits_dims = gating_logits.shape(); + int expert_num = gating_logits_dims[gating_logits_dims.size() - 1]; + int64_t token_num = 0; + if (gating_logits_dims.size() == 3) { + token_num = gating_logits_dims[0] * gating_logits_dims[1]; + } else { + token_num = gating_logits_dims[0]; + } + auto topk_ids = paddle::empty( + {token_num, moe_topk}, paddle::DataType::INT32, gating_logits.place()); + auto topk_ids_tmp = paddle::empty( + {token_num, moe_topk}, paddle::DataType::INT32, gating_logits.place()); + auto source_rows_tmp = paddle::empty( + {token_num, moe_topk}, paddle::DataType::INT32, gating_logits.place()); + auto topk_weights = paddle::empty( + {token_num, moe_topk}, paddle::DataType::FLOAT32, gating_logits.place()); + + const float* bias_data = + bias.get_ptr() != nullptr ? bias.get_ptr()->data() : nullptr; + int ret = infer_ops::moe_redundant_softmax_topk_normed( + ctx, + gating_logits.data(), + bias_data, + expert_id_to_ep_rank_array.data(), + expert_in_rank_num_list.data(), + tokens_per_expert_stats_list.data(), + topk_weights.data(), + topk_ids.data(), + topk_ids_tmp.data(), + source_rows_tmp.data(), + expert_num, + moe_topk, + token_num, + redundant_ep_rank_num_plus_one); + PD_CHECK(ret == 0); + + return {topk_ids, topk_weights}; +} + +std::vector> MoERedundantTopKSelectInferShape( + const std::vector& gating_logits_shape, + const std::vector& expert_id_to_ep_rank_array_shape, + const std::vector& expert_in_rank_num_list_shape, + const std::vector& tokens_per_expert_stats_list_shape, + const paddle::optional>& bias_shape, + const int moe_topk, + const bool apply_norm_weight, + const bool enable_softmax_top_k_fused, + const int redundant_ep_rank_num_plus_one) { + int64_t token_rows = -1; + if (gating_logits_shape.size() == 3) { + token_rows = gating_logits_shape[0] * gating_logits_shape[1]; + } else { + token_rows = gating_logits_shape[0]; + } + + std::vector topk_ids_shape = {token_rows, moe_topk}; + std::vector topk_weights_shape = {token_rows, moe_topk}; + return {topk_ids_shape, topk_weights_shape}; +} + +std::vector MoERedundantTopKSelectInferDtype( + const paddle::DataType& gating_logits_dtype, + const paddle::DataType& expert_id_to_ep_rank_array_dtype, + const paddle::DataType& expert_in_rank_num_list_dtype, + const paddle::DataType& tokens_per_expert_stats_list_dtype, + const paddle::optional& bias_type, + const int moe_topk, + const bool apply_norm_weight, + const bool enable_softmax_top_k_fused, + const int redundant_ep_rank_num_plus_one) { + return {paddle::DataType::INT32, paddle::DataType::FLOAT32}; +} + +PD_BUILD_OP(moe_redundant_topk_select) + .Inputs({"gating_logits", + "expert_id_to_ep_rank_array", + "expert_in_rank_num_list", + "tokens_per_expert_stats_list", + paddle::Optional("bias")}) + .Outputs({"topk_ids", "topk_weights", "tokens_per_expert_stats_list_out"}) + .Attrs({"moe_topk: int", + "apply_norm_weight: bool", + "enable_softmax_top_k_fused:bool", + "redundant_ep_rank_num_plus_one:int"}) + .SetInplaceMap({{"tokens_per_expert_stats_list", + "tokens_per_expert_stats_list_out"}}) + .SetKernelFn(PD_KERNEL(MoERedundantTopKSelect)) + .SetInferShapeFn(PD_INFER_SHAPE(MoERedundantTopKSelectInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(MoERedundantTopKSelectInferDtype)); diff --git a/custom_ops/xpu_ops/src/ops/moe_topk_select.cc b/custom_ops/xpu_ops/src/ops/moe_topk_select.cc new file mode 100644 index 000000000..7f39e4482 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/moe_topk_select.cc @@ -0,0 +1,84 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "paddle/extension.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "utility/debug.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +std::vector MoeTopkSelect( + const paddle::Tensor& gating_logits, + const paddle::optional& bias, + const int moe_topk, + const bool apply_norm_weight) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + + PD_CHECK(apply_norm_weight, "only support apply_norm_weight==true"); + + auto gating_logits_dims = gating_logits.shape(); + int token_num = gating_logits_dims[0]; + int expert_num = gating_logits_dims[1]; + auto topk_ids = paddle::empty( + {token_num, moe_topk}, paddle::DataType::INT32, gating_logits.place()); + auto topk_weights = paddle::empty( + {token_num, moe_topk}, paddle::DataType::FLOAT32, gating_logits.place()); + int32_t* block_statistic = nullptr; + const float* bias_data = + bias.get_ptr() != nullptr ? bias.get_ptr()->data() : nullptr; + int ret = infer_ops::moe_softmax_topk_norm_fusion( + xpu_ctx->x_context(), + gating_logits.data(), + topk_weights.mutable_data(), + topk_ids.mutable_data(), + block_statistic, + token_num, + expert_num, + moe_topk, + 0, + bias_data); + PD_CHECK(ret == 0); + + return {topk_ids, topk_weights}; +} + +std::vector> MoeTopkSelectInferShape( + const std::vector& gating_logits_shape, + const std::vector& bias_shape, + const int moe_topk, + const bool apply_norm_weight) { + std::vector topk_ids_shape = {gating_logits_shape[0], moe_topk}; + std::vector topk_weights_shape = {gating_logits_shape[0], moe_topk}; + return {topk_ids_shape, topk_weights_shape}; +} + +std::vector MoeTopkSelectInferDtype( + const paddle::DataType& gating_logits_dtype, + const paddle::DataType& bias_dtype) { + return {paddle::DataType::INT64, paddle::DataType::FLOAT32}; +} + +PD_BUILD_STATIC_OP(moe_topk_select) + .Inputs({"gating_logits", paddle::Optional("bias")}) + .Outputs({"topk_ids", "topk_weights"}) + .Attrs({"moe_topk: int", "apply_norm_weight: bool"}) + .SetKernelFn(PD_KERNEL(MoeTopkSelect)) + .SetInferShapeFn(PD_INFER_SHAPE(MoeTopkSelectInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(MoeTopkSelectInferDtype)); diff --git a/custom_ops/xpu_ops/src/ops/msg_utils.h b/custom_ops/xpu_ops/src/ops/msg_utils.h new file mode 100644 index 000000000..96cc32593 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/msg_utils.h @@ -0,0 +1,39 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "paddle/extension.h" + +#define MAX_BSZ 512 + +struct msgdata { + long mtype; // NOLINT + int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens +}; + +struct msgdatakv { + long mtype; // NOLINT + int mtext[MAX_BSZ * 3 + 2]; // encoder_count, layer_id, bid- pair +}; diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/draft_model_postprocess.cc b/custom_ops/xpu_ops/src/ops/mtp/draft_model_postprocess.cc similarity index 94% rename from custom_ops/xpu_ops/src/ops/mtp_ops/draft_model_postprocess.cc rename to custom_ops/xpu_ops/src/ops/mtp/draft_model_postprocess.cc index c61fda27b..91a7e9564 100644 --- a/custom_ops/xpu_ops/src/ops/mtp_ops/draft_model_postprocess.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/draft_model_postprocess.cc @@ -17,6 +17,10 @@ #include "paddle/extension.h" #include "xpu/plugin.h" +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens, const paddle::Tensor& base_model_seq_lens_this_time, const paddle::Tensor& base_model_seq_lens_encoder, @@ -37,7 +41,7 @@ void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens, PADDLE_ENFORCE_XDNN_SUCCESS(r, ""); } -PD_BUILD_OP(draft_model_postprocess) +PD_BUILD_STATIC_OP(draft_model_postprocess) .Inputs({"base_model_draft_tokens", "base_model_seq_lens_this_time", "base_model_seq_lens_encoder", diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/draft_model_preprocess.cc b/custom_ops/xpu_ops/src/ops/mtp/draft_model_preprocess.cc similarity index 97% rename from custom_ops/xpu_ops/src/ops/mtp_ops/draft_model_preprocess.cc rename to custom_ops/xpu_ops/src/ops/mtp/draft_model_preprocess.cc index 68551c548..ec501a790 100644 --- a/custom_ops/xpu_ops/src/ops/mtp_ops/draft_model_preprocess.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/draft_model_preprocess.cc @@ -17,6 +17,10 @@ #include "paddle/phi/core/enforce.h" #include "xpu/plugin.h" +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + namespace api = baidu::xpu::api; void DraftModelPreprocess(const paddle::Tensor& draft_tokens, const paddle::Tensor& input_ids, @@ -90,7 +94,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens, not_need_stop_data[0] = not_need_stop_cpu.data()[0]; } -PD_BUILD_OP(draft_model_preprocess) +PD_BUILD_STATIC_OP(draft_model_preprocess) .Inputs({"draft_tokens", "input_ids", "stop_flags", diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/draft_model_update.cc b/custom_ops/xpu_ops/src/ops/mtp/draft_model_update.cc similarity index 97% rename from custom_ops/xpu_ops/src/ops/mtp_ops/draft_model_update.cc rename to custom_ops/xpu_ops/src/ops/mtp/draft_model_update.cc index 930fc7804..5f88905ce 100644 --- a/custom_ops/xpu_ops/src/ops/mtp_ops/draft_model_update.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/draft_model_update.cc @@ -17,6 +17,10 @@ #include "paddle/phi/core/enforce.h" #include "xpu/plugin.h" +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + void DraftModelUpdate(const paddle::Tensor& inter_next_tokens, const paddle::Tensor& draft_tokens, const paddle::Tensor& pre_ids, @@ -86,7 +90,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens, PD_CHECK(r == 0, "draft_model_update failed."); } -PD_BUILD_OP(draft_model_update) +PD_BUILD_STATIC_OP(draft_model_update) .Inputs({"inter_next_tokens", "draft_tokens", "pre_ids", diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/eagle_get_hidden_states.cc b/custom_ops/xpu_ops/src/ops/mtp/eagle_get_hidden_states.cc similarity index 97% rename from custom_ops/xpu_ops/src/ops/mtp_ops/eagle_get_hidden_states.cc rename to custom_ops/xpu_ops/src/ops/mtp/eagle_get_hidden_states.cc index b45c8febd..952d7464d 100644 --- a/custom_ops/xpu_ops/src/ops/mtp_ops/eagle_get_hidden_states.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/eagle_get_hidden_states.cc @@ -16,6 +16,10 @@ #include "paddle/extension.h" #include "xpu/plugin.h" +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + namespace api = baidu::xpu::api; std::vector EagleGetHiddenStates( const paddle::Tensor& input, @@ -102,7 +106,7 @@ std::vector EagleGetHiddenStates( } } -PD_BUILD_OP(eagle_get_hidden_states) +PD_BUILD_STATIC_OP(eagle_get_hidden_states) .Inputs({"input", "seq_lens_this_time", "seq_lens_encoder", diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/eagle_get_self_hidden_states.cc b/custom_ops/xpu_ops/src/ops/mtp/eagle_get_self_hidden_states.cc similarity index 96% rename from custom_ops/xpu_ops/src/ops/mtp_ops/eagle_get_self_hidden_states.cc rename to custom_ops/xpu_ops/src/ops/mtp/eagle_get_self_hidden_states.cc index 68d09662a..8fbd642a0 100644 --- a/custom_ops/xpu_ops/src/ops/mtp_ops/eagle_get_self_hidden_states.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/eagle_get_self_hidden_states.cc @@ -16,6 +16,10 @@ #include "paddle/extension.h" #include "xpu/plugin.h" +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + namespace api = baidu::xpu::api; std::vector EagleGetSelfHiddenStates( const paddle::Tensor& input, @@ -97,7 +101,7 @@ std::vector EagleGetSelfHiddenStates( } } -PD_BUILD_OP(eagle_get_self_hidden_states) +PD_BUILD_STATIC_OP(eagle_get_self_hidden_states) .Inputs( {"input", "last_seq_lens_this_time", "seq_lens_this_time", "step_idx"}) .Outputs({"out"}) diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/mtp_save_first_token.cc b/custom_ops/xpu_ops/src/ops/mtp/mtp_save_first_token.cc similarity index 100% rename from custom_ops/xpu_ops/src/ops/mtp_ops/mtp_save_first_token.cc rename to custom_ops/xpu_ops/src/ops/mtp/mtp_save_first_token.cc diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/mtp_step_paddle.cc b/custom_ops/xpu_ops/src/ops/mtp/mtp_step_paddle.cc similarity index 96% rename from custom_ops/xpu_ops/src/ops/mtp_ops/mtp_step_paddle.cc rename to custom_ops/xpu_ops/src/ops/mtp/mtp_step_paddle.cc index c7bf2d7a1..46a1c2008 100644 --- a/custom_ops/xpu_ops/src/ops/mtp_ops/mtp_step_paddle.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/mtp_step_paddle.cc @@ -17,6 +17,10 @@ #include "paddle/phi/core/enforce.h" #include "xpu/plugin.h" +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + namespace api = baidu::xpu::api; void MTPStepPaddle( const paddle::Tensor &base_model_stop_flags, @@ -64,7 +68,7 @@ void MTPStepPaddle( } } -PD_BUILD_OP(mtp_step_paddle) +PD_BUILD_STATIC_OP(mtp_step_paddle) .Inputs({"base_model_stop_flags", "stop_flags", "batch_drop", diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_clear_accept_nums.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_clear_accept_nums.cc similarity index 91% rename from custom_ops/xpu_ops/src/ops/mtp_ops/speculate_clear_accept_nums.cc rename to custom_ops/xpu_ops/src/ops/mtp/speculate_clear_accept_nums.cc index f47244169..f18a1503d 100644 --- a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_clear_accept_nums.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_clear_accept_nums.cc @@ -16,6 +16,10 @@ #include "paddle/extension.h" #include "xpu/plugin.h" +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + void SpeculateClearAcceptNums(const paddle::Tensor& accept_num, const paddle::Tensor& seq_lens_decoder) { // printf("enter clear \n"); @@ -31,7 +35,7 @@ void SpeculateClearAcceptNums(const paddle::Tensor& accept_num, PD_CHECK(r == 0, "speculate_clear_accept_nums_kernel failed."); } -PD_BUILD_OP(speculate_clear_accept_nums) +PD_BUILD_STATIC_OP(speculate_clear_accept_nums) .Inputs({"accept_num", "seq_lens_decoder"}) .Outputs({"seq_lens_decoder_out"}) .SetInplaceMap({{"seq_lens_decoder", "seq_lens_decoder_out"}}) diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_get_output.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_get_output.cc similarity index 100% rename from custom_ops/xpu_ops/src/ops/mtp_ops/speculate_get_output.cc rename to custom_ops/xpu_ops/src/ops/mtp/speculate_get_output.cc diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_get_output_padding_offset.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_get_output_padding_offset.cc similarity index 95% rename from custom_ops/xpu_ops/src/ops/mtp_ops/speculate_get_output_padding_offset.cc rename to custom_ops/xpu_ops/src/ops/mtp/speculate_get_output_padding_offset.cc index b29240a08..31d0e1fac 100644 --- a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_get_output_padding_offset.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_get_output_padding_offset.cc @@ -16,6 +16,10 @@ #include "paddle/extension.h" #include "xpu/plugin.h" +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + std::vector SpeculateGetOutputPaddingOffset( const paddle::Tensor& output_cum_offsets_tmp, const paddle::Tensor& out_token_num, @@ -69,7 +73,7 @@ std::vector SpeculateGetOutputPaddingOffsetInferDtype( return {output_cum_offsets_tmp_dtype, output_cum_offsets_tmp_dtype}; } -PD_BUILD_OP(speculate_get_output_padding_offset) +PD_BUILD_STATIC_OP(speculate_get_output_padding_offset) .Inputs({"output_cum_offsets_tmp", "out_token_num", "seq_lens_output"}) .Outputs({"output_padding_offset", "output_cum_offsets"}) .Attrs({"max_seq_len: int"}) diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_get_padding_offset.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_get_padding_offset.cc similarity index 97% rename from custom_ops/xpu_ops/src/ops/mtp_ops/speculate_get_padding_offset.cc rename to custom_ops/xpu_ops/src/ops/mtp/speculate_get_padding_offset.cc index bd06ef2be..1cf14b810 100644 --- a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_get_padding_offset.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_get_padding_offset.cc @@ -16,6 +16,10 @@ #include "paddle/extension.h" #include "xpu/plugin.h" +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + std::vector SpeculateGetPaddingOffset( const paddle::Tensor& input_ids, const paddle::Tensor& draft_tokens, @@ -110,7 +114,7 @@ std::vector SpeculateGetPaddingOffsetInferDtype( seq_len_dtype}; } -PD_BUILD_OP(speculate_get_padding_offset) +PD_BUILD_STATIC_OP(speculate_get_padding_offset) .Inputs({"input_ids", "draft_tokens", "cum_offsets", diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_get_seq_lens_output.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_get_seq_lens_output.cc similarity index 94% rename from custom_ops/xpu_ops/src/ops/mtp_ops/speculate_get_seq_lens_output.cc rename to custom_ops/xpu_ops/src/ops/mtp/speculate_get_seq_lens_output.cc index 3caf47696..2a34ac726 100644 --- a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_get_seq_lens_output.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_get_seq_lens_output.cc @@ -16,6 +16,10 @@ #include "paddle/extension.h" #include "xpu/plugin.h" +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + std::vector SpeculateGetSeqLensOutput( const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& seq_lens_encoder, @@ -61,7 +65,7 @@ std::vector SpeculateGetSeqLensOutputInferDtype( return {seq_lens_this_time_dtype}; } -PD_BUILD_OP(speculate_get_seq_lens_output) +PD_BUILD_STATIC_OP(speculate_get_seq_lens_output) .Inputs({"seq_lens_this_time", "seq_lens_encoder", "seq_lens_decoder"}) .Outputs({"seq_lens_output"}) .SetKernelFn(PD_KERNEL(SpeculateGetSeqLensOutput)) diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_msg.h b/custom_ops/xpu_ops/src/ops/mtp/speculate_msg.h similarity index 100% rename from custom_ops/xpu_ops/src/ops/mtp_ops/speculate_msg.h rename to custom_ops/xpu_ops/src/ops/mtp/speculate_msg.h diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_rebuild_append_padding.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_rebuild_append_padding.cc similarity index 100% rename from custom_ops/xpu_ops/src/ops/mtp_ops/speculate_rebuild_append_padding.cc rename to custom_ops/xpu_ops/src/ops/mtp/speculate_rebuild_append_padding.cc diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_save_output.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_save_output.cc similarity index 100% rename from custom_ops/xpu_ops/src/ops/mtp_ops/speculate_save_output.cc rename to custom_ops/xpu_ops/src/ops/mtp/speculate_save_output.cc diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_set_stop_value_multi_seqs.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_set_stop_value_multi_seqs.cc similarity index 100% rename from custom_ops/xpu_ops/src/ops/mtp_ops/speculate_set_stop_value_multi_seqs.cc rename to custom_ops/xpu_ops/src/ops/mtp/speculate_set_stop_value_multi_seqs.cc diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_set_value_by_flags.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_set_value_by_flags.cc similarity index 94% rename from custom_ops/xpu_ops/src/ops/mtp_ops/speculate_set_value_by_flags.cc rename to custom_ops/xpu_ops/src/ops/mtp/speculate_set_value_by_flags.cc index 60843e88e..5cef0ba27 100644 --- a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_set_value_by_flags.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_set_value_by_flags.cc @@ -16,6 +16,10 @@ #include "paddle/extension.h" #include "xpu/plugin.h" +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all, const paddle::Tensor &accept_tokens, const paddle::Tensor &accept_num, @@ -53,7 +57,7 @@ void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all, PD_CHECK(r == 0, "speculate_clear_accept_nums_kernel failed."); } -PD_BUILD_OP(speculate_set_value_by_flags_and_idx) +PD_BUILD_STATIC_OP(speculate_set_value_by_flags_and_idx) .Inputs({"pre_ids_all", "accept_tokens", "accept_num", diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_step_reschedule.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_reschedule.cc similarity index 98% rename from custom_ops/xpu_ops/src/ops/mtp_ops/speculate_step_reschedule.cc rename to custom_ops/xpu_ops/src/ops/mtp/speculate_step_reschedule.cc index fb150bebc..b2d254acb 100644 --- a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_step_reschedule.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_reschedule.cc @@ -17,6 +17,10 @@ #include "speculate_msg.h" // NOLINT #include "xpu/plugin.h" +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + // 为不修改接口调用方式,入参暂不改变 void SpeculateStepSchedule( const paddle::Tensor &stop_flags, @@ -150,7 +154,7 @@ void SpeculateStepSchedule( } } -PD_BUILD_OP(speculate_step_reschedule) +PD_BUILD_STATIC_OP(speculate_step_reschedule) .Inputs({"stop_flags", "seq_lens_this_time", "ori_seq_lens_encoder", diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_token_penalty_multi_scores.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_token_penalty_multi_scores.cc similarity index 83% rename from custom_ops/xpu_ops/src/ops/mtp_ops/speculate_token_penalty_multi_scores.cc rename to custom_ops/xpu_ops/src/ops/mtp/speculate_token_penalty_multi_scores.cc index 0ecd4e139..a3a5d4a73 100644 --- a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_token_penalty_multi_scores.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_token_penalty_multi_scores.cc @@ -17,20 +17,25 @@ #include "paddle/phi/core/enforce.h" #include "xpu/plugin.h" -void TokenPenaltyMultiScores(const paddle::Tensor& pre_ids, - const paddle::Tensor& logits, - const paddle::Tensor& penalty_scores, - const paddle::Tensor& frequency_scores, - const paddle::Tensor& presence_scores, - const paddle::Tensor& temperatures, - const paddle::Tensor& bad_tokens, - const paddle::Tensor& cur_len, - const paddle::Tensor& min_len, - const paddle::Tensor& eos_token_id, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& output_padding_offset, - const paddle::Tensor& output_cum_offsets, - const int max_seq_len) { +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +void SpeculateTokenPenaltyMultiScores( + const paddle::Tensor& pre_ids, + const paddle::Tensor& logits, + const paddle::Tensor& penalty_scores, + const paddle::Tensor& frequency_scores, + const paddle::Tensor& presence_scores, + const paddle::Tensor& temperatures, + const paddle::Tensor& bad_tokens, + const paddle::Tensor& cur_len, + const paddle::Tensor& min_len, + const paddle::Tensor& eos_token_id, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& output_padding_offset, + const paddle::Tensor& output_cum_offsets, + const int max_seq_len) { namespace api = baidu::xpu::api; phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); @@ -137,7 +142,7 @@ void TokenPenaltyMultiScores(const paddle::Tensor& pre_ids, } } -PD_BUILD_OP(speculate_get_token_penalty_multi_scores) +PD_BUILD_STATIC_OP(speculate_get_token_penalty_multi_scores) .Inputs({"pre_ids", "logits", "penalty_scores", @@ -154,4 +159,4 @@ PD_BUILD_OP(speculate_get_token_penalty_multi_scores) .Outputs({"logits_out"}) .Attrs({"max_seq_len: int"}) .SetInplaceMap({{"logits", "logits_out"}}) - .SetKernelFn(PD_KERNEL(TokenPenaltyMultiScores)); + .SetKernelFn(PD_KERNEL(SpeculateTokenPenaltyMultiScores)); diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_update_input_ids_cpu.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_update_input_ids_cpu.cc similarity index 100% rename from custom_ops/xpu_ops/src/ops/mtp_ops/speculate_update_input_ids_cpu.cc rename to custom_ops/xpu_ops/src/ops/mtp/speculate_update_input_ids_cpu.cc diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_update_v3.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_update_v3.cc similarity index 96% rename from custom_ops/xpu_ops/src/ops/mtp_ops/speculate_update_v3.cc rename to custom_ops/xpu_ops/src/ops/mtp/speculate_update_v3.cc index 7d06582d9..f71159703 100644 --- a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_update_v3.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_update_v3.cc @@ -16,6 +16,10 @@ #include "paddle/extension.h" #include "xpu/plugin.h" +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + namespace api = baidu::xpu::api; void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder, @@ -66,7 +70,7 @@ void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder, not_need_stop_data[0] = not_need_stop_cpu.data()[0]; } -PD_BUILD_OP(speculate_update_v3) +PD_BUILD_STATIC_OP(speculate_update_v3) .Inputs({"seq_lens_encoder", "seq_lens_decoder", "not_need_stop", diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_verify.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc similarity index 98% rename from custom_ops/xpu_ops/src/ops/mtp_ops/speculate_verify.cc rename to custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc index 2316d5ad7..53b5b90dc 100644 --- a/custom_ops/xpu_ops/src/ops/mtp_ops/speculate_verify.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc @@ -17,10 +17,13 @@ #include "paddle/common/flags.h" #include "paddle/extension.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" -#include "ops/utility/debug.h" #include "xpu/internal/infra_op.h" #include "xpu/plugin.h" +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + namespace api = baidu::xpu::api; void SpeculateVerify(const paddle::Tensor &accept_tokens, @@ -221,7 +224,7 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens, } } -PD_BUILD_OP(speculate_verify) +PD_BUILD_STATIC_OP(speculate_verify) .Inputs({"accept_tokens", "accept_num", "step_idx", diff --git a/custom_ops/xpu_ops/src/ops/mtp_ops/top_p_candidates.cc b/custom_ops/xpu_ops/src/ops/mtp/top_p_candidates.cc similarity index 97% rename from custom_ops/xpu_ops/src/ops/mtp_ops/top_p_candidates.cc rename to custom_ops/xpu_ops/src/ops/mtp/top_p_candidates.cc index f5c47ce7d..e261c9912 100644 --- a/custom_ops/xpu_ops/src/ops/mtp_ops/top_p_candidates.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/top_p_candidates.cc @@ -16,6 +16,10 @@ #include "paddle/extension.h" #include "xpu/plugin.h" +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + #define FIXED_TOPK_BASE(topk, ...) \ case (topk): { \ constexpr auto kTopK = topk; \ @@ -149,7 +153,7 @@ std::vector TopPCandidatesInferDtype( return {probs_dtype, paddle::DataType::INT64, paddle::DataType::INT32}; } -PD_BUILD_OP(top_p_candidates) +PD_BUILD_STATIC_OP(top_p_candidates) .Inputs({"probs", "top_p", "output_padding_offset"}) .Outputs({"verify_scores", "verify_tokens", "actual_candidate_lens"}) .Attrs({"candidates_len: int", "max_seq_len: int"}) diff --git a/custom_ops/xpu_ops/src/ops/open_shm_and_get_meta_signal.cc b/custom_ops/xpu_ops/src/ops/open_shm_and_get_meta_signal.cc new file mode 100644 index 000000000..69449f7b2 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/open_shm_and_get_meta_signal.cc @@ -0,0 +1,91 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ops/pybind/pybind.h" +#include "ops/remote_cache_kv_ipc.h" +#include "ops/utility/env.h" +#include "paddle/extension.h" + +XPU_DECLARE_BOOL(fmt_write_cache_completed_signal, false); + +using cache_write_complete_signal_type = + RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data; + +paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank, + const bool keep_pd_step_flag) { + cache_write_complete_signal_type kv_signal_metadata; + const char *fmt_write_cache_completed_signal_str = + std::getenv("FLAGS_fmt_write_cache_completed_signal"); + if (fmt_write_cache_completed_signal_str && + (std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 || + std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) { + kv_signal_metadata = + RemoteCacheKvIpc::open_shm_and_get_complete_signal_meta_data( + rank, keep_pd_step_flag); + } + + auto kv_signal_metadata_out = + paddle::full({3}, -1, paddle::DataType::INT64, paddle::CPUPlace()); + kv_signal_metadata_out.data()[0] = + static_cast(kv_signal_metadata.layer_id); + kv_signal_metadata_out.data()[1] = + reinterpret_cast(kv_signal_metadata.shm_ptr); + kv_signal_metadata_out.data()[2] = + static_cast(kv_signal_metadata.shm_fd); + return kv_signal_metadata_out; +} + +void InitKVSignalPerQuery(const paddle::Tensor &seq_lens_encoder_tensor, + const paddle::Tensor &seq_lens_this_time_tensor, + const paddle::Tensor &seq_lens_decoder_tensor, + const int rank, + const int num_layers) { + if (FLAGS_fmt_write_cache_completed_signal) { + int real_bsz = seq_lens_this_time_tensor.dims()[0]; + // GPU init, cp to cpu? + auto seq_lens_encoder_cpu = + seq_lens_encoder_tensor.copy_to(paddle::CPUPlace(), false); + auto seq_lens_decoder_cpu = + seq_lens_decoder_tensor.copy_to(paddle::CPUPlace(), false); + RemoteCacheKvIpc::kv_complete_signal_meta_data_per_query.init( + seq_lens_encoder_cpu.data(), + seq_lens_decoder_cpu.data(), + rank, + num_layers, + real_bsz); + } +} + +std::vector OpenShmAndGetMetaSignal( + const int rank, const bool keep_pd_step_flag) { + return {OpenShmAndGetMetaSignalFunc(rank, keep_pd_step_flag)}; +} + +std::vector> OpenShmAndGetMetaSignalShape( + const int rank, const bool keep_pd_step_flag) { + return {{3}}; +} + +std::vector OpenShmAndGetMetaSignalDtype( + const int rank, const bool keep_pd_step_flag) { + return {paddle::DataType::INT64}; +} + +PD_BUILD_OP(open_shm_and_get_meta_signal) + .Inputs({}) + .Outputs({"kv_signal_metadata"}) + .Attrs({"rank: int", "keep_pd_step_flag: bool"}) + .SetKernelFn(PD_KERNEL(OpenShmAndGetMetaSignal)) + .SetInferShapeFn(PD_INFER_SHAPE(OpenShmAndGetMetaSignalShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(OpenShmAndGetMetaSignalDtype)); diff --git a/custom_ops/xpu_ops/src/ops/pybind/alloc_cache_pinned.cc b/custom_ops/xpu_ops/src/ops/pybind/alloc_cache_pinned.cc new file mode 100644 index 000000000..500cfbf43 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/pybind/alloc_cache_pinned.cc @@ -0,0 +1,46 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include // NOLINT +#include "cuda_runtime_api.h" // NOLINT +#include "paddle/extension.h" +#include "xpu/runtime.h" +#include "ops/pybind/pybind.h" + +void check_xpu_error(int error) { + if (error != XPU_SUCCESS) { + throw XPUError(error); + } +} + +// 封装xpu_host_alloc的Python函数 +uintptr_t custom_xpu_host_alloc(size_t size, unsigned int flags) { + void* ptr = nullptr; + // check_xpu_error(xpu_host_alloc(&ptr, size, flags)); + ptr = malloc(size); + PD_CHECK(ptr != nullptr); + PD_CHECK(mlock(ptr, size) == 0); + return reinterpret_cast(ptr); +} + +// 封装xpu_host_free的Python函数 +void custom_xpu_host_free(uintptr_t ptr) { + check_xpu_error(xpu_host_free(reinterpret_cast(ptr))); +} + +// 封装cudaHostRegister的Python函数,将可分页内存注册为锁页的 +void xpu_cuda_host_register(uintptr_t ptr, size_t size, unsigned int flags) { + cudaError_t e = cudaHostRegister(reinterpret_cast(ptr), size, flags); + PD_CHECK(e == cudaSuccess, cudaGetErrorString(e)); +} diff --git a/custom_ops/xpu_ops/src/ops/pybind/cachekv_signal_thread_worker.cc b/custom_ops/xpu_ops/src/ops/pybind/cachekv_signal_thread_worker.cc new file mode 100644 index 000000000..f7a6b3160 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/pybind/cachekv_signal_thread_worker.cc @@ -0,0 +1,111 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "ops/pybind/cachekv_signal_thread_worker.h" +#include +#include "ops/remote_cache_kv_ipc.h" +#include "ops/utility/env.h" +XPU_DECLARE_BOOL(fmt_write_cache_completed_signal, false); +CacheKvSignalThreadWorker::CacheKvSignalThreadWorker() : stop(false) { + xpu_stream_create(&write_cache_kv_stream); + int devid; + auto ret = xpu_current_device(&devid); + PD_CHECK(ret == 0, "xpu_current_device failed."); + auto func = [this, devid]() { + int old_dev; + xpu_current_device(&old_dev); + auto ret = xpu_set_device(devid); + PD_CHECK(ret == 0, "xpu_set_device failed."); + ret = cudaSetDevice(devid); + PD_CHECK(ret == 0, "cudaSetDevice failed."); + + while (true) { + std::function task; + { + std::unique_lock lock(write_mutex); + if (stop) return; + if (!signal_task_queue.empty()) { + task = std::move(signal_task_queue.front()); + signal_task_queue.pop(); + } else { + lock.unlock(); + std::this_thread::sleep_for(std::chrono::microseconds(1)); + continue; + } + } + task(); // 执行任务 + } + }; + worker_thread = std::thread(func); +} + +void CacheKvSignalThreadWorker::push_signal_task(XPUEvent e1, void* meta_data) { + auto func = [this, e1, meta_data]() { + xpu_stream_wait_event(write_cache_kv_stream, e1); + xpu_wait(write_cache_kv_stream); + RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise(meta_data); + xpu_event_destroy(e1); + }; + std::lock_guard lock(write_mutex); + signal_task_queue.push(func); +} + +void CacheKvSignalThreadWorker::push_signal_task_per_query(XPUEvent e1, + void* meta_data) { + auto func = [this, e1, meta_data]() { + xpu_stream_wait_event(write_cache_kv_stream, e1); + xpu_wait(write_cache_kv_stream); + RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_per_query( + meta_data); + xpu_event_destroy(e1); + }; + std::lock_guard lock(write_mutex); + signal_task_queue.push(func); +} + +void CacheKvSignalThreadWorker::sync_all_signals() { + { + std::unique_lock lock(write_mutex); + while (!signal_task_queue.empty()) { + // 1 微秒休眠 + lock.unlock(); + std::this_thread::sleep_for(std::chrono::microseconds(1)); + lock.lock(); + } + stop = true; + } + worker_thread.join(); + xpu_stream_destroy(write_cache_kv_stream); +} + +paddle::Tensor create_cachekv_signal_thread() { + CacheKvSignalThreadWorker* worker = nullptr; + if (FLAGS_fmt_write_cache_completed_signal) { + worker = new CacheKvSignalThreadWorker(); + } + auto t = paddle::full({1}, 0, paddle::DataType::INT64, paddle::CPUPlace()); + t.data()[0] = reinterpret_cast(worker); + return t; +} +void destroy_cachekv_signal_thread(const paddle::Tensor& t) { + auto worker = + reinterpret_cast(t.data()[0]); + if (FLAGS_fmt_write_cache_completed_signal) { + PD_CHECK(worker != nullptr, "cachekv_signal_thread should not be nullptr"); + worker->sync_all_signals(); + delete worker; + } else { + PD_CHECK(worker == nullptr, + "cachekv_signal_thread should be nullptr if not pd split"); + } +} diff --git a/custom_ops/xpu_ops/src/ops/pybind/cachekv_signal_thread_worker.h b/custom_ops/xpu_ops/src/ops/pybind/cachekv_signal_thread_worker.h new file mode 100644 index 000000000..e05552c1f --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/pybind/cachekv_signal_thread_worker.h @@ -0,0 +1,35 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include +#include +#include +#include "paddle/extension.h" +#include "xpu/runtime.h" + +struct CacheKvSignalThreadWorker { + CacheKvSignalThreadWorker(); + void push_signal_task(XPUEvent e1, void* meta_data); + void push_signal_task_per_query(XPUEvent e1, void* meta_data); + void sync_all_signals(); + std::thread worker_thread; + std::queue> signal_task_queue; + std::mutex write_mutex; + XPUStream write_cache_kv_stream; + bool stop; +}; + +paddle::Tensor create_cachekv_signal_thread(); +void destroy_cachekv_signal_thread(const paddle::Tensor& t); diff --git a/custom_ops/xpu_ops/src/ops/pybind/get_peermem_addr.cc b/custom_ops/xpu_ops/src/ops/pybind/get_peermem_addr.cc new file mode 100644 index 000000000..3fb5038a5 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/pybind/get_peermem_addr.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cuda_runtime_api.h" // NOLINT +#include "paddle/extension.h" +#include "xpu/runtime.h" + +uintptr_t xpu_get_peer_mem_addr(uintptr_t ptr) { + struct cudaPointerAttributes pointerAttr; + cudaPointerGetAttributes(&pointerAttr, reinterpret_cast(ptr)); + PD_CHECK(pointerAttr.hostPointer != nullptr, + "Failed to get host pointer from device pointer"); + uintptr_t ptr_out = reinterpret_cast(pointerAttr.hostPointer); + return ptr_out; +} diff --git a/custom_ops/xpu_ops/src/ops/pybind/profiler.cc b/custom_ops/xpu_ops/src/ops/pybind/profiler.cc new file mode 100644 index 000000000..1410ef825 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/pybind/profiler.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/extension.h" +#include "xpu/runtime.h" + +void prof_start() { + int ret = xpu_profiler_start(); + PD_CHECK(ret == 0, "xpu_profiler_start error"); +} + +void prof_stop() { + int ret = xpu_profiler_stop(); + PD_CHECK(ret == 0, "xpu_profiler_stop error"); +} diff --git a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc new file mode 100644 index 000000000..de81c49bf --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc @@ -0,0 +1,704 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ops/pybind/pybind.h" +#include +#include "cuda_runtime_api.h" // NOLINT +#include "paddle/extension.h" + +namespace py = pybind11; + +uintptr_t custom_xpu_host_alloc(size_t size, unsigned int flags); + +void custom_xpu_host_free(uintptr_t ptr); + +uintptr_t xpu_get_peer_mem_addr(uintptr_t ptr); + +void xpu_cuda_host_register(uintptr_t ptr, + size_t size, + unsigned int flags = cudaHostRegisterDefault); + +void prof_start(); + +void prof_stop(); + +void InitKVSignalPerQuery(const paddle::Tensor &seq_lens_encoder_tensor, + const paddle::Tensor &seq_lens_this_time_tensor, + const paddle::Tensor &seq_lens_decoder_tensor, + const int rank, + const int num_layers); + +void GetOutputKVSignal(const paddle::Tensor &x, + int64_t rank_id, + bool wait_flag); + +std::vector MoERedundantTopKSelect( + const paddle::Tensor& gating_logits, + const paddle::Tensor& expert_id_to_ep_rank_array, + const paddle::Tensor& expert_in_rank_num_list, + paddle::Tensor& tokens_per_expert_stats_list, // NOLINT + const paddle::optional& bias, + const int moe_topk, + const bool apply_norm_weight, + const bool enable_softmax_top_k_fused, + const int redundant_ep_rank_num_plus_one); + +void set_ncluster(int num) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + xpu_ctx->x_context()->set_ncluster(num); +} + +std::vector RmsNorm( + const paddle::Tensor& x, + const paddle::optional& bias, + const paddle::optional& residual, + const paddle::Tensor& norm_weight, + const paddle::optional& norm_bias, + const float epsilon, + const int begin_norm_axis, + const float quant_scale, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound); + +std::vector WeightOnlyLinear( + const paddle::Tensor& x, + const paddle::Tensor& weight, + const paddle::Tensor& weight_scale, + const paddle::optional& bias, + const std::string& weight_dtype, + const int arch, + const int group_size); + +std::vector MoeEPCombine(const paddle::Tensor& ffn_out, + const paddle::Tensor& moe_index, + const paddle::Tensor& weights, + const int recv_token_num, + const int expand_token_num, + const int hidden_dim, + const int topk); + +std::vector EPMoeExpertDispatch( + const paddle::Tensor& input, + const paddle::Tensor& topk_ids, + const paddle::Tensor& topk_weights, + const paddle::optional& input_scales, + const std::vector& token_nums_per_expert, + const int token_nums_this_rank, + const std::string quant_method); + +std::vector MoeExpertFFN( + const paddle::Tensor& ffn_in, + const paddle::Tensor& token_num_info, + const paddle::Tensor& ffn1_weight, + const paddle::Tensor& ffn2_weight, + const paddle::optional& ffn1_bias, + const paddle::optional& ffn2_bias, + const paddle::optional& ffn1_act_scale, + const paddle::optional& ffn2_act_scale, + const paddle::optional& ffn1_weight_scale, + const paddle::optional& ffn2_weight_scale, + const paddle::optional& ffn2_shift, + const paddle::optional& ffn2_smooth, + const std::string& quant_method, + const int hadamard_blocksize, + const int valid_token_num); + +std::vector MoeTopkSelect( + const paddle::Tensor& gating_logits, + const paddle::optional& bias, + const int moe_topk, + const bool apply_norm_weight); + +void DraftModelUpdate(const paddle::Tensor& inter_next_tokens, + const paddle::Tensor& draft_tokens, + const paddle::Tensor& pre_ids, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& step_idx, + const paddle::Tensor& output_cum_offsets, + const paddle::Tensor& stop_flags, + const paddle::Tensor& not_need_stop, + const paddle::Tensor& max_dec_len, + const paddle::Tensor& end_ids, + const paddle::Tensor& base_model_draft_tokens, + const int max_seq_len, + const int substep); + +void SpeculateUpdateV3(const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& not_need_stop, + const paddle::Tensor& draft_tokens, + const paddle::Tensor& actual_draft_token_nums, + const paddle::Tensor& accept_tokens, + const paddle::Tensor& accept_num, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& is_block_step, + const paddle::Tensor& stop_nums); + +void SpeculateTokenPenaltyMultiScores( + const paddle::Tensor& pre_ids, + const paddle::Tensor& logits, + const paddle::Tensor& penalty_scores, + const paddle::Tensor& frequency_scores, + const paddle::Tensor& presence_scores, + const paddle::Tensor& temperatures, + const paddle::Tensor& bad_tokens, + const paddle::Tensor& cur_len, + const paddle::Tensor& min_len, + const paddle::Tensor& eos_token_id, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& output_padding_offset, + const paddle::Tensor& output_cum_offsets, + const int max_seq_len); + +void SpeculateUpdateV3(const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& not_need_stop, + const paddle::Tensor& draft_tokens, + const paddle::Tensor& actual_draft_token_nums, + const paddle::Tensor& accept_tokens, + const paddle::Tensor& accept_num, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& is_block_step, + const paddle::Tensor& stop_nums); + +std::vector TopPCandidates( + const paddle::Tensor& probs, + const paddle::Tensor& top_p, + const paddle::Tensor& output_padding_offset, + int candidates_len, + int max_seq_len); + +void SpeculateVerify(const paddle::Tensor& accept_tokens, + const paddle::Tensor& accept_num, + const paddle::Tensor& step_idx, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& draft_tokens, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& verify_tokens, + const paddle::Tensor& verify_scores, + const paddle::Tensor& max_dec_len, + const paddle::Tensor& end_tokens, + const paddle::Tensor& is_block_step, + const paddle::Tensor& output_cum_offsets, + const paddle::Tensor& actual_candidate_len, + const paddle::Tensor& actual_draft_token_nums, + const paddle::Tensor& topp, + int max_seq_len, + int verify_window, + bool enable_topp); + +void SpeculateClearAcceptNums(const paddle::Tensor& accept_num, + const paddle::Tensor& seq_lens_decoder); + +void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor& pre_ids_all, + const paddle::Tensor& accept_tokens, + const paddle::Tensor& accept_num, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& step_idx); + +void DraftModelPreprocess(const paddle::Tensor& draft_tokens, + const paddle::Tensor& input_ids, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& step_idx, + const paddle::Tensor& seq_lens_encoder_record, + const paddle::Tensor& seq_lens_decoder_record, + const paddle::Tensor& not_need_stop, + const paddle::Tensor& batch_drop, + const paddle::Tensor& accept_tokens, + const paddle::Tensor& accept_num, + const paddle::Tensor& base_model_seq_lens_encoder, + const paddle::Tensor& base_model_seq_lens_decoder, + const paddle::Tensor& base_model_step_idx, + const paddle::Tensor& base_model_stop_flags, + const paddle::Tensor& base_model_is_block_step, + const paddle::Tensor& base_model_draft_tokens, + const int max_draft_token, + const bool truncate_first_token, + const bool splitwise_prefill); + +void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens, + const paddle::Tensor& base_model_seq_lens_this_time, + const paddle::Tensor& base_model_seq_lens_encoder, + const paddle::Tensor& base_model_stop_flags); + +std::vector EagleGetHiddenStates( + const paddle::Tensor& input, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& stop_flags, + const paddle::Tensor& accept_nums, + const paddle::Tensor& base_model_seq_lens_this_time, + const paddle::Tensor& base_model_seq_lens_encoder, + const int actual_draft_token_num); + +std::vector EagleGetSelfHiddenStates( + const paddle::Tensor& input, + const paddle::Tensor& last_seq_lens_this_time, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& step_idx); + +std::vector SpeculateGetOutputPaddingOffset( + const paddle::Tensor& output_cum_offsets_tmp, + const paddle::Tensor& out_token_num, + const paddle::Tensor& seq_lens_output, + const int max_seq_len); + +std::vector SpeculateGetPaddingOffset( + const paddle::Tensor& input_ids, + const paddle::Tensor& draft_tokens, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& token_num, + const paddle::Tensor& seq_len, + const paddle::Tensor& seq_lens_encoder); + +void MTPStepPaddle( + const paddle::Tensor& base_model_stop_flags, + const paddle::Tensor& stop_flags, + const paddle::Tensor& batch_drop, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& block_tables, // [bsz, block_num_per_seq] + const paddle::Tensor& encoder_block_lens, + const paddle::Tensor& used_list_len, + const paddle::Tensor& free_list, + const paddle::Tensor& free_list_len, + const int block_size, + const int max_draft_tokens); + +void SpeculateStepSchedule( + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& ori_seq_lens_encoder, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& block_tables, // [bsz, block_num_per_seq] + const paddle::Tensor& encoder_block_lens, + const paddle::Tensor& is_block_step, + const paddle::Tensor& step_block_list, + const paddle::Tensor& step_lens, + const paddle::Tensor& recover_block_list, + const paddle::Tensor& recover_lens, + const paddle::Tensor& need_block_list, + const paddle::Tensor& need_block_len, + const paddle::Tensor& used_list_len, + const paddle::Tensor& free_list, + const paddle::Tensor& free_list_len, + const paddle::Tensor& input_ids, + const paddle::Tensor& pre_ids, + const paddle::Tensor& step_idx, + const paddle::Tensor& next_tokens, + const paddle::Tensor& first_token_ids, + const paddle::Tensor& accept_num, + const int block_size, + const int encoder_decoder_block_num, + const int max_draft_tokens); + +std::vector SpeculateGetSeqLensOutput( + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder); + +PYBIND11_MODULE(fastdeploy_ops, m) { + m.def("cuda_host_alloc", + &custom_xpu_host_alloc, + "Allocate pinned memory", + py::arg("size"), + py::arg("flags") = 0x00); + m.def("cuda_host_free", + &custom_xpu_host_free, + "Free pinned memory", + py::arg("ptr")); + m.def("get_peer_mem_addr", + &xpu_get_peer_mem_addr, + "Get Host memory address of device pointer", + py::arg("ptr")); + m.def("cuda_host_register", + &xpu_cuda_host_register, + "Register pinned memory", + py::arg("ptr"), + py::arg("size"), + py::arg("flags") = cudaHostRegisterDefault); + m.def("create_kv_signal_sender", + &create_cachekv_signal_thread, + "init write cache kv signal thread"); + m.def("destroy_kv_signal_sender", + &destroy_cachekv_signal_thread, + "write cache kv signal thread exit"); + m.def("prof_start", &prof_start, "prof_start"); + m.def("prof_stop", &prof_stop, "prof_stop"); + m.def("moe_redundant_topk_select", + &MoERedundantTopKSelect, + py::arg("gating_logits"), + py::arg("expert_id_to_ep_rank_array"), + py::arg("expert_in_rank_num_list"), + py::arg("tokens_per_expert_stats_list"), + py::arg("bias"), + py::arg("moe_topk"), + py::arg("apply_norm_weight"), + py::arg("enable_softmax_top_k_fused"), + py::arg("redundant_ep_rank_num_plus_one"), + "moe export RedundantTopKSelect function"); + m.def("set_ncluster", &set_ncluster, "set ncluster"); + + /** + * open_shm_and_get_meta_signal.cc + * InitKVSingnalPerQuery + */ + m.def("init_kv_signal_per_query", + &InitKVSignalPerQuery, + py::arg("seq_lens_encoder_tensor"), + py::arg("seq_lens_this_time_tensor"), + py::arg("seq_lens_decoder_tensor"), + py::arg("rank"), + py::arg("num_layers"), + "init_kv_signal_per_query function"); + + /** + * GetOutputKVSignal + */ + m.def("get_output_kv_signal", + &GetOutputKVSignal, + py::arg("x"), + py::arg("rank_id"), + py::arg("wait_flag"), + "get_output_kv_signal function"); + + m.def("fused_rms_norm_xpu", + &RmsNorm, + "Fused RMS normalization for XPU", + py::arg("x"), // 输入张量 + py::arg("bias"), // 偏置(可选) + py::arg("residual"), // 残差连接(可选) + py::arg("norm_weight"), // 归一化权重 + py::arg("norm_bias"), // 归一化偏置(可选) + py::arg("epsilon"), // 数值稳定项 + py::arg("begin_norm_axis"), // 归一化起始维度 + py::arg("quant_scale"), // 量化缩放因子 + py::arg("quant_round_type"), // 量化舍入类型 + py::arg("quant_max_bound"), // 量化最大值边界 + py::arg("quant_min_bound") // 量化最小值边界 + ); + + m.def("weight_only_linear_xpu", + &WeightOnlyLinear, + "Weight-only quantized linear layer", + py::arg("x"), + py::arg("weight"), + py::arg("weight_scale"), + py::arg("bias"), + py::arg("weight_dtype"), + py::arg("arch"), + py::arg("group_size")); + + m.def("ep_moe_expert_combine", + &MoeEPCombine, + "MoE (Mixture of Experts) EP combine operation", + py::arg("ffn_out"), // FFN输出张量 [token_num, hidden_dim] + py::arg("moe_index"), // MoE专家索引张量 [token_num, topk] + py::arg("weights"), // 专家权重张量 [token_num, topk] + py::arg("recv_token_num"), // 接收的token数量(int) + py::arg("expand_token_num"), // 扩展的token数量(int) + py::arg("hidden_dim"), // 隐藏层维度(int) + py::arg("topk") // 选择的专家数量(int) + ); + + m.def("ep_moe_expert_dispatch", + &EPMoeExpertDispatch, + "EP MoE expert dispatch operation", + py::arg("input"), + py::arg("topk_ids"), + py::arg("topk_weights"), + py::arg("input_scales") = py::none(), + py::arg("token_nums_per_expert"), + py::arg("token_nums_this_rank"), + py::arg("quant_method")); + + m.def("moe_expert_ffn", + &MoeExpertFFN, + "MoE expert feed-forward network with quantization support", + py::arg("ffn_in"), // [valid_token_num, hidden_dim] + py::arg("token_num_info"), + py::arg("ffn1_weight"), + py::arg("ffn2_weight"), + py::arg("ffn1_bias") = py::none(), + py::arg("ffn2_bias") = py::none(), + py::arg("ffn1_act_scale") = py::none(), + py::arg("ffn2_act_scale") = py::none(), + py::arg("ffn1_weight_scale") = py::none(), + py::arg("ffn2_weight_scale") = py::none(), + py::arg("ffn2_shift") = py::none(), + py::arg("ffn2_smooth") = py::none(), + py::arg("quant_method"), + py::arg("hadamard_blocksize"), + py::arg("valid_token_num")); + + m.def("moe_topk_select", + &MoeTopkSelect, + "MoE Top-k selection: selects top-k experts via gating logits", + py::arg("gating_logits"), + py::arg("bias") = py::none(), + py::arg("moe_topk"), + py::arg("apply_norm_weight")); + + m.def("draft_model_update", + &DraftModelUpdate, + "Update draft model states during speculative decoding", + py::arg("inter_next_tokens"), // 中间next tokens张量 + py::arg("draft_tokens"), // 草稿token张量 + py::arg("pre_ids"), // 前置ID张量 + py::arg("seq_lens_this_time"), // 当前步骤序列长度张量 + py::arg("seq_lens_encoder"), // 编码器序列长度张量 + py::arg("seq_lens_decoder"), // 解码器序列长度张量 + py::arg("step_idx"), // 步骤索引张量 + py::arg("output_cum_offsets"), // 输出累积偏移量张量 + py::arg("stop_flags"), // 停止标志张量 + py::arg("not_need_stop"), // 无需停止标志张量 + py::arg("max_dec_len"), // 最大解码长度张量 + py::arg("end_ids"), // 结束ID张量 + py::arg("base_model_draft_tokens"), // 基础模型草稿token张量 + py::arg("max_seq_len"), // 最大序列长度(int) + py::arg("substep") // 子步骤编号(int) + ); + + m.def("speculate_get_token_penalty_multi_scores", + &SpeculateTokenPenaltyMultiScores, + py::arg("pre_ids"), + py::arg("logits"), + py::arg("penalty_scores"), + py::arg("frequency_scores"), + py::arg("presence_scores"), + py::arg("temperatures"), + py::arg("bad_tokens"), + py::arg("cur_len"), + py::arg("min_len"), + py::arg("eos_token_id"), + py::arg("seq_lens_this_time"), + py::arg("output_padding_offset"), + py::arg("output_cum_offsets"), + py::arg("max_seq_len"), + "Applies token penalty with multiple scores"); + + m.def("speculate_update_v3", + &SpeculateUpdateV3, + py::arg("seq_lens_encoder"), + py::arg("seq_lens_decoder"), + py::arg("not_need_stop"), + py::arg("draft_tokens"), + py::arg("actual_draft_token_nums"), + py::arg("accept_tokens"), + py::arg("accept_num"), + py::arg("stop_flags"), + py::arg("seq_lens_this_time"), + py::arg("is_block_step"), + py::arg("stop_nums"), + "Update speculative decoding states (V3)"); + + m.def("top_p_candidates", + &TopPCandidates, + py::arg("probs"), + py::arg("top_p"), + py::arg("output_padding_offset"), + py::arg("candidates_len"), + py::arg("max_seq_len"), + "Generate top-p candidates based on probability distributions"); + + m.def("speculate_verify", + &SpeculateVerify, + py::arg("accept_tokens"), + py::arg("accept_num"), + py::arg("step_idx"), + py::arg("stop_flags"), + py::arg("seq_lens_encoder"), + py::arg("seq_lens_decoder"), + py::arg("draft_tokens"), + py::arg("seq_lens_this_time"), + py::arg("verify_tokens"), + py::arg("verify_scores"), + py::arg("max_dec_len"), + py::arg("end_tokens"), + py::arg("is_block_step"), + py::arg("output_cum_offsets"), + py::arg("actual_candidate_len"), + py::arg("actual_draft_token_nums"), + py::arg("topp"), + py::arg("max_seq_len"), + py::arg("verify_window"), + py::arg("enable_topp"), + "Perform speculative verification for decoding"); + + m.def("speculate_clear_accept_nums", + &SpeculateClearAcceptNums, + py::arg("accept_num"), + py::arg("seq_lens_decoder"), + "Clear accept numbers based on decoder sequence lengths"); + + m.def("speculate_set_value_by_flags_and_idx", + &SpeculateSetValueByFlagsAndIdx, + py::arg("pre_ids_all"), + py::arg("accept_tokens"), + py::arg("accept_num"), + py::arg("stop_flags"), + py::arg("seq_lens_this_time"), + py::arg("seq_lens_encoder"), + py::arg("seq_lens_decoder"), + py::arg("step_idx"), + "Set values based on flags and indices in speculative decoding"); + + m.def("draft_model_preprocess", + &DraftModelPreprocess, + py::arg("draft_tokens"), + py::arg("input_ids"), + py::arg("stop_flags"), + py::arg("seq_lens_this_time"), + py::arg("seq_lens_encoder"), + py::arg("seq_lens_decoder"), + py::arg("step_idx"), + py::arg("seq_lens_encoder_record"), + py::arg("seq_lens_decoder_record"), + py::arg("not_need_stop"), + py::arg("batch_drop"), + py::arg("accept_tokens"), + py::arg("accept_num"), + py::arg("base_model_seq_lens_encoder"), + py::arg("base_model_seq_lens_decoder"), + py::arg("base_model_step_idx"), + py::arg("base_model_stop_flags"), + py::arg("base_model_is_block_step"), + py::arg("base_model_draft_tokens"), + py::arg("max_draft_token"), + py::arg("truncate_first_token"), + py::arg("splitwise_prefill"), + "Preprocess data for draft model in speculative decoding"); + + m.def("draft_model_postprocess", + &DraftModelPostprocess, + py::arg("base_model_draft_tokens"), + py::arg("base_model_seq_lens_this_time"), + py::arg("base_model_seq_lens_encoder"), + py::arg("base_model_stop_flags"), + "Postprocess data for draft model in speculative decoding"); + + m.def("eagle_get_hidden_states", + &EagleGetHiddenStates, + py::arg("input"), + py::arg("seq_lens_this_time"), + py::arg("seq_lens_encoder"), + py::arg("seq_lens_decoder"), + py::arg("stop_flags"), + py::arg("accept_nums"), + py::arg("base_model_seq_lens_this_time"), + py::arg("base_model_seq_lens_encoder"), + py::arg("actual_draft_token_num"), + "Get draft model hidden states"); + + m.def("eagle_get_self_hidden_states", + &EagleGetSelfHiddenStates, + py::arg("input"), + py::arg("last_seq_lens_this_time"), + py::arg("seq_lens_this_time"), + py::arg("step_idx"), + "Rebuild draft model hidden states"); + + m.def("speculate_get_output_padding_offset", + &SpeculateGetOutputPaddingOffset, + py::arg("output_cum_offsets_tmp"), + py::arg("out_token_num"), + py::arg("seq_lens_output"), + py::arg("max_seq_len"), + "Get output padding offset"); + + m.def("speculate_get_padding_offset", + &SpeculateGetPaddingOffset, + py::arg("input_ids"), + py::arg("draft_tokens"), + py::arg("cum_offsets"), + py::arg("token_num"), + py::arg("seq_len"), + py::arg("seq_lens_encoder"), + "Get padding offset"); + + m.def("mtp_step_paddle", + &MTPStepPaddle, + py::arg("base_model_stop_flags"), + py::arg("stop_flags"), + py::arg("batch_drop"), + py::arg("seq_lens_this_time"), + py::arg("seq_lens_encoder"), + py::arg("seq_lens_decoder"), + py::arg("block_tables"), // [bsz, block_num_per_seq] + py::arg("encoder_block_lens"), + py::arg("used_list_len"), + py::arg("free_list"), + py::arg("free_list_len"), + py::arg("block_size"), + py::arg("max_draft_tokens"), + "MTP step paddle"); + + m.def("speculate_step_reschedule", + &SpeculateStepSchedule, + py::arg("stop_flags"), + py::arg("seq_lens_this_time"), + py::arg("ori_seq_lens_encoder"), + py::arg("seq_lens_encoder"), + py::arg("seq_lens_decoder"), + py::arg("block_tables"), + py::arg("encoder_block_lens"), + py::arg("is_block_step"), + py::arg("step_block_list"), + py::arg("step_lens"), + py::arg("recover_block_list"), + py::arg("recover_lens"), + py::arg("need_block_list"), + py::arg("need_block_len"), + py::arg("used_list_len"), + py::arg("free_list"), + py::arg("free_list_len"), + py::arg("input_ids"), + py::arg("pre_ids"), + py::arg("step_idx"), + py::arg("next_tokens"), + py::arg("first_token_ids"), + py::arg("accept_num"), + py::arg("block_size"), + py::arg("encoder_decoder_block_num"), + py::arg("max_draft_tokens"), + "Step reschedule"); + + m.def("speculate_get_seq_lens_output", + &SpeculateGetSeqLensOutput, + py::arg("seq_lens_this_time"), + py::arg("seq_lens_encoder"), + py::arg("seq_lens_decoder"), + "Get sequence lengths output"); + + // 添加XPU错误信息的异常处理类 + py::register_exception(m, "XPUError"); +} diff --git a/custom_ops/xpu_ops/src/ops/pybind/pybind.h b/custom_ops/xpu_ops/src/ops/pybind/pybind.h new file mode 100644 index 000000000..d37e2f384 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/pybind/pybind.h @@ -0,0 +1,29 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include +#include +#include "ops/pybind/cachekv_signal_thread_worker.h" + +// 自定义异常类,用于处理XPU错误 +class XPUError : public std::exception { + public: + explicit XPUError(int error) : error_(error) {} + + const char *what() const noexcept override { return xpu_strerror(error_); } + + private: + int error_; +}; diff --git a/custom_ops/xpu_ops/src/ops/read_data_ipc.cc b/custom_ops/xpu_ops/src/ops/read_data_ipc.cc new file mode 100644 index 000000000..efc9dd006 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/read_data_ipc.cc @@ -0,0 +1,95 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include "paddle/extension.h" +#include "xpu/plugin.h" +#include "xpu_multiprocess.h" // NOLINT + +void ReadDataIpc(const paddle::Tensor &tmp_input, const std::string &shm_name) { + volatile shmStruct *shm = NULL; + sharedMemoryInfo info; + int ret = sharedMemoryOpen(shm_name.c_str(), sizeof(shmStruct), &info); + PD_CHECK(ret == 0, "sharedMemoryOpen failed"); + + shm = static_cast(info.addr); + void *ptr = nullptr; +#if XPURT_VERSION_MAJOR == 5 + ret = xpu_ipc_open_memhandle( + &ptr, *(XPUIpcMemHandle *)&shm->memHandle, 0x01); // NOLINT +#elif XPURT_VERSION_MAJOR == 4 + PD_THROW("kl2 not support prefix cache"); +#endif + PD_CHECK(ret == XPU_SUCCESS, "xpu_ipc_open_memhandle failed"); + PD_CHECK(tmp_input.place().GetType() == phi::AllocationType::CPU); + // switch (tmp_input.dtype()) { + // case paddle::DataType::FLOAT32: + // ret = xpu_memcpy(const_cast(tmp_input.data()), + // ptr, + // tmp_input.numel() * sizeof(float), + // XPUMemcpyKind::XPU_DEVICE_TO_HOST); + // break; + // case paddle::DataType::FLOAT16: + // ret = xpu_memcpy(const_cast( + // tmp_input.data()), + // ptr, + // tmp_input.numel() * sizeof(phi::dtype::float16), + // XPUMemcpyKind::XPU_DEVICE_TO_HOST); + // break; + // case paddle::DataType::UINT8: + // ret = xpu_memcpy(const_cast(tmp_input.data()), + // ptr, + // tmp_input.numel() * sizeof(uint8_t), + // XPUMemcpyKind::XPU_DEVICE_TO_HOST); + // break; + // default: + // PD_THROW("not support dtype: ", + // phi::DataTypeToString(tmp_input.dtype())); + // } + // PD_CHECK(ret == XPU_SUCCESS, "not support dtype"); + // ret = xpu_ipc_close_memhandle(ptr); + // PD_CHECK(ret == XPU_SUCCESS, "not support dtype"); + + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + void *data_ptr = reinterpret_cast(shm->data_ptr_addr); + auto x = paddle::from_blob(data_ptr, + tmp_input.shape(), + tmp_input.dtype(), + tmp_input.layout(), + place); + paddle::Tensor y = tmp_input.copy_to(place, false); + ret = baidu::xpu::api::scale(xpu_ctx->x_context(), + x.data(), + y.data(), + tmp_input.numel(), + true, + 1.f, + 2.f); + PD_CHECK(ret == XPU_SUCCESS, "add2 fail"); + ret = xpu_memcpy(const_cast(tmp_input.data()), + y.data(), + tmp_input.numel() * sizeof(float), + XPUMemcpyKind::XPU_DEVICE_TO_HOST); + PD_CHECK(ret == XPU_SUCCESS, "xpu_memcpy fail"); + + sharedMemoryClose(&info); +} + +PD_BUILD_OP(read_data_ipc) + .Inputs({"tmp_input"}) + .Attrs({"shm_name: std::string"}) + .Outputs({"tmp_input_out"}) + .SetInplaceMap({{"tmp_input", "tmp_input_out"}}) + .SetKernelFn(PD_KERNEL(ReadDataIpc)); diff --git a/custom_ops/xpu_ops/src/ops/remote_cache_kv_ipc.cc b/custom_ops/xpu_ops/src/ops/remote_cache_kv_ipc.cc new file mode 100644 index 000000000..79a86af6b --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/remote_cache_kv_ipc.cc @@ -0,0 +1,113 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ops/remote_cache_kv_ipc.h" +#include +#include +#include "paddle/extension.h" + +RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data + RemoteCacheKvIpc::kv_complete_signal_meta_data; +RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data_per_query + RemoteCacheKvIpc::kv_complete_signal_meta_data_per_query; +void* RemoteCacheKvIpc::kv_complete_signal_identity_ptr = nullptr; +bool RemoteCacheKvIpc::kv_complete_signal_shmem_opened = false; + +RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data +RemoteCacheKvIpc::open_shm_and_get_complete_signal_meta_data( + const int rank_id, const bool keep_pd_step_flag) { + if (RemoteCacheKvIpc::kv_complete_signal_shmem_opened) { + if (keep_pd_step_flag) { + return RemoteCacheKvIpc::kv_complete_signal_meta_data; + } + int32_t current_identity = (*reinterpret_cast( + RemoteCacheKvIpc::kv_complete_signal_identity_ptr)); + int32_t* write_ptr = reinterpret_cast( + RemoteCacheKvIpc::kv_complete_signal_identity_ptr); + *write_ptr = (current_identity + 1) % 100003; + RemoteCacheKvIpc::kv_complete_signal_meta_data.layer_id = -1; + int32_t* layer_complete_ptr = + reinterpret_cast(kv_complete_signal_meta_data.shm_ptr); + *layer_complete_ptr = -1; + return RemoteCacheKvIpc::kv_complete_signal_meta_data; + } + std::string flags_server_uuid; + if (const char* iflags_server_uuid_env_p = std::getenv("SHM_UUID")) { + std::string iflags_server_uuid_env_str(iflags_server_uuid_env_p); + flags_server_uuid = iflags_server_uuid_env_str; + } + std::string step_shm_name = + ("splitwise_complete_prefilled_step_" + std::to_string(rank_id) + "_" + + flags_server_uuid); + std::string layer_shm_name = + ("splitwise_complete_prefilled_layer_" + std::to_string(rank_id) + "_" + + flags_server_uuid); + if (const char* use_ep = std::getenv("ENABLE_EP_DP")) { + if (std::strcmp(use_ep, "1") == 0) { + step_shm_name = "splitwise_complete_prefilled_step_tprank0_dprank" + + std::to_string(rank_id) + "_" + flags_server_uuid; + layer_shm_name = "splitwise_complete_prefilled_layer_tprank0_dprank" + + std::to_string(rank_id) + "_" + flags_server_uuid; + } + } + + int signal_shm_fd = shm_open(layer_shm_name.c_str(), O_CREAT | O_RDWR, 0666); + PD_CHECK(signal_shm_fd != -1, + "can not open shm for cache_kv_complete_signal."); + int signal_shm_ftruncate = ftruncate(signal_shm_fd, 4); + void* signal_ptr = mmap(0, 4, PROT_WRITE, MAP_SHARED, signal_shm_fd, 0); + PD_CHECK(signal_ptr != MAP_FAILED, + "can not open shm for cache_kv_compelete_identity."); + int32_t* write_signal_ptr = reinterpret_cast(signal_ptr); + *write_signal_ptr = -1; + using type_meta_data = + RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data; + + // std::printf("#### open_shm_and_get_complete_signal_meta_data layer idx:%d, + // to ptx:%p \n", + // -1, signal_ptr); + + type_meta_data meta_data(-1, signal_ptr, signal_shm_fd); + RemoteCacheKvIpc::kv_complete_signal_meta_data = meta_data; + int identity_shm_fd = shm_open(step_shm_name.c_str(), O_CREAT | O_RDWR, 0666); + PD_CHECK(identity_shm_fd != -1, + "can not open shm for cache_kv_compelete_identity."); + + int identity_shm_ftruncate = ftruncate(identity_shm_fd, 4); + void* identity_ptr = mmap(0, 4, PROT_WRITE, MAP_SHARED, identity_shm_fd, 0); + PD_CHECK(identity_ptr != MAP_FAILED, "MAP_FAILED for prefill_identity."); + + int32_t current_identity = (*reinterpret_cast(identity_ptr)); + int32_t* write_ptr = reinterpret_cast(identity_ptr); + *write_ptr = (current_identity + 1) % 100003; + RemoteCacheKvIpc::kv_complete_signal_identity_ptr = identity_ptr; + RemoteCacheKvIpc::kv_complete_signal_shmem_opened = true; + return meta_data; +} + +void RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise( + void* meta_data) { + int64_t* meta_data_ptr = reinterpret_cast(meta_data); + int32_t layer_id = meta_data_ptr[0]; + int32_t* ptr = reinterpret_cast(meta_data_ptr[1]); + *ptr = layer_id; + // std::printf("#### save_cache_kv_complete_signal_layerwise layer idx:%d, to + // ptx:%p \n", + // *ptr, meta_data_ptr[1]); +} + +void RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_per_query( + void* meta_data) { + RemoteCacheKvIpc::kv_complete_signal_meta_data_per_query.send_signal(); +} diff --git a/custom_ops/xpu_ops/src/ops/remote_cache_kv_ipc.h b/custom_ops/xpu_ops/src/ops/remote_cache_kv_ipc.h new file mode 100644 index 000000000..1cc4531c6 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/remote_cache_kv_ipc.h @@ -0,0 +1,98 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include "msg_utils.h" // NOLINT + +struct RemoteCacheKvIpc { + struct save_cache_kv_complete_signal_layerwise_meta_data { + int32_t layer_id = -1; + void* shm_ptr = nullptr; + int shm_fd = -1; + save_cache_kv_complete_signal_layerwise_meta_data() {} + save_cache_kv_complete_signal_layerwise_meta_data(int32_t layer_id_, + void* shm_ptr_, + int shm_fd_) + : layer_id(layer_id_), shm_ptr(shm_ptr_), shm_fd(shm_fd_) {} + }; + + struct save_cache_kv_complete_signal_layerwise_meta_data_per_query { + int layer_id_; + int num_layers_; + bool inited = false; + struct msgdatakv msg_sed; + int msgid; + + save_cache_kv_complete_signal_layerwise_meta_data_per_query() {} + + void init(const int* seq_lens_encoder, + const int* seq_lens_decoder, + const int rank, + const int num_layers, + const int real_bsz) { + layer_id_ = 0; + num_layers_ = num_layers; + msg_sed.mtype = 1; + int encoder_count = 0; + for (int i = 0; i < real_bsz; i++) { + if (seq_lens_encoder[i] > 0) { + msg_sed.mtext[3 * encoder_count + 2] = i; + msg_sed.mtext[3 * encoder_count + 3] = seq_lens_decoder[i]; + msg_sed.mtext[3 * encoder_count + 4] = seq_lens_encoder[i]; + encoder_count++; + } + } + msg_sed.mtext[0] = encoder_count; + + if (!inited) { + // just init once + const int msg_id = 1024 + rank; + key_t key = ftok("/opt/", msg_id); + msgid = msgget(key, IPC_CREAT | 0666); + inited = true; + } + } + + void send_signal() { + msg_sed.mtext[1] = layer_id_; + if ((msgsnd(msgid, &msg_sed, (MAX_BSZ * 3 + 2) * 4, 0)) == -1) { + printf("kv signal full msg buffer\n"); + } + layer_id_ = (layer_id_ + 1); + assert(layer_id_ <= num_layers_); + } + }; + + static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data + kv_complete_signal_meta_data; + static RemoteCacheKvIpc:: + save_cache_kv_complete_signal_layerwise_meta_data_per_query + kv_complete_signal_meta_data_per_query; + static void* kv_complete_signal_identity_ptr; + static bool kv_complete_signal_shmem_opened; + + static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data + open_shm_and_get_complete_signal_meta_data(const int rank_id, + const bool keep_pd_step_flag); + static void save_cache_kv_complete_signal_layerwise(void* meta_data); + static void save_cache_kv_complete_signal_layerwise_per_query( + void* meta_data); +}; diff --git a/custom_ops/xpu_ops/src/ops/set_data_ipc.cc b/custom_ops/xpu_ops/src/ops/set_data_ipc.cc new file mode 100644 index 000000000..307c1dac9 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/set_data_ipc.cc @@ -0,0 +1,69 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/extension.h" +#include "xpu_multiprocess.h" // NOLINT + +template +void set_data_ipc(const paddle::Tensor &tmp_input, + const std::string &shm_name) { + sharedMemoryInfo info; + volatile shmStruct *shm = NULL; + int ret = sharedMemoryCreate(shm_name.c_str(), sizeof(*shm), &info); + PD_CHECK(ret == 0, "sharedMemoryCreate failed"); + shm = (volatile shmStruct *)info.addr; + memset((void *)shm, 0, sizeof(*shm)); // NOLINT + + void *data_ptr_now = + reinterpret_cast(const_cast(tmp_input.data())); +#if XPURT_VERSION_MAJOR == 5 + ret = xpu_ipc_get_memhandle((XPUIpcMemHandle *)&shm->memHandle, // NOLINT + data_ptr_now); +#elif XPURT_VERSION_MAJOR == 4 + PD_THROW("kl2 not support prefix cache"); +#endif + PD_CHECK(ret == XPU_SUCCESS, "xpu_ipc_get_memhandle failed"); + shm->data_ptr_addr = reinterpret_cast((data_ptr_now)); +} + +void SetDataIpc(const paddle::Tensor &tmp_input, const std::string &shm_name) { + switch (tmp_input.type()) { + case paddle::DataType::FLOAT16: { + return set_data_ipc(tmp_input, shm_name); + } + case paddle::DataType::FLOAT32: { + return set_data_ipc(tmp_input, shm_name); + } + case paddle::DataType::INT8: { + return set_data_ipc(tmp_input, shm_name); + } + case paddle::DataType::UINT8: { + return set_data_ipc(tmp_input, shm_name); + } + case paddle::DataType::BFLOAT16: { + return set_data_ipc(tmp_input, shm_name); + } + default: { + PD_THROW("NOT supported data type."); + break; + } + } +} + +PD_BUILD_OP(set_data_ipc) + .Inputs({"tmp_input"}) + .Attrs({"shm_name: std::string"}) + .Outputs({"tmp_input_out"}) + .SetInplaceMap({{"tmp_input", "tmp_input_out"}}) + .SetKernelFn(PD_KERNEL(SetDataIpc)); diff --git a/custom_ops/xpu_ops/src/ops/share_external_data.cc b/custom_ops/xpu_ops/src/ops/share_external_data.cc new file mode 100644 index 000000000..f4f40bebb --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/share_external_data.cc @@ -0,0 +1,57 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/tensor_meta.h" +#include "xpu/plugin.h" +#include "xpu_multiprocess.h" // NOLINT(build/include_subdir) + +std::vector ShareExternalData(const paddle::Tensor &input, + const std::string shm_name, + const std::vector &shape, + bool use_ipc) { + sharedMemoryInfo info; + int ret = sharedMemoryOpen(shm_name.c_str(), sizeof(shmStruct), &info); + PD_CHECK(ret == 0, "sharedMemoryOpen failed"); + volatile shmStruct *shm = static_cast(info.addr); + void *data_ptr_addr = nullptr; + if (use_ipc) { +#if XPURT_VERSION_MAJOR == 5 + int ret = xpu_ipc_open_memhandle( + &data_ptr_addr, *(XPUIpcMemHandle *)&shm->memHandle, 0x01); // NOLINT + PD_CHECK(ret == XPU_SUCCESS, "xpu_ipc_open_memhandle failed"); +#elif XPURT_VERSION_MAJOR == 4 + PD_THROW("kl2 not support prefix cache"); +#endif + } else { + data_ptr_addr = reinterpret_cast(shm->data_ptr_addr); + } + + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + paddle::Tensor output = paddle::from_blob( + data_ptr_addr, shape, input.dtype(), input.layout(), place); + + sharedMemoryClose(&info); + return {output}; +} + +PD_BUILD_OP(share_external_data) + .Inputs({"input"}) + .Outputs({"output"}) + .Attrs({"shm_name: std::string", + "shape: std::vector", + "use_ipc: bool"}) + .SetKernelFn(PD_KERNEL(ShareExternalData)); diff --git a/custom_ops/xpu_ops/src/ops/swap_cache_batch.cc b/custom_ops/xpu_ops/src/ops/swap_cache_batch.cc new file mode 100644 index 000000000..4de0be75f --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/swap_cache_batch.cc @@ -0,0 +1,166 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "paddle/extension.h" + +template +void SwapCacheImplAllLayers( + const std::vector& cache_xpu_tensors, // xpu + const std::vector& cache_cpu_ptrs, // cpu + const int64_t& max_block_num_cpu, + const std::vector& swap_block_ids_xpu, + const std::vector& swap_block_ids_cpu, + int mode) { + using XPUType = typename XPUTypeTrait::Type; + for (int layer_idx = 0; layer_idx < cache_xpu_tensors.size(); layer_idx++) { + const paddle::Tensor& cache_xpu = cache_xpu_tensors[layer_idx]; + const int64_t& cache_cpu_pointer = cache_cpu_ptrs[layer_idx]; + // XPUType* cache_xpu_ptr = + // reinterpret_cast(const_cast(cache_xpu.data())); + T* cache_xpu_ptr = const_cast(cache_xpu.data()); + auto* cache_cpu_ptr = reinterpret_cast(cache_cpu_pointer); + auto cache_shape = cache_xpu.shape(); + const int64_t max_block_num_xpu = cache_shape[0]; + const int64_t num_heads = cache_shape[1]; + const int64_t block_size = cache_shape[2]; + const int64_t head_dim = cache_shape[3]; + const int64_t cache_stride = num_heads * block_size * head_dim; + + if (swap_block_ids_xpu.size() == 0) { + return; + } + int i = 0; + int64_t consecutive_block_count = 1; + int64_t last_xpu_block_id = swap_block_ids_xpu[i]; + int64_t last_cpu_block_id = swap_block_ids_cpu[i]; + int64_t first_xpu_block_id = + last_xpu_block_id; // first block id in a consecutive block ids + int64_t first_cpu_block_id = last_cpu_block_id; + i += 1; + while (true) { + if (i >= swap_block_ids_xpu.size()) { + break; + } + int64_t xpu_block_id = swap_block_ids_xpu[i]; + int64_t cpu_block_id = swap_block_ids_cpu[i]; + PD_CHECK(xpu_block_id >= 0 && xpu_block_id < max_block_num_xpu); + PD_CHECK(cpu_block_id >= 0 && cpu_block_id < max_block_num_cpu); + if (xpu_block_id == last_xpu_block_id + 1 && + cpu_block_id == last_cpu_block_id + 1) { // consecutive + consecutive_block_count += 1; + last_xpu_block_id = xpu_block_id; + last_cpu_block_id = cpu_block_id; + } else { + // end of a consecutive block ids + auto* cache_xpu_ptr_now = + cache_xpu_ptr + first_xpu_block_id * cache_stride; + auto* cache_cpu_ptr_now = + cache_cpu_ptr + first_cpu_block_id * cache_stride; + if (mode == 0) { // copy from device to host + xpu_memcpy(cache_cpu_ptr_now, + cache_xpu_ptr_now, + cache_stride * sizeof(XPUType) * consecutive_block_count, + XPU_DEVICE_TO_HOST); + } else { // copy from host to device + xpu_memcpy(cache_xpu_ptr_now, + cache_cpu_ptr_now, + cache_stride * sizeof(XPUType) * consecutive_block_count, + XPU_HOST_TO_DEVICE); + } + first_xpu_block_id = xpu_block_id; + first_cpu_block_id = cpu_block_id; + last_xpu_block_id = xpu_block_id; + last_cpu_block_id = cpu_block_id; + consecutive_block_count = 1; + } + i += 1; + } + // last batch + auto* cache_xpu_ptr_now = cache_xpu_ptr + first_xpu_block_id * cache_stride; + auto* cache_cpu_ptr_now = cache_cpu_ptr + first_cpu_block_id * cache_stride; + if (mode == 0) { // copy from device to host + xpu_memcpy(cache_cpu_ptr_now, + cache_xpu_ptr_now, + cache_stride * sizeof(XPUType) * consecutive_block_count, + XPU_DEVICE_TO_HOST); + } else { // copy from host to device + xpu_memcpy(cache_xpu_ptr_now, + cache_cpu_ptr_now, + cache_stride * sizeof(XPUType) * consecutive_block_count, + XPU_HOST_TO_DEVICE); + } + } +} + +void SwapCacheAllLayers( + const std::vector& cache_xpu_tensors, // xpu + const std::vector& cache_cpu_ptrs, // cpu memory pointer + int64_t max_block_num_cpu, // cpu max block num + const std::vector& swap_block_ids_xpu, + const std::vector& swap_block_ids_cpu, + int rank, + int mode) { + xpu_set_device(rank); // used for distributed launch + PD_CHECK(cache_xpu_tensors.size() > 0 && + cache_xpu_tensors.size() == cache_cpu_ptrs.size()); + switch (cache_xpu_tensors[0].dtype()) { + case paddle::DataType::FLOAT16: + return SwapCacheImplAllLayers(cache_xpu_tensors, + cache_cpu_ptrs, + max_block_num_cpu, + swap_block_ids_xpu, + swap_block_ids_cpu, + mode); + case paddle::DataType::UINT8: + return SwapCacheImplAllLayers(cache_xpu_tensors, + cache_cpu_ptrs, + max_block_num_cpu, + swap_block_ids_xpu, + swap_block_ids_cpu, + mode); + case paddle::DataType::INT8: + return SwapCacheImplAllLayers(cache_xpu_tensors, + cache_cpu_ptrs, + max_block_num_cpu, + swap_block_ids_xpu, + swap_block_ids_cpu, + mode); + case paddle::DataType::BFLOAT16: + return SwapCacheImplAllLayers(cache_xpu_tensors, + cache_cpu_ptrs, + max_block_num_cpu, + swap_block_ids_xpu, + swap_block_ids_cpu, + mode); + default: + PD_THROW("Unsupported data type."); + } +} + +PD_BUILD_OP(swap_cache_all_layers) + .Inputs({paddle::Vec("cache_xpu_tensors")}) + .Attrs({ + "cache_cpu_ptrs: std::vector", + "max_block_num_cpu: int64_t", + "swap_block_ids_xpu: std::vector", + "swap_block_ids_cpu: std::vector", + "rank: int", + "mode: int", + }) + .Outputs({paddle::Vec("cache_dst_outs")}) + .SetInplaceMap({{paddle::Vec("cache_xpu_tensors"), + paddle::Vec("cache_dst_outs")}}) + .SetKernelFn(PD_KERNEL(SwapCacheAllLayers)); diff --git a/custom_ops/xpu_ops/src/ops/utility/debug.cc b/custom_ops/xpu_ops/src/ops/utility/debug.cc new file mode 100644 index 000000000..2a67c07fe --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/utility/debug.cc @@ -0,0 +1,194 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ops/utility/debug.h" +#include +#include // for std::sqrt +#include +#include +#include // for std::accumulate +#include +#include +#include +#include "paddle/phi/common/float16.h" +#include "xpu/internal/infra_op.h" + +namespace paddle { + +std::string string_format(const std::string fmt_str, ...) { + // Reserve two times as much as the length of the fmt_str + int final_n, n = (static_cast(fmt_str.size())) * 2; + std::unique_ptr formatted; + va_list ap; + while (1) { + formatted.reset(new char[n]); + // Wrap the plain char array into the unique_ptr + std::strcpy(&formatted[0], fmt_str.c_str()); // NOLINT + va_start(ap, fmt_str); + final_n = vsnprintf(&formatted[0], n, fmt_str.c_str(), ap); + va_end(ap); + if (final_n < 0 || final_n >= n) + n += std::abs(final_n - n + 1); + else + break; + } + return std::string(formatted.get()); +} + +std::string shape_to_string(const std::vector& shape) { + std::ostringstream os; + auto rank = shape.size(); + if (rank > 0) { + os << shape[0]; + for (size_t i = 1; i < rank; i++) { + os << ", " << shape[i]; + } + } + return os.str(); +} + +template +float cal_mean(const std::vector& data) { + return std::accumulate(data.begin(), data.end(), 0.f) / + static_cast(data.size()); +} + +template +float cal_std(const std::vector& data) { + float mean = cal_mean(data); + float variance = std::accumulate(data.begin(), + data.end(), + 0.0, + [mean](T acc, T val) { + return acc + (val - mean) * (val - mean); + }) / + data.size(); + return std::sqrt(variance); +} + +template +void DebugPrintXPUTensor(const phi::XPUContext* xpu_ctx, + const paddle::Tensor& input, + std::string tag, + int len) { + const T* input_data_ptr = input.data(); + std::vector input_data(len); + xpu::do_device2host( + xpu_ctx->x_context(), input_data_ptr, input_data.data(), len); + for (int i = 0; i < len; ++i) { + std::cout << "DebugPrintXPUTensor " << tag << ", data: " << input_data[i] + << std::endl; + } + + std::cout << "DebugPrintXPUTensor " << tag + << ", mean: " << cal_mean(input_data) << std::endl; + std::cout << "DebugPrintXPUTensor " << tag << ", std: " << cal_std(input_data) + << std::endl; +} + +template +void DebugPrintXPUTensorv2(const paddle::Tensor& input, + std::string tag, + int len) { + auto input_cpu = input.copy_to(phi::CPUPlace(), false); + std::ostringstream os; + + const T* input_data = input_cpu.data(); + for (int i = 0; i < len; ++i) { + os << input_data[i] << ", "; + } + std::cout << "DebugPrintXPUTensorv2 " << tag << ", data: " << os.str() + << std::endl; +} + +template <> +void DebugPrintXPUTensorv2(const paddle::Tensor& input, + std::string tag, + int len) { + auto input_cpu = input.copy_to(phi::CPUPlace(), false); + std::ostringstream os; + + const paddle::float16* input_data = input_cpu.data(); + for (int i = 0; i < len; ++i) { + os << static_cast(input_data[i]) << ", "; + } + std::cout << "DebugPrintXPUTensorv2 " << tag << ", data: " << os.str() + << std::endl; +} + +template <> +void DebugPrintXPUTensorv2(const paddle::Tensor& input, + std::string tag, + int len) { + auto input_cpu = input.copy_to(phi::CPUPlace(), false); + std::ostringstream os; + + const paddle::bfloat16* input_data = input_cpu.data(); + for (int i = 0; i < len; ++i) { + os << static_cast(input_data[i]) << ", "; + } + std::cout << "DebugPrintXPUTensorv2 " << tag << ", data: " << os.str() + << std::endl; +} + +template <> +void DebugPrintXPUTensorv2(const paddle::Tensor& input, + std::string tag, + int len) { + auto input_cpu = input.copy_to(phi::CPUPlace(), false); + + std::ostringstream os; + + const int8_t* input_data = input_cpu.data(); + for (int i = 0; i < len; ++i) { + int8_t tmp = input_data[i] >> 4; + os << (int32_t)tmp << ", "; + } + std::cout << "DebugPrintXPUTensorv2 " << tag << ", data: " << os.str() + << std::endl; +} + +#define INSTANTIATE_DEBUGPRINT_XPUTENSOR(Type, FuncName, ...) \ + template void FuncName(__VA_ARGS__); + +#define INSTANTIATE_DEBUGPRINT_XPUTENSOR_V1(Type) \ + INSTANTIATE_DEBUGPRINT_XPUTENSOR(Type, \ + DebugPrintXPUTensor, \ + const phi::XPUContext* xpu_ctx, \ + const paddle::Tensor& input, \ + std::string tag, \ + int len) + +#define INSTANTIATE_DEBUGPRINT_XPUTENSOR_V2(Type) \ + INSTANTIATE_DEBUGPRINT_XPUTENSOR(Type, \ + DebugPrintXPUTensorv2, \ + const paddle::Tensor& input, \ + std::string tag, \ + int len) + +// do not support bool type now, please use DebugPrintXPUTensorv2 +// INSTANTIATE_DEBUGPRINT_XPUTENSOR_V1(bool) +INSTANTIATE_DEBUGPRINT_XPUTENSOR_V1(float) +INSTANTIATE_DEBUGPRINT_XPUTENSOR_V1(int) +INSTANTIATE_DEBUGPRINT_XPUTENSOR_V1(int64_t) + +INSTANTIATE_DEBUGPRINT_XPUTENSOR_V2(int8_t) +INSTANTIATE_DEBUGPRINT_XPUTENSOR_V2(bool) +INSTANTIATE_DEBUGPRINT_XPUTENSOR_V2(int64_t) +INSTANTIATE_DEBUGPRINT_XPUTENSOR_V2(float) +INSTANTIATE_DEBUGPRINT_XPUTENSOR_V2(int) +INSTANTIATE_DEBUGPRINT_XPUTENSOR_V2(paddle::float16) +INSTANTIATE_DEBUGPRINT_XPUTENSOR_V2(paddle::bfloat16) + +} // namespace paddle diff --git a/custom_ops/xpu_ops/src/ops/utility/env.cc b/custom_ops/xpu_ops/src/ops/utility/env.cc new file mode 100644 index 000000000..0fd4f2ec2 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/utility/env.cc @@ -0,0 +1,63 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env.h" // NOLINT + +namespace paddle { + +// Specialization for bool +template <> +bool get_env(const std::string& var_name, bool default_value) { + const char* value = std::getenv(var_name.c_str()); + if (!value) { + if (var_name.size() < 6 || var_name.substr(0, 6) != "FLAGS_") { + return get_env("FLAGS_" + var_name, default_value); + } + return default_value; + } + std::string valStr(value); + std::transform(valStr.begin(), valStr.end(), valStr.begin(), ::tolower); + if (valStr == "true" || valStr == "1") { + return true; + } else if (valStr == "false" || valStr == "0") { + return false; + } + PD_THROW("Unexpected value:", valStr, ", only bool supported."); + return default_value; +} + +template <> +int get_env(const std::string& var_name, int default_value) { + const char* value = std::getenv(var_name.c_str()); + if (!value) { + if (var_name.size() < 6 || var_name.substr(0, 6) != "FLAGS_") { + return get_env("FLAGS_" + var_name, default_value); + } + return default_value; + } + try { + return std::stoi(value); + } catch (...) { + PD_THROW("Unexpected value:", value, ", only int supported."); + } +} + +#define DEFINE_GET_ENV_SPECIALIZATION(T) \ + template <> \ + T get_env(const std::string& var_name, T default_value); + +DEFINE_GET_ENV_SPECIALIZATION(bool) +DEFINE_GET_ENV_SPECIALIZATION(int) + +} // namespace paddle diff --git a/custom_ops/xpu_ops/src/ops/utility/env.h b/custom_ops/xpu_ops/src/ops/utility/env.h new file mode 100644 index 000000000..69e71d460 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/utility/env.h @@ -0,0 +1,29 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "paddle/extension.h" + +namespace paddle { +template +T get_env(const std::string& var_name, T default_value); +} + +#define XPU_DECLARE_VALUE(type, env_name, default_value) \ + static type FLAGS_##env_name = \ + paddle::get_env(#env_name, default_value); + +#define XPU_DECLARE_BOOL(env_name, default_value) \ + XPU_DECLARE_VALUE(bool, env_name, default_value) +#define XPU_DECLARE_INT(env_name, default_value) \ + XPU_DECLARE_VALUE(int, env_name, default_value) diff --git a/custom_ops/xpu_ops/src/ops/utility/logging.cc b/custom_ops/xpu_ops/src/ops/utility/logging.cc new file mode 100644 index 000000000..d87e70571 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/utility/logging.cc @@ -0,0 +1,95 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ops/utility/logging.h" +#include + +namespace paddle { + +void gen_log(std::ostream& log_stream_, + const char* file, + const char* func, + int lineno, + const char* level, + const int kMaxLen = 40) { + const int len = strlen(file); + + struct tm tm_time; // Time of creation of LogMessage + time_t timestamp = time(NULL); +#if defined(_WIN32) + localtime_s(&tm_time, ×tamp); +#else + localtime_r(×tamp, &tm_time); +#endif + struct timeval tv; + gettimeofday(&tv, NULL); + + // print date / time + log_stream_ << '[' << level << ' ' << std::setw(2) << 1 + tm_time.tm_mon + << '/' << std::setw(2) << tm_time.tm_mday << ' ' << std::setw(2) + << tm_time.tm_hour << ':' << std::setw(2) << tm_time.tm_min << ':' + << std::setw(2) << tm_time.tm_sec << '.' << std::setw(3) + << tv.tv_usec / 1000 << " "; + + if (len > kMaxLen) { + log_stream_ << "..." << file + len - kMaxLen << ":" << lineno << " " << func + << "] "; + } else { + log_stream_ << file << " " << func << ":" << lineno << "] "; + } +} + +CustomLogMessage::CustomLogMessage(const char* file, + const char* func, + int lineno, + const char* level) + : level_(level) { + gen_log(log_stream_, file, func, lineno, level); +} + +CustomLogMessage::~CustomLogMessage() { + log_stream_ << '\n'; + fprintf(stderr, "%s", log_stream_.str().c_str()); +} + +CustomLogMessageFatal::~CustomLogMessageFatal() noexcept(false) { + log_stream_ << '\n'; + fprintf(stderr, "%s", log_stream_.str().c_str()); + throw CustomException(log_stream_.str().c_str()); + abort(); +} + +CustomVLogMessage::CustomVLogMessage(const char* file, + const char* func, + int lineno, + const int32_t level_int) { + const char* GLOG_v = std::getenv("GLOG_v"); + GLOG_v_int = (GLOG_v && atoi(GLOG_v) > 0) ? atoi(GLOG_v) : 0; + this->level_int = level_int; + if (GLOG_v_int < level_int) { + return; + } + const char* level = std::to_string(level_int).c_str(); + gen_log(log_stream_, file, func, lineno, level); +} + +CustomVLogMessage::~CustomVLogMessage() { + if (GLOG_v_int < this->level_int) { + return; + } + log_stream_ << '\n'; + fprintf(stderr, "%s", log_stream_.str().c_str()); +} + +} // namespace paddle diff --git a/custom_ops/xpu_ops/src/ops/utility/logging.h b/custom_ops/xpu_ops/src/ops/utility/logging.h new file mode 100644 index 000000000..919e80eca --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/utility/logging.h @@ -0,0 +1,114 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#if !defined(_WIN32) +#include +#include +#else +#define NOMINMAX // msvc max/min macro conflict with std::min/max +#include +#undef min +#undef max +extern struct timeval; +static int gettimeofday(struct timeval* tp, void* tzp) { + LARGE_INTEGER now, freq; + QueryPerformanceCounter(&now); + QueryPerformanceFrequency(&freq); + tp->tv_sec = now.QuadPart / freq.QuadPart; + tp->tv_usec = (now.QuadPart % freq.QuadPart) * 1000000 / freq.QuadPart; + return (0); +} +#endif + +#include +#include +#include +#include + +// LOG() +#define LOG(status) LOG_##status.stream() +#define LOG_INFO paddle::CustomLogMessage(__FILE__, __FUNCTION__, __LINE__, "I") +#define LOG_ERROR LOG_INFO +#define LOG_WARNING \ + paddle::CustomLogMessage(__FILE__, __FUNCTION__, __LINE__, "W") +#define LOG_FATAL \ + paddle::CustomLogMessageFatal(__FILE__, __FUNCTION__, __LINE__) + +// VLOG() +#define VLOG(level) \ + paddle::CustomVLogMessage(__FILE__, __FUNCTION__, __LINE__, level).stream() + +namespace paddle { + +struct CustomException : public std::exception { + const std::string exception_prefix = "Custom exception: \n"; + std::string message; + explicit CustomException(const char* detail) { + message = exception_prefix + std::string(detail); + } + const char* what() const noexcept { return message.c_str(); } +}; + +class CustomLogMessage { + public: + CustomLogMessage(const char* file, + const char* func, + int lineno, + const char* level = "I"); + ~CustomLogMessage(); + + std::ostream& stream() { return log_stream_; } + + protected: + std::stringstream log_stream_; + std::string level_; + + CustomLogMessage(const CustomLogMessage&) = delete; + void operator=(const CustomLogMessage&) = delete; +}; + +class CustomLogMessageFatal : public CustomLogMessage { + public: + CustomLogMessageFatal(const char* file, + const char* func, + int lineno, + const char* level = "F") + : CustomLogMessage(file, func, lineno, level) {} + ~CustomLogMessageFatal() noexcept(false); +}; + +class CustomVLogMessage { + public: + CustomVLogMessage(const char* file, + const char* func, + int lineno, + const int32_t level_int = 0); + ~CustomVLogMessage(); + + std::ostream& stream() { return log_stream_; } + + protected: + std::stringstream log_stream_; + int32_t GLOG_v_int; + int32_t level_int; + + CustomVLogMessage(const CustomVLogMessage&) = delete; + void operator=(const CustomVLogMessage&) = delete; +}; + +} // namespace paddle diff --git a/custom_ops/xpu_ops/src/ops/weight_only_linear.cc b/custom_ops/xpu_ops/src/ops/weight_only_linear.cc new file mode 100644 index 000000000..62a115989 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/weight_only_linear.cc @@ -0,0 +1,207 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "paddle/extension.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "utility/debug.h" +#include "utility/env.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +XPU_DECLARE_BOOL(ENABLE_XVLLM_SDNN_INFER, false); +namespace xftblock = baidu::xpu::xftblock; +namespace api = baidu::xpu::api; + +template +std::vector WeightOnlyLinearKernel( + const paddle::Tensor& x, + const paddle::Tensor& weight, + const paddle::Tensor& weight_scale, + const paddle::optional& bias, + const std::string& weight_dtype) { + using XPU_TX = typename XPUTypeTrait::Type; + using XPU_TW = typename XPUTypeTrait::Type; + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + xftblock::XFTContext xctx(xpu_ctx->x_context(), nullptr); + auto rt_guard = xctx.get_rt_guard(); + auto xftblock_tx = xftblock::DataTypeToEnum::value; + auto xftblock_tw = xftblock::DataTypeToEnum::value; + + int ret = -1; + auto x_shape = x.shape(); + auto w_shape = weight.shape(); + int64_t n = w_shape[0]; + int64_t k = w_shape[1]; + int64_t m = x.numel() / k; + if (weight_dtype == "int4_t") { + n = n * 2; + } + paddle::Tensor out = paddle::empty({m, n}, x.dtype(), x.place()); + if (m == 0) { + return {out}; + } + + paddle::Tensor bias_fp32; + if (bias.get_ptr() && bias.get_ptr()->dtype() != paddle::DataType::FLOAT32) { + bias_fp32 = paddle::empty({n}, paddle::DataType::FLOAT32, x.place()); + PD_CHECK(bias.get_ptr()->dtype() == x.dtype(), "bias.dtype != x.dtype"); + ret = api::cast( + xpu_ctx->x_context(), + reinterpret_cast(bias.get_ptr()->data()), + bias_fp32.data(), + n); + PD_CHECK(ret == 0, "cast"); + } + + xftblock::Tensor input_x(const_cast(x.data()), xftblock_tx, {m, k}); + xftblock::Tensor input_w(const_cast(weight.data()), + nullptr, + const_cast(weight_scale.data()), + xftblock_tw, + {n, k}); + xftblock::Tensor output(const_cast(out.data()), xftblock_tx, {m, n}); + std::shared_ptr input_bias; + if (bias.get_ptr()) { + if (bias.get_ptr()->dtype() != paddle::DataType::FLOAT32) { + input_bias = std::make_shared( + const_cast(bias_fp32.data()), + xftblock::DataType::DT_FLOAT, + std::vector({n})); + } else { + input_bias = std::make_shared( + const_cast(bias.get_ptr()->data()), + xftblock::DataType::DT_FLOAT, + std::vector({n})); + } + } + bool use_sdnn = FLAGS_ENABLE_XVLLM_SDNN_INFER; + if (x.dtype() == paddle::DataType::BFLOAT16) { + ret = xftblock:: + xft_fc_block_cast_te_per_token( + &xctx, + &input_x, + &input_w, + &output, + input_bias.get(), + api::Activation_t::LINEAR, + false, + true, + 1.0f, + 0.0f, + 0, + 1, + false, + false, + use_sdnn); + PD_CHECK(ret == 0, "xft_fc_block_cast_te_per_token"); + } else { + ret = xftblock::xft_fc_block( + &xctx, + &input_x, + &input_w, + &output, + input_bias.get(), + api::Activation_t::LINEAR, + false, + true, + 1.0f, + 0.0f, + 0, + 1, + false, + false); + PD_CHECK(ret == 0, "xft_fc_block"); + } + + return {out}; +} + +std::vector WeightOnlyLinear( + const paddle::Tensor& x, + const paddle::Tensor& weight, + const paddle::Tensor& weight_scale, + const paddle::optional& bias, + const std::string& weight_dtype, + const int arch, + const int group_size) { + const auto x_type = x.dtype(); + const auto w_type = weight.dtype(); + +#define APPLY_FFN_KERNEL(TX, TW) \ + return WeightOnlyLinearKernel( \ + x, weight, weight_scale, bias, weight_dtype); + + if (x_type == paddle::DataType::BFLOAT16 && + w_type == paddle::DataType::INT8) { + APPLY_FFN_KERNEL(paddle::bfloat16, int8_t); + } else if (x_type == paddle::DataType::FLOAT16 && + w_type == paddle::DataType::INT8) { + APPLY_FFN_KERNEL(paddle::float16, int8_t); + } else { + PD_THROW("WeightOnlyLinear not support x_type=", + static_cast(x_type), + ", w_type=", + static_cast(w_type)); + return {}; + } +#undef APPLY_FFN_KERNEL +} + +std::vector> WeightOnlyLinearInferShape( + const std::vector& x_shape, + const std::vector& weight_shape, + const std::vector& weight_scale_shape, + const paddle::optional>& bias_shape, + const std::string& weight_dtype, + const int arch, + const int group_size) { + PD_CHECK(weight_shape.size() == 2); + int64_t n = weight_shape[0]; + int64_t k = weight_shape[1]; + int64_t x_numel = std::accumulate(x_shape.begin(), + x_shape.end(), + static_cast(1), + std::multiplies()); + int64_t m = x_numel / k; + if (weight_dtype == "int4") { + n = n * 2; + } + return {{m, n}}; +} + +std::vector WeightOnlyLinearInferDtype( + const paddle::DataType& x_dtype, + const paddle::DataType& w_dtype, + const paddle::DataType& weight_scale_dtype, + const paddle::optional& bias_dtype, + const std::string& weight_dtype, + const int arch, + const int group_size) { + return {x_dtype}; +} + +PD_BUILD_STATIC_OP(weight_only_linear_xpu) + .Inputs({"x", "weight", "weight_scale", paddle::Optional("bias")}) + .Outputs({"out"}) + .Attrs({"weight_dtype:std::string", "arch:int", "group_size:int"}) + .SetKernelFn(PD_KERNEL(WeightOnlyLinear)) + .SetInferShapeFn(PD_INFER_SHAPE(WeightOnlyLinearInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(WeightOnlyLinearInferDtype)); diff --git a/custom_ops/xpu_ops/test/test_block_attn_prefix_cache.py b/custom_ops/xpu_ops/test/test_block_attn_prefix_cache.py new file mode 100644 index 000000000..1a607e192 --- /dev/null +++ b/custom_ops/xpu_ops/test/test_block_attn_prefix_cache.py @@ -0,0 +1,336 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import block_attn, get_infer_param + +head_num = 64 +kv_head_num = 8 +head_dim = 128 +seq_len = 128 +block_batch = 5 +max_block_per_seq = 128 +block_size = 64 + +seq_lens_encoder = paddle.to_tensor([128, 0, 0, 0, 0], dtype="int32") +seq_lens_decoder = paddle.to_tensor([0, 0, 0, 0, 0], dtype="int32") +seq_lens_this_time = paddle.to_tensor([128, 0, 0, 0, 0], dtype="int32") +block_tables = paddle.arange(0, block_batch * max_block_per_seq, dtype="int32") +block_tables = block_tables.reshape((block_batch, max_block_per_seq)) +( + encoder_batch_map, + decoder_batch_map, + encoder_batch_idx, + decoder_batch_idx, + encoder_seq_lod, + decoder_seq_lod, + encoder_kv_lod, + prefix_len, + decoder_context_len, + decoder_context_len_cache, + prefix_block_tables, + encoder_batch_map_cpu, + decoder_batch_map_cpu, + encoder_batch_idx_cpu, + decoder_batch_idx_cpu, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_kv_lod_cpu, + prefix_len_cpu, + decoder_context_len_cpu, + decoder_context_len_cache_cpu, + len_info_cpu, +) = get_infer_param( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64 +) # block_size + +qkv = paddle.uniform( + shape=[seq_len, (head_num + 2 * kv_head_num) * head_dim], + dtype="bfloat16", + min=-1.0, + max=1.0, +) + +cum_offsets = paddle.zeros(shape=[block_batch], dtype="bfloat16") +rotary_embs = paddle.uniform(shape=[2, 1, 8192, 1, head_dim], dtype="float32", min=-1.0, max=1.0) +key_cache = paddle.zeros( + shape=[block_batch * max_block_per_seq, kv_head_num, block_size, head_dim], + dtype="bfloat16", +) +value_cache = paddle.zeros( + shape=[block_batch * max_block_per_seq, kv_head_num, block_size, head_dim], + dtype="bfloat16", +) +# C8 +key_cache_int8 = paddle.zeros( + shape=[block_batch * max_block_per_seq, kv_head_num, block_size, head_dim], + dtype="int8", +) +value_cache_int8 = paddle.zeros( + shape=[block_batch * max_block_per_seq, kv_head_num, block_size, head_dim], + dtype="int8", +) +scale_tensor_k = paddle.uniform(shape=[kv_head_num * head_dim], dtype="bfloat16", min=1.0, max=1.0) # max +scale_tensor_v = paddle.uniform(shape=[kv_head_num * head_dim], dtype="bfloat16", min=1.0, max=1.0) # max +k_quant_scale = 127.0 / scale_tensor_k # for C8 per channel means 127 / max +v_quant_scale = 127.0 / scale_tensor_v # for C8 per channel means 127 / max +k_dequant_scale = paddle.cast(scale_tensor_k, dtype="float32") # for C8 per channel means max +v_dequant_scale = paddle.cast(scale_tensor_v, dtype="float32") # for C8 per channel means max +k_dequant_scale_zp = 1 / k_quant_scale # for C8 per channel zp means max +v_dequant_scale_zp = 1 / v_quant_scale # for C8 per channel zp means max + +k_zp = paddle.zeros(shape=[kv_head_num * head_dim], dtype="bfloat16") +v_zp = paddle.zeros(shape=[kv_head_num * head_dim], dtype="bfloat16") +attn_out = block_attn( + qkv, + key_cache, + value_cache, + cum_offsets, + rotary_embs, + block_tables, + prefix_block_tables, + len_info_cpu, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_kv_lod_cpu, + encoder_batch_map_cpu, + decoder_context_len_cpu, + decoder_context_len_cache_cpu, + decoder_batch_map_cpu, + prefix_len_cpu, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, +) +attn_out_C8 = block_attn( + qkv, + key_cache_int8, + value_cache_int8, + cum_offsets, + rotary_embs, + block_tables, + prefix_block_tables, + len_info_cpu, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_kv_lod_cpu, + encoder_batch_map_cpu, + decoder_context_len_cpu, + decoder_context_len_cache_cpu, + decoder_batch_map_cpu, + prefix_len_cpu, + k_quant_scale, + v_quant_scale, + k_dequant_scale, + v_dequant_scale, + None, + None, + None, + None, + None, + None, +) +attn_out_C8_zp = block_attn( + qkv, + key_cache_int8, + value_cache_int8, + cum_offsets, + rotary_embs, + block_tables, + prefix_block_tables, + len_info_cpu, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_kv_lod_cpu, + encoder_batch_map_cpu, + decoder_context_len_cpu, + decoder_context_len_cache_cpu, + decoder_batch_map_cpu, + prefix_len_cpu, + k_quant_scale, + v_quant_scale, + k_dequant_scale_zp, + v_dequant_scale_zp, + k_zp, + v_zp, + None, + None, + None, + None, +) + +# prefix cache : hit 71 tokens +hit_prefix_len = 71 +seq_lens_encoder = paddle.to_tensor([seq_len - hit_prefix_len, 0, 0, 0, 0], dtype="int32") +# 71 means prefix len +seq_lens_decoder = paddle.to_tensor([hit_prefix_len, 0, 0, 0, 0], dtype="int32") +( + encoder_batch_map, + decoder_batch_map, + encoder_batch_idx, + decoder_batch_idx, + encoder_seq_lod, + decoder_seq_lod, + encoder_kv_lod, + prefix_len, + decoder_context_len, + decoder_context_len_cache, + prefix_block_tables, + encoder_batch_map_cpu, + decoder_batch_map_cpu, + encoder_batch_idx_cpu, + decoder_batch_idx_cpu, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_kv_lod_cpu, + prefix_len_cpu, + decoder_context_len_cpu, + decoder_context_len_cache_cpu, + len_info_cpu, +) = get_infer_param( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64 +) # block_size +qkv_prefix = qkv[hit_prefix_len:] + +attn_out_prefix_cache = block_attn( + qkv_prefix, + key_cache, + value_cache, + cum_offsets, + rotary_embs, + block_tables, + prefix_block_tables, + len_info_cpu, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_kv_lod_cpu, + encoder_batch_map_cpu, + decoder_context_len_cpu, + decoder_context_len_cache_cpu, + decoder_batch_map_cpu, + prefix_len_cpu, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, +) + +attn_out_C8_prefix_cache = block_attn( + qkv_prefix, + key_cache_int8, + value_cache_int8, + cum_offsets, + rotary_embs, + block_tables, + prefix_block_tables, + len_info_cpu, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_kv_lod_cpu, + encoder_batch_map_cpu, + decoder_context_len_cpu, + decoder_context_len_cache_cpu, + decoder_batch_map_cpu, + prefix_len_cpu, + k_quant_scale, + v_quant_scale, + k_dequant_scale, + v_dequant_scale, + None, + None, + None, + None, + None, + None, +) + +attn_out_C8_zp_prefix_cache = block_attn( + qkv_prefix, + key_cache_int8, + value_cache_int8, + cum_offsets, + rotary_embs, + block_tables, + prefix_block_tables, + len_info_cpu, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_kv_lod_cpu, + encoder_batch_map_cpu, + decoder_context_len_cpu, + decoder_context_len_cache_cpu, + decoder_batch_map_cpu, + prefix_len_cpu, + k_quant_scale, + v_quant_scale, + k_dequant_scale_zp, + v_dequant_scale_zp, + k_zp, + v_zp, + None, + None, + None, + None, +) +print("-- C16 prefix cache test --") +print("attn_out[hit_prefix_len:]'s mean:", attn_out[hit_prefix_len:].mean().item()) +print("attn_out_prefix_cache's mean: ", attn_out_prefix_cache.mean().item()) +attn_out_prefix_cache_np = attn_out_prefix_cache.astype("float32").numpy() +attn_out_np = attn_out[hit_prefix_len:].astype("float32").numpy() +assert np.allclose( + attn_out_prefix_cache_np, attn_out_np, rtol=1e-2, atol=1e-3 +), f"C16 prefix cache != No prefix cache,\n attn_out[hit_prefix_len:]: {attn_out_np},\nattn_out_prefix_cache: {attn_out_prefix_cache_np}" + + +print("\n-- C8 per channle prefix cache test --") +print( + "attn_out_C8[hit_prefix_len:]'s mean:", + attn_out_C8[hit_prefix_len:].mean().item(), +) +print("attn_out_C8_prefix_cache's mean: ", attn_out_C8_prefix_cache.mean().item()) +attn_out_C8_prefix_cache_np = attn_out_C8_prefix_cache.astype("float32").numpy() +attn_out_C8_np = attn_out_C8[hit_prefix_len:].astype("float32").numpy() +assert np.allclose( + attn_out_C8_prefix_cache_np, attn_out_C8_np, rtol=1e-1, atol=1e-2 +), f"C8 per channle prefix cache != No prefix cache,\n attn_out_C8[hit_prefix_len:]: {attn_out_C8_np},\nattn_out_C8_prefix_cache: {attn_out_C8_prefix_cache_np}" + +print("\n-- C8 per channle zp prefix cache test --") +print( + "attn_out_C8_zp[hit_prefix_len:]'s mean:", + attn_out_C8_zp[hit_prefix_len:].mean().item(), +) +print( + "attn_out_C8_zp_prefix_cache's mean: ", + attn_out_C8_zp_prefix_cache.mean().item(), +) +attn_out_C8_zp_prefix_cache_np = attn_out_C8_zp_prefix_cache.astype("float32").numpy() +attn_out_C8_zp_np = attn_out_C8_zp[hit_prefix_len:].astype("float32").numpy() +assert np.allclose( + attn_out_C8_zp_prefix_cache_np, attn_out_C8_zp_np, rtol=1e-1, atol=1e-2 +), f"C8 per channle zp prefix cache != No prefix cache,\n attn_out_C8_zp[hit_prefix_len:]: {attn_out_C8_zp_np},\nattn_out_C8_zp_prefix_cache: {attn_out_C8_zp_prefix_cache_np}" diff --git a/custom_ops/xpu_ops/test/test_fused_rms_norm.py b/custom_ops/xpu_ops/test/test_fused_rms_norm.py new file mode 100644 index 000000000..ea218c8c7 --- /dev/null +++ b/custom_ops/xpu_ops/test/test_fused_rms_norm.py @@ -0,0 +1,137 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import fused_rms_norm_xpu + +# from paddle.incubate.nn.functional import fused_rms_norm + + +def find_max_diff(arr1, arr2): + """找出两个数组元素差值的最大值及其索引 + 返回: + max_diff (float): 最大绝对值差 + index (tuple): 最大值的位置索引 + actual_diff (float): 实际差值(带符号) + """ + diff = arr1 - arr2 + abs_diff = np.abs(diff) + flat_idx = np.argmax(abs_diff) + idx = np.unravel_index(flat_idx, arr1.shape) + return abs_diff[idx], idx, diff[idx], arr1[idx], arr2[idx] + + +def naive_rmsnorm( + x, + gamma, + beta=None, + epsilon=1e-6, + begin_norm_axis=1, + bias=None, + residual=None, +): + residual_out = None + if bias is not None: + x = x + bias + if residual is not None: + x = x + residual + residual_out = x + variance = (x * x).mean(axis=-1) + out = np.expand_dims(1.0 / np.sqrt(variance + epsilon), axis=-1) * x + out = out * gamma + if beta is not None: + out = out + beta + return out, residual_out + + +def run_and_compare(x_in, residual, bias, norm_weight): + x_in_pd = paddle.to_tensor(x_in).astype(data_type) + residual_pd = None + if residual is not None: + residual_pd = paddle.to_tensor(residual).astype(data_type) + bias_pd = paddle.to_tensor(bias).astype(data_type) + norm_weight_pd = paddle.to_tensor(norm_weight).astype(data_type) + # norm_bias_pd = paddle.to_tensor(norm_bias).astype(data_type) + + out_np, residual_out_np = naive_rmsnorm(x_in, norm_weight, None, epsilon, begin_norm_axis, bias, residual) + out_pd, residual_out_pd = fused_rms_norm_xpu( + x_in_pd, + bias_pd, + residual_pd, + norm_weight_pd, + None, # norm_bias_pd, + epsilon, + begin_norm_axis, + -1, + 0, + 0, + 0, + ) + """ + out_pd1, residual_out_pd1 = fused_rms_norm( + x_in_pd, + norm_weight=norm_weight_pd, + norm_bias=norm_bias_pd, + epsilon=epsilon, + begin_norm_axis=1, + bias=bias_pd, + residual=residual_pd, + quant_scale=-1, + quant_round_type=0, + quant_max_bound=0, + quant_min_bound=0, + ) + """ + abs_diff, idx, diff, val1, val2 = find_max_diff(out_np, out_pd.astype("float32").numpy()) + print(f"out compare: abs_diff={abs_diff}, index={idx}, diff={diff}, {val1} vs {val2}") + assert np.allclose(out_np, out_pd.astype("float32").numpy(), rtol=1e-5, atol=1e-5) + + if residual is not None: + abs_diff, idx, diff, val1, val2 = find_max_diff(residual_out_np, residual_out_pd.astype("float32").numpy()) + print(f"residual_out compare: abs_diff={abs_diff}, index={idx}, diff={diff}, {val1} vs {val2}") + assert np.allclose( + residual_out_np, + residual_out_pd.astype("float32").numpy(), + rtol=1e-5, + atol=1e-5, + ) + + +if __name__ == "__main__": + seed = np.random.randint(0, 1e8) + print(f"numpy random seed is {seed}") + np.random.seed(seed) + + m = 7 + n = 8192 + epsilon = 1e-5 + begin_norm_axis = 1 + data_type = "float32" + + x_in = (np.random.random([m, n]) - 0.5).astype("float32") + residual = (np.random.random([m, n]) - 0.5).astype("float32") + bias = (np.random.random([n]) - 0.5).astype("float32") + norm_weight = (np.random.random([n]) - 0.5).astype("float32") + # norm_bias = np.zeros([n]).astype("float32") + # norm_bias = (np.random.random([n]) - 0.5).astype("float32") + x_in_pd = paddle.to_tensor(x_in).astype(data_type) + residual_pd = paddle.to_tensor(residual).astype(data_type) + bias_pd = paddle.to_tensor(bias).astype(data_type) + norm_weight_pd = paddle.to_tensor(norm_weight).astype(data_type) + # norm_bias_pd = paddle.to_tensor(norm_bias).astype(data_type) + + run_and_compare(x_in, residual, bias, norm_weight) + run_and_compare(x_in, None, bias, norm_weight) diff --git a/custom_ops/xpu_ops/test/test_get_infer_param.py b/custom_ops/xpu_ops/test/test_get_infer_param.py new file mode 100755 index 000000000..f1b992395 --- /dev/null +++ b/custom_ops/xpu_ops/test/test_get_infer_param.py @@ -0,0 +1,95 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle + +from fastdeploy.model_executor.ops.xpu import get_infer_param + +seq_lens_encoder = paddle.to_tensor([100, 0, 0, 0, 300], dtype="int32") +seq_lens_decoder = paddle.to_tensor([0, 5, 0, 25, 64], dtype="int32") +seq_lens_this_time = paddle.to_tensor([100, 1, 0, 1, 300], dtype="int32") +block_table = paddle.arange(0, 40, dtype="int32") +block_table = block_table.reshape((5, 8)) +( + encoder_batch_map, + decoder_batch_map, + encoder_batch_idx, + decoder_batch_idx, + encoder_seq_lod, + decoder_seq_lod, + encoder_kv_lod, + prefix_len, + decoder_context_len, + decoder_context_len_cache, + prefix_block_tables, + encoder_batch_map_cpu, + decoder_batch_map_cpu, + encoder_batch_idx_cpu, + decoder_batch_idx_cpu, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_kv_lod_cpu, + prefix_len_cpu, + decoder_context_len_cpu, + decoder_context_len_cache_cpu, + len_info_cpu, +) = get_infer_param( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_table, 64 +) # block_size + +print("block_table", block_table) +print("encoder_batch_map", encoder_batch_map) # [0, 4, 0, 0, 0] +print("decoder_batch_map", decoder_batch_map) # [1, 3, 0, 0, 0] +print("encoder_batch_idx", encoder_batch_idx) # [0, 3, 0, 0, 0] +print("decoder_batch_idx", decoder_batch_idx) # [1, 2, 0, 0, 0] +print("encoder_seq_lod", encoder_seq_lod) # [0, 100, 400 ,0 ,0 ,0] +print("decoder_seq_lod", decoder_seq_lod) # [0, 1, 2 ,0 ,0 ,0] +print("encoder_kv_lod", encoder_kv_lod) # [0, 100, 464, 0, 0, 0] +print("prefix_len", prefix_len) # [0, 64, 0, 0, 0] +print("decoder_context_len", decoder_context_len) # [6, 26, 0, 0, 0] +print("decoder_context_len_cache", decoder_context_len_cache) # [5, 25, 0, 0, 0] +print("prefix_block_tables", prefix_block_tables) +print("encoder_batch_map_cpu", encoder_batch_map_cpu) # [0, 4, 0, 0, 0] +print("decoder_batch_map_cpu", decoder_batch_map_cpu) # [1, 3, 0, 0, 0] +print("encoder_batch_idx_cpu", encoder_batch_idx_cpu) # [0, 3, 0, 0, 0] +print("decoder_batch_idx_cpu", decoder_batch_idx_cpu) # [1, 2, 0, 0, 0] +print("encoder_seq_lod_cpu", encoder_seq_lod_cpu) # [0, 100, 400 ,0 ,0 ,0] +print("decoder_seq_lod_cpu", decoder_seq_lod_cpu) # [0, 1, 2 ,0 ,0 ,0] +print("encoder_kv_lod_cpu", encoder_kv_lod_cpu) # [0, 100, 464, 0, 0, 0] +print("prefix_len_cpu", prefix_len_cpu) # [0, 64, 0, 0, 0] +print("decoder_context_len_cpu", decoder_context_len_cpu) # [6, 26, 0, 0, 0] +print("decoder_context_len_cache_cpu", decoder_context_len_cache_cpu) # [5, 25, 0, 0, 0] +print( + "len_info_cpu", len_info_cpu +) # {enc_batch, dec_batch, total_enc_len, max_seq_len, max_kv_len, prefix_block_num_per_seq} = [2, 2, 400, 300, 364, 6] + +""" +block_table Tensor(shape=[5, 8], dtype=int32, place=Place(xpu:0), stop_gradient=True, + [[0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ], + [8 , 9 , 10, 11, 12, 13, 14, 15], + [16, 17, 18, 19, 20, 21, 22, 23], + [24, 25, 26, 27, 28, 29, 30, 31], + [32, 33, 34, 35, 36, 37, 38, 39]]) + +prefix_block_tables Tensor(shape=[5, 8], dtype=int32, place=Place(xpu:0), stop_gradient=True, + [[ 0, 1, -1, -1, -1, -1, 32, 33], + [34, 35, 36, 37, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1, -1]]) + +The size of the prefix_block_tables tensor is same as block_table to avoid problems with InferShape of the prefix_block_tables. +However, the actual size used by prefix_block_tables is [block_bs, prefix_block_num_per_seq], where prefix_block_num_per_seq = ceil(max_kv_len / block_size). +Therefore, do not use the tensor shape of prefix_block_tables. Its shape is obtained through block_table.dims[0] and len_info_cpu[-1] +""" diff --git a/custom_ops/xpu_ops/test/test_moe_ep_combine.py b/custom_ops/xpu_ops/test/test_moe_ep_combine.py new file mode 100644 index 000000000..b71e05dae --- /dev/null +++ b/custom_ops/xpu_ops/test/test_moe_ep_combine.py @@ -0,0 +1,93 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import ep_moe_expert_combine + +np.random.seed(2025) + + +def np_softmax(x, axis=-1): + x_max = np.max(x, axis=axis, keepdims=True) + x_exp = np.exp(x - x_max) + return x_exp / np.sum(x_exp, axis=axis, keepdims=True) + + +def create_moe_index(token_num, moe_topk, expand_token_num): + total_positions = token_num * moe_topk + positions = np.random.choice(total_positions, size=expand_token_num, replace=False) + rows = positions // moe_topk + cols = positions % moe_topk + values = np.random.permutation(expand_token_num) + + # moe_index is the output of moe_ep_dispatch + # the val in moe_index is the row in ffn_out for corresponding token and expert, -1 means invalid + moe_index = np.full((token_num, moe_topk), -1) + for i in range(expand_token_num): + moe_index[rows[i], cols[i]] = values[i] + return moe_index + + +# 1) preparation +token_num = 10 +moe_topk = 8 +hidden_dim = 128 +expand_token_num = 30 + +ffn_out = np.random.random((expand_token_num, hidden_dim)) +moe_index = create_moe_index(token_num, moe_topk, expand_token_num) +moe_weights = np.random.random((token_num, moe_topk)) +moe_weights = np_softmax(moe_weights) +moe_weights[moe_index == -1] = -1 +print(f"ffn_out:\n{ffn_out}") +print(f"moe_index:\n{moe_index}") +print(f"moe_weights:\n{moe_weights}") + +# 2) np calculation +combined_out_np = np.zeros((token_num, hidden_dim)) +for token_idx, item in enumerate(moe_index): + for topk_idx, ffn_out_row in enumerate(item): + if ffn_out_row == -1: + continue + combined_out_np[token_idx] += ffn_out[ffn_out_row] * moe_weights[token_idx][topk_idx] +print(f"combined_out_np:\n{combined_out_np}") + +# 3) xpu calculation +dtype = "bfloat16" +ffn_out_pd = paddle.to_tensor(ffn_out, dtype=dtype) +moe_index_pd = paddle.to_tensor(moe_index, dtype="int32") +moe_weights_pd = paddle.to_tensor(moe_weights, dtype=dtype) +combined_out_pd = ep_moe_expert_combine( + ffn_out_pd, + moe_index_pd, + moe_weights_pd, + moe_index_pd.shape[0], + ffn_out_pd.shape[0], + ffn_out_pd.shape[1], + moe_index_pd.shape[1], +) + +# comparation +# print("moe_index:\n", moe_index) +# print("moe_weights:\n", moe_weights) +# print("combined_out_np:\n", combined_out_np) +# print("combined_out_pd:\n", combined_out_pd) +combined_out_pd = combined_out_pd.astype("float32").numpy() +avg_diff = np.sum(np.abs(combined_out_pd - combined_out_np)) / combined_out_pd.size +assert ( + avg_diff < 2e-3 +), f"avg_diff: {avg_diff}\n combined_out_np:\n{combined_out_np}\n combined_out_pd:\n{combined_out_pd}\n" +print(f"[Passed] avg_diff: {avg_diff}") diff --git a/custom_ops/xpu_ops/test/test_moe_ep_dispatch.py b/custom_ops/xpu_ops/test/test_moe_ep_dispatch.py new file mode 100644 index 000000000..9b38bb34e --- /dev/null +++ b/custom_ops/xpu_ops/test/test_moe_ep_dispatch.py @@ -0,0 +1,136 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import ep_moe_expert_dispatch + +np.random.seed(2025) + + +def ep_moe_expert_dispatch_cpu(input, topk_ids, topk_weights, token_nums_per_expert, token_nums_this_rank): + m, n = input.shape[0], input.shape[1] + topk = topk_ids.shape[1] + expert_num = len(token_nums_per_expert) + expert_per_rank = expert_num + + permute_input = np.full((token_nums_this_rank, n), 0.0, dtype=np.float32) + permute_indices_per_token = np.full((m, topk), -1, dtype=np.int32) + recv_num_tokens_per_expert_list_cumsum = np.full(expert_num + 1, 0, dtype=np.int32) + dst_indices = np.full((expert_num, m), -1, dtype=np.int32) + cumsum_idx = np.full(expert_num, 0, dtype=np.int32) + offset = 0 + for expert_id in range(expert_per_rank): + for token_id in range(m): + for k in range(topk): + cur_index = topk_ids[token_id, k] + if cur_index == expert_id: + permute_indices_per_token[token_id, k] = offset + permute_input[offset, :] = input[token_id, :] + offset += 1 + recv_num_tokens_per_expert_list_cumsum[expert_id + 1] = offset + return ( + permute_input, + permute_indices_per_token, + recv_num_tokens_per_expert_list_cumsum, + topk_weights, + dst_indices, + cumsum_idx, + ) + + +def create_moe_index(token_num, topk, expert_num): + topk_ids = np.full((token_num, topk), -1, dtype=np.int32) + token_nums_per_expert = np.full(expert_num_per_rank, 0, dtype=np.int32) + token_all_num = 0 + for i in range(topk_ids.shape[0]): + pos = np.random.choice(np.arange(topk), np.random.randint(1, topk + 1), replace=False) + token_all_num += len(pos) + for j in pos: + topk_ids[i, j] = np.random.choice(expert_num, replace=False) + token_nums_per_expert[topk_ids[i, j]] += 1 + return token_all_num, topk_ids, list(token_nums_per_expert) + + +# 1) preparation +token_num = 7 +expert_num_per_rank = 4 +topk = 8 +hidden_dim = 8192 + +input = np.random.random((token_num, hidden_dim)) +token_nums_this_rank, topk_ids, token_nums_per_expert = create_moe_index(token_num, topk, expert_num_per_rank) +topk_weights = np.random.random((token_num, topk)) +print(f"input:\n{input}") +print(f"token_nums_this_rank:\n{token_nums_this_rank}") +print(f"topk_ids:\n{topk_ids}") +print(f"token_nums_per_expert:\n{token_nums_per_expert}") +print(f"topk_weights:\n{topk_weights}") + +dtype = "bfloat16" +input_xpu = paddle.to_tensor(input, dtype=dtype) +topk_ids_xpu = paddle.to_tensor(topk_ids) +topk_weights_xpu = paddle.to_tensor(topk_weights) + +# 2) cpu calculation +( + permute_input, + permute_indices_per_token, + recv_num_tokens_per_expert_list_cumsum, + dst_weights, + dst_indices, + cumsum_idx, +) = ep_moe_expert_dispatch_cpu(input, topk_ids, topk_weights, token_nums_per_expert, token_nums_this_rank) +print(f"permute_input:\n{permute_input}") +print(f"permute_indices_per_token:\n{permute_indices_per_token}") +print(f"recv_num_tokens_per_expert_list_cumsum:\n{recv_num_tokens_per_expert_list_cumsum}") +print(f"dst_weights:\n{dst_weights}") +print(f"dst_indices:\n{dst_indices}") +print(f"cumsum_idx:\n{cumsum_idx}") + +# 3) xpu calculation +( + permute_input_xpu, + permute_indices_per_token_xpu, + recv_num_tokens_per_expert_list_cumsum_xpu, + dst_weights_xpu, + expand_input_scales, +) = ep_moe_expert_dispatch( + input_xpu, + topk_ids_xpu, + topk_weights_xpu, + None, + token_nums_per_expert, + token_nums_this_rank, + "weight_only_int8", +) + +# comparation +permute_input_xpu = permute_input_xpu.astype("float32").numpy() +permute_indices_per_token_xpu = permute_indices_per_token_xpu.numpy() +recv_num_tokens_per_expert_list_cumsum_xpu = recv_num_tokens_per_expert_list_cumsum_xpu.numpy() + +diff = np.sum(np.abs(permute_input - permute_input_xpu)) / permute_input.size +assert diff < 1e-2, f"diff: {diff}\n permute_input:\n {permute_input}\n permute_input_xpu:\n {permute_input_xpu}\n" + +assert ( + permute_indices_per_token == permute_indices_per_token_xpu +).all(), f"permute_indices_per_token:\n {permute_indices_per_token}\n permute_indices_per_token_xpu:\n {permute_indices_per_token_xpu}\n" + +assert ( + recv_num_tokens_per_expert_list_cumsum == recv_num_tokens_per_expert_list_cumsum_xpu +).all(), f"recv_num_tokens_per_expert_list_cumsum:\n {recv_num_tokens_per_expert_list_cumsum}\n recv_num_tokens_per_expert_list_cumsum_xpu:\n {recv_num_tokens_per_expert_list_cumsum_xpu}\n" + +print("ep_moe_expert_dispatch test success!") diff --git a/custom_ops/xpu_ops/test/test_moe_expert_ffn.py b/custom_ops/xpu_ops/test/test_moe_expert_ffn.py new file mode 100644 index 000000000..be99cdcee --- /dev/null +++ b/custom_ops/xpu_ops/test/test_moe_expert_ffn.py @@ -0,0 +1,295 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import moe_expert_ffn + +np.random.seed(2025) + +token_num = 7 +expert_num = 64 +hidden_dim = 8192 +ffn_inter_dim = 7168 +ffn_outer_dim = ffn_inter_dim // 2 +num_max_dispatch_tokens_per_rank = 128 +num_rank = 8 +expert_num_per_rank = expert_num // num_rank +used_in_ep_low_latency = True +hadamard_blocksize = 512 + +ffn_in = (np.random.random([token_num, hidden_dim]) - 0.5).astype("float32") +token_num_lod = np.full([expert_num_per_rank + 1], 0, "int32") +token_num_lod[-1] = token_num +token_num_lod[1:-1] = np.random.randint(0, token_num, [expert_num_per_rank - 1]) +token_num_lod = np.sort(token_num_lod) +token_num_per_expert = token_num_lod[1:] - token_num_lod[:-1] +ffn1_w = (np.random.random([expert_num_per_rank, ffn_inter_dim, hidden_dim]) - 0.5).astype("float32") +ffn2_w = (np.random.random([expert_num_per_rank, hidden_dim, ffn_outer_dim]) - 0.5).astype("float32") +ffn2_shift = (np.random.random([1, ffn_outer_dim]) - 0.5).astype("float32") +ffn2_smooth = (np.random.random([1, ffn_outer_dim]) - 0.5).astype("float32") + +if used_in_ep_low_latency: + ffn_in_tmp = ffn_in + ffn_in = np.zeros( + [ + expert_num_per_rank, + num_max_dispatch_tokens_per_rank * num_rank, + hidden_dim, + ], + "float32", + ) + for i in range(expert_num_per_rank): + ffn_in[i][: token_num_per_expert[i]] = ffn_in_tmp[token_num_lod[i] : token_num_lod[i + 1]] + token_num_info = token_num_per_expert +else: + token_num_info = token_num_lod + +print(f"ffn_in: {ffn_in}") +print(f"token_num_lod: {token_num_lod}") +print(f"token_num_per_expert: {token_num_per_expert}") +print(f"ffn1_w: {ffn1_w}") +print(f"ffn2_w: {ffn2_w}") + + +def clip_and_round(x, quant_max_bound=127): + return np.clip(np.around(x), -quant_max_bound, quant_max_bound).astype("int8") + + +def weight_quant_wint8(w_fp32): + w_max = np.max(np.abs(w_fp32), axis=-1, keepdims=True) + w_int8 = clip_and_round(w_fp32 / w_max * 127.0) + return w_int8, w_max.reshape([-1]) + + +def weight_quant_wint4(w_fp32): + w_max = np.max(np.abs(w_fp32), axis=-1, keepdims=True) + w_int4 = clip_and_round(w_fp32 / w_max * 7.0, 7) + w_int4 = (w_int4[:, :, 1::2] & 0xF) << 4 | (w_int4[:, :, ::2] & 0xF) # pack int4 + return w_int4, w_max.reshape([-1]) + + +def weight_quant(w_fp32, algo="weight_only_int8"): + if algo == "weight_only_int8": + return weight_quant_wint8(w_fp32) + elif algo == "weight_only_int4": + return weight_quant_wint4(w_fp32) + else: + return None, None + + +quant_method = "weight_only_int4" +print(f"quant_method={quant_method}, used_in_ep_low_latency={used_in_ep_low_latency}") +ffn1_quant_w, ffn1_w_scale = weight_quant(ffn1_w, quant_method) +ffn2_quant_w, ffn2_w_scale = weight_quant(ffn2_w, quant_method) +print(f"ffn1_w {ffn1_w.shape}: {ffn1_w}") +print(f"ffn2_w {ffn2_w.shape}: {ffn2_w}") +print(f"ffn1_quant_w {ffn1_quant_w.shape}: {ffn1_quant_w}") +print(f"ffn1_w_scale {ffn1_w_scale.shape}: {ffn1_w_scale}") +print(f"ffn2_quant_w {ffn2_quant_w.shape}: {ffn2_quant_w}") +print(f"ffn2_w_scale {ffn2_w_scale.shape}: {ffn2_w_scale}") + + +def weight_dequant_wint8(w_int, w_scale): + w_shape = w_int.shape + w_scale_new_shape = list(w_shape) + w_scale_new_shape[-1] = 1 + w_scale_new = w_scale.reshape(w_scale_new_shape) + w_fp32 = w_int.astype("float32") / 127.0 * w_scale_new + return w_fp32 + + +def weight_dequant_wint4(w_int, w_scale): + w_shape = w_int.shape + w_scale_new_shape = list(w_shape) + w_scale_new_shape[-1] = 1 + # w_scale_new_shape[-2] = w_scale_new_shape[-2] * 2 + w_scale_new = w_scale.reshape(w_scale_new_shape) + w_new_shape = list(w_shape) + w_new_shape[-1] = w_new_shape[-1] * 2 + w_int8 = np.zeros(w_new_shape, dtype=np.int8) + w_int8[:, :, ::2] = w_int & 0xF + w_int8[:, :, 1::2] = (w_int >> 4) & 0xF + w_int8 = np.where(w_int8 >= 8, w_int8 - 16, w_int8) + w_fp32 = w_int8.astype("float32") / 7.0 * w_scale_new + return w_fp32 + + +def weight_dequant(w_int, w_scale, algo="weight_only_int8"): + if algo == "weight_only_int8": + return weight_dequant_wint8(w_int, w_scale) + elif algo == "weight_only_int4": + return weight_dequant_wint4(w_int, w_scale) + else: + return None, None + + +def fwt(a): + """ + 快速 Walsh-Hadamard 变换(正向变换) + :param a: 输入列表,长度必须是2的幂 + :return: 变换后的列表 + """ + n = len(a) + # 检查输入长度是否为2的幂 + if n == 0 or n & (n - 1) != 0: + raise ValueError("输入长度必须是2的幂") + + # 复制输入以避免修改原始数据 + a = a.copy() + h = 1 + while h < n: + for i in range(0, n, 2 * h): + for j in range(i, i + h): + x = a[j] + y = a[j + h] + a[j] = x + y + a[j + h] = x - y + h <<= 1 # 等同于 h *= 2 + return a + + +def hadamard(_x, block_size): + x = np.copy(_x).reshape((-1, _x.shape[-1])) + if block_size == -1: + return x + m = 1 + n = x.shape[-1] + for i in range(len(x.shape) - 1): + m = m * x.shape[i] + for i in range(m): + for j in range(0, n, block_size): + subx = x[i][j : j + block_size] + x[i][j : j + block_size] = fwt(subx) + return x.reshape(_x.shape) + + +# print(f"ffn1_w {ffn1_w.shape}: {ffn1_w}") +# ffn1_quant_w8, ffn1_w8_scale = weight_quant(ffn1_w, "weight_only_int8") +# ffn1_quant_w4, ffn1_w4_scale = weight_quant(ffn1_w, "weight_only_int4") +# print(f"ffn1_quant_w8 {ffn1_quant_w8.shape}: {ffn1_quant_w8}") +# print(f"ffn1_w8_scale {ffn1_w8_scale.shape}: {ffn1_w8_scale}") +# print(f"ffn1_quant_w4 {ffn1_quant_w4.shape}: {ffn1_quant_w4}") +# print(f"ffn1_w4_scale {ffn1_w4_scale.shape}: {ffn1_w4_scale}") + +# ffn1_w8_dq = weight_dequant(ffn1_quant_w8, ffn1_w8_scale, "weight_only_int8") +# ffn1_w4_dq = weight_dequant(ffn1_quant_w4, ffn1_w4_scale, "weight_only_int4") +# print(f"ffn1_w8_dq {ffn1_w8_dq.shape}: {ffn1_w8_dq}") +# print(f"ffn1_w4_dq {ffn1_w4_dq.shape}: {ffn1_w4_dq}") + + +def batch_matmul(x, token_num_info, w, w_scale, algo): + w_fp32 = weight_dequant(w, w_scale, algo) + print(f"x {x.shape}, w {w_fp32.shape}") + out_hidden_dim = w_fp32.shape[1] + if not used_in_ep_low_latency: + y = np.zeros([x.shape[0], out_hidden_dim], "float32") + token_num_lod = token_num_info + for i in range(expert_num_per_rank): + start_i = token_num_lod[i] + end_i = token_num_lod[i + 1] + subx = x[start_i:end_i] + subw = w_fp32[i : i + 1].transpose([0, 2, 1]) + y[start_i:end_i] = np.matmul(subx, subw) + else: + y = np.zeros( + [ + expert_num_per_rank, + num_max_dispatch_tokens_per_rank, + out_hidden_dim, + ], + "float32", + ) + token_num_per_expert = token_num_info + for i in range(expert_num_per_rank): + subx = x[i][: token_num_per_expert[i]] + subw = w_fp32[i : i + 1].transpose([0, 2, 1]) + y[i][: token_num_per_expert[i]] = np.matmul(subx, subw) + return y + + +def swiglu(x): + new_shape = list(x.shape) + new_shape[-1] //= 2 + x1 = np.copy(x[..., : new_shape[-1]]) + x2 = np.copy(x[..., new_shape[-1] :]) + y = x1 * 1.0 / (1.0 + np.exp(-x1)) * x2 + return y + + +ref_ffn1_out = batch_matmul(ffn_in, token_num_info, ffn1_quant_w, ffn1_w_scale, quant_method) +print(f"ref_ffn1_out {ref_ffn1_out.shape}: {ref_ffn1_out}") +ref_swiglu_out = swiglu(ref_ffn1_out) +print(f"ref_swiglu_out {ref_swiglu_out.shape}: {ref_swiglu_out}") +ref_swiglu_out = (ref_swiglu_out + ffn2_shift) * ffn2_smooth +ref_hadamard_out = hadamard(ref_swiglu_out, hadamard_blocksize) +ref_ffn2_out = batch_matmul( + ref_hadamard_out, + token_num_info, + ffn2_quant_w, + ffn2_w_scale, + quant_method, +) + +ffn_in_tensor = paddle.to_tensor(ffn_in).astype("bfloat16") +token_num_info_tensor = paddle.to_tensor(token_num_info) +ffn1_quant_w_tensor = paddle.to_tensor(ffn1_quant_w) +ffn2_quant_w_tensor = paddle.to_tensor(ffn2_quant_w) +ffn1_w_scale_tensor = paddle.to_tensor(ffn1_w_scale) +ffn2_w_scale_tensor = paddle.to_tensor(ffn2_w_scale) +ffn2_shift_tensor = paddle.to_tensor(ffn2_shift).astype("bfloat16") +ffn2_smooth_tensor = paddle.to_tensor(ffn2_smooth).astype("bfloat16") + +ffn2_out = moe_expert_ffn( + ffn_in_tensor, + token_num_info_tensor, + ffn1_quant_w_tensor, + ffn2_quant_w_tensor, + None, # ffn1_bias + None, # ffn2_bias + None, # ffn1_act_scale + None, # ffn2_act_scale + ffn1_w_scale_tensor, + ffn2_w_scale_tensor, + ffn2_shift_tensor, + ffn2_smooth_tensor, + quant_method, + hadamard_blocksize, + token_num, +) +ffn2_out = ffn2_out.astype("float32").numpy() +print(f"ffn2_out: {ffn2_out}") +print(f"ref_ffn2_out: {ref_ffn2_out}") + +if not used_in_ep_low_latency: + diff = np.sum(np.abs(ffn2_out - ref_ffn2_out)) / np.sum(np.abs(ffn2_out)) + print(f"diff: {diff}") + assert diff < 0.01, f"diff: {diff}\nffn2_out:\n{ffn2_out}\nref_ffn2_out:\n{ref_ffn2_out}\n" +else: + diff_all = 0 + for i in range(expert_num_per_rank): + token_num_this_expert = token_num_per_expert[i] + if token_num_this_expert == 0: + continue + tmp_ffn2_out = ffn2_out[i][:token_num_this_expert] + tmp_ref_ffn2_out = ref_ffn2_out[i][:token_num_this_expert] + diff = np.sum(np.abs(tmp_ffn2_out - tmp_ref_ffn2_out)) / np.sum(np.abs(tmp_ffn2_out)) + print(f"diff: {diff}") + print(f"{i}, tmp_ffn2_out: {tmp_ffn2_out}") + print(f"{i}, tmp_ref_ffn2_out: {tmp_ref_ffn2_out}") + diff_all += diff + diff_avg = diff_all / expert_num_per_rank + print(f"diff_avg: {diff_avg}") + assert diff_avg < 0.03, f"diff_avg: {diff_avg}\nffn2_out:\n{ffn2_out}\nref_ffn2_out:\n{ref_ffn2_out}\n" diff --git a/custom_ops/xpu_ops/test/test_moe_redundant_topk_select.py b/custom_ops/xpu_ops/test/test_moe_redundant_topk_select.py new file mode 100644 index 000000000..f8d59069f --- /dev/null +++ b/custom_ops/xpu_ops/test/test_moe_redundant_topk_select.py @@ -0,0 +1,200 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import moe_redundant_topk_select + + +def ref_moe_topk_select(gating_logits, bias, moe_topk, apply_norm_weight): + assert apply_norm_weight is True + + def _softmax(x): + axis = 1 + x_max = np.max(x, axis=axis, keepdims=True) + e_x = np.exp(x - x_max) + return e_x / np.sum(e_x, axis=axis, keepdims=True) + + softmax_logits = _softmax(gating_logits) + softmax_logits_with_bias = np.copy(softmax_logits) + if bias is not None: + softmax_logits_with_bias += bias.reshape([1, -1]) + sorted_indices = np.argsort(softmax_logits_with_bias, axis=1, kind="stable")[:, ::-1] + topk_ids = sorted_indices[:, :moe_topk] + topk_weights = np.take_along_axis(softmax_logits, topk_ids, axis=1) + topk_weights = topk_weights[:, :moe_topk] + topk_weights /= np.sum(topk_weights, axis=1, keepdims=True) + return topk_ids, topk_weights + + +def generate_expert_in_rank_num(num_values, extra_num): + if num_values <= 0: + return np.array([]) + # 一次性生成所有随机索引 + indices = np.random.randint(0, num_values, extra_num) + # 使用 bincount 统计频率(向量化操作) + bin_counts = np.bincount(indices, minlength=num_values) + # 结果 = 基础值1 + 额外增加值 + return 1 + bin_counts + + +def generate_expert_id_to_ep_rank(expert_in_rank_num_list, num_rank, redundant_num_plus_one): + num_expert = expert_in_rank_num_list.size + redundant_num = redundant_num_plus_one - 1 + # 生成随机排名ID (一次性生成) + rank_idx = np.random.randint(0, num_rank, num_expert) + # 初始化结果矩阵 (-1 表示未分配) + expert_id_to_rank_id = np.full((num_expert, redundant_num + 1), -1, dtype=int) + # 初始分配 - 每个专家分配一个基础ID + expert_ids = np.arange(num_expert) + expert_id_to_rank_id[expert_ids, 0] = rank_idx + if redundant_num > 0: + positions = np.ones(num_expert, dtype=int) + for expert_id in range(expert_in_rank_num_list.size): + repeat_num = expert_in_rank_num_list[expert_id] + while repeat_num > 1: + rank_idx = np.random.randint(0, num_rank) + expert_id_to_rank_id[expert_id][positions[expert_id]] = rank_idx + positions[expert_id] += 1 + repeat_num -= 1 + return expert_id_to_rank_id + + +def generate_rank_to_id(id_to_rank, rank_num): + max_rank = -1 + for ranks in id_to_rank: + if ranks: + current_max = max(ranks) + if current_max > max_rank: + max_rank = current_max + if max_rank < 0 or max_rank >= rank_num: + return [] + + rank_to_id = [[] for _ in range(rank_num)] + for id_val, ranks in enumerate(id_to_rank): + for r in ranks: + if r < 0: # 忽略负数值 + continue + if r < len(rank_to_id): # 确保索引在有效范围内 + rank_to_id[r].append(id_val) + return rank_to_id + + +def my_sort(key_arr, val_arr): + if key_arr.shape != val_arr.shape: + return None, None + # 不转换整个数组,逐行处理 + sorted_keys = np.empty_like(key_arr) + sorted_vals = np.empty_like(val_arr) + + for i in range(key_arr.shape[0]): + keys = key_arr[i] + vals = val_arr[i] + idx = np.lexsort((keys, vals)) + sorted_keys[i] = keys[idx] + sorted_vals[i] = vals[idx] + + return sorted_keys, sorted_vals + + +if __name__ == "__main__": + seed = np.random.randint(1, 1e9) + print(f"numpy random seed={seed}") + np.random.seed(seed) + + rank_num = 8 + token_num = 1215 + expert_num = 256 + moe_topk = 8 + redundant_ep_rank_num_plus_one = 1 # no redundant experts + apply_norm_weight = True + enable_softmax_top_k_fused = True + gating_logits = np.random.random([token_num, expert_num]).astype("float32") + bias = np.random.random([expert_num]).astype("float32") + expert_in_rank_num_list = generate_expert_in_rank_num(expert_num, redundant_ep_rank_num_plus_one - 1) + print(f"expert_in_rank_num_list={expert_in_rank_num_list}") + expert_id_to_ep_rank_array = generate_expert_id_to_ep_rank( + expert_in_rank_num_list, rank_num, redundant_ep_rank_num_plus_one + ) + tokens_per_expert_stats_list = np.random.randint(0, 20, size=(expert_num)) + print(f"expert_id_to_ep_rank_array={expert_id_to_ep_rank_array}") + print(f"tokens_per_expert_stats_list={tokens_per_expert_stats_list}") + + # ref_topk_ids, ref_topk_weights = ref_moe_topk_select( + # gating_logits, bias, moe_topk, apply_norm_weight + # ) + + gating_logits = paddle.to_tensor(gating_logits).astype("float32") + expert_id_to_ep_rank_array = paddle.to_tensor(expert_id_to_ep_rank_array).astype("int32") + expert_in_rank_num_list = paddle.to_tensor(expert_in_rank_num_list).astype("int32") + tokens_per_expert_stats_list = paddle.to_tensor(tokens_per_expert_stats_list).astype("int32") + if bias is not None: + bias = paddle.to_tensor(bias).astype("float32") + + gating_logits_ref = gating_logits.cpu() + expert_id_to_ep_rank_array_ref = expert_id_to_ep_rank_array.cpu() + expert_in_rank_num_list_ref = expert_in_rank_num_list.cpu() + tokens_per_expert_stats_list_ref = tokens_per_expert_stats_list.cpu() + bias_ref = None + if bias is not None: + bias_ref = bias.cpu() + + topk_ids, topk_weights = moe_redundant_topk_select( + gating_logits, + expert_id_to_ep_rank_array, + expert_in_rank_num_list, + tokens_per_expert_stats_list, + bias, + moe_topk, + apply_norm_weight, + enable_softmax_top_k_fused, + redundant_ep_rank_num_plus_one, + ) + topk_ids_ref, topk_weights_ref = moe_redundant_topk_select( + gating_logits_ref, + expert_id_to_ep_rank_array_ref, + expert_in_rank_num_list_ref, + tokens_per_expert_stats_list_ref, + bias_ref, + moe_topk, + apply_norm_weight, + enable_softmax_top_k_fused, + redundant_ep_rank_num_plus_one, + ) + + topk_ids_np, topk_weights_np, tokens_per_expert_stats_list_np = ( + topk_ids.numpy(), + topk_weights.numpy(), + tokens_per_expert_stats_list.numpy(), + ) + topk_ids_ref, topk_weights_ref, tokens_per_expert_stats_list_ref = ( + topk_ids_ref.numpy(), + topk_weights_ref.numpy(), + tokens_per_expert_stats_list_ref.numpy(), + ) + sorted_topk_ids, sorted_topk_weights = my_sort(topk_ids_np, topk_weights_np) + sorted_topk_ids_ref, sorted_topk_weights_ref = my_sort(topk_ids_ref, topk_weights_ref) + + assert np.array_equal( + tokens_per_expert_stats_list_np, tokens_per_expert_stats_list_ref + ), f"\ntokens_per_expert_stats_list:\n{tokens_per_expert_stats_list.numpy()}\ntokens_per_expert_stats_list_ref:\n{tokens_per_expert_stats_list_ref}" + assert np.array_equal( + sorted_topk_ids, sorted_topk_ids_ref + ), f"\ntopk_ids:\n{topk_ids.numpy()}\ntopk_ids_ref:\n{topk_ids_ref}" + assert np.allclose( + sorted_topk_weights, sorted_topk_weights_ref + ), f"\ntopk_weights:\n{topk_weights.numpy()}\ntopk_weights_ref:\n{topk_weights_ref}" + + print("Passed all tests.") diff --git a/custom_ops/xpu_ops/test/test_moe_topk_select.py b/custom_ops/xpu_ops/test/test_moe_topk_select.py new file mode 100644 index 000000000..0b0fe3ac2 --- /dev/null +++ b/custom_ops/xpu_ops/test/test_moe_topk_select.py @@ -0,0 +1,67 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import f_moe_topk_select + +np.random.seed(2025) + +token_num = 15 +expert_num = 256 +moe_topk = 8 +apply_norm_weight = True + +gating_logits = np.random.random([token_num, expert_num]).astype("float32") +bias = np.random.random([expert_num]).astype("float32") + + +def ref_moe_topk_select(gating_logits, bias, moe_topk, apply_norm_weight): + assert apply_norm_weight is True + + def _softmax(x): + axis = 1 + x_max = np.max(x, axis=axis, keepdims=True) + e_x = np.exp(x - x_max) + return e_x / np.sum(e_x, axis=axis, keepdims=True) + + softmax_logits = _softmax(gating_logits) + softmax_logits_with_bias = np.copy(softmax_logits) + if bias is not None: + softmax_logits_with_bias += bias.reshape([1, -1]) + sorted_indices = np.argsort(softmax_logits_with_bias, axis=1, kind="stable")[:, ::-1] + topk_ids = sorted_indices[:, :moe_topk] + topk_weights = np.take_along_axis(softmax_logits, topk_ids, axis=1) + topk_weights = topk_weights[:, :moe_topk] + topk_weights /= np.sum(topk_weights, axis=1, keepdims=True) + return topk_ids, topk_weights + + +ref_topk_ids, ref_topk_weights = ref_moe_topk_select(gating_logits, bias, moe_topk, apply_norm_weight) + +gating_logits = paddle.to_tensor(gating_logits) +if bias is not None: + bias = paddle.to_tensor(bias) + +topk_ids, topk_weights = f_moe_topk_select(gating_logits, bias, moe_topk, apply_norm_weight) + +assert np.array_equal( + topk_ids.numpy(), ref_topk_ids +), f"\ntopk_ids:\n{topk_ids.numpy()}\nref_topk_ids:\n{ref_topk_ids}" +assert np.allclose( + topk_weights.numpy(), ref_topk_weights +), f"\ntopk_weights:\n{topk_weights.numpy()}\nref_topk_weights:\n{ref_topk_weights}" + +print("Passed all tests.") diff --git a/custom_ops/xpu_ops/test/test_read_data_ipc.py b/custom_ops/xpu_ops/test/test_read_data_ipc.py new file mode 100644 index 000000000..54721b09e --- /dev/null +++ b/custom_ops/xpu_ops/test/test_read_data_ipc.py @@ -0,0 +1,23 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import read_data_ipc + +x = np.zeros([512, 8, 64, 128], dtype="float32") +x = paddle.to_tensor(x, place=paddle.CPUPlace()) +read_data_ipc(x, "test_set_data_ipc") +print(x.numpy().flatten()[:100]) diff --git a/custom_ops/xpu_ops/test/test_set_data_ipc.py b/custom_ops/xpu_ops/test/test_set_data_ipc.py new file mode 100644 index 000000000..3e3588ef3 --- /dev/null +++ b/custom_ops/xpu_ops/test/test_set_data_ipc.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import paddle + +from fastdeploy.model_executor.ops.xpu import set_data_ipc + +x = paddle.full(shape=[512, 8, 64, 128], fill_value=2, dtype="float32") +set_data_ipc(x, "test_set_data_ipc") +print("set_data_ipc done") + +time.sleep(60) diff --git a/custom_ops/xpu_ops/test/test_set_get_data_ipc.py b/custom_ops/xpu_ops/test/test_set_get_data_ipc.py new file mode 100644 index 000000000..d3191357c --- /dev/null +++ b/custom_ops/xpu_ops/test/test_set_get_data_ipc.py @@ -0,0 +1,45 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import time + +import paddle + +from fastdeploy.model_executor.ops.xpu import set_data_ipc, share_external_data + +shape = [8, 128] +dtype = "bfloat16" +shm_name = "xpu_shm_tensor" + +paddle.set_device("xpu:0") + +if sys.argv[1] == "0": + print("set data ipc") + input_tensor = paddle.cast(paddle.rand(shape), dtype) + set_data_ipc(input_tensor, shm_name) + print(input_tensor) + time.sleep(120) +elif sys.argv[1] == "1": + print("test share_external_data") + tmp_input = paddle.empty([], dtype=dtype) + output = share_external_data(tmp_input, shm_name, shape, use_ipc=True) + print(output.shape) + print(output.cpu()) # use xpu_memcpy +else: + print("test share_external_data") + tmp_input = paddle.empty([], dtype=dtype) + output = share_external_data(tmp_input, shm_name, shape, use_ipc=False) + temp_output = output * 1 # avoid xpu_memcpy + print(temp_output) diff --git a/custom_ops/xpu_ops/test/test_weight_only_linear.py b/custom_ops/xpu_ops/test/test_weight_only_linear.py new file mode 100644 index 000000000..fe3993e12 --- /dev/null +++ b/custom_ops/xpu_ops/test/test_weight_only_linear.py @@ -0,0 +1,138 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import ( + weight_only_linear_xpu as weight_only_linear, +) + +np.random.seed(2025) + + +def np_clip_and_round(x, abs_max=127): + return np.clip(np.around(x), -abs_max, abs_max).astype("int8") + + +def np_quant_weight_int4(weight_np): + assert weight_np.dtype == np.float32 # k,n + weight = weight_np + # weight = np.transpose(weight_np, [1, 0]) # n,k + max_value = np.max(np.abs(weight), axis=1).reshape(-1, 1) # k => k,1 + quanted_weight = np_clip_and_round(weight / max_value * 7.0, 7) # n,k + quanted_weight = (quanted_weight[:, 1::2] & 0xF) << 4 | (quanted_weight[:, ::2] & 0xF) # pack int4, [n,k//2] + weight_scales = (max_value).astype(weight_np.dtype).reshape(-1) + return quanted_weight, weight_scales.astype(np.float32) + + +def np_quant_weight(weight_np, algo="weight_only_int8"): + assert weight_np.dtype == np.float32 + + if algo == "weight_only_int4": + return np_quant_weight_int4(weight_np) + + weight = weight_np + # weight = np.transpose(weight_np, [1, 0]) + max_value = np.max(np.abs(weight), axis=1).reshape(-1, 1) + quanted_weight = np_clip_and_round(weight / max_value * 127.0) + weight_scales = (max_value).astype(weight_np.dtype).reshape(-1) + return quanted_weight, weight_scales.astype(np.float32) + + +def int8_to_bin_np(value): + value_np = np.int8(value) + return np.binary_repr(value_np, width=8) + + +def int8_to_bin(value): + if not -128 <= value <= 127: + raise ValueError("int8 值必须在 -128 到 127 之间") + return format(value & 0xFF, "08b") # '08b' 表示 8 位二进制,高位补零 + + +def weight_dequant_wint8(w_int, w_scale): + w_shape = w_int.shape + # print(f"w_shape={w_shape}") + w_scale_new_shape = list(w_shape) + w_scale_new_shape[-1] = 1 + w_scale_new = w_scale.reshape(w_scale_new_shape) + w_fp32 = w_int.astype("float32") / 127.0 * w_scale_new + return w_fp32 + + +def weight_dequant_wint4(w_int, w_scale): + w_shape = w_int.shape + w_scale_new_shape = list(w_shape) + w_scale_new_shape[-1] = 1 + # w_scale_new_shape[-2] = w_scale_new_shape[-2] * 2 + w_scale_new = w_scale.reshape(w_scale_new_shape) + w_new_shape = list(w_shape) + w_new_shape[-1] = w_new_shape[-1] * 2 + w_int8 = np.zeros(w_new_shape, dtype=np.int8) + w_int8[:, :, ::2] = w_int & 0xF + w_int8[:, :, 1::2] = (w_int >> 4) & 0xF + w_int8 = np.where(w_int8 >= 8, w_int8 - 16, w_int8) + w_fp32 = w_int8.astype("float32") / 7.0 * w_scale_new + return w_fp32 + + +def weight_dequant(w_int, w_scale, algo="weight_only_int8"): + if algo == "weight_only_int8": + return weight_dequant_wint8(w_int, w_scale) + elif algo == "weight_only_int4": + return weight_dequant_wint4(w_int, w_scale) + else: + return None, None + + +def batch_matmul(x, qw, wscale, algo, bias=None): + w_fp32 = weight_dequant(qw, wscale, algo) + # print(f"w_dequant={w_fp32}") + # print(f"x.shape={x.shape}, w.shape={w_fp32.shape}") + w_trans = np.transpose(w_fp32, [1, 0]) + y = np.matmul(x, w_trans) + if bias is not None: + y = y + bias + return y + + +# 1) preparation +m, n, k = 64, 128, 256 +algo = "weight_only_int8" +weight_dtype = "int8" +# m, n, k = 12, 14336, 8192 + +x_np = (np.random.random((m, k)).astype(np.float32) - 0.5) * 10 +w_np = (np.random.random((n, k)).astype(np.float32) - 0.5) * 10 +qw_np, wscale_np = np_quant_weight(w_np, algo) +# print(f"x_np={x_np}") +# print(f"w_np={w_np}") +# 2) np calculation +out_np = batch_matmul(x_np, qw_np, wscale_np, algo) + +# 3) xpu calculation +x_pd = paddle.to_tensor(x_np).astype("bfloat16") +qw_pd = paddle.to_tensor(qw_np) +wscale_pd = paddle.to_tensor(wscale_np).astype("float32") +out_pd = weight_only_linear(x_pd, qw_pd, wscale_pd, None, weight_dtype, -1, -1) +print(f"out_pd:\n{out_pd}") +print(f"out_np:\n{out_np}") + +# comparation +print(f"out_pd, mean={out_pd.mean()}, std={out_pd.std()}") +print(f"out_np, mean={out_np.mean()}, std={out_np.std()}") +sum_diff = np.sum(np.abs(out_pd.astype("float32").numpy() - out_np.astype("float32"))) +print(f"sum_diff: {sum_diff}") +print(f"avg_diff: {sum_diff / (m * n)}") diff --git a/docs/get_started/installation/kunlunxin_xpu.md b/docs/get_started/installation/kunlunxin_xpu.md index 81356d759..b7bcdaa8b 100644 --- a/docs/get_started/installation/kunlunxin_xpu.md +++ b/docs/get_started/installation/kunlunxin_xpu.md @@ -83,20 +83,20 @@ cd FastDeploy ### Download Kunlunxin Compilation Dependency ```bash -bash custom_ops/xpu_ops/src/download_dependencies.sh stable +bash custom_ops/xpu_ops/download_dependencies.sh stable ``` Alternatively, you can download the latest versions of XTDK and XVLLM (Not recommended) ```bash -bash custom_ops/xpu_ops/src/download_dependencies.sh develop +bash custom_ops/xpu_ops/download_dependencies.sh develop ``` Set environment variables, ```bash -export CLANG_PATH=$(pwd)/custom_ops/xpu_ops/src/third_party/xtdk -export XVLLM_PATH=$(pwd)/custom_ops/xpu_ops/src/third_party/xvllm +export CLANG_PATH=$(pwd)/custom_ops/xpu_ops/third_party/xtdk +export XVLLM_PATH=$(pwd)/custom_ops/xpu_ops/third_party/xvllm ``` ### Compile and Install. diff --git a/docs/zh/get_started/installation/kunlunxin_xpu.md b/docs/zh/get_started/installation/kunlunxin_xpu.md index 42770df90..18522581a 100644 --- a/docs/zh/get_started/installation/kunlunxin_xpu.md +++ b/docs/zh/get_started/installation/kunlunxin_xpu.md @@ -83,20 +83,20 @@ cd FastDeploy ### 下载昆仑编译依赖 ```bash -bash custom_ops/xpu_ops/src/download_dependencies.sh stable +bash custom_ops/xpu_ops/download_dependencies.sh stable ``` 或者你也可以下载最新版编译依赖 ```bash -bash custom_ops/xpu_ops/src/download_dependencies.sh develop +bash custom_ops/xpu_ops/download_dependencies.sh develop ``` 设置环境变量 ```bash -export CLANG_PATH=$(pwd)/custom_ops/xpu_ops/src/third_party/xtdk -export XVLLM_PATH=$(pwd)/custom_ops/xpu_ops/src/third_party/xvllm +export CLANG_PATH=$(pwd)/custom_ops/xpu_ops/third_party/xtdk +export XVLLM_PATH=$(pwd)/custom_ops/xpu_ops/third_party/xvllm ``` ### 开始编译并安装: diff --git a/scripts/run_ci_xpu.sh b/scripts/run_ci_xpu.sh index 6ae13bd39..a240d1aca 100644 --- a/scripts/run_ci_xpu.sh +++ b/scripts/run_ci_xpu.sh @@ -20,9 +20,9 @@ python -m pip uninstall fastdeploy-xpu -y python -m pip install paddlepaddle-xpu -i https://www.paddlepaddle.org.cn/packages/nightly/xpu-p800/ # python -m pip install https://paddle-whl.bj.bcebos.com/nightly/xpu-p800/paddlepaddle-xpu/paddlepaddle_xpu-3.0.0.dev20250901-cp310-cp310-linux_x86_64.whl echo "build whl" -bash custom_ops/xpu_ops/src/download_dependencies.sh develop -export CLANG_PATH=$(pwd)/custom_ops/xpu_ops/src/third_party/xtdk -export XVLLM_PATH=$(pwd)/custom_ops/xpu_ops/src/third_party/xvllm +bash custom_ops/xpu_ops/download_dependencies.sh develop +export CLANG_PATH=$(pwd)/custom_ops/xpu_ops/third_party/xtdk +export XVLLM_PATH=$(pwd)/custom_ops/xpu_ops/third_party/xvllm bash build.sh || exit 1 echo "pip others" python -m pip install openai -U