diff --git a/build.sh b/build.sh index 51e3cb05c..0ddc2588b 100644 --- a/build.sh +++ b/build.sh @@ -104,6 +104,15 @@ function copy_ops(){ return fi + if_corex=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device(\"iluvatar_gpu\"))"` + if [ "$if_corex" = "True" ]; then + DEVICE_TYPE="iluvatar-gpu" + cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base + cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/iluvatar + echo -e "BASE and Iluvatar ops have been copy to fastdeploy" + return + fi + DEVICE_TYPE="cpu" cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base cd ../../../../ diff --git a/custom_ops/gpu_ops/get_padding_offset.cu b/custom_ops/gpu_ops/get_padding_offset.cu index 345affe97..2e1152e42 100644 --- a/custom_ops/gpu_ops/get_padding_offset.cu +++ b/custom_ops/gpu_ops/get_padding_offset.cu @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/extension.h" +#include "helper.h" #ifndef PD_BUILD_STATIC_OP #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) @@ -59,7 +60,12 @@ std::vector GetPaddingOffset(const paddle::Tensor &input_ids, const paddle::Tensor &cum_offsets, const paddle::Tensor &token_num, const paddle::Tensor &seq_len) { +#ifdef PADDLE_WITH_CUSTOM_DEVICE + auto dev_ctx = static_cast(paddle::experimental::DeviceContextPool::Instance().Get(input_ids.place())); + auto cu_stream = dev_ctx->stream(); +#else auto cu_stream = input_ids.stream(); +#endif std::vector input_ids_shape = input_ids.shape(); const int bsz = seq_len.shape()[0]; const int seq_length = input_ids_shape[1]; @@ -75,7 +81,11 @@ std::vector GetPaddingOffset(const paddle::Tensor &input_ids, paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); auto cu_seqlens_k = paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); - int blockSize = min((token_num_data + 32 - 1) / 32 * 32, 128); +#ifdef PADDLE_WITH_COREX + int blockSize = std::min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128); +#else + int blockSize = min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128); +#endif GetPaddingOffsetKernel<<>>( padding_offset.data(), cum_offsets_out.data(), diff --git a/custom_ops/gpu_ops/helper.h b/custom_ops/gpu_ops/helper.h index f829bf1ff..48a79ef07 100644 --- a/custom_ops/gpu_ops/helper.h +++ b/custom_ops/gpu_ops/helper.h @@ -14,7 +14,9 @@ #pragma once +#ifndef PADDLE_WITH_COREX #include "glog/logging.h" +#endif #include #include #include @@ -35,22 +37,35 @@ namespace cub = hipcub; #else #include #endif +#ifndef PADDLE_WITH_COREX #include "nlohmann/json.hpp" +#endif #include #include #include "env.h" #include "paddle/extension.h" #include "paddle/phi/core/allocator.h" +#ifdef PADDLE_WITH_CUSTOM_DEVICE +#include "paddle/phi/backends/custom/custom_context.h" +#else #include "paddle/phi/core/cuda_stream.h" +#endif #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/backends/gpu/gpu_info.h" +#ifdef PADDLE_WITH_COREX +#define WARP_SIZE 64 +#else +#define WARP_SIZE 32 +#endif #ifndef PD_BUILD_STATIC_OP #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) #endif +#ifndef PADDLE_WITH_COREX using json = nlohmann::json; +#endif #define CUDA_CHECK(call) \ do { \ @@ -237,6 +252,7 @@ inline int GetBlockSize(int vocab_size) { } } +#ifndef PADDLE_WITH_COREX inline json readJsonFromFile(const std::string &filePath) { std::ifstream file(filePath); if (!file.is_open()) { @@ -247,6 +263,7 @@ inline json readJsonFromFile(const std::string &filePath) { file >> j; return j; } +#endif #define cudaCheckError() \ { \ @@ -418,6 +435,7 @@ inline std::string base64_decode(const std::string &encoded_string) { return ret; } +#ifndef PADDLE_WITH_COREX template inline T get_relative_best(nlohmann::json *json_data, const std::string &target_key, @@ -430,6 +448,7 @@ inline T get_relative_best(nlohmann::json *json_data, return default_value; } } +#endif __device__ inline bool is_in_end(const int64_t id, const int64_t *end_ids, int length) { diff --git a/custom_ops/gpu_ops/noaux_tc.cu b/custom_ops/gpu_ops/noaux_tc.cu index a14f7443b..d98b5b4b4 100644 --- a/custom_ops/gpu_ops/noaux_tc.cu +++ b/custom_ops/gpu_ops/noaux_tc.cu @@ -18,7 +18,6 @@ #include #include -#include "helper.h" #include "noauxtc_kernel.h" std::vector NoauxTc(paddle::Tensor& scores, diff --git a/custom_ops/gpu_ops/noauxtc_kernel.h b/custom_ops/gpu_ops/noauxtc_kernel.h index bce305edc..c91d4f5b3 100644 --- a/custom_ops/gpu_ops/noauxtc_kernel.h +++ b/custom_ops/gpu_ops/noauxtc_kernel.h @@ -17,11 +17,11 @@ #pragma once #include #include +#include "helper.h" namespace cg = cooperative_groups; constexpr unsigned FULL_WARP_MASK = 0xffffffff; -constexpr int32_t WARP_SIZE = 32; constexpr int32_t BLOCK_SIZE = 512; constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE; diff --git a/custom_ops/gpu_ops/rebuild_padding.cu b/custom_ops/gpu_ops/rebuild_padding.cu index a20948001..3d69e9e45 100644 --- a/custom_ops/gpu_ops/rebuild_padding.cu +++ b/custom_ops/gpu_ops/rebuild_padding.cu @@ -91,7 +91,12 @@ std::vector rebuild_padding( typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; +#ifdef PADDLE_WITH_CUSTOM_DEVICE + auto dev_ctx = static_cast(paddle::experimental::DeviceContextPool::Instance().Get(tmp_out.place())); + auto cu_stream = dev_ctx->stream(); +#else auto cu_stream = tmp_out.stream(); +#endif std::vector tmp_out_shape = tmp_out.shape(); const int token_num = tmp_out_shape[0]; const int dim_embed = tmp_out_shape[1]; @@ -125,7 +130,7 @@ std::vector rebuild_padding( if (output_padding_offset) { RebuildAppendPaddingKernel - <<>>( + <<>>( reinterpret_cast(out.data()), reinterpret_cast(tmp_out.data()), cum_offsets.data(), @@ -138,7 +143,7 @@ std::vector rebuild_padding( elem_nums); } else { RebuildPaddingKernel - <<>>( + <<>>( reinterpret_cast(out.data()), reinterpret_cast( const_cast(tmp_out.data())), diff --git a/custom_ops/gpu_ops/sample_kernels/air_top_p_sampling.cu b/custom_ops/gpu_ops/sample_kernels/air_top_p_sampling.cu index 0d73e0bd5..ade1d74b5 100644 --- a/custom_ops/gpu_ops/sample_kernels/air_top_p_sampling.cu +++ b/custom_ops/gpu_ops/sample_kernels/air_top_p_sampling.cu @@ -376,7 +376,6 @@ __global__ void air_topp_sampling(Counter *counters, T *histograms, } // scan/find - constexpr int WARP_SIZE = 32; constexpr int WARP_COUNT = NumBuckets / WARP_SIZE; namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); diff --git a/custom_ops/gpu_ops/set_value_by_flags.cu b/custom_ops/gpu_ops/set_value_by_flags.cu index 6c92eaf3f..38d2ea045 100644 --- a/custom_ops/gpu_ops/set_value_by_flags.cu +++ b/custom_ops/gpu_ops/set_value_by_flags.cu @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/extension.h" +#include "helper.h" #ifndef PD_BUILD_STATIC_OP #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) @@ -51,13 +52,18 @@ void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &step_idx, const paddle::Tensor &stop_flags) { +#ifdef PADDLE_WITH_CUSTOM_DEVICE + auto dev_ctx = static_cast(paddle::experimental::DeviceContextPool::Instance().Get(stop_flags.place())); + auto cu_stream = dev_ctx->stream(); +#else auto cu_stream = stop_flags.stream(); +#endif std::vector pre_ids_all_shape = pre_ids_all.shape(); int bs = seq_lens_this_time.shape()[0]; int length = pre_ids_all_shape[1]; int length_input_ids = input_ids.shape()[1]; - int block_size = (bs + 32 - 1) / 32 * 32; + int block_size = (bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; set_value_by_flag_and_id<<<1, block_size, 0, cu_stream>>>( stop_flags.data(), const_cast(pre_ids_all.data()), diff --git a/custom_ops/gpu_ops/step.cu b/custom_ops/gpu_ops/step.cu index dc2487c9f..90b95c983 100644 --- a/custom_ops/gpu_ops/step.cu +++ b/custom_ops/gpu_ops/step.cu @@ -189,7 +189,7 @@ __global__ void free_and_dispatch_block(bool *stop_flags, ? tmp_used_len + 1 : max_decoder_block_num_this_seq; #ifdef DEBUG_STEP - printf("#### ori_step_len:%d, ori_free_list_len:%d, used_len:%d \n", + printf("#### ori_step_len:%d, ori_free_list_len:%d, used_len:%d \n", ori_step_len, ori_free_list_len, used_len); #endif while (ori_step_len > 0 && ori_free_list_len >= used_len) { @@ -323,7 +323,12 @@ void StepPaddle(const paddle::Tensor &stop_flags, const paddle::Tensor &first_token_ids, const int block_size, const int encoder_decoder_block_num) { +#ifdef PADDLE_WITH_CUSTOM_DEVICE + auto dev_ctx = static_cast(paddle::experimental::DeviceContextPool::Instance().Get(seq_lens_this_time.place())); + auto cu_stream = dev_ctx->stream(); +#else auto cu_stream = seq_lens_this_time.stream(); +#endif const int bsz = seq_lens_this_time.shape()[0]; const int block_num_per_seq = block_tables.shape()[1]; const int length = input_ids.shape()[1]; diff --git a/custom_ops/gpu_ops/stop_generation_multi_ends.cu b/custom_ops/gpu_ops/stop_generation_multi_ends.cu index a804eba43..fcabc009b 100644 --- a/custom_ops/gpu_ops/stop_generation_multi_ends.cu +++ b/custom_ops/gpu_ops/stop_generation_multi_ends.cu @@ -74,11 +74,16 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids, } } +#ifdef PADDLE_WITH_CUSTOM_DEVICE + auto dev_ctx = static_cast(paddle::experimental::DeviceContextPool::Instance().Get(topk_ids.place())); + auto cu_stream = dev_ctx->stream(); +#else auto cu_stream = topk_ids.stream(); +#endif std::vector shape = topk_ids.shape(); int64_t bs_now = shape[0]; int64_t end_length = end_ids.shape()[0]; - int block_size = (bs_now + 32 - 1) / 32 * 32; + int block_size = (bs_now + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; set_value_by_flags<<<1, block_size, 0, cu_stream>>>( const_cast(stop_flags.data()), const_cast(topk_ids.data()), diff --git a/custom_ops/gpu_ops/stop_generation_multi_stop_seqs.cu b/custom_ops/gpu_ops/stop_generation_multi_stop_seqs.cu index a053a939d..c2a14c2cc 100644 --- a/custom_ops/gpu_ops/stop_generation_multi_stop_seqs.cu +++ b/custom_ops/gpu_ops/stop_generation_multi_stop_seqs.cu @@ -21,6 +21,7 @@ #include #include #include "paddle/extension.h" +#include "helper.h" #ifndef PD_BUILD_STATIC_OP #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) @@ -88,7 +89,12 @@ void GetStopFlagsMultiSeqs(const paddle::Tensor &topk_ids, PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64); PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL); +#ifdef PADDLE_WITH_CUSTOM_DEVICE + auto dev_ctx = static_cast(paddle::experimental::DeviceContextPool::Instance().Get(topk_ids.place())); + auto cu_stream = dev_ctx->stream(); +#else auto cu_stream = topk_ids.stream(); +#endif std::vector shape = topk_ids.shape(); std::vector stop_seqs_shape = stop_seqs.shape(); int bs_now = shape[0]; @@ -96,7 +102,7 @@ void GetStopFlagsMultiSeqs(const paddle::Tensor &topk_ids, int stop_seqs_max_len = stop_seqs_shape[1]; int pre_ids_len = pre_ids.shape()[1]; - int block_size = (stop_seqs_bs + 31) / 32 * 32; + int block_size = (stop_seqs_bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; set_value_by_stop_seqs<<>>( const_cast(stop_flags.data()), const_cast(topk_ids.data()), diff --git a/custom_ops/gpu_ops/token_penalty_multi_scores.cu b/custom_ops/gpu_ops/token_penalty_multi_scores.cu index c15289e0c..a930791e7 100644 --- a/custom_ops/gpu_ops/token_penalty_multi_scores.cu +++ b/custom_ops/gpu_ops/token_penalty_multi_scores.cu @@ -132,7 +132,12 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids, typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; +#ifdef PADDLE_WITH_CUSTOM_DEVICE + auto dev_ctx = static_cast(paddle::experimental::DeviceContextPool::Instance().Get(logits.place())); + auto cu_stream = dev_ctx->stream(); +#else auto cu_stream = logits.stream(); +#endif std::vector shape = logits.shape(); auto repeat_times = paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place()); @@ -143,7 +148,7 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids, int64_t end_length = eos_token_id.shape()[0]; - int block_size = (bs + 32 - 1) / 32 * 32; + int block_size = (bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; min_length_logits_process<<<1, block_size, 0, cu_stream>>>( reinterpret_cast( const_cast(logits.data())), @@ -154,8 +159,12 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids, length, end_length); - block_size = (length_id + 32 - 1) / 32 * 32; + block_size = (length_id + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; +#ifdef PADDLE_WITH_COREX + block_size = std::min(block_size, 512); +#else block_size = min(block_size, 512); +#endif update_repeat_times<<>>( pre_ids.data(), cur_len.data(), @@ -164,8 +173,12 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids, length, length_id); - block_size = (length + 32 - 1) / 32 * 32; + block_size = (length + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; +#ifdef PADDLE_WITH_COREX + block_size = std::min(block_size, 512); +#else block_size = min(block_size, 512); +#endif update_value_by_repeat_times<<>>( repeat_times.data(), reinterpret_cast( @@ -180,8 +193,12 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids, bs, length); - block_size = (length_bad_words + 32 - 1) / 32 * 32; + block_size = (length_bad_words + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; +#ifdef PADDLE_WITH_COREX + block_size = std::min(block_size, 512); +#else block_size = min(block_size, 512); +#endif ban_bad_words<<>>( reinterpret_cast( const_cast(logits.data())), diff --git a/custom_ops/gpu_ops/update_inputs.cu b/custom_ops/gpu_ops/update_inputs.cu index 78f39e353..c58aeb39c 100644 --- a/custom_ops/gpu_ops/update_inputs.cu +++ b/custom_ops/gpu_ops/update_inputs.cu @@ -75,11 +75,17 @@ void UpdateInputes(const paddle::Tensor &stop_flags, const paddle::Tensor &stop_nums, const paddle::Tensor &next_tokens, const paddle::Tensor &is_block_step) { +#ifdef PADDLE_WITH_CUSTOM_DEVICE + auto dev_ctx = static_cast(paddle::experimental::DeviceContextPool::Instance().Get(input_ids.place())); + auto cu_stream = dev_ctx->stream(); +#else + auto cu_stream = input_ids.stream(); +#endif const int max_bsz = stop_flags.shape()[0]; const int now_bsz = seq_lens_this_time.shape()[0]; const int input_ids_stride = input_ids.shape()[1]; auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false); - update_inputs_kernel<1024><<<1, 1024, 0, input_ids.stream()>>>( + update_inputs_kernel<1024><<<1, 1024, 0, cu_stream>>>( const_cast(not_need_stop_gpu.data()), const_cast(seq_lens_this_time.data()), const_cast(seq_lens_encoder.data()), diff --git a/custom_ops/iluvatar_ops/fused_moe_helper.h b/custom_ops/iluvatar_ops/fused_moe_helper.h new file mode 100644 index 000000000..4a9ce04db --- /dev/null +++ b/custom_ops/iluvatar_ops/fused_moe_helper.h @@ -0,0 +1,55 @@ + +/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "fused_moe_op.h" + +namespace phi { + +template +__global__ void moe_token_type_ids_kernel(T *gating_output, + const int *moe_token_type_ids_out, + const int num_rows, + const int num_experts, + const int k) { + const int moe_token_index = blockIdx.x * blockDim.x + threadIdx.x; + + if (moe_token_index >= num_rows) { + return; + } + + gating_output[moe_token_index * 2] = + gating_output[moe_token_index * 2] + + (moe_token_type_ids_out[moe_token_index]) * -1e10; + gating_output[moe_token_index * 2 + 1] = + gating_output[moe_token_index * 2 + 1] + + (1 - moe_token_type_ids_out[moe_token_index]) * -1e10; +} + +template +void moe_token_type_ids_kernelLauncher(T *gating_output, + const int *moe_token_type_ids_out, + const int num_rows, + const int num_experts, + const int k, + cudaStream_t stream) { + const int blocks = num_rows * k / 512 + 1; + const int threads = 512; + moe_token_type_ids_kernel<<>>( + gating_output, moe_token_type_ids_out, num_rows, num_experts, k); +} + +} // namespace phi diff --git a/custom_ops/iluvatar_ops/fused_moe_imp_op.h b/custom_ops/iluvatar_ops/fused_moe_imp_op.h new file mode 100644 index 000000000..254f80e67 --- /dev/null +++ b/custom_ops/iluvatar_ops/fused_moe_imp_op.h @@ -0,0 +1,127 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include "cub/cub.cuh" + +namespace phi { + +static const float HALF_FLT_MAX = 65504.F; +static const float HALF_FLT_MIN = -65504.F; +static inline size_t AlignTo16(const size_t& input) { + static constexpr int ALIGNMENT = 16; + return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT); +} + +class CubKeyValueSorter { + public: + CubKeyValueSorter() : num_experts_(0), num_bits_(sizeof(int) * 8) {} + + explicit CubKeyValueSorter(const int num_experts) + : num_experts_(num_experts), + num_bits_(static_cast(log2(num_experts)) + 1) {} + + void update_num_experts(const int num_experts) { + num_experts_ = num_experts; + num_bits_ = static_cast(log2(num_experts)) + 1; + } + + size_t getWorkspaceSize(const size_t num_key_value_pairs, + bool descending = false) { + num_key_value_pairs_ = num_key_value_pairs; + size_t required_storage = 0; + int* null_int = nullptr; + if (descending) { + cub::DeviceRadixSort::SortPairsDescending(NULL, + required_storage, + null_int, + null_int, + null_int, + null_int, + num_key_value_pairs, + 0, + 32); + } else { + cub::DeviceRadixSort::SortPairs(NULL, + required_storage, + null_int, + null_int, + null_int, + null_int, + num_key_value_pairs, + 0, + num_bits_); + } + return required_storage; + } + + template + void run(void* workspace, + const size_t workspace_size, + const KeyT* keys_in, + KeyT* keys_out, + const int* values_in, + int* values_out, + const size_t num_key_value_pairs, + bool descending, + cudaStream_t stream) { + size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs); + size_t actual_ws_size = workspace_size; + + if (expected_ws_size > workspace_size) { + std::stringstream err_ss; + err_ss << "[Error][CubKeyValueSorter::run]\n"; + err_ss << "Error. The allocated workspace is too small to run this " + "problem.\n"; + err_ss << "Expected workspace size of at least " << expected_ws_size + << " but got problem size " << workspace_size << "\n"; + throw std::runtime_error(err_ss.str()); + } + if (descending) { + cub::DeviceRadixSort::SortPairsDescending(workspace, + actual_ws_size, + keys_in, + keys_out, + values_in, + values_out, + num_key_value_pairs, + 0, + 32, + stream); + } else { + cub::DeviceRadixSort::SortPairs(workspace, + actual_ws_size, + keys_in, + keys_out, + values_in, + values_out, + num_key_value_pairs, + 0, + num_bits_, + stream); + } + } + + private: + size_t num_key_value_pairs_; + int num_experts_; + int num_bits_; +}; + +} // namespace phi diff --git a/custom_ops/iluvatar_ops/fused_moe_op.h b/custom_ops/iluvatar_ops/fused_moe_op.h new file mode 100644 index 000000000..91bd589f7 --- /dev/null +++ b/custom_ops/iluvatar_ops/fused_moe_op.h @@ -0,0 +1,990 @@ +// /* +// * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & +// * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 +// * +// * Licensed under the Apache License, Version 2.0 (the "License"); +// * you may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ + +#pragma once + +#include +#include +#include "fused_moe_imp_op.h" +#include "fused_moe_helper.h" +// Ignore CUTLASS warnings about type punning +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#pragma GCC diagnostic ignored "-Wunused-function" + +// #include "paddle/phi/backends/gpu/gpu_info.h" +#pragma GCC diagnostic pop + +#include "helper.h" + +namespace phi { + +struct GpuLaunchConfig { + dim3 block_per_grid; + dim3 thread_per_block; +}; + +inline GpuLaunchConfig Get1DBlocksAnd2DGridsMoe(const int64_t cols) { + int blocks_x = cols; + int blocks_y = 1; + int blocks_z = 1; + if (blocks_x > 1024) { + blocks_y = 256; + blocks_x = (blocks_x + blocks_y - 1) / blocks_y; + } + + GpuLaunchConfig config; + config.block_per_grid.x = blocks_x; + config.block_per_grid.y = blocks_y; + config.block_per_grid.z = blocks_z; + return config; +} + +// ====================== Softmax things =============================== +// We have our own implementation of softmax here so we can support transposing +// the output in the softmax kernel when we extend this module to support +// expert-choice routing. +template +__launch_bounds__(TPB) __global__ + void group_moe_softmax(const T* input, + T* output, + T* softmax_max_prob, + const int64_t num_cols, + const int64_t softmax_num_rows) { + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + __shared__ float normalizing_factor; + __shared__ float float_max; + __shared__ float max_out; + + int globalIdx = blockIdx.x + blockIdx.y * gridDim.x; + if (globalIdx >= softmax_num_rows) { + return; + } + const int64_t thread_row_offset = globalIdx * num_cols; + + cub::Sum sum; + float threadData(-FLT_MAX); + + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + threadData = max(static_cast(input[idx]), threadData); + } + + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); + if (threadIdx.x == 0) { + float_max = maxElem; + } + __syncthreads(); + + threadData = 0; + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + threadData += exp((static_cast(input[idx]) - float_max)); + } + + const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); + + if (threadIdx.x == 0) { + normalizing_factor = 1.f / Z; + } + __syncthreads(); + + threadData = 0; + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + const float val = + exp((static_cast(input[idx]) - float_max)) * normalizing_factor; + output[idx] = T(val); + threadData = max(static_cast(T(val)), threadData); + } + + const float maxOut = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); + if (threadIdx.x == 0) { + // group max probs + max_out = 1.f / maxOut; + softmax_max_prob[globalIdx] = T(max_out); + } + __syncthreads(); + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + // group softmax normalization + output[idx] = output[idx] * static_cast(max_out); + } +} + +template +__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, + T* output, + IdxT* indices, + int* source_rows, + T* softmax_max_prob, + const int64_t num_experts, + const int64_t k, + const int64_t num_rows) { + using cub_kvp = cub::KeyValuePair; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + cub_kvp thread_kvp; + cub::ArgMax arg_max; + + const int block_row = blockIdx.x + blockIdx.y * gridDim.x; + if (block_row >= num_rows) { + return; + } + + const bool should_process_row = true; + const int thread_read_offset = block_row * num_experts; + + for (int k_idx = 0; k_idx < k; ++k_idx) { + thread_kvp.key = 0; + thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities + + cub_kvp inp_kvp; + for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { + const int idx = thread_read_offset + expert; + inp_kvp.key = expert; + inp_kvp.value = inputs_after_softmax[idx]; + + for (int prior_k = 0; prior_k < k_idx; ++prior_k) { + const IdxT prior_winning_expert = indices[k * block_row + prior_k]; + + if (prior_winning_expert == expert) { + inp_kvp = thread_kvp; + } + } + + thread_kvp = arg_max(inp_kvp, thread_kvp); + } + + const cub_kvp result_kvp = + BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); + if (threadIdx.x == 0) { + const int idx = k * block_row + k_idx; + // restore normalized probes + output[idx] = result_kvp.value / T(softmax_max_prob[idx]); + indices[idx] = should_process_row ? result_kvp.key : num_experts; + source_rows[idx] = k_idx * num_rows + block_row; + } + __syncthreads(); + } +} + +template +__launch_bounds__(TPB) __global__ void moe_softmax(const T* input, + T* output, + const int64_t num_cols, + const int64_t num_rows) { + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + __shared__ float normalizing_factor; + __shared__ float float_max; + + int globalIdx = blockIdx.x + blockIdx.y * gridDim.x; + if (globalIdx >= num_rows) { + return; + } + const int64_t thread_row_offset = globalIdx * num_cols; + + cub::Sum sum; + float threadData(-FLT_MAX); + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + threadData = max(static_cast(input[idx]), threadData); + } + + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); + if (threadIdx.x == 0) { + float_max = maxElem; + } + __syncthreads(); + + threadData = 0; + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + threadData += exp((static_cast(input[idx]) - float_max)); + } + + const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); + + if (threadIdx.x == 0) { + normalizing_factor = 1.f / Z; + } + __syncthreads(); + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + const float val = + exp((static_cast(input[idx]) - float_max)) * normalizing_factor; + output[idx] = T(val); + } +} + +template +__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, + const T* bias, + T* output, + IdxT* indices, + int* source_rows, + const int64_t num_experts, + const int64_t k, + const int64_t num_rows) { + using cub_kvp = cub::KeyValuePair; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + cub_kvp thread_kvp; + cub::ArgMax arg_max; + + const int block_row = blockIdx.x + blockIdx.y * gridDim.x; + if (block_row >= num_rows) { + return; + } + + const bool should_process_row = true; + const int thread_read_offset = block_row * num_experts; + + for (int k_idx = 0; k_idx < k; ++k_idx) { + thread_kvp.key = 0; + thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities + + cub_kvp inp_kvp; + for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { + const int idx = thread_read_offset + expert; + inp_kvp.key = expert; + inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ; + + for (int prior_k = 0; prior_k < k_idx; ++prior_k) { + const IdxT prior_winning_expert = indices[k * block_row + prior_k]; + + if (prior_winning_expert == expert) { + inp_kvp = thread_kvp; + } + } + + thread_kvp = arg_max(inp_kvp, thread_kvp); + } + + const cub_kvp result_kvp = + BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); + if (threadIdx.x == 0) { + const int idx = k * block_row + k_idx; + output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value; + indices[idx] = should_process_row ? result_kvp.key : num_experts; + source_rows[idx] = k_idx * num_rows + block_row; + } + __syncthreads(); + } +} + +template +__launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input, + const T* bias, + T* output, + IdxT* indices, + int* source_rows, + const int64_t num_experts, + const int64_t k, + const int64_t num_rows) { + // softmax + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + __shared__ float normalizing_factor; + __shared__ float float_max; + + int globalIdx = blockIdx.x + blockIdx.y * gridDim.x; + if (globalIdx >= num_rows) { + return; + } + const int64_t thread_row_offset = globalIdx * num_experts; + const int64_t idx = thread_row_offset+threadIdx.x; + + cub::Sum sum; + + float threadData = (threadIdx.x < num_experts) ? static_cast(input[idx]) :(-FLT_MAX); + + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); + if (threadIdx.x == 0) { + float_max = maxElem; + } + __syncthreads(); + + float threadDataSub = threadData - float_max; + float threadDataExp = exp(threadDataSub); + + const auto Z = BlockReduce(tmpStorage).Reduce(threadDataExp, sum); + + if (threadIdx.x == 0) { + normalizing_factor = 1.f / Z; + } + __syncthreads(); + + T val = T(threadDataExp * normalizing_factor); + + // top_k + using cub_kvp = cub::KeyValuePair; + using BlockReduceP = cub::BlockReduce; + __shared__ typename BlockReduceP::TempStorage tmpStorageP; + + cub_kvp thread_kvp; + cub::ArgMax arg_max; + + for (int k_idx = 0; k_idx < k; ++k_idx) { + thread_kvp.key = 0; + thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities + + if (threadIdx.x < num_experts) { + cub_kvp inp_kvp; + int expert = threadIdx.x; + inp_kvp.key = expert; + inp_kvp.value = bias ? val + bias[expert] : val; + + for (int prior_k = 0; prior_k < k_idx; ++prior_k) { + const IdxT prior_winning_expert = indices[k * globalIdx + prior_k]; + + if (prior_winning_expert == expert) { + inp_kvp = thread_kvp; + } + } + thread_kvp = arg_max(inp_kvp, thread_kvp); + } + + const cub_kvp result_kvp = + BlockReduceP(tmpStorageP).Reduce(thread_kvp, arg_max); + if (threadIdx.x == 0) { + const int cur_idx = k * globalIdx + k_idx; + output[cur_idx] = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value; + indices[cur_idx] = result_kvp.key; + source_rows[cur_idx] = k_idx * num_rows + globalIdx; + } + __syncthreads(); + } +} + +template +__launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_softmax, + const T* bias, + T* output, + IdxT* indices, + int* source_rows, + const int64_t num_experts, + const int64_t k, + const int64_t num_rows) { + using cub_kvp = cub::KeyValuePair; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + cub_kvp thread_kvp; + cub::ArgMax arg_max; + + const int block_row = blockIdx.x + blockIdx.y * gridDim.x; + if (block_row >= num_rows) { + return; + } + + const bool should_process_row = true; + const int thread_read_offset = block_row * num_experts; + T weight_sum = static_cast(0); + + extern __shared__ char smem[]; + + T* row_outputs = reinterpret_cast(smem); + + for (int k_idx = 0; k_idx < k; ++k_idx) { + thread_kvp.key = 0; + thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities + + cub_kvp inp_kvp; + for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { + const int idx = thread_read_offset + expert; + inp_kvp.key = expert; + inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ; + + for (int prior_k = 0; prior_k < k_idx; ++prior_k) { + const int prior_winning_expert = indices[k * block_row + prior_k]; + + if (prior_winning_expert == expert) { + inp_kvp = thread_kvp; + } + } + + thread_kvp = arg_max(inp_kvp, thread_kvp); + } + + const cub_kvp result_kvp = + BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); + if (threadIdx.x == 0) { + const int idx = k * block_row + k_idx; + // output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value; + indices[idx] = should_process_row ? result_kvp.key : num_experts; + source_rows[idx] = k_idx * num_rows + block_row; + + T row_out = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value; + row_outputs[k_idx] = row_out; + weight_sum += row_out; + } + __syncthreads(); + } + if (threadIdx.x < WARP_SIZE) { + weight_sum = __shfl_sync(0xffffffff, weight_sum, 0); + } + + if (threadIdx.x < k) { + output[k * block_row + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum; + } +} + + +template +__launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* input, + const T* bias, + T* output, + IdxT* indices, + int* source_rows, + const int64_t num_experts, + const int64_t k, + const int64_t num_rows) { + // softmax + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + __shared__ float normalizing_factor; + __shared__ float float_max; + + int globalIdx = blockIdx.x + blockIdx.y * gridDim.x; + if (globalIdx >= num_rows) { + return; + } + const int64_t thread_row_offset = globalIdx * num_experts; + const int64_t idx = thread_row_offset+threadIdx.x; + + cub::Sum sum; + + float threadData = (threadIdx.x < num_experts) ? static_cast(input[idx]) :(-FLT_MAX); + + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); + if (threadIdx.x == 0) { + float_max = maxElem; + } + __syncthreads(); + + float threadDataSub = threadData - float_max; + float threadDataExp = exp(threadDataSub); + + const auto Z = BlockReduce(tmpStorage).Reduce(threadDataExp, sum); + + if (threadIdx.x == 0) { + normalizing_factor = 1.f / Z; + } + + __syncthreads(); + + T val = T(threadDataExp * normalizing_factor); + + // top_k + using cub_kvp = cub::KeyValuePair; + using BlockReduceP = cub::BlockReduce; + __shared__ typename BlockReduceP::TempStorage tmpStorageP; + + cub_kvp thread_kvp; + cub::ArgMax arg_max; + + T weight_sum = static_cast(0); + extern __shared__ char smem[]; + T* row_outputs = reinterpret_cast(smem); + + for (int k_idx = 0; k_idx < k; ++k_idx) { + thread_kvp.key = 0; + thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities + + if (threadIdx.x < num_experts) { + cub_kvp inp_kvp; + int expert = threadIdx.x; + inp_kvp.key = expert; + inp_kvp.value = bias ? val + bias[expert] : val; + + for (int prior_k = 0; prior_k < k_idx; ++prior_k) { + const IdxT prior_winning_expert = indices[k * globalIdx + prior_k]; + + if (prior_winning_expert == expert) { + inp_kvp = thread_kvp; + } + } + thread_kvp = arg_max(inp_kvp, thread_kvp); + } + + const cub_kvp result_kvp = + BlockReduceP(tmpStorageP).Reduce(thread_kvp, arg_max); + if (threadIdx.x == 0) { + const int cur_idx = k * globalIdx + k_idx; + + T row_out = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value; + row_outputs[k_idx] = row_out; + weight_sum += row_out; + + indices[cur_idx] = result_kvp.key; + source_rows[cur_idx] = k_idx * num_rows + globalIdx; + } + __syncthreads(); + } + + if (threadIdx.x < WARP_SIZE) { + weight_sum = __shfl_sync(0xffffffff, weight_sum, 0); + } + + if (threadIdx.x < k) { + output[k * globalIdx + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum; + } +} + +namespace detail { +// Constructs some constants needed to partition the work across threads at +// compile time. +template +struct TopkConstants { + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T); + static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || + EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, + ""); + static constexpr int VECs_PER_THREAD = + std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); + static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; + static constexpr int THREADS_PER_ROW = EXPERTS / VPT; + static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; +}; +} // namespace detail + +template +void topk_gating_softmax_kernelLauncher(const T* input, + const T* gating_correction_bias, + T* output, + T* softmax, + IdxT* indices, + int* source_row, + T* softmax_max_prob, + const int64_t num_rows, + const int64_t num_experts, + const int64_t k, + const bool group_moe, + cudaStream_t stream, + const bool topk_only_mode = false) { + if (topk_only_mode) { + static constexpr int TPB = 256; + const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows); + moe_top_k<<>>( + input, gating_correction_bias, output, indices, source_row, num_experts, k, num_rows); + return; + } + static constexpr int WARPS_PER_TB = 4; + + #define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \ + case N: { \ + topk_gating_softmax_launcher_helper( \ + input, output, indices, source_row, num_rows, num_experts, k, stream); \ + break; \ + } + int64_t tem_num_experts = num_experts; + if(gating_correction_bias != nullptr) tem_num_experts = 0; + switch (tem_num_experts) { + //LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2) + //LAUNCH_TOPK_GATING_SOFTMAX_HELPER(4) + //LAUNCH_TOPK_GATING_SOFTMAX_HELPER(8) + //LAUNCH_TOPK_GATING_SOFTMAX_HELPER(16) + //LAUNCH_TOPK_GATING_SOFTMAX_HELPER(32) + //LAUNCH_TOPK_GATING_SOFTMAX_HELPER(64) + //LAUNCH_TOPK_GATING_SOFTMAX_HELPER(128) + //LAUNCH_TOPK_GATING_SOFTMAX_HELPER(256) + + default: { + static constexpr int TPB = 256; + if (group_moe) { + const int group_experts = num_experts / k; + const int softmax_num_rows = num_rows * k; + const auto config_softmax = Get1DBlocksAnd2DGridsMoe(softmax_num_rows); + group_moe_softmax + <<>>( + input, + softmax, + softmax_max_prob, + group_experts, + softmax_num_rows); + const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows); + moe_top_k + <<>>(softmax, + output, + indices, + source_row, + softmax_max_prob, + num_experts, + k, + num_rows); + } else { + const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows); + moe_softmax<<>>( + input, softmax, num_experts, num_rows); + moe_top_k + <<>>(softmax, + gating_correction_bias, + output, + indices, + source_row, + num_experts, + k, + num_rows); + } + } + } +} + +// ========================== Permutation things +// ======================================= + +// Duplicated and permutes rows for MoE. In addition, reverse the permutation +// map to help with finalizing routing. + +// "expanded_x_row" simply means that the number of values is num_rows x k. It +// is "expanded" since we will have to duplicate some rows in the input matrix +// to match the dimensions. Duplicates will always get routed to separate +// experts in the end. + +// Note that the expanded_dest_row_to_expanded_source_row map referred to here +// has indices in the range (0, k*rows_in_input - 1). However, it is set up so +// that index 0, rows_in_input, 2*rows_in_input ... (k-1)*rows_in_input all map +// to row 0 in the original matrix. Thus, to know where to read in the source +// matrix, we simply take the modulus of the expanded index. + +template +__global__ void initialize_moe_routing_kernel( + const T* unpermuted_input, + T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + const int64_t num_rows, + const int64_t active_rows, + const int64_t cols, + const int64_t num_rows_k) { + using LoadT = AlignedVector; + LoadT src_vec; + + // Reverse permutation map. + // I do this so that later, we can use the source -> dest map to do the k-way + // reduction and unpermuting. I need the reverse map for that reduction to + // allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 + // thread block will be responsible for all k summations. + const int expanded_dest_row = blockIdx.x + blockIdx.y * gridDim.x; + if (expanded_dest_row >= num_rows_k) return; + const int expanded_source_row = + expanded_dest_row_to_expanded_source_row[expanded_dest_row]; + if (threadIdx.x == 0) { + expanded_source_row_to_expanded_dest_row[expanded_source_row] = + expanded_dest_row; + } + + if ((blockIdx.x + blockIdx.y * gridDim.x) < active_rows) { + // Duplicate and permute rows + const int source_row = expanded_source_row % num_rows; + + const T* source_row_ptr = unpermuted_input + source_row * cols; + T* dest_row_ptr = permuted_output + expanded_dest_row * cols; + + for (int tid = threadIdx.x * VecSize; tid < cols; + tid += blockDim.x * VecSize) { + // dest_row_ptr[tid] = source_row_ptr[tid]; + Load(&source_row_ptr[tid], &src_vec); + Store(src_vec, &dest_row_ptr[tid]); + } + } +} + +template +void initialize_moe_routing_kernelLauncher( + const T* unpermuted_input, + T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + const int64_t num_rows, + const int64_t active_rows, + const int64_t cols, + const int64_t k, + cudaStream_t stream) { + const int threads = std::min(cols, int64_t(1024)); + constexpr int max_pack_size = 16 / sizeof(T); + const auto config_initialize = Get1DBlocksAnd2DGridsMoe(num_rows * k); + if (cols % max_pack_size == 0) { + initialize_moe_routing_kernel + <<>>( + unpermuted_input, + permuted_output, + expanded_dest_row_to_expanded_source_row, + expanded_source_row_to_expanded_dest_row, + num_rows, + k * active_rows, + cols, + num_rows * k); + } else { + initialize_moe_routing_kernel + <<>>( + unpermuted_input, + permuted_output, + expanded_dest_row_to_expanded_source_row, + expanded_source_row_to_expanded_dest_row, + num_rows, + k * active_rows, + cols, + num_rows * k); + } +} + +// ============================== Infer GEMM sizes +// ================================= +__device__ inline int find_total_elts_leq_target(int* sorted_indices, + const int64_t arr_length, + const int64_t target) { + int64_t low = 0, high = arr_length - 1, target_location = -1; + while (low <= high) { + int64_t mid = (low + high) / 2; + + if (sorted_indices[mid] > target) { + high = mid - 1; + } else { + low = mid + 1; + target_location = mid; + } + } + return target_location + 1; +} + +// Final kernel to unpermute and scale +// This kernel unpermutes the original data, does the k-way reduction and +// performs the final skip connection. +template +__global__ void finalize_moe_routing_kernel( + const T* expanded_permuted_rows, + T* reduced_unpermuted_output, + const T* bias, + const float* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, + const int64_t cols, + const int64_t k, + const int64_t compute_bias, + const bool norm_topk_prob, + const float routed_scaling_factor, + const int64_t num_rows) { + const int original_row = blockIdx.x + blockIdx.y * gridDim.x; + // const int original_row = blockIdx.x; + // const int num_rows = gridDim.x; + if (original_row >= num_rows) return; + T* reduced_row_ptr = reduced_unpermuted_output + original_row * cols; + + for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) { + T thread_output{0.f}; + float row_rescale{0.f}; + for (int k_idx = 0; k_idx < k; ++k_idx) { + const int expanded_original_row = original_row + k_idx * num_rows; + const int expanded_permuted_row = + expanded_source_row_to_expanded_dest_row[expanded_original_row]; + + const int64_t k_offset = original_row * k + k_idx; + const float row_scale = scales[k_offset]; + row_rescale = row_rescale + row_scale; + + const T* expanded_permuted_rows_row_ptr = + expanded_permuted_rows + expanded_permuted_row * cols; + + const int expert_idx = expert_for_source_row[k_offset]; + const T* bias_ptr = bias ? bias + expert_idx * cols : nullptr; + const T bias_value = bias_ptr ? bias_ptr[tid] : T{0.f}; + + thread_output = + static_cast(thread_output) + + row_scale * static_cast( + expanded_permuted_rows_row_ptr[tid] + + bias_value * + static_cast(static_cast(compute_bias))); + } + + thread_output = static_cast(thread_output) / + (norm_topk_prob ? row_rescale : 1.0f) * + routed_scaling_factor; + reduced_row_ptr[tid] = thread_output; + } +} + +template +void finalize_moe_routing_kernelLauncher( + const T* expanded_permuted_rows, + T* reduced_unpermuted_output, + const T* bias, + const float* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, + const int64_t num_rows, + const int64_t cols, + const int64_t k, + const int64_t compute_bias, + const bool norm_topk_prob, + const float routed_scaling_factor, + cudaStream_t stream) { + const int threads = std::min(cols, int64_t(1024)); + const auto config_final = Get1DBlocksAnd2DGridsMoe(num_rows); + + finalize_moe_routing_kernel + <<>>( + expanded_permuted_rows, + reduced_unpermuted_output, + bias, + scales, + expanded_source_row_to_expanded_dest_row, + expert_for_source_row, + cols, + k, + compute_bias, + norm_topk_prob, + routed_scaling_factor, + num_rows); +} + +// ========================= TopK Softmax specializations +// =========================== +template void topk_gating_softmax_kernelLauncher(const float*, + const float*, + float*, + float*, + int*, + int*, + float*, + const int64_t, + const int64_t, + const int64_t, + const bool, + cudaStream_t, + const bool); +template void topk_gating_softmax_kernelLauncher(const half*, + const half*, + half*, + half*, + int*, + int*, + half*, + const int64_t, + const int64_t, + const int64_t, + const bool, + cudaStream_t, + const bool); +#ifdef PADDLE_CUDA_BF16 +template void topk_gating_softmax_kernelLauncher(const __nv_bfloat16*, + const __nv_bfloat16*, + __nv_bfloat16*, + __nv_bfloat16*, + int*, + int*, + __nv_bfloat16*, + const int64_t, + const int64_t, + const int64_t, + const bool, + cudaStream_t, + const bool); +#endif +// ===================== Specializations for init routing +// ========================= +template void initialize_moe_routing_kernelLauncher(const float*, + float*, + const int*, + int*, + const int64_t, + const int64_t, + const int64_t, + const int64_t, + cudaStream_t); +template void initialize_moe_routing_kernelLauncher(const half*, + half*, + const int*, + int*, + const int64_t, + const int64_t, + const int64_t, + const int64_t, + cudaStream_t); +#ifdef PADDLE_CUDA_BF16 +template void initialize_moe_routing_kernelLauncher(const __nv_bfloat16*, + __nv_bfloat16*, + const int*, + int*, + const int64_t, + const int64_t, + const int64_t, + const int64_t, + cudaStream_t); +#endif +// ==================== Specializations for final routing +// =================================== +template void finalize_moe_routing_kernelLauncher(const float*, + float*, + const float*, + const float*, + const int*, + const int*, + const int64_t, + const int64_t, + const int64_t, + const int64_t, + const bool, + const float, + cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const half*, + half*, + const half*, + const float*, + const int*, + const int*, + const int64_t, + const int64_t, + const int64_t, + const int64_t, + const bool, + const float, + cudaStream_t); +#ifdef PADDLE_CUDA_BF16 +template void finalize_moe_routing_kernelLauncher(const __nv_bfloat16*, + __nv_bfloat16*, + const __nv_bfloat16*, + const float*, + const int*, + const int*, + const int64_t, + const int64_t, + const int64_t, + const int64_t, + const bool, + const float, + cudaStream_t); +#endif + +} // namespace phi diff --git a/custom_ops/iluvatar_ops/moe_dispatch.cu b/custom_ops/iluvatar_ops/moe_dispatch.cu new file mode 100644 index 000000000..a6195f44e --- /dev/null +++ b/custom_ops/iluvatar_ops/moe_dispatch.cu @@ -0,0 +1,311 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +// Ignore CUTLASS warnings about type punning +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#pragma GCC diagnostic ignored "-Wunused-function" +#pragma once + +#include "fused_moe_helper.h" +#include "fused_moe_op.h" +#pragma GCC diagnostic pop +#include "helper.h" + +__global__ void compute_total_rows_before_expert_kernel( + int* sorted_experts, + const int64_t sorted_experts_len, + const int64_t num_experts, + int64_t* total_rows_before_expert) { + const int expert = blockIdx.x * blockDim.x + threadIdx.x; + if (expert >= num_experts) return; + total_rows_before_expert[expert] = + phi::find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert); +} + +void compute_total_rows_before_expert(int* sorted_indices, + const int64_t total_indices, + const int64_t num_experts, + int64_t* total_rows_before_expert, + cudaStream_t stream) { + const int threads = std::min(int64_t(1024), num_experts); + const int blocks = (num_experts + threads - 1) / threads; + + compute_total_rows_before_expert_kernel<<>>( + sorted_indices, total_indices, num_experts, total_rows_before_expert); +} + +template +void MoeDispatchKernel(const paddle::Tensor& input, + const paddle::Tensor& gating_output, + const paddle::optional& gating_correction_bias, + const int moe_topk, + const bool group_moe, + const bool topk_only_mode, + const int num_rows, + const int hidden_size, + const int expert_num, + paddle::Tensor* permute_input, + paddle::Tensor* tokens_expert_prefix_sum, + paddle::Tensor* permute_indices_per_token, + paddle::Tensor* top_k_weight, + paddle::Tensor* top_k_indices) { + using namespace phi; + + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + auto place = input.place(); + auto dev_ctx = static_cast(paddle::experimental::DeviceContextPool::Instance().Get(input.place())); + auto stream = static_cast(dev_ctx->stream()); + if (group_moe) { + // Check if expert_num is divisible by moe_topk, else throw an error + PADDLE_ENFORCE_EQ(expert_num % moe_topk, + 0, + common::errors::InvalidArgument( + "The number of experts (expert_num) " + "must be divisible by moe_topk. " + "Got expert_num = %d and moe_topk = %d.", + expert_num, + moe_topk)); + } + + const int num_moe_inputs = AlignTo16(num_rows * moe_topk); + const int bytes = num_moe_inputs * sizeof(int); + + CubKeyValueSorter sorter_; + sorter_.update_num_experts(expert_num); + + const int sorter_ws_size_bytes = + AlignTo16(sorter_.getWorkspaceSize(moe_topk * num_rows)); + const int sort_tmp_in_out_size = num_moe_inputs * 2 * sizeof(int); + + paddle::Tensor ws_ptr_tensor = + GetEmptyTensor({bytes + sorter_ws_size_bytes + sort_tmp_in_out_size}, + paddle::DataType::INT8, + place); + + int8_t* ws_ptr = ws_ptr_tensor.data(); + int* source_rows_ = reinterpret_cast(ws_ptr); + int8_t* sorter_ws_ptr = reinterpret_cast(ws_ptr + bytes); + int* permuted_experts_ = + reinterpret_cast(sorter_ws_ptr + sorter_ws_size_bytes); + int* permuted_rows_ = permuted_experts_ + num_moe_inputs; + + int* expert_for_source_row = top_k_indices->data(); + + float* softmax_max_prob = nullptr; + if (group_moe) { + paddle::Tensor softmax_max_prob_tensor = + GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place); + // (TODO: check fill sucess ?) + paddle::experimental::fill(softmax_max_prob_tensor, 0.f); + softmax_max_prob = softmax_max_prob_tensor.data(); + } + + float* softmax_out_; + + const bool is_pow_2 = + (expert_num != 0) && ((expert_num & (expert_num - 1)) == 0); + + paddle::Tensor softmax_buffer; + + if (!is_pow_2 || expert_num > 256 || group_moe || gating_correction_bias) { + softmax_buffer = GetEmptyTensor( + {num_rows * expert_num}, paddle::DataType::FLOAT32, place); + softmax_out_ = softmax_buffer.data(); + } else { + softmax_out_ = nullptr; + } + + topk_gating_softmax_kernelLauncher(gating_output.data(), + gating_correction_bias ? gating_correction_bias.get().data() : nullptr, + top_k_weight->data(), + softmax_out_, + expert_for_source_row, + source_rows_, + softmax_max_prob, + num_rows, + expert_num, + moe_topk, + group_moe, + stream, + topk_only_mode); + + sorter_.run(reinterpret_cast(sorter_ws_ptr), + sorter_ws_size_bytes, + expert_for_source_row, + permuted_experts_, + source_rows_, + permuted_rows_, + moe_topk * num_rows, + false, + stream); + + + initialize_moe_routing_kernelLauncher( + input.data(), + permute_input->data(), + permuted_rows_, + permute_indices_per_token->data(), + num_rows, + num_rows, + hidden_size, + moe_topk, + stream); + + + compute_total_rows_before_expert( + permuted_experts_, + moe_topk * num_rows, + expert_num, + tokens_expert_prefix_sum->data(), + stream); +} + + +std::vector MoeExpertDispatch( + const paddle::Tensor& input, + const paddle::Tensor& gating_output, + const paddle::optional& gating_correction_bias, + const paddle::optional& w4a8_in_scale, + const int moe_topk, + const bool group_moe, + const bool topk_only_mode) { + const auto input_type = input.dtype(); + auto place = input.place(); + int token_rows = 0; + auto input_dims = input.dims(); + auto gating_dims = gating_output.dims(); + const int expert_num = gating_dims[gating_dims.size() - 1]; + + if (input_dims.size() == 3) { + token_rows = input_dims[0] * input_dims[1]; + } else { + token_rows = input_dims[0]; + } + const int num_rows = token_rows; + const int hidden_size = input.dims()[input_dims.size() - 1]; + + auto permute_input = + GetEmptyTensor({moe_topk * num_rows, hidden_size}, input_type, place); + // correspond to the weighted coefficients of the results from each expert. + auto top_k_weight = + GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place); + auto top_k_indices = + GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::INT32, place); + + auto tokens_expert_prefix_sum = + GetEmptyTensor({expert_num}, paddle::DataType::INT64, place); + auto permute_indices_per_token = + GetEmptyTensor({moe_topk, num_rows}, paddle::DataType::INT32, place); + + + switch (input_type) { + case paddle::DataType::BFLOAT16: + MoeDispatchKernel(input, + gating_output, + gating_correction_bias, + moe_topk, + group_moe, + topk_only_mode, + num_rows, + hidden_size, + expert_num, + &permute_input, + &tokens_expert_prefix_sum, + &permute_indices_per_token, + &top_k_weight, + &top_k_indices); + break; + case paddle::DataType::FLOAT16: + MoeDispatchKernel(input, + gating_output, + gating_correction_bias, + moe_topk, + group_moe, + topk_only_mode, + num_rows, + hidden_size, + expert_num, + &permute_input, + &tokens_expert_prefix_sum, + &permute_indices_per_token, + &top_k_weight, + &top_k_indices); + break; + default: + PD_THROW("Unsupported data type for MoeDispatchKernel"); + } + return {permute_input, + tokens_expert_prefix_sum, + permute_indices_per_token, + top_k_weight, + top_k_indices, + top_k_indices}; +} + + +std::vector> MoeExpertDispatchInferShape( + const std::vector& input_shape, + const std::vector& gating_output_shape, + const paddle::optional>& bias_shape, + const int moe_topk) { + int token_rows = -1; + + if (input_shape.size() == 3) { + token_rows = input_shape[0] * input_shape[1]; + } else { + token_rows = input_shape[0]; + } + const int expert_num = gating_output_shape[gating_output_shape.size() - 1]; + const int num_rows = token_rows; + const int hidden_size = input_shape[input_shape.size() - 1]; + + return {{moe_topk * num_rows, hidden_size}, + {expert_num}, + {moe_topk, num_rows}, + {num_rows, moe_topk}, + {num_rows, moe_topk}, + {num_rows, moe_topk}}; +} + +std::vector MoeExpertDispatchInferDtype( + const paddle::DataType& input_dtype, + const paddle::DataType& gating_output_dtype, + const paddle::optional& bias_type, + const int moe_topk) { + return {input_dtype, + paddle::DataType::INT64, + paddle::DataType::INT32, + paddle::DataType::FLOAT32, + paddle::DataType::INT32, + paddle::DataType::INT32}; +} + + +PD_BUILD_STATIC_OP(moe_expert_dispatch) + .Inputs({"input", "gating_output", paddle::Optional("gating_correction_bias"), + paddle::Optional("w4a8_in_scale")}) + .Outputs({"permute_input", + "tokens_expert_prefix_sum", + "permute_indices_per_token", + "top_k_weight", + "top_k_indices", + "expert_idx_per_token"}) + .Attrs({"moe_topk:int", "group_moe:bool", "topk_only_mode:bool"}) + .SetKernelFn(PD_KERNEL(MoeExpertDispatch)) + .SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertDispatchInferDtype)); diff --git a/custom_ops/iluvatar_ops/moe_reduce.cu b/custom_ops/iluvatar_ops/moe_reduce.cu new file mode 100644 index 000000000..dda0ce44b --- /dev/null +++ b/custom_ops/iluvatar_ops/moe_reduce.cu @@ -0,0 +1,155 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Ignore CUTLASS warnings about type punning + +#pragma once + +#include "helper.h" +#include "fused_moe_helper.h" +#include "fused_moe_op.h" + +template +void MoeReduceKernel(const paddle::Tensor& ffn_out, + const paddle::Tensor& top_k_weight, + const paddle::Tensor& permute_indices_per_token, + const paddle::Tensor& top_k_indices, + const paddle::optional& ffn2_bias, + const bool norm_topk_prob, + const float routed_scaling_factor, + const int num_rows, + const int hidden_size, + const int topk, + paddle::Tensor* output) { + using namespace phi; + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + auto dev_ctx = static_cast(paddle::experimental::DeviceContextPool::Instance().Get(ffn_out.place())); + auto stream = static_cast(dev_ctx->stream()); + + finalize_moe_routing_kernelLauncher( + ffn_out.data(), + output->data(), + ffn2_bias ? ffn2_bias->data() : nullptr, + top_k_weight.data(), + permute_indices_per_token.data(), + top_k_indices.data(), + num_rows, + hidden_size, + topk, + static_cast(1), + norm_topk_prob, + routed_scaling_factor, + stream); +} + +paddle::Tensor MoeExpertReduceFunc( + const paddle::Tensor& ffn_out, + const paddle::Tensor& top_k_weight, + const paddle::Tensor& permute_indices_per_token, + const paddle::Tensor& top_k_indices, + const paddle::optional& ffn2_bias, + const bool norm_topk_prob, + const float routed_scaling_factor) { + const auto input_type = ffn_out.dtype(); + auto place = ffn_out.place(); + + const int topk = top_k_indices.dims()[1]; + const int num_rows = ffn_out.dims()[0] / topk; + const int hidden_size = ffn_out.dims()[1]; + + auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place); + + switch (input_type) { + case paddle::DataType::BFLOAT16: + MoeReduceKernel( + ffn_out, + top_k_weight, + permute_indices_per_token, + top_k_indices, + ffn2_bias, + norm_topk_prob, + routed_scaling_factor, + num_rows, + hidden_size, + topk, + &output); + break; + case paddle::DataType::FLOAT16: + MoeReduceKernel( + ffn_out, + top_k_weight, + permute_indices_per_token, + top_k_indices, + ffn2_bias, + norm_topk_prob, + routed_scaling_factor, + num_rows, + hidden_size, + topk, + &output); + break; + default: + PD_THROW("Unsupported data type for MoeDispatchKernel"); + } + return output; +} + +std::vector MoeExpertReduce( + const paddle::Tensor& ffn_out, + const paddle::Tensor& top_k_weight, + const paddle::Tensor& permute_indices_per_token, + const paddle::Tensor& top_k_indices, + const paddle::optional& ffn2_bias, + const bool norm_topk_prob, + const float routed_scaling_factor) { + return {MoeExpertReduceFunc(ffn_out, + top_k_weight, + permute_indices_per_token, + top_k_indices, + ffn2_bias, + norm_topk_prob, + routed_scaling_factor)}; +} + +std::vector> MoeExpertReduceInferShape( + const std::vector& ffn_out_shape, + const std::vector& top_k_weight_shape, + const std::vector& permute_indices_per_token_shape, + const std::vector& top_k_indices_shape, + const paddle::optional>& ffn2_bias_shape) { + return {ffn_out_shape}; +} + +std::vector MoeExpertReduceInferDtype( + const paddle::DataType& ffn_out_dtype, + const paddle::DataType& top_k_weight_dtype, + const paddle::DataType& permute_indices_per_token_dtype, + const paddle::DataType& top_k_indices_dtype, + const paddle::optional& ffn2_bias_dtype) { + return {ffn_out_dtype}; +} + +PD_BUILD_STATIC_OP(moe_expert_reduce) + .Inputs({"ffn_out", + "top_k_weight", + "permute_indices_per_token", + "top_k_indices", + paddle::Optional("ffn2_bias")}) + .Outputs({"output"}) + .Attrs({"norm_topk_prob:bool", "routed_scaling_factor:float"}) + .SetKernelFn(PD_KERNEL(MoeExpertReduce)) + .SetInferShapeFn(PD_INFER_SHAPE(MoeExpertReduceInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertReduceInferDtype)); diff --git a/custom_ops/iluvatar_ops/paged_attn.cu b/custom_ops/iluvatar_ops/paged_attn.cu new file mode 100644 index 000000000..7c9ead54d --- /dev/null +++ b/custom_ops/iluvatar_ops/paged_attn.cu @@ -0,0 +1,337 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" +#include "iluvatar_context.h" + +#define CUINFER_CHECK(func) \ + do { \ + cuinferStatus_t status = (func); \ + if (status != CUINFER_STATUS_SUCCESS) { \ + std::cerr << "Error in file " << __FILE__ << " on line " << __LINE__ << ": " \ + << cuinferGetErrorString(status) << std::endl; \ + throw std::runtime_error("CUINFER_CHECK ERROR"); \ + } \ + } while (0) + +template +void PagedAttnKernel(const paddle::Tensor& q, + const paddle::Tensor& k_cache, + const paddle::Tensor& v_cache, + const paddle::Tensor& block_table, + const paddle::Tensor& seq_lens, + const paddle::optional &alibi_slopes, + const paddle::optional &k, + const paddle::optional &v, + int num_kv_heads, + float scale, + int block_size, + int max_context_len, + bool causal, + int window_left, + int window_right, + float softcap, + bool enable_cuda_graph, + bool use_sqrt_alibi, + paddle::Tensor& out) { + if (alibi_slopes) { + PADDLE_ENFORCE_EQ(alibi_slopes.get().dtype(), + paddle::DataType::FLOAT32, + common::errors::InvalidArgument( + "paged_attention expects alibi_slopes float tensor")); + PADDLE_ENFORCE_EQ(alibi_slopes.get().is_contiguous(), + true, + common::errors::InvalidArgument( + "paged_attention expects alibi_slopes is contiguous")); + } + + // check dtype and contiguous + const auto& dtype = q.dtype(); + cudaDataType_t data_type; + if (dtype == paddle::DataType::FLOAT16) { + data_type = CUDA_R_16F; + } else if (dtype == paddle::DataType::BFLOAT16) { + data_type = CUDA_R_16BF; + } else { + common::errors::InvalidArgument("paged_attention support half and bfloat16 now"); + } + + PADDLE_ENFORCE_EQ(k_cache.dtype(), + dtype, + common::errors::InvalidArgument( + "k_cache dtype must be the same as query dtype")); + PADDLE_ENFORCE_EQ(k_cache.is_contiguous(), + true, + common::errors::InvalidArgument( + "paged_attention expects k_cache is contiguous")); + PADDLE_ENFORCE_EQ(v_cache.dtype(), + dtype, + common::errors::InvalidArgument( + "v_cache dtype must be the same as query dtype")); + PADDLE_ENFORCE_EQ(v_cache.is_contiguous(), + true, + common::errors::InvalidArgument( + "paged_attention expects v_cache is contiguous")); + PADDLE_ENFORCE_EQ(block_table.dtype(), + paddle::DataType::INT32, + common::errors::InvalidArgument( + "block_table dtype must be int32")); + PADDLE_ENFORCE_EQ(block_table.is_contiguous(), + true, + common::errors::InvalidArgument( + "paged_attention expects block_table is contiguous")); + PADDLE_ENFORCE_EQ(seq_lens.dtype(), + paddle::DataType::INT32, + common::errors::InvalidArgument( + "seq_lens dtype must be int32")); + PADDLE_ENFORCE_EQ(seq_lens.is_contiguous(), + true, + common::errors::InvalidArgument( + "paged_attention expects seq_lens is contiguous")); + + // check dim and shape + // out: [num_seqs, num_heads, head_size] + // q: [num_seqs, num_heads, head_size] + // k_chache: [num_blocks, kv_num_heads, block_size, head_size] + // v_chache: [num_blocks, kv_num_heads, block_size, head_size] + // block_table: [num_seqs, max_num_blocks_per_seq] + // seq_lens: [num_seqs] + + const auto& q_dims = q.dims(); + PADDLE_ENFORCE_EQ(q_dims.size(), + 3, + common::errors::InvalidArgument( + "paged_attn receive query dims is " + "[num_seqs, num_heads, head_size]")); + PADDLE_ENFORCE_EQ(out.dims().size(), + 3, + common::errors::InvalidArgument( + "paged_attn receive out dims is " + "[num_seqs, num_heads, head_size]")); + PADDLE_ENFORCE_EQ(k_cache.dims(), + v_cache.dims(), + common::errors::InvalidArgument( + "paged_attn requires k_cache size is the " + "same as v_cache")); + + const auto& kv_cache_dims = k_cache.dims(); + PADDLE_ENFORCE_EQ(kv_cache_dims.size(), + 4, + common::errors::InvalidArgument( + "paged_attn receive kv cache dims is " + "[num_blocks, kv_num_heads, block_size, head_size]")); + + const auto& block_table_dims = block_table.dims(); + PADDLE_ENFORCE_EQ(block_table_dims.size(), + 2, + common::errors::InvalidArgument( + "paged_attn receive block_table dims is " + "[num_seqs, max_num_blocks_per_seq]")); + + const auto& seq_lens_dims = seq_lens.dims(); + PADDLE_ENFORCE_EQ(seq_lens_dims.size(), + 1, + common::errors::InvalidArgument( + "paged_attn receive seq_lens dims is [num_seqs]")); + + int num_seqs = q_dims[0]; + int num_heads = q_dims[1]; + int head_size = q_dims[2]; + int max_num_blocks_per_seq = block_table_dims[1]; + int q_stride = q.strides()[0]; + int num_blocks = kv_cache_dims[0]; + + PADDLE_ENFORCE_EQ(kv_cache_dims[1], + num_kv_heads, + common::errors::InvalidArgument( + "kv_cache_dims[1] must be equal to num_kv_head")); + PADDLE_ENFORCE_EQ(kv_cache_dims[2], + block_size, + common::errors::InvalidArgument( + "kv_cache_dims[2] must be equal to block_size")); + PADDLE_ENFORCE_EQ(kv_cache_dims[3], + head_size, + common::errors::InvalidArgument( + "kv_cache_dims[3] must be equal to head_size")); + PADDLE_ENFORCE_EQ(block_table_dims[0], + num_seqs, + common::errors::InvalidArgument( + "block_table_dims[0] must be equal to num_seqs")); + PADDLE_ENFORCE_EQ(seq_lens_dims[0], + num_seqs, + common::errors::InvalidArgument( + "seq_lens_dims[0] must be equal to num_seqs")); + + int kv_block_stride = k_cache.strides()[0]; + int kv_head_stride = k_cache.strides()[1]; + const float *alibi_slopes_ptr = alibi_slopes ? alibi_slopes.get().data() : nullptr; + const void *key_ptr = k ? k.get().data() : nullptr; + const void *value_ptr = v ? v.get().data() : nullptr; + + size_t workspace_size = 0; + void* workspace_ptr = nullptr; + CUINFER_CHECK(cuInferPageAttentionGetWorkspaceV7( + num_seqs, num_heads, num_kv_heads, head_size, block_size, max_context_len, &workspace_size)); + + CUDA_CHECK(cudaMalloc((void**)&workspace_ptr, workspace_size)); + CUDA_CHECK(cudaMemset(workspace_ptr, 0xff, workspace_size)); + + auto dev_ctx = static_cast(paddle::experimental::DeviceContextPool::Instance().Get(q.place())); + auto stream = static_cast(dev_ctx->stream()); + cuinferHandle_t cuinfer_handle = iluvatar::getContextInstance()->getIxInferHandle(); + + PageAttentionWithKVCacheArguments args{ + static_cast(scale), 1.0, 1.0, static_cast(softcap), window_left, window_right, + causal, use_sqrt_alibi, enable_cuda_graph, false, alibi_slopes_ptr, key_ptr, value_ptr, workspace_ptr}; + CUINFER_CHECK(cuInferPageAttentionV7(cuinfer_handle, + out.data(), + data_type, + q.data(), + data_type, + num_seqs, + num_heads, + num_kv_heads, + head_size, + q_stride, + kv_block_stride, + kv_head_stride, + k_cache.data(), + data_type, + v_cache.data(), + data_type, + block_size, + max_num_blocks_per_seq, + max_context_len, + block_table.data(), + seq_lens.data(), + args)); + + CUDA_CHECK(cudaFree(workspace_ptr)); +} + +std::vector PagedAttn(const paddle::Tensor& q, + const paddle::Tensor& k_cache, + const paddle::Tensor& v_cache, + const paddle::Tensor& block_table, + const paddle::Tensor& seq_lens, + const paddle::optional &alibi_slopes, + const paddle::optional &k, + const paddle::optional &v, + int num_kv_heads, + float scale, + int block_size, + int max_context_len, + bool causal, + int window_left, + int window_right, + float softcap, + bool enable_cuda_graph, + bool use_sqrt_alibi) { + + const auto dtype = q.dtype(); + auto out = paddle::empty_like(q, dtype); + + switch (dtype) { + case paddle::DataType::BFLOAT16: + PagedAttnKernel(q, + k_cache, + v_cache, + block_table, + seq_lens, + alibi_slopes, + k, + v, + num_kv_heads, + scale, + block_size, + max_context_len, + causal, + window_left, + window_right, + softcap, + enable_cuda_graph, + use_sqrt_alibi, + out); + break; + case paddle::DataType::FLOAT16: + PagedAttnKernel(q, + k_cache, + v_cache, + block_table, + seq_lens, + alibi_slopes, + k, + v, + num_kv_heads, + scale, + block_size, + max_context_len, + causal, + window_left, + window_right, + softcap, + enable_cuda_graph, + use_sqrt_alibi, + out); + break; + default: + PD_THROW("Unsupported data type for Paged attn"); + } + return {out}; +} + +std::vector> PagedAttnInferShape(const std::vector& q_shape, + const std::vector& k_cache_shape, + const std::vector& v_cache_shape, + const std::vector& block_table_shape, + const std::vector& seq_lens_shape, + const std::vector& alibi_slopes_shape, + const std::vector& k_shape, + const std::vector& v_shape) { + return {q_shape}; +} + +std::vector PagedAttnInferDtype(const paddle::DataType& q_dtype, + const paddle::DataType& k_cache_dtype, + const paddle::DataType& v_cache_dtype, + const paddle::DataType& block_table_dtype, + const paddle::DataType& seq_lens_dtype, + const paddle::DataType& alibi_slopes_dtype, + const paddle::DataType& k_dtype, + const paddle::DataType& v_dtype) { + return {q_dtype}; +} + + +PD_BUILD_STATIC_OP(paged_attn) + .Inputs({"q", "k_cache", "v_cache", "block_table", "seq_lens", paddle::Optional("alibi_slopes"), paddle::Optional("k"), paddle::Optional("v")}) + .Outputs({"out"}) + .Attrs({"num_kv_heads:int", + "scale:float", + "block_size:int", + "max_context_len:int", + "causal:bool", + "window_left:int", + "window_right:int", + "softcap:float", + "enable_cuda_graph:bool", + "use_sqrt_alibi:bool"}) + .SetKernelFn(PD_KERNEL(PagedAttn)) + .SetInferShapeFn(PD_INFER_SHAPE(PagedAttnInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(PagedAttnInferDtype)); + + +PYBIND11_MODULE(fastdeploy_ops, m) { + m.def("paged_attn", &PagedAttn, "paged attn function"); +} diff --git a/custom_ops/iluvatar_ops/runtime/iluvatar_context.cc b/custom_ops/iluvatar_ops/runtime/iluvatar_context.cc new file mode 100644 index 000000000..d64f57d11 --- /dev/null +++ b/custom_ops/iluvatar_ops/runtime/iluvatar_context.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "iluvatar_context.h" + +#include +#include +namespace iluvatar { +IluvatarContext::~IluvatarContext() { + if (ixinfer_handle_) { + cuinferDestroy(ixinfer_handle_); + } +} +cuinferHandle_t IluvatarContext::getIxInferHandle() { + if (!ixinfer_handle_) { + cuinferCreate(&ixinfer_handle_); + } + return ixinfer_handle_; +} + +IluvatarContext* getContextInstance() { + static IluvatarContext context; + return &context; +} +} // namespace iluvatar diff --git a/custom_ops/iluvatar_ops/runtime/iluvatar_context.h b/custom_ops/iluvatar_ops/runtime/iluvatar_context.h new file mode 100644 index 000000000..4865fe816 --- /dev/null +++ b/custom_ops/iluvatar_ops/runtime/iluvatar_context.h @@ -0,0 +1,33 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#pragma once +#include + +namespace iluvatar { + +class IluvatarContext { + public: + IluvatarContext() = default; + ~IluvatarContext(); + + cuinferHandle_t getIxInferHandle(); + + private: + cuinferHandle_t ixinfer_handle_{nullptr}; +}; +IluvatarContext* getContextInstance(); + +} // namespace iluvatar diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index dd26d6e90..eca3349ac 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -470,6 +470,36 @@ elif paddle.is_compiled_with_cuda(): ) elif paddle.is_compiled_with_xpu(): assert False, "In XPU, we should use setup_ops.py in xpu_ops/src, not this." +elif paddle.is_compiled_with_custom_device("iluvatar_gpu"): + setup( + name="fastdeploy_ops", + ext_modules=CUDAExtension( + extra_compile_args={ + "nvcc": [ + "-DPADDLE_DEV", + "-DPADDLE_WITH_CUSTOM_DEVICE", + ] + }, + sources=[ + "gpu_ops/get_padding_offset.cu", + "gpu_ops/set_value_by_flags.cu", + "gpu_ops/stop_generation_multi_stop_seqs.cu", + "gpu_ops/rebuild_padding.cu", + "gpu_ops/update_inputs.cu", + "gpu_ops/stop_generation_multi_ends.cu", + "gpu_ops/step.cu", + "gpu_ops/token_penalty_multi_scores.cu", + "iluvatar_ops/moe_dispatch.cu", + "iluvatar_ops/moe_reduce.cu", + "iluvatar_ops/paged_attn.cu", + "iluvatar_ops/runtime/iluvatar_context.cc", + ], + include_dirs=["iluvatar_ops/runtime", "gpu_ops"], + extra_link_args=[ + "-lcuinfer", + ], + ), + ) else: use_bf16 = envs.FD_CPU_USE_BF16 == "True" diff --git a/fastdeploy/engine/config.py b/fastdeploy/engine/config.py index d92f7c2a9..65f67254f 100644 --- a/fastdeploy/engine/config.py +++ b/fastdeploy/engine/config.py @@ -42,7 +42,7 @@ class ModelConfig: model_name_or_path: str, config_json_file: str = "config.json", dynamic_load_weight: bool = False, - load_strategy: str="meta", + load_strategy: str = "meta", quantization: str = None, download_dir: Optional[str] = None): """ @@ -590,7 +590,7 @@ class Config: self.nnode = 1 else: self.nnode = len(self.pod_ips) - + assert self.splitwise_role in ["mixed", "prefill", "decode"] # TODO @@ -608,8 +608,9 @@ class Config: == 1), "TP and EP cannot be enabled at the same time" num_ranks = self.tensor_parallel_size * self.parallel_config.expert_parallel_size - if num_ranks > 8: - self.worker_num_per_node = 8 + self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 + if num_ranks > self.max_chips_per_node: + self.worker_num_per_node = self.max_chips_per_node nnode = ceil_div(num_ranks, self.worker_num_per_node) assert nnode == self.nnode, \ f"nnode: {nnode}, but got {self.nnode}" @@ -679,8 +680,8 @@ class Config: is_port_available('0.0.0.0', self.engine_worker_queue_port) ), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use." assert ( - 8 >= self.tensor_parallel_size > 0 - ), f"tensor_parallel_size: {self.tensor_parallel_size} should be between 1 and 8" + self.max_chips_per_node >= self.tensor_parallel_size > 0 + ), f"tensor_parallel_size: {self.tensor_parallel_size} should be between 1 and {self.max_chips_per_node}" assert (self.nnode >= 1), f"nnode: {self.nnode} should no less than 1" assert ( self.max_model_len >= 16 @@ -816,7 +817,7 @@ class Config: def _check_master(self): return self.is_master - + def _str_to_list(self, attr_name, default_type): if hasattr(self, attr_name): val = getattr(self, attr_name) diff --git a/fastdeploy/model_executor/layers/activation.py b/fastdeploy/model_executor/layers/activation.py index aa8ff7f2c..09126aa6c 100644 --- a/fastdeploy/model_executor/layers/activation.py +++ b/fastdeploy/model_executor/layers/activation.py @@ -63,7 +63,8 @@ class SiluAndMul(nn.Layer): """ super().__init__() - if current_platform.is_cuda() or current_platform.is_xpu(): + if current_platform.is_cuda() or current_platform.is_xpu( + ) or current_platform.is_iluvatar(): self.forward = self.forward_cuda else: raise NotImplementedError diff --git a/fastdeploy/model_executor/layers/attention/__init__.py b/fastdeploy/model_executor/layers/attention/__init__.py index 6a1d0e1c1..9b0f2b289 100644 --- a/fastdeploy/model_executor/layers/attention/__init__.py +++ b/fastdeploy/model_executor/layers/attention/__init__.py @@ -19,9 +19,10 @@ from .flash_attn_backend import FlashAttentionBackend from .mla_attention_backend import MLAAttentionBackend from .native_paddle_backend import PaddleNativeAttnBackend from .xpu_attn_backend import XPUAttentionBackend +from .iluvatar_attn_backend import IluvatarAttnBackend __all__ = [ "AttentionBackend", "PaddleNativeAttnBackend", "get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend", - "MLAAttentionBackend", "FlashAttentionBackend" + "MLAAttentionBackend", "FlashAttentionBackend", "IluvatarAttnBackend" ] diff --git a/fastdeploy/model_executor/layers/attention/attention.py b/fastdeploy/model_executor/layers/attention/attention.py index 3f676f031..0ee0b41f4 100644 --- a/fastdeploy/model_executor/layers/attention/attention.py +++ b/fastdeploy/model_executor/layers/attention/attention.py @@ -67,7 +67,7 @@ class Attention(nn.Layer): self.num_heads: int = fd_config.model_config.num_attention_heads // fd_config.parallel_config.tensor_parallel_degree self.head_dim: int = fd_config.model_config.head_dim self.kv_num_heads: int = \ - fd_config.model_config.num_key_value_heads // fd_config.parallel_config.tensor_parallel_degree + max(1, fd_config.model_config.num_key_value_heads // fd_config.parallel_config.tensor_parallel_degree) self.layer_id: int = layer_id self.v_head_dim: int = v_head_dim if v_head_dim > 0 else self.head_dim self.rope_type: str = rope_type diff --git a/fastdeploy/model_executor/layers/attention/iluvatar_attn_backend.py b/fastdeploy/model_executor/layers/attention/iluvatar_attn_backend.py new file mode 100644 index 000000000..43e034194 --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/iluvatar_attn_backend.py @@ -0,0 +1,613 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from __future__ import annotations + +import os +import paddle + +from dataclasses import dataclass +from typing import Optional +from math import sqrt + +from paddle.nn.functional.flash_attention import flash_attn_unpadded +from fastdeploy.model_executor.ops.iluvatar import paged_attention + +from fastdeploy.config import FDConfig +from fastdeploy.model_executor.layers.attention.attention import Attention +from fastdeploy.model_executor.layers.attention.base_attention_backend import ( + AttentionBackend, AttentionMetadata) +from fastdeploy.worker.forward_meta import ForwardMeta + + +@dataclass +class IluvatarAttentionMetadata(AttentionMetadata): + """ + IluvatarAttentionMetadata + """ + # flash_attn metadata + cu_seqlens_q: Optional[paddle.Tensor] = None + cu_seqlens_k: Optional[paddle.Tensor] = None + fixed_seed_offset: Optional[paddle.Tensor] = None + attn_mask: Optional[paddle.Tensor] = None + attn_mask_start_row_indices: Optional[paddle.Tensor] = None + dropout: float = 0.0 + causal: bool = True + return_softmax: bool = False + rng_name: str = "" + + # paged_attn metadata + block_tables: Optional[paddle.Tensor] = None + seq_lens: Optional[paddle.Tensor] = None + num_kv_heads: int = 1 + scale: float = 1.0 + block_size: int = 1 + max_context_len: int = 1 + alibi_slopes: Optional[paddle.Tensor] = None + # causal: bool = True + window_left: int = -1 + window_right: int = -1 + softcap: float = 0.0 + use_cuda_graph: bool = False + use_sqrt_alibi: bool = False + + +# qk[seq, h, d], cos/sin [seq, 1, d] +def apply_rope(qk, cos, sin): + rotate_half = paddle.reshape( + paddle.stack([-qk[..., 1::2], qk[..., 0::2]], axis=-1), + paddle.shape(qk), + ) + out = paddle.add(paddle.multiply(qk, cos), + paddle.multiply(rotate_half, sin)) + return paddle.cast(out, qk.dtype) + + +class IluvatarAttnBackend(AttentionBackend): + """ + The backend class that uses paddle native attention implementation. + Which is used only for testing purpose. + """ + + def __init__(self, llm_config: FDConfig, kv_num_heads: int, num_heads: int, + head_dim: int): + super().__init__() + self.attention_metadata = IluvatarAttentionMetadata() + self.attention_metadata.block_size = llm_config.parallel_config.block_size + assert llm_config.parallel_config.enc_dec_block_num == 0, "Iluvatar does not support yet" + + self.attention_metadata.max_context_len = llm_config.parallel_config.max_model_len + self.attention_metadata.causal = getattr(llm_config.model_config, + "causal", True) + self.speculate_method = getattr(llm_config.parallel_config, + "speculate_method", None) + self.use_speculate = self.speculate_method is not None + self.attention_metadata.num_kv_heads = kv_num_heads + self.attention_metadata.dropout = llm_config.model_config.hidden_dropout_prob + self.num_heads = num_heads + self.head_dim = head_dim + # note: scale need to change if using MLA + self.attention_metadata.scale = 1.0 / sqrt(head_dim) + self.num_layers = llm_config.model_config.num_layers + self.record_block_table_metadata = {} + self.only_use_flash_attn = int( + os.getenv("FD_ILUVATAR_ONLY_USE_FLASH_ATTN", 0)) == 1 + self.do_check_kv_cache = int( + os.getenv("FD_ILUVATAR_CHECK_KV_CACHE_CORRECTNESS", 0)) == 1 + if not self.only_use_flash_attn: + assert self.attention_metadata.block_size == 16, "Iluvatar paged attn requires block_size must be 16." + if self.do_check_kv_cache: + self.record_batched_k = [{} for _ in range(self.num_layers)] + self.record_batched_v = [{} for _ in range(self.num_layers)] + + def init_attention_metadata(self, forward_meta: ForwardMeta): + """Initialize attntion metadata hence all layers in the forward pass can reuse it.""" + self.attention_metadata.block_tables = forward_meta.block_tables + self.attention_metadata.attn_mask = forward_meta.attn_mask + self.attention_metadata.seq_lens = forward_meta.seq_lens_decoder + self.attention_metadata.cu_seqlens_q = forward_meta.cu_seqlens_q + self.attention_metadata.cu_seqlens_k = forward_meta.cu_seqlens_k + + def get_attntion_meta(self): + """get_attntion_meta""" + return self.attention_metadata + + def get_kv_cache_shape( + self, + max_num_blocks: int, + ): + """ + Caculate kv cache shape + """ + return (max_num_blocks, self.attention_metadata.num_kv_heads, + self.attention_metadata.block_size, self.head_dim) + + def get_new_kv(self, + k, + v, + k_cache_id: int, + v_cache_id: int, + forward_meta: ForwardMeta, + debug_paged_attn=False): + new_k = [] + new_v = [] + tensor_start = 0 + for batch_idx in range(forward_meta.block_tables.shape[0]): + seq_len = forward_meta.seq_lens_this_time[batch_idx] + if seq_len == 0: + continue + + tensor_end = tensor_start + seq_len + slice_k = k[tensor_start:tensor_end, :, :] + slice_v = v[tensor_start:tensor_end, :, :] + + if seq_len > 1: + # prefill + new_k.append(slice_k) + new_v.append(slice_v) + else: + # decode + assert seq_len == 1 + cur_block_tables = forward_meta.block_tables[batch_idx] + cur_used_block_tables = cur_block_tables[cur_block_tables != + -1] + assert batch_idx in self.record_block_table_metadata, \ + f"Key error: {batch_idx} vs {self.record_block_table_metadata}." + cur_block_table_metadata = self.record_block_table_metadata[ + batch_idx] + record_last_block_id = cur_block_table_metadata["block_id"] + assert record_last_block_id != -1 + for block_id in cur_used_block_tables: + if block_id == record_last_block_id: + cache_end = cur_block_table_metadata["cache_end"] + block_k_cache = forward_meta.caches[k_cache_id][ + block_id, :, 0:cache_end, :] + block_v_cache = forward_meta.caches[v_cache_id][ + block_id, :, 0:cache_end, :] + else: + block_k_cache = forward_meta.caches[k_cache_id][ + block_id] + block_v_cache = forward_meta.caches[v_cache_id][ + block_id] + + # [num_kv_heads, block_size, head_dim] -> [block_size, num_kv_heads, head_dim] + new_k.append( + block_k_cache.transpose([1, 0, 2]).contiguous()) + new_v.append( + block_v_cache.transpose([1, 0, 2]).contiguous()) + if block_id == record_last_block_id: + break + + # as line 301 show, record_block_table_metadata updates when executing the last layer, + # so slice_k and slice_v has been updated in block_k_cache and block_v_cache + if not (debug_paged_attn and + (k_cache_id / 2 == self.num_layers - 1)): + new_k.append(slice_k) + new_v.append(slice_v) + + tensor_start = tensor_end + + if len(new_k) == 1: + return new_k[0], new_v[0] + else: + new_k = paddle.concat(new_k, axis=0) + new_v = paddle.concat(new_v, axis=0) + return new_k, new_v + + def update_kv_cache(self, + k, + v, + k_cache_id: int, + v_cache_id: int, + layer_id: int, + forward_meta: ForwardMeta, + specific_batch_ids=None, + debug_paged_attn=False): + # [num_tokens, num_kv_heads, head_dim] -> [num_kv_heads, num_tokens, head_dim] + trans_k = k.transpose([1, 0, 2]).contiguous() + trans_v = v.transpose([1, 0, 2]).contiguous() + tensor_start = 0 + for batch_idx in range(forward_meta.block_tables.shape[0]): + if specific_batch_ids is not None and batch_idx not in specific_batch_ids: + continue + seq_len = forward_meta.seq_lens_this_time[batch_idx] + if seq_len == 0: + continue + + tensor_end = tensor_start + seq_len + slice_trans_k = trans_k[:, tensor_start:tensor_end, :] + slice_trans_v = trans_v[:, tensor_start:tensor_end, :] + + cur_block_tables = forward_meta.block_tables[batch_idx] + cur_used_block_tables = cur_block_tables[cur_block_tables != -1] + + # prefill + if seq_len > 1: + cache_start = 0 + cur_used_num_blocks = cur_used_block_tables.shape[0] + for i, block_id in enumerate(cur_used_block_tables): + # last block: seq_len - cache_start <= block_size + if i == cur_used_num_blocks - 1: + cache_end = seq_len - cache_start + assert cache_end <= self.attention_metadata.block_size + forward_meta.caches[k_cache_id][ + block_id, :, + 0:cache_end, :] = slice_trans_k[:, cache_start: + seq_len, :] + forward_meta.caches[v_cache_id][ + block_id, :, + 0:cache_end, :] = slice_trans_v[:, cache_start: + seq_len, :] + if layer_id == self.num_layers - 1: + self.record_block_table_metadata[batch_idx] = { + "block_id": block_id.item(), + "cache_end": cache_end + } + # non last block: seq_lens_this_time > block_size + else: + assert seq_len > self.attention_metadata.block_size + cache_end = cache_start + self.attention_metadata.block_size + forward_meta.caches[k_cache_id][ + block_id] = slice_trans_k[:, + cache_start:cache_end, :] + forward_meta.caches[v_cache_id][ + block_id] = slice_trans_v[:, + cache_start:cache_end, :] + cache_start += self.attention_metadata.block_size + else: + # decode + assert seq_len == 1 + cur_last_block_id = cur_used_block_tables[-1].item() + assert cur_last_block_id != -1 + assert batch_idx in self.record_block_table_metadata, \ + f"Key error: {batch_idx} vs {self.record_block_table_metadata}." + cur_block_table_metadata = self.record_block_table_metadata[ + batch_idx] + record_last_block_id = cur_block_table_metadata["block_id"] + + if cur_last_block_id == record_last_block_id: + # not alloc new block in decode stage + cache_start = cur_block_table_metadata["cache_end"] + else: + # alloc new block in decode stage + cache_start = 0 + + cache_end = cache_start + 1 + assert cache_end <= self.attention_metadata.block_size + + # paged attn API will update kv cache with inplace mode + if not debug_paged_attn: + forward_meta.caches[k_cache_id][ + cur_last_block_id, :, + cache_start:cache_end, :] = slice_trans_k + forward_meta.caches[v_cache_id][ + cur_last_block_id, :, + cache_start:cache_end, :] = slice_trans_v + + # update record_block_table_metadata + if layer_id == self.num_layers - 1: + self.record_block_table_metadata[batch_idx][ + "block_id"] = cur_last_block_id + self.record_block_table_metadata[batch_idx][ + "cache_end"] = cache_end + + tensor_start = tensor_end + + def _check_new_kv_correctness(self, k, v, new_k, new_v, layer_id: int, + forward_meta: ForwardMeta): + tensor_start = 0 + for batch_idx, seq_lens_this_time in enumerate( + forward_meta.seq_lens_this_time): + if seq_lens_this_time == 0: + continue + # note: the second request will also use the batch_idx 0 instead of 1 in + # the streaming inference mode, so use seq_lens_this_time > 1 with the same + # batch_idx represents the second request comes. + if seq_lens_this_time > 1 and batch_idx in self.record_batched_k[ + layer_id]: + print( + f"clear self.record_batched_batched_k: " + f"layer_id={layer_id}, batch_id={batch_idx}, " + f"record_lens={len(self.record_batched_k[layer_id][batch_idx])}" + ) + self.record_batched_k[layer_id][batch_idx].clear() + self.record_batched_v[layer_id][batch_idx].clear() + tensor_end = tensor_start + seq_lens_this_time + slice_k = k[tensor_start:tensor_end, :, :] + slice_v = v[tensor_start:tensor_end, :, :] + if batch_idx not in self.record_batched_k[layer_id]: + self.record_batched_k[layer_id][batch_idx] = [] + self.record_batched_v[layer_id][batch_idx] = [] + self.record_batched_k[layer_id][batch_idx].append(slice_k) + self.record_batched_v[layer_id][batch_idx].append(slice_v) + tensor_start = tensor_end + + ref_k, ref_v = [], [] + for batch_idx, seq_lens_this_time in enumerate( + forward_meta.seq_lens_this_time): + if seq_lens_this_time == 0: + continue + bached_k_list = self.record_batched_k[layer_id][batch_idx] + bached_v_list = self.record_batched_v[layer_id][batch_idx] + ref_k.extend(bached_k_list) + ref_v.extend(bached_v_list) + + ref_k = paddle.concat(ref_k, axis=0) + ref_v = paddle.concat(ref_v, axis=0) + print( + f"_check_new_kv_correctness: layer_id={layer_id}, " + f"k.shape={k.shape}, v.shape={v.shape}, " + f"ref_k.shape={ref_k.shape}, ref_v.shape={ref_v.shape}, " + f"new_k.shape={new_k.shape}, new_v.shape={new_v.shape}, " + f"len(self.record_batched_k[layer_id])={len(self.record_batched_k[layer_id])}, " + f"len(self.record_batched_k[layer_id][0])={len(self.record_batched_k[layer_id][0])}, " + f"forward_meta.seq_lens_this_time={forward_meta.seq_lens_this_time}" + f"ref_k[-2:, 0:2, 0:2]={ref_k[-2:, 0:2, 0:2]}, " + f"ref_v[-2:, 0:2, 0:2]={ref_v[-2:, 0:2, 0:2]}, " + f"new_k[-2:, 0:2, 0:2]={new_k[-2:, 0:2, 0:2]}, " + f"new_v[-2:, 0:2, 0:2]={new_v[-2:, 0:2, 0:2]}") + assert paddle.allclose( + ref_k.to("cpu").to(paddle.float32), + new_k.to("cpu").to(paddle.float32)) + assert paddle.allclose( + ref_v.to("cpu").to(paddle.float32), + new_v.to("cpu").to(paddle.float32)) + + def get_splited_qkv(self, qkv: paddle.Tensor, forward_meta: ForwardMeta): + q_end = self.num_heads * self.head_dim + k_end = q_end + self.attention_metadata.num_kv_heads * self.head_dim + v_end = k_end + self.attention_metadata.num_kv_heads * self.head_dim + assert v_end == qkv.shape[ + -1], f"Shape mistach: {v_end} vs {qkv.shape[-1]}" + assert qkv.shape[0] == forward_meta.cu_seqlens_q[-1] + + q = qkv[..., 0:q_end] + k = qkv[..., q_end:k_end] + v = qkv[..., k_end:v_end] + q = q.view([-1, self.num_heads, self.head_dim]).contiguous() + k = k.view([-1, self.attention_metadata.num_kv_heads, + self.head_dim]).contiguous() + v = v.view([-1, self.attention_metadata.num_kv_heads, + self.head_dim]).contiguous() + # forward_meta.seq_lens_this_time [max_batch,] + for batch_idx in range(forward_meta.seq_lens_this_time.shape[0]): + seq_len_i = forward_meta.seq_lens_this_time[batch_idx] + if seq_len_i == 0: + continue + cached_kv_len = forward_meta.seq_lens_decoder[batch_idx][0] + cu_seq_start_q = forward_meta.cu_seqlens_q[batch_idx] + cu_seq_end_q = forward_meta.cu_seqlens_q[batch_idx + 1] + # forward_meta.rotary_embs is [2, 1, S, 1, D] + if forward_meta.rotary_embs is not None: + cos = forward_meta.rotary_embs[0, 0, + cached_kv_len:cached_kv_len + + seq_len_i, :, :] + sin = forward_meta.rotary_embs[1, 0, + cached_kv_len:cached_kv_len + + seq_len_i, :, :] + q[cu_seq_start_q:cu_seq_end_q] = apply_rope( + q[cu_seq_start_q:cu_seq_end_q], cos, sin) + k[cu_seq_start_q:cu_seq_end_q] = apply_rope( + k[cu_seq_start_q:cu_seq_end_q], cos, sin) + + return q, k, v + + def get_splited_info_by_stage(self, q, k, v, forward_meta: ForwardMeta): + prefill_info_dict = {"q": [], "k": [], "v": [], "batch_ids": []} + decode_info_dict = {"q": [], "k": [], "v": [], "batch_ids": []} + tensor_start = 0 + for batch_idx, seq_lens_this_time in enumerate( + forward_meta.seq_lens_this_time): + if seq_lens_this_time == 0: + continue + tensor_end = tensor_start + seq_lens_this_time + slice_q = q[tensor_start:tensor_end, :, :] + slice_k = k[tensor_start:tensor_end, :, :] + slice_v = v[tensor_start:tensor_end, :, :] + if seq_lens_this_time > 1: + prefill_info_dict["q"].append(slice_q) + prefill_info_dict["k"].append(slice_k) + prefill_info_dict["v"].append(slice_v) + prefill_info_dict["batch_ids"].append(batch_idx) + else: + assert seq_lens_this_time == 1 + decode_info_dict["q"].append(slice_q) + decode_info_dict["k"].append(slice_k) + decode_info_dict["v"].append(slice_v) + decode_info_dict["batch_ids"].append(batch_idx) + tensor_start = tensor_end + + if len(prefill_info_dict["batch_ids"]) > 0: + prefill_info_dict["q"] = paddle.concat(prefill_info_dict["q"], + axis=0) + prefill_info_dict["k"] = paddle.concat(prefill_info_dict["k"], + axis=0) + prefill_info_dict["v"] = paddle.concat(prefill_info_dict["v"], + axis=0) + cu_seq_ids = list( + map(lambda x: x + 1, prefill_info_dict["batch_ids"])) + prefill_info_dict["cu_seq_ids"] = [0, *cu_seq_ids] + + if len(decode_info_dict["batch_ids"]) > 0: + decode_info_dict["q"] = paddle.concat(decode_info_dict["q"], + axis=0) + decode_info_dict["k"] = paddle.concat(decode_info_dict["k"], + axis=0) + decode_info_dict["v"] = paddle.concat(decode_info_dict["v"], + axis=0) + + return prefill_info_dict, decode_info_dict + + def merge_output(self, prefill_out, decode_out, forward_meta: ForwardMeta): + assert not (prefill_out is None and decode_out + is None), "prefill and decode output cannot both be None" + if prefill_out is None: + return decode_out + elif decode_out is None: + return prefill_out + else: + merged_output = [] + prefill_tensor_start = 0 + decode_tensor_start = 0 + for seq_lens_this_time in forward_meta.seq_lens_this_time: + if seq_lens_this_time == 0: + continue + if seq_lens_this_time > 1: + tensor_end = prefill_tensor_start + seq_lens_this_time + merged_output.append( + prefill_out[prefill_tensor_start:tensor_end, :, :]) + prefill_tensor_start = tensor_end + else: + assert seq_lens_this_time == 1 + tensor_end = decode_tensor_start + seq_lens_this_time + merged_output.append( + decode_out[decode_tensor_start:tensor_end, :, :]) + decode_tensor_start = tensor_end + + assert prefill_tensor_start == prefill_out.shape[0], \ + f"prefill merged unfinished: {prefill_tensor_start} vs {prefill_out.shape[0]}" + assert decode_tensor_start == decode_out.shape[0], \ + f"decode merged unfinished: {decode_tensor_start} vs {decode_out.shape[0]}" + merged_output = paddle.concat(merged_output, axis=0) + return merged_output + + def forward_mixed( + self, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + qkv: paddle.Tensor, + compressed_kv: paddle.Tensor, + k_pe: paddle.Tensor, + layer: Attention, + forward_meta: ForwardMeta, + ): + """ + forward_mixed + """ + assert not self.use_speculate, "IluvatarAttnBackend cannot support speculate now" + layer_id = layer.layer_id + k_cache_id = layer_id * 2 + v_cache_id = k_cache_id + 1 + + assert qkv is not None + q_dim = qkv.dim() + q, k, v = self.get_splited_qkv(qkv, forward_meta) + + if self.only_use_flash_attn: + new_k, new_v = self.get_new_kv(k, v, k_cache_id, v_cache_id, + forward_meta) + if self.do_check_kv_cache: + self._check_new_kv_correctness(k, v, new_k, new_v, layer_id, + forward_meta) + + out = flash_attn_unpadded( + q, + new_k, + new_v, + cu_seqlens_q=self.attention_metadata.cu_seqlens_q, + cu_seqlens_k=self.attention_metadata.cu_seqlens_k, + max_seqlen_q=self.attention_metadata.max_context_len, + max_seqlen_k=self.attention_metadata.max_context_len, + scale=self.attention_metadata.scale, + dropout=self.attention_metadata.dropout, + causal=self.attention_metadata.causal, + return_softmax=self.attention_metadata.return_softmax)[0] + + self.update_kv_cache(k, v, k_cache_id, v_cache_id, layer_id, + forward_meta) + else: + prefill_info_dict, decode_info_dict = self.get_splited_info_by_stage( + q, k, v, forward_meta) + prefill_out, decode_out = None, None + + if len(prefill_info_dict["batch_ids"]) > 0: + prefill_out = flash_attn_unpadded( + prefill_info_dict["q"], + prefill_info_dict["k"], + prefill_info_dict["v"], + cu_seqlens_q=forward_meta.cu_seqlens_q[ + prefill_info_dict["cu_seq_ids"]], + cu_seqlens_k=forward_meta.cu_seqlens_k[ + prefill_info_dict["cu_seq_ids"]], + max_seqlen_q=self.attention_metadata.max_context_len, + max_seqlen_k=self.attention_metadata.max_context_len, + scale=self.attention_metadata.scale, + dropout=self.attention_metadata.dropout, + causal=self.attention_metadata.causal, + return_softmax=self.attention_metadata.return_softmax)[0] + self.update_kv_cache( + prefill_info_dict["k"], + prefill_info_dict["v"], + k_cache_id, + v_cache_id, + layer_id, + forward_meta, + specific_batch_ids=prefill_info_dict['batch_ids']) + + if len(decode_info_dict["batch_ids"]) > 0: + k_cache = forward_meta.caches[k_cache_id] + v_cache = forward_meta.caches[v_cache_id] + + decode_out = paged_attention( + decode_info_dict["q"], + k_cache, + v_cache, + block_tables=forward_meta.block_tables[ + decode_info_dict["batch_ids"], :], + seq_lens=forward_meta.seq_lens_decoder[ + decode_info_dict["batch_ids"], 0] + 1, + num_kv_heads=self.attention_metadata.num_kv_heads, + scale=self.attention_metadata.scale, + block_size=self.attention_metadata.block_size, + max_context_len=self.attention_metadata.max_context_len, + alibi_slopes=self.attention_metadata.alibi_slopes, + causal=self.attention_metadata.causal, + window_left=self.attention_metadata.window_left, + window_right=self.attention_metadata.window_right, + softcap=self.attention_metadata.softcap, + use_cuda_graph=self.attention_metadata.use_cuda_graph, + use_sqrt_alibi=self.attention_metadata.use_sqrt_alibi, + k=decode_info_dict["k"], + v=decode_info_dict["v"]) + + if self.do_check_kv_cache: + self.update_kv_cache( + decode_info_dict['k'], + decode_info_dict['v'], + k_cache_id, + v_cache_id, + layer_id, + forward_meta, + specific_batch_ids=decode_info_dict['batch_ids'], + debug_paged_attn=True) + + if self.do_check_kv_cache: + new_k, new_v = self.get_new_kv(k, + v, + k_cache_id, + v_cache_id, + forward_meta, + debug_paged_attn=True) + self._check_new_kv_correctness(k, v, new_k, new_v, layer_id, + forward_meta) + + out = self.merge_output(prefill_out, decode_out, forward_meta) + + if q_dim == 2: + out = out.view([-1, self.num_heads * self.head_dim]) + + return out diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index b8dc49e1b..77c8e56fd 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -57,7 +57,8 @@ class LinearBase(nn.Layer): NotImplementedError: Raised if the current platform is not a CUDA platform. """ super().__init__() - if current_platform.is_cuda() or current_platform.is_xpu(): + if current_platform.is_cuda() or current_platform.is_xpu( + ) or current_platform.is_iluvatar(): self.forward = self.forward_cuda else: raise NotImplementedError @@ -411,9 +412,14 @@ class QKVParallelLinear(ColumnParallelLinear): self.head_dim = fd_config.model_config.head_dim self.nranks = fd_config.parallel_config.tensor_parallel_degree self.num_heads_per_rank = divide(self.num_heads, self.nranks) - self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.nranks) + if self.kv_num_heads < self.nranks and self.nranks % self.kv_num_heads == 0: + self.kv_num_heads_per_rank = 1 + output_size = (self.num_heads + 2 * self.nranks) * self.head_dim + else: + self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.nranks) + output_size = (self.num_heads + + 2 * self.kv_num_heads) * self.head_dim input_size = self.hidden_size - output_size = (self.num_heads + 2 * self.kv_num_heads) * self.head_dim super().__init__(fd_config=fd_config, prefix=prefix, input_size=input_size, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 3c00ddfe4..5b49a4428 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -30,6 +30,8 @@ from .fused_moe_backend_base import MoEMethodBase if current_platform.is_cuda(): from fastdeploy.model_executor.ops.gpu import (moe_expert_dispatch, moe_expert_reduce, noaux_tc) +elif current_platform.is_iluvatar(): + from fastdeploy.model_executor.ops.iluvatar import moe_expert_dispatch, moe_expert_reduce # used for deepseek_v3 @@ -89,6 +91,23 @@ class CutlassMoEMethod(MoEMethodBase): """ Paddle Cutlass compute Fused MoE. """ + if current_platform.is_iluvatar(): + return fastdeploy.model_executor.ops.iluvatar.moe_expert_ffn( + permute_input, + token_nums_per_expert, + layer.moe_ffn1_weight, + layer.moe_ffn2_weight, + None, + (layer.moe_ffn1_weight_scale if hasattr( + layer, "moe_ffn1_weight_scale") else None), + (layer.moe_ffn2_weight_scale if hasattr( + layer, "moe_ffn2_weight_scale") else None), + (layer.moe_ffn2_in_scale + if hasattr(layer, "moe_ffn2_in_scale") else None), + expert_idx_per_token, + self.moe_quant_type, + used_in_ep_low_latency, + ) return fastdeploy.model_executor.ops.gpu.moe_expert_ffn( permute_input, token_nums_per_expert, diff --git a/fastdeploy/model_executor/layers/normalization.py b/fastdeploy/model_executor/layers/normalization.py index 6d25df345..e4f78a05d 100644 --- a/fastdeploy/model_executor/layers/normalization.py +++ b/fastdeploy/model_executor/layers/normalization.py @@ -20,6 +20,7 @@ import numpy as np import paddle from paddle import nn from paddle.incubate.nn.functional import fused_layer_norm, fused_rms_norm +from fastdeploy.platforms import current_platform from fastdeploy.config import FDConfig @@ -265,6 +266,18 @@ class LayerNorm(nn.Layer): The `residual_output` is the result of applying the normalization and possibly other operations (like linear transformation) on the `residual_input`. """ + if current_platform.is_iluvatar(): + if self.ln_weight is None and self.ln_bias is None: + out = x + if self.linear_bias is not None: + out += self.linear_bias + if residual_input is not None: + out += residual_input + return out, out + else: + return out + else: + raise NotImplementedError("Iluvatar does not support yet!") norm_out = self.norm_func( x, diff --git a/fastdeploy/model_executor/layers/rotary_embedding.py b/fastdeploy/model_executor/layers/rotary_embedding.py index 0521c4166..17c7dffc1 100644 --- a/fastdeploy/model_executor/layers/rotary_embedding.py +++ b/fastdeploy/model_executor/layers/rotary_embedding.py @@ -48,7 +48,8 @@ class ErnieRotaryEmbedding: freqs = paddle.einsum("ij,k->ijk", partial_rotary_position_ids.cast("float32"), inv_freq) - if paddle.is_compiled_with_xpu(): + if paddle.is_compiled_with_xpu( + ) or paddle.is_compiled_with_custom_device("iluvatar_gpu"): # shape: [B, S, D] rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim), dtype="float32") diff --git a/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py b/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py index 63125f4c5..2e21a85bc 100644 --- a/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py +++ b/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py @@ -64,6 +64,21 @@ def apply_penalty_multi_scores( min_dec_lens, eos_token_ids, ) + elif current_platform.is_iluvatar(): + from fastdeploy.model_executor.ops.iluvatar import \ + get_token_penalty_multi_scores + logits = get_token_penalty_multi_scores( + pre_token_ids, + logits, + repetition_penalties, + frequency_penalties, + presence_penalties, + temperature, + bad_words_token_ids, + step_idx, + min_dec_lens, + eos_token_ids, + ) else: raise NotImplementedError() diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 988fa4443..217776861 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -170,7 +170,8 @@ class Sampler(nn.Layer): """ """ super().__init__() - if current_platform.is_cuda() or current_platform.is_xpu(): + if current_platform.is_cuda() or current_platform.is_xpu( + ) or current_platform.is_iluvatar(): self.forward = self.forward_cuda else: raise NotImplementedError() diff --git a/fastdeploy/model_executor/layers/utils.py b/fastdeploy/model_executor/layers/utils.py index 255c17e7a..d635ef285 100644 --- a/fastdeploy/model_executor/layers/utils.py +++ b/fastdeploy/model_executor/layers/utils.py @@ -33,6 +33,8 @@ if current_platform.is_cuda() and current_platform.available(): "Verify environment consistency between compilation and FastDeploy installation. " "And ensure the Paddle version supports FastDeploy's custom operators" ) +if current_platform.is_iluvatar(): + from fastdeploy.model_executor.ops.iluvatar import get_padding_offset import re from fastdeploy import envs @@ -377,4 +379,4 @@ def create_and_set_parameter(layer: nn.Layer, name: str, dtype=tensor.dtype, default_initializer=paddle.nn.initializer.Constant(0), )) - getattr(layer, name).set_value(tensor) \ No newline at end of file + getattr(layer, name).set_value(tensor) diff --git a/fastdeploy/model_executor/models/tp_utils.py b/fastdeploy/model_executor/models/tp_utils.py index f360c5106..493201fa7 100644 --- a/fastdeploy/model_executor/models/tp_utils.py +++ b/fastdeploy/model_executor/models/tp_utils.py @@ -213,8 +213,14 @@ def gqa_qkv_split_func( return np.split(tensor, degree, axis=0) q_list = split_tensor(q, tensor_parallel_degree) - k_list = split_tensor(k, tensor_parallel_degree) - v_list = split_tensor(v, tensor_parallel_degree) + repeat_kv = num_key_value_heads < tensor_parallel_degree and tensor_parallel_degree % num_key_value_heads == 0 + repeat_num = tensor_parallel_degree // num_key_value_heads if repeat_kv else 1 + if repeat_kv: + k_list = split_tensor(k, num_key_value_heads) + v_list = split_tensor(v, num_key_value_heads) + else: + k_list = split_tensor(k, tensor_parallel_degree) + v_list = split_tensor(v, tensor_parallel_degree) if tensor_parallel_rank is None: res = [] @@ -236,8 +242,8 @@ def gqa_qkv_split_func( return paddle.concat( [ q_list[tensor_parallel_rank], - k_list[tensor_parallel_rank], - v_list[tensor_parallel_rank], + k_list[tensor_parallel_rank // repeat_num], + v_list[tensor_parallel_rank // repeat_num], ], axis=-1, ) @@ -245,8 +251,8 @@ def gqa_qkv_split_func( return paddle.concat( [ q_list[tensor_parallel_rank], - k_list[tensor_parallel_rank], - v_list[tensor_parallel_rank], + k_list[tensor_parallel_rank // repeat_num], + v_list[tensor_parallel_rank // repeat_num], ], axis=0, ) @@ -255,8 +261,8 @@ def gqa_qkv_split_func( return np.concatenate( [ q_list[tensor_parallel_rank], - k_list[tensor_parallel_rank], - v_list[tensor_parallel_rank], + k_list[tensor_parallel_rank // repeat_num], + v_list[tensor_parallel_rank // repeat_num], ], axis=-1, ) @@ -264,8 +270,8 @@ def gqa_qkv_split_func( return np.concatenate( [ q_list[tensor_parallel_rank], - k_list[tensor_parallel_rank], - v_list[tensor_parallel_rank], + k_list[tensor_parallel_rank // repeat_num], + v_list[tensor_parallel_rank // repeat_num], ], axis=0, ) @@ -281,8 +287,8 @@ def gqa_qkv_merge_func(num_attention_heads, num_key_value_heads, head_dim): def fn(weight_list, is_column=True): """fn""" tensor_parallel_degree = len(weight_list) - num_attention_heads = num_attention_heads // tensor_parallel_degree - num_key_value_heads = num_key_value_heads // tensor_parallel_degree + num_attention_heads = num_attention_heads // tensor_parallel_degree # noqa: F823 + num_key_value_heads = num_key_value_heads // tensor_parallel_degree # noqa: F823 is_paddle_tensor = not isinstance(weight_list[0], np.ndarray) diff --git a/fastdeploy/model_executor/models/utils.py b/fastdeploy/model_executor/models/utils.py index 350f10651..c792c07d3 100644 --- a/fastdeploy/model_executor/models/utils.py +++ b/fastdeploy/model_executor/models/utils.py @@ -196,6 +196,9 @@ def convert_ndarray_dtype(np_array: np.ndarray, np.ndarray: converted numpy ndarray instance """ source_dtype = convert_dtype(np_array.dtype) + if source_dtype == "uint16" and target_dtype == "bfloat16" and paddle.is_compiled_with_custom_device( + "iluvatar_gpu"): + return np_array.view(dtype=target_dtype) if source_dtype == "uint16" or target_dtype == "bfloat16": if paddle.is_compiled_with_xpu(): # xpu not support bf16. diff --git a/fastdeploy/model_executor/ops/__init__.py b/fastdeploy/model_executor/ops/__init__.py index 6f3618eed..ebd011e95 100644 --- a/fastdeploy/model_executor/ops/__init__.py +++ b/fastdeploy/model_executor/ops/__init__.py @@ -16,5 +16,6 @@ from . import gpu from . import cpu from . import xpu from . import npu +from . import iluvatar -__all__ = ["gpu", "cpu", "xpu", "npu"] +__all__ = ["gpu", "cpu", "xpu", "npu", "iluvatar"] diff --git a/fastdeploy/model_executor/ops/iluvatar/__init__.py b/fastdeploy/model_executor/ops/iluvatar/__init__.py new file mode 100644 index 000000000..7c1eeb6f2 --- /dev/null +++ b/fastdeploy/model_executor/ops/iluvatar/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""fastdeploy gpu ops""" + +from fastdeploy.import_ops import import_custom_ops + +PACKAGE = "fastdeploy.model_executor.ops.iluvatar" + +import_custom_ops(PACKAGE, "..base.fastdeploy_base_ops", globals()) +import_custom_ops(PACKAGE, ".fastdeploy_ops", globals()) + +from .moe_ops import iluvatar_moe_expert_ffn as moe_expert_ffn # noqa: E402, F401 +from .paged_attention import paged_attention # noqa: E402, F401 diff --git a/fastdeploy/model_executor/ops/iluvatar/moe_ops.py b/fastdeploy/model_executor/ops/iluvatar/moe_ops.py new file mode 100644 index 000000000..327cffb6e --- /dev/null +++ b/fastdeploy/model_executor/ops/iluvatar/moe_ops.py @@ -0,0 +1,101 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Optional +import paddle +from paddle.nn.quant import weight_only_linear +from paddle.incubate.nn.functional import swiglu + + +def group_gemm( + input: paddle.Tensor, + tokens_expert_prefix_sum: paddle.Tensor, + weight: paddle.Tensor, + scale: paddle.Tensor, + output: paddle.Tensor, +): + assert (input.dim() == 2 and tokens_expert_prefix_sum.dim() == 1 + and weight.dim() == 3 and scale.dim() == 2 and output.dim() == 2) + num_tokens = input.shape[0] + dim_in = input.shape[1] + dim_out = weight.shape[1] + num_experts = weight.shape[0] + + # check shape + assert tokens_expert_prefix_sum.shape == [ + num_experts, + ] + assert weight.shape == [num_experts, dim_out, dim_in] + assert scale.shape == [num_experts, dim_out] + assert output.shape == [num_tokens, dim_out] + + # check dtype + assert input.dtype in (paddle.float16, paddle.bfloat16) + assert scale.dtype == input.dtype and output.dtype == input.dtype + assert tokens_expert_prefix_sum.dtype == paddle.int64 + assert weight.dtype == paddle.int8 + + # check others + assert tokens_expert_prefix_sum.place.is_cpu_place() + assert tokens_expert_prefix_sum[-1] == num_tokens + for i in range(num_experts): + expert_start = 0 if i == 0 else tokens_expert_prefix_sum[i - 1] + expert_end = tokens_expert_prefix_sum[i] + if expert_start == expert_end: + continue + input_i = input[expert_start:expert_end] + weight_i = weight[i] + scale_i = scale[i] + # avoid d2d? + output[expert_start:expert_end] = weight_only_linear( + input_i, + weight_i, + weight_scale=scale_i, + weight_dtype="int8", + group_size=-1) + + +def iluvatar_moe_expert_ffn( + permute_input: paddle.Tensor, + tokens_expert_prefix_sum: paddle.Tensor, + ffn1_weight: paddle.Tensor, + ffn2_weight: paddle.Tensor, + ffn1_bias: Optional[paddle.Tensor], + ffn1_scale: Optional[paddle.Tensor], + ffn2_scale: Optional[paddle.Tensor], + ffn2_in_scale: Optional[paddle.Tensor], + expert_idx_per_token: Optional[paddle.Tensor], + quant_method: str, + used_in_ep_low_latency: bool, +): + assert ffn1_bias is None + assert ffn1_scale is not None + assert ffn2_scale is not None + assert ffn2_in_scale is None + assert expert_idx_per_token is None + assert quant_method in ("weight_only_int8") + assert not used_in_ep_low_latency + tokens_expert_prefix_sum_cpu = tokens_expert_prefix_sum.to("cpu") + ffn1_output = paddle.empty([permute_input.shape[0], ffn1_weight.shape[1]], + dtype=permute_input.dtype) + group_gemm(permute_input, tokens_expert_prefix_sum_cpu, ffn1_weight, + ffn1_scale, ffn1_output) + act_out = swiglu(ffn1_output) + output = paddle.empty([act_out.shape[0], ffn2_weight.shape[1]], + dtype=act_out.dtype) + group_gemm(act_out, tokens_expert_prefix_sum_cpu, ffn2_weight, ffn2_scale, + output) + return output diff --git a/fastdeploy/model_executor/ops/iluvatar/paged_attention.py b/fastdeploy/model_executor/ops/iluvatar/paged_attention.py new file mode 100644 index 000000000..f52bfe672 --- /dev/null +++ b/fastdeploy/model_executor/ops/iluvatar/paged_attention.py @@ -0,0 +1,46 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import paddle +try: + from fastdeploy.model_executor.ops.iluvatar import paged_attn +except ImportError: + paged_attn = None + + +def paged_attention(q: paddle.Tensor, + k_cache: paddle.Tensor, + v_cache: paddle.Tensor, + block_tables: paddle.Tensor, + seq_lens: paddle.Tensor, + num_kv_heads: int, + scale: float, + block_size: int, + max_context_len: int, + alibi_slopes: paddle.Tensor = None, + causal: bool = True, + window_left: int = -1, + window_right: int = -1, + softcap: float = 0.0, + use_cuda_graph: bool = False, + use_sqrt_alibi: bool = False, + k: paddle.Tensor = None, + v: paddle.Tensor = None): + output = paged_attn(q, k_cache, v_cache, block_tables, seq_lens, + alibi_slopes, k, v, num_kv_heads, scale, block_size, + max_context_len, causal, window_left, window_right, + softcap, use_cuda_graph, use_sqrt_alibi) + return output[0] if isinstance(output, list) else output diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 8f3601ff6..526197f2a 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -19,18 +19,25 @@ import paddle from fastdeploy import envs from fastdeploy.engine.config import SpeculativeConfig -from fastdeploy.model_executor.ops.gpu import ( - get_padding_offset, save_output, set_stop_value_multi_ends, - speculate_clear_accept_nums, speculate_get_output_padding_offset, - speculate_get_padding_offset, speculate_get_seq_lens_output, - speculate_save_output, speculate_set_value_by_flags_and_idx, - speculate_step_paddle, speculate_step_system_cache, speculate_update_v3, - step_paddle, step_system_cache, update_inputs, step_reschedule) from fastdeploy.platforms import current_platform +if current_platform.is_iluvatar(): + from fastdeploy.model_executor.ops.iluvatar import ( + get_padding_offset, save_output, set_stop_value_multi_ends, + step_paddle, update_inputs) +else: + from fastdeploy.model_executor.ops.gpu import ( + get_padding_offset, save_output, set_stop_value_multi_ends, + speculate_clear_accept_nums, speculate_get_output_padding_offset, + speculate_get_padding_offset, speculate_get_seq_lens_output, + speculate_save_output, speculate_set_value_by_flags_and_idx, + speculate_step_paddle, speculate_step_system_cache, + speculate_update_v3, step_paddle, step_system_cache, update_inputs, + step_reschedule) from fastdeploy.worker.output import ModelOutputData DISABLE_RECOVER = (envs.FD_DISABLED_RECOVER == "1") + def pre_process( max_len: int, input_ids: paddle.Tensor, @@ -151,6 +158,7 @@ def post_process_normal(sampled_token_ids: paddle.Tensor, save_each_rank, # save_each_rank ) + def post_process_specualate(model_output, skip_save_output: bool = False): """""" speculate_update_v3( @@ -217,7 +225,6 @@ def step_cuda( TODO(gongshaotian): normalization name """ - if speculative_config.method is not None: if enable_prefix_caching: speculate_step_system_cache( @@ -373,6 +380,17 @@ def rebuild_padding(tmp_out: paddle.Tensor, output_padding_offset, max_input_length, ) + elif current_platform.is_iluvatar(): + from fastdeploy.model_executor.ops.iluvatar import rebuild_padding + hidden_states = rebuild_padding( + tmp_out, + cum_offsets, + seq_len_this_time, + seq_lens_decoder, + seq_lens_encoder, + output_padding_offset, + max_input_length, + ) elif current_platform.is_cpu(): from fastdeploy.model_executor.ops.cpu import rebuild_padding_cpu hidden_states = rebuild_padding_cpu( diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 7e6e951fa..136197f9c 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -122,9 +122,13 @@ class TokenProcessor(object): if current_platform.is_xpu(): from fastdeploy.model_executor.ops.xpu import get_output + elif current_platform.is_iluvatar(): + from fastdeploy.model_executor.ops.iluvatar import get_output else: - from fastdeploy.model_executor.ops.gpu import ( - get_output, get_output_ep, speculate_get_output) + from fastdeploy.model_executor.ops.gpu import (get_output, + get_output_ep, + speculate_get_output + ) rank_id = self.cfg.parallel_config.local_data_parallel_id while True: @@ -413,9 +417,12 @@ class WarmUpTokenProcessor(TokenProcessor): if current_platform.is_xpu(): from fastdeploy.model_executor.ops.xpu import get_output + elif current_platform.is_iluvatar(): + from fastdeploy.model_executor.ops.iluvatar import get_output else: - from fastdeploy.model_executor.ops.gpu import ( - get_output, speculate_get_output) + from fastdeploy.model_executor.ops.gpu import (get_output, + speculate_get_output + ) while self._is_running: try: diff --git a/fastdeploy/platforms/__init__.py b/fastdeploy/platforms/__init__.py index 94282a6ec..5fbbc0d89 100644 --- a/fastdeploy/platforms/__init__.py +++ b/fastdeploy/platforms/__init__.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """ platform module """ @@ -22,7 +21,8 @@ from .cpu import CPUPlatform from .xpu import XPUPlatform from .npu import NPUPlatform from .dcu import DCUPlatform -from .base import _Backend +from .iluvatar import IluvatarPlatform +from .base import _Backend # noqa: F401 _current_platform = None @@ -40,10 +40,13 @@ def __getattr__(name: str): _current_platform = NPUPlatform() elif paddle.is_compiled_with_rocm(): _current_platform = DCUPlatform() + elif paddle.is_compiled_with_custom_device("iluvatar_gpu"): + _current_platform = IluvatarPlatform() else: _current_platform = CPUPlatform() return _current_platform elif name in globals(): return globals()[name] else: - raise AttributeError(f"No attribute named '{name}' exists in {__name__}.") + raise AttributeError( + f"No attribute named '{name}' exists in {__name__}.") diff --git a/fastdeploy/platforms/base.py b/fastdeploy/platforms/base.py index 9b0c86a99..6d93893fa 100644 --- a/fastdeploy/platforms/base.py +++ b/fastdeploy/platforms/base.py @@ -63,6 +63,12 @@ class Platform: """ return paddle.is_compiled_with_rocm() + def is_iluvatar(self) -> bool: + """ + whether platform is iluvatar gpu + """ + return paddle.is_compiled_with_custom_device("iluvatar_gpu") + @classmethod def get_attention_backend_cls(self, selected_backend): """Get the attention backend""" diff --git a/fastdeploy/platforms/iluvatar.py b/fastdeploy/platforms/iluvatar.py new file mode 100644 index 000000000..cd1892058 --- /dev/null +++ b/fastdeploy/platforms/iluvatar.py @@ -0,0 +1,26 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .base import Platform + + +class IluvatarPlatform(Platform): + device_name = "iluvatar_gpu" + + @classmethod + def get_attention_backend_cls(cls, selected_backend): + """ + get_attention_backend_cls + """ + return ( + "fastdeploy.model_executor.layers.attention.IluvatarAttnBackend") diff --git a/fastdeploy/worker/iluvatar_model_runner.py b/fastdeploy/worker/iluvatar_model_runner.py new file mode 100644 index 000000000..534853c72 --- /dev/null +++ b/fastdeploy/worker/iluvatar_model_runner.py @@ -0,0 +1,1142 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import os +import time +from typing import List, Optional + +import numpy as np +import paddle +import paddle.nn as nn +from paddleformers.utils.log import logger + +from fastdeploy.config import FDConfig +from fastdeploy.engine.request import Request +from fastdeploy.model_executor.layers.attention import get_attention_backend +from fastdeploy.model_executor.layers.attention.base_attention_backend import \ + AttentionBackend +from fastdeploy.model_executor.layers.rotary_embedding import get_rope +from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata +from fastdeploy.model_executor.layers.sample.sampler import (Sampler, + SpeculativeSampler + ) +from fastdeploy.model_executor.model_loader import get_model_from_loader +from fastdeploy.model_executor.ops.iluvatar import set_value_by_flags_and_idx +from fastdeploy.model_executor.pre_and_post_process import (post_process, + pre_process, + rebuild_padding, + step_cuda) +from fastdeploy.worker.forward_meta import ForwardMeta +from fastdeploy.worker.model_runner_base import ModelRunnerBase +from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput + + +class IluvatarModelRunner(ModelRunnerBase): + """ """ + + def __init__( + self, + fd_config: FDConfig, + device: str, # logic device + device_id: int, # physical device id + rank: int, + local_rank: int): + super().__init__(fd_config=fd_config, device=device) + self.rank = rank + self.local_rank = local_rank + self.device_id = device_id + self.speculative_method = self.fd_config.speculative_config.method + self.speculative_decoding = self.speculative_method is not None + assert not self.speculative_decoding, "Iluvatar does not support yet" + + self.guided_backend = None + + # Sampler + if not self.speculative_decoding: + self.sampler = Sampler() + else: + self.sampler = SpeculativeSampler(fd_config) + + # Lazy initialize kv cache after model loading + # self.kv_caches: list[paddle.Tensor] = [] + + # Cuda Graph + self.use_cudagraph = self.graph_opt_config.use_cudagraph + self.cudagraph_capture_sizes = list( + reversed(self.graph_opt_config.cudagraph_capture_sizes)) + self.cudagraph_num_of_warmups = self.graph_opt_config.cudagraph_num_of_warmups + self.input_ids = paddle.zeros(self.parallel_config.max_num_seqs, + dtype='int32') + + # Initialize share inputs + self._init_share_inputs(self.parallel_config.max_num_seqs) + self.infer_seed_increment = paddle.full( + shape=[self.parallel_config.max_num_seqs, 1], + fill_value=4, + dtype="int64") + self.restore_chunked_prefill_request = dict() + + # Initialize attention Backend + # Note(gonshaotian): Currently, all attention layers share one attention backend instance. + # In the future, we will expand it as a list. + self.attn_backends: list[AttentionBackend] = [] + # self.attn_metadatas: list[AttentionMetadata] = [] + self.initialize_attn_backend() + + # Forward meta store the global meta information of the forward + self.forward_meta: ForwardMeta = None + + # Postprocess Env params + os.environ["INFERENCE_MSG_QUEUE_ID"] = str( + self.local_rank + + int(self.parallel_config.engine_worker_queue_port)) + + def prefill_finished(self): + """ + check whether prefill stage finished + """ + if int(paddle.max(self.share_inputs['seq_lens_encoder'])) != 0: + return 1 + else: + return 0 + + def _init_logits_processor(self, request): + """ + init logits processor for guided decoding + """ + assert self.guided_backend is not None, "guided_backend is None, use "\ + "--guided-decoding-backend to specify the backend at server startup." + + if request.guided_json is not None: + schemata_key = ("json", request.guided_json) + elif request.guided_regex is not None: + schemata_key = ("regex", request.guided_regex) + elif request.guided_grammar is not None: + schemata_key = ("grammar", request.guided_grammar) + elif request.structural_tag is not None: + schemata_key = ("structural_tag", request.structural_tag) + + return self.guided_backend.get_logits_processor( + schemata_key=schemata_key), schemata_key + + def insert_prefill_inputs(self, req_dicts: List[Request]): + """ + Process inputs for prefill tasks and insert it to share_inputs buffer + TODO(gongshaotian): Refactor this func + """ + # NOTE(luotingdan): Lazy initialize kv cache + if "caches" not in self.share_inputs: + self.initialize_kv_cache() + + # NOTE(luotingdan): Set environment variable of prefill node + if req_dicts[-1].disaggregate_info is not None and req_dicts[ + -1].disaggregate_info["role"] == "prefill": + os.environ['PREFILL_NODE_ONE_STEP_STOP'] = "1" + + req_len = len(req_dicts) + for i in range(req_len): + request = req_dicts[i] + idx = request.idx + length = len(request.prompt_token_ids) + + prefill_tokens = [] + if (request.guided_json is not None + or request.guided_regex is not None + or request.structural_tag is not None + or request.guided_grammar is not None): + logits_info, schemata_key = self._init_logits_processor( + request) + request.logits_processor, request.logits_cached = logits_info + request.schemata_key = schemata_key + + # Is Decode Node + if req_dicts[i].disaggregate_info is not None and req_dicts[ + i].disaggregate_info["role"] == "decode": + prefill_tokens.append(request.prompt_token_ids[0]) + self.share_inputs["pre_ids"][idx:idx + + 1] = request.prompt_token_ids[-1] + self.share_inputs["input_ids"][idx:idx + 1, + 0] = request.prompt_token_ids[0] + self.share_inputs['seq_lens_encoder'][idx:idx + 1] = 0 + self.share_inputs['seq_lens_decoder'][idx:idx + 1] = length + self.share_inputs['seq_lens_this_time'][idx:idx + 1] = 1 + self.share_inputs['step_seq_lens_encoder'][idx:idx + 1] = 0 + self.share_inputs['step_seq_lens_decoder'][idx:idx + + 1] = length + self.share_inputs['step_idx'][idx:idx + 1] = 1 + + if self.speculative_decoding: + num_prefill_send_token = self.speculative_config.num_speculative_tokens + 1 + self.share_inputs['draft_tokens'][idx:idx + 1, 0:num_prefill_send_token] =\ + paddle.to_tensor(request.draft_token_ids[0:num_prefill_send_token], dtype="int64") + self.share_inputs['seq_lens_this_time'][ + idx:idx + 1] = num_prefill_send_token + else: + self.share_inputs["pre_ids"][idx:idx + 1] = -1 + self.share_inputs["step_idx"][idx:idx + 1] = 0 + self.share_inputs["input_ids"][idx:idx + + 1, :length] = np.array( + request.prompt_token_ids) + + # Use chunked prefill + if self.parallel_config.enable_chunked_prefill: + request.set("chunk_idx", 1) + logger.info( + f"prefill_chunk_info: {request.prefill_chunk_info}") + token_chunk_size = request.prefill_chunk_info[0] + self.share_inputs["seq_lens_this_time"][ + idx:idx + 1] = token_chunk_size + self.share_inputs['input_ids'][ + idx, :token_chunk_size] = np.array( + request.prompt_token_ids[:token_chunk_size]) + self.share_inputs['step_seq_lens_encoder'][ + idx:idx + 1] = token_chunk_size + self.share_inputs['seq_lens_encoder'][idx:idx + + 1] = token_chunk_size + self.share_inputs['seq_lens_decoder'][ + idx:idx + 1] = request.get("seq_lens_decoder", 0) + self.share_inputs['step_seq_lens_decoder'][ + idx:idx + 1] = request.get("seq_lens_decoder", 0) + else: + self.share_inputs['seq_lens_decoder'][ + idx:idx + 1] = request.get("seq_lens_decoder", 0) + self.share_inputs['step_seq_lens_decoder'][ + idx:idx + 1] = request.get("seq_lens_decoder", 0) + self.share_inputs['seq_lens_this_time'][idx:idx + + 1] = length + self.share_inputs['step_seq_lens_encoder'][idx:idx + + 1] = length + self.share_inputs['seq_lens_encoder'][idx:idx + 1] = length + + if len(request.eos_token_ids + ) < self.parallel_config.eos_tokens_lens: + request.eos_token_ids.append(request.eos_token_ids[0]) + self.share_inputs["eos_token_id"][:] = np.array( + request.eos_token_ids, dtype="int64").reshape(-1, 1) + + self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 0.7) + self.share_inputs["temperature"][idx:idx + 1] = request.get( + "temperature", 0.95) + self.share_inputs["penalty_score"][idx:idx + 1] = request.get( + "repetition_penalty", 1.0) + self.share_inputs["frequency_score"][idx:idx + 1] = request.get( + "frequency_penalty", 0.0) + self.share_inputs["presence_score"][idx:idx + 1] = request.get( + "presence_penalty", 0.0) + + self.share_inputs["min_dec_len"][idx:idx + 1] = request.get( + "min_tokens", 1) + self.share_inputs["max_dec_len"][idx:idx + 1] = request.get( + "max_tokens", self.model_config.max_length) + self.share_inputs["stop_flags"][idx:idx + 1] = False + + self.share_inputs["first_token_ids"][ + idx:idx + 1] = self.share_inputs["input_ids"][idx:idx + 1, :1] + self.share_inputs["ori_seq_lens_encoder"][idx:idx + 1] = length + + if request.get("seed") is not None: + self.share_inputs["infer_seed"][idx:idx + + 1] = request.get("seed") + encoder_block_num = len(request.get("block_tables")) + self.share_inputs["encoder_block_lens"][idx:idx + + 1] = encoder_block_num + self.share_inputs["block_tables"][idx:idx + 1, :] = -1 + self.share_inputs["block_tables"][ + idx:idx + 1, :encoder_block_num] = np.array( + request.block_tables, dtype="int32") + + if request.get("stop_token_ids") is not None and request.get( + "stop_seqs_len") is not None: + stop_seqs_num = len(request.get("stop_seqs_len")) + for i in range(stop_seqs_num, + self.model_config.max_stop_seqs_num): + request.stop_seqs_len.append(0) + self.share_inputs["stop_seqs_len"][:] = np.array( + request.stop_seqs_len, dtype="int32") + self.share_inputs["stop_seqs"][:stop_seqs_num, :len( + request.get("stop_token_ids")[0])] = np.array( + request.get("stop_token_ids"), dtype="int64") + + self.sampler.apply_logits_processor( + idx, request.get("logits_processor"), prefill_tokens) + + self.share_inputs["not_need_stop"][0] = True + + def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, + expected_decode_len: int): + """ Set dummy prefill inputs to share_inputs """ + # NOTE(gongshaotian): The maximum decoding length is equal to the expected decoded tokens plus the eos token + max_dec_len = expected_decode_len + 1 + full_length = min(num_tokens // batch_size, + self.parallel_config.max_model_len - max_dec_len) + input_length = int(full_length * self.parallel_config.kv_cache_ratio) + block_num = ( + input_length + self.parallel_config.block_size - 1 + ) // self.parallel_config.block_size + self.parallel_config.enc_dec_block_num + + for i in range(batch_size): + idx = i + self.share_inputs["input_ids"][idx:idx + + 1, :input_length] = np.array( + [5] * input_length) + self.share_inputs["eos_token_id"][:] = np.array( + [2], dtype="int64").reshape(-1, 1) + self.share_inputs["seq_lens_this_time"][idx:idx + 1] = input_length + self.share_inputs["step_seq_lens_encoder"][idx:idx + + 1] = input_length + self.share_inputs["seq_lens_encoder"][idx:idx + 1] = input_length + self.share_inputs["seq_lens_decoder"][idx:idx + 1] = 0 + self.share_inputs["step_idx"][idx:idx + 1] = 0 + self.share_inputs["max_dec_len"][idx:idx + 1] = max_dec_len + self.share_inputs["stop_flags"][idx:idx + 1] = False + + self.share_inputs["first_token_ids"][ + idx:idx + 1] = self.share_inputs["input_ids"][idx:idx + 1, :1] + self.share_inputs["ori_seq_lens_encoder"][idx:idx + + 1] = input_length + + self.share_inputs["encoder_block_lens"][idx:idx + 1] = block_num + self.share_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange(idx * block_num, \ + (idx + 1) * block_num, 1) + + def _init_share_inputs(self, max_num_seqs: int): + """Initialize all share buffers for model inputs. + Note: In the future, we may abandon share buffers. + """ + self.MAX_INFER_SEED = 9223372036854775806 + self.share_inputs = {} + + self.share_inputs["pre_ids"] = paddle.full( + [max_num_seqs, self.parallel_config.max_model_len], + -1, + dtype='int64') + self.share_inputs["input_ids"] = paddle.full( + [max_num_seqs, self.parallel_config.max_model_len], + self.parallel_config.pad_token_id, + dtype='int64') + self.share_inputs["eos_token_id"] = paddle.full( + [self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64') + self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], + self.model_config.top_p, + dtype='float32') + self.share_inputs["temperature"] = paddle.full( + [max_num_seqs, 1], self.model_config.temperature, dtype='float32') + self.share_inputs["penalty_score"] = paddle.full( + [max_num_seqs, 1], + self.model_config.penalty_score, + dtype='float32') + self.share_inputs["frequency_score"] = paddle.full( + [max_num_seqs, 1], + self.model_config.frequency_score, + dtype='float32') + self.share_inputs["presence_score"] = paddle.full( + [max_num_seqs, 1], + self.model_config.presence_score, + dtype='float32') + + self.share_inputs["min_dec_len"] = paddle.full( + [max_num_seqs, 1], self.model_config.min_length, dtype='int64') + self.share_inputs["max_dec_len"] = paddle.full( + [max_num_seqs, 1], self.model_config.max_length, dtype='int64') + self.share_inputs["min_length"] = paddle.full( + [max_num_seqs, 1], self.model_config.min_length, dtype='int64') + self.share_inputs["max_length"] = paddle.full( + [max_num_seqs, 1], self.model_config.max_length, dtype='int64') + self.share_inputs["seq_lens_this_time"] = paddle.full(max_num_seqs, + 0, + dtype='int32') + self.share_inputs["seq_lens_encoder"] = paddle.full([max_num_seqs, 1], + 0, + dtype='int32') + self.share_inputs["seq_lens_decoder"] = paddle.full([max_num_seqs, 1], + 0, + dtype='int32') + self.share_inputs["step_seq_lens_encoder"] = paddle.full( + [max_num_seqs, 1], 0, dtype='int32') + self.share_inputs["step_seq_lens_decoder"] = paddle.full( + [max_num_seqs, 1], 0, dtype='int32') + self.share_inputs["step_idx"] = paddle.full([max_num_seqs, 1], + 0, + dtype='int64') + self.share_inputs["not_need_stop"] = paddle.full( + [1], False, + dtype='bool').cpu() # TODO(gongshaotian): move to pinnd memory + self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], + True, + dtype='bool') + self.share_inputs["stop_nums"] = paddle.full([1], + max_num_seqs, + dtype='int64') + + self.share_inputs["bad_tokens"] = paddle.full([1], -1, dtype='int64') + self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], + -1, + dtype='int64') + self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], + False, + dtype='bool') + self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], + 0, + dtype='int32') + self.share_inputs["step_block_list"] = paddle.full([max_num_seqs], + -1, + dtype='int32') + self.share_inputs["step_lens"] = paddle.full([1], 0, dtype='int32') + self.share_inputs["recover_block_list"] = paddle.full([max_num_seqs], + -1, + dtype='int32') + self.share_inputs["recover_lens"] = paddle.full([1], 0, dtype='int32') + self.share_inputs["need_block_list"] = paddle.full([max_num_seqs], + -1, + dtype='int32') + self.share_inputs["need_block_len"] = paddle.full([1], + 0, + dtype='int32') + self.share_inputs["used_list_len"] = paddle.full([max_num_seqs], + 0, + dtype='int32') + self.share_inputs["infer_seed"] = paddle.full([max_num_seqs, 1], + 0, + dtype='int64') + self.share_inputs["first_token_ids"] = paddle.full([max_num_seqs, 1], + -1, + dtype='int64') + self.share_inputs["ori_seq_lens_encoder"] = paddle.full( + [max_num_seqs, 1], 0, dtype='int32') + self.share_inputs["system_lens"] = paddle.full([max_num_seqs, 1], + 0, + dtype='int32') + self.share_inputs["system_ids"] = paddle.full([max_num_seqs, 1], + -1, + dtype='int32') + + self.share_inputs["ids_remove_padding"] = paddle.full( + [max_num_seqs * self.parallel_config.max_model_len], + 0, + dtype='int64') + self.share_inputs["cum_offsets"] = paddle.full([max_num_seqs, 1], + 0, + dtype='int32') + self.share_inputs["padding_offset"] = paddle.full([max_num_seqs, 1], + 0, + dtype='int32') + self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], + 0, + dtype='int32') + self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], + 0, + dtype='int32') + # AttentionBackend buffers + self.share_inputs["decoder_batch_ids"] = paddle.full([max_num_seqs, 1], + 0, + dtype='int32') + self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full( + [max_num_seqs, 1], 0, dtype='int32') + + # Initialize rotary position embedding + tmp_position_ids = paddle.arange( + self.parallel_config.max_model_len).reshape((1, -1)) + # TODO(gongshaotian): move to models + self.share_inputs["rope_emb"] = get_rope( + rotary_dim=self.model_config.head_dim, + position_ids=tmp_position_ids, + base=self.model_config.rope_theta, + model_config=self.model_config) + + # Set block tables + pre_max_block_num = ( + self.parallel_config.max_model_len + + self.parallel_config.block_size - 1 + ) // self.parallel_config.block_size + self.parallel_config.enc_dec_block_num + self.share_inputs["block_tables"] = paddle.full( + [max_num_seqs, pre_max_block_num], -1, dtype='int32') + + # Initialize free list + free_list = list( + range( + self.parallel_config.max_block_num - 1, + int(self.parallel_config.max_block_num * + self.parallel_config.kv_cache_ratio) - 1, -1)) + self.free_list_len = len(free_list) + self.share_inputs["free_list"] = paddle.to_tensor(free_list, + dtype="int32") + self.share_inputs["free_list_len"] = paddle.full([1], + self.free_list_len, + dtype="int32") + + # Initialize stop seqs + self.share_inputs["stop_seqs_len"] = paddle.full( + [self.model_config.max_stop_seqs_num], 0, dtype="int32") + self.share_inputs["stop_seqs"] = paddle.full([ + self.model_config.max_stop_seqs_num, + self.model_config.stop_seqs_max_len + ], + -1, + dtype="int32") + if self.speculative_decoding: + max_draft_token_num = self.speculative_config.num_speculative_tokens + self.share_inputs["input_ids_cpu"] = paddle.full( + shape=[max_num_seqs, self.parallel_config.max_model_len], + fill_value=1, + dtype='int64').cpu() + self.share_inputs['accept_tokens'] = paddle.full( + shape=[max_num_seqs, max_draft_token_num + 1], + fill_value=0, + dtype="int64") + self.share_inputs['accept_num'] = paddle.full(shape=[max_num_seqs], + fill_value=0, + dtype='int32') + self.share_inputs['draft_tokens'] = paddle.full( + shape=[max_num_seqs, max_draft_token_num + 1], + fill_value=0, + dtype="int64") + + self.share_inputs['actual_draft_token_num'] = paddle.full( + shape=[max_num_seqs], + fill_value=max_draft_token_num, + dtype="int32") + self.share_inputs["output_cum_offsets"] = paddle.full( + shape=[max_num_seqs, 1], fill_value=0, dtype='int32') + self.share_inputs["output_padding_offset"] = paddle.full( + shape=[max_num_seqs * (max_draft_token_num + 1)], + fill_value=0, + dtype="int32") + + def _prepare_inputs(self) -> None: + """ prepare the model inputs """ + # Remove padding + ( + ids_remove_padding, + cum_offsets, + padding_offset, + cu_seqlens_q, + cu_seqlens_k, + output_cum_offsets, + output_padding_offset, + ) = pre_process( + self.parallel_config.max_model_len, self.share_inputs["input_ids"], + self.share_inputs["seq_lens_this_time"], self.speculative_decoding, + self.share_inputs["draft_tokens"] if self.speculative_decoding else + None, self.share_inputs["seq_lens_encoder"], + self.share_inputs["seq_lens_decoder"]) + cu_seqlens_k = paddle.concat([ + paddle.to_tensor([0], dtype=paddle.int32), + paddle.cumsum(self.share_inputs["seq_lens_this_time"] + + self.share_inputs["seq_lens_decoder"][:, 0]) + ]) + + self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, + False) + self.share_inputs["cum_offsets"].copy_(cum_offsets, False) + self.share_inputs["padding_offset"].copy_(padding_offset, False) + self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) + self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) + + # For speculative decoding + if self.speculative_decoding: + self.share_inputs["output_cum_offsets"].copy_( + output_cum_offsets, False) + self.share_inputs["output_padding_offset"].copy_( + output_padding_offset, False) + + # Initialize forward meta data + self.initialize_forward_meta() + + # Get sampling metadata + self.sampling_metadata = SamplingMetadata( + temperature=self.share_inputs["temperature"], + top_p=self.share_inputs["top_p"], + step_idx=self.share_inputs["step_idx"], + pre_token_ids=self.share_inputs["pre_ids"], + frequency_penalties=self.share_inputs["frequency_score"], + presence_penalties=self.share_inputs["presence_score"], + repetition_penalties=self.share_inputs["penalty_score"], + min_dec_lens=self.share_inputs["min_dec_len"], + bad_words_token_ids=self.share_inputs["bad_tokens"], + eos_token_ids=self.share_inputs["eos_token_id"], + ) + + def load_model(self) -> None: + """ load or download model """ + logger.info( + f"Starting to load model {self.model_config.architectures[0]}") + time_before_load = time.perf_counter() + # 1. Load original model + self.model = get_model_from_loader(fd_config=self.fd_config) + + # 2. Load lora model + + # 3. Load drafter model(for speculative decoding) + + time_after_load = time.perf_counter() + logger.info( + f"Model loading took {time_after_load - time_before_load} seconds") + + def get_model(self) -> nn.Layer: + """ get current model """ + return self.model + + def initialize_forward_meta(self): + """ + Initialize forward meta and attention meta data + """ + # Initialize forward meta + self.forward_meta = ForwardMeta.init_forward_meta( + self.share_inputs, self.attn_backends[0]) + + # Initialzie attention meta data + for attn_backend in self.attn_backends: + attn_backend.init_attention_metadata(self.forward_meta) + + def clear_cache(self): + """Clear cached data from shared inputs and forward metadata.""" + self.share_inputs.pop("caches", None) + if self.forward_meta is not None: + self.forward_meta.clear_caches() + + def initialize_kv_cache(self) -> None: + """ + Initialize kv cache + """ + cache_kvs = {} + max_block_num = self.num_gpu_blocks + + # Get kv cache dtype + cache_type = self.parallel_config.dtype + + if (self.quant_config + and hasattr(self.quant_config, "kv_cache_quant_type") + and self.quant_config.kv_cache_quant_type is not None): + cache_type = 'uint8' + + # Get kv cache shape + kv_cache_shape = self.attn_backends[0].get_kv_cache_shape( + max_num_blocks=max_block_num) + + if not self.parallel_config.do_profile and ( + self.parallel_config.enable_prefix_caching \ + or self.parallel_config.splitwise_role != "mixed"): + raise NotImplementedError("Iluvatar does not support yet") + else: + for i in range(self.model_config.num_layers): + + cache_kvs["key_caches_{}".format(i)] = paddle.full( + shape=kv_cache_shape, + fill_value=0, + dtype=cache_type, + ) + cache_kvs["value_caches_{}".format(i)] = paddle.full( + shape=kv_cache_shape, + fill_value=0, + dtype=cache_type, + ) + self.share_inputs["caches"] = list(cache_kvs.values()) + for value in cache_kvs.values(): + del value + paddle.device.cuda.empty_cache() + + def initialize_attn_backend(self) -> None: + """ + Initialize attention backends and forward metadata + """ + assert len(self.attn_backends) == 0 + + # TODO(gongshaotian): Get rank from config + num_heads = self.model_config.num_attention_heads // self.parallel_config.tensor_parallel_degree + self.model_config.kv_num_heads = max( + 1, + int(self.model_config.num_key_value_heads) // + self.parallel_config.tensor_parallel_degree) + head_dim = self.model_config.head_dim + + # Get the attention backend + attn_cls = get_attention_backend() + attn_backend = attn_cls(self.fd_config, + kv_num_heads=self.model_config.kv_num_heads, + num_heads=num_heads, + head_dim=head_dim) + if attn_backend is None: + raise NotImplementedError( + "Attention backend which you chose is not support by GPUModelRunner" + ) + self.attn_backends.append(attn_backend) + + def _dummy_run(self, + num_tokens: paddle.Tensor, + batch_size: paddle.Tensor, + expected_decode_len: int = 1, + in_capturing: bool = False) -> paddle.Tensor: + """ + Use dummy inputs to run before formal execution. + Args: + num_tokens: + expected_decode_len: Expected number of tokens generated + """ + self._dummy_prefill_inputs(num_tokens=num_tokens, + batch_size=batch_size, + expected_decode_len=expected_decode_len) + while True: + + # 1. Compute real num_tokens + self._prepare_inputs() + + # 2. Initialize attention backend and forward meta data + + # 3. Prepare lora + + # 4. Run model + is_decode_batch = not ((self.share_inputs["seq_lens_this_time"] + > 1).sum() > 0) + self.forward_meta.step_use_cudagraph = is_decode_batch and in_capturing + self.forward_meta.is_decode_batch = is_decode_batch + model_output = self.model( + ids_remove_padding=self.share_inputs["ids_remove_padding"], + forward_meta=self.forward_meta) + + hiddden_states = rebuild_padding( + model_output, + self.share_inputs["cum_offsets"], + self.share_inputs["seq_lens_this_time"], + self.share_inputs["seq_lens_decoder"], + self.share_inputs["seq_lens_encoder"], + None, # speculative decoding requires + self.parallel_config.max_model_len, + ) + + # 5. Execute spec decode + logits = self.model.compute_logits(hiddden_states) + + if not self.speculative_decoding: + set_value_by_flags_and_idx( + self.share_inputs["pre_ids"], + self.share_inputs["input_ids"], + self.share_inputs["seq_lens_this_time"], + self.share_inputs["seq_lens_encoder"], + self.share_inputs["seq_lens_decoder"], + self.share_inputs["step_idx"], + self.share_inputs["stop_flags"], + ) + sampled_token_ids = self.sampler(logits, + self.sampling_metadata) + if self.parallel_config.tensor_parallel_degree > 1: + paddle.distributed.broadcast(sampled_token_ids, 0) + else: + self.sampler(logits, self.sampling_metadata, + self.parallel_config.max_model_len, + self.share_inputs) + sampled_token_ids = None + if self.parallel_config.tensor_parallel_degree > 1: + paddle.distributed.broadcast( + self.share_inputs["accept_tokens"], 0) + paddle.distributed.broadcast( + self.share_inputs["accept_num"], 0) + paddle.distributed.broadcast(self.share_inputs["step_idx"], + 0) + paddle.distributed.broadcast( + self.share_inputs["stop_flags"], 0) + + # 6. post process + model_output_data = ModelOutputData( + next_tokens=self.share_inputs["next_tokens"], + stop_flags=self.share_inputs["stop_flags"], + step_idx=self.share_inputs["step_idx"], + max_dec_len=self.share_inputs["max_dec_len"], + pre_ids=self.share_inputs["pre_ids"], + seq_lens_this_time=self.share_inputs["seq_lens_this_time"], + eos_token_id=self.share_inputs["eos_token_id"], + not_need_stop=self.share_inputs["not_need_stop"], + input_ids=self.share_inputs["input_ids"], + stop_nums=self.share_inputs["stop_nums"], + seq_lens_encoder=self.share_inputs["seq_lens_encoder"], + seq_lens_decoder=self.share_inputs["seq_lens_decoder"], + is_block_step=self.share_inputs["is_block_step"], + full_hidden_states=model_output, + msg_queue_id=self.parallel_config.msg_queue_id, + mp_rank=self.local_rank, + use_ep=self.parallel_config.use_ep, + draft_tokens=self.share_inputs["draft_tokens"] + if self.speculative_decoding else None, + actual_draft_token_num=self. + share_inputs["actual_draft_token_num"] + if self.speculative_decoding else None, + accept_tokens=self.share_inputs["accept_tokens"] + if self.speculative_decoding else None, + accept_num=self.share_inputs["accept_num"] + if self.speculative_decoding else None) + + post_process(sampled_token_ids=sampled_token_ids, + model_output=model_output_data, + speculative_decoding=self.speculative_decoding, + skip_save_output=True) + + # 7. Updata 'infer_seed' and step_cuda() + self.share_inputs["infer_seed"].add_(self.infer_seed_increment) + self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED + step_cuda(self.share_inputs, self.parallel_config.block_size, + self.parallel_config.enc_dec_block_num, + self.speculative_config, + self.parallel_config.enable_prefix_caching) + + if int((self.share_inputs['seq_lens_this_time'] > 0).sum()) == 0: + break + + def _update_chunked_prefill(self, tasks): + """ + 更新chunked prefill相关参数 + """ + if not self.parallel_config.enable_chunked_prefill: + return + + for task in tasks: + if task.get("prefill_chunk_info", None) is None: + continue + + if task.chunk_idx > len(task.prefill_chunk_info): + continue + self.restore_chunked_prefill_request[task.request_id] = task + + for id, task in list(self.restore_chunked_prefill_request.items()): + idx = task.idx + logger.debug( + f"{task.request_id} chunked prefill {task.chunk_idx}/{len(task.prefill_chunk_info)}" + ) + start_idx = sum(task.prefill_chunk_info[:task.chunk_idx]) + if task.chunk_idx == len(task.prefill_chunk_info): + self.share_inputs["seq_lens_this_time"][idx:idx + 1] = 1 + self.share_inputs['seq_lens_encoder'][idx:idx + 1] = 0 + self.share_inputs["step_idx"][idx:idx + 1] = 1 + self.share_inputs["seq_lens_decoder"][ + idx:idx + 1] = start_idx + task.get("seq_lens_decoder", 0) + del self.restore_chunked_prefill_request[task.request_id] + else: + token_chunk_size = task.prefill_chunk_info[task.chunk_idx] + + self.share_inputs["seq_lens_this_time"][idx:idx + + 1] = token_chunk_size + self.share_inputs['input_ids'][ + idx, :token_chunk_size] = np.array( + task.prompt_token_ids[start_idx:start_idx + + token_chunk_size]) + self.share_inputs['seq_lens_encoder'][idx:idx + + 1] = token_chunk_size + self.share_inputs["step_idx"][idx:idx + 1] = 0 + self.share_inputs["seq_lens_decoder"][ + idx:idx + 1] = start_idx + task.get("seq_lens_decoder", 0) + task.chunk_idx += 1 + + def _dummy_sampler_run(self) -> paddle.Tensor: + """ """ + pass + + def capture_model(self) -> None: + """ + Trigger CUDA Graph capture for all shapes in 'CudaGraphConfig.cudagraph_capture_sizes' + """ + if not self.use_cudagraph: + logger.info( + "Skipping CUDA graph capture. Please check GraphOptimizationConfig" + ) + return + time_before_capture = time.perf_counter() + expected_decode_len = 1 + capture_sizes = self.cudagraph_capture_sizes.copy() + for batch_size in sorted(capture_sizes, reverse=True): + self._dummy_run(num_tokens=self.parallel_config.max_model_len, + batch_size=batch_size, + in_capturing=True, + expected_decode_len=expected_decode_len) + logger.info( + f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}" + ) + + time_after_capture = time.perf_counter() + logger.info( + f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds" + ) + + def _get_skip_idx(self, model_forward_batch): + """ + Get the index of the request that needs to be skipped during execution. + Args: + model_forward_batch: A list of requests to be executed by this runner. + Returns: + A list of indices corresponding to the requests that need to be skipped. + """ + skip_idx_list = [] + if not self.parallel_config.enable_chunked_prefill or self.guided_backend is None: + return skip_idx_list + + for task in model_forward_batch: + if task.get("prefill_chunk_info", + None) is None or task.chunk_idx >= len( + task.prefill_chunk_info): + continue + skip_idx_list.append(task.idx) + + for task in self.restore_chunked_prefill_request.values(): + if task.idx in skip_idx_list or task.chunk_idx >= len( + task.prefill_chunk_info): + continue + skip_idx_list.append(task.idx) + + return skip_idx_list + + def execute_model( + self, + model_forward_batch: Optional[List[Request]] = None, + ) -> Optional[ModelRunnerOutput]: + """ + The Entrance of model execute. + Args: + model_forward_batch: 'Request' contains information related to prompt and is an abstract + class at the server level, which is too granular for ModelRunner. + We plan to replace it with 'ModelForwardBatch'. + intermediate_tensors: + """ + # Note(@wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state. + # This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode, + # when there is data on other runner, the current runner is required to execute part of the model. + if not self.not_need_stop(): + self._execute_empty_input() + return None + + # 1. Prepare inputs of model and decoder. + # sampler create async operation + skip_idx_list = self._get_skip_idx(model_forward_batch) + self._prepare_inputs() + self.sampler.pre_process(skip_idx_list) + + # 2. Padding inputs for cuda grph + + # 3. Execute model + # TODO(gongshaotian): Use seq_lens_encoder to set is_decode_batch + is_decode_batch = not ((self.share_inputs["seq_lens_this_time"] + > 1).sum() > 0) + self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch + self.forward_meta.is_decode_batch = is_decode_batch + model_output = self.model( + ids_remove_padding=self.share_inputs["ids_remove_padding"], + forward_meta=self.forward_meta) + + hiddden_states = rebuild_padding( + model_output, + self.share_inputs["cum_offsets"], + self.share_inputs["seq_lens_this_time"], + self.share_inputs["seq_lens_decoder"], + self.share_inputs["seq_lens_encoder"], + self.share_inputs["output_padding_offset"] + if self.speculative_decoding else None, + self.parallel_config.max_model_len, + ) + + # 4. Compute logits, Sample + logits = self.model.compute_logits(hiddden_states) + + if not self.speculative_decoding: + set_value_by_flags_and_idx( + self.share_inputs["pre_ids"], + self.share_inputs["input_ids"], + self.share_inputs["seq_lens_this_time"], + self.share_inputs["seq_lens_encoder"], + self.share_inputs["seq_lens_decoder"], + self.share_inputs["step_idx"], + self.share_inputs["stop_flags"], + ) + sampled_token_ids = self.sampler( + logits, + self.sampling_metadata, + skip_idx_list, + ) + if self.parallel_config.tensor_parallel_degree > 1: + paddle.distributed.broadcast(sampled_token_ids, 0) + + else: + self.sampler(logits, self.sampling_metadata, + self.parallel_config.max_model_len, self.share_inputs) + sampled_token_ids = None + if self.parallel_config.tensor_parallel_degree > 1: + paddle.distributed.broadcast( + self.share_inputs["accept_tokens"], 0) + paddle.distributed.broadcast(self.share_inputs["accept_num"], + 0) + paddle.distributed.broadcast(self.share_inputs["step_idx"], 0) + paddle.distributed.broadcast(self.share_inputs["stop_flags"], + 0) + + # 5. Post Process + model_output_data = ModelOutputData( + next_tokens=self.share_inputs["next_tokens"], + stop_flags=self.share_inputs["stop_flags"], + step_idx=self.share_inputs["step_idx"], + max_dec_len=self.share_inputs["max_dec_len"], + pre_ids=self.share_inputs["pre_ids"], + seq_lens_this_time=self.share_inputs["seq_lens_this_time"], + eos_token_id=self.share_inputs["eos_token_id"], + not_need_stop=self.share_inputs["not_need_stop"], + input_ids=self.share_inputs["input_ids"], + stop_nums=self.share_inputs["stop_nums"], + seq_lens_encoder=self.share_inputs["seq_lens_encoder"], + seq_lens_decoder=self.share_inputs["seq_lens_decoder"], + is_block_step=self.share_inputs["is_block_step"], + full_hidden_states=model_output, + msg_queue_id=self.parallel_config.msg_queue_id, + mp_rank=self.local_rank, + use_ep=self.parallel_config.use_ep, + draft_tokens=self.share_inputs["draft_tokens"] + if self.speculative_decoding else None, + actual_draft_token_num=self.share_inputs["actual_draft_token_num"] + if self.speculative_decoding else None, + accept_tokens=self.share_inputs["accept_tokens"] + if self.speculative_decoding else None, + accept_num=self.share_inputs["accept_num"] + if self.speculative_decoding else None) + + if self.speculative_config.method in ["mtp"] and \ + self.parallel_config.splitwise_role == "prefill": + skip_save_output = True + else: + skip_save_output = False + post_process(sampled_token_ids=sampled_token_ids, + model_output=model_output_data, + save_each_rank=self.parallel_config.use_ep, + speculative_decoding=self.speculative_decoding, + skip_save_output=skip_save_output) + + # 7. Updata 'infer_seed' and step_cuda() + self.share_inputs["infer_seed"].add_(self.infer_seed_increment) + self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED + step_cuda( + self.share_inputs, + self.parallel_config.block_size, + self.parallel_config.enc_dec_block_num, + self.speculative_config, + self.parallel_config.enable_prefix_caching, + ) + + self._update_chunked_prefill(model_forward_batch) + self._add_cache(model_forward_batch) + return None + + def _add_cache(self, model_forward_batch) -> None: + """ + Add cache for guided decoding. + """ + if self.guided_backend is None: + return + + for request in model_forward_batch: + logits_cached = request.get("logits_cached", None) + if logits_cached is None or logits_cached: + continue + + raise NotImplementedError("Iluvatar does not support yet") + + def _execute_empty_input(self) -> None: + """ + In certain scenarios, such as during EP, + the runner needs to execute partial modules of the model without input data. + This requires the model to implement the `empty_input_forward` method. + """ + if hasattr(self.model, "empty_input_forward"): + self.model.empty_input_forward() + else: + raise ValueError( + f"{type(self.model)} has no attribute 'empty_input_forward") + + def profile_run(self) -> None: + """Execute a forward pass with dummy inputs to profile the memory usage of the model.""" + + # Initialize kv cache for profile run. After profile run kv cache will be reset. + # TODO(gongshaotian): Optimize the management logic of kvcache + self.num_gpu_blocks = self.parallel_config.max_block_num + self.initialize_kv_cache() + + # 1. Profile with multimodal encoder & encoder cache + + # 2. Dummy run + self._dummy_run(num_tokens=self.parallel_config.max_num_batched_tokens, + batch_size=min(self.parallel_config.max_num_seqs, 3)) + + # 3. gc + self.clear_cache() + + # paddle.device.cuda.synchronize() + + def update_share_input_block_num(self, num_gpu_blocks: int) -> None: + """ + Set a globally unified block number and update the model's shared input. + Args: + num_gpu_blocks: + """ + self.num_gpu_blocks = num_gpu_blocks + + # Reset block table and kv cache with global block num + if not (self.parallel_config.enable_prefix_caching \ + or self.parallel_config.splitwise_role != "mixed"): + self.initialize_kv_cache() + + # Reset free list + free_list = list( + range( + self.num_gpu_blocks - 1, + int(self.num_gpu_blocks * self.parallel_config.kv_cache_ratio) + - 1, -1)) + self.free_list_len = len(free_list) + self.share_inputs.update({ + "free_list": + paddle.to_tensor(free_list, dtype="int32"), + "free_list_len": + paddle.full([1], self.free_list_len, dtype="int32"), + }) + + self.parallel_config.do_profile = False + + def cal_theortical_kvcache(self): + """ + Calculate the total block memory required at the model level + TODO(gongshaotian): Move to Attention Backend + """ + """ + Byte of dtype: + - default(bf16): 2 + - cache_int8: 1 + - cache_int4: + """ + cache_quant_dtype = None + if (self.quant_config + and hasattr(self.quant_config, "kv_cache_quant_type") + and self.quant_config.kv_cache_quant_type is not None): + cache_quant_dtype = self.quant_config.kv_cache_quant_type + + if cache_quant_dtype is not None: # int8, int8_zp, fp8, fp8_zp + byte_of_dtype = 1 + else: # default + byte_of_dtype = 2 + + hidden_dim = self.model_config.head_dim * self.model_config.kv_num_heads + # NOTE(liuzichang): Implement multi-layer MTP architecture in the future + num_layers = self.model_config.num_layers + \ + self.speculative_config.num_gpu_block_expand_ratio if \ + self.speculative_method in [ + "mtp" + ] else self.model_config.num_layers + required_memory = ( + byte_of_dtype * 2 * # k + v + (self.parallel_config.block_size * hidden_dim) * num_layers) + return required_memory + + def not_need_stop(self) -> bool: + """ """ + return self.share_inputs["not_need_stop"][0] diff --git a/fastdeploy/worker/iluvatar_worker.py b/fastdeploy/worker/iluvatar_worker.py new file mode 100644 index 000000000..590e7e662 --- /dev/null +++ b/fastdeploy/worker/iluvatar_worker.py @@ -0,0 +1,143 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import gc +import os +from typing import List, Optional + +import paddle +import paddle.nn as nn + +from fastdeploy.config import FDConfig +from fastdeploy.engine.request import Request +from fastdeploy.utils import get_logger +from fastdeploy.worker.iluvatar_model_runner import IluvatarModelRunner +from fastdeploy.worker.output import ModelRunnerOutput +from fastdeploy.worker.worker_base import WorkerBase + +logger = get_logger("iluvatar_worker", "iluvatar_worker.log") + + +class IluvatarWorker(WorkerBase): + """ """ + + def __init__( + self, + fd_config: FDConfig, + local_rank: int, + rank: int, + ): + super().__init__( + fd_config=fd_config, + local_rank=local_rank, + rank=rank, + ) + pass + + def init_device(self): + """ Initialize device and Construct model runner + """ + if paddle.is_compiled_with_custom_device("iluvatar_gpu"): + # Set evironment variable + self.device = f"iluvatar_gpu:{self.local_rank}" + paddle.device.set_device(self.device) + paddle.set_default_dtype(self.parallel_config.dtype) + self.device_ids = self.parallel_config.device_ids.split(",") + + gc.collect() + else: + raise RuntimeError( + f"Not support device type: {self.device_config.device}") + + # Construct model runner + self.model_runner: IluvatarModelRunner = IluvatarModelRunner( + fd_config=self.fd_config, + device=self.device, + device_id=self.device_ids[self.local_rank], + rank=self.rank, + local_rank=self.local_rank) + + def prefill_finished(self): + """ + check whether prefill stage finished + """ + return self.model_runner.prefill_finished() + + def determine_available_memory(self) -> int: + """ + Profiles the peak memory usage of the model to determine how much + memory can be used for KV cache without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + Tip: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + # 1. Record memory state before profile run + return int(float(os.getenv("FD_ILUVATAR_KVCACHE_MEM", "3")) * 1024**3) + + def load_model(self) -> None: + """ """ + self.model_runner.load_model() + + def get_model(self) -> nn.Layer: + """ """ + return self.model_runner.get_model() + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """ """ + pass + + def execute_model( + self, + model_forward_batch: Optional[List[Request]] = None, + ) -> Optional[ModelRunnerOutput]: + """ """ + output = self.model_runner.execute_model(model_forward_batch) + return output + + def preprocess_new_task(self, req_dicts: List[Request]) -> None: + """ Process new requests and then start the decode loop + TODO(gongshaotian):The scheduler should schedule the handling of prefill, + and workers and modelrunners should not perceive it. + """ + self.model_runner.insert_prefill_inputs(req_dicts=req_dicts) + + def graph_optimize_and_warm_up_model(self) -> None: + """ + Perform the warm-up and the graph optimization + """ + # 1. Warm up model + # NOTE(gongshaotian): may be not need warm_up at this place + + # 2. Triger cuda grpah capture + self.model_runner.capture_model() + + def check_health(self) -> bool: + """ """ + return True + + def cal_theortical_kvcache(self) -> int: + """ """ + return self.model_runner.cal_theortical_kvcache() + + def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None: + """ """ + self.model_runner.update_share_input_block_num( + num_gpu_blocks=num_gpu_blocks) diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index ba7a5541a..39c33574c 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -48,6 +48,11 @@ def get_worker(fd_config: FDConfig, local_rank: int, rank: int) -> WorkerBase: if current_platform.is_xpu(): from fastdeploy.worker.xpu_worker import XpuWorker return XpuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank) + if current_platform.is_iluvatar(): + from fastdeploy.worker.iluvatar_worker import IluvatarWorker + return IluvatarWorker(fd_config=fd_config, + local_rank=local_rank, + rank=rank) class PaddleDisWorkerProc(): @@ -125,9 +130,9 @@ class PaddleDisWorkerProc(): model_weights_status: """ # init worker_ready_signal - + max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 array_size = min( - 8, self.parallel_config.tensor_parallel_degree * + max_chips_per_node, self.parallel_config.tensor_parallel_degree * self.parallel_config.expert_parallel_degree) workers_ready = np.zeros(shape=[array_size], dtype=np.int32) self.worker_ready_signal = IPCSignal( @@ -136,7 +141,8 @@ class PaddleDisWorkerProc(): dtype=np.int32, suffix=self.parallel_config.engine_pid, create=False) - self.worker_ready_signal.value[self.local_rank % 8] = 1 + self.worker_ready_signal.value[self.local_rank % + max_chips_per_node] = 1 # init worker_healthy_live_signal workers_alive = np.zeros(shape=[self.ranks], dtype=np.int32) diff --git a/requirements_iluvatar.txt b/requirements_iluvatar.txt new file mode 100644 index 000000000..75d549a83 --- /dev/null +++ b/requirements_iluvatar.txt @@ -0,0 +1,29 @@ +setuptools>=62.3.0,<80.0 +pre-commit +yapf +flake8 +ruamel.yaml +zmq +aiozmq +openai +tqdm +pynvml +uvicorn +fastapi +paddleformers +redis +etcd3 +httpx +tool_helpers +pybind11[global] +tabulate +gradio +xlwt +visualdl +setuptools-scm>=8 +prometheus-client +decord +moviepy +use-triton-in-paddle +crcmod +fastsafetensors==0.1.14 diff --git a/setup.py b/setup.py index 6957270df..0447388a4 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,7 @@ import subprocess from pathlib import Path from setuptools import Extension, find_packages, setup from setuptools.command.build_ext import build_ext +from wheel.bdist_wheel import bdist_wheel long_description = "FastDeploy: Large Language Model Serving.\n\n" long_description += "GitHub: https://github.com/PaddlePaddle/FastDeploy\n" @@ -35,7 +36,6 @@ PLAT_TO_CMAKE = { "win-arm64": "ARM64", } -from wheel.bdist_wheel import bdist_wheel class CustomBdistWheel(bdist_wheel): """Custom wheel builder for pure Python packages.""" @@ -49,10 +49,14 @@ class CustomBdistWheel(bdist_wheel): self.plat_name_supplied = True self.plat_name = 'any' + class CMakeExtension(Extension): """A setuptools Extension for CMake-based builds.""" - def __init__(self, name: str, sourcedir: str = "", version: str = None) -> None: + def __init__(self, + name: str, + sourcedir: str = "", + version: str = None) -> None: """ Initialize CMake extension. @@ -83,16 +87,12 @@ class CMakeBuild(build_ext): ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name) extdir = ext_fullpath.parent.resolve() cfg = "Debug" if int(os.environ.get("DEBUG", 0)) else "Release" - - python_version = f"{sys.version_info.major}.{sys.version_info.minor}" - + cmake_args = [ f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}", f"-DPYTHON_EXECUTABLE={sys.executable}", - f"-DCMAKE_BUILD_TYPE={cfg}", - f"-DVERSION_INFO=", - f"-DPYBIND11_PYTHON_VERSION=", - f"-DPYTHON_VERSION=", + f"-DCMAKE_BUILD_TYPE={cfg}", "-DVERSION_INFO=", + "-DPYBIND11_PYTHON_VERSION=", "-DPYTHON_VERSION=", f"-DPYTHON_INCLUDE_DIR={sys.prefix}/include/python{sys.version_info.major}.{sys.version_info.minor}", f"-DPYTHON_LIBRARY={sys.prefix}/lib/libpython{sys.version_info.major}.{sys.version_info.minor}.so" ] @@ -134,22 +134,27 @@ class CMakeBuild(build_ext): build_temp.mkdir(parents=True, exist_ok=True) subprocess.run(["cmake", ext.sourcedir, *cmake_args], - cwd=build_temp, - check=True) + cwd=build_temp, + check=True) subprocess.run(["cmake", "--build", ".", *build_args], - cwd=build_temp, - check=True) + cwd=build_temp, + check=True) + def load_requirements(): """Load dependencies from requirements.txt""" + requirements_file_name = 'requirements.txt' + if paddle.is_compiled_with_custom_device('iluvatar_gpu'): + requirements_file_name = 'requirements_iluvatar.txt' requirements_path = os.path.join(os.path.dirname(__file__), - 'requirements.txt') + requirements_file_name) with open(requirements_path, 'r') as f: return [ line.strip() for line in f if line.strip() and not line.startswith('#') ] + def get_device_type(): """Get the device type (rocm/gpu/xpu/npu/cpu) that paddle is compiled with.""" if paddle.is_compiled_with_rocm(): @@ -160,13 +165,17 @@ def get_device_type(): return "xpu" elif paddle.is_compiled_with_custom_device('npu'): return "npu" + elif paddle.is_compiled_with_custom_device('iluvatar_gpu'): + return "iluvatar-gpu" else: return "cpu" + def get_name(): """get package name""" return "fastdeploy-" + get_device_type() + cmdclass_dict = {'bdist_wheel': CustomBdistWheel} cmdclass_dict['build_ext'] = CMakeBuild FASTDEPLOY_VERSION = os.environ.get("FASTDEPLOY_VERSION", "2.0.0") @@ -187,8 +196,8 @@ setup( "model_executor/ops/gpu/*", "model_executor/ops/gpu/deep_gemm/include/**/*", "model_executor/ops/cpu/*", "model_executor/ops/xpu/*", - "model_executor/ops/xpu/libs/*", - "model_executor/ops/npu/*", "model_executor/ops/base/*", + "model_executor/ops/xpu/libs/*", "model_executor/ops/npu/*", + "model_executor/ops/base/*", "model_executor/ops/iluvatar/*", "model_executor/models/*", "model_executor/layers/*", "input/mm_processor/utils/*", "version.txt" @@ -198,9 +207,10 @@ setup( ext_modules=[ CMakeExtension( "rdma_comm", - sourcedir="fastdeploy/cache_manager/transfer_factory/kvcache_transfer", + sourcedir= + "fastdeploy/cache_manager/transfer_factory/kvcache_transfer", version=None) - ], + ] if os.getenv("ENABLE_FD_RDMA", "0") == "1" else [], cmdclass=cmdclass_dict if os.getenv("ENABLE_FD_RDMA", "0") == "1" else {}, zip_safe=False, classifiers=[