Adapt for iluvatar gpu (#2684)

This commit is contained in:
liddk1121
2025-07-07 16:53:14 +08:00
committed by GitHub
parent 2579e8fea8
commit 1b54a2831e
50 changed files with 4485 additions and 80 deletions

View File

@@ -104,6 +104,15 @@ function copy_ops(){
return
fi
if_corex=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device(\"iluvatar_gpu\"))"`
if [ "$if_corex" = "True" ]; then
DEVICE_TYPE="iluvatar-gpu"
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/iluvatar
echo -e "BASE and Iluvatar ops have been copy to fastdeploy"
return
fi
DEVICE_TYPE="cpu"
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cd ../../../../

View File

@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/extension.h"
#include "helper.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
@@ -59,7 +60,12 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &token_num,
const paddle::Tensor &seq_len) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(input_ids.place()));
auto cu_stream = dev_ctx->stream();
#else
auto cu_stream = input_ids.stream();
#endif
std::vector<int64_t> input_ids_shape = input_ids.shape();
const int bsz = seq_len.shape()[0];
const int seq_length = input_ids_shape[1];
@@ -75,7 +81,11 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k =
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
int blockSize = min((token_num_data + 32 - 1) / 32 * 32, 128);
#ifdef PADDLE_WITH_COREX
int blockSize = std::min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128);
#else
int blockSize = min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128);
#endif
GetPaddingOffsetKernel<<<bsz, 128, 0, cu_stream>>>(
padding_offset.data<int>(),
cum_offsets_out.data<int>(),

View File

@@ -14,7 +14,9 @@
#pragma once
#ifndef PADDLE_WITH_COREX
#include "glog/logging.h"
#endif
#include <fcntl.h>
#include <stdio.h>
#include <stdlib.h>
@@ -35,22 +37,35 @@ namespace cub = hipcub;
#else
#include <cub/cub.cuh>
#endif
#ifndef PADDLE_WITH_COREX
#include "nlohmann/json.hpp"
#endif
#include <fstream>
#include <iostream>
#include "env.h"
#include "paddle/extension.h"
#include "paddle/phi/core/allocator.h"
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/custom/custom_context.h"
#else
#include "paddle/phi/core/cuda_stream.h"
#endif
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#ifdef PADDLE_WITH_COREX
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
#ifndef PADDLE_WITH_COREX
using json = nlohmann::json;
#endif
#define CUDA_CHECK(call) \
do { \
@@ -237,6 +252,7 @@ inline int GetBlockSize(int vocab_size) {
}
}
#ifndef PADDLE_WITH_COREX
inline json readJsonFromFile(const std::string &filePath) {
std::ifstream file(filePath);
if (!file.is_open()) {
@@ -247,6 +263,7 @@ inline json readJsonFromFile(const std::string &filePath) {
file >> j;
return j;
}
#endif
#define cudaCheckError() \
{ \
@@ -418,6 +435,7 @@ inline std::string base64_decode(const std::string &encoded_string) {
return ret;
}
#ifndef PADDLE_WITH_COREX
template <typename T>
inline T get_relative_best(nlohmann::json *json_data,
const std::string &target_key,
@@ -430,6 +448,7 @@ inline T get_relative_best(nlohmann::json *json_data,
return default_value;
}
}
#endif
__device__ inline bool is_in_end(const int64_t id, const int64_t *end_ids,
int length) {

View File

@@ -18,7 +18,6 @@
#include <algorithm>
#include <optional>
#include "helper.h"
#include "noauxtc_kernel.h"
std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,

View File

@@ -17,11 +17,11 @@
#pragma once
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include "helper.h"
namespace cg = cooperative_groups;
constexpr unsigned FULL_WARP_MASK = 0xffffffff;
constexpr int32_t WARP_SIZE = 32;
constexpr int32_t BLOCK_SIZE = 512;
constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;

View File

@@ -91,7 +91,12 @@ std::vector<paddle::Tensor> rebuild_padding(
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(tmp_out.place()));
auto cu_stream = dev_ctx->stream();
#else
auto cu_stream = tmp_out.stream();
#endif
std::vector<int64_t> tmp_out_shape = tmp_out.shape();
const int token_num = tmp_out_shape[0];
const int dim_embed = tmp_out_shape[1];
@@ -125,7 +130,7 @@ std::vector<paddle::Tensor> rebuild_padding(
if (output_padding_offset) {
RebuildAppendPaddingKernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, tmp_out.stream()>>>(
<<<grid_size, blocksize, 0, cu_stream>>>(
reinterpret_cast<DataType_ *>(out.data<data_t>()),
reinterpret_cast<const DataType_ *>(tmp_out.data<data_t>()),
cum_offsets.data<int>(),
@@ -138,7 +143,7 @@ std::vector<paddle::Tensor> rebuild_padding(
elem_nums);
} else {
RebuildPaddingKernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, tmp_out.stream()>>>(
<<<grid_size, blocksize, 0, cu_stream>>>(
reinterpret_cast<DataType_ *>(out.data<data_t>()),
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(tmp_out.data<data_t>())),

View File

@@ -376,7 +376,6 @@ __global__ void air_topp_sampling(Counter<T> *counters, T *histograms,
}
// scan/find
constexpr int WARP_SIZE = 32;
constexpr int WARP_COUNT = NumBuckets / WARP_SIZE;
namespace cg = cooperative_groups;
cg::thread_block block = cg::this_thread_block();

View File

@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/extension.h"
#include "helper.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
@@ -51,13 +52,18 @@ void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_idx,
const paddle::Tensor &stop_flags) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(stop_flags.place()));
auto cu_stream = dev_ctx->stream();
#else
auto cu_stream = stop_flags.stream();
#endif
std::vector<int64_t> pre_ids_all_shape = pre_ids_all.shape();
int bs = seq_lens_this_time.shape()[0];
int length = pre_ids_all_shape[1];
int length_input_ids = input_ids.shape()[1];
int block_size = (bs + 32 - 1) / 32 * 32;
int block_size = (bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
set_value_by_flag_and_id<<<1, block_size, 0, cu_stream>>>(
stop_flags.data<bool>(),
const_cast<int64_t *>(pre_ids_all.data<int64_t>()),

View File

@@ -189,7 +189,7 @@ __global__ void free_and_dispatch_block(bool *stop_flags,
? tmp_used_len + 1
: max_decoder_block_num_this_seq;
#ifdef DEBUG_STEP
printf("#### ori_step_len:%d, ori_free_list_len:%d, used_len:%d \n",
printf("#### ori_step_len:%d, ori_free_list_len:%d, used_len:%d \n",
ori_step_len, ori_free_list_len, used_len);
#endif
while (ori_step_len > 0 && ori_free_list_len >= used_len) {
@@ -323,7 +323,12 @@ void StepPaddle(const paddle::Tensor &stop_flags,
const paddle::Tensor &first_token_ids,
const int block_size,
const int encoder_decoder_block_num) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(seq_lens_this_time.place()));
auto cu_stream = dev_ctx->stream();
#else
auto cu_stream = seq_lens_this_time.stream();
#endif
const int bsz = seq_lens_this_time.shape()[0];
const int block_num_per_seq = block_tables.shape()[1];
const int length = input_ids.shape()[1];

View File

@@ -74,11 +74,16 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
}
}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(topk_ids.place()));
auto cu_stream = dev_ctx->stream();
#else
auto cu_stream = topk_ids.stream();
#endif
std::vector<int64_t> shape = topk_ids.shape();
int64_t bs_now = shape[0];
int64_t end_length = end_ids.shape()[0];
int block_size = (bs_now + 32 - 1) / 32 * 32;
int block_size = (bs_now + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
set_value_by_flags<<<1, block_size, 0, cu_stream>>>(
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int64_t *>(topk_ids.data<int64_t>()),

View File

@@ -21,6 +21,7 @@
#include <sys/types.h>
#include <unistd.h>
#include "paddle/extension.h"
#include "helper.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
@@ -88,7 +89,12 @@ void GetStopFlagsMultiSeqs(const paddle::Tensor &topk_ids,
PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64);
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(topk_ids.place()));
auto cu_stream = dev_ctx->stream();
#else
auto cu_stream = topk_ids.stream();
#endif
std::vector<int64_t> shape = topk_ids.shape();
std::vector<int64_t> stop_seqs_shape = stop_seqs.shape();
int bs_now = shape[0];
@@ -96,7 +102,7 @@ void GetStopFlagsMultiSeqs(const paddle::Tensor &topk_ids,
int stop_seqs_max_len = stop_seqs_shape[1];
int pre_ids_len = pre_ids.shape()[1];
int block_size = (stop_seqs_bs + 31) / 32 * 32;
int block_size = (stop_seqs_bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
set_value_by_stop_seqs<<<bs_now, block_size, 0, cu_stream>>>(
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int64_t *>(topk_ids.data<int64_t>()),

View File

@@ -132,7 +132,12 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(logits.place()));
auto cu_stream = dev_ctx->stream();
#else
auto cu_stream = logits.stream();
#endif
std::vector<int64_t> shape = logits.shape();
auto repeat_times =
paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place());
@@ -143,7 +148,7 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
int64_t end_length = eos_token_id.shape()[0];
int block_size = (bs + 32 - 1) / 32 * 32;
int block_size = (bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
min_length_logits_process<<<1, block_size, 0, cu_stream>>>(
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(logits.data<data_t>())),
@@ -154,8 +159,12 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
length,
end_length);
block_size = (length_id + 32 - 1) / 32 * 32;
block_size = (length_id + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
#ifdef PADDLE_WITH_COREX
block_size = std::min(block_size, 512);
#else
block_size = min(block_size, 512);
#endif
update_repeat_times<<<bs, block_size, 0, cu_stream>>>(
pre_ids.data<int64_t>(),
cur_len.data<int64_t>(),
@@ -164,8 +173,12 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
length,
length_id);
block_size = (length + 32 - 1) / 32 * 32;
block_size = (length + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
#ifdef PADDLE_WITH_COREX
block_size = std::min(block_size, 512);
#else
block_size = min(block_size, 512);
#endif
update_value_by_repeat_times<DataType_><<<bs, block_size, 0, cu_stream>>>(
repeat_times.data<int>(),
reinterpret_cast<DataType_ *>(
@@ -180,8 +193,12 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
bs,
length);
block_size = (length_bad_words + 32 - 1) / 32 * 32;
block_size = (length_bad_words + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
#ifdef PADDLE_WITH_COREX
block_size = std::min(block_size, 512);
#else
block_size = min(block_size, 512);
#endif
ban_bad_words<DataType_><<<bs, block_size, 0, cu_stream>>>(
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(logits.data<data_t>())),

View File

@@ -75,11 +75,17 @@ void UpdateInputes(const paddle::Tensor &stop_flags,
const paddle::Tensor &stop_nums,
const paddle::Tensor &next_tokens,
const paddle::Tensor &is_block_step) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(input_ids.place()));
auto cu_stream = dev_ctx->stream();
#else
auto cu_stream = input_ids.stream();
#endif
const int max_bsz = stop_flags.shape()[0];
const int now_bsz = seq_lens_this_time.shape()[0];
const int input_ids_stride = input_ids.shape()[1];
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
update_inputs_kernel<1024><<<1, 1024, 0, input_ids.stream()>>>(
update_inputs_kernel<1024><<<1, 1024, 0, cu_stream>>>(
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_encoder.data<int>()),

View File

@@ -0,0 +1,55 @@
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "fused_moe_op.h"
namespace phi {
template <typename T, int VecSize>
__global__ void moe_token_type_ids_kernel(T *gating_output,
const int *moe_token_type_ids_out,
const int num_rows,
const int num_experts,
const int k) {
const int moe_token_index = blockIdx.x * blockDim.x + threadIdx.x;
if (moe_token_index >= num_rows) {
return;
}
gating_output[moe_token_index * 2] =
gating_output[moe_token_index * 2] +
(moe_token_type_ids_out[moe_token_index]) * -1e10;
gating_output[moe_token_index * 2 + 1] =
gating_output[moe_token_index * 2 + 1] +
(1 - moe_token_type_ids_out[moe_token_index]) * -1e10;
}
template <typename T>
void moe_token_type_ids_kernelLauncher(T *gating_output,
const int *moe_token_type_ids_out,
const int num_rows,
const int num_experts,
const int k,
cudaStream_t stream) {
const int blocks = num_rows * k / 512 + 1;
const int threads = 512;
moe_token_type_ids_kernel<T, 1><<<blocks, 512, 0, stream>>>(
gating_output, moe_token_type_ids_out, num_rows, num_experts, k);
}
} // namespace phi

View File

@@ -0,0 +1,127 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <string>
#include <sstream>
#include "cub/cub.cuh"
namespace phi {
static const float HALF_FLT_MAX = 65504.F;
static const float HALF_FLT_MIN = -65504.F;
static inline size_t AlignTo16(const size_t& input) {
static constexpr int ALIGNMENT = 16;
return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT);
}
class CubKeyValueSorter {
public:
CubKeyValueSorter() : num_experts_(0), num_bits_(sizeof(int) * 8) {}
explicit CubKeyValueSorter(const int num_experts)
: num_experts_(num_experts),
num_bits_(static_cast<int>(log2(num_experts)) + 1) {}
void update_num_experts(const int num_experts) {
num_experts_ = num_experts;
num_bits_ = static_cast<int>(log2(num_experts)) + 1;
}
size_t getWorkspaceSize(const size_t num_key_value_pairs,
bool descending = false) {
num_key_value_pairs_ = num_key_value_pairs;
size_t required_storage = 0;
int* null_int = nullptr;
if (descending) {
cub::DeviceRadixSort::SortPairsDescending(NULL,
required_storage,
null_int,
null_int,
null_int,
null_int,
num_key_value_pairs,
0,
32);
} else {
cub::DeviceRadixSort::SortPairs(NULL,
required_storage,
null_int,
null_int,
null_int,
null_int,
num_key_value_pairs,
0,
num_bits_);
}
return required_storage;
}
template <typename KeyT>
void run(void* workspace,
const size_t workspace_size,
const KeyT* keys_in,
KeyT* keys_out,
const int* values_in,
int* values_out,
const size_t num_key_value_pairs,
bool descending,
cudaStream_t stream) {
size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs);
size_t actual_ws_size = workspace_size;
if (expected_ws_size > workspace_size) {
std::stringstream err_ss;
err_ss << "[Error][CubKeyValueSorter::run]\n";
err_ss << "Error. The allocated workspace is too small to run this "
"problem.\n";
err_ss << "Expected workspace size of at least " << expected_ws_size
<< " but got problem size " << workspace_size << "\n";
throw std::runtime_error(err_ss.str());
}
if (descending) {
cub::DeviceRadixSort::SortPairsDescending(workspace,
actual_ws_size,
keys_in,
keys_out,
values_in,
values_out,
num_key_value_pairs,
0,
32,
stream);
} else {
cub::DeviceRadixSort::SortPairs(workspace,
actual_ws_size,
keys_in,
keys_out,
values_in,
values_out,
num_key_value_pairs,
0,
num_bits_,
stream);
}
}
private:
size_t num_key_value_pairs_;
int num_experts_;
int num_bits_;
};
} // namespace phi

View File

@@ -0,0 +1,990 @@
// /*
// * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION &
// * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
// *
// * Licensed under the Apache License, Version 2.0 (the "License");
// * you may not use this file except in compliance with the License.
// * You may obtain a copy of the License at
// *
// * http://www.apache.org/licenses/LICENSE-2.0
// *
// * Unless required by applicable law or agreed to in writing, software
// * distributed under the License is distributed on an "AS IS" BASIS,
// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// * See the License for the specific language governing permissions and
// * limitations under the License.
// */
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include "fused_moe_imp_op.h"
#include "fused_moe_helper.h"
// Ignore CUTLASS warnings about type punning
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wunused-function"
// #include "paddle/phi/backends/gpu/gpu_info.h"
#pragma GCC diagnostic pop
#include "helper.h"
namespace phi {
struct GpuLaunchConfig {
dim3 block_per_grid;
dim3 thread_per_block;
};
inline GpuLaunchConfig Get1DBlocksAnd2DGridsMoe(const int64_t cols) {
int blocks_x = cols;
int blocks_y = 1;
int blocks_z = 1;
if (blocks_x > 1024) {
blocks_y = 256;
blocks_x = (blocks_x + blocks_y - 1) / blocks_y;
}
GpuLaunchConfig config;
config.block_per_grid.x = blocks_x;
config.block_per_grid.y = blocks_y;
config.block_per_grid.z = blocks_z;
return config;
}
// ====================== Softmax things ===============================
// We have our own implementation of softmax here so we can support transposing
// the output in the softmax kernel when we extend this module to support
// expert-choice routing.
template <typename T, int TPB>
__launch_bounds__(TPB) __global__
void group_moe_softmax(const T* input,
T* output,
T* softmax_max_prob,
const int64_t num_cols,
const int64_t softmax_num_rows) {
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
__shared__ float normalizing_factor;
__shared__ float float_max;
__shared__ float max_out;
int globalIdx = blockIdx.x + blockIdx.y * gridDim.x;
if (globalIdx >= softmax_num_rows) {
return;
}
const int64_t thread_row_offset = globalIdx * num_cols;
cub::Sum sum;
float threadData(-FLT_MAX);
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii;
threadData = max(static_cast<float>(input[idx]), threadData);
}
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
if (threadIdx.x == 0) {
float_max = maxElem;
}
__syncthreads();
threadData = 0;
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii;
threadData += exp((static_cast<float>(input[idx]) - float_max));
}
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
if (threadIdx.x == 0) {
normalizing_factor = 1.f / Z;
}
__syncthreads();
threadData = 0;
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii;
const float val =
exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
output[idx] = T(val);
threadData = max(static_cast<float>(T(val)), threadData);
}
const float maxOut = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
if (threadIdx.x == 0) {
// group max probs
max_out = 1.f / maxOut;
softmax_max_prob[globalIdx] = T(max_out);
}
__syncthreads();
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii;
// group softmax normalization
output[idx] = output[idx] * static_cast<T>(max_out);
}
}
template <typename T, int TPB, typename IdxT = int>
__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
T* output,
IdxT* indices,
int* source_rows,
T* softmax_max_prob,
const int64_t num_experts,
const int64_t k,
const int64_t num_rows) {
using cub_kvp = cub::KeyValuePair<int, T>;
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
cub_kvp thread_kvp;
cub::ArgMax arg_max;
const int block_row = blockIdx.x + blockIdx.y * gridDim.x;
if (block_row >= num_rows) {
return;
}
const bool should_process_row = true;
const int thread_read_offset = block_row * num_experts;
for (int k_idx = 0; k_idx < k; ++k_idx) {
thread_kvp.key = 0;
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
cub_kvp inp_kvp;
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
const int idx = thread_read_offset + expert;
inp_kvp.key = expert;
inp_kvp.value = inputs_after_softmax[idx];
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
const IdxT prior_winning_expert = indices[k * block_row + prior_k];
if (prior_winning_expert == expert) {
inp_kvp = thread_kvp;
}
}
thread_kvp = arg_max(inp_kvp, thread_kvp);
}
const cub_kvp result_kvp =
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0) {
const int idx = k * block_row + k_idx;
// restore normalized probes
output[idx] = result_kvp.value / T(softmax_max_prob[idx]);
indices[idx] = should_process_row ? result_kvp.key : num_experts;
source_rows[idx] = k_idx * num_rows + block_row;
}
__syncthreads();
}
}
template <typename T, int TPB>
__launch_bounds__(TPB) __global__ void moe_softmax(const T* input,
T* output,
const int64_t num_cols,
const int64_t num_rows) {
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
__shared__ float normalizing_factor;
__shared__ float float_max;
int globalIdx = blockIdx.x + blockIdx.y * gridDim.x;
if (globalIdx >= num_rows) {
return;
}
const int64_t thread_row_offset = globalIdx * num_cols;
cub::Sum sum;
float threadData(-FLT_MAX);
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii;
threadData = max(static_cast<float>(input[idx]), threadData);
}
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
if (threadIdx.x == 0) {
float_max = maxElem;
}
__syncthreads();
threadData = 0;
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii;
threadData += exp((static_cast<float>(input[idx]) - float_max));
}
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
if (threadIdx.x == 0) {
normalizing_factor = 1.f / Z;
}
__syncthreads();
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii;
const float val =
exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
output[idx] = T(val);
}
}
template <typename T, int TPB, typename IdxT = int>
__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
const T* bias,
T* output,
IdxT* indices,
int* source_rows,
const int64_t num_experts,
const int64_t k,
const int64_t num_rows) {
using cub_kvp = cub::KeyValuePair<int, T>;
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
cub_kvp thread_kvp;
cub::ArgMax arg_max;
const int block_row = blockIdx.x + blockIdx.y * gridDim.x;
if (block_row >= num_rows) {
return;
}
const bool should_process_row = true;
const int thread_read_offset = block_row * num_experts;
for (int k_idx = 0; k_idx < k; ++k_idx) {
thread_kvp.key = 0;
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
cub_kvp inp_kvp;
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
const int idx = thread_read_offset + expert;
inp_kvp.key = expert;
inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ;
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
const IdxT prior_winning_expert = indices[k * block_row + prior_k];
if (prior_winning_expert == expert) {
inp_kvp = thread_kvp;
}
}
thread_kvp = arg_max(inp_kvp, thread_kvp);
}
const cub_kvp result_kvp =
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0) {
const int idx = k * block_row + k_idx;
output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
indices[idx] = should_process_row ? result_kvp.key : num_experts;
source_rows[idx] = k_idx * num_rows + block_row;
}
__syncthreads();
}
}
template <typename T, int TPB, typename IdxT = int>
__launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
const T* bias,
T* output,
IdxT* indices,
int* source_rows,
const int64_t num_experts,
const int64_t k,
const int64_t num_rows) {
// softmax
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
__shared__ float normalizing_factor;
__shared__ float float_max;
int globalIdx = blockIdx.x + blockIdx.y * gridDim.x;
if (globalIdx >= num_rows) {
return;
}
const int64_t thread_row_offset = globalIdx * num_experts;
const int64_t idx = thread_row_offset+threadIdx.x;
cub::Sum sum;
float threadData = (threadIdx.x < num_experts) ? static_cast<float>(input[idx]) :(-FLT_MAX);
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
if (threadIdx.x == 0) {
float_max = maxElem;
}
__syncthreads();
float threadDataSub = threadData - float_max;
float threadDataExp = exp(threadDataSub);
const auto Z = BlockReduce(tmpStorage).Reduce(threadDataExp, sum);
if (threadIdx.x == 0) {
normalizing_factor = 1.f / Z;
}
__syncthreads();
T val = T(threadDataExp * normalizing_factor);
// top_k
using cub_kvp = cub::KeyValuePair<int, T>;
using BlockReduceP = cub::BlockReduce<cub_kvp, TPB>;
__shared__ typename BlockReduceP::TempStorage tmpStorageP;
cub_kvp thread_kvp;
cub::ArgMax arg_max;
for (int k_idx = 0; k_idx < k; ++k_idx) {
thread_kvp.key = 0;
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
if (threadIdx.x < num_experts) {
cub_kvp inp_kvp;
int expert = threadIdx.x;
inp_kvp.key = expert;
inp_kvp.value = bias ? val + bias[expert] : val;
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
const IdxT prior_winning_expert = indices[k * globalIdx + prior_k];
if (prior_winning_expert == expert) {
inp_kvp = thread_kvp;
}
}
thread_kvp = arg_max(inp_kvp, thread_kvp);
}
const cub_kvp result_kvp =
BlockReduceP(tmpStorageP).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0) {
const int cur_idx = k * globalIdx + k_idx;
output[cur_idx] = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
indices[cur_idx] = result_kvp.key;
source_rows[cur_idx] = k_idx * num_rows + globalIdx;
}
__syncthreads();
}
}
template <typename T, int TPB, typename IdxT = int>
__launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_softmax,
const T* bias,
T* output,
IdxT* indices,
int* source_rows,
const int64_t num_experts,
const int64_t k,
const int64_t num_rows) {
using cub_kvp = cub::KeyValuePair<int, T>;
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
cub_kvp thread_kvp;
cub::ArgMax arg_max;
const int block_row = blockIdx.x + blockIdx.y * gridDim.x;
if (block_row >= num_rows) {
return;
}
const bool should_process_row = true;
const int thread_read_offset = block_row * num_experts;
T weight_sum = static_cast<T>(0);
extern __shared__ char smem[];
T* row_outputs = reinterpret_cast<T*>(smem);
for (int k_idx = 0; k_idx < k; ++k_idx) {
thread_kvp.key = 0;
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
cub_kvp inp_kvp;
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
const int idx = thread_read_offset + expert;
inp_kvp.key = expert;
inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ;
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
const int prior_winning_expert = indices[k * block_row + prior_k];
if (prior_winning_expert == expert) {
inp_kvp = thread_kvp;
}
}
thread_kvp = arg_max(inp_kvp, thread_kvp);
}
const cub_kvp result_kvp =
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0) {
const int idx = k * block_row + k_idx;
// output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
indices[idx] = should_process_row ? result_kvp.key : num_experts;
source_rows[idx] = k_idx * num_rows + block_row;
T row_out = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
row_outputs[k_idx] = row_out;
weight_sum += row_out;
}
__syncthreads();
}
if (threadIdx.x < WARP_SIZE) {
weight_sum = __shfl_sync(0xffffffff, weight_sum, 0);
}
if (threadIdx.x < k) {
output[k * block_row + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum;
}
}
template <typename T, int TPB, typename IdxT = int>
__launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* input,
const T* bias,
T* output,
IdxT* indices,
int* source_rows,
const int64_t num_experts,
const int64_t k,
const int64_t num_rows) {
// softmax
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
__shared__ float normalizing_factor;
__shared__ float float_max;
int globalIdx = blockIdx.x + blockIdx.y * gridDim.x;
if (globalIdx >= num_rows) {
return;
}
const int64_t thread_row_offset = globalIdx * num_experts;
const int64_t idx = thread_row_offset+threadIdx.x;
cub::Sum sum;
float threadData = (threadIdx.x < num_experts) ? static_cast<float>(input[idx]) :(-FLT_MAX);
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
if (threadIdx.x == 0) {
float_max = maxElem;
}
__syncthreads();
float threadDataSub = threadData - float_max;
float threadDataExp = exp(threadDataSub);
const auto Z = BlockReduce(tmpStorage).Reduce(threadDataExp, sum);
if (threadIdx.x == 0) {
normalizing_factor = 1.f / Z;
}
__syncthreads();
T val = T(threadDataExp * normalizing_factor);
// top_k
using cub_kvp = cub::KeyValuePair<int, T>;
using BlockReduceP = cub::BlockReduce<cub_kvp, TPB>;
__shared__ typename BlockReduceP::TempStorage tmpStorageP;
cub_kvp thread_kvp;
cub::ArgMax arg_max;
T weight_sum = static_cast<T>(0);
extern __shared__ char smem[];
T* row_outputs = reinterpret_cast<T*>(smem);
for (int k_idx = 0; k_idx < k; ++k_idx) {
thread_kvp.key = 0;
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
if (threadIdx.x < num_experts) {
cub_kvp inp_kvp;
int expert = threadIdx.x;
inp_kvp.key = expert;
inp_kvp.value = bias ? val + bias[expert] : val;
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
const IdxT prior_winning_expert = indices[k * globalIdx + prior_k];
if (prior_winning_expert == expert) {
inp_kvp = thread_kvp;
}
}
thread_kvp = arg_max(inp_kvp, thread_kvp);
}
const cub_kvp result_kvp =
BlockReduceP(tmpStorageP).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0) {
const int cur_idx = k * globalIdx + k_idx;
T row_out = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
row_outputs[k_idx] = row_out;
weight_sum += row_out;
indices[cur_idx] = result_kvp.key;
source_rows[cur_idx] = k_idx * num_rows + globalIdx;
}
__syncthreads();
}
if (threadIdx.x < WARP_SIZE) {
weight_sum = __shfl_sync(0xffffffff, weight_sum, 0);
}
if (threadIdx.x < k) {
output[k * globalIdx + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum;
}
}
namespace detail {
// Constructs some constants needed to partition the work across threads at
// compile time.
template <typename T, int EXPERTS, int BYTES_PER_LDG>
struct TopkConstants {
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T);
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 ||
EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0,
"");
static constexpr int VECs_PER_THREAD =
std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
};
} // namespace detail
template <typename T, typename IdxT = int>
void topk_gating_softmax_kernelLauncher(const T* input,
const T* gating_correction_bias,
T* output,
T* softmax,
IdxT* indices,
int* source_row,
T* softmax_max_prob,
const int64_t num_rows,
const int64_t num_experts,
const int64_t k,
const bool group_moe,
cudaStream_t stream,
const bool topk_only_mode = false) {
if (topk_only_mode) {
static constexpr int TPB = 256;
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
moe_top_k<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
input, gating_correction_bias, output, indices, source_row, num_experts, k, num_rows);
return;
}
static constexpr int WARPS_PER_TB = 4;
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
case N: { \
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB>( \
input, output, indices, source_row, num_rows, num_experts, k, stream); \
break; \
}
int64_t tem_num_experts = num_experts;
if(gating_correction_bias != nullptr) tem_num_experts = 0;
switch (tem_num_experts) {
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2)
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(4)
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(8)
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(16)
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(32)
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(64)
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(128)
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(256)
default: {
static constexpr int TPB = 256;
if (group_moe) {
const int group_experts = num_experts / k;
const int softmax_num_rows = num_rows * k;
const auto config_softmax = Get1DBlocksAnd2DGridsMoe(softmax_num_rows);
group_moe_softmax<T, TPB>
<<<config_softmax.block_per_grid, TPB, 0, stream>>>(
input,
softmax,
softmax_max_prob,
group_experts,
softmax_num_rows);
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
moe_top_k<T, TPB>
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
output,
indices,
source_row,
softmax_max_prob,
num_experts,
k,
num_rows);
} else {
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
moe_softmax<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
input, softmax, num_experts, num_rows);
moe_top_k<T, TPB>
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
gating_correction_bias,
output,
indices,
source_row,
num_experts,
k,
num_rows);
}
}
}
}
// ========================== Permutation things
// =======================================
// Duplicated and permutes rows for MoE. In addition, reverse the permutation
// map to help with finalizing routing.
// "expanded_x_row" simply means that the number of values is num_rows x k. It
// is "expanded" since we will have to duplicate some rows in the input matrix
// to match the dimensions. Duplicates will always get routed to separate
// experts in the end.
// Note that the expanded_dest_row_to_expanded_source_row map referred to here
// has indices in the range (0, k*rows_in_input - 1). However, it is set up so
// that index 0, rows_in_input, 2*rows_in_input ... (k-1)*rows_in_input all map
// to row 0 in the original matrix. Thus, to know where to read in the source
// matrix, we simply take the modulus of the expanded index.
template <typename T, int VecSize>
__global__ void initialize_moe_routing_kernel(
const T* unpermuted_input,
T* permuted_output,
const int* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row,
const int64_t num_rows,
const int64_t active_rows,
const int64_t cols,
const int64_t num_rows_k) {
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec;
// Reverse permutation map.
// I do this so that later, we can use the source -> dest map to do the k-way
// reduction and unpermuting. I need the reverse map for that reduction to
// allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1
// thread block will be responsible for all k summations.
const int expanded_dest_row = blockIdx.x + blockIdx.y * gridDim.x;
if (expanded_dest_row >= num_rows_k) return;
const int expanded_source_row =
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
if (threadIdx.x == 0) {
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
expanded_dest_row;
}
if ((blockIdx.x + blockIdx.y * gridDim.x) < active_rows) {
// Duplicate and permute rows
const int source_row = expanded_source_row % num_rows;
const T* source_row_ptr = unpermuted_input + source_row * cols;
T* dest_row_ptr = permuted_output + expanded_dest_row * cols;
for (int tid = threadIdx.x * VecSize; tid < cols;
tid += blockDim.x * VecSize) {
// dest_row_ptr[tid] = source_row_ptr[tid];
Load<T, VecSize>(&source_row_ptr[tid], &src_vec);
Store<T, VecSize>(src_vec, &dest_row_ptr[tid]);
}
}
}
template <typename T>
void initialize_moe_routing_kernelLauncher(
const T* unpermuted_input,
T* permuted_output,
const int* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row,
const int64_t num_rows,
const int64_t active_rows,
const int64_t cols,
const int64_t k,
cudaStream_t stream) {
const int threads = std::min(cols, int64_t(1024));
constexpr int max_pack_size = 16 / sizeof(T);
const auto config_initialize = Get1DBlocksAnd2DGridsMoe(num_rows * k);
if (cols % max_pack_size == 0) {
initialize_moe_routing_kernel<T, max_pack_size>
<<<config_initialize.block_per_grid, threads, 0, stream>>>(
unpermuted_input,
permuted_output,
expanded_dest_row_to_expanded_source_row,
expanded_source_row_to_expanded_dest_row,
num_rows,
k * active_rows,
cols,
num_rows * k);
} else {
initialize_moe_routing_kernel<T, 1>
<<<config_initialize.block_per_grid, threads, 0, stream>>>(
unpermuted_input,
permuted_output,
expanded_dest_row_to_expanded_source_row,
expanded_source_row_to_expanded_dest_row,
num_rows,
k * active_rows,
cols,
num_rows * k);
}
}
// ============================== Infer GEMM sizes
// =================================
__device__ inline int find_total_elts_leq_target(int* sorted_indices,
const int64_t arr_length,
const int64_t target) {
int64_t low = 0, high = arr_length - 1, target_location = -1;
while (low <= high) {
int64_t mid = (low + high) / 2;
if (sorted_indices[mid] > target) {
high = mid - 1;
} else {
low = mid + 1;
target_location = mid;
}
}
return target_location + 1;
}
// Final kernel to unpermute and scale
// This kernel unpermutes the original data, does the k-way reduction and
// performs the final skip connection.
template <typename T, int RESIDUAL_NUM>
__global__ void finalize_moe_routing_kernel(
const T* expanded_permuted_rows,
T* reduced_unpermuted_output,
const T* bias,
const float* scales,
const int* expanded_source_row_to_expanded_dest_row,
const int* expert_for_source_row,
const int64_t cols,
const int64_t k,
const int64_t compute_bias,
const bool norm_topk_prob,
const float routed_scaling_factor,
const int64_t num_rows) {
const int original_row = blockIdx.x + blockIdx.y * gridDim.x;
// const int original_row = blockIdx.x;
// const int num_rows = gridDim.x;
if (original_row >= num_rows) return;
T* reduced_row_ptr = reduced_unpermuted_output + original_row * cols;
for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) {
T thread_output{0.f};
float row_rescale{0.f};
for (int k_idx = 0; k_idx < k; ++k_idx) {
const int expanded_original_row = original_row + k_idx * num_rows;
const int expanded_permuted_row =
expanded_source_row_to_expanded_dest_row[expanded_original_row];
const int64_t k_offset = original_row * k + k_idx;
const float row_scale = scales[k_offset];
row_rescale = row_rescale + row_scale;
const T* expanded_permuted_rows_row_ptr =
expanded_permuted_rows + expanded_permuted_row * cols;
const int expert_idx = expert_for_source_row[k_offset];
const T* bias_ptr = bias ? bias + expert_idx * cols : nullptr;
const T bias_value = bias_ptr ? bias_ptr[tid] : T{0.f};
thread_output =
static_cast<float>(thread_output) +
row_scale * static_cast<float>(
expanded_permuted_rows_row_ptr[tid] +
bias_value *
static_cast<T>(static_cast<float>(compute_bias)));
}
thread_output = static_cast<float>(thread_output) /
(norm_topk_prob ? row_rescale : 1.0f) *
routed_scaling_factor;
reduced_row_ptr[tid] = thread_output;
}
}
template <typename T>
void finalize_moe_routing_kernelLauncher(
const T* expanded_permuted_rows,
T* reduced_unpermuted_output,
const T* bias,
const float* scales,
const int* expanded_source_row_to_expanded_dest_row,
const int* expert_for_source_row,
const int64_t num_rows,
const int64_t cols,
const int64_t k,
const int64_t compute_bias,
const bool norm_topk_prob,
const float routed_scaling_factor,
cudaStream_t stream) {
const int threads = std::min(cols, int64_t(1024));
const auto config_final = Get1DBlocksAnd2DGridsMoe(num_rows);
finalize_moe_routing_kernel<T, 1>
<<<config_final.block_per_grid, threads, 0, stream>>>(
expanded_permuted_rows,
reduced_unpermuted_output,
bias,
scales,
expanded_source_row_to_expanded_dest_row,
expert_for_source_row,
cols,
k,
compute_bias,
norm_topk_prob,
routed_scaling_factor,
num_rows);
}
// ========================= TopK Softmax specializations
// ===========================
template void topk_gating_softmax_kernelLauncher(const float*,
const float*,
float*,
float*,
int*,
int*,
float*,
const int64_t,
const int64_t,
const int64_t,
const bool,
cudaStream_t,
const bool);
template void topk_gating_softmax_kernelLauncher(const half*,
const half*,
half*,
half*,
int*,
int*,
half*,
const int64_t,
const int64_t,
const int64_t,
const bool,
cudaStream_t,
const bool);
#ifdef PADDLE_CUDA_BF16
template void topk_gating_softmax_kernelLauncher(const __nv_bfloat16*,
const __nv_bfloat16*,
__nv_bfloat16*,
__nv_bfloat16*,
int*,
int*,
__nv_bfloat16*,
const int64_t,
const int64_t,
const int64_t,
const bool,
cudaStream_t,
const bool);
#endif
// ===================== Specializations for init routing
// =========================
template void initialize_moe_routing_kernelLauncher(const float*,
float*,
const int*,
int*,
const int64_t,
const int64_t,
const int64_t,
const int64_t,
cudaStream_t);
template void initialize_moe_routing_kernelLauncher(const half*,
half*,
const int*,
int*,
const int64_t,
const int64_t,
const int64_t,
const int64_t,
cudaStream_t);
#ifdef PADDLE_CUDA_BF16
template void initialize_moe_routing_kernelLauncher(const __nv_bfloat16*,
__nv_bfloat16*,
const int*,
int*,
const int64_t,
const int64_t,
const int64_t,
const int64_t,
cudaStream_t);
#endif
// ==================== Specializations for final routing
// ===================================
template void finalize_moe_routing_kernelLauncher(const float*,
float*,
const float*,
const float*,
const int*,
const int*,
const int64_t,
const int64_t,
const int64_t,
const int64_t,
const bool,
const float,
cudaStream_t);
template void finalize_moe_routing_kernelLauncher(const half*,
half*,
const half*,
const float*,
const int*,
const int*,
const int64_t,
const int64_t,
const int64_t,
const int64_t,
const bool,
const float,
cudaStream_t);
#ifdef PADDLE_CUDA_BF16
template void finalize_moe_routing_kernelLauncher(const __nv_bfloat16*,
__nv_bfloat16*,
const __nv_bfloat16*,
const float*,
const int*,
const int*,
const int64_t,
const int64_t,
const int64_t,
const int64_t,
const bool,
const float,
cudaStream_t);
#endif
} // namespace phi

View File

@@ -0,0 +1,311 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Ignore CUTLASS warnings about type punning
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wunused-function"
#pragma once
#include "fused_moe_helper.h"
#include "fused_moe_op.h"
#pragma GCC diagnostic pop
#include "helper.h"
__global__ void compute_total_rows_before_expert_kernel(
int* sorted_experts,
const int64_t sorted_experts_len,
const int64_t num_experts,
int64_t* total_rows_before_expert) {
const int expert = blockIdx.x * blockDim.x + threadIdx.x;
if (expert >= num_experts) return;
total_rows_before_expert[expert] =
phi::find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert);
}
void compute_total_rows_before_expert(int* sorted_indices,
const int64_t total_indices,
const int64_t num_experts,
int64_t* total_rows_before_expert,
cudaStream_t stream) {
const int threads = std::min(int64_t(1024), num_experts);
const int blocks = (num_experts + threads - 1) / threads;
compute_total_rows_before_expert_kernel<<<blocks, threads, 0, stream>>>(
sorted_indices, total_indices, num_experts, total_rows_before_expert);
}
template <paddle::DataType T>
void MoeDispatchKernel(const paddle::Tensor& input,
const paddle::Tensor& gating_output,
const paddle::optional<paddle::Tensor>& gating_correction_bias,
const int moe_topk,
const bool group_moe,
const bool topk_only_mode,
const int num_rows,
const int hidden_size,
const int expert_num,
paddle::Tensor* permute_input,
paddle::Tensor* tokens_expert_prefix_sum,
paddle::Tensor* permute_indices_per_token,
paddle::Tensor* top_k_weight,
paddle::Tensor* top_k_indices) {
using namespace phi;
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto place = input.place();
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(input.place()));
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
if (group_moe) {
// Check if expert_num is divisible by moe_topk, else throw an error
PADDLE_ENFORCE_EQ(expert_num % moe_topk,
0,
common::errors::InvalidArgument(
"The number of experts (expert_num) "
"must be divisible by moe_topk. "
"Got expert_num = %d and moe_topk = %d.",
expert_num,
moe_topk));
}
const int num_moe_inputs = AlignTo16(num_rows * moe_topk);
const int bytes = num_moe_inputs * sizeof(int);
CubKeyValueSorter sorter_;
sorter_.update_num_experts(expert_num);
const int sorter_ws_size_bytes =
AlignTo16(sorter_.getWorkspaceSize(moe_topk * num_rows));
const int sort_tmp_in_out_size = num_moe_inputs * 2 * sizeof(int);
paddle::Tensor ws_ptr_tensor =
GetEmptyTensor({bytes + sorter_ws_size_bytes + sort_tmp_in_out_size},
paddle::DataType::INT8,
place);
int8_t* ws_ptr = ws_ptr_tensor.data<int8_t>();
int* source_rows_ = reinterpret_cast<int*>(ws_ptr);
int8_t* sorter_ws_ptr = reinterpret_cast<int8_t*>(ws_ptr + bytes);
int* permuted_experts_ =
reinterpret_cast<int*>(sorter_ws_ptr + sorter_ws_size_bytes);
int* permuted_rows_ = permuted_experts_ + num_moe_inputs;
int* expert_for_source_row = top_k_indices->data<int>();
float* softmax_max_prob = nullptr;
if (group_moe) {
paddle::Tensor softmax_max_prob_tensor =
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
// (TODO: check fill sucess ?)
paddle::experimental::fill(softmax_max_prob_tensor, 0.f);
softmax_max_prob = softmax_max_prob_tensor.data<float>();
}
float* softmax_out_;
const bool is_pow_2 =
(expert_num != 0) && ((expert_num & (expert_num - 1)) == 0);
paddle::Tensor softmax_buffer;
if (!is_pow_2 || expert_num > 256 || group_moe || gating_correction_bias) {
softmax_buffer = GetEmptyTensor(
{num_rows * expert_num}, paddle::DataType::FLOAT32, place);
softmax_out_ = softmax_buffer.data<float>();
} else {
softmax_out_ = nullptr;
}
topk_gating_softmax_kernelLauncher<float>(gating_output.data<float>(),
gating_correction_bias ? gating_correction_bias.get().data<float>() : nullptr,
top_k_weight->data<float>(),
softmax_out_,
expert_for_source_row,
source_rows_,
softmax_max_prob,
num_rows,
expert_num,
moe_topk,
group_moe,
stream,
topk_only_mode);
sorter_.run(reinterpret_cast<void*>(sorter_ws_ptr),
sorter_ws_size_bytes,
expert_for_source_row,
permuted_experts_,
source_rows_,
permuted_rows_,
moe_topk * num_rows,
false,
stream);
initialize_moe_routing_kernelLauncher(
input.data<data_t>(),
permute_input->data<data_t>(),
permuted_rows_,
permute_indices_per_token->data<int32_t>(),
num_rows,
num_rows,
hidden_size,
moe_topk,
stream);
compute_total_rows_before_expert(
permuted_experts_,
moe_topk * num_rows,
expert_num,
tokens_expert_prefix_sum->data<int64_t>(),
stream);
}
std::vector<paddle::Tensor> MoeExpertDispatch(
const paddle::Tensor& input,
const paddle::Tensor& gating_output,
const paddle::optional<paddle::Tensor>& gating_correction_bias,
const paddle::optional<paddle::Tensor>& w4a8_in_scale,
const int moe_topk,
const bool group_moe,
const bool topk_only_mode) {
const auto input_type = input.dtype();
auto place = input.place();
int token_rows = 0;
auto input_dims = input.dims();
auto gating_dims = gating_output.dims();
const int expert_num = gating_dims[gating_dims.size() - 1];
if (input_dims.size() == 3) {
token_rows = input_dims[0] * input_dims[1];
} else {
token_rows = input_dims[0];
}
const int num_rows = token_rows;
const int hidden_size = input.dims()[input_dims.size() - 1];
auto permute_input =
GetEmptyTensor({moe_topk * num_rows, hidden_size}, input_type, place);
// correspond to the weighted coefficients of the results from each expert.
auto top_k_weight =
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
auto top_k_indices =
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::INT32, place);
auto tokens_expert_prefix_sum =
GetEmptyTensor({expert_num}, paddle::DataType::INT64, place);
auto permute_indices_per_token =
GetEmptyTensor({moe_topk, num_rows}, paddle::DataType::INT32, place);
switch (input_type) {
case paddle::DataType::BFLOAT16:
MoeDispatchKernel<paddle::DataType::BFLOAT16>(input,
gating_output,
gating_correction_bias,
moe_topk,
group_moe,
topk_only_mode,
num_rows,
hidden_size,
expert_num,
&permute_input,
&tokens_expert_prefix_sum,
&permute_indices_per_token,
&top_k_weight,
&top_k_indices);
break;
case paddle::DataType::FLOAT16:
MoeDispatchKernel<paddle::DataType::FLOAT16>(input,
gating_output,
gating_correction_bias,
moe_topk,
group_moe,
topk_only_mode,
num_rows,
hidden_size,
expert_num,
&permute_input,
&tokens_expert_prefix_sum,
&permute_indices_per_token,
&top_k_weight,
&top_k_indices);
break;
default:
PD_THROW("Unsupported data type for MoeDispatchKernel");
}
return {permute_input,
tokens_expert_prefix_sum,
permute_indices_per_token,
top_k_weight,
top_k_indices,
top_k_indices};
}
std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& gating_output_shape,
const paddle::optional<std::vector<int64_t>>& bias_shape,
const int moe_topk) {
int token_rows = -1;
if (input_shape.size() == 3) {
token_rows = input_shape[0] * input_shape[1];
} else {
token_rows = input_shape[0];
}
const int expert_num = gating_output_shape[gating_output_shape.size() - 1];
const int num_rows = token_rows;
const int hidden_size = input_shape[input_shape.size() - 1];
return {{moe_topk * num_rows, hidden_size},
{expert_num},
{moe_topk, num_rows},
{num_rows, moe_topk},
{num_rows, moe_topk},
{num_rows, moe_topk}};
}
std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
const paddle::DataType& input_dtype,
const paddle::DataType& gating_output_dtype,
const paddle::optional<paddle::DataType>& bias_type,
const int moe_topk) {
return {input_dtype,
paddle::DataType::INT64,
paddle::DataType::INT32,
paddle::DataType::FLOAT32,
paddle::DataType::INT32,
paddle::DataType::INT32};
}
PD_BUILD_STATIC_OP(moe_expert_dispatch)
.Inputs({"input", "gating_output", paddle::Optional("gating_correction_bias"),
paddle::Optional("w4a8_in_scale")})
.Outputs({"permute_input",
"tokens_expert_prefix_sum",
"permute_indices_per_token",
"top_k_weight",
"top_k_indices",
"expert_idx_per_token"})
.Attrs({"moe_topk:int", "group_moe:bool", "topk_only_mode:bool"})
.SetKernelFn(PD_KERNEL(MoeExpertDispatch))
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertDispatchInferDtype));

View File

@@ -0,0 +1,155 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Ignore CUTLASS warnings about type punning
#pragma once
#include "helper.h"
#include "fused_moe_helper.h"
#include "fused_moe_op.h"
template <paddle::DataType T>
void MoeReduceKernel(const paddle::Tensor& ffn_out,
const paddle::Tensor& top_k_weight,
const paddle::Tensor& permute_indices_per_token,
const paddle::Tensor& top_k_indices,
const paddle::optional<paddle::Tensor>& ffn2_bias,
const bool norm_topk_prob,
const float routed_scaling_factor,
const int num_rows,
const int hidden_size,
const int topk,
paddle::Tensor* output) {
using namespace phi;
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(ffn_out.place()));
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
finalize_moe_routing_kernelLauncher(
ffn_out.data<data_t>(),
output->data<data_t>(),
ffn2_bias ? ffn2_bias->data<data_t>() : nullptr,
top_k_weight.data<float>(),
permute_indices_per_token.data<int32_t>(),
top_k_indices.data<int>(),
num_rows,
hidden_size,
topk,
static_cast<int>(1),
norm_topk_prob,
routed_scaling_factor,
stream);
}
paddle::Tensor MoeExpertReduceFunc(
const paddle::Tensor& ffn_out,
const paddle::Tensor& top_k_weight,
const paddle::Tensor& permute_indices_per_token,
const paddle::Tensor& top_k_indices,
const paddle::optional<paddle::Tensor>& ffn2_bias,
const bool norm_topk_prob,
const float routed_scaling_factor) {
const auto input_type = ffn_out.dtype();
auto place = ffn_out.place();
const int topk = top_k_indices.dims()[1];
const int num_rows = ffn_out.dims()[0] / topk;
const int hidden_size = ffn_out.dims()[1];
auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place);
switch (input_type) {
case paddle::DataType::BFLOAT16:
MoeReduceKernel<paddle::DataType::BFLOAT16>(
ffn_out,
top_k_weight,
permute_indices_per_token,
top_k_indices,
ffn2_bias,
norm_topk_prob,
routed_scaling_factor,
num_rows,
hidden_size,
topk,
&output);
break;
case paddle::DataType::FLOAT16:
MoeReduceKernel<paddle::DataType::BFLOAT16>(
ffn_out,
top_k_weight,
permute_indices_per_token,
top_k_indices,
ffn2_bias,
norm_topk_prob,
routed_scaling_factor,
num_rows,
hidden_size,
topk,
&output);
break;
default:
PD_THROW("Unsupported data type for MoeDispatchKernel");
}
return output;
}
std::vector<paddle::Tensor> MoeExpertReduce(
const paddle::Tensor& ffn_out,
const paddle::Tensor& top_k_weight,
const paddle::Tensor& permute_indices_per_token,
const paddle::Tensor& top_k_indices,
const paddle::optional<paddle::Tensor>& ffn2_bias,
const bool norm_topk_prob,
const float routed_scaling_factor) {
return {MoeExpertReduceFunc(ffn_out,
top_k_weight,
permute_indices_per_token,
top_k_indices,
ffn2_bias,
norm_topk_prob,
routed_scaling_factor)};
}
std::vector<std::vector<int64_t>> MoeExpertReduceInferShape(
const std::vector<int64_t>& ffn_out_shape,
const std::vector<int64_t>& top_k_weight_shape,
const std::vector<int64_t>& permute_indices_per_token_shape,
const std::vector<int64_t>& top_k_indices_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_bias_shape) {
return {ffn_out_shape};
}
std::vector<paddle::DataType> MoeExpertReduceInferDtype(
const paddle::DataType& ffn_out_dtype,
const paddle::DataType& top_k_weight_dtype,
const paddle::DataType& permute_indices_per_token_dtype,
const paddle::DataType& top_k_indices_dtype,
const paddle::optional<paddle::DataType>& ffn2_bias_dtype) {
return {ffn_out_dtype};
}
PD_BUILD_STATIC_OP(moe_expert_reduce)
.Inputs({"ffn_out",
"top_k_weight",
"permute_indices_per_token",
"top_k_indices",
paddle::Optional("ffn2_bias")})
.Outputs({"output"})
.Attrs({"norm_topk_prob:bool", "routed_scaling_factor:float"})
.SetKernelFn(PD_KERNEL(MoeExpertReduce))
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertReduceInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertReduceInferDtype));

View File

@@ -0,0 +1,337 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "iluvatar_context.h"
#define CUINFER_CHECK(func) \
do { \
cuinferStatus_t status = (func); \
if (status != CUINFER_STATUS_SUCCESS) { \
std::cerr << "Error in file " << __FILE__ << " on line " << __LINE__ << ": " \
<< cuinferGetErrorString(status) << std::endl; \
throw std::runtime_error("CUINFER_CHECK ERROR"); \
} \
} while (0)
template <paddle::DataType T>
void PagedAttnKernel(const paddle::Tensor& q,
const paddle::Tensor& k_cache,
const paddle::Tensor& v_cache,
const paddle::Tensor& block_table,
const paddle::Tensor& seq_lens,
const paddle::optional<paddle::Tensor> &alibi_slopes,
const paddle::optional<paddle::Tensor> &k,
const paddle::optional<paddle::Tensor> &v,
int num_kv_heads,
float scale,
int block_size,
int max_context_len,
bool causal,
int window_left,
int window_right,
float softcap,
bool enable_cuda_graph,
bool use_sqrt_alibi,
paddle::Tensor& out) {
if (alibi_slopes) {
PADDLE_ENFORCE_EQ(alibi_slopes.get().dtype(),
paddle::DataType::FLOAT32,
common::errors::InvalidArgument(
"paged_attention expects alibi_slopes float tensor"));
PADDLE_ENFORCE_EQ(alibi_slopes.get().is_contiguous(),
true,
common::errors::InvalidArgument(
"paged_attention expects alibi_slopes is contiguous"));
}
// check dtype and contiguous
const auto& dtype = q.dtype();
cudaDataType_t data_type;
if (dtype == paddle::DataType::FLOAT16) {
data_type = CUDA_R_16F;
} else if (dtype == paddle::DataType::BFLOAT16) {
data_type = CUDA_R_16BF;
} else {
common::errors::InvalidArgument("paged_attention support half and bfloat16 now");
}
PADDLE_ENFORCE_EQ(k_cache.dtype(),
dtype,
common::errors::InvalidArgument(
"k_cache dtype must be the same as query dtype"));
PADDLE_ENFORCE_EQ(k_cache.is_contiguous(),
true,
common::errors::InvalidArgument(
"paged_attention expects k_cache is contiguous"));
PADDLE_ENFORCE_EQ(v_cache.dtype(),
dtype,
common::errors::InvalidArgument(
"v_cache dtype must be the same as query dtype"));
PADDLE_ENFORCE_EQ(v_cache.is_contiguous(),
true,
common::errors::InvalidArgument(
"paged_attention expects v_cache is contiguous"));
PADDLE_ENFORCE_EQ(block_table.dtype(),
paddle::DataType::INT32,
common::errors::InvalidArgument(
"block_table dtype must be int32"));
PADDLE_ENFORCE_EQ(block_table.is_contiguous(),
true,
common::errors::InvalidArgument(
"paged_attention expects block_table is contiguous"));
PADDLE_ENFORCE_EQ(seq_lens.dtype(),
paddle::DataType::INT32,
common::errors::InvalidArgument(
"seq_lens dtype must be int32"));
PADDLE_ENFORCE_EQ(seq_lens.is_contiguous(),
true,
common::errors::InvalidArgument(
"paged_attention expects seq_lens is contiguous"));
// check dim and shape
// out: [num_seqs, num_heads, head_size]
// q: [num_seqs, num_heads, head_size]
// k_chache: [num_blocks, kv_num_heads, block_size, head_size]
// v_chache: [num_blocks, kv_num_heads, block_size, head_size]
// block_table: [num_seqs, max_num_blocks_per_seq]
// seq_lens: [num_seqs]
const auto& q_dims = q.dims();
PADDLE_ENFORCE_EQ(q_dims.size(),
3,
common::errors::InvalidArgument(
"paged_attn receive query dims is "
"[num_seqs, num_heads, head_size]"));
PADDLE_ENFORCE_EQ(out.dims().size(),
3,
common::errors::InvalidArgument(
"paged_attn receive out dims is "
"[num_seqs, num_heads, head_size]"));
PADDLE_ENFORCE_EQ(k_cache.dims(),
v_cache.dims(),
common::errors::InvalidArgument(
"paged_attn requires k_cache size is the "
"same as v_cache"));
const auto& kv_cache_dims = k_cache.dims();
PADDLE_ENFORCE_EQ(kv_cache_dims.size(),
4,
common::errors::InvalidArgument(
"paged_attn receive kv cache dims is "
"[num_blocks, kv_num_heads, block_size, head_size]"));
const auto& block_table_dims = block_table.dims();
PADDLE_ENFORCE_EQ(block_table_dims.size(),
2,
common::errors::InvalidArgument(
"paged_attn receive block_table dims is "
"[num_seqs, max_num_blocks_per_seq]"));
const auto& seq_lens_dims = seq_lens.dims();
PADDLE_ENFORCE_EQ(seq_lens_dims.size(),
1,
common::errors::InvalidArgument(
"paged_attn receive seq_lens dims is [num_seqs]"));
int num_seqs = q_dims[0];
int num_heads = q_dims[1];
int head_size = q_dims[2];
int max_num_blocks_per_seq = block_table_dims[1];
int q_stride = q.strides()[0];
int num_blocks = kv_cache_dims[0];
PADDLE_ENFORCE_EQ(kv_cache_dims[1],
num_kv_heads,
common::errors::InvalidArgument(
"kv_cache_dims[1] must be equal to num_kv_head"));
PADDLE_ENFORCE_EQ(kv_cache_dims[2],
block_size,
common::errors::InvalidArgument(
"kv_cache_dims[2] must be equal to block_size"));
PADDLE_ENFORCE_EQ(kv_cache_dims[3],
head_size,
common::errors::InvalidArgument(
"kv_cache_dims[3] must be equal to head_size"));
PADDLE_ENFORCE_EQ(block_table_dims[0],
num_seqs,
common::errors::InvalidArgument(
"block_table_dims[0] must be equal to num_seqs"));
PADDLE_ENFORCE_EQ(seq_lens_dims[0],
num_seqs,
common::errors::InvalidArgument(
"seq_lens_dims[0] must be equal to num_seqs"));
int kv_block_stride = k_cache.strides()[0];
int kv_head_stride = k_cache.strides()[1];
const float *alibi_slopes_ptr = alibi_slopes ? alibi_slopes.get().data<float>() : nullptr;
const void *key_ptr = k ? k.get().data() : nullptr;
const void *value_ptr = v ? v.get().data() : nullptr;
size_t workspace_size = 0;
void* workspace_ptr = nullptr;
CUINFER_CHECK(cuInferPageAttentionGetWorkspaceV7(
num_seqs, num_heads, num_kv_heads, head_size, block_size, max_context_len, &workspace_size));
CUDA_CHECK(cudaMalloc((void**)&workspace_ptr, workspace_size));
CUDA_CHECK(cudaMemset(workspace_ptr, 0xff, workspace_size));
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(q.place()));
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
cuinferHandle_t cuinfer_handle = iluvatar::getContextInstance()->getIxInferHandle();
PageAttentionWithKVCacheArguments args{
static_cast<float>(scale), 1.0, 1.0, static_cast<float>(softcap), window_left, window_right,
causal, use_sqrt_alibi, enable_cuda_graph, false, alibi_slopes_ptr, key_ptr, value_ptr, workspace_ptr};
CUINFER_CHECK(cuInferPageAttentionV7(cuinfer_handle,
out.data(),
data_type,
q.data(),
data_type,
num_seqs,
num_heads,
num_kv_heads,
head_size,
q_stride,
kv_block_stride,
kv_head_stride,
k_cache.data(),
data_type,
v_cache.data(),
data_type,
block_size,
max_num_blocks_per_seq,
max_context_len,
block_table.data<int32_t>(),
seq_lens.data<int32_t>(),
args));
CUDA_CHECK(cudaFree(workspace_ptr));
}
std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
const paddle::Tensor& k_cache,
const paddle::Tensor& v_cache,
const paddle::Tensor& block_table,
const paddle::Tensor& seq_lens,
const paddle::optional<paddle::Tensor> &alibi_slopes,
const paddle::optional<paddle::Tensor> &k,
const paddle::optional<paddle::Tensor> &v,
int num_kv_heads,
float scale,
int block_size,
int max_context_len,
bool causal,
int window_left,
int window_right,
float softcap,
bool enable_cuda_graph,
bool use_sqrt_alibi) {
const auto dtype = q.dtype();
auto out = paddle::empty_like(q, dtype);
switch (dtype) {
case paddle::DataType::BFLOAT16:
PagedAttnKernel<paddle::DataType::BFLOAT16>(q,
k_cache,
v_cache,
block_table,
seq_lens,
alibi_slopes,
k,
v,
num_kv_heads,
scale,
block_size,
max_context_len,
causal,
window_left,
window_right,
softcap,
enable_cuda_graph,
use_sqrt_alibi,
out);
break;
case paddle::DataType::FLOAT16:
PagedAttnKernel<paddle::DataType::FLOAT16>(q,
k_cache,
v_cache,
block_table,
seq_lens,
alibi_slopes,
k,
v,
num_kv_heads,
scale,
block_size,
max_context_len,
causal,
window_left,
window_right,
softcap,
enable_cuda_graph,
use_sqrt_alibi,
out);
break;
default:
PD_THROW("Unsupported data type for Paged attn");
}
return {out};
}
std::vector<std::vector<int64_t>> PagedAttnInferShape(const std::vector<int64_t>& q_shape,
const std::vector<int64_t>& k_cache_shape,
const std::vector<int64_t>& v_cache_shape,
const std::vector<int64_t>& block_table_shape,
const std::vector<int64_t>& seq_lens_shape,
const std::vector<int64_t>& alibi_slopes_shape,
const std::vector<int64_t>& k_shape,
const std::vector<int64_t>& v_shape) {
return {q_shape};
}
std::vector<paddle::DataType> PagedAttnInferDtype(const paddle::DataType& q_dtype,
const paddle::DataType& k_cache_dtype,
const paddle::DataType& v_cache_dtype,
const paddle::DataType& block_table_dtype,
const paddle::DataType& seq_lens_dtype,
const paddle::DataType& alibi_slopes_dtype,
const paddle::DataType& k_dtype,
const paddle::DataType& v_dtype) {
return {q_dtype};
}
PD_BUILD_STATIC_OP(paged_attn)
.Inputs({"q", "k_cache", "v_cache", "block_table", "seq_lens", paddle::Optional("alibi_slopes"), paddle::Optional("k"), paddle::Optional("v")})
.Outputs({"out"})
.Attrs({"num_kv_heads:int",
"scale:float",
"block_size:int",
"max_context_len:int",
"causal:bool",
"window_left:int",
"window_right:int",
"softcap:float",
"enable_cuda_graph:bool",
"use_sqrt_alibi:bool"})
.SetKernelFn(PD_KERNEL(PagedAttn))
.SetInferShapeFn(PD_INFER_SHAPE(PagedAttnInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(PagedAttnInferDtype));
PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("paged_attn", &PagedAttn, "paged attn function");
}

View File

@@ -0,0 +1,37 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iluvatar_context.h"
#include <memory>
#include <mutex>
namespace iluvatar {
IluvatarContext::~IluvatarContext() {
if (ixinfer_handle_) {
cuinferDestroy(ixinfer_handle_);
}
}
cuinferHandle_t IluvatarContext::getIxInferHandle() {
if (!ixinfer_handle_) {
cuinferCreate(&ixinfer_handle_);
}
return ixinfer_handle_;
}
IluvatarContext* getContextInstance() {
static IluvatarContext context;
return &context;
}
} // namespace iluvatar

View File

@@ -0,0 +1,33 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ixinfer.h>
namespace iluvatar {
class IluvatarContext {
public:
IluvatarContext() = default;
~IluvatarContext();
cuinferHandle_t getIxInferHandle();
private:
cuinferHandle_t ixinfer_handle_{nullptr};
};
IluvatarContext* getContextInstance();
} // namespace iluvatar

View File

@@ -470,6 +470,36 @@ elif paddle.is_compiled_with_cuda():
)
elif paddle.is_compiled_with_xpu():
assert False, "In XPU, we should use setup_ops.py in xpu_ops/src, not this."
elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
setup(
name="fastdeploy_ops",
ext_modules=CUDAExtension(
extra_compile_args={
"nvcc": [
"-DPADDLE_DEV",
"-DPADDLE_WITH_CUSTOM_DEVICE",
]
},
sources=[
"gpu_ops/get_padding_offset.cu",
"gpu_ops/set_value_by_flags.cu",
"gpu_ops/stop_generation_multi_stop_seqs.cu",
"gpu_ops/rebuild_padding.cu",
"gpu_ops/update_inputs.cu",
"gpu_ops/stop_generation_multi_ends.cu",
"gpu_ops/step.cu",
"gpu_ops/token_penalty_multi_scores.cu",
"iluvatar_ops/moe_dispatch.cu",
"iluvatar_ops/moe_reduce.cu",
"iluvatar_ops/paged_attn.cu",
"iluvatar_ops/runtime/iluvatar_context.cc",
],
include_dirs=["iluvatar_ops/runtime", "gpu_ops"],
extra_link_args=[
"-lcuinfer",
],
),
)
else:
use_bf16 = envs.FD_CPU_USE_BF16 == "True"

View File

@@ -42,7 +42,7 @@ class ModelConfig:
model_name_or_path: str,
config_json_file: str = "config.json",
dynamic_load_weight: bool = False,
load_strategy: str="meta",
load_strategy: str = "meta",
quantization: str = None,
download_dir: Optional[str] = None):
"""
@@ -590,7 +590,7 @@ class Config:
self.nnode = 1
else:
self.nnode = len(self.pod_ips)
assert self.splitwise_role in ["mixed", "prefill", "decode"]
# TODO
@@ -608,8 +608,9 @@ class Config:
== 1), "TP and EP cannot be enabled at the same time"
num_ranks = self.tensor_parallel_size * self.parallel_config.expert_parallel_size
if num_ranks > 8:
self.worker_num_per_node = 8
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
if num_ranks > self.max_chips_per_node:
self.worker_num_per_node = self.max_chips_per_node
nnode = ceil_div(num_ranks, self.worker_num_per_node)
assert nnode == self.nnode, \
f"nnode: {nnode}, but got {self.nnode}"
@@ -679,8 +680,8 @@ class Config:
is_port_available('0.0.0.0', self.engine_worker_queue_port)
), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use."
assert (
8 >= self.tensor_parallel_size > 0
), f"tensor_parallel_size: {self.tensor_parallel_size} should be between 1 and 8"
self.max_chips_per_node >= self.tensor_parallel_size > 0
), f"tensor_parallel_size: {self.tensor_parallel_size} should be between 1 and {self.max_chips_per_node}"
assert (self.nnode >= 1), f"nnode: {self.nnode} should no less than 1"
assert (
self.max_model_len >= 16
@@ -816,7 +817,7 @@ class Config:
def _check_master(self):
return self.is_master
def _str_to_list(self, attr_name, default_type):
if hasattr(self, attr_name):
val = getattr(self, attr_name)

View File

@@ -63,7 +63,8 @@ class SiluAndMul(nn.Layer):
"""
super().__init__()
if current_platform.is_cuda() or current_platform.is_xpu():
if current_platform.is_cuda() or current_platform.is_xpu(
) or current_platform.is_iluvatar():
self.forward = self.forward_cuda
else:
raise NotImplementedError

View File

@@ -19,9 +19,10 @@ from .flash_attn_backend import FlashAttentionBackend
from .mla_attention_backend import MLAAttentionBackend
from .native_paddle_backend import PaddleNativeAttnBackend
from .xpu_attn_backend import XPUAttentionBackend
from .iluvatar_attn_backend import IluvatarAttnBackend
__all__ = [
"AttentionBackend", "PaddleNativeAttnBackend",
"get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend",
"MLAAttentionBackend", "FlashAttentionBackend"
"MLAAttentionBackend", "FlashAttentionBackend", "IluvatarAttnBackend"
]

View File

@@ -67,7 +67,7 @@ class Attention(nn.Layer):
self.num_heads: int = fd_config.model_config.num_attention_heads // fd_config.parallel_config.tensor_parallel_degree
self.head_dim: int = fd_config.model_config.head_dim
self.kv_num_heads: int = \
fd_config.model_config.num_key_value_heads // fd_config.parallel_config.tensor_parallel_degree
max(1, fd_config.model_config.num_key_value_heads // fd_config.parallel_config.tensor_parallel_degree)
self.layer_id: int = layer_id
self.v_head_dim: int = v_head_dim if v_head_dim > 0 else self.head_dim
self.rope_type: str = rope_type

View File

@@ -0,0 +1,613 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import os
import paddle
from dataclasses import dataclass
from typing import Optional
from math import sqrt
from paddle.nn.functional.flash_attention import flash_attn_unpadded
from fastdeploy.model_executor.ops.iluvatar import paged_attention
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend, AttentionMetadata)
from fastdeploy.worker.forward_meta import ForwardMeta
@dataclass
class IluvatarAttentionMetadata(AttentionMetadata):
"""
IluvatarAttentionMetadata
"""
# flash_attn metadata
cu_seqlens_q: Optional[paddle.Tensor] = None
cu_seqlens_k: Optional[paddle.Tensor] = None
fixed_seed_offset: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None
attn_mask_start_row_indices: Optional[paddle.Tensor] = None
dropout: float = 0.0
causal: bool = True
return_softmax: bool = False
rng_name: str = ""
# paged_attn metadata
block_tables: Optional[paddle.Tensor] = None
seq_lens: Optional[paddle.Tensor] = None
num_kv_heads: int = 1
scale: float = 1.0
block_size: int = 1
max_context_len: int = 1
alibi_slopes: Optional[paddle.Tensor] = None
# causal: bool = True
window_left: int = -1
window_right: int = -1
softcap: float = 0.0
use_cuda_graph: bool = False
use_sqrt_alibi: bool = False
# qk[seq, h, d], cos/sin [seq, 1, d]
def apply_rope(qk, cos, sin):
rotate_half = paddle.reshape(
paddle.stack([-qk[..., 1::2], qk[..., 0::2]], axis=-1),
paddle.shape(qk),
)
out = paddle.add(paddle.multiply(qk, cos),
paddle.multiply(rotate_half, sin))
return paddle.cast(out, qk.dtype)
class IluvatarAttnBackend(AttentionBackend):
"""
The backend class that uses paddle native attention implementation.
Which is used only for testing purpose.
"""
def __init__(self, llm_config: FDConfig, kv_num_heads: int, num_heads: int,
head_dim: int):
super().__init__()
self.attention_metadata = IluvatarAttentionMetadata()
self.attention_metadata.block_size = llm_config.parallel_config.block_size
assert llm_config.parallel_config.enc_dec_block_num == 0, "Iluvatar does not support yet"
self.attention_metadata.max_context_len = llm_config.parallel_config.max_model_len
self.attention_metadata.causal = getattr(llm_config.model_config,
"causal", True)
self.speculate_method = getattr(llm_config.parallel_config,
"speculate_method", None)
self.use_speculate = self.speculate_method is not None
self.attention_metadata.num_kv_heads = kv_num_heads
self.attention_metadata.dropout = llm_config.model_config.hidden_dropout_prob
self.num_heads = num_heads
self.head_dim = head_dim
# note: scale need to change if using MLA
self.attention_metadata.scale = 1.0 / sqrt(head_dim)
self.num_layers = llm_config.model_config.num_layers
self.record_block_table_metadata = {}
self.only_use_flash_attn = int(
os.getenv("FD_ILUVATAR_ONLY_USE_FLASH_ATTN", 0)) == 1
self.do_check_kv_cache = int(
os.getenv("FD_ILUVATAR_CHECK_KV_CACHE_CORRECTNESS", 0)) == 1
if not self.only_use_flash_attn:
assert self.attention_metadata.block_size == 16, "Iluvatar paged attn requires block_size must be 16."
if self.do_check_kv_cache:
self.record_batched_k = [{} for _ in range(self.num_layers)]
self.record_batched_v = [{} for _ in range(self.num_layers)]
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
self.attention_metadata.block_tables = forward_meta.block_tables
self.attention_metadata.attn_mask = forward_meta.attn_mask
self.attention_metadata.seq_lens = forward_meta.seq_lens_decoder
self.attention_metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
self.attention_metadata.cu_seqlens_k = forward_meta.cu_seqlens_k
def get_attntion_meta(self):
"""get_attntion_meta"""
return self.attention_metadata
def get_kv_cache_shape(
self,
max_num_blocks: int,
):
"""
Caculate kv cache shape
"""
return (max_num_blocks, self.attention_metadata.num_kv_heads,
self.attention_metadata.block_size, self.head_dim)
def get_new_kv(self,
k,
v,
k_cache_id: int,
v_cache_id: int,
forward_meta: ForwardMeta,
debug_paged_attn=False):
new_k = []
new_v = []
tensor_start = 0
for batch_idx in range(forward_meta.block_tables.shape[0]):
seq_len = forward_meta.seq_lens_this_time[batch_idx]
if seq_len == 0:
continue
tensor_end = tensor_start + seq_len
slice_k = k[tensor_start:tensor_end, :, :]
slice_v = v[tensor_start:tensor_end, :, :]
if seq_len > 1:
# prefill
new_k.append(slice_k)
new_v.append(slice_v)
else:
# decode
assert seq_len == 1
cur_block_tables = forward_meta.block_tables[batch_idx]
cur_used_block_tables = cur_block_tables[cur_block_tables !=
-1]
assert batch_idx in self.record_block_table_metadata, \
f"Key error: {batch_idx} vs {self.record_block_table_metadata}."
cur_block_table_metadata = self.record_block_table_metadata[
batch_idx]
record_last_block_id = cur_block_table_metadata["block_id"]
assert record_last_block_id != -1
for block_id in cur_used_block_tables:
if block_id == record_last_block_id:
cache_end = cur_block_table_metadata["cache_end"]
block_k_cache = forward_meta.caches[k_cache_id][
block_id, :, 0:cache_end, :]
block_v_cache = forward_meta.caches[v_cache_id][
block_id, :, 0:cache_end, :]
else:
block_k_cache = forward_meta.caches[k_cache_id][
block_id]
block_v_cache = forward_meta.caches[v_cache_id][
block_id]
# [num_kv_heads, block_size, head_dim] -> [block_size, num_kv_heads, head_dim]
new_k.append(
block_k_cache.transpose([1, 0, 2]).contiguous())
new_v.append(
block_v_cache.transpose([1, 0, 2]).contiguous())
if block_id == record_last_block_id:
break
# as line 301 show, record_block_table_metadata updates when executing the last layer,
# so slice_k and slice_v has been updated in block_k_cache and block_v_cache
if not (debug_paged_attn and
(k_cache_id / 2 == self.num_layers - 1)):
new_k.append(slice_k)
new_v.append(slice_v)
tensor_start = tensor_end
if len(new_k) == 1:
return new_k[0], new_v[0]
else:
new_k = paddle.concat(new_k, axis=0)
new_v = paddle.concat(new_v, axis=0)
return new_k, new_v
def update_kv_cache(self,
k,
v,
k_cache_id: int,
v_cache_id: int,
layer_id: int,
forward_meta: ForwardMeta,
specific_batch_ids=None,
debug_paged_attn=False):
# [num_tokens, num_kv_heads, head_dim] -> [num_kv_heads, num_tokens, head_dim]
trans_k = k.transpose([1, 0, 2]).contiguous()
trans_v = v.transpose([1, 0, 2]).contiguous()
tensor_start = 0
for batch_idx in range(forward_meta.block_tables.shape[0]):
if specific_batch_ids is not None and batch_idx not in specific_batch_ids:
continue
seq_len = forward_meta.seq_lens_this_time[batch_idx]
if seq_len == 0:
continue
tensor_end = tensor_start + seq_len
slice_trans_k = trans_k[:, tensor_start:tensor_end, :]
slice_trans_v = trans_v[:, tensor_start:tensor_end, :]
cur_block_tables = forward_meta.block_tables[batch_idx]
cur_used_block_tables = cur_block_tables[cur_block_tables != -1]
# prefill
if seq_len > 1:
cache_start = 0
cur_used_num_blocks = cur_used_block_tables.shape[0]
for i, block_id in enumerate(cur_used_block_tables):
# last block: seq_len - cache_start <= block_size
if i == cur_used_num_blocks - 1:
cache_end = seq_len - cache_start
assert cache_end <= self.attention_metadata.block_size
forward_meta.caches[k_cache_id][
block_id, :,
0:cache_end, :] = slice_trans_k[:, cache_start:
seq_len, :]
forward_meta.caches[v_cache_id][
block_id, :,
0:cache_end, :] = slice_trans_v[:, cache_start:
seq_len, :]
if layer_id == self.num_layers - 1:
self.record_block_table_metadata[batch_idx] = {
"block_id": block_id.item(),
"cache_end": cache_end
}
# non last block: seq_lens_this_time > block_size
else:
assert seq_len > self.attention_metadata.block_size
cache_end = cache_start + self.attention_metadata.block_size
forward_meta.caches[k_cache_id][
block_id] = slice_trans_k[:,
cache_start:cache_end, :]
forward_meta.caches[v_cache_id][
block_id] = slice_trans_v[:,
cache_start:cache_end, :]
cache_start += self.attention_metadata.block_size
else:
# decode
assert seq_len == 1
cur_last_block_id = cur_used_block_tables[-1].item()
assert cur_last_block_id != -1
assert batch_idx in self.record_block_table_metadata, \
f"Key error: {batch_idx} vs {self.record_block_table_metadata}."
cur_block_table_metadata = self.record_block_table_metadata[
batch_idx]
record_last_block_id = cur_block_table_metadata["block_id"]
if cur_last_block_id == record_last_block_id:
# not alloc new block in decode stage
cache_start = cur_block_table_metadata["cache_end"]
else:
# alloc new block in decode stage
cache_start = 0
cache_end = cache_start + 1
assert cache_end <= self.attention_metadata.block_size
# paged attn API will update kv cache with inplace mode
if not debug_paged_attn:
forward_meta.caches[k_cache_id][
cur_last_block_id, :,
cache_start:cache_end, :] = slice_trans_k
forward_meta.caches[v_cache_id][
cur_last_block_id, :,
cache_start:cache_end, :] = slice_trans_v
# update record_block_table_metadata
if layer_id == self.num_layers - 1:
self.record_block_table_metadata[batch_idx][
"block_id"] = cur_last_block_id
self.record_block_table_metadata[batch_idx][
"cache_end"] = cache_end
tensor_start = tensor_end
def _check_new_kv_correctness(self, k, v, new_k, new_v, layer_id: int,
forward_meta: ForwardMeta):
tensor_start = 0
for batch_idx, seq_lens_this_time in enumerate(
forward_meta.seq_lens_this_time):
if seq_lens_this_time == 0:
continue
# note: the second request will also use the batch_idx 0 instead of 1 in
# the streaming inference mode, so use seq_lens_this_time > 1 with the same
# batch_idx represents the second request comes.
if seq_lens_this_time > 1 and batch_idx in self.record_batched_k[
layer_id]:
print(
f"clear self.record_batched_batched_k: "
f"layer_id={layer_id}, batch_id={batch_idx}, "
f"record_lens={len(self.record_batched_k[layer_id][batch_idx])}"
)
self.record_batched_k[layer_id][batch_idx].clear()
self.record_batched_v[layer_id][batch_idx].clear()
tensor_end = tensor_start + seq_lens_this_time
slice_k = k[tensor_start:tensor_end, :, :]
slice_v = v[tensor_start:tensor_end, :, :]
if batch_idx not in self.record_batched_k[layer_id]:
self.record_batched_k[layer_id][batch_idx] = []
self.record_batched_v[layer_id][batch_idx] = []
self.record_batched_k[layer_id][batch_idx].append(slice_k)
self.record_batched_v[layer_id][batch_idx].append(slice_v)
tensor_start = tensor_end
ref_k, ref_v = [], []
for batch_idx, seq_lens_this_time in enumerate(
forward_meta.seq_lens_this_time):
if seq_lens_this_time == 0:
continue
bached_k_list = self.record_batched_k[layer_id][batch_idx]
bached_v_list = self.record_batched_v[layer_id][batch_idx]
ref_k.extend(bached_k_list)
ref_v.extend(bached_v_list)
ref_k = paddle.concat(ref_k, axis=0)
ref_v = paddle.concat(ref_v, axis=0)
print(
f"_check_new_kv_correctness: layer_id={layer_id}, "
f"k.shape={k.shape}, v.shape={v.shape}, "
f"ref_k.shape={ref_k.shape}, ref_v.shape={ref_v.shape}, "
f"new_k.shape={new_k.shape}, new_v.shape={new_v.shape}, "
f"len(self.record_batched_k[layer_id])={len(self.record_batched_k[layer_id])}, "
f"len(self.record_batched_k[layer_id][0])={len(self.record_batched_k[layer_id][0])}, "
f"forward_meta.seq_lens_this_time={forward_meta.seq_lens_this_time}"
f"ref_k[-2:, 0:2, 0:2]={ref_k[-2:, 0:2, 0:2]}, "
f"ref_v[-2:, 0:2, 0:2]={ref_v[-2:, 0:2, 0:2]}, "
f"new_k[-2:, 0:2, 0:2]={new_k[-2:, 0:2, 0:2]}, "
f"new_v[-2:, 0:2, 0:2]={new_v[-2:, 0:2, 0:2]}")
assert paddle.allclose(
ref_k.to("cpu").to(paddle.float32),
new_k.to("cpu").to(paddle.float32))
assert paddle.allclose(
ref_v.to("cpu").to(paddle.float32),
new_v.to("cpu").to(paddle.float32))
def get_splited_qkv(self, qkv: paddle.Tensor, forward_meta: ForwardMeta):
q_end = self.num_heads * self.head_dim
k_end = q_end + self.attention_metadata.num_kv_heads * self.head_dim
v_end = k_end + self.attention_metadata.num_kv_heads * self.head_dim
assert v_end == qkv.shape[
-1], f"Shape mistach: {v_end} vs {qkv.shape[-1]}"
assert qkv.shape[0] == forward_meta.cu_seqlens_q[-1]
q = qkv[..., 0:q_end]
k = qkv[..., q_end:k_end]
v = qkv[..., k_end:v_end]
q = q.view([-1, self.num_heads, self.head_dim]).contiguous()
k = k.view([-1, self.attention_metadata.num_kv_heads,
self.head_dim]).contiguous()
v = v.view([-1, self.attention_metadata.num_kv_heads,
self.head_dim]).contiguous()
# forward_meta.seq_lens_this_time [max_batch,]
for batch_idx in range(forward_meta.seq_lens_this_time.shape[0]):
seq_len_i = forward_meta.seq_lens_this_time[batch_idx]
if seq_len_i == 0:
continue
cached_kv_len = forward_meta.seq_lens_decoder[batch_idx][0]
cu_seq_start_q = forward_meta.cu_seqlens_q[batch_idx]
cu_seq_end_q = forward_meta.cu_seqlens_q[batch_idx + 1]
# forward_meta.rotary_embs is [2, 1, S, 1, D]
if forward_meta.rotary_embs is not None:
cos = forward_meta.rotary_embs[0, 0,
cached_kv_len:cached_kv_len +
seq_len_i, :, :]
sin = forward_meta.rotary_embs[1, 0,
cached_kv_len:cached_kv_len +
seq_len_i, :, :]
q[cu_seq_start_q:cu_seq_end_q] = apply_rope(
q[cu_seq_start_q:cu_seq_end_q], cos, sin)
k[cu_seq_start_q:cu_seq_end_q] = apply_rope(
k[cu_seq_start_q:cu_seq_end_q], cos, sin)
return q, k, v
def get_splited_info_by_stage(self, q, k, v, forward_meta: ForwardMeta):
prefill_info_dict = {"q": [], "k": [], "v": [], "batch_ids": []}
decode_info_dict = {"q": [], "k": [], "v": [], "batch_ids": []}
tensor_start = 0
for batch_idx, seq_lens_this_time in enumerate(
forward_meta.seq_lens_this_time):
if seq_lens_this_time == 0:
continue
tensor_end = tensor_start + seq_lens_this_time
slice_q = q[tensor_start:tensor_end, :, :]
slice_k = k[tensor_start:tensor_end, :, :]
slice_v = v[tensor_start:tensor_end, :, :]
if seq_lens_this_time > 1:
prefill_info_dict["q"].append(slice_q)
prefill_info_dict["k"].append(slice_k)
prefill_info_dict["v"].append(slice_v)
prefill_info_dict["batch_ids"].append(batch_idx)
else:
assert seq_lens_this_time == 1
decode_info_dict["q"].append(slice_q)
decode_info_dict["k"].append(slice_k)
decode_info_dict["v"].append(slice_v)
decode_info_dict["batch_ids"].append(batch_idx)
tensor_start = tensor_end
if len(prefill_info_dict["batch_ids"]) > 0:
prefill_info_dict["q"] = paddle.concat(prefill_info_dict["q"],
axis=0)
prefill_info_dict["k"] = paddle.concat(prefill_info_dict["k"],
axis=0)
prefill_info_dict["v"] = paddle.concat(prefill_info_dict["v"],
axis=0)
cu_seq_ids = list(
map(lambda x: x + 1, prefill_info_dict["batch_ids"]))
prefill_info_dict["cu_seq_ids"] = [0, *cu_seq_ids]
if len(decode_info_dict["batch_ids"]) > 0:
decode_info_dict["q"] = paddle.concat(decode_info_dict["q"],
axis=0)
decode_info_dict["k"] = paddle.concat(decode_info_dict["k"],
axis=0)
decode_info_dict["v"] = paddle.concat(decode_info_dict["v"],
axis=0)
return prefill_info_dict, decode_info_dict
def merge_output(self, prefill_out, decode_out, forward_meta: ForwardMeta):
assert not (prefill_out is None and decode_out
is None), "prefill and decode output cannot both be None"
if prefill_out is None:
return decode_out
elif decode_out is None:
return prefill_out
else:
merged_output = []
prefill_tensor_start = 0
decode_tensor_start = 0
for seq_lens_this_time in forward_meta.seq_lens_this_time:
if seq_lens_this_time == 0:
continue
if seq_lens_this_time > 1:
tensor_end = prefill_tensor_start + seq_lens_this_time
merged_output.append(
prefill_out[prefill_tensor_start:tensor_end, :, :])
prefill_tensor_start = tensor_end
else:
assert seq_lens_this_time == 1
tensor_end = decode_tensor_start + seq_lens_this_time
merged_output.append(
decode_out[decode_tensor_start:tensor_end, :, :])
decode_tensor_start = tensor_end
assert prefill_tensor_start == prefill_out.shape[0], \
f"prefill merged unfinished: {prefill_tensor_start} vs {prefill_out.shape[0]}"
assert decode_tensor_start == decode_out.shape[0], \
f"decode merged unfinished: {decode_tensor_start} vs {decode_out.shape[0]}"
merged_output = paddle.concat(merged_output, axis=0)
return merged_output
def forward_mixed(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: Attention,
forward_meta: ForwardMeta,
):
"""
forward_mixed
"""
assert not self.use_speculate, "IluvatarAttnBackend cannot support speculate now"
layer_id = layer.layer_id
k_cache_id = layer_id * 2
v_cache_id = k_cache_id + 1
assert qkv is not None
q_dim = qkv.dim()
q, k, v = self.get_splited_qkv(qkv, forward_meta)
if self.only_use_flash_attn:
new_k, new_v = self.get_new_kv(k, v, k_cache_id, v_cache_id,
forward_meta)
if self.do_check_kv_cache:
self._check_new_kv_correctness(k, v, new_k, new_v, layer_id,
forward_meta)
out = flash_attn_unpadded(
q,
new_k,
new_v,
cu_seqlens_q=self.attention_metadata.cu_seqlens_q,
cu_seqlens_k=self.attention_metadata.cu_seqlens_k,
max_seqlen_q=self.attention_metadata.max_context_len,
max_seqlen_k=self.attention_metadata.max_context_len,
scale=self.attention_metadata.scale,
dropout=self.attention_metadata.dropout,
causal=self.attention_metadata.causal,
return_softmax=self.attention_metadata.return_softmax)[0]
self.update_kv_cache(k, v, k_cache_id, v_cache_id, layer_id,
forward_meta)
else:
prefill_info_dict, decode_info_dict = self.get_splited_info_by_stage(
q, k, v, forward_meta)
prefill_out, decode_out = None, None
if len(prefill_info_dict["batch_ids"]) > 0:
prefill_out = flash_attn_unpadded(
prefill_info_dict["q"],
prefill_info_dict["k"],
prefill_info_dict["v"],
cu_seqlens_q=forward_meta.cu_seqlens_q[
prefill_info_dict["cu_seq_ids"]],
cu_seqlens_k=forward_meta.cu_seqlens_k[
prefill_info_dict["cu_seq_ids"]],
max_seqlen_q=self.attention_metadata.max_context_len,
max_seqlen_k=self.attention_metadata.max_context_len,
scale=self.attention_metadata.scale,
dropout=self.attention_metadata.dropout,
causal=self.attention_metadata.causal,
return_softmax=self.attention_metadata.return_softmax)[0]
self.update_kv_cache(
prefill_info_dict["k"],
prefill_info_dict["v"],
k_cache_id,
v_cache_id,
layer_id,
forward_meta,
specific_batch_ids=prefill_info_dict['batch_ids'])
if len(decode_info_dict["batch_ids"]) > 0:
k_cache = forward_meta.caches[k_cache_id]
v_cache = forward_meta.caches[v_cache_id]
decode_out = paged_attention(
decode_info_dict["q"],
k_cache,
v_cache,
block_tables=forward_meta.block_tables[
decode_info_dict["batch_ids"], :],
seq_lens=forward_meta.seq_lens_decoder[
decode_info_dict["batch_ids"], 0] + 1,
num_kv_heads=self.attention_metadata.num_kv_heads,
scale=self.attention_metadata.scale,
block_size=self.attention_metadata.block_size,
max_context_len=self.attention_metadata.max_context_len,
alibi_slopes=self.attention_metadata.alibi_slopes,
causal=self.attention_metadata.causal,
window_left=self.attention_metadata.window_left,
window_right=self.attention_metadata.window_right,
softcap=self.attention_metadata.softcap,
use_cuda_graph=self.attention_metadata.use_cuda_graph,
use_sqrt_alibi=self.attention_metadata.use_sqrt_alibi,
k=decode_info_dict["k"],
v=decode_info_dict["v"])
if self.do_check_kv_cache:
self.update_kv_cache(
decode_info_dict['k'],
decode_info_dict['v'],
k_cache_id,
v_cache_id,
layer_id,
forward_meta,
specific_batch_ids=decode_info_dict['batch_ids'],
debug_paged_attn=True)
if self.do_check_kv_cache:
new_k, new_v = self.get_new_kv(k,
v,
k_cache_id,
v_cache_id,
forward_meta,
debug_paged_attn=True)
self._check_new_kv_correctness(k, v, new_k, new_v, layer_id,
forward_meta)
out = self.merge_output(prefill_out, decode_out, forward_meta)
if q_dim == 2:
out = out.view([-1, self.num_heads * self.head_dim])
return out

View File

@@ -57,7 +57,8 @@ class LinearBase(nn.Layer):
NotImplementedError: Raised if the current platform is not a CUDA platform.
"""
super().__init__()
if current_platform.is_cuda() or current_platform.is_xpu():
if current_platform.is_cuda() or current_platform.is_xpu(
) or current_platform.is_iluvatar():
self.forward = self.forward_cuda
else:
raise NotImplementedError
@@ -411,9 +412,14 @@ class QKVParallelLinear(ColumnParallelLinear):
self.head_dim = fd_config.model_config.head_dim
self.nranks = fd_config.parallel_config.tensor_parallel_degree
self.num_heads_per_rank = divide(self.num_heads, self.nranks)
self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.nranks)
if self.kv_num_heads < self.nranks and self.nranks % self.kv_num_heads == 0:
self.kv_num_heads_per_rank = 1
output_size = (self.num_heads + 2 * self.nranks) * self.head_dim
else:
self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.nranks)
output_size = (self.num_heads +
2 * self.kv_num_heads) * self.head_dim
input_size = self.hidden_size
output_size = (self.num_heads + 2 * self.kv_num_heads) * self.head_dim
super().__init__(fd_config=fd_config,
prefix=prefix,
input_size=input_size,

View File

@@ -30,6 +30,8 @@ from .fused_moe_backend_base import MoEMethodBase
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import (moe_expert_dispatch,
moe_expert_reduce, noaux_tc)
elif current_platform.is_iluvatar():
from fastdeploy.model_executor.ops.iluvatar import moe_expert_dispatch, moe_expert_reduce
# used for deepseek_v3
@@ -89,6 +91,23 @@ class CutlassMoEMethod(MoEMethodBase):
"""
Paddle Cutlass compute Fused MoE.
"""
if current_platform.is_iluvatar():
return fastdeploy.model_executor.ops.iluvatar.moe_expert_ffn(
permute_input,
token_nums_per_expert,
layer.moe_ffn1_weight,
layer.moe_ffn2_weight,
None,
(layer.moe_ffn1_weight_scale if hasattr(
layer, "moe_ffn1_weight_scale") else None),
(layer.moe_ffn2_weight_scale if hasattr(
layer, "moe_ffn2_weight_scale") else None),
(layer.moe_ffn2_in_scale
if hasattr(layer, "moe_ffn2_in_scale") else None),
expert_idx_per_token,
self.moe_quant_type,
used_in_ep_low_latency,
)
return fastdeploy.model_executor.ops.gpu.moe_expert_ffn(
permute_input,
token_nums_per_expert,

View File

@@ -20,6 +20,7 @@ import numpy as np
import paddle
from paddle import nn
from paddle.incubate.nn.functional import fused_layer_norm, fused_rms_norm
from fastdeploy.platforms import current_platform
from fastdeploy.config import FDConfig
@@ -265,6 +266,18 @@ class LayerNorm(nn.Layer):
The `residual_output` is the result of applying the normalization and possibly other
operations (like linear transformation) on the `residual_input`.
"""
if current_platform.is_iluvatar():
if self.ln_weight is None and self.ln_bias is None:
out = x
if self.linear_bias is not None:
out += self.linear_bias
if residual_input is not None:
out += residual_input
return out, out
else:
return out
else:
raise NotImplementedError("Iluvatar does not support yet!")
norm_out = self.norm_func(
x,

View File

@@ -48,7 +48,8 @@ class ErnieRotaryEmbedding:
freqs = paddle.einsum("ij,k->ijk",
partial_rotary_position_ids.cast("float32"),
inv_freq)
if paddle.is_compiled_with_xpu():
if paddle.is_compiled_with_xpu(
) or paddle.is_compiled_with_custom_device("iluvatar_gpu"):
# shape: [B, S, D]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim),
dtype="float32")

View File

@@ -64,6 +64,21 @@ def apply_penalty_multi_scores(
min_dec_lens,
eos_token_ids,
)
elif current_platform.is_iluvatar():
from fastdeploy.model_executor.ops.iluvatar import \
get_token_penalty_multi_scores
logits = get_token_penalty_multi_scores(
pre_token_ids,
logits,
repetition_penalties,
frequency_penalties,
presence_penalties,
temperature,
bad_words_token_ids,
step_idx,
min_dec_lens,
eos_token_ids,
)
else:
raise NotImplementedError()

View File

@@ -170,7 +170,8 @@ class Sampler(nn.Layer):
"""
"""
super().__init__()
if current_platform.is_cuda() or current_platform.is_xpu():
if current_platform.is_cuda() or current_platform.is_xpu(
) or current_platform.is_iluvatar():
self.forward = self.forward_cuda
else:
raise NotImplementedError()

View File

@@ -33,6 +33,8 @@ if current_platform.is_cuda() and current_platform.available():
"Verify environment consistency between compilation and FastDeploy installation. "
"And ensure the Paddle version supports FastDeploy's custom operators"
)
if current_platform.is_iluvatar():
from fastdeploy.model_executor.ops.iluvatar import get_padding_offset
import re
from fastdeploy import envs
@@ -377,4 +379,4 @@ def create_and_set_parameter(layer: nn.Layer, name: str,
dtype=tensor.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
))
getattr(layer, name).set_value(tensor)
getattr(layer, name).set_value(tensor)

View File

@@ -213,8 +213,14 @@ def gqa_qkv_split_func(
return np.split(tensor, degree, axis=0)
q_list = split_tensor(q, tensor_parallel_degree)
k_list = split_tensor(k, tensor_parallel_degree)
v_list = split_tensor(v, tensor_parallel_degree)
repeat_kv = num_key_value_heads < tensor_parallel_degree and tensor_parallel_degree % num_key_value_heads == 0
repeat_num = tensor_parallel_degree // num_key_value_heads if repeat_kv else 1
if repeat_kv:
k_list = split_tensor(k, num_key_value_heads)
v_list = split_tensor(v, num_key_value_heads)
else:
k_list = split_tensor(k, tensor_parallel_degree)
v_list = split_tensor(v, tensor_parallel_degree)
if tensor_parallel_rank is None:
res = []
@@ -236,8 +242,8 @@ def gqa_qkv_split_func(
return paddle.concat(
[
q_list[tensor_parallel_rank],
k_list[tensor_parallel_rank],
v_list[tensor_parallel_rank],
k_list[tensor_parallel_rank // repeat_num],
v_list[tensor_parallel_rank // repeat_num],
],
axis=-1,
)
@@ -245,8 +251,8 @@ def gqa_qkv_split_func(
return paddle.concat(
[
q_list[tensor_parallel_rank],
k_list[tensor_parallel_rank],
v_list[tensor_parallel_rank],
k_list[tensor_parallel_rank // repeat_num],
v_list[tensor_parallel_rank // repeat_num],
],
axis=0,
)
@@ -255,8 +261,8 @@ def gqa_qkv_split_func(
return np.concatenate(
[
q_list[tensor_parallel_rank],
k_list[tensor_parallel_rank],
v_list[tensor_parallel_rank],
k_list[tensor_parallel_rank // repeat_num],
v_list[tensor_parallel_rank // repeat_num],
],
axis=-1,
)
@@ -264,8 +270,8 @@ def gqa_qkv_split_func(
return np.concatenate(
[
q_list[tensor_parallel_rank],
k_list[tensor_parallel_rank],
v_list[tensor_parallel_rank],
k_list[tensor_parallel_rank // repeat_num],
v_list[tensor_parallel_rank // repeat_num],
],
axis=0,
)
@@ -281,8 +287,8 @@ def gqa_qkv_merge_func(num_attention_heads, num_key_value_heads, head_dim):
def fn(weight_list, is_column=True):
"""fn"""
tensor_parallel_degree = len(weight_list)
num_attention_heads = num_attention_heads // tensor_parallel_degree
num_key_value_heads = num_key_value_heads // tensor_parallel_degree
num_attention_heads = num_attention_heads // tensor_parallel_degree # noqa: F823
num_key_value_heads = num_key_value_heads // tensor_parallel_degree # noqa: F823
is_paddle_tensor = not isinstance(weight_list[0], np.ndarray)

View File

@@ -196,6 +196,9 @@ def convert_ndarray_dtype(np_array: np.ndarray,
np.ndarray: converted numpy ndarray instance
"""
source_dtype = convert_dtype(np_array.dtype)
if source_dtype == "uint16" and target_dtype == "bfloat16" and paddle.is_compiled_with_custom_device(
"iluvatar_gpu"):
return np_array.view(dtype=target_dtype)
if source_dtype == "uint16" or target_dtype == "bfloat16":
if paddle.is_compiled_with_xpu():
# xpu not support bf16.

View File

@@ -16,5 +16,6 @@ from . import gpu
from . import cpu
from . import xpu
from . import npu
from . import iluvatar
__all__ = ["gpu", "cpu", "xpu", "npu"]
__all__ = ["gpu", "cpu", "xpu", "npu", "iluvatar"]

View File

@@ -0,0 +1,24 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""fastdeploy gpu ops"""
from fastdeploy.import_ops import import_custom_ops
PACKAGE = "fastdeploy.model_executor.ops.iluvatar"
import_custom_ops(PACKAGE, "..base.fastdeploy_base_ops", globals())
import_custom_ops(PACKAGE, ".fastdeploy_ops", globals())
from .moe_ops import iluvatar_moe_expert_ffn as moe_expert_ffn # noqa: E402, F401
from .paged_attention import paged_attention # noqa: E402, F401

View File

@@ -0,0 +1,101 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Optional
import paddle
from paddle.nn.quant import weight_only_linear
from paddle.incubate.nn.functional import swiglu
def group_gemm(
input: paddle.Tensor,
tokens_expert_prefix_sum: paddle.Tensor,
weight: paddle.Tensor,
scale: paddle.Tensor,
output: paddle.Tensor,
):
assert (input.dim() == 2 and tokens_expert_prefix_sum.dim() == 1
and weight.dim() == 3 and scale.dim() == 2 and output.dim() == 2)
num_tokens = input.shape[0]
dim_in = input.shape[1]
dim_out = weight.shape[1]
num_experts = weight.shape[0]
# check shape
assert tokens_expert_prefix_sum.shape == [
num_experts,
]
assert weight.shape == [num_experts, dim_out, dim_in]
assert scale.shape == [num_experts, dim_out]
assert output.shape == [num_tokens, dim_out]
# check dtype
assert input.dtype in (paddle.float16, paddle.bfloat16)
assert scale.dtype == input.dtype and output.dtype == input.dtype
assert tokens_expert_prefix_sum.dtype == paddle.int64
assert weight.dtype == paddle.int8
# check others
assert tokens_expert_prefix_sum.place.is_cpu_place()
assert tokens_expert_prefix_sum[-1] == num_tokens
for i in range(num_experts):
expert_start = 0 if i == 0 else tokens_expert_prefix_sum[i - 1]
expert_end = tokens_expert_prefix_sum[i]
if expert_start == expert_end:
continue
input_i = input[expert_start:expert_end]
weight_i = weight[i]
scale_i = scale[i]
# avoid d2d?
output[expert_start:expert_end] = weight_only_linear(
input_i,
weight_i,
weight_scale=scale_i,
weight_dtype="int8",
group_size=-1)
def iluvatar_moe_expert_ffn(
permute_input: paddle.Tensor,
tokens_expert_prefix_sum: paddle.Tensor,
ffn1_weight: paddle.Tensor,
ffn2_weight: paddle.Tensor,
ffn1_bias: Optional[paddle.Tensor],
ffn1_scale: Optional[paddle.Tensor],
ffn2_scale: Optional[paddle.Tensor],
ffn2_in_scale: Optional[paddle.Tensor],
expert_idx_per_token: Optional[paddle.Tensor],
quant_method: str,
used_in_ep_low_latency: bool,
):
assert ffn1_bias is None
assert ffn1_scale is not None
assert ffn2_scale is not None
assert ffn2_in_scale is None
assert expert_idx_per_token is None
assert quant_method in ("weight_only_int8")
assert not used_in_ep_low_latency
tokens_expert_prefix_sum_cpu = tokens_expert_prefix_sum.to("cpu")
ffn1_output = paddle.empty([permute_input.shape[0], ffn1_weight.shape[1]],
dtype=permute_input.dtype)
group_gemm(permute_input, tokens_expert_prefix_sum_cpu, ffn1_weight,
ffn1_scale, ffn1_output)
act_out = swiglu(ffn1_output)
output = paddle.empty([act_out.shape[0], ffn2_weight.shape[1]],
dtype=act_out.dtype)
group_gemm(act_out, tokens_expert_prefix_sum_cpu, ffn2_weight, ffn2_scale,
output)
return output

View File

@@ -0,0 +1,46 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
try:
from fastdeploy.model_executor.ops.iluvatar import paged_attn
except ImportError:
paged_attn = None
def paged_attention(q: paddle.Tensor,
k_cache: paddle.Tensor,
v_cache: paddle.Tensor,
block_tables: paddle.Tensor,
seq_lens: paddle.Tensor,
num_kv_heads: int,
scale: float,
block_size: int,
max_context_len: int,
alibi_slopes: paddle.Tensor = None,
causal: bool = True,
window_left: int = -1,
window_right: int = -1,
softcap: float = 0.0,
use_cuda_graph: bool = False,
use_sqrt_alibi: bool = False,
k: paddle.Tensor = None,
v: paddle.Tensor = None):
output = paged_attn(q, k_cache, v_cache, block_tables, seq_lens,
alibi_slopes, k, v, num_kv_heads, scale, block_size,
max_context_len, causal, window_left, window_right,
softcap, use_cuda_graph, use_sqrt_alibi)
return output[0] if isinstance(output, list) else output

View File

@@ -19,18 +19,25 @@ import paddle
from fastdeploy import envs
from fastdeploy.engine.config import SpeculativeConfig
from fastdeploy.model_executor.ops.gpu import (
get_padding_offset, save_output, set_stop_value_multi_ends,
speculate_clear_accept_nums, speculate_get_output_padding_offset,
speculate_get_padding_offset, speculate_get_seq_lens_output,
speculate_save_output, speculate_set_value_by_flags_and_idx,
speculate_step_paddle, speculate_step_system_cache, speculate_update_v3,
step_paddle, step_system_cache, update_inputs, step_reschedule)
from fastdeploy.platforms import current_platform
if current_platform.is_iluvatar():
from fastdeploy.model_executor.ops.iluvatar import (
get_padding_offset, save_output, set_stop_value_multi_ends,
step_paddle, update_inputs)
else:
from fastdeploy.model_executor.ops.gpu import (
get_padding_offset, save_output, set_stop_value_multi_ends,
speculate_clear_accept_nums, speculate_get_output_padding_offset,
speculate_get_padding_offset, speculate_get_seq_lens_output,
speculate_save_output, speculate_set_value_by_flags_and_idx,
speculate_step_paddle, speculate_step_system_cache,
speculate_update_v3, step_paddle, step_system_cache, update_inputs,
step_reschedule)
from fastdeploy.worker.output import ModelOutputData
DISABLE_RECOVER = (envs.FD_DISABLED_RECOVER == "1")
def pre_process(
max_len: int,
input_ids: paddle.Tensor,
@@ -151,6 +158,7 @@ def post_process_normal(sampled_token_ids: paddle.Tensor,
save_each_rank, # save_each_rank
)
def post_process_specualate(model_output, skip_save_output: bool = False):
""""""
speculate_update_v3(
@@ -217,7 +225,6 @@ def step_cuda(
TODO(gongshaotian): normalization name
"""
if speculative_config.method is not None:
if enable_prefix_caching:
speculate_step_system_cache(
@@ -373,6 +380,17 @@ def rebuild_padding(tmp_out: paddle.Tensor,
output_padding_offset,
max_input_length,
)
elif current_platform.is_iluvatar():
from fastdeploy.model_executor.ops.iluvatar import rebuild_padding
hidden_states = rebuild_padding(
tmp_out,
cum_offsets,
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
max_input_length,
)
elif current_platform.is_cpu():
from fastdeploy.model_executor.ops.cpu import rebuild_padding_cpu
hidden_states = rebuild_padding_cpu(

View File

@@ -122,9 +122,13 @@ class TokenProcessor(object):
if current_platform.is_xpu():
from fastdeploy.model_executor.ops.xpu import get_output
elif current_platform.is_iluvatar():
from fastdeploy.model_executor.ops.iluvatar import get_output
else:
from fastdeploy.model_executor.ops.gpu import (
get_output, get_output_ep, speculate_get_output)
from fastdeploy.model_executor.ops.gpu import (get_output,
get_output_ep,
speculate_get_output
)
rank_id = self.cfg.parallel_config.local_data_parallel_id
while True:
@@ -413,9 +417,12 @@ class WarmUpTokenProcessor(TokenProcessor):
if current_platform.is_xpu():
from fastdeploy.model_executor.ops.xpu import get_output
elif current_platform.is_iluvatar():
from fastdeploy.model_executor.ops.iluvatar import get_output
else:
from fastdeploy.model_executor.ops.gpu import (
get_output, speculate_get_output)
from fastdeploy.model_executor.ops.gpu import (get_output,
speculate_get_output
)
while self._is_running:
try:

View File

@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
platform module
"""
@@ -22,7 +21,8 @@ from .cpu import CPUPlatform
from .xpu import XPUPlatform
from .npu import NPUPlatform
from .dcu import DCUPlatform
from .base import _Backend
from .iluvatar import IluvatarPlatform
from .base import _Backend # noqa: F401
_current_platform = None
@@ -40,10 +40,13 @@ def __getattr__(name: str):
_current_platform = NPUPlatform()
elif paddle.is_compiled_with_rocm():
_current_platform = DCUPlatform()
elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
_current_platform = IluvatarPlatform()
else:
_current_platform = CPUPlatform()
return _current_platform
elif name in globals():
return globals()[name]
else:
raise AttributeError(f"No attribute named '{name}' exists in {__name__}.")
raise AttributeError(
f"No attribute named '{name}' exists in {__name__}.")

View File

@@ -63,6 +63,12 @@ class Platform:
"""
return paddle.is_compiled_with_rocm()
def is_iluvatar(self) -> bool:
"""
whether platform is iluvatar gpu
"""
return paddle.is_compiled_with_custom_device("iluvatar_gpu")
@classmethod
def get_attention_backend_cls(self, selected_backend):
"""Get the attention backend"""

View File

@@ -0,0 +1,26 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base import Platform
class IluvatarPlatform(Platform):
device_name = "iluvatar_gpu"
@classmethod
def get_attention_backend_cls(cls, selected_backend):
"""
get_attention_backend_cls
"""
return (
"fastdeploy.model_executor.layers.attention.IluvatarAttnBackend")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,143 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import gc
import os
from typing import List, Optional
import paddle
import paddle.nn as nn
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request
from fastdeploy.utils import get_logger
from fastdeploy.worker.iluvatar_model_runner import IluvatarModelRunner
from fastdeploy.worker.output import ModelRunnerOutput
from fastdeploy.worker.worker_base import WorkerBase
logger = get_logger("iluvatar_worker", "iluvatar_worker.log")
class IluvatarWorker(WorkerBase):
""" """
def __init__(
self,
fd_config: FDConfig,
local_rank: int,
rank: int,
):
super().__init__(
fd_config=fd_config,
local_rank=local_rank,
rank=rank,
)
pass
def init_device(self):
""" Initialize device and Construct model runner
"""
if paddle.is_compiled_with_custom_device("iluvatar_gpu"):
# Set evironment variable
self.device = f"iluvatar_gpu:{self.local_rank}"
paddle.device.set_device(self.device)
paddle.set_default_dtype(self.parallel_config.dtype)
self.device_ids = self.parallel_config.device_ids.split(",")
gc.collect()
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Construct model runner
self.model_runner: IluvatarModelRunner = IluvatarModelRunner(
fd_config=self.fd_config,
device=self.device,
device_id=self.device_ids[self.local_rank],
rank=self.rank,
local_rank=self.local_rank)
def prefill_finished(self):
"""
check whether prefill stage finished
"""
return self.model_runner.prefill_finished()
def determine_available_memory(self) -> int:
"""
Profiles the peak memory usage of the model to determine how much
memory can be used for KV cache without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
Tip:
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# 1. Record memory state before profile run
return int(float(os.getenv("FD_ILUVATAR_KVCACHE_MEM", "3")) * 1024**3)
def load_model(self) -> None:
""" """
self.model_runner.load_model()
def get_model(self) -> nn.Layer:
""" """
return self.model_runner.get_model()
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
""" """
pass
def execute_model(
self,
model_forward_batch: Optional[List[Request]] = None,
) -> Optional[ModelRunnerOutput]:
""" """
output = self.model_runner.execute_model(model_forward_batch)
return output
def preprocess_new_task(self, req_dicts: List[Request]) -> None:
""" Process new requests and then start the decode loop
TODO(gongshaotian):The scheduler should schedule the handling of prefill,
and workers and modelrunners should not perceive it.
"""
self.model_runner.insert_prefill_inputs(req_dicts=req_dicts)
def graph_optimize_and_warm_up_model(self) -> None:
"""
Perform the warm-up and the graph optimization
"""
# 1. Warm up model
# NOTE(gongshaotian): may be not need warm_up at this place
# 2. Triger cuda grpah capture
self.model_runner.capture_model()
def check_health(self) -> bool:
""" """
return True
def cal_theortical_kvcache(self) -> int:
""" """
return self.model_runner.cal_theortical_kvcache()
def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None:
""" """
self.model_runner.update_share_input_block_num(
num_gpu_blocks=num_gpu_blocks)

View File

@@ -48,6 +48,11 @@ def get_worker(fd_config: FDConfig, local_rank: int, rank: int) -> WorkerBase:
if current_platform.is_xpu():
from fastdeploy.worker.xpu_worker import XpuWorker
return XpuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
if current_platform.is_iluvatar():
from fastdeploy.worker.iluvatar_worker import IluvatarWorker
return IluvatarWorker(fd_config=fd_config,
local_rank=local_rank,
rank=rank)
class PaddleDisWorkerProc():
@@ -125,9 +130,9 @@ class PaddleDisWorkerProc():
model_weights_status:
"""
# init worker_ready_signal
max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
array_size = min(
8, self.parallel_config.tensor_parallel_degree *
max_chips_per_node, self.parallel_config.tensor_parallel_degree *
self.parallel_config.expert_parallel_degree)
workers_ready = np.zeros(shape=[array_size], dtype=np.int32)
self.worker_ready_signal = IPCSignal(
@@ -136,7 +141,8 @@ class PaddleDisWorkerProc():
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False)
self.worker_ready_signal.value[self.local_rank % 8] = 1
self.worker_ready_signal.value[self.local_rank %
max_chips_per_node] = 1
# init worker_healthy_live_signal
workers_alive = np.zeros(shape=[self.ranks], dtype=np.int32)

29
requirements_iluvatar.txt Normal file
View File

@@ -0,0 +1,29 @@
setuptools>=62.3.0,<80.0
pre-commit
yapf
flake8
ruamel.yaml
zmq
aiozmq
openai
tqdm
pynvml
uvicorn
fastapi
paddleformers
redis
etcd3
httpx
tool_helpers
pybind11[global]
tabulate
gradio
xlwt
visualdl
setuptools-scm>=8
prometheus-client
decord
moviepy
use-triton-in-paddle
crcmod
fastsafetensors==0.1.14

View File

@@ -22,6 +22,7 @@ import subprocess
from pathlib import Path
from setuptools import Extension, find_packages, setup
from setuptools.command.build_ext import build_ext
from wheel.bdist_wheel import bdist_wheel
long_description = "FastDeploy: Large Language Model Serving.\n\n"
long_description += "GitHub: https://github.com/PaddlePaddle/FastDeploy\n"
@@ -35,7 +36,6 @@ PLAT_TO_CMAKE = {
"win-arm64": "ARM64",
}
from wheel.bdist_wheel import bdist_wheel
class CustomBdistWheel(bdist_wheel):
"""Custom wheel builder for pure Python packages."""
@@ -49,10 +49,14 @@ class CustomBdistWheel(bdist_wheel):
self.plat_name_supplied = True
self.plat_name = 'any'
class CMakeExtension(Extension):
"""A setuptools Extension for CMake-based builds."""
def __init__(self, name: str, sourcedir: str = "", version: str = None) -> None:
def __init__(self,
name: str,
sourcedir: str = "",
version: str = None) -> None:
"""
Initialize CMake extension.
@@ -83,16 +87,12 @@ class CMakeBuild(build_ext):
ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name)
extdir = ext_fullpath.parent.resolve()
cfg = "Debug" if int(os.environ.get("DEBUG", 0)) else "Release"
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
cmake_args = [
f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}",
f"-DPYTHON_EXECUTABLE={sys.executable}",
f"-DCMAKE_BUILD_TYPE={cfg}",
f"-DVERSION_INFO=",
f"-DPYBIND11_PYTHON_VERSION=",
f"-DPYTHON_VERSION=",
f"-DCMAKE_BUILD_TYPE={cfg}", "-DVERSION_INFO=",
"-DPYBIND11_PYTHON_VERSION=", "-DPYTHON_VERSION=",
f"-DPYTHON_INCLUDE_DIR={sys.prefix}/include/python{sys.version_info.major}.{sys.version_info.minor}",
f"-DPYTHON_LIBRARY={sys.prefix}/lib/libpython{sys.version_info.major}.{sys.version_info.minor}.so"
]
@@ -134,22 +134,27 @@ class CMakeBuild(build_ext):
build_temp.mkdir(parents=True, exist_ok=True)
subprocess.run(["cmake", ext.sourcedir, *cmake_args],
cwd=build_temp,
check=True)
cwd=build_temp,
check=True)
subprocess.run(["cmake", "--build", ".", *build_args],
cwd=build_temp,
check=True)
cwd=build_temp,
check=True)
def load_requirements():
"""Load dependencies from requirements.txt"""
requirements_file_name = 'requirements.txt'
if paddle.is_compiled_with_custom_device('iluvatar_gpu'):
requirements_file_name = 'requirements_iluvatar.txt'
requirements_path = os.path.join(os.path.dirname(__file__),
'requirements.txt')
requirements_file_name)
with open(requirements_path, 'r') as f:
return [
line.strip() for line in f
if line.strip() and not line.startswith('#')
]
def get_device_type():
"""Get the device type (rocm/gpu/xpu/npu/cpu) that paddle is compiled with."""
if paddle.is_compiled_with_rocm():
@@ -160,13 +165,17 @@ def get_device_type():
return "xpu"
elif paddle.is_compiled_with_custom_device('npu'):
return "npu"
elif paddle.is_compiled_with_custom_device('iluvatar_gpu'):
return "iluvatar-gpu"
else:
return "cpu"
def get_name():
"""get package name"""
return "fastdeploy-" + get_device_type()
cmdclass_dict = {'bdist_wheel': CustomBdistWheel}
cmdclass_dict['build_ext'] = CMakeBuild
FASTDEPLOY_VERSION = os.environ.get("FASTDEPLOY_VERSION", "2.0.0")
@@ -187,8 +196,8 @@ setup(
"model_executor/ops/gpu/*",
"model_executor/ops/gpu/deep_gemm/include/**/*",
"model_executor/ops/cpu/*", "model_executor/ops/xpu/*",
"model_executor/ops/xpu/libs/*",
"model_executor/ops/npu/*", "model_executor/ops/base/*",
"model_executor/ops/xpu/libs/*", "model_executor/ops/npu/*",
"model_executor/ops/base/*", "model_executor/ops/iluvatar/*",
"model_executor/models/*", "model_executor/layers/*",
"input/mm_processor/utils/*",
"version.txt"
@@ -198,9 +207,10 @@ setup(
ext_modules=[
CMakeExtension(
"rdma_comm",
sourcedir="fastdeploy/cache_manager/transfer_factory/kvcache_transfer",
sourcedir=
"fastdeploy/cache_manager/transfer_factory/kvcache_transfer",
version=None)
],
] if os.getenv("ENABLE_FD_RDMA", "0") == "1" else [],
cmdclass=cmdclass_dict if os.getenv("ENABLE_FD_RDMA", "0") == "1" else {},
zip_safe=False,
classifiers=[