mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
Adapt for iluvatar gpu (#2684)
This commit is contained in:
9
build.sh
9
build.sh
@@ -104,6 +104,15 @@ function copy_ops(){
|
|||||||
return
|
return
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if_corex=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device(\"iluvatar_gpu\"))"`
|
||||||
|
if [ "$if_corex" = "True" ]; then
|
||||||
|
DEVICE_TYPE="iluvatar-gpu"
|
||||||
|
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
||||||
|
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/iluvatar
|
||||||
|
echo -e "BASE and Iluvatar ops have been copy to fastdeploy"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
|
||||||
DEVICE_TYPE="cpu"
|
DEVICE_TYPE="cpu"
|
||||||
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
||||||
cd ../../../../
|
cd ../../../../
|
||||||
|
@@ -13,6 +13,7 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include "paddle/extension.h"
|
#include "paddle/extension.h"
|
||||||
|
#include "helper.h"
|
||||||
|
|
||||||
#ifndef PD_BUILD_STATIC_OP
|
#ifndef PD_BUILD_STATIC_OP
|
||||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||||
@@ -59,7 +60,12 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
|||||||
const paddle::Tensor &cum_offsets,
|
const paddle::Tensor &cum_offsets,
|
||||||
const paddle::Tensor &token_num,
|
const paddle::Tensor &token_num,
|
||||||
const paddle::Tensor &seq_len) {
|
const paddle::Tensor &seq_len) {
|
||||||
|
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||||
|
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(input_ids.place()));
|
||||||
|
auto cu_stream = dev_ctx->stream();
|
||||||
|
#else
|
||||||
auto cu_stream = input_ids.stream();
|
auto cu_stream = input_ids.stream();
|
||||||
|
#endif
|
||||||
std::vector<int64_t> input_ids_shape = input_ids.shape();
|
std::vector<int64_t> input_ids_shape = input_ids.shape();
|
||||||
const int bsz = seq_len.shape()[0];
|
const int bsz = seq_len.shape()[0];
|
||||||
const int seq_length = input_ids_shape[1];
|
const int seq_length = input_ids_shape[1];
|
||||||
@@ -75,7 +81,11 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
|||||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||||
auto cu_seqlens_k =
|
auto cu_seqlens_k =
|
||||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||||
int blockSize = min((token_num_data + 32 - 1) / 32 * 32, 128);
|
#ifdef PADDLE_WITH_COREX
|
||||||
|
int blockSize = std::min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128);
|
||||||
|
#else
|
||||||
|
int blockSize = min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128);
|
||||||
|
#endif
|
||||||
GetPaddingOffsetKernel<<<bsz, 128, 0, cu_stream>>>(
|
GetPaddingOffsetKernel<<<bsz, 128, 0, cu_stream>>>(
|
||||||
padding_offset.data<int>(),
|
padding_offset.data<int>(),
|
||||||
cum_offsets_out.data<int>(),
|
cum_offsets_out.data<int>(),
|
||||||
|
@@ -14,7 +14,9 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#ifndef PADDLE_WITH_COREX
|
||||||
#include "glog/logging.h"
|
#include "glog/logging.h"
|
||||||
|
#endif
|
||||||
#include <fcntl.h>
|
#include <fcntl.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
@@ -35,22 +37,35 @@ namespace cub = hipcub;
|
|||||||
#else
|
#else
|
||||||
#include <cub/cub.cuh>
|
#include <cub/cub.cuh>
|
||||||
#endif
|
#endif
|
||||||
|
#ifndef PADDLE_WITH_COREX
|
||||||
#include "nlohmann/json.hpp"
|
#include "nlohmann/json.hpp"
|
||||||
|
#endif
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
#include "env.h"
|
#include "env.h"
|
||||||
#include "paddle/extension.h"
|
#include "paddle/extension.h"
|
||||||
#include "paddle/phi/core/allocator.h"
|
#include "paddle/phi/core/allocator.h"
|
||||||
|
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||||
|
#include "paddle/phi/backends/custom/custom_context.h"
|
||||||
|
#else
|
||||||
#include "paddle/phi/core/cuda_stream.h"
|
#include "paddle/phi/core/cuda_stream.h"
|
||||||
|
#endif
|
||||||
#include "paddle/phi/core/dense_tensor.h"
|
#include "paddle/phi/core/dense_tensor.h"
|
||||||
#include "paddle/phi/backends/gpu/gpu_info.h"
|
#include "paddle/phi/backends/gpu/gpu_info.h"
|
||||||
|
|
||||||
|
#ifdef PADDLE_WITH_COREX
|
||||||
|
#define WARP_SIZE 64
|
||||||
|
#else
|
||||||
|
#define WARP_SIZE 32
|
||||||
|
#endif
|
||||||
#ifndef PD_BUILD_STATIC_OP
|
#ifndef PD_BUILD_STATIC_OP
|
||||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifndef PADDLE_WITH_COREX
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
|
#endif
|
||||||
|
|
||||||
#define CUDA_CHECK(call) \
|
#define CUDA_CHECK(call) \
|
||||||
do { \
|
do { \
|
||||||
@@ -237,6 +252,7 @@ inline int GetBlockSize(int vocab_size) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifndef PADDLE_WITH_COREX
|
||||||
inline json readJsonFromFile(const std::string &filePath) {
|
inline json readJsonFromFile(const std::string &filePath) {
|
||||||
std::ifstream file(filePath);
|
std::ifstream file(filePath);
|
||||||
if (!file.is_open()) {
|
if (!file.is_open()) {
|
||||||
@@ -247,6 +263,7 @@ inline json readJsonFromFile(const std::string &filePath) {
|
|||||||
file >> j;
|
file >> j;
|
||||||
return j;
|
return j;
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
#define cudaCheckError() \
|
#define cudaCheckError() \
|
||||||
{ \
|
{ \
|
||||||
@@ -418,6 +435,7 @@ inline std::string base64_decode(const std::string &encoded_string) {
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifndef PADDLE_WITH_COREX
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline T get_relative_best(nlohmann::json *json_data,
|
inline T get_relative_best(nlohmann::json *json_data,
|
||||||
const std::string &target_key,
|
const std::string &target_key,
|
||||||
@@ -430,6 +448,7 @@ inline T get_relative_best(nlohmann::json *json_data,
|
|||||||
return default_value;
|
return default_value;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
__device__ inline bool is_in_end(const int64_t id, const int64_t *end_ids,
|
__device__ inline bool is_in_end(const int64_t id, const int64_t *end_ids,
|
||||||
int length) {
|
int length) {
|
||||||
|
@@ -18,7 +18,6 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
|
||||||
#include "helper.h"
|
|
||||||
#include "noauxtc_kernel.h"
|
#include "noauxtc_kernel.h"
|
||||||
|
|
||||||
std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
|
std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
|
||||||
|
@@ -17,11 +17,11 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
#include <cooperative_groups.h>
|
#include <cooperative_groups.h>
|
||||||
#include <cooperative_groups/reduce.h>
|
#include <cooperative_groups/reduce.h>
|
||||||
|
#include "helper.h"
|
||||||
|
|
||||||
namespace cg = cooperative_groups;
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
constexpr unsigned FULL_WARP_MASK = 0xffffffff;
|
constexpr unsigned FULL_WARP_MASK = 0xffffffff;
|
||||||
constexpr int32_t WARP_SIZE = 32;
|
|
||||||
constexpr int32_t BLOCK_SIZE = 512;
|
constexpr int32_t BLOCK_SIZE = 512;
|
||||||
constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
|
constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
|
||||||
|
|
||||||
|
@@ -91,7 +91,12 @@ std::vector<paddle::Tensor> rebuild_padding(
|
|||||||
typedef typename traits_::DataType DataType_;
|
typedef typename traits_::DataType DataType_;
|
||||||
typedef typename traits_::data_t data_t;
|
typedef typename traits_::data_t data_t;
|
||||||
|
|
||||||
|
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||||
|
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(tmp_out.place()));
|
||||||
|
auto cu_stream = dev_ctx->stream();
|
||||||
|
#else
|
||||||
auto cu_stream = tmp_out.stream();
|
auto cu_stream = tmp_out.stream();
|
||||||
|
#endif
|
||||||
std::vector<int64_t> tmp_out_shape = tmp_out.shape();
|
std::vector<int64_t> tmp_out_shape = tmp_out.shape();
|
||||||
const int token_num = tmp_out_shape[0];
|
const int token_num = tmp_out_shape[0];
|
||||||
const int dim_embed = tmp_out_shape[1];
|
const int dim_embed = tmp_out_shape[1];
|
||||||
@@ -125,7 +130,7 @@ std::vector<paddle::Tensor> rebuild_padding(
|
|||||||
|
|
||||||
if (output_padding_offset) {
|
if (output_padding_offset) {
|
||||||
RebuildAppendPaddingKernel<DataType_, PackSize>
|
RebuildAppendPaddingKernel<DataType_, PackSize>
|
||||||
<<<grid_size, blocksize, 0, tmp_out.stream()>>>(
|
<<<grid_size, blocksize, 0, cu_stream>>>(
|
||||||
reinterpret_cast<DataType_ *>(out.data<data_t>()),
|
reinterpret_cast<DataType_ *>(out.data<data_t>()),
|
||||||
reinterpret_cast<const DataType_ *>(tmp_out.data<data_t>()),
|
reinterpret_cast<const DataType_ *>(tmp_out.data<data_t>()),
|
||||||
cum_offsets.data<int>(),
|
cum_offsets.data<int>(),
|
||||||
@@ -138,7 +143,7 @@ std::vector<paddle::Tensor> rebuild_padding(
|
|||||||
elem_nums);
|
elem_nums);
|
||||||
} else {
|
} else {
|
||||||
RebuildPaddingKernel<DataType_, PackSize>
|
RebuildPaddingKernel<DataType_, PackSize>
|
||||||
<<<grid_size, blocksize, 0, tmp_out.stream()>>>(
|
<<<grid_size, blocksize, 0, cu_stream>>>(
|
||||||
reinterpret_cast<DataType_ *>(out.data<data_t>()),
|
reinterpret_cast<DataType_ *>(out.data<data_t>()),
|
||||||
reinterpret_cast<DataType_ *>(
|
reinterpret_cast<DataType_ *>(
|
||||||
const_cast<data_t *>(tmp_out.data<data_t>())),
|
const_cast<data_t *>(tmp_out.data<data_t>())),
|
||||||
|
@@ -376,7 +376,6 @@ __global__ void air_topp_sampling(Counter<T> *counters, T *histograms,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// scan/find
|
// scan/find
|
||||||
constexpr int WARP_SIZE = 32;
|
|
||||||
constexpr int WARP_COUNT = NumBuckets / WARP_SIZE;
|
constexpr int WARP_COUNT = NumBuckets / WARP_SIZE;
|
||||||
namespace cg = cooperative_groups;
|
namespace cg = cooperative_groups;
|
||||||
cg::thread_block block = cg::this_thread_block();
|
cg::thread_block block = cg::this_thread_block();
|
||||||
|
@@ -13,6 +13,7 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include "paddle/extension.h"
|
#include "paddle/extension.h"
|
||||||
|
#include "helper.h"
|
||||||
|
|
||||||
#ifndef PD_BUILD_STATIC_OP
|
#ifndef PD_BUILD_STATIC_OP
|
||||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||||
@@ -51,13 +52,18 @@ void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
|
|||||||
const paddle::Tensor &seq_lens_decoder,
|
const paddle::Tensor &seq_lens_decoder,
|
||||||
const paddle::Tensor &step_idx,
|
const paddle::Tensor &step_idx,
|
||||||
const paddle::Tensor &stop_flags) {
|
const paddle::Tensor &stop_flags) {
|
||||||
|
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||||
|
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(stop_flags.place()));
|
||||||
|
auto cu_stream = dev_ctx->stream();
|
||||||
|
#else
|
||||||
auto cu_stream = stop_flags.stream();
|
auto cu_stream = stop_flags.stream();
|
||||||
|
#endif
|
||||||
std::vector<int64_t> pre_ids_all_shape = pre_ids_all.shape();
|
std::vector<int64_t> pre_ids_all_shape = pre_ids_all.shape();
|
||||||
|
|
||||||
int bs = seq_lens_this_time.shape()[0];
|
int bs = seq_lens_this_time.shape()[0];
|
||||||
int length = pre_ids_all_shape[1];
|
int length = pre_ids_all_shape[1];
|
||||||
int length_input_ids = input_ids.shape()[1];
|
int length_input_ids = input_ids.shape()[1];
|
||||||
int block_size = (bs + 32 - 1) / 32 * 32;
|
int block_size = (bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
|
||||||
set_value_by_flag_and_id<<<1, block_size, 0, cu_stream>>>(
|
set_value_by_flag_and_id<<<1, block_size, 0, cu_stream>>>(
|
||||||
stop_flags.data<bool>(),
|
stop_flags.data<bool>(),
|
||||||
const_cast<int64_t *>(pre_ids_all.data<int64_t>()),
|
const_cast<int64_t *>(pre_ids_all.data<int64_t>()),
|
||||||
|
@@ -323,7 +323,12 @@ void StepPaddle(const paddle::Tensor &stop_flags,
|
|||||||
const paddle::Tensor &first_token_ids,
|
const paddle::Tensor &first_token_ids,
|
||||||
const int block_size,
|
const int block_size,
|
||||||
const int encoder_decoder_block_num) {
|
const int encoder_decoder_block_num) {
|
||||||
|
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||||
|
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(seq_lens_this_time.place()));
|
||||||
|
auto cu_stream = dev_ctx->stream();
|
||||||
|
#else
|
||||||
auto cu_stream = seq_lens_this_time.stream();
|
auto cu_stream = seq_lens_this_time.stream();
|
||||||
|
#endif
|
||||||
const int bsz = seq_lens_this_time.shape()[0];
|
const int bsz = seq_lens_this_time.shape()[0];
|
||||||
const int block_num_per_seq = block_tables.shape()[1];
|
const int block_num_per_seq = block_tables.shape()[1];
|
||||||
const int length = input_ids.shape()[1];
|
const int length = input_ids.shape()[1];
|
||||||
|
@@ -74,11 +74,16 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||||
|
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(topk_ids.place()));
|
||||||
|
auto cu_stream = dev_ctx->stream();
|
||||||
|
#else
|
||||||
auto cu_stream = topk_ids.stream();
|
auto cu_stream = topk_ids.stream();
|
||||||
|
#endif
|
||||||
std::vector<int64_t> shape = topk_ids.shape();
|
std::vector<int64_t> shape = topk_ids.shape();
|
||||||
int64_t bs_now = shape[0];
|
int64_t bs_now = shape[0];
|
||||||
int64_t end_length = end_ids.shape()[0];
|
int64_t end_length = end_ids.shape()[0];
|
||||||
int block_size = (bs_now + 32 - 1) / 32 * 32;
|
int block_size = (bs_now + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
|
||||||
set_value_by_flags<<<1, block_size, 0, cu_stream>>>(
|
set_value_by_flags<<<1, block_size, 0, cu_stream>>>(
|
||||||
const_cast<bool *>(stop_flags.data<bool>()),
|
const_cast<bool *>(stop_flags.data<bool>()),
|
||||||
const_cast<int64_t *>(topk_ids.data<int64_t>()),
|
const_cast<int64_t *>(topk_ids.data<int64_t>()),
|
||||||
|
@@ -21,6 +21,7 @@
|
|||||||
#include <sys/types.h>
|
#include <sys/types.h>
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
#include "paddle/extension.h"
|
#include "paddle/extension.h"
|
||||||
|
#include "helper.h"
|
||||||
|
|
||||||
#ifndef PD_BUILD_STATIC_OP
|
#ifndef PD_BUILD_STATIC_OP
|
||||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||||
@@ -88,7 +89,12 @@ void GetStopFlagsMultiSeqs(const paddle::Tensor &topk_ids,
|
|||||||
PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64);
|
PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64);
|
||||||
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
|
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
|
||||||
|
|
||||||
|
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||||
|
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(topk_ids.place()));
|
||||||
|
auto cu_stream = dev_ctx->stream();
|
||||||
|
#else
|
||||||
auto cu_stream = topk_ids.stream();
|
auto cu_stream = topk_ids.stream();
|
||||||
|
#endif
|
||||||
std::vector<int64_t> shape = topk_ids.shape();
|
std::vector<int64_t> shape = topk_ids.shape();
|
||||||
std::vector<int64_t> stop_seqs_shape = stop_seqs.shape();
|
std::vector<int64_t> stop_seqs_shape = stop_seqs.shape();
|
||||||
int bs_now = shape[0];
|
int bs_now = shape[0];
|
||||||
@@ -96,7 +102,7 @@ void GetStopFlagsMultiSeqs(const paddle::Tensor &topk_ids,
|
|||||||
int stop_seqs_max_len = stop_seqs_shape[1];
|
int stop_seqs_max_len = stop_seqs_shape[1];
|
||||||
int pre_ids_len = pre_ids.shape()[1];
|
int pre_ids_len = pre_ids.shape()[1];
|
||||||
|
|
||||||
int block_size = (stop_seqs_bs + 31) / 32 * 32;
|
int block_size = (stop_seqs_bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
|
||||||
set_value_by_stop_seqs<<<bs_now, block_size, 0, cu_stream>>>(
|
set_value_by_stop_seqs<<<bs_now, block_size, 0, cu_stream>>>(
|
||||||
const_cast<bool *>(stop_flags.data<bool>()),
|
const_cast<bool *>(stop_flags.data<bool>()),
|
||||||
const_cast<int64_t *>(topk_ids.data<int64_t>()),
|
const_cast<int64_t *>(topk_ids.data<int64_t>()),
|
||||||
|
@@ -132,7 +132,12 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
|
|||||||
typedef PDTraits<D> traits_;
|
typedef PDTraits<D> traits_;
|
||||||
typedef typename traits_::DataType DataType_;
|
typedef typename traits_::DataType DataType_;
|
||||||
typedef typename traits_::data_t data_t;
|
typedef typename traits_::data_t data_t;
|
||||||
|
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||||
|
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(logits.place()));
|
||||||
|
auto cu_stream = dev_ctx->stream();
|
||||||
|
#else
|
||||||
auto cu_stream = logits.stream();
|
auto cu_stream = logits.stream();
|
||||||
|
#endif
|
||||||
std::vector<int64_t> shape = logits.shape();
|
std::vector<int64_t> shape = logits.shape();
|
||||||
auto repeat_times =
|
auto repeat_times =
|
||||||
paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place());
|
paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place());
|
||||||
@@ -143,7 +148,7 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
|
|||||||
|
|
||||||
int64_t end_length = eos_token_id.shape()[0];
|
int64_t end_length = eos_token_id.shape()[0];
|
||||||
|
|
||||||
int block_size = (bs + 32 - 1) / 32 * 32;
|
int block_size = (bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
|
||||||
min_length_logits_process<<<1, block_size, 0, cu_stream>>>(
|
min_length_logits_process<<<1, block_size, 0, cu_stream>>>(
|
||||||
reinterpret_cast<DataType_ *>(
|
reinterpret_cast<DataType_ *>(
|
||||||
const_cast<data_t *>(logits.data<data_t>())),
|
const_cast<data_t *>(logits.data<data_t>())),
|
||||||
@@ -154,8 +159,12 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
|
|||||||
length,
|
length,
|
||||||
end_length);
|
end_length);
|
||||||
|
|
||||||
block_size = (length_id + 32 - 1) / 32 * 32;
|
block_size = (length_id + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
|
||||||
|
#ifdef PADDLE_WITH_COREX
|
||||||
|
block_size = std::min(block_size, 512);
|
||||||
|
#else
|
||||||
block_size = min(block_size, 512);
|
block_size = min(block_size, 512);
|
||||||
|
#endif
|
||||||
update_repeat_times<<<bs, block_size, 0, cu_stream>>>(
|
update_repeat_times<<<bs, block_size, 0, cu_stream>>>(
|
||||||
pre_ids.data<int64_t>(),
|
pre_ids.data<int64_t>(),
|
||||||
cur_len.data<int64_t>(),
|
cur_len.data<int64_t>(),
|
||||||
@@ -164,8 +173,12 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
|
|||||||
length,
|
length,
|
||||||
length_id);
|
length_id);
|
||||||
|
|
||||||
block_size = (length + 32 - 1) / 32 * 32;
|
block_size = (length + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
|
||||||
|
#ifdef PADDLE_WITH_COREX
|
||||||
|
block_size = std::min(block_size, 512);
|
||||||
|
#else
|
||||||
block_size = min(block_size, 512);
|
block_size = min(block_size, 512);
|
||||||
|
#endif
|
||||||
update_value_by_repeat_times<DataType_><<<bs, block_size, 0, cu_stream>>>(
|
update_value_by_repeat_times<DataType_><<<bs, block_size, 0, cu_stream>>>(
|
||||||
repeat_times.data<int>(),
|
repeat_times.data<int>(),
|
||||||
reinterpret_cast<DataType_ *>(
|
reinterpret_cast<DataType_ *>(
|
||||||
@@ -180,8 +193,12 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
|
|||||||
bs,
|
bs,
|
||||||
length);
|
length);
|
||||||
|
|
||||||
block_size = (length_bad_words + 32 - 1) / 32 * 32;
|
block_size = (length_bad_words + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
|
||||||
|
#ifdef PADDLE_WITH_COREX
|
||||||
|
block_size = std::min(block_size, 512);
|
||||||
|
#else
|
||||||
block_size = min(block_size, 512);
|
block_size = min(block_size, 512);
|
||||||
|
#endif
|
||||||
ban_bad_words<DataType_><<<bs, block_size, 0, cu_stream>>>(
|
ban_bad_words<DataType_><<<bs, block_size, 0, cu_stream>>>(
|
||||||
reinterpret_cast<DataType_ *>(
|
reinterpret_cast<DataType_ *>(
|
||||||
const_cast<data_t *>(logits.data<data_t>())),
|
const_cast<data_t *>(logits.data<data_t>())),
|
||||||
|
@@ -75,11 +75,17 @@ void UpdateInputes(const paddle::Tensor &stop_flags,
|
|||||||
const paddle::Tensor &stop_nums,
|
const paddle::Tensor &stop_nums,
|
||||||
const paddle::Tensor &next_tokens,
|
const paddle::Tensor &next_tokens,
|
||||||
const paddle::Tensor &is_block_step) {
|
const paddle::Tensor &is_block_step) {
|
||||||
|
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||||
|
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(input_ids.place()));
|
||||||
|
auto cu_stream = dev_ctx->stream();
|
||||||
|
#else
|
||||||
|
auto cu_stream = input_ids.stream();
|
||||||
|
#endif
|
||||||
const int max_bsz = stop_flags.shape()[0];
|
const int max_bsz = stop_flags.shape()[0];
|
||||||
const int now_bsz = seq_lens_this_time.shape()[0];
|
const int now_bsz = seq_lens_this_time.shape()[0];
|
||||||
const int input_ids_stride = input_ids.shape()[1];
|
const int input_ids_stride = input_ids.shape()[1];
|
||||||
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
|
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
|
||||||
update_inputs_kernel<1024><<<1, 1024, 0, input_ids.stream()>>>(
|
update_inputs_kernel<1024><<<1, 1024, 0, cu_stream>>>(
|
||||||
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
|
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
|
||||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||||
|
55
custom_ops/iluvatar_ops/fused_moe_helper.h
Normal file
55
custom_ops/iluvatar_ops/fused_moe_helper.h
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
|
||||||
|
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "fused_moe_op.h"
|
||||||
|
|
||||||
|
namespace phi {
|
||||||
|
|
||||||
|
template <typename T, int VecSize>
|
||||||
|
__global__ void moe_token_type_ids_kernel(T *gating_output,
|
||||||
|
const int *moe_token_type_ids_out,
|
||||||
|
const int num_rows,
|
||||||
|
const int num_experts,
|
||||||
|
const int k) {
|
||||||
|
const int moe_token_index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
if (moe_token_index >= num_rows) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
gating_output[moe_token_index * 2] =
|
||||||
|
gating_output[moe_token_index * 2] +
|
||||||
|
(moe_token_type_ids_out[moe_token_index]) * -1e10;
|
||||||
|
gating_output[moe_token_index * 2 + 1] =
|
||||||
|
gating_output[moe_token_index * 2 + 1] +
|
||||||
|
(1 - moe_token_type_ids_out[moe_token_index]) * -1e10;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void moe_token_type_ids_kernelLauncher(T *gating_output,
|
||||||
|
const int *moe_token_type_ids_out,
|
||||||
|
const int num_rows,
|
||||||
|
const int num_experts,
|
||||||
|
const int k,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
const int blocks = num_rows * k / 512 + 1;
|
||||||
|
const int threads = 512;
|
||||||
|
moe_token_type_ids_kernel<T, 1><<<blocks, 512, 0, stream>>>(
|
||||||
|
gating_output, moe_token_type_ids_out, num_rows, num_experts, k);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace phi
|
127
custom_ops/iluvatar_ops/fused_moe_imp_op.h
Normal file
127
custom_ops/iluvatar_ops/fused_moe_imp_op.h
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
/*
|
||||||
|
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION &
|
||||||
|
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <string>
|
||||||
|
#include <sstream>
|
||||||
|
#include "cub/cub.cuh"
|
||||||
|
|
||||||
|
namespace phi {
|
||||||
|
|
||||||
|
static const float HALF_FLT_MAX = 65504.F;
|
||||||
|
static const float HALF_FLT_MIN = -65504.F;
|
||||||
|
static inline size_t AlignTo16(const size_t& input) {
|
||||||
|
static constexpr int ALIGNMENT = 16;
|
||||||
|
return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT);
|
||||||
|
}
|
||||||
|
|
||||||
|
class CubKeyValueSorter {
|
||||||
|
public:
|
||||||
|
CubKeyValueSorter() : num_experts_(0), num_bits_(sizeof(int) * 8) {}
|
||||||
|
|
||||||
|
explicit CubKeyValueSorter(const int num_experts)
|
||||||
|
: num_experts_(num_experts),
|
||||||
|
num_bits_(static_cast<int>(log2(num_experts)) + 1) {}
|
||||||
|
|
||||||
|
void update_num_experts(const int num_experts) {
|
||||||
|
num_experts_ = num_experts;
|
||||||
|
num_bits_ = static_cast<int>(log2(num_experts)) + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t getWorkspaceSize(const size_t num_key_value_pairs,
|
||||||
|
bool descending = false) {
|
||||||
|
num_key_value_pairs_ = num_key_value_pairs;
|
||||||
|
size_t required_storage = 0;
|
||||||
|
int* null_int = nullptr;
|
||||||
|
if (descending) {
|
||||||
|
cub::DeviceRadixSort::SortPairsDescending(NULL,
|
||||||
|
required_storage,
|
||||||
|
null_int,
|
||||||
|
null_int,
|
||||||
|
null_int,
|
||||||
|
null_int,
|
||||||
|
num_key_value_pairs,
|
||||||
|
0,
|
||||||
|
32);
|
||||||
|
} else {
|
||||||
|
cub::DeviceRadixSort::SortPairs(NULL,
|
||||||
|
required_storage,
|
||||||
|
null_int,
|
||||||
|
null_int,
|
||||||
|
null_int,
|
||||||
|
null_int,
|
||||||
|
num_key_value_pairs,
|
||||||
|
0,
|
||||||
|
num_bits_);
|
||||||
|
}
|
||||||
|
return required_storage;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename KeyT>
|
||||||
|
void run(void* workspace,
|
||||||
|
const size_t workspace_size,
|
||||||
|
const KeyT* keys_in,
|
||||||
|
KeyT* keys_out,
|
||||||
|
const int* values_in,
|
||||||
|
int* values_out,
|
||||||
|
const size_t num_key_value_pairs,
|
||||||
|
bool descending,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs);
|
||||||
|
size_t actual_ws_size = workspace_size;
|
||||||
|
|
||||||
|
if (expected_ws_size > workspace_size) {
|
||||||
|
std::stringstream err_ss;
|
||||||
|
err_ss << "[Error][CubKeyValueSorter::run]\n";
|
||||||
|
err_ss << "Error. The allocated workspace is too small to run this "
|
||||||
|
"problem.\n";
|
||||||
|
err_ss << "Expected workspace size of at least " << expected_ws_size
|
||||||
|
<< " but got problem size " << workspace_size << "\n";
|
||||||
|
throw std::runtime_error(err_ss.str());
|
||||||
|
}
|
||||||
|
if (descending) {
|
||||||
|
cub::DeviceRadixSort::SortPairsDescending(workspace,
|
||||||
|
actual_ws_size,
|
||||||
|
keys_in,
|
||||||
|
keys_out,
|
||||||
|
values_in,
|
||||||
|
values_out,
|
||||||
|
num_key_value_pairs,
|
||||||
|
0,
|
||||||
|
32,
|
||||||
|
stream);
|
||||||
|
} else {
|
||||||
|
cub::DeviceRadixSort::SortPairs(workspace,
|
||||||
|
actual_ws_size,
|
||||||
|
keys_in,
|
||||||
|
keys_out,
|
||||||
|
values_in,
|
||||||
|
values_out,
|
||||||
|
num_key_value_pairs,
|
||||||
|
0,
|
||||||
|
num_bits_,
|
||||||
|
stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
size_t num_key_value_pairs_;
|
||||||
|
int num_experts_;
|
||||||
|
int num_bits_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace phi
|
990
custom_ops/iluvatar_ops/fused_moe_op.h
Normal file
990
custom_ops/iluvatar_ops/fused_moe_op.h
Normal file
@@ -0,0 +1,990 @@
|
|||||||
|
// /*
|
||||||
|
// * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION &
|
||||||
|
// * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||||
|
// *
|
||||||
|
// * Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// * you may not use this file except in compliance with the License.
|
||||||
|
// * You may obtain a copy of the License at
|
||||||
|
// *
|
||||||
|
// * http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
// *
|
||||||
|
// * Unless required by applicable law or agreed to in writing, software
|
||||||
|
// * distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// * See the License for the specific language governing permissions and
|
||||||
|
// * limitations under the License.
|
||||||
|
// */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include "fused_moe_imp_op.h"
|
||||||
|
#include "fused_moe_helper.h"
|
||||||
|
// Ignore CUTLASS warnings about type punning
|
||||||
|
#pragma GCC diagnostic push
|
||||||
|
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||||
|
#pragma GCC diagnostic ignored "-Wunused-function"
|
||||||
|
|
||||||
|
// #include "paddle/phi/backends/gpu/gpu_info.h"
|
||||||
|
#pragma GCC diagnostic pop
|
||||||
|
|
||||||
|
#include "helper.h"
|
||||||
|
|
||||||
|
namespace phi {
|
||||||
|
|
||||||
|
struct GpuLaunchConfig {
|
||||||
|
dim3 block_per_grid;
|
||||||
|
dim3 thread_per_block;
|
||||||
|
};
|
||||||
|
|
||||||
|
inline GpuLaunchConfig Get1DBlocksAnd2DGridsMoe(const int64_t cols) {
|
||||||
|
int blocks_x = cols;
|
||||||
|
int blocks_y = 1;
|
||||||
|
int blocks_z = 1;
|
||||||
|
if (blocks_x > 1024) {
|
||||||
|
blocks_y = 256;
|
||||||
|
blocks_x = (blocks_x + blocks_y - 1) / blocks_y;
|
||||||
|
}
|
||||||
|
|
||||||
|
GpuLaunchConfig config;
|
||||||
|
config.block_per_grid.x = blocks_x;
|
||||||
|
config.block_per_grid.y = blocks_y;
|
||||||
|
config.block_per_grid.z = blocks_z;
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ====================== Softmax things ===============================
|
||||||
|
// We have our own implementation of softmax here so we can support transposing
|
||||||
|
// the output in the softmax kernel when we extend this module to support
|
||||||
|
// expert-choice routing.
|
||||||
|
template <typename T, int TPB>
|
||||||
|
__launch_bounds__(TPB) __global__
|
||||||
|
void group_moe_softmax(const T* input,
|
||||||
|
T* output,
|
||||||
|
T* softmax_max_prob,
|
||||||
|
const int64_t num_cols,
|
||||||
|
const int64_t softmax_num_rows) {
|
||||||
|
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||||
|
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||||
|
|
||||||
|
__shared__ float normalizing_factor;
|
||||||
|
__shared__ float float_max;
|
||||||
|
__shared__ float max_out;
|
||||||
|
|
||||||
|
int globalIdx = blockIdx.x + blockIdx.y * gridDim.x;
|
||||||
|
if (globalIdx >= softmax_num_rows) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const int64_t thread_row_offset = globalIdx * num_cols;
|
||||||
|
|
||||||
|
cub::Sum sum;
|
||||||
|
float threadData(-FLT_MAX);
|
||||||
|
|
||||||
|
|
||||||
|
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||||
|
const int idx = thread_row_offset + ii;
|
||||||
|
threadData = max(static_cast<float>(input[idx]), threadData);
|
||||||
|
}
|
||||||
|
|
||||||
|
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
float_max = maxElem;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
threadData = 0;
|
||||||
|
|
||||||
|
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||||
|
const int idx = thread_row_offset + ii;
|
||||||
|
threadData += exp((static_cast<float>(input[idx]) - float_max));
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
normalizing_factor = 1.f / Z;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
threadData = 0;
|
||||||
|
|
||||||
|
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||||
|
const int idx = thread_row_offset + ii;
|
||||||
|
const float val =
|
||||||
|
exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
|
||||||
|
output[idx] = T(val);
|
||||||
|
threadData = max(static_cast<float>(T(val)), threadData);
|
||||||
|
}
|
||||||
|
|
||||||
|
const float maxOut = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
// group max probs
|
||||||
|
max_out = 1.f / maxOut;
|
||||||
|
softmax_max_prob[globalIdx] = T(max_out);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||||
|
const int idx = thread_row_offset + ii;
|
||||||
|
// group softmax normalization
|
||||||
|
output[idx] = output[idx] * static_cast<T>(max_out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int TPB, typename IdxT = int>
|
||||||
|
__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
|
||||||
|
T* output,
|
||||||
|
IdxT* indices,
|
||||||
|
int* source_rows,
|
||||||
|
T* softmax_max_prob,
|
||||||
|
const int64_t num_experts,
|
||||||
|
const int64_t k,
|
||||||
|
const int64_t num_rows) {
|
||||||
|
using cub_kvp = cub::KeyValuePair<int, T>;
|
||||||
|
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
|
||||||
|
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||||
|
|
||||||
|
cub_kvp thread_kvp;
|
||||||
|
cub::ArgMax arg_max;
|
||||||
|
|
||||||
|
const int block_row = blockIdx.x + blockIdx.y * gridDim.x;
|
||||||
|
if (block_row >= num_rows) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const bool should_process_row = true;
|
||||||
|
const int thread_read_offset = block_row * num_experts;
|
||||||
|
|
||||||
|
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||||
|
thread_kvp.key = 0;
|
||||||
|
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
|
||||||
|
|
||||||
|
cub_kvp inp_kvp;
|
||||||
|
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
|
||||||
|
const int idx = thread_read_offset + expert;
|
||||||
|
inp_kvp.key = expert;
|
||||||
|
inp_kvp.value = inputs_after_softmax[idx];
|
||||||
|
|
||||||
|
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
|
||||||
|
const IdxT prior_winning_expert = indices[k * block_row + prior_k];
|
||||||
|
|
||||||
|
if (prior_winning_expert == expert) {
|
||||||
|
inp_kvp = thread_kvp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
thread_kvp = arg_max(inp_kvp, thread_kvp);
|
||||||
|
}
|
||||||
|
|
||||||
|
const cub_kvp result_kvp =
|
||||||
|
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
const int idx = k * block_row + k_idx;
|
||||||
|
// restore normalized probes
|
||||||
|
output[idx] = result_kvp.value / T(softmax_max_prob[idx]);
|
||||||
|
indices[idx] = should_process_row ? result_kvp.key : num_experts;
|
||||||
|
source_rows[idx] = k_idx * num_rows + block_row;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int TPB>
|
||||||
|
__launch_bounds__(TPB) __global__ void moe_softmax(const T* input,
|
||||||
|
T* output,
|
||||||
|
const int64_t num_cols,
|
||||||
|
const int64_t num_rows) {
|
||||||
|
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||||
|
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||||
|
|
||||||
|
__shared__ float normalizing_factor;
|
||||||
|
__shared__ float float_max;
|
||||||
|
|
||||||
|
int globalIdx = blockIdx.x + blockIdx.y * gridDim.x;
|
||||||
|
if (globalIdx >= num_rows) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const int64_t thread_row_offset = globalIdx * num_cols;
|
||||||
|
|
||||||
|
cub::Sum sum;
|
||||||
|
float threadData(-FLT_MAX);
|
||||||
|
|
||||||
|
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||||
|
const int idx = thread_row_offset + ii;
|
||||||
|
threadData = max(static_cast<float>(input[idx]), threadData);
|
||||||
|
}
|
||||||
|
|
||||||
|
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
float_max = maxElem;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
threadData = 0;
|
||||||
|
|
||||||
|
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||||
|
const int idx = thread_row_offset + ii;
|
||||||
|
threadData += exp((static_cast<float>(input[idx]) - float_max));
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
normalizing_factor = 1.f / Z;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||||
|
const int idx = thread_row_offset + ii;
|
||||||
|
const float val =
|
||||||
|
exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
|
||||||
|
output[idx] = T(val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int TPB, typename IdxT = int>
|
||||||
|
__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
|
||||||
|
const T* bias,
|
||||||
|
T* output,
|
||||||
|
IdxT* indices,
|
||||||
|
int* source_rows,
|
||||||
|
const int64_t num_experts,
|
||||||
|
const int64_t k,
|
||||||
|
const int64_t num_rows) {
|
||||||
|
using cub_kvp = cub::KeyValuePair<int, T>;
|
||||||
|
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
|
||||||
|
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||||
|
|
||||||
|
cub_kvp thread_kvp;
|
||||||
|
cub::ArgMax arg_max;
|
||||||
|
|
||||||
|
const int block_row = blockIdx.x + blockIdx.y * gridDim.x;
|
||||||
|
if (block_row >= num_rows) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const bool should_process_row = true;
|
||||||
|
const int thread_read_offset = block_row * num_experts;
|
||||||
|
|
||||||
|
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||||
|
thread_kvp.key = 0;
|
||||||
|
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
|
||||||
|
|
||||||
|
cub_kvp inp_kvp;
|
||||||
|
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
|
||||||
|
const int idx = thread_read_offset + expert;
|
||||||
|
inp_kvp.key = expert;
|
||||||
|
inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ;
|
||||||
|
|
||||||
|
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
|
||||||
|
const IdxT prior_winning_expert = indices[k * block_row + prior_k];
|
||||||
|
|
||||||
|
if (prior_winning_expert == expert) {
|
||||||
|
inp_kvp = thread_kvp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
thread_kvp = arg_max(inp_kvp, thread_kvp);
|
||||||
|
}
|
||||||
|
|
||||||
|
const cub_kvp result_kvp =
|
||||||
|
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
const int idx = k * block_row + k_idx;
|
||||||
|
output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
|
||||||
|
indices[idx] = should_process_row ? result_kvp.key : num_experts;
|
||||||
|
source_rows[idx] = k_idx * num_rows + block_row;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int TPB, typename IdxT = int>
|
||||||
|
__launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
|
||||||
|
const T* bias,
|
||||||
|
T* output,
|
||||||
|
IdxT* indices,
|
||||||
|
int* source_rows,
|
||||||
|
const int64_t num_experts,
|
||||||
|
const int64_t k,
|
||||||
|
const int64_t num_rows) {
|
||||||
|
// softmax
|
||||||
|
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||||
|
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||||
|
|
||||||
|
__shared__ float normalizing_factor;
|
||||||
|
__shared__ float float_max;
|
||||||
|
|
||||||
|
int globalIdx = blockIdx.x + blockIdx.y * gridDim.x;
|
||||||
|
if (globalIdx >= num_rows) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const int64_t thread_row_offset = globalIdx * num_experts;
|
||||||
|
const int64_t idx = thread_row_offset+threadIdx.x;
|
||||||
|
|
||||||
|
cub::Sum sum;
|
||||||
|
|
||||||
|
float threadData = (threadIdx.x < num_experts) ? static_cast<float>(input[idx]) :(-FLT_MAX);
|
||||||
|
|
||||||
|
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
float_max = maxElem;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
float threadDataSub = threadData - float_max;
|
||||||
|
float threadDataExp = exp(threadDataSub);
|
||||||
|
|
||||||
|
const auto Z = BlockReduce(tmpStorage).Reduce(threadDataExp, sum);
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
normalizing_factor = 1.f / Z;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
T val = T(threadDataExp * normalizing_factor);
|
||||||
|
|
||||||
|
// top_k
|
||||||
|
using cub_kvp = cub::KeyValuePair<int, T>;
|
||||||
|
using BlockReduceP = cub::BlockReduce<cub_kvp, TPB>;
|
||||||
|
__shared__ typename BlockReduceP::TempStorage tmpStorageP;
|
||||||
|
|
||||||
|
cub_kvp thread_kvp;
|
||||||
|
cub::ArgMax arg_max;
|
||||||
|
|
||||||
|
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||||
|
thread_kvp.key = 0;
|
||||||
|
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
|
||||||
|
|
||||||
|
if (threadIdx.x < num_experts) {
|
||||||
|
cub_kvp inp_kvp;
|
||||||
|
int expert = threadIdx.x;
|
||||||
|
inp_kvp.key = expert;
|
||||||
|
inp_kvp.value = bias ? val + bias[expert] : val;
|
||||||
|
|
||||||
|
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
|
||||||
|
const IdxT prior_winning_expert = indices[k * globalIdx + prior_k];
|
||||||
|
|
||||||
|
if (prior_winning_expert == expert) {
|
||||||
|
inp_kvp = thread_kvp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
thread_kvp = arg_max(inp_kvp, thread_kvp);
|
||||||
|
}
|
||||||
|
|
||||||
|
const cub_kvp result_kvp =
|
||||||
|
BlockReduceP(tmpStorageP).Reduce(thread_kvp, arg_max);
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
const int cur_idx = k * globalIdx + k_idx;
|
||||||
|
output[cur_idx] = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
|
||||||
|
indices[cur_idx] = result_kvp.key;
|
||||||
|
source_rows[cur_idx] = k_idx * num_rows + globalIdx;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int TPB, typename IdxT = int>
|
||||||
|
__launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_softmax,
|
||||||
|
const T* bias,
|
||||||
|
T* output,
|
||||||
|
IdxT* indices,
|
||||||
|
int* source_rows,
|
||||||
|
const int64_t num_experts,
|
||||||
|
const int64_t k,
|
||||||
|
const int64_t num_rows) {
|
||||||
|
using cub_kvp = cub::KeyValuePair<int, T>;
|
||||||
|
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
|
||||||
|
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||||
|
|
||||||
|
cub_kvp thread_kvp;
|
||||||
|
cub::ArgMax arg_max;
|
||||||
|
|
||||||
|
const int block_row = blockIdx.x + blockIdx.y * gridDim.x;
|
||||||
|
if (block_row >= num_rows) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const bool should_process_row = true;
|
||||||
|
const int thread_read_offset = block_row * num_experts;
|
||||||
|
T weight_sum = static_cast<T>(0);
|
||||||
|
|
||||||
|
extern __shared__ char smem[];
|
||||||
|
|
||||||
|
T* row_outputs = reinterpret_cast<T*>(smem);
|
||||||
|
|
||||||
|
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||||
|
thread_kvp.key = 0;
|
||||||
|
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
|
||||||
|
|
||||||
|
cub_kvp inp_kvp;
|
||||||
|
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
|
||||||
|
const int idx = thread_read_offset + expert;
|
||||||
|
inp_kvp.key = expert;
|
||||||
|
inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ;
|
||||||
|
|
||||||
|
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
|
||||||
|
const int prior_winning_expert = indices[k * block_row + prior_k];
|
||||||
|
|
||||||
|
if (prior_winning_expert == expert) {
|
||||||
|
inp_kvp = thread_kvp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
thread_kvp = arg_max(inp_kvp, thread_kvp);
|
||||||
|
}
|
||||||
|
|
||||||
|
const cub_kvp result_kvp =
|
||||||
|
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
const int idx = k * block_row + k_idx;
|
||||||
|
// output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
|
||||||
|
indices[idx] = should_process_row ? result_kvp.key : num_experts;
|
||||||
|
source_rows[idx] = k_idx * num_rows + block_row;
|
||||||
|
|
||||||
|
T row_out = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
|
||||||
|
row_outputs[k_idx] = row_out;
|
||||||
|
weight_sum += row_out;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
if (threadIdx.x < WARP_SIZE) {
|
||||||
|
weight_sum = __shfl_sync(0xffffffff, weight_sum, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (threadIdx.x < k) {
|
||||||
|
output[k * block_row + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T, int TPB, typename IdxT = int>
|
||||||
|
__launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* input,
|
||||||
|
const T* bias,
|
||||||
|
T* output,
|
||||||
|
IdxT* indices,
|
||||||
|
int* source_rows,
|
||||||
|
const int64_t num_experts,
|
||||||
|
const int64_t k,
|
||||||
|
const int64_t num_rows) {
|
||||||
|
// softmax
|
||||||
|
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||||
|
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||||
|
|
||||||
|
__shared__ float normalizing_factor;
|
||||||
|
__shared__ float float_max;
|
||||||
|
|
||||||
|
int globalIdx = blockIdx.x + blockIdx.y * gridDim.x;
|
||||||
|
if (globalIdx >= num_rows) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const int64_t thread_row_offset = globalIdx * num_experts;
|
||||||
|
const int64_t idx = thread_row_offset+threadIdx.x;
|
||||||
|
|
||||||
|
cub::Sum sum;
|
||||||
|
|
||||||
|
float threadData = (threadIdx.x < num_experts) ? static_cast<float>(input[idx]) :(-FLT_MAX);
|
||||||
|
|
||||||
|
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
float_max = maxElem;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
float threadDataSub = threadData - float_max;
|
||||||
|
float threadDataExp = exp(threadDataSub);
|
||||||
|
|
||||||
|
const auto Z = BlockReduce(tmpStorage).Reduce(threadDataExp, sum);
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
normalizing_factor = 1.f / Z;
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
T val = T(threadDataExp * normalizing_factor);
|
||||||
|
|
||||||
|
// top_k
|
||||||
|
using cub_kvp = cub::KeyValuePair<int, T>;
|
||||||
|
using BlockReduceP = cub::BlockReduce<cub_kvp, TPB>;
|
||||||
|
__shared__ typename BlockReduceP::TempStorage tmpStorageP;
|
||||||
|
|
||||||
|
cub_kvp thread_kvp;
|
||||||
|
cub::ArgMax arg_max;
|
||||||
|
|
||||||
|
T weight_sum = static_cast<T>(0);
|
||||||
|
extern __shared__ char smem[];
|
||||||
|
T* row_outputs = reinterpret_cast<T*>(smem);
|
||||||
|
|
||||||
|
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||||
|
thread_kvp.key = 0;
|
||||||
|
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
|
||||||
|
|
||||||
|
if (threadIdx.x < num_experts) {
|
||||||
|
cub_kvp inp_kvp;
|
||||||
|
int expert = threadIdx.x;
|
||||||
|
inp_kvp.key = expert;
|
||||||
|
inp_kvp.value = bias ? val + bias[expert] : val;
|
||||||
|
|
||||||
|
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
|
||||||
|
const IdxT prior_winning_expert = indices[k * globalIdx + prior_k];
|
||||||
|
|
||||||
|
if (prior_winning_expert == expert) {
|
||||||
|
inp_kvp = thread_kvp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
thread_kvp = arg_max(inp_kvp, thread_kvp);
|
||||||
|
}
|
||||||
|
|
||||||
|
const cub_kvp result_kvp =
|
||||||
|
BlockReduceP(tmpStorageP).Reduce(thread_kvp, arg_max);
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
const int cur_idx = k * globalIdx + k_idx;
|
||||||
|
|
||||||
|
T row_out = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
|
||||||
|
row_outputs[k_idx] = row_out;
|
||||||
|
weight_sum += row_out;
|
||||||
|
|
||||||
|
indices[cur_idx] = result_kvp.key;
|
||||||
|
source_rows[cur_idx] = k_idx * num_rows + globalIdx;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (threadIdx.x < WARP_SIZE) {
|
||||||
|
weight_sum = __shfl_sync(0xffffffff, weight_sum, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (threadIdx.x < k) {
|
||||||
|
output[k * globalIdx + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
// Constructs some constants needed to partition the work across threads at
|
||||||
|
// compile time.
|
||||||
|
template <typename T, int EXPERTS, int BYTES_PER_LDG>
|
||||||
|
struct TopkConstants {
|
||||||
|
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T);
|
||||||
|
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 ||
|
||||||
|
EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0,
|
||||||
|
"");
|
||||||
|
static constexpr int VECs_PER_THREAD =
|
||||||
|
std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
|
||||||
|
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
|
||||||
|
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
|
||||||
|
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
|
||||||
|
};
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
template <typename T, typename IdxT = int>
|
||||||
|
void topk_gating_softmax_kernelLauncher(const T* input,
|
||||||
|
const T* gating_correction_bias,
|
||||||
|
T* output,
|
||||||
|
T* softmax,
|
||||||
|
IdxT* indices,
|
||||||
|
int* source_row,
|
||||||
|
T* softmax_max_prob,
|
||||||
|
const int64_t num_rows,
|
||||||
|
const int64_t num_experts,
|
||||||
|
const int64_t k,
|
||||||
|
const bool group_moe,
|
||||||
|
cudaStream_t stream,
|
||||||
|
const bool topk_only_mode = false) {
|
||||||
|
if (topk_only_mode) {
|
||||||
|
static constexpr int TPB = 256;
|
||||||
|
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
|
||||||
|
moe_top_k<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
|
||||||
|
input, gating_correction_bias, output, indices, source_row, num_experts, k, num_rows);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
static constexpr int WARPS_PER_TB = 4;
|
||||||
|
|
||||||
|
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
|
||||||
|
case N: { \
|
||||||
|
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB>( \
|
||||||
|
input, output, indices, source_row, num_rows, num_experts, k, stream); \
|
||||||
|
break; \
|
||||||
|
}
|
||||||
|
int64_t tem_num_experts = num_experts;
|
||||||
|
if(gating_correction_bias != nullptr) tem_num_experts = 0;
|
||||||
|
switch (tem_num_experts) {
|
||||||
|
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2)
|
||||||
|
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(4)
|
||||||
|
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(8)
|
||||||
|
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(16)
|
||||||
|
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(32)
|
||||||
|
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(64)
|
||||||
|
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(128)
|
||||||
|
//LAUNCH_TOPK_GATING_SOFTMAX_HELPER(256)
|
||||||
|
|
||||||
|
default: {
|
||||||
|
static constexpr int TPB = 256;
|
||||||
|
if (group_moe) {
|
||||||
|
const int group_experts = num_experts / k;
|
||||||
|
const int softmax_num_rows = num_rows * k;
|
||||||
|
const auto config_softmax = Get1DBlocksAnd2DGridsMoe(softmax_num_rows);
|
||||||
|
group_moe_softmax<T, TPB>
|
||||||
|
<<<config_softmax.block_per_grid, TPB, 0, stream>>>(
|
||||||
|
input,
|
||||||
|
softmax,
|
||||||
|
softmax_max_prob,
|
||||||
|
group_experts,
|
||||||
|
softmax_num_rows);
|
||||||
|
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
|
||||||
|
moe_top_k<T, TPB>
|
||||||
|
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
|
||||||
|
output,
|
||||||
|
indices,
|
||||||
|
source_row,
|
||||||
|
softmax_max_prob,
|
||||||
|
num_experts,
|
||||||
|
k,
|
||||||
|
num_rows);
|
||||||
|
} else {
|
||||||
|
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
|
||||||
|
moe_softmax<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
|
||||||
|
input, softmax, num_experts, num_rows);
|
||||||
|
moe_top_k<T, TPB>
|
||||||
|
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
|
||||||
|
gating_correction_bias,
|
||||||
|
output,
|
||||||
|
indices,
|
||||||
|
source_row,
|
||||||
|
num_experts,
|
||||||
|
k,
|
||||||
|
num_rows);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========================== Permutation things
|
||||||
|
// =======================================
|
||||||
|
|
||||||
|
// Duplicated and permutes rows for MoE. In addition, reverse the permutation
|
||||||
|
// map to help with finalizing routing.
|
||||||
|
|
||||||
|
// "expanded_x_row" simply means that the number of values is num_rows x k. It
|
||||||
|
// is "expanded" since we will have to duplicate some rows in the input matrix
|
||||||
|
// to match the dimensions. Duplicates will always get routed to separate
|
||||||
|
// experts in the end.
|
||||||
|
|
||||||
|
// Note that the expanded_dest_row_to_expanded_source_row map referred to here
|
||||||
|
// has indices in the range (0, k*rows_in_input - 1). However, it is set up so
|
||||||
|
// that index 0, rows_in_input, 2*rows_in_input ... (k-1)*rows_in_input all map
|
||||||
|
// to row 0 in the original matrix. Thus, to know where to read in the source
|
||||||
|
// matrix, we simply take the modulus of the expanded index.
|
||||||
|
|
||||||
|
template <typename T, int VecSize>
|
||||||
|
__global__ void initialize_moe_routing_kernel(
|
||||||
|
const T* unpermuted_input,
|
||||||
|
T* permuted_output,
|
||||||
|
const int* expanded_dest_row_to_expanded_source_row,
|
||||||
|
int* expanded_source_row_to_expanded_dest_row,
|
||||||
|
const int64_t num_rows,
|
||||||
|
const int64_t active_rows,
|
||||||
|
const int64_t cols,
|
||||||
|
const int64_t num_rows_k) {
|
||||||
|
using LoadT = AlignedVector<T, VecSize>;
|
||||||
|
LoadT src_vec;
|
||||||
|
|
||||||
|
// Reverse permutation map.
|
||||||
|
// I do this so that later, we can use the source -> dest map to do the k-way
|
||||||
|
// reduction and unpermuting. I need the reverse map for that reduction to
|
||||||
|
// allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1
|
||||||
|
// thread block will be responsible for all k summations.
|
||||||
|
const int expanded_dest_row = blockIdx.x + blockIdx.y * gridDim.x;
|
||||||
|
if (expanded_dest_row >= num_rows_k) return;
|
||||||
|
const int expanded_source_row =
|
||||||
|
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
|
||||||
|
expanded_dest_row;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((blockIdx.x + blockIdx.y * gridDim.x) < active_rows) {
|
||||||
|
// Duplicate and permute rows
|
||||||
|
const int source_row = expanded_source_row % num_rows;
|
||||||
|
|
||||||
|
const T* source_row_ptr = unpermuted_input + source_row * cols;
|
||||||
|
T* dest_row_ptr = permuted_output + expanded_dest_row * cols;
|
||||||
|
|
||||||
|
for (int tid = threadIdx.x * VecSize; tid < cols;
|
||||||
|
tid += blockDim.x * VecSize) {
|
||||||
|
// dest_row_ptr[tid] = source_row_ptr[tid];
|
||||||
|
Load<T, VecSize>(&source_row_ptr[tid], &src_vec);
|
||||||
|
Store<T, VecSize>(src_vec, &dest_row_ptr[tid]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void initialize_moe_routing_kernelLauncher(
|
||||||
|
const T* unpermuted_input,
|
||||||
|
T* permuted_output,
|
||||||
|
const int* expanded_dest_row_to_expanded_source_row,
|
||||||
|
int* expanded_source_row_to_expanded_dest_row,
|
||||||
|
const int64_t num_rows,
|
||||||
|
const int64_t active_rows,
|
||||||
|
const int64_t cols,
|
||||||
|
const int64_t k,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
const int threads = std::min(cols, int64_t(1024));
|
||||||
|
constexpr int max_pack_size = 16 / sizeof(T);
|
||||||
|
const auto config_initialize = Get1DBlocksAnd2DGridsMoe(num_rows * k);
|
||||||
|
if (cols % max_pack_size == 0) {
|
||||||
|
initialize_moe_routing_kernel<T, max_pack_size>
|
||||||
|
<<<config_initialize.block_per_grid, threads, 0, stream>>>(
|
||||||
|
unpermuted_input,
|
||||||
|
permuted_output,
|
||||||
|
expanded_dest_row_to_expanded_source_row,
|
||||||
|
expanded_source_row_to_expanded_dest_row,
|
||||||
|
num_rows,
|
||||||
|
k * active_rows,
|
||||||
|
cols,
|
||||||
|
num_rows * k);
|
||||||
|
} else {
|
||||||
|
initialize_moe_routing_kernel<T, 1>
|
||||||
|
<<<config_initialize.block_per_grid, threads, 0, stream>>>(
|
||||||
|
unpermuted_input,
|
||||||
|
permuted_output,
|
||||||
|
expanded_dest_row_to_expanded_source_row,
|
||||||
|
expanded_source_row_to_expanded_dest_row,
|
||||||
|
num_rows,
|
||||||
|
k * active_rows,
|
||||||
|
cols,
|
||||||
|
num_rows * k);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================== Infer GEMM sizes
|
||||||
|
// =================================
|
||||||
|
__device__ inline int find_total_elts_leq_target(int* sorted_indices,
|
||||||
|
const int64_t arr_length,
|
||||||
|
const int64_t target) {
|
||||||
|
int64_t low = 0, high = arr_length - 1, target_location = -1;
|
||||||
|
while (low <= high) {
|
||||||
|
int64_t mid = (low + high) / 2;
|
||||||
|
|
||||||
|
if (sorted_indices[mid] > target) {
|
||||||
|
high = mid - 1;
|
||||||
|
} else {
|
||||||
|
low = mid + 1;
|
||||||
|
target_location = mid;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return target_location + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Final kernel to unpermute and scale
|
||||||
|
// This kernel unpermutes the original data, does the k-way reduction and
|
||||||
|
// performs the final skip connection.
|
||||||
|
template <typename T, int RESIDUAL_NUM>
|
||||||
|
__global__ void finalize_moe_routing_kernel(
|
||||||
|
const T* expanded_permuted_rows,
|
||||||
|
T* reduced_unpermuted_output,
|
||||||
|
const T* bias,
|
||||||
|
const float* scales,
|
||||||
|
const int* expanded_source_row_to_expanded_dest_row,
|
||||||
|
const int* expert_for_source_row,
|
||||||
|
const int64_t cols,
|
||||||
|
const int64_t k,
|
||||||
|
const int64_t compute_bias,
|
||||||
|
const bool norm_topk_prob,
|
||||||
|
const float routed_scaling_factor,
|
||||||
|
const int64_t num_rows) {
|
||||||
|
const int original_row = blockIdx.x + blockIdx.y * gridDim.x;
|
||||||
|
// const int original_row = blockIdx.x;
|
||||||
|
// const int num_rows = gridDim.x;
|
||||||
|
if (original_row >= num_rows) return;
|
||||||
|
T* reduced_row_ptr = reduced_unpermuted_output + original_row * cols;
|
||||||
|
|
||||||
|
for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) {
|
||||||
|
T thread_output{0.f};
|
||||||
|
float row_rescale{0.f};
|
||||||
|
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||||
|
const int expanded_original_row = original_row + k_idx * num_rows;
|
||||||
|
const int expanded_permuted_row =
|
||||||
|
expanded_source_row_to_expanded_dest_row[expanded_original_row];
|
||||||
|
|
||||||
|
const int64_t k_offset = original_row * k + k_idx;
|
||||||
|
const float row_scale = scales[k_offset];
|
||||||
|
row_rescale = row_rescale + row_scale;
|
||||||
|
|
||||||
|
const T* expanded_permuted_rows_row_ptr =
|
||||||
|
expanded_permuted_rows + expanded_permuted_row * cols;
|
||||||
|
|
||||||
|
const int expert_idx = expert_for_source_row[k_offset];
|
||||||
|
const T* bias_ptr = bias ? bias + expert_idx * cols : nullptr;
|
||||||
|
const T bias_value = bias_ptr ? bias_ptr[tid] : T{0.f};
|
||||||
|
|
||||||
|
thread_output =
|
||||||
|
static_cast<float>(thread_output) +
|
||||||
|
row_scale * static_cast<float>(
|
||||||
|
expanded_permuted_rows_row_ptr[tid] +
|
||||||
|
bias_value *
|
||||||
|
static_cast<T>(static_cast<float>(compute_bias)));
|
||||||
|
}
|
||||||
|
|
||||||
|
thread_output = static_cast<float>(thread_output) /
|
||||||
|
(norm_topk_prob ? row_rescale : 1.0f) *
|
||||||
|
routed_scaling_factor;
|
||||||
|
reduced_row_ptr[tid] = thread_output;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void finalize_moe_routing_kernelLauncher(
|
||||||
|
const T* expanded_permuted_rows,
|
||||||
|
T* reduced_unpermuted_output,
|
||||||
|
const T* bias,
|
||||||
|
const float* scales,
|
||||||
|
const int* expanded_source_row_to_expanded_dest_row,
|
||||||
|
const int* expert_for_source_row,
|
||||||
|
const int64_t num_rows,
|
||||||
|
const int64_t cols,
|
||||||
|
const int64_t k,
|
||||||
|
const int64_t compute_bias,
|
||||||
|
const bool norm_topk_prob,
|
||||||
|
const float routed_scaling_factor,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
const int threads = std::min(cols, int64_t(1024));
|
||||||
|
const auto config_final = Get1DBlocksAnd2DGridsMoe(num_rows);
|
||||||
|
|
||||||
|
finalize_moe_routing_kernel<T, 1>
|
||||||
|
<<<config_final.block_per_grid, threads, 0, stream>>>(
|
||||||
|
expanded_permuted_rows,
|
||||||
|
reduced_unpermuted_output,
|
||||||
|
bias,
|
||||||
|
scales,
|
||||||
|
expanded_source_row_to_expanded_dest_row,
|
||||||
|
expert_for_source_row,
|
||||||
|
cols,
|
||||||
|
k,
|
||||||
|
compute_bias,
|
||||||
|
norm_topk_prob,
|
||||||
|
routed_scaling_factor,
|
||||||
|
num_rows);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========================= TopK Softmax specializations
|
||||||
|
// ===========================
|
||||||
|
template void topk_gating_softmax_kernelLauncher(const float*,
|
||||||
|
const float*,
|
||||||
|
float*,
|
||||||
|
float*,
|
||||||
|
int*,
|
||||||
|
int*,
|
||||||
|
float*,
|
||||||
|
const int64_t,
|
||||||
|
const int64_t,
|
||||||
|
const int64_t,
|
||||||
|
const bool,
|
||||||
|
cudaStream_t,
|
||||||
|
const bool);
|
||||||
|
template void topk_gating_softmax_kernelLauncher(const half*,
|
||||||
|
const half*,
|
||||||
|
half*,
|
||||||
|
half*,
|
||||||
|
int*,
|
||||||
|
int*,
|
||||||
|
half*,
|
||||||
|
const int64_t,
|
||||||
|
const int64_t,
|
||||||
|
const int64_t,
|
||||||
|
const bool,
|
||||||
|
cudaStream_t,
|
||||||
|
const bool);
|
||||||
|
#ifdef PADDLE_CUDA_BF16
|
||||||
|
template void topk_gating_softmax_kernelLauncher(const __nv_bfloat16*,
|
||||||
|
const __nv_bfloat16*,
|
||||||
|
__nv_bfloat16*,
|
||||||
|
__nv_bfloat16*,
|
||||||
|
int*,
|
||||||
|
int*,
|
||||||
|
__nv_bfloat16*,
|
||||||
|
const int64_t,
|
||||||
|
const int64_t,
|
||||||
|
const int64_t,
|
||||||
|
const bool,
|
||||||
|
cudaStream_t,
|
||||||
|
const bool);
|
||||||
|
#endif
|
||||||
|
// ===================== Specializations for init routing
|
||||||
|
// =========================
|
||||||
|
template void initialize_moe_routing_kernelLauncher(const float*,
|
||||||
|
float*,
|
||||||
|
const int*,
|
||||||
|
int*,
|
||||||
|
const int64_t,
|
||||||
|
const int64_t,
|
||||||
|
const int64_t,
|
||||||
|
const int64_t,
|
||||||
|
cudaStream_t);
|
||||||
|
template void initialize_moe_routing_kernelLauncher(const half*,
|
||||||
|
half*,
|
||||||
|
const int*,
|
||||||
|
int*,
|
||||||
|
const int64_t,
|
||||||
|
const int64_t,
|
||||||
|
const int64_t,
|
||||||
|
const int64_t,
|
||||||
|
cudaStream_t);
|
||||||
|
#ifdef PADDLE_CUDA_BF16
|
||||||
|
template void initialize_moe_routing_kernelLauncher(const __nv_bfloat16*,
|
||||||
|
__nv_bfloat16*,
|
||||||
|
const int*,
|
||||||
|
int*,
|
||||||
|
const int64_t,
|
||||||
|
const int64_t,
|
||||||
|
const int64_t,
|
||||||
|
const int64_t,
|
||||||
|
cudaStream_t);
|
||||||
|
#endif
|
||||||
|
// ==================== Specializations for final routing
|
||||||
|
// ===================================
|
||||||
|
template void finalize_moe_routing_kernelLauncher(const float*,
|
||||||
|
float*,
|
||||||
|
const float*,
|
||||||
|
const float*,
|
||||||
|
const int*,
|
||||||
|
const int*,
|
||||||
|
const int64_t,
|
||||||
|
const int64_t,
|
||||||
|
const int64_t,
|
||||||
|
const int64_t,
|
||||||
|
const bool,
|
||||||
|
const float,
|
||||||
|
cudaStream_t);
|
||||||
|
template void finalize_moe_routing_kernelLauncher(const half*,
|
||||||
|
half*,
|
||||||
|
const half*,
|
||||||
|
const float*,
|
||||||
|
const int*,
|
||||||
|
const int*,
|
||||||
|
const int64_t,
|
||||||
|
const int64_t,
|
||||||
|
const int64_t,
|
||||||
|
const int64_t,
|
||||||
|
const bool,
|
||||||
|
const float,
|
||||||
|
cudaStream_t);
|
||||||
|
#ifdef PADDLE_CUDA_BF16
|
||||||
|
template void finalize_moe_routing_kernelLauncher(const __nv_bfloat16*,
|
||||||
|
__nv_bfloat16*,
|
||||||
|
const __nv_bfloat16*,
|
||||||
|
const float*,
|
||||||
|
const int*,
|
||||||
|
const int*,
|
||||||
|
const int64_t,
|
||||||
|
const int64_t,
|
||||||
|
const int64_t,
|
||||||
|
const int64_t,
|
||||||
|
const bool,
|
||||||
|
const float,
|
||||||
|
cudaStream_t);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace phi
|
311
custom_ops/iluvatar_ops/moe_dispatch.cu
Normal file
311
custom_ops/iluvatar_ops/moe_dispatch.cu
Normal file
@@ -0,0 +1,311 @@
|
|||||||
|
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
// Ignore CUTLASS warnings about type punning
|
||||||
|
#pragma GCC diagnostic push
|
||||||
|
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||||
|
#pragma GCC diagnostic ignored "-Wunused-function"
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "fused_moe_helper.h"
|
||||||
|
#include "fused_moe_op.h"
|
||||||
|
#pragma GCC diagnostic pop
|
||||||
|
#include "helper.h"
|
||||||
|
|
||||||
|
__global__ void compute_total_rows_before_expert_kernel(
|
||||||
|
int* sorted_experts,
|
||||||
|
const int64_t sorted_experts_len,
|
||||||
|
const int64_t num_experts,
|
||||||
|
int64_t* total_rows_before_expert) {
|
||||||
|
const int expert = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (expert >= num_experts) return;
|
||||||
|
total_rows_before_expert[expert] =
|
||||||
|
phi::find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert);
|
||||||
|
}
|
||||||
|
|
||||||
|
void compute_total_rows_before_expert(int* sorted_indices,
|
||||||
|
const int64_t total_indices,
|
||||||
|
const int64_t num_experts,
|
||||||
|
int64_t* total_rows_before_expert,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
const int threads = std::min(int64_t(1024), num_experts);
|
||||||
|
const int blocks = (num_experts + threads - 1) / threads;
|
||||||
|
|
||||||
|
compute_total_rows_before_expert_kernel<<<blocks, threads, 0, stream>>>(
|
||||||
|
sorted_indices, total_indices, num_experts, total_rows_before_expert);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <paddle::DataType T>
|
||||||
|
void MoeDispatchKernel(const paddle::Tensor& input,
|
||||||
|
const paddle::Tensor& gating_output,
|
||||||
|
const paddle::optional<paddle::Tensor>& gating_correction_bias,
|
||||||
|
const int moe_topk,
|
||||||
|
const bool group_moe,
|
||||||
|
const bool topk_only_mode,
|
||||||
|
const int num_rows,
|
||||||
|
const int hidden_size,
|
||||||
|
const int expert_num,
|
||||||
|
paddle::Tensor* permute_input,
|
||||||
|
paddle::Tensor* tokens_expert_prefix_sum,
|
||||||
|
paddle::Tensor* permute_indices_per_token,
|
||||||
|
paddle::Tensor* top_k_weight,
|
||||||
|
paddle::Tensor* top_k_indices) {
|
||||||
|
using namespace phi;
|
||||||
|
|
||||||
|
typedef PDTraits<T> traits_;
|
||||||
|
typedef typename traits_::DataType DataType_;
|
||||||
|
typedef typename traits_::data_t data_t;
|
||||||
|
auto place = input.place();
|
||||||
|
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(input.place()));
|
||||||
|
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
|
||||||
|
if (group_moe) {
|
||||||
|
// Check if expert_num is divisible by moe_topk, else throw an error
|
||||||
|
PADDLE_ENFORCE_EQ(expert_num % moe_topk,
|
||||||
|
0,
|
||||||
|
common::errors::InvalidArgument(
|
||||||
|
"The number of experts (expert_num) "
|
||||||
|
"must be divisible by moe_topk. "
|
||||||
|
"Got expert_num = %d and moe_topk = %d.",
|
||||||
|
expert_num,
|
||||||
|
moe_topk));
|
||||||
|
}
|
||||||
|
|
||||||
|
const int num_moe_inputs = AlignTo16(num_rows * moe_topk);
|
||||||
|
const int bytes = num_moe_inputs * sizeof(int);
|
||||||
|
|
||||||
|
CubKeyValueSorter sorter_;
|
||||||
|
sorter_.update_num_experts(expert_num);
|
||||||
|
|
||||||
|
const int sorter_ws_size_bytes =
|
||||||
|
AlignTo16(sorter_.getWorkspaceSize(moe_topk * num_rows));
|
||||||
|
const int sort_tmp_in_out_size = num_moe_inputs * 2 * sizeof(int);
|
||||||
|
|
||||||
|
paddle::Tensor ws_ptr_tensor =
|
||||||
|
GetEmptyTensor({bytes + sorter_ws_size_bytes + sort_tmp_in_out_size},
|
||||||
|
paddle::DataType::INT8,
|
||||||
|
place);
|
||||||
|
|
||||||
|
int8_t* ws_ptr = ws_ptr_tensor.data<int8_t>();
|
||||||
|
int* source_rows_ = reinterpret_cast<int*>(ws_ptr);
|
||||||
|
int8_t* sorter_ws_ptr = reinterpret_cast<int8_t*>(ws_ptr + bytes);
|
||||||
|
int* permuted_experts_ =
|
||||||
|
reinterpret_cast<int*>(sorter_ws_ptr + sorter_ws_size_bytes);
|
||||||
|
int* permuted_rows_ = permuted_experts_ + num_moe_inputs;
|
||||||
|
|
||||||
|
int* expert_for_source_row = top_k_indices->data<int>();
|
||||||
|
|
||||||
|
float* softmax_max_prob = nullptr;
|
||||||
|
if (group_moe) {
|
||||||
|
paddle::Tensor softmax_max_prob_tensor =
|
||||||
|
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||||
|
// (TODO: check fill sucess ?)
|
||||||
|
paddle::experimental::fill(softmax_max_prob_tensor, 0.f);
|
||||||
|
softmax_max_prob = softmax_max_prob_tensor.data<float>();
|
||||||
|
}
|
||||||
|
|
||||||
|
float* softmax_out_;
|
||||||
|
|
||||||
|
const bool is_pow_2 =
|
||||||
|
(expert_num != 0) && ((expert_num & (expert_num - 1)) == 0);
|
||||||
|
|
||||||
|
paddle::Tensor softmax_buffer;
|
||||||
|
|
||||||
|
if (!is_pow_2 || expert_num > 256 || group_moe || gating_correction_bias) {
|
||||||
|
softmax_buffer = GetEmptyTensor(
|
||||||
|
{num_rows * expert_num}, paddle::DataType::FLOAT32, place);
|
||||||
|
softmax_out_ = softmax_buffer.data<float>();
|
||||||
|
} else {
|
||||||
|
softmax_out_ = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
topk_gating_softmax_kernelLauncher<float>(gating_output.data<float>(),
|
||||||
|
gating_correction_bias ? gating_correction_bias.get().data<float>() : nullptr,
|
||||||
|
top_k_weight->data<float>(),
|
||||||
|
softmax_out_,
|
||||||
|
expert_for_source_row,
|
||||||
|
source_rows_,
|
||||||
|
softmax_max_prob,
|
||||||
|
num_rows,
|
||||||
|
expert_num,
|
||||||
|
moe_topk,
|
||||||
|
group_moe,
|
||||||
|
stream,
|
||||||
|
topk_only_mode);
|
||||||
|
|
||||||
|
sorter_.run(reinterpret_cast<void*>(sorter_ws_ptr),
|
||||||
|
sorter_ws_size_bytes,
|
||||||
|
expert_for_source_row,
|
||||||
|
permuted_experts_,
|
||||||
|
source_rows_,
|
||||||
|
permuted_rows_,
|
||||||
|
moe_topk * num_rows,
|
||||||
|
false,
|
||||||
|
stream);
|
||||||
|
|
||||||
|
|
||||||
|
initialize_moe_routing_kernelLauncher(
|
||||||
|
input.data<data_t>(),
|
||||||
|
permute_input->data<data_t>(),
|
||||||
|
permuted_rows_,
|
||||||
|
permute_indices_per_token->data<int32_t>(),
|
||||||
|
num_rows,
|
||||||
|
num_rows,
|
||||||
|
hidden_size,
|
||||||
|
moe_topk,
|
||||||
|
stream);
|
||||||
|
|
||||||
|
|
||||||
|
compute_total_rows_before_expert(
|
||||||
|
permuted_experts_,
|
||||||
|
moe_topk * num_rows,
|
||||||
|
expert_num,
|
||||||
|
tokens_expert_prefix_sum->data<int64_t>(),
|
||||||
|
stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||||
|
const paddle::Tensor& input,
|
||||||
|
const paddle::Tensor& gating_output,
|
||||||
|
const paddle::optional<paddle::Tensor>& gating_correction_bias,
|
||||||
|
const paddle::optional<paddle::Tensor>& w4a8_in_scale,
|
||||||
|
const int moe_topk,
|
||||||
|
const bool group_moe,
|
||||||
|
const bool topk_only_mode) {
|
||||||
|
const auto input_type = input.dtype();
|
||||||
|
auto place = input.place();
|
||||||
|
int token_rows = 0;
|
||||||
|
auto input_dims = input.dims();
|
||||||
|
auto gating_dims = gating_output.dims();
|
||||||
|
const int expert_num = gating_dims[gating_dims.size() - 1];
|
||||||
|
|
||||||
|
if (input_dims.size() == 3) {
|
||||||
|
token_rows = input_dims[0] * input_dims[1];
|
||||||
|
} else {
|
||||||
|
token_rows = input_dims[0];
|
||||||
|
}
|
||||||
|
const int num_rows = token_rows;
|
||||||
|
const int hidden_size = input.dims()[input_dims.size() - 1];
|
||||||
|
|
||||||
|
auto permute_input =
|
||||||
|
GetEmptyTensor({moe_topk * num_rows, hidden_size}, input_type, place);
|
||||||
|
// correspond to the weighted coefficients of the results from each expert.
|
||||||
|
auto top_k_weight =
|
||||||
|
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||||
|
auto top_k_indices =
|
||||||
|
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::INT32, place);
|
||||||
|
|
||||||
|
auto tokens_expert_prefix_sum =
|
||||||
|
GetEmptyTensor({expert_num}, paddle::DataType::INT64, place);
|
||||||
|
auto permute_indices_per_token =
|
||||||
|
GetEmptyTensor({moe_topk, num_rows}, paddle::DataType::INT32, place);
|
||||||
|
|
||||||
|
|
||||||
|
switch (input_type) {
|
||||||
|
case paddle::DataType::BFLOAT16:
|
||||||
|
MoeDispatchKernel<paddle::DataType::BFLOAT16>(input,
|
||||||
|
gating_output,
|
||||||
|
gating_correction_bias,
|
||||||
|
moe_topk,
|
||||||
|
group_moe,
|
||||||
|
topk_only_mode,
|
||||||
|
num_rows,
|
||||||
|
hidden_size,
|
||||||
|
expert_num,
|
||||||
|
&permute_input,
|
||||||
|
&tokens_expert_prefix_sum,
|
||||||
|
&permute_indices_per_token,
|
||||||
|
&top_k_weight,
|
||||||
|
&top_k_indices);
|
||||||
|
break;
|
||||||
|
case paddle::DataType::FLOAT16:
|
||||||
|
MoeDispatchKernel<paddle::DataType::FLOAT16>(input,
|
||||||
|
gating_output,
|
||||||
|
gating_correction_bias,
|
||||||
|
moe_topk,
|
||||||
|
group_moe,
|
||||||
|
topk_only_mode,
|
||||||
|
num_rows,
|
||||||
|
hidden_size,
|
||||||
|
expert_num,
|
||||||
|
&permute_input,
|
||||||
|
&tokens_expert_prefix_sum,
|
||||||
|
&permute_indices_per_token,
|
||||||
|
&top_k_weight,
|
||||||
|
&top_k_indices);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
PD_THROW("Unsupported data type for MoeDispatchKernel");
|
||||||
|
}
|
||||||
|
return {permute_input,
|
||||||
|
tokens_expert_prefix_sum,
|
||||||
|
permute_indices_per_token,
|
||||||
|
top_k_weight,
|
||||||
|
top_k_indices,
|
||||||
|
top_k_indices};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
|
||||||
|
const std::vector<int64_t>& input_shape,
|
||||||
|
const std::vector<int64_t>& gating_output_shape,
|
||||||
|
const paddle::optional<std::vector<int64_t>>& bias_shape,
|
||||||
|
const int moe_topk) {
|
||||||
|
int token_rows = -1;
|
||||||
|
|
||||||
|
if (input_shape.size() == 3) {
|
||||||
|
token_rows = input_shape[0] * input_shape[1];
|
||||||
|
} else {
|
||||||
|
token_rows = input_shape[0];
|
||||||
|
}
|
||||||
|
const int expert_num = gating_output_shape[gating_output_shape.size() - 1];
|
||||||
|
const int num_rows = token_rows;
|
||||||
|
const int hidden_size = input_shape[input_shape.size() - 1];
|
||||||
|
|
||||||
|
return {{moe_topk * num_rows, hidden_size},
|
||||||
|
{expert_num},
|
||||||
|
{moe_topk, num_rows},
|
||||||
|
{num_rows, moe_topk},
|
||||||
|
{num_rows, moe_topk},
|
||||||
|
{num_rows, moe_topk}};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
|
||||||
|
const paddle::DataType& input_dtype,
|
||||||
|
const paddle::DataType& gating_output_dtype,
|
||||||
|
const paddle::optional<paddle::DataType>& bias_type,
|
||||||
|
const int moe_topk) {
|
||||||
|
return {input_dtype,
|
||||||
|
paddle::DataType::INT64,
|
||||||
|
paddle::DataType::INT32,
|
||||||
|
paddle::DataType::FLOAT32,
|
||||||
|
paddle::DataType::INT32,
|
||||||
|
paddle::DataType::INT32};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
PD_BUILD_STATIC_OP(moe_expert_dispatch)
|
||||||
|
.Inputs({"input", "gating_output", paddle::Optional("gating_correction_bias"),
|
||||||
|
paddle::Optional("w4a8_in_scale")})
|
||||||
|
.Outputs({"permute_input",
|
||||||
|
"tokens_expert_prefix_sum",
|
||||||
|
"permute_indices_per_token",
|
||||||
|
"top_k_weight",
|
||||||
|
"top_k_indices",
|
||||||
|
"expert_idx_per_token"})
|
||||||
|
.Attrs({"moe_topk:int", "group_moe:bool", "topk_only_mode:bool"})
|
||||||
|
.SetKernelFn(PD_KERNEL(MoeExpertDispatch))
|
||||||
|
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape))
|
||||||
|
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertDispatchInferDtype));
|
155
custom_ops/iluvatar_ops/moe_reduce.cu
Normal file
155
custom_ops/iluvatar_ops/moe_reduce.cu
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// Ignore CUTLASS warnings about type punning
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "helper.h"
|
||||||
|
#include "fused_moe_helper.h"
|
||||||
|
#include "fused_moe_op.h"
|
||||||
|
|
||||||
|
template <paddle::DataType T>
|
||||||
|
void MoeReduceKernel(const paddle::Tensor& ffn_out,
|
||||||
|
const paddle::Tensor& top_k_weight,
|
||||||
|
const paddle::Tensor& permute_indices_per_token,
|
||||||
|
const paddle::Tensor& top_k_indices,
|
||||||
|
const paddle::optional<paddle::Tensor>& ffn2_bias,
|
||||||
|
const bool norm_topk_prob,
|
||||||
|
const float routed_scaling_factor,
|
||||||
|
const int num_rows,
|
||||||
|
const int hidden_size,
|
||||||
|
const int topk,
|
||||||
|
paddle::Tensor* output) {
|
||||||
|
using namespace phi;
|
||||||
|
typedef PDTraits<T> traits_;
|
||||||
|
typedef typename traits_::DataType DataType_;
|
||||||
|
typedef typename traits_::data_t data_t;
|
||||||
|
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(ffn_out.place()));
|
||||||
|
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
|
||||||
|
|
||||||
|
finalize_moe_routing_kernelLauncher(
|
||||||
|
ffn_out.data<data_t>(),
|
||||||
|
output->data<data_t>(),
|
||||||
|
ffn2_bias ? ffn2_bias->data<data_t>() : nullptr,
|
||||||
|
top_k_weight.data<float>(),
|
||||||
|
permute_indices_per_token.data<int32_t>(),
|
||||||
|
top_k_indices.data<int>(),
|
||||||
|
num_rows,
|
||||||
|
hidden_size,
|
||||||
|
topk,
|
||||||
|
static_cast<int>(1),
|
||||||
|
norm_topk_prob,
|
||||||
|
routed_scaling_factor,
|
||||||
|
stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
paddle::Tensor MoeExpertReduceFunc(
|
||||||
|
const paddle::Tensor& ffn_out,
|
||||||
|
const paddle::Tensor& top_k_weight,
|
||||||
|
const paddle::Tensor& permute_indices_per_token,
|
||||||
|
const paddle::Tensor& top_k_indices,
|
||||||
|
const paddle::optional<paddle::Tensor>& ffn2_bias,
|
||||||
|
const bool norm_topk_prob,
|
||||||
|
const float routed_scaling_factor) {
|
||||||
|
const auto input_type = ffn_out.dtype();
|
||||||
|
auto place = ffn_out.place();
|
||||||
|
|
||||||
|
const int topk = top_k_indices.dims()[1];
|
||||||
|
const int num_rows = ffn_out.dims()[0] / topk;
|
||||||
|
const int hidden_size = ffn_out.dims()[1];
|
||||||
|
|
||||||
|
auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place);
|
||||||
|
|
||||||
|
switch (input_type) {
|
||||||
|
case paddle::DataType::BFLOAT16:
|
||||||
|
MoeReduceKernel<paddle::DataType::BFLOAT16>(
|
||||||
|
ffn_out,
|
||||||
|
top_k_weight,
|
||||||
|
permute_indices_per_token,
|
||||||
|
top_k_indices,
|
||||||
|
ffn2_bias,
|
||||||
|
norm_topk_prob,
|
||||||
|
routed_scaling_factor,
|
||||||
|
num_rows,
|
||||||
|
hidden_size,
|
||||||
|
topk,
|
||||||
|
&output);
|
||||||
|
break;
|
||||||
|
case paddle::DataType::FLOAT16:
|
||||||
|
MoeReduceKernel<paddle::DataType::BFLOAT16>(
|
||||||
|
ffn_out,
|
||||||
|
top_k_weight,
|
||||||
|
permute_indices_per_token,
|
||||||
|
top_k_indices,
|
||||||
|
ffn2_bias,
|
||||||
|
norm_topk_prob,
|
||||||
|
routed_scaling_factor,
|
||||||
|
num_rows,
|
||||||
|
hidden_size,
|
||||||
|
topk,
|
||||||
|
&output);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
PD_THROW("Unsupported data type for MoeDispatchKernel");
|
||||||
|
}
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<paddle::Tensor> MoeExpertReduce(
|
||||||
|
const paddle::Tensor& ffn_out,
|
||||||
|
const paddle::Tensor& top_k_weight,
|
||||||
|
const paddle::Tensor& permute_indices_per_token,
|
||||||
|
const paddle::Tensor& top_k_indices,
|
||||||
|
const paddle::optional<paddle::Tensor>& ffn2_bias,
|
||||||
|
const bool norm_topk_prob,
|
||||||
|
const float routed_scaling_factor) {
|
||||||
|
return {MoeExpertReduceFunc(ffn_out,
|
||||||
|
top_k_weight,
|
||||||
|
permute_indices_per_token,
|
||||||
|
top_k_indices,
|
||||||
|
ffn2_bias,
|
||||||
|
norm_topk_prob,
|
||||||
|
routed_scaling_factor)};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<int64_t>> MoeExpertReduceInferShape(
|
||||||
|
const std::vector<int64_t>& ffn_out_shape,
|
||||||
|
const std::vector<int64_t>& top_k_weight_shape,
|
||||||
|
const std::vector<int64_t>& permute_indices_per_token_shape,
|
||||||
|
const std::vector<int64_t>& top_k_indices_shape,
|
||||||
|
const paddle::optional<std::vector<int64_t>>& ffn2_bias_shape) {
|
||||||
|
return {ffn_out_shape};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<paddle::DataType> MoeExpertReduceInferDtype(
|
||||||
|
const paddle::DataType& ffn_out_dtype,
|
||||||
|
const paddle::DataType& top_k_weight_dtype,
|
||||||
|
const paddle::DataType& permute_indices_per_token_dtype,
|
||||||
|
const paddle::DataType& top_k_indices_dtype,
|
||||||
|
const paddle::optional<paddle::DataType>& ffn2_bias_dtype) {
|
||||||
|
return {ffn_out_dtype};
|
||||||
|
}
|
||||||
|
|
||||||
|
PD_BUILD_STATIC_OP(moe_expert_reduce)
|
||||||
|
.Inputs({"ffn_out",
|
||||||
|
"top_k_weight",
|
||||||
|
"permute_indices_per_token",
|
||||||
|
"top_k_indices",
|
||||||
|
paddle::Optional("ffn2_bias")})
|
||||||
|
.Outputs({"output"})
|
||||||
|
.Attrs({"norm_topk_prob:bool", "routed_scaling_factor:float"})
|
||||||
|
.SetKernelFn(PD_KERNEL(MoeExpertReduce))
|
||||||
|
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertReduceInferShape))
|
||||||
|
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertReduceInferDtype));
|
337
custom_ops/iluvatar_ops/paged_attn.cu
Normal file
337
custom_ops/iluvatar_ops/paged_attn.cu
Normal file
@@ -0,0 +1,337 @@
|
|||||||
|
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "helper.h"
|
||||||
|
#include "iluvatar_context.h"
|
||||||
|
|
||||||
|
#define CUINFER_CHECK(func) \
|
||||||
|
do { \
|
||||||
|
cuinferStatus_t status = (func); \
|
||||||
|
if (status != CUINFER_STATUS_SUCCESS) { \
|
||||||
|
std::cerr << "Error in file " << __FILE__ << " on line " << __LINE__ << ": " \
|
||||||
|
<< cuinferGetErrorString(status) << std::endl; \
|
||||||
|
throw std::runtime_error("CUINFER_CHECK ERROR"); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
template <paddle::DataType T>
|
||||||
|
void PagedAttnKernel(const paddle::Tensor& q,
|
||||||
|
const paddle::Tensor& k_cache,
|
||||||
|
const paddle::Tensor& v_cache,
|
||||||
|
const paddle::Tensor& block_table,
|
||||||
|
const paddle::Tensor& seq_lens,
|
||||||
|
const paddle::optional<paddle::Tensor> &alibi_slopes,
|
||||||
|
const paddle::optional<paddle::Tensor> &k,
|
||||||
|
const paddle::optional<paddle::Tensor> &v,
|
||||||
|
int num_kv_heads,
|
||||||
|
float scale,
|
||||||
|
int block_size,
|
||||||
|
int max_context_len,
|
||||||
|
bool causal,
|
||||||
|
int window_left,
|
||||||
|
int window_right,
|
||||||
|
float softcap,
|
||||||
|
bool enable_cuda_graph,
|
||||||
|
bool use_sqrt_alibi,
|
||||||
|
paddle::Tensor& out) {
|
||||||
|
if (alibi_slopes) {
|
||||||
|
PADDLE_ENFORCE_EQ(alibi_slopes.get().dtype(),
|
||||||
|
paddle::DataType::FLOAT32,
|
||||||
|
common::errors::InvalidArgument(
|
||||||
|
"paged_attention expects alibi_slopes float tensor"));
|
||||||
|
PADDLE_ENFORCE_EQ(alibi_slopes.get().is_contiguous(),
|
||||||
|
true,
|
||||||
|
common::errors::InvalidArgument(
|
||||||
|
"paged_attention expects alibi_slopes is contiguous"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// check dtype and contiguous
|
||||||
|
const auto& dtype = q.dtype();
|
||||||
|
cudaDataType_t data_type;
|
||||||
|
if (dtype == paddle::DataType::FLOAT16) {
|
||||||
|
data_type = CUDA_R_16F;
|
||||||
|
} else if (dtype == paddle::DataType::BFLOAT16) {
|
||||||
|
data_type = CUDA_R_16BF;
|
||||||
|
} else {
|
||||||
|
common::errors::InvalidArgument("paged_attention support half and bfloat16 now");
|
||||||
|
}
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_EQ(k_cache.dtype(),
|
||||||
|
dtype,
|
||||||
|
common::errors::InvalidArgument(
|
||||||
|
"k_cache dtype must be the same as query dtype"));
|
||||||
|
PADDLE_ENFORCE_EQ(k_cache.is_contiguous(),
|
||||||
|
true,
|
||||||
|
common::errors::InvalidArgument(
|
||||||
|
"paged_attention expects k_cache is contiguous"));
|
||||||
|
PADDLE_ENFORCE_EQ(v_cache.dtype(),
|
||||||
|
dtype,
|
||||||
|
common::errors::InvalidArgument(
|
||||||
|
"v_cache dtype must be the same as query dtype"));
|
||||||
|
PADDLE_ENFORCE_EQ(v_cache.is_contiguous(),
|
||||||
|
true,
|
||||||
|
common::errors::InvalidArgument(
|
||||||
|
"paged_attention expects v_cache is contiguous"));
|
||||||
|
PADDLE_ENFORCE_EQ(block_table.dtype(),
|
||||||
|
paddle::DataType::INT32,
|
||||||
|
common::errors::InvalidArgument(
|
||||||
|
"block_table dtype must be int32"));
|
||||||
|
PADDLE_ENFORCE_EQ(block_table.is_contiguous(),
|
||||||
|
true,
|
||||||
|
common::errors::InvalidArgument(
|
||||||
|
"paged_attention expects block_table is contiguous"));
|
||||||
|
PADDLE_ENFORCE_EQ(seq_lens.dtype(),
|
||||||
|
paddle::DataType::INT32,
|
||||||
|
common::errors::InvalidArgument(
|
||||||
|
"seq_lens dtype must be int32"));
|
||||||
|
PADDLE_ENFORCE_EQ(seq_lens.is_contiguous(),
|
||||||
|
true,
|
||||||
|
common::errors::InvalidArgument(
|
||||||
|
"paged_attention expects seq_lens is contiguous"));
|
||||||
|
|
||||||
|
// check dim and shape
|
||||||
|
// out: [num_seqs, num_heads, head_size]
|
||||||
|
// q: [num_seqs, num_heads, head_size]
|
||||||
|
// k_chache: [num_blocks, kv_num_heads, block_size, head_size]
|
||||||
|
// v_chache: [num_blocks, kv_num_heads, block_size, head_size]
|
||||||
|
// block_table: [num_seqs, max_num_blocks_per_seq]
|
||||||
|
// seq_lens: [num_seqs]
|
||||||
|
|
||||||
|
const auto& q_dims = q.dims();
|
||||||
|
PADDLE_ENFORCE_EQ(q_dims.size(),
|
||||||
|
3,
|
||||||
|
common::errors::InvalidArgument(
|
||||||
|
"paged_attn receive query dims is "
|
||||||
|
"[num_seqs, num_heads, head_size]"));
|
||||||
|
PADDLE_ENFORCE_EQ(out.dims().size(),
|
||||||
|
3,
|
||||||
|
common::errors::InvalidArgument(
|
||||||
|
"paged_attn receive out dims is "
|
||||||
|
"[num_seqs, num_heads, head_size]"));
|
||||||
|
PADDLE_ENFORCE_EQ(k_cache.dims(),
|
||||||
|
v_cache.dims(),
|
||||||
|
common::errors::InvalidArgument(
|
||||||
|
"paged_attn requires k_cache size is the "
|
||||||
|
"same as v_cache"));
|
||||||
|
|
||||||
|
const auto& kv_cache_dims = k_cache.dims();
|
||||||
|
PADDLE_ENFORCE_EQ(kv_cache_dims.size(),
|
||||||
|
4,
|
||||||
|
common::errors::InvalidArgument(
|
||||||
|
"paged_attn receive kv cache dims is "
|
||||||
|
"[num_blocks, kv_num_heads, block_size, head_size]"));
|
||||||
|
|
||||||
|
const auto& block_table_dims = block_table.dims();
|
||||||
|
PADDLE_ENFORCE_EQ(block_table_dims.size(),
|
||||||
|
2,
|
||||||
|
common::errors::InvalidArgument(
|
||||||
|
"paged_attn receive block_table dims is "
|
||||||
|
"[num_seqs, max_num_blocks_per_seq]"));
|
||||||
|
|
||||||
|
const auto& seq_lens_dims = seq_lens.dims();
|
||||||
|
PADDLE_ENFORCE_EQ(seq_lens_dims.size(),
|
||||||
|
1,
|
||||||
|
common::errors::InvalidArgument(
|
||||||
|
"paged_attn receive seq_lens dims is [num_seqs]"));
|
||||||
|
|
||||||
|
int num_seqs = q_dims[0];
|
||||||
|
int num_heads = q_dims[1];
|
||||||
|
int head_size = q_dims[2];
|
||||||
|
int max_num_blocks_per_seq = block_table_dims[1];
|
||||||
|
int q_stride = q.strides()[0];
|
||||||
|
int num_blocks = kv_cache_dims[0];
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_EQ(kv_cache_dims[1],
|
||||||
|
num_kv_heads,
|
||||||
|
common::errors::InvalidArgument(
|
||||||
|
"kv_cache_dims[1] must be equal to num_kv_head"));
|
||||||
|
PADDLE_ENFORCE_EQ(kv_cache_dims[2],
|
||||||
|
block_size,
|
||||||
|
common::errors::InvalidArgument(
|
||||||
|
"kv_cache_dims[2] must be equal to block_size"));
|
||||||
|
PADDLE_ENFORCE_EQ(kv_cache_dims[3],
|
||||||
|
head_size,
|
||||||
|
common::errors::InvalidArgument(
|
||||||
|
"kv_cache_dims[3] must be equal to head_size"));
|
||||||
|
PADDLE_ENFORCE_EQ(block_table_dims[0],
|
||||||
|
num_seqs,
|
||||||
|
common::errors::InvalidArgument(
|
||||||
|
"block_table_dims[0] must be equal to num_seqs"));
|
||||||
|
PADDLE_ENFORCE_EQ(seq_lens_dims[0],
|
||||||
|
num_seqs,
|
||||||
|
common::errors::InvalidArgument(
|
||||||
|
"seq_lens_dims[0] must be equal to num_seqs"));
|
||||||
|
|
||||||
|
int kv_block_stride = k_cache.strides()[0];
|
||||||
|
int kv_head_stride = k_cache.strides()[1];
|
||||||
|
const float *alibi_slopes_ptr = alibi_slopes ? alibi_slopes.get().data<float>() : nullptr;
|
||||||
|
const void *key_ptr = k ? k.get().data() : nullptr;
|
||||||
|
const void *value_ptr = v ? v.get().data() : nullptr;
|
||||||
|
|
||||||
|
size_t workspace_size = 0;
|
||||||
|
void* workspace_ptr = nullptr;
|
||||||
|
CUINFER_CHECK(cuInferPageAttentionGetWorkspaceV7(
|
||||||
|
num_seqs, num_heads, num_kv_heads, head_size, block_size, max_context_len, &workspace_size));
|
||||||
|
|
||||||
|
CUDA_CHECK(cudaMalloc((void**)&workspace_ptr, workspace_size));
|
||||||
|
CUDA_CHECK(cudaMemset(workspace_ptr, 0xff, workspace_size));
|
||||||
|
|
||||||
|
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(q.place()));
|
||||||
|
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
|
||||||
|
cuinferHandle_t cuinfer_handle = iluvatar::getContextInstance()->getIxInferHandle();
|
||||||
|
|
||||||
|
PageAttentionWithKVCacheArguments args{
|
||||||
|
static_cast<float>(scale), 1.0, 1.0, static_cast<float>(softcap), window_left, window_right,
|
||||||
|
causal, use_sqrt_alibi, enable_cuda_graph, false, alibi_slopes_ptr, key_ptr, value_ptr, workspace_ptr};
|
||||||
|
CUINFER_CHECK(cuInferPageAttentionV7(cuinfer_handle,
|
||||||
|
out.data(),
|
||||||
|
data_type,
|
||||||
|
q.data(),
|
||||||
|
data_type,
|
||||||
|
num_seqs,
|
||||||
|
num_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_size,
|
||||||
|
q_stride,
|
||||||
|
kv_block_stride,
|
||||||
|
kv_head_stride,
|
||||||
|
k_cache.data(),
|
||||||
|
data_type,
|
||||||
|
v_cache.data(),
|
||||||
|
data_type,
|
||||||
|
block_size,
|
||||||
|
max_num_blocks_per_seq,
|
||||||
|
max_context_len,
|
||||||
|
block_table.data<int32_t>(),
|
||||||
|
seq_lens.data<int32_t>(),
|
||||||
|
args));
|
||||||
|
|
||||||
|
CUDA_CHECK(cudaFree(workspace_ptr));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
|
||||||
|
const paddle::Tensor& k_cache,
|
||||||
|
const paddle::Tensor& v_cache,
|
||||||
|
const paddle::Tensor& block_table,
|
||||||
|
const paddle::Tensor& seq_lens,
|
||||||
|
const paddle::optional<paddle::Tensor> &alibi_slopes,
|
||||||
|
const paddle::optional<paddle::Tensor> &k,
|
||||||
|
const paddle::optional<paddle::Tensor> &v,
|
||||||
|
int num_kv_heads,
|
||||||
|
float scale,
|
||||||
|
int block_size,
|
||||||
|
int max_context_len,
|
||||||
|
bool causal,
|
||||||
|
int window_left,
|
||||||
|
int window_right,
|
||||||
|
float softcap,
|
||||||
|
bool enable_cuda_graph,
|
||||||
|
bool use_sqrt_alibi) {
|
||||||
|
|
||||||
|
const auto dtype = q.dtype();
|
||||||
|
auto out = paddle::empty_like(q, dtype);
|
||||||
|
|
||||||
|
switch (dtype) {
|
||||||
|
case paddle::DataType::BFLOAT16:
|
||||||
|
PagedAttnKernel<paddle::DataType::BFLOAT16>(q,
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
block_table,
|
||||||
|
seq_lens,
|
||||||
|
alibi_slopes,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
num_kv_heads,
|
||||||
|
scale,
|
||||||
|
block_size,
|
||||||
|
max_context_len,
|
||||||
|
causal,
|
||||||
|
window_left,
|
||||||
|
window_right,
|
||||||
|
softcap,
|
||||||
|
enable_cuda_graph,
|
||||||
|
use_sqrt_alibi,
|
||||||
|
out);
|
||||||
|
break;
|
||||||
|
case paddle::DataType::FLOAT16:
|
||||||
|
PagedAttnKernel<paddle::DataType::FLOAT16>(q,
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
block_table,
|
||||||
|
seq_lens,
|
||||||
|
alibi_slopes,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
num_kv_heads,
|
||||||
|
scale,
|
||||||
|
block_size,
|
||||||
|
max_context_len,
|
||||||
|
causal,
|
||||||
|
window_left,
|
||||||
|
window_right,
|
||||||
|
softcap,
|
||||||
|
enable_cuda_graph,
|
||||||
|
use_sqrt_alibi,
|
||||||
|
out);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
PD_THROW("Unsupported data type for Paged attn");
|
||||||
|
}
|
||||||
|
return {out};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<int64_t>> PagedAttnInferShape(const std::vector<int64_t>& q_shape,
|
||||||
|
const std::vector<int64_t>& k_cache_shape,
|
||||||
|
const std::vector<int64_t>& v_cache_shape,
|
||||||
|
const std::vector<int64_t>& block_table_shape,
|
||||||
|
const std::vector<int64_t>& seq_lens_shape,
|
||||||
|
const std::vector<int64_t>& alibi_slopes_shape,
|
||||||
|
const std::vector<int64_t>& k_shape,
|
||||||
|
const std::vector<int64_t>& v_shape) {
|
||||||
|
return {q_shape};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<paddle::DataType> PagedAttnInferDtype(const paddle::DataType& q_dtype,
|
||||||
|
const paddle::DataType& k_cache_dtype,
|
||||||
|
const paddle::DataType& v_cache_dtype,
|
||||||
|
const paddle::DataType& block_table_dtype,
|
||||||
|
const paddle::DataType& seq_lens_dtype,
|
||||||
|
const paddle::DataType& alibi_slopes_dtype,
|
||||||
|
const paddle::DataType& k_dtype,
|
||||||
|
const paddle::DataType& v_dtype) {
|
||||||
|
return {q_dtype};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
PD_BUILD_STATIC_OP(paged_attn)
|
||||||
|
.Inputs({"q", "k_cache", "v_cache", "block_table", "seq_lens", paddle::Optional("alibi_slopes"), paddle::Optional("k"), paddle::Optional("v")})
|
||||||
|
.Outputs({"out"})
|
||||||
|
.Attrs({"num_kv_heads:int",
|
||||||
|
"scale:float",
|
||||||
|
"block_size:int",
|
||||||
|
"max_context_len:int",
|
||||||
|
"causal:bool",
|
||||||
|
"window_left:int",
|
||||||
|
"window_right:int",
|
||||||
|
"softcap:float",
|
||||||
|
"enable_cuda_graph:bool",
|
||||||
|
"use_sqrt_alibi:bool"})
|
||||||
|
.SetKernelFn(PD_KERNEL(PagedAttn))
|
||||||
|
.SetInferShapeFn(PD_INFER_SHAPE(PagedAttnInferShape))
|
||||||
|
.SetInferDtypeFn(PD_INFER_DTYPE(PagedAttnInferDtype));
|
||||||
|
|
||||||
|
|
||||||
|
PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||||
|
m.def("paged_attn", &PagedAttn, "paged attn function");
|
||||||
|
}
|
37
custom_ops/iluvatar_ops/runtime/iluvatar_context.cc
Normal file
37
custom_ops/iluvatar_ops/runtime/iluvatar_context.cc
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
#include "iluvatar_context.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <mutex>
|
||||||
|
namespace iluvatar {
|
||||||
|
IluvatarContext::~IluvatarContext() {
|
||||||
|
if (ixinfer_handle_) {
|
||||||
|
cuinferDestroy(ixinfer_handle_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cuinferHandle_t IluvatarContext::getIxInferHandle() {
|
||||||
|
if (!ixinfer_handle_) {
|
||||||
|
cuinferCreate(&ixinfer_handle_);
|
||||||
|
}
|
||||||
|
return ixinfer_handle_;
|
||||||
|
}
|
||||||
|
|
||||||
|
IluvatarContext* getContextInstance() {
|
||||||
|
static IluvatarContext context;
|
||||||
|
return &context;
|
||||||
|
}
|
||||||
|
} // namespace iluvatar
|
33
custom_ops/iluvatar_ops/runtime/iluvatar_context.h
Normal file
33
custom_ops/iluvatar_ops/runtime/iluvatar_context.h
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <ixinfer.h>
|
||||||
|
|
||||||
|
namespace iluvatar {
|
||||||
|
|
||||||
|
class IluvatarContext {
|
||||||
|
public:
|
||||||
|
IluvatarContext() = default;
|
||||||
|
~IluvatarContext();
|
||||||
|
|
||||||
|
cuinferHandle_t getIxInferHandle();
|
||||||
|
|
||||||
|
private:
|
||||||
|
cuinferHandle_t ixinfer_handle_{nullptr};
|
||||||
|
};
|
||||||
|
IluvatarContext* getContextInstance();
|
||||||
|
|
||||||
|
} // namespace iluvatar
|
@@ -470,6 +470,36 @@ elif paddle.is_compiled_with_cuda():
|
|||||||
)
|
)
|
||||||
elif paddle.is_compiled_with_xpu():
|
elif paddle.is_compiled_with_xpu():
|
||||||
assert False, "In XPU, we should use setup_ops.py in xpu_ops/src, not this."
|
assert False, "In XPU, we should use setup_ops.py in xpu_ops/src, not this."
|
||||||
|
elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
||||||
|
setup(
|
||||||
|
name="fastdeploy_ops",
|
||||||
|
ext_modules=CUDAExtension(
|
||||||
|
extra_compile_args={
|
||||||
|
"nvcc": [
|
||||||
|
"-DPADDLE_DEV",
|
||||||
|
"-DPADDLE_WITH_CUSTOM_DEVICE",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
sources=[
|
||||||
|
"gpu_ops/get_padding_offset.cu",
|
||||||
|
"gpu_ops/set_value_by_flags.cu",
|
||||||
|
"gpu_ops/stop_generation_multi_stop_seqs.cu",
|
||||||
|
"gpu_ops/rebuild_padding.cu",
|
||||||
|
"gpu_ops/update_inputs.cu",
|
||||||
|
"gpu_ops/stop_generation_multi_ends.cu",
|
||||||
|
"gpu_ops/step.cu",
|
||||||
|
"gpu_ops/token_penalty_multi_scores.cu",
|
||||||
|
"iluvatar_ops/moe_dispatch.cu",
|
||||||
|
"iluvatar_ops/moe_reduce.cu",
|
||||||
|
"iluvatar_ops/paged_attn.cu",
|
||||||
|
"iluvatar_ops/runtime/iluvatar_context.cc",
|
||||||
|
],
|
||||||
|
include_dirs=["iluvatar_ops/runtime", "gpu_ops"],
|
||||||
|
extra_link_args=[
|
||||||
|
"-lcuinfer",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
use_bf16 = envs.FD_CPU_USE_BF16 == "True"
|
use_bf16 = envs.FD_CPU_USE_BF16 == "True"
|
||||||
|
|
||||||
|
@@ -42,7 +42,7 @@ class ModelConfig:
|
|||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
config_json_file: str = "config.json",
|
config_json_file: str = "config.json",
|
||||||
dynamic_load_weight: bool = False,
|
dynamic_load_weight: bool = False,
|
||||||
load_strategy: str="meta",
|
load_strategy: str = "meta",
|
||||||
quantization: str = None,
|
quantization: str = None,
|
||||||
download_dir: Optional[str] = None):
|
download_dir: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
@@ -608,8 +608,9 @@ class Config:
|
|||||||
== 1), "TP and EP cannot be enabled at the same time"
|
== 1), "TP and EP cannot be enabled at the same time"
|
||||||
|
|
||||||
num_ranks = self.tensor_parallel_size * self.parallel_config.expert_parallel_size
|
num_ranks = self.tensor_parallel_size * self.parallel_config.expert_parallel_size
|
||||||
if num_ranks > 8:
|
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||||
self.worker_num_per_node = 8
|
if num_ranks > self.max_chips_per_node:
|
||||||
|
self.worker_num_per_node = self.max_chips_per_node
|
||||||
nnode = ceil_div(num_ranks, self.worker_num_per_node)
|
nnode = ceil_div(num_ranks, self.worker_num_per_node)
|
||||||
assert nnode == self.nnode, \
|
assert nnode == self.nnode, \
|
||||||
f"nnode: {nnode}, but got {self.nnode}"
|
f"nnode: {nnode}, but got {self.nnode}"
|
||||||
@@ -679,8 +680,8 @@ class Config:
|
|||||||
is_port_available('0.0.0.0', self.engine_worker_queue_port)
|
is_port_available('0.0.0.0', self.engine_worker_queue_port)
|
||||||
), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use."
|
), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use."
|
||||||
assert (
|
assert (
|
||||||
8 >= self.tensor_parallel_size > 0
|
self.max_chips_per_node >= self.tensor_parallel_size > 0
|
||||||
), f"tensor_parallel_size: {self.tensor_parallel_size} should be between 1 and 8"
|
), f"tensor_parallel_size: {self.tensor_parallel_size} should be between 1 and {self.max_chips_per_node}"
|
||||||
assert (self.nnode >= 1), f"nnode: {self.nnode} should no less than 1"
|
assert (self.nnode >= 1), f"nnode: {self.nnode} should no less than 1"
|
||||||
assert (
|
assert (
|
||||||
self.max_model_len >= 16
|
self.max_model_len >= 16
|
||||||
|
@@ -63,7 +63,8 @@ class SiluAndMul(nn.Layer):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if current_platform.is_cuda() or current_platform.is_xpu():
|
if current_platform.is_cuda() or current_platform.is_xpu(
|
||||||
|
) or current_platform.is_iluvatar():
|
||||||
self.forward = self.forward_cuda
|
self.forward = self.forward_cuda
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@@ -19,9 +19,10 @@ from .flash_attn_backend import FlashAttentionBackend
|
|||||||
from .mla_attention_backend import MLAAttentionBackend
|
from .mla_attention_backend import MLAAttentionBackend
|
||||||
from .native_paddle_backend import PaddleNativeAttnBackend
|
from .native_paddle_backend import PaddleNativeAttnBackend
|
||||||
from .xpu_attn_backend import XPUAttentionBackend
|
from .xpu_attn_backend import XPUAttentionBackend
|
||||||
|
from .iluvatar_attn_backend import IluvatarAttnBackend
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AttentionBackend", "PaddleNativeAttnBackend",
|
"AttentionBackend", "PaddleNativeAttnBackend",
|
||||||
"get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend",
|
"get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend",
|
||||||
"MLAAttentionBackend", "FlashAttentionBackend"
|
"MLAAttentionBackend", "FlashAttentionBackend", "IluvatarAttnBackend"
|
||||||
]
|
]
|
||||||
|
@@ -67,7 +67,7 @@ class Attention(nn.Layer):
|
|||||||
self.num_heads: int = fd_config.model_config.num_attention_heads // fd_config.parallel_config.tensor_parallel_degree
|
self.num_heads: int = fd_config.model_config.num_attention_heads // fd_config.parallel_config.tensor_parallel_degree
|
||||||
self.head_dim: int = fd_config.model_config.head_dim
|
self.head_dim: int = fd_config.model_config.head_dim
|
||||||
self.kv_num_heads: int = \
|
self.kv_num_heads: int = \
|
||||||
fd_config.model_config.num_key_value_heads // fd_config.parallel_config.tensor_parallel_degree
|
max(1, fd_config.model_config.num_key_value_heads // fd_config.parallel_config.tensor_parallel_degree)
|
||||||
self.layer_id: int = layer_id
|
self.layer_id: int = layer_id
|
||||||
self.v_head_dim: int = v_head_dim if v_head_dim > 0 else self.head_dim
|
self.v_head_dim: int = v_head_dim if v_head_dim > 0 else self.head_dim
|
||||||
self.rope_type: str = rope_type
|
self.rope_type: str = rope_type
|
||||||
|
@@ -0,0 +1,613 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
from math import sqrt
|
||||||
|
|
||||||
|
from paddle.nn.functional.flash_attention import flash_attn_unpadded
|
||||||
|
from fastdeploy.model_executor.ops.iluvatar import paged_attention
|
||||||
|
|
||||||
|
from fastdeploy.config import FDConfig
|
||||||
|
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||||
|
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||||
|
AttentionBackend, AttentionMetadata)
|
||||||
|
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IluvatarAttentionMetadata(AttentionMetadata):
|
||||||
|
"""
|
||||||
|
IluvatarAttentionMetadata
|
||||||
|
"""
|
||||||
|
# flash_attn metadata
|
||||||
|
cu_seqlens_q: Optional[paddle.Tensor] = None
|
||||||
|
cu_seqlens_k: Optional[paddle.Tensor] = None
|
||||||
|
fixed_seed_offset: Optional[paddle.Tensor] = None
|
||||||
|
attn_mask: Optional[paddle.Tensor] = None
|
||||||
|
attn_mask_start_row_indices: Optional[paddle.Tensor] = None
|
||||||
|
dropout: float = 0.0
|
||||||
|
causal: bool = True
|
||||||
|
return_softmax: bool = False
|
||||||
|
rng_name: str = ""
|
||||||
|
|
||||||
|
# paged_attn metadata
|
||||||
|
block_tables: Optional[paddle.Tensor] = None
|
||||||
|
seq_lens: Optional[paddle.Tensor] = None
|
||||||
|
num_kv_heads: int = 1
|
||||||
|
scale: float = 1.0
|
||||||
|
block_size: int = 1
|
||||||
|
max_context_len: int = 1
|
||||||
|
alibi_slopes: Optional[paddle.Tensor] = None
|
||||||
|
# causal: bool = True
|
||||||
|
window_left: int = -1
|
||||||
|
window_right: int = -1
|
||||||
|
softcap: float = 0.0
|
||||||
|
use_cuda_graph: bool = False
|
||||||
|
use_sqrt_alibi: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
# qk[seq, h, d], cos/sin [seq, 1, d]
|
||||||
|
def apply_rope(qk, cos, sin):
|
||||||
|
rotate_half = paddle.reshape(
|
||||||
|
paddle.stack([-qk[..., 1::2], qk[..., 0::2]], axis=-1),
|
||||||
|
paddle.shape(qk),
|
||||||
|
)
|
||||||
|
out = paddle.add(paddle.multiply(qk, cos),
|
||||||
|
paddle.multiply(rotate_half, sin))
|
||||||
|
return paddle.cast(out, qk.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class IluvatarAttnBackend(AttentionBackend):
|
||||||
|
"""
|
||||||
|
The backend class that uses paddle native attention implementation.
|
||||||
|
Which is used only for testing purpose.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, llm_config: FDConfig, kv_num_heads: int, num_heads: int,
|
||||||
|
head_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.attention_metadata = IluvatarAttentionMetadata()
|
||||||
|
self.attention_metadata.block_size = llm_config.parallel_config.block_size
|
||||||
|
assert llm_config.parallel_config.enc_dec_block_num == 0, "Iluvatar does not support yet"
|
||||||
|
|
||||||
|
self.attention_metadata.max_context_len = llm_config.parallel_config.max_model_len
|
||||||
|
self.attention_metadata.causal = getattr(llm_config.model_config,
|
||||||
|
"causal", True)
|
||||||
|
self.speculate_method = getattr(llm_config.parallel_config,
|
||||||
|
"speculate_method", None)
|
||||||
|
self.use_speculate = self.speculate_method is not None
|
||||||
|
self.attention_metadata.num_kv_heads = kv_num_heads
|
||||||
|
self.attention_metadata.dropout = llm_config.model_config.hidden_dropout_prob
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
# note: scale need to change if using MLA
|
||||||
|
self.attention_metadata.scale = 1.0 / sqrt(head_dim)
|
||||||
|
self.num_layers = llm_config.model_config.num_layers
|
||||||
|
self.record_block_table_metadata = {}
|
||||||
|
self.only_use_flash_attn = int(
|
||||||
|
os.getenv("FD_ILUVATAR_ONLY_USE_FLASH_ATTN", 0)) == 1
|
||||||
|
self.do_check_kv_cache = int(
|
||||||
|
os.getenv("FD_ILUVATAR_CHECK_KV_CACHE_CORRECTNESS", 0)) == 1
|
||||||
|
if not self.only_use_flash_attn:
|
||||||
|
assert self.attention_metadata.block_size == 16, "Iluvatar paged attn requires block_size must be 16."
|
||||||
|
if self.do_check_kv_cache:
|
||||||
|
self.record_batched_k = [{} for _ in range(self.num_layers)]
|
||||||
|
self.record_batched_v = [{} for _ in range(self.num_layers)]
|
||||||
|
|
||||||
|
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||||
|
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
||||||
|
self.attention_metadata.block_tables = forward_meta.block_tables
|
||||||
|
self.attention_metadata.attn_mask = forward_meta.attn_mask
|
||||||
|
self.attention_metadata.seq_lens = forward_meta.seq_lens_decoder
|
||||||
|
self.attention_metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
|
||||||
|
self.attention_metadata.cu_seqlens_k = forward_meta.cu_seqlens_k
|
||||||
|
|
||||||
|
def get_attntion_meta(self):
|
||||||
|
"""get_attntion_meta"""
|
||||||
|
return self.attention_metadata
|
||||||
|
|
||||||
|
def get_kv_cache_shape(
|
||||||
|
self,
|
||||||
|
max_num_blocks: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Caculate kv cache shape
|
||||||
|
"""
|
||||||
|
return (max_num_blocks, self.attention_metadata.num_kv_heads,
|
||||||
|
self.attention_metadata.block_size, self.head_dim)
|
||||||
|
|
||||||
|
def get_new_kv(self,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
k_cache_id: int,
|
||||||
|
v_cache_id: int,
|
||||||
|
forward_meta: ForwardMeta,
|
||||||
|
debug_paged_attn=False):
|
||||||
|
new_k = []
|
||||||
|
new_v = []
|
||||||
|
tensor_start = 0
|
||||||
|
for batch_idx in range(forward_meta.block_tables.shape[0]):
|
||||||
|
seq_len = forward_meta.seq_lens_this_time[batch_idx]
|
||||||
|
if seq_len == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
tensor_end = tensor_start + seq_len
|
||||||
|
slice_k = k[tensor_start:tensor_end, :, :]
|
||||||
|
slice_v = v[tensor_start:tensor_end, :, :]
|
||||||
|
|
||||||
|
if seq_len > 1:
|
||||||
|
# prefill
|
||||||
|
new_k.append(slice_k)
|
||||||
|
new_v.append(slice_v)
|
||||||
|
else:
|
||||||
|
# decode
|
||||||
|
assert seq_len == 1
|
||||||
|
cur_block_tables = forward_meta.block_tables[batch_idx]
|
||||||
|
cur_used_block_tables = cur_block_tables[cur_block_tables !=
|
||||||
|
-1]
|
||||||
|
assert batch_idx in self.record_block_table_metadata, \
|
||||||
|
f"Key error: {batch_idx} vs {self.record_block_table_metadata}."
|
||||||
|
cur_block_table_metadata = self.record_block_table_metadata[
|
||||||
|
batch_idx]
|
||||||
|
record_last_block_id = cur_block_table_metadata["block_id"]
|
||||||
|
assert record_last_block_id != -1
|
||||||
|
for block_id in cur_used_block_tables:
|
||||||
|
if block_id == record_last_block_id:
|
||||||
|
cache_end = cur_block_table_metadata["cache_end"]
|
||||||
|
block_k_cache = forward_meta.caches[k_cache_id][
|
||||||
|
block_id, :, 0:cache_end, :]
|
||||||
|
block_v_cache = forward_meta.caches[v_cache_id][
|
||||||
|
block_id, :, 0:cache_end, :]
|
||||||
|
else:
|
||||||
|
block_k_cache = forward_meta.caches[k_cache_id][
|
||||||
|
block_id]
|
||||||
|
block_v_cache = forward_meta.caches[v_cache_id][
|
||||||
|
block_id]
|
||||||
|
|
||||||
|
# [num_kv_heads, block_size, head_dim] -> [block_size, num_kv_heads, head_dim]
|
||||||
|
new_k.append(
|
||||||
|
block_k_cache.transpose([1, 0, 2]).contiguous())
|
||||||
|
new_v.append(
|
||||||
|
block_v_cache.transpose([1, 0, 2]).contiguous())
|
||||||
|
if block_id == record_last_block_id:
|
||||||
|
break
|
||||||
|
|
||||||
|
# as line 301 show, record_block_table_metadata updates when executing the last layer,
|
||||||
|
# so slice_k and slice_v has been updated in block_k_cache and block_v_cache
|
||||||
|
if not (debug_paged_attn and
|
||||||
|
(k_cache_id / 2 == self.num_layers - 1)):
|
||||||
|
new_k.append(slice_k)
|
||||||
|
new_v.append(slice_v)
|
||||||
|
|
||||||
|
tensor_start = tensor_end
|
||||||
|
|
||||||
|
if len(new_k) == 1:
|
||||||
|
return new_k[0], new_v[0]
|
||||||
|
else:
|
||||||
|
new_k = paddle.concat(new_k, axis=0)
|
||||||
|
new_v = paddle.concat(new_v, axis=0)
|
||||||
|
return new_k, new_v
|
||||||
|
|
||||||
|
def update_kv_cache(self,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
k_cache_id: int,
|
||||||
|
v_cache_id: int,
|
||||||
|
layer_id: int,
|
||||||
|
forward_meta: ForwardMeta,
|
||||||
|
specific_batch_ids=None,
|
||||||
|
debug_paged_attn=False):
|
||||||
|
# [num_tokens, num_kv_heads, head_dim] -> [num_kv_heads, num_tokens, head_dim]
|
||||||
|
trans_k = k.transpose([1, 0, 2]).contiguous()
|
||||||
|
trans_v = v.transpose([1, 0, 2]).contiguous()
|
||||||
|
tensor_start = 0
|
||||||
|
for batch_idx in range(forward_meta.block_tables.shape[0]):
|
||||||
|
if specific_batch_ids is not None and batch_idx not in specific_batch_ids:
|
||||||
|
continue
|
||||||
|
seq_len = forward_meta.seq_lens_this_time[batch_idx]
|
||||||
|
if seq_len == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
tensor_end = tensor_start + seq_len
|
||||||
|
slice_trans_k = trans_k[:, tensor_start:tensor_end, :]
|
||||||
|
slice_trans_v = trans_v[:, tensor_start:tensor_end, :]
|
||||||
|
|
||||||
|
cur_block_tables = forward_meta.block_tables[batch_idx]
|
||||||
|
cur_used_block_tables = cur_block_tables[cur_block_tables != -1]
|
||||||
|
|
||||||
|
# prefill
|
||||||
|
if seq_len > 1:
|
||||||
|
cache_start = 0
|
||||||
|
cur_used_num_blocks = cur_used_block_tables.shape[0]
|
||||||
|
for i, block_id in enumerate(cur_used_block_tables):
|
||||||
|
# last block: seq_len - cache_start <= block_size
|
||||||
|
if i == cur_used_num_blocks - 1:
|
||||||
|
cache_end = seq_len - cache_start
|
||||||
|
assert cache_end <= self.attention_metadata.block_size
|
||||||
|
forward_meta.caches[k_cache_id][
|
||||||
|
block_id, :,
|
||||||
|
0:cache_end, :] = slice_trans_k[:, cache_start:
|
||||||
|
seq_len, :]
|
||||||
|
forward_meta.caches[v_cache_id][
|
||||||
|
block_id, :,
|
||||||
|
0:cache_end, :] = slice_trans_v[:, cache_start:
|
||||||
|
seq_len, :]
|
||||||
|
if layer_id == self.num_layers - 1:
|
||||||
|
self.record_block_table_metadata[batch_idx] = {
|
||||||
|
"block_id": block_id.item(),
|
||||||
|
"cache_end": cache_end
|
||||||
|
}
|
||||||
|
# non last block: seq_lens_this_time > block_size
|
||||||
|
else:
|
||||||
|
assert seq_len > self.attention_metadata.block_size
|
||||||
|
cache_end = cache_start + self.attention_metadata.block_size
|
||||||
|
forward_meta.caches[k_cache_id][
|
||||||
|
block_id] = slice_trans_k[:,
|
||||||
|
cache_start:cache_end, :]
|
||||||
|
forward_meta.caches[v_cache_id][
|
||||||
|
block_id] = slice_trans_v[:,
|
||||||
|
cache_start:cache_end, :]
|
||||||
|
cache_start += self.attention_metadata.block_size
|
||||||
|
else:
|
||||||
|
# decode
|
||||||
|
assert seq_len == 1
|
||||||
|
cur_last_block_id = cur_used_block_tables[-1].item()
|
||||||
|
assert cur_last_block_id != -1
|
||||||
|
assert batch_idx in self.record_block_table_metadata, \
|
||||||
|
f"Key error: {batch_idx} vs {self.record_block_table_metadata}."
|
||||||
|
cur_block_table_metadata = self.record_block_table_metadata[
|
||||||
|
batch_idx]
|
||||||
|
record_last_block_id = cur_block_table_metadata["block_id"]
|
||||||
|
|
||||||
|
if cur_last_block_id == record_last_block_id:
|
||||||
|
# not alloc new block in decode stage
|
||||||
|
cache_start = cur_block_table_metadata["cache_end"]
|
||||||
|
else:
|
||||||
|
# alloc new block in decode stage
|
||||||
|
cache_start = 0
|
||||||
|
|
||||||
|
cache_end = cache_start + 1
|
||||||
|
assert cache_end <= self.attention_metadata.block_size
|
||||||
|
|
||||||
|
# paged attn API will update kv cache with inplace mode
|
||||||
|
if not debug_paged_attn:
|
||||||
|
forward_meta.caches[k_cache_id][
|
||||||
|
cur_last_block_id, :,
|
||||||
|
cache_start:cache_end, :] = slice_trans_k
|
||||||
|
forward_meta.caches[v_cache_id][
|
||||||
|
cur_last_block_id, :,
|
||||||
|
cache_start:cache_end, :] = slice_trans_v
|
||||||
|
|
||||||
|
# update record_block_table_metadata
|
||||||
|
if layer_id == self.num_layers - 1:
|
||||||
|
self.record_block_table_metadata[batch_idx][
|
||||||
|
"block_id"] = cur_last_block_id
|
||||||
|
self.record_block_table_metadata[batch_idx][
|
||||||
|
"cache_end"] = cache_end
|
||||||
|
|
||||||
|
tensor_start = tensor_end
|
||||||
|
|
||||||
|
def _check_new_kv_correctness(self, k, v, new_k, new_v, layer_id: int,
|
||||||
|
forward_meta: ForwardMeta):
|
||||||
|
tensor_start = 0
|
||||||
|
for batch_idx, seq_lens_this_time in enumerate(
|
||||||
|
forward_meta.seq_lens_this_time):
|
||||||
|
if seq_lens_this_time == 0:
|
||||||
|
continue
|
||||||
|
# note: the second request will also use the batch_idx 0 instead of 1 in
|
||||||
|
# the streaming inference mode, so use seq_lens_this_time > 1 with the same
|
||||||
|
# batch_idx represents the second request comes.
|
||||||
|
if seq_lens_this_time > 1 and batch_idx in self.record_batched_k[
|
||||||
|
layer_id]:
|
||||||
|
print(
|
||||||
|
f"clear self.record_batched_batched_k: "
|
||||||
|
f"layer_id={layer_id}, batch_id={batch_idx}, "
|
||||||
|
f"record_lens={len(self.record_batched_k[layer_id][batch_idx])}"
|
||||||
|
)
|
||||||
|
self.record_batched_k[layer_id][batch_idx].clear()
|
||||||
|
self.record_batched_v[layer_id][batch_idx].clear()
|
||||||
|
tensor_end = tensor_start + seq_lens_this_time
|
||||||
|
slice_k = k[tensor_start:tensor_end, :, :]
|
||||||
|
slice_v = v[tensor_start:tensor_end, :, :]
|
||||||
|
if batch_idx not in self.record_batched_k[layer_id]:
|
||||||
|
self.record_batched_k[layer_id][batch_idx] = []
|
||||||
|
self.record_batched_v[layer_id][batch_idx] = []
|
||||||
|
self.record_batched_k[layer_id][batch_idx].append(slice_k)
|
||||||
|
self.record_batched_v[layer_id][batch_idx].append(slice_v)
|
||||||
|
tensor_start = tensor_end
|
||||||
|
|
||||||
|
ref_k, ref_v = [], []
|
||||||
|
for batch_idx, seq_lens_this_time in enumerate(
|
||||||
|
forward_meta.seq_lens_this_time):
|
||||||
|
if seq_lens_this_time == 0:
|
||||||
|
continue
|
||||||
|
bached_k_list = self.record_batched_k[layer_id][batch_idx]
|
||||||
|
bached_v_list = self.record_batched_v[layer_id][batch_idx]
|
||||||
|
ref_k.extend(bached_k_list)
|
||||||
|
ref_v.extend(bached_v_list)
|
||||||
|
|
||||||
|
ref_k = paddle.concat(ref_k, axis=0)
|
||||||
|
ref_v = paddle.concat(ref_v, axis=0)
|
||||||
|
print(
|
||||||
|
f"_check_new_kv_correctness: layer_id={layer_id}, "
|
||||||
|
f"k.shape={k.shape}, v.shape={v.shape}, "
|
||||||
|
f"ref_k.shape={ref_k.shape}, ref_v.shape={ref_v.shape}, "
|
||||||
|
f"new_k.shape={new_k.shape}, new_v.shape={new_v.shape}, "
|
||||||
|
f"len(self.record_batched_k[layer_id])={len(self.record_batched_k[layer_id])}, "
|
||||||
|
f"len(self.record_batched_k[layer_id][0])={len(self.record_batched_k[layer_id][0])}, "
|
||||||
|
f"forward_meta.seq_lens_this_time={forward_meta.seq_lens_this_time}"
|
||||||
|
f"ref_k[-2:, 0:2, 0:2]={ref_k[-2:, 0:2, 0:2]}, "
|
||||||
|
f"ref_v[-2:, 0:2, 0:2]={ref_v[-2:, 0:2, 0:2]}, "
|
||||||
|
f"new_k[-2:, 0:2, 0:2]={new_k[-2:, 0:2, 0:2]}, "
|
||||||
|
f"new_v[-2:, 0:2, 0:2]={new_v[-2:, 0:2, 0:2]}")
|
||||||
|
assert paddle.allclose(
|
||||||
|
ref_k.to("cpu").to(paddle.float32),
|
||||||
|
new_k.to("cpu").to(paddle.float32))
|
||||||
|
assert paddle.allclose(
|
||||||
|
ref_v.to("cpu").to(paddle.float32),
|
||||||
|
new_v.to("cpu").to(paddle.float32))
|
||||||
|
|
||||||
|
def get_splited_qkv(self, qkv: paddle.Tensor, forward_meta: ForwardMeta):
|
||||||
|
q_end = self.num_heads * self.head_dim
|
||||||
|
k_end = q_end + self.attention_metadata.num_kv_heads * self.head_dim
|
||||||
|
v_end = k_end + self.attention_metadata.num_kv_heads * self.head_dim
|
||||||
|
assert v_end == qkv.shape[
|
||||||
|
-1], f"Shape mistach: {v_end} vs {qkv.shape[-1]}"
|
||||||
|
assert qkv.shape[0] == forward_meta.cu_seqlens_q[-1]
|
||||||
|
|
||||||
|
q = qkv[..., 0:q_end]
|
||||||
|
k = qkv[..., q_end:k_end]
|
||||||
|
v = qkv[..., k_end:v_end]
|
||||||
|
q = q.view([-1, self.num_heads, self.head_dim]).contiguous()
|
||||||
|
k = k.view([-1, self.attention_metadata.num_kv_heads,
|
||||||
|
self.head_dim]).contiguous()
|
||||||
|
v = v.view([-1, self.attention_metadata.num_kv_heads,
|
||||||
|
self.head_dim]).contiguous()
|
||||||
|
# forward_meta.seq_lens_this_time [max_batch,]
|
||||||
|
for batch_idx in range(forward_meta.seq_lens_this_time.shape[0]):
|
||||||
|
seq_len_i = forward_meta.seq_lens_this_time[batch_idx]
|
||||||
|
if seq_len_i == 0:
|
||||||
|
continue
|
||||||
|
cached_kv_len = forward_meta.seq_lens_decoder[batch_idx][0]
|
||||||
|
cu_seq_start_q = forward_meta.cu_seqlens_q[batch_idx]
|
||||||
|
cu_seq_end_q = forward_meta.cu_seqlens_q[batch_idx + 1]
|
||||||
|
# forward_meta.rotary_embs is [2, 1, S, 1, D]
|
||||||
|
if forward_meta.rotary_embs is not None:
|
||||||
|
cos = forward_meta.rotary_embs[0, 0,
|
||||||
|
cached_kv_len:cached_kv_len +
|
||||||
|
seq_len_i, :, :]
|
||||||
|
sin = forward_meta.rotary_embs[1, 0,
|
||||||
|
cached_kv_len:cached_kv_len +
|
||||||
|
seq_len_i, :, :]
|
||||||
|
q[cu_seq_start_q:cu_seq_end_q] = apply_rope(
|
||||||
|
q[cu_seq_start_q:cu_seq_end_q], cos, sin)
|
||||||
|
k[cu_seq_start_q:cu_seq_end_q] = apply_rope(
|
||||||
|
k[cu_seq_start_q:cu_seq_end_q], cos, sin)
|
||||||
|
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
def get_splited_info_by_stage(self, q, k, v, forward_meta: ForwardMeta):
|
||||||
|
prefill_info_dict = {"q": [], "k": [], "v": [], "batch_ids": []}
|
||||||
|
decode_info_dict = {"q": [], "k": [], "v": [], "batch_ids": []}
|
||||||
|
tensor_start = 0
|
||||||
|
for batch_idx, seq_lens_this_time in enumerate(
|
||||||
|
forward_meta.seq_lens_this_time):
|
||||||
|
if seq_lens_this_time == 0:
|
||||||
|
continue
|
||||||
|
tensor_end = tensor_start + seq_lens_this_time
|
||||||
|
slice_q = q[tensor_start:tensor_end, :, :]
|
||||||
|
slice_k = k[tensor_start:tensor_end, :, :]
|
||||||
|
slice_v = v[tensor_start:tensor_end, :, :]
|
||||||
|
if seq_lens_this_time > 1:
|
||||||
|
prefill_info_dict["q"].append(slice_q)
|
||||||
|
prefill_info_dict["k"].append(slice_k)
|
||||||
|
prefill_info_dict["v"].append(slice_v)
|
||||||
|
prefill_info_dict["batch_ids"].append(batch_idx)
|
||||||
|
else:
|
||||||
|
assert seq_lens_this_time == 1
|
||||||
|
decode_info_dict["q"].append(slice_q)
|
||||||
|
decode_info_dict["k"].append(slice_k)
|
||||||
|
decode_info_dict["v"].append(slice_v)
|
||||||
|
decode_info_dict["batch_ids"].append(batch_idx)
|
||||||
|
tensor_start = tensor_end
|
||||||
|
|
||||||
|
if len(prefill_info_dict["batch_ids"]) > 0:
|
||||||
|
prefill_info_dict["q"] = paddle.concat(prefill_info_dict["q"],
|
||||||
|
axis=0)
|
||||||
|
prefill_info_dict["k"] = paddle.concat(prefill_info_dict["k"],
|
||||||
|
axis=0)
|
||||||
|
prefill_info_dict["v"] = paddle.concat(prefill_info_dict["v"],
|
||||||
|
axis=0)
|
||||||
|
cu_seq_ids = list(
|
||||||
|
map(lambda x: x + 1, prefill_info_dict["batch_ids"]))
|
||||||
|
prefill_info_dict["cu_seq_ids"] = [0, *cu_seq_ids]
|
||||||
|
|
||||||
|
if len(decode_info_dict["batch_ids"]) > 0:
|
||||||
|
decode_info_dict["q"] = paddle.concat(decode_info_dict["q"],
|
||||||
|
axis=0)
|
||||||
|
decode_info_dict["k"] = paddle.concat(decode_info_dict["k"],
|
||||||
|
axis=0)
|
||||||
|
decode_info_dict["v"] = paddle.concat(decode_info_dict["v"],
|
||||||
|
axis=0)
|
||||||
|
|
||||||
|
return prefill_info_dict, decode_info_dict
|
||||||
|
|
||||||
|
def merge_output(self, prefill_out, decode_out, forward_meta: ForwardMeta):
|
||||||
|
assert not (prefill_out is None and decode_out
|
||||||
|
is None), "prefill and decode output cannot both be None"
|
||||||
|
if prefill_out is None:
|
||||||
|
return decode_out
|
||||||
|
elif decode_out is None:
|
||||||
|
return prefill_out
|
||||||
|
else:
|
||||||
|
merged_output = []
|
||||||
|
prefill_tensor_start = 0
|
||||||
|
decode_tensor_start = 0
|
||||||
|
for seq_lens_this_time in forward_meta.seq_lens_this_time:
|
||||||
|
if seq_lens_this_time == 0:
|
||||||
|
continue
|
||||||
|
if seq_lens_this_time > 1:
|
||||||
|
tensor_end = prefill_tensor_start + seq_lens_this_time
|
||||||
|
merged_output.append(
|
||||||
|
prefill_out[prefill_tensor_start:tensor_end, :, :])
|
||||||
|
prefill_tensor_start = tensor_end
|
||||||
|
else:
|
||||||
|
assert seq_lens_this_time == 1
|
||||||
|
tensor_end = decode_tensor_start + seq_lens_this_time
|
||||||
|
merged_output.append(
|
||||||
|
decode_out[decode_tensor_start:tensor_end, :, :])
|
||||||
|
decode_tensor_start = tensor_end
|
||||||
|
|
||||||
|
assert prefill_tensor_start == prefill_out.shape[0], \
|
||||||
|
f"prefill merged unfinished: {prefill_tensor_start} vs {prefill_out.shape[0]}"
|
||||||
|
assert decode_tensor_start == decode_out.shape[0], \
|
||||||
|
f"decode merged unfinished: {decode_tensor_start} vs {decode_out.shape[0]}"
|
||||||
|
merged_output = paddle.concat(merged_output, axis=0)
|
||||||
|
return merged_output
|
||||||
|
|
||||||
|
def forward_mixed(
|
||||||
|
self,
|
||||||
|
q: paddle.Tensor,
|
||||||
|
k: paddle.Tensor,
|
||||||
|
v: paddle.Tensor,
|
||||||
|
qkv: paddle.Tensor,
|
||||||
|
compressed_kv: paddle.Tensor,
|
||||||
|
k_pe: paddle.Tensor,
|
||||||
|
layer: Attention,
|
||||||
|
forward_meta: ForwardMeta,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
forward_mixed
|
||||||
|
"""
|
||||||
|
assert not self.use_speculate, "IluvatarAttnBackend cannot support speculate now"
|
||||||
|
layer_id = layer.layer_id
|
||||||
|
k_cache_id = layer_id * 2
|
||||||
|
v_cache_id = k_cache_id + 1
|
||||||
|
|
||||||
|
assert qkv is not None
|
||||||
|
q_dim = qkv.dim()
|
||||||
|
q, k, v = self.get_splited_qkv(qkv, forward_meta)
|
||||||
|
|
||||||
|
if self.only_use_flash_attn:
|
||||||
|
new_k, new_v = self.get_new_kv(k, v, k_cache_id, v_cache_id,
|
||||||
|
forward_meta)
|
||||||
|
if self.do_check_kv_cache:
|
||||||
|
self._check_new_kv_correctness(k, v, new_k, new_v, layer_id,
|
||||||
|
forward_meta)
|
||||||
|
|
||||||
|
out = flash_attn_unpadded(
|
||||||
|
q,
|
||||||
|
new_k,
|
||||||
|
new_v,
|
||||||
|
cu_seqlens_q=self.attention_metadata.cu_seqlens_q,
|
||||||
|
cu_seqlens_k=self.attention_metadata.cu_seqlens_k,
|
||||||
|
max_seqlen_q=self.attention_metadata.max_context_len,
|
||||||
|
max_seqlen_k=self.attention_metadata.max_context_len,
|
||||||
|
scale=self.attention_metadata.scale,
|
||||||
|
dropout=self.attention_metadata.dropout,
|
||||||
|
causal=self.attention_metadata.causal,
|
||||||
|
return_softmax=self.attention_metadata.return_softmax)[0]
|
||||||
|
|
||||||
|
self.update_kv_cache(k, v, k_cache_id, v_cache_id, layer_id,
|
||||||
|
forward_meta)
|
||||||
|
else:
|
||||||
|
prefill_info_dict, decode_info_dict = self.get_splited_info_by_stage(
|
||||||
|
q, k, v, forward_meta)
|
||||||
|
prefill_out, decode_out = None, None
|
||||||
|
|
||||||
|
if len(prefill_info_dict["batch_ids"]) > 0:
|
||||||
|
prefill_out = flash_attn_unpadded(
|
||||||
|
prefill_info_dict["q"],
|
||||||
|
prefill_info_dict["k"],
|
||||||
|
prefill_info_dict["v"],
|
||||||
|
cu_seqlens_q=forward_meta.cu_seqlens_q[
|
||||||
|
prefill_info_dict["cu_seq_ids"]],
|
||||||
|
cu_seqlens_k=forward_meta.cu_seqlens_k[
|
||||||
|
prefill_info_dict["cu_seq_ids"]],
|
||||||
|
max_seqlen_q=self.attention_metadata.max_context_len,
|
||||||
|
max_seqlen_k=self.attention_metadata.max_context_len,
|
||||||
|
scale=self.attention_metadata.scale,
|
||||||
|
dropout=self.attention_metadata.dropout,
|
||||||
|
causal=self.attention_metadata.causal,
|
||||||
|
return_softmax=self.attention_metadata.return_softmax)[0]
|
||||||
|
self.update_kv_cache(
|
||||||
|
prefill_info_dict["k"],
|
||||||
|
prefill_info_dict["v"],
|
||||||
|
k_cache_id,
|
||||||
|
v_cache_id,
|
||||||
|
layer_id,
|
||||||
|
forward_meta,
|
||||||
|
specific_batch_ids=prefill_info_dict['batch_ids'])
|
||||||
|
|
||||||
|
if len(decode_info_dict["batch_ids"]) > 0:
|
||||||
|
k_cache = forward_meta.caches[k_cache_id]
|
||||||
|
v_cache = forward_meta.caches[v_cache_id]
|
||||||
|
|
||||||
|
decode_out = paged_attention(
|
||||||
|
decode_info_dict["q"],
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
block_tables=forward_meta.block_tables[
|
||||||
|
decode_info_dict["batch_ids"], :],
|
||||||
|
seq_lens=forward_meta.seq_lens_decoder[
|
||||||
|
decode_info_dict["batch_ids"], 0] + 1,
|
||||||
|
num_kv_heads=self.attention_metadata.num_kv_heads,
|
||||||
|
scale=self.attention_metadata.scale,
|
||||||
|
block_size=self.attention_metadata.block_size,
|
||||||
|
max_context_len=self.attention_metadata.max_context_len,
|
||||||
|
alibi_slopes=self.attention_metadata.alibi_slopes,
|
||||||
|
causal=self.attention_metadata.causal,
|
||||||
|
window_left=self.attention_metadata.window_left,
|
||||||
|
window_right=self.attention_metadata.window_right,
|
||||||
|
softcap=self.attention_metadata.softcap,
|
||||||
|
use_cuda_graph=self.attention_metadata.use_cuda_graph,
|
||||||
|
use_sqrt_alibi=self.attention_metadata.use_sqrt_alibi,
|
||||||
|
k=decode_info_dict["k"],
|
||||||
|
v=decode_info_dict["v"])
|
||||||
|
|
||||||
|
if self.do_check_kv_cache:
|
||||||
|
self.update_kv_cache(
|
||||||
|
decode_info_dict['k'],
|
||||||
|
decode_info_dict['v'],
|
||||||
|
k_cache_id,
|
||||||
|
v_cache_id,
|
||||||
|
layer_id,
|
||||||
|
forward_meta,
|
||||||
|
specific_batch_ids=decode_info_dict['batch_ids'],
|
||||||
|
debug_paged_attn=True)
|
||||||
|
|
||||||
|
if self.do_check_kv_cache:
|
||||||
|
new_k, new_v = self.get_new_kv(k,
|
||||||
|
v,
|
||||||
|
k_cache_id,
|
||||||
|
v_cache_id,
|
||||||
|
forward_meta,
|
||||||
|
debug_paged_attn=True)
|
||||||
|
self._check_new_kv_correctness(k, v, new_k, new_v, layer_id,
|
||||||
|
forward_meta)
|
||||||
|
|
||||||
|
out = self.merge_output(prefill_out, decode_out, forward_meta)
|
||||||
|
|
||||||
|
if q_dim == 2:
|
||||||
|
out = out.view([-1, self.num_heads * self.head_dim])
|
||||||
|
|
||||||
|
return out
|
@@ -57,7 +57,8 @@ class LinearBase(nn.Layer):
|
|||||||
NotImplementedError: Raised if the current platform is not a CUDA platform.
|
NotImplementedError: Raised if the current platform is not a CUDA platform.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if current_platform.is_cuda() or current_platform.is_xpu():
|
if current_platform.is_cuda() or current_platform.is_xpu(
|
||||||
|
) or current_platform.is_iluvatar():
|
||||||
self.forward = self.forward_cuda
|
self.forward = self.forward_cuda
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -411,9 +412,14 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
self.head_dim = fd_config.model_config.head_dim
|
self.head_dim = fd_config.model_config.head_dim
|
||||||
self.nranks = fd_config.parallel_config.tensor_parallel_degree
|
self.nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||||
self.num_heads_per_rank = divide(self.num_heads, self.nranks)
|
self.num_heads_per_rank = divide(self.num_heads, self.nranks)
|
||||||
self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.nranks)
|
if self.kv_num_heads < self.nranks and self.nranks % self.kv_num_heads == 0:
|
||||||
|
self.kv_num_heads_per_rank = 1
|
||||||
|
output_size = (self.num_heads + 2 * self.nranks) * self.head_dim
|
||||||
|
else:
|
||||||
|
self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.nranks)
|
||||||
|
output_size = (self.num_heads +
|
||||||
|
2 * self.kv_num_heads) * self.head_dim
|
||||||
input_size = self.hidden_size
|
input_size = self.hidden_size
|
||||||
output_size = (self.num_heads + 2 * self.kv_num_heads) * self.head_dim
|
|
||||||
super().__init__(fd_config=fd_config,
|
super().__init__(fd_config=fd_config,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
input_size=input_size,
|
input_size=input_size,
|
||||||
|
@@ -30,6 +30,8 @@ from .fused_moe_backend_base import MoEMethodBase
|
|||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
from fastdeploy.model_executor.ops.gpu import (moe_expert_dispatch,
|
from fastdeploy.model_executor.ops.gpu import (moe_expert_dispatch,
|
||||||
moe_expert_reduce, noaux_tc)
|
moe_expert_reduce, noaux_tc)
|
||||||
|
elif current_platform.is_iluvatar():
|
||||||
|
from fastdeploy.model_executor.ops.iluvatar import moe_expert_dispatch, moe_expert_reduce
|
||||||
|
|
||||||
|
|
||||||
# used for deepseek_v3
|
# used for deepseek_v3
|
||||||
@@ -89,6 +91,23 @@ class CutlassMoEMethod(MoEMethodBase):
|
|||||||
"""
|
"""
|
||||||
Paddle Cutlass compute Fused MoE.
|
Paddle Cutlass compute Fused MoE.
|
||||||
"""
|
"""
|
||||||
|
if current_platform.is_iluvatar():
|
||||||
|
return fastdeploy.model_executor.ops.iluvatar.moe_expert_ffn(
|
||||||
|
permute_input,
|
||||||
|
token_nums_per_expert,
|
||||||
|
layer.moe_ffn1_weight,
|
||||||
|
layer.moe_ffn2_weight,
|
||||||
|
None,
|
||||||
|
(layer.moe_ffn1_weight_scale if hasattr(
|
||||||
|
layer, "moe_ffn1_weight_scale") else None),
|
||||||
|
(layer.moe_ffn2_weight_scale if hasattr(
|
||||||
|
layer, "moe_ffn2_weight_scale") else None),
|
||||||
|
(layer.moe_ffn2_in_scale
|
||||||
|
if hasattr(layer, "moe_ffn2_in_scale") else None),
|
||||||
|
expert_idx_per_token,
|
||||||
|
self.moe_quant_type,
|
||||||
|
used_in_ep_low_latency,
|
||||||
|
)
|
||||||
return fastdeploy.model_executor.ops.gpu.moe_expert_ffn(
|
return fastdeploy.model_executor.ops.gpu.moe_expert_ffn(
|
||||||
permute_input,
|
permute_input,
|
||||||
token_nums_per_expert,
|
token_nums_per_expert,
|
||||||
|
@@ -20,6 +20,7 @@ import numpy as np
|
|||||||
import paddle
|
import paddle
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
from paddle.incubate.nn.functional import fused_layer_norm, fused_rms_norm
|
from paddle.incubate.nn.functional import fused_layer_norm, fused_rms_norm
|
||||||
|
from fastdeploy.platforms import current_platform
|
||||||
|
|
||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig
|
||||||
|
|
||||||
@@ -265,6 +266,18 @@ class LayerNorm(nn.Layer):
|
|||||||
The `residual_output` is the result of applying the normalization and possibly other
|
The `residual_output` is the result of applying the normalization and possibly other
|
||||||
operations (like linear transformation) on the `residual_input`.
|
operations (like linear transformation) on the `residual_input`.
|
||||||
"""
|
"""
|
||||||
|
if current_platform.is_iluvatar():
|
||||||
|
if self.ln_weight is None and self.ln_bias is None:
|
||||||
|
out = x
|
||||||
|
if self.linear_bias is not None:
|
||||||
|
out += self.linear_bias
|
||||||
|
if residual_input is not None:
|
||||||
|
out += residual_input
|
||||||
|
return out, out
|
||||||
|
else:
|
||||||
|
return out
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Iluvatar does not support yet!")
|
||||||
|
|
||||||
norm_out = self.norm_func(
|
norm_out = self.norm_func(
|
||||||
x,
|
x,
|
||||||
|
@@ -48,7 +48,8 @@ class ErnieRotaryEmbedding:
|
|||||||
freqs = paddle.einsum("ij,k->ijk",
|
freqs = paddle.einsum("ij,k->ijk",
|
||||||
partial_rotary_position_ids.cast("float32"),
|
partial_rotary_position_ids.cast("float32"),
|
||||||
inv_freq)
|
inv_freq)
|
||||||
if paddle.is_compiled_with_xpu():
|
if paddle.is_compiled_with_xpu(
|
||||||
|
) or paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
||||||
# shape: [B, S, D]
|
# shape: [B, S, D]
|
||||||
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim),
|
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim),
|
||||||
dtype="float32")
|
dtype="float32")
|
||||||
|
@@ -64,6 +64,21 @@ def apply_penalty_multi_scores(
|
|||||||
min_dec_lens,
|
min_dec_lens,
|
||||||
eos_token_ids,
|
eos_token_ids,
|
||||||
)
|
)
|
||||||
|
elif current_platform.is_iluvatar():
|
||||||
|
from fastdeploy.model_executor.ops.iluvatar import \
|
||||||
|
get_token_penalty_multi_scores
|
||||||
|
logits = get_token_penalty_multi_scores(
|
||||||
|
pre_token_ids,
|
||||||
|
logits,
|
||||||
|
repetition_penalties,
|
||||||
|
frequency_penalties,
|
||||||
|
presence_penalties,
|
||||||
|
temperature,
|
||||||
|
bad_words_token_ids,
|
||||||
|
step_idx,
|
||||||
|
min_dec_lens,
|
||||||
|
eos_token_ids,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@@ -170,7 +170,8 @@ class Sampler(nn.Layer):
|
|||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if current_platform.is_cuda() or current_platform.is_xpu():
|
if current_platform.is_cuda() or current_platform.is_xpu(
|
||||||
|
) or current_platform.is_iluvatar():
|
||||||
self.forward = self.forward_cuda
|
self.forward = self.forward_cuda
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
@@ -33,6 +33,8 @@ if current_platform.is_cuda() and current_platform.available():
|
|||||||
"Verify environment consistency between compilation and FastDeploy installation. "
|
"Verify environment consistency between compilation and FastDeploy installation. "
|
||||||
"And ensure the Paddle version supports FastDeploy's custom operators"
|
"And ensure the Paddle version supports FastDeploy's custom operators"
|
||||||
)
|
)
|
||||||
|
if current_platform.is_iluvatar():
|
||||||
|
from fastdeploy.model_executor.ops.iluvatar import get_padding_offset
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from fastdeploy import envs
|
from fastdeploy import envs
|
||||||
|
@@ -213,8 +213,14 @@ def gqa_qkv_split_func(
|
|||||||
return np.split(tensor, degree, axis=0)
|
return np.split(tensor, degree, axis=0)
|
||||||
|
|
||||||
q_list = split_tensor(q, tensor_parallel_degree)
|
q_list = split_tensor(q, tensor_parallel_degree)
|
||||||
k_list = split_tensor(k, tensor_parallel_degree)
|
repeat_kv = num_key_value_heads < tensor_parallel_degree and tensor_parallel_degree % num_key_value_heads == 0
|
||||||
v_list = split_tensor(v, tensor_parallel_degree)
|
repeat_num = tensor_parallel_degree // num_key_value_heads if repeat_kv else 1
|
||||||
|
if repeat_kv:
|
||||||
|
k_list = split_tensor(k, num_key_value_heads)
|
||||||
|
v_list = split_tensor(v, num_key_value_heads)
|
||||||
|
else:
|
||||||
|
k_list = split_tensor(k, tensor_parallel_degree)
|
||||||
|
v_list = split_tensor(v, tensor_parallel_degree)
|
||||||
|
|
||||||
if tensor_parallel_rank is None:
|
if tensor_parallel_rank is None:
|
||||||
res = []
|
res = []
|
||||||
@@ -236,8 +242,8 @@ def gqa_qkv_split_func(
|
|||||||
return paddle.concat(
|
return paddle.concat(
|
||||||
[
|
[
|
||||||
q_list[tensor_parallel_rank],
|
q_list[tensor_parallel_rank],
|
||||||
k_list[tensor_parallel_rank],
|
k_list[tensor_parallel_rank // repeat_num],
|
||||||
v_list[tensor_parallel_rank],
|
v_list[tensor_parallel_rank // repeat_num],
|
||||||
],
|
],
|
||||||
axis=-1,
|
axis=-1,
|
||||||
)
|
)
|
||||||
@@ -245,8 +251,8 @@ def gqa_qkv_split_func(
|
|||||||
return paddle.concat(
|
return paddle.concat(
|
||||||
[
|
[
|
||||||
q_list[tensor_parallel_rank],
|
q_list[tensor_parallel_rank],
|
||||||
k_list[tensor_parallel_rank],
|
k_list[tensor_parallel_rank // repeat_num],
|
||||||
v_list[tensor_parallel_rank],
|
v_list[tensor_parallel_rank // repeat_num],
|
||||||
],
|
],
|
||||||
axis=0,
|
axis=0,
|
||||||
)
|
)
|
||||||
@@ -255,8 +261,8 @@ def gqa_qkv_split_func(
|
|||||||
return np.concatenate(
|
return np.concatenate(
|
||||||
[
|
[
|
||||||
q_list[tensor_parallel_rank],
|
q_list[tensor_parallel_rank],
|
||||||
k_list[tensor_parallel_rank],
|
k_list[tensor_parallel_rank // repeat_num],
|
||||||
v_list[tensor_parallel_rank],
|
v_list[tensor_parallel_rank // repeat_num],
|
||||||
],
|
],
|
||||||
axis=-1,
|
axis=-1,
|
||||||
)
|
)
|
||||||
@@ -264,8 +270,8 @@ def gqa_qkv_split_func(
|
|||||||
return np.concatenate(
|
return np.concatenate(
|
||||||
[
|
[
|
||||||
q_list[tensor_parallel_rank],
|
q_list[tensor_parallel_rank],
|
||||||
k_list[tensor_parallel_rank],
|
k_list[tensor_parallel_rank // repeat_num],
|
||||||
v_list[tensor_parallel_rank],
|
v_list[tensor_parallel_rank // repeat_num],
|
||||||
],
|
],
|
||||||
axis=0,
|
axis=0,
|
||||||
)
|
)
|
||||||
@@ -281,8 +287,8 @@ def gqa_qkv_merge_func(num_attention_heads, num_key_value_heads, head_dim):
|
|||||||
def fn(weight_list, is_column=True):
|
def fn(weight_list, is_column=True):
|
||||||
"""fn"""
|
"""fn"""
|
||||||
tensor_parallel_degree = len(weight_list)
|
tensor_parallel_degree = len(weight_list)
|
||||||
num_attention_heads = num_attention_heads // tensor_parallel_degree
|
num_attention_heads = num_attention_heads // tensor_parallel_degree # noqa: F823
|
||||||
num_key_value_heads = num_key_value_heads // tensor_parallel_degree
|
num_key_value_heads = num_key_value_heads // tensor_parallel_degree # noqa: F823
|
||||||
|
|
||||||
is_paddle_tensor = not isinstance(weight_list[0], np.ndarray)
|
is_paddle_tensor = not isinstance(weight_list[0], np.ndarray)
|
||||||
|
|
||||||
|
@@ -196,6 +196,9 @@ def convert_ndarray_dtype(np_array: np.ndarray,
|
|||||||
np.ndarray: converted numpy ndarray instance
|
np.ndarray: converted numpy ndarray instance
|
||||||
"""
|
"""
|
||||||
source_dtype = convert_dtype(np_array.dtype)
|
source_dtype = convert_dtype(np_array.dtype)
|
||||||
|
if source_dtype == "uint16" and target_dtype == "bfloat16" and paddle.is_compiled_with_custom_device(
|
||||||
|
"iluvatar_gpu"):
|
||||||
|
return np_array.view(dtype=target_dtype)
|
||||||
if source_dtype == "uint16" or target_dtype == "bfloat16":
|
if source_dtype == "uint16" or target_dtype == "bfloat16":
|
||||||
if paddle.is_compiled_with_xpu():
|
if paddle.is_compiled_with_xpu():
|
||||||
# xpu not support bf16.
|
# xpu not support bf16.
|
||||||
|
@@ -16,5 +16,6 @@ from . import gpu
|
|||||||
from . import cpu
|
from . import cpu
|
||||||
from . import xpu
|
from . import xpu
|
||||||
from . import npu
|
from . import npu
|
||||||
|
from . import iluvatar
|
||||||
|
|
||||||
__all__ = ["gpu", "cpu", "xpu", "npu"]
|
__all__ = ["gpu", "cpu", "xpu", "npu", "iluvatar"]
|
||||||
|
24
fastdeploy/model_executor/ops/iluvatar/__init__.py
Normal file
24
fastdeploy/model_executor/ops/iluvatar/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""fastdeploy gpu ops"""
|
||||||
|
|
||||||
|
from fastdeploy.import_ops import import_custom_ops
|
||||||
|
|
||||||
|
PACKAGE = "fastdeploy.model_executor.ops.iluvatar"
|
||||||
|
|
||||||
|
import_custom_ops(PACKAGE, "..base.fastdeploy_base_ops", globals())
|
||||||
|
import_custom_ops(PACKAGE, ".fastdeploy_ops", globals())
|
||||||
|
|
||||||
|
from .moe_ops import iluvatar_moe_expert_ffn as moe_expert_ffn # noqa: E402, F401
|
||||||
|
from .paged_attention import paged_attention # noqa: E402, F401
|
101
fastdeploy/model_executor/ops/iluvatar/moe_ops.py
Normal file
101
fastdeploy/model_executor/ops/iluvatar/moe_ops.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
import paddle
|
||||||
|
from paddle.nn.quant import weight_only_linear
|
||||||
|
from paddle.incubate.nn.functional import swiglu
|
||||||
|
|
||||||
|
|
||||||
|
def group_gemm(
|
||||||
|
input: paddle.Tensor,
|
||||||
|
tokens_expert_prefix_sum: paddle.Tensor,
|
||||||
|
weight: paddle.Tensor,
|
||||||
|
scale: paddle.Tensor,
|
||||||
|
output: paddle.Tensor,
|
||||||
|
):
|
||||||
|
assert (input.dim() == 2 and tokens_expert_prefix_sum.dim() == 1
|
||||||
|
and weight.dim() == 3 and scale.dim() == 2 and output.dim() == 2)
|
||||||
|
num_tokens = input.shape[0]
|
||||||
|
dim_in = input.shape[1]
|
||||||
|
dim_out = weight.shape[1]
|
||||||
|
num_experts = weight.shape[0]
|
||||||
|
|
||||||
|
# check shape
|
||||||
|
assert tokens_expert_prefix_sum.shape == [
|
||||||
|
num_experts,
|
||||||
|
]
|
||||||
|
assert weight.shape == [num_experts, dim_out, dim_in]
|
||||||
|
assert scale.shape == [num_experts, dim_out]
|
||||||
|
assert output.shape == [num_tokens, dim_out]
|
||||||
|
|
||||||
|
# check dtype
|
||||||
|
assert input.dtype in (paddle.float16, paddle.bfloat16)
|
||||||
|
assert scale.dtype == input.dtype and output.dtype == input.dtype
|
||||||
|
assert tokens_expert_prefix_sum.dtype == paddle.int64
|
||||||
|
assert weight.dtype == paddle.int8
|
||||||
|
|
||||||
|
# check others
|
||||||
|
assert tokens_expert_prefix_sum.place.is_cpu_place()
|
||||||
|
assert tokens_expert_prefix_sum[-1] == num_tokens
|
||||||
|
for i in range(num_experts):
|
||||||
|
expert_start = 0 if i == 0 else tokens_expert_prefix_sum[i - 1]
|
||||||
|
expert_end = tokens_expert_prefix_sum[i]
|
||||||
|
if expert_start == expert_end:
|
||||||
|
continue
|
||||||
|
input_i = input[expert_start:expert_end]
|
||||||
|
weight_i = weight[i]
|
||||||
|
scale_i = scale[i]
|
||||||
|
# avoid d2d?
|
||||||
|
output[expert_start:expert_end] = weight_only_linear(
|
||||||
|
input_i,
|
||||||
|
weight_i,
|
||||||
|
weight_scale=scale_i,
|
||||||
|
weight_dtype="int8",
|
||||||
|
group_size=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def iluvatar_moe_expert_ffn(
|
||||||
|
permute_input: paddle.Tensor,
|
||||||
|
tokens_expert_prefix_sum: paddle.Tensor,
|
||||||
|
ffn1_weight: paddle.Tensor,
|
||||||
|
ffn2_weight: paddle.Tensor,
|
||||||
|
ffn1_bias: Optional[paddle.Tensor],
|
||||||
|
ffn1_scale: Optional[paddle.Tensor],
|
||||||
|
ffn2_scale: Optional[paddle.Tensor],
|
||||||
|
ffn2_in_scale: Optional[paddle.Tensor],
|
||||||
|
expert_idx_per_token: Optional[paddle.Tensor],
|
||||||
|
quant_method: str,
|
||||||
|
used_in_ep_low_latency: bool,
|
||||||
|
):
|
||||||
|
assert ffn1_bias is None
|
||||||
|
assert ffn1_scale is not None
|
||||||
|
assert ffn2_scale is not None
|
||||||
|
assert ffn2_in_scale is None
|
||||||
|
assert expert_idx_per_token is None
|
||||||
|
assert quant_method in ("weight_only_int8")
|
||||||
|
assert not used_in_ep_low_latency
|
||||||
|
tokens_expert_prefix_sum_cpu = tokens_expert_prefix_sum.to("cpu")
|
||||||
|
ffn1_output = paddle.empty([permute_input.shape[0], ffn1_weight.shape[1]],
|
||||||
|
dtype=permute_input.dtype)
|
||||||
|
group_gemm(permute_input, tokens_expert_prefix_sum_cpu, ffn1_weight,
|
||||||
|
ffn1_scale, ffn1_output)
|
||||||
|
act_out = swiglu(ffn1_output)
|
||||||
|
output = paddle.empty([act_out.shape[0], ffn2_weight.shape[1]],
|
||||||
|
dtype=act_out.dtype)
|
||||||
|
group_gemm(act_out, tokens_expert_prefix_sum_cpu, ffn2_weight, ffn2_scale,
|
||||||
|
output)
|
||||||
|
return output
|
46
fastdeploy/model_executor/ops/iluvatar/paged_attention.py
Normal file
46
fastdeploy/model_executor/ops/iluvatar/paged_attention.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
try:
|
||||||
|
from fastdeploy.model_executor.ops.iluvatar import paged_attn
|
||||||
|
except ImportError:
|
||||||
|
paged_attn = None
|
||||||
|
|
||||||
|
|
||||||
|
def paged_attention(q: paddle.Tensor,
|
||||||
|
k_cache: paddle.Tensor,
|
||||||
|
v_cache: paddle.Tensor,
|
||||||
|
block_tables: paddle.Tensor,
|
||||||
|
seq_lens: paddle.Tensor,
|
||||||
|
num_kv_heads: int,
|
||||||
|
scale: float,
|
||||||
|
block_size: int,
|
||||||
|
max_context_len: int,
|
||||||
|
alibi_slopes: paddle.Tensor = None,
|
||||||
|
causal: bool = True,
|
||||||
|
window_left: int = -1,
|
||||||
|
window_right: int = -1,
|
||||||
|
softcap: float = 0.0,
|
||||||
|
use_cuda_graph: bool = False,
|
||||||
|
use_sqrt_alibi: bool = False,
|
||||||
|
k: paddle.Tensor = None,
|
||||||
|
v: paddle.Tensor = None):
|
||||||
|
output = paged_attn(q, k_cache, v_cache, block_tables, seq_lens,
|
||||||
|
alibi_slopes, k, v, num_kv_heads, scale, block_size,
|
||||||
|
max_context_len, causal, window_left, window_right,
|
||||||
|
softcap, use_cuda_graph, use_sqrt_alibi)
|
||||||
|
return output[0] if isinstance(output, list) else output
|
@@ -19,18 +19,25 @@ import paddle
|
|||||||
|
|
||||||
from fastdeploy import envs
|
from fastdeploy import envs
|
||||||
from fastdeploy.engine.config import SpeculativeConfig
|
from fastdeploy.engine.config import SpeculativeConfig
|
||||||
from fastdeploy.model_executor.ops.gpu import (
|
|
||||||
get_padding_offset, save_output, set_stop_value_multi_ends,
|
|
||||||
speculate_clear_accept_nums, speculate_get_output_padding_offset,
|
|
||||||
speculate_get_padding_offset, speculate_get_seq_lens_output,
|
|
||||||
speculate_save_output, speculate_set_value_by_flags_and_idx,
|
|
||||||
speculate_step_paddle, speculate_step_system_cache, speculate_update_v3,
|
|
||||||
step_paddle, step_system_cache, update_inputs, step_reschedule)
|
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
|
if current_platform.is_iluvatar():
|
||||||
|
from fastdeploy.model_executor.ops.iluvatar import (
|
||||||
|
get_padding_offset, save_output, set_stop_value_multi_ends,
|
||||||
|
step_paddle, update_inputs)
|
||||||
|
else:
|
||||||
|
from fastdeploy.model_executor.ops.gpu import (
|
||||||
|
get_padding_offset, save_output, set_stop_value_multi_ends,
|
||||||
|
speculate_clear_accept_nums, speculate_get_output_padding_offset,
|
||||||
|
speculate_get_padding_offset, speculate_get_seq_lens_output,
|
||||||
|
speculate_save_output, speculate_set_value_by_flags_and_idx,
|
||||||
|
speculate_step_paddle, speculate_step_system_cache,
|
||||||
|
speculate_update_v3, step_paddle, step_system_cache, update_inputs,
|
||||||
|
step_reschedule)
|
||||||
from fastdeploy.worker.output import ModelOutputData
|
from fastdeploy.worker.output import ModelOutputData
|
||||||
|
|
||||||
DISABLE_RECOVER = (envs.FD_DISABLED_RECOVER == "1")
|
DISABLE_RECOVER = (envs.FD_DISABLED_RECOVER == "1")
|
||||||
|
|
||||||
|
|
||||||
def pre_process(
|
def pre_process(
|
||||||
max_len: int,
|
max_len: int,
|
||||||
input_ids: paddle.Tensor,
|
input_ids: paddle.Tensor,
|
||||||
@@ -151,6 +158,7 @@ def post_process_normal(sampled_token_ids: paddle.Tensor,
|
|||||||
save_each_rank, # save_each_rank
|
save_each_rank, # save_each_rank
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def post_process_specualate(model_output, skip_save_output: bool = False):
|
def post_process_specualate(model_output, skip_save_output: bool = False):
|
||||||
""""""
|
""""""
|
||||||
speculate_update_v3(
|
speculate_update_v3(
|
||||||
@@ -217,7 +225,6 @@ def step_cuda(
|
|||||||
TODO(gongshaotian): normalization name
|
TODO(gongshaotian): normalization name
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
if speculative_config.method is not None:
|
if speculative_config.method is not None:
|
||||||
if enable_prefix_caching:
|
if enable_prefix_caching:
|
||||||
speculate_step_system_cache(
|
speculate_step_system_cache(
|
||||||
@@ -373,6 +380,17 @@ def rebuild_padding(tmp_out: paddle.Tensor,
|
|||||||
output_padding_offset,
|
output_padding_offset,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
)
|
)
|
||||||
|
elif current_platform.is_iluvatar():
|
||||||
|
from fastdeploy.model_executor.ops.iluvatar import rebuild_padding
|
||||||
|
hidden_states = rebuild_padding(
|
||||||
|
tmp_out,
|
||||||
|
cum_offsets,
|
||||||
|
seq_len_this_time,
|
||||||
|
seq_lens_decoder,
|
||||||
|
seq_lens_encoder,
|
||||||
|
output_padding_offset,
|
||||||
|
max_input_length,
|
||||||
|
)
|
||||||
elif current_platform.is_cpu():
|
elif current_platform.is_cpu():
|
||||||
from fastdeploy.model_executor.ops.cpu import rebuild_padding_cpu
|
from fastdeploy.model_executor.ops.cpu import rebuild_padding_cpu
|
||||||
hidden_states = rebuild_padding_cpu(
|
hidden_states = rebuild_padding_cpu(
|
||||||
|
@@ -122,9 +122,13 @@ class TokenProcessor(object):
|
|||||||
|
|
||||||
if current_platform.is_xpu():
|
if current_platform.is_xpu():
|
||||||
from fastdeploy.model_executor.ops.xpu import get_output
|
from fastdeploy.model_executor.ops.xpu import get_output
|
||||||
|
elif current_platform.is_iluvatar():
|
||||||
|
from fastdeploy.model_executor.ops.iluvatar import get_output
|
||||||
else:
|
else:
|
||||||
from fastdeploy.model_executor.ops.gpu import (
|
from fastdeploy.model_executor.ops.gpu import (get_output,
|
||||||
get_output, get_output_ep, speculate_get_output)
|
get_output_ep,
|
||||||
|
speculate_get_output
|
||||||
|
)
|
||||||
rank_id = self.cfg.parallel_config.local_data_parallel_id
|
rank_id = self.cfg.parallel_config.local_data_parallel_id
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@@ -413,9 +417,12 @@ class WarmUpTokenProcessor(TokenProcessor):
|
|||||||
|
|
||||||
if current_platform.is_xpu():
|
if current_platform.is_xpu():
|
||||||
from fastdeploy.model_executor.ops.xpu import get_output
|
from fastdeploy.model_executor.ops.xpu import get_output
|
||||||
|
elif current_platform.is_iluvatar():
|
||||||
|
from fastdeploy.model_executor.ops.iluvatar import get_output
|
||||||
else:
|
else:
|
||||||
from fastdeploy.model_executor.ops.gpu import (
|
from fastdeploy.model_executor.ops.gpu import (get_output,
|
||||||
get_output, speculate_get_output)
|
speculate_get_output
|
||||||
|
)
|
||||||
|
|
||||||
while self._is_running:
|
while self._is_running:
|
||||||
try:
|
try:
|
||||||
|
@@ -11,7 +11,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
platform module
|
platform module
|
||||||
"""
|
"""
|
||||||
@@ -22,7 +21,8 @@ from .cpu import CPUPlatform
|
|||||||
from .xpu import XPUPlatform
|
from .xpu import XPUPlatform
|
||||||
from .npu import NPUPlatform
|
from .npu import NPUPlatform
|
||||||
from .dcu import DCUPlatform
|
from .dcu import DCUPlatform
|
||||||
from .base import _Backend
|
from .iluvatar import IluvatarPlatform
|
||||||
|
from .base import _Backend # noqa: F401
|
||||||
|
|
||||||
_current_platform = None
|
_current_platform = None
|
||||||
|
|
||||||
@@ -40,10 +40,13 @@ def __getattr__(name: str):
|
|||||||
_current_platform = NPUPlatform()
|
_current_platform = NPUPlatform()
|
||||||
elif paddle.is_compiled_with_rocm():
|
elif paddle.is_compiled_with_rocm():
|
||||||
_current_platform = DCUPlatform()
|
_current_platform = DCUPlatform()
|
||||||
|
elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
||||||
|
_current_platform = IluvatarPlatform()
|
||||||
else:
|
else:
|
||||||
_current_platform = CPUPlatform()
|
_current_platform = CPUPlatform()
|
||||||
return _current_platform
|
return _current_platform
|
||||||
elif name in globals():
|
elif name in globals():
|
||||||
return globals()[name]
|
return globals()[name]
|
||||||
else:
|
else:
|
||||||
raise AttributeError(f"No attribute named '{name}' exists in {__name__}.")
|
raise AttributeError(
|
||||||
|
f"No attribute named '{name}' exists in {__name__}.")
|
||||||
|
@@ -63,6 +63,12 @@ class Platform:
|
|||||||
"""
|
"""
|
||||||
return paddle.is_compiled_with_rocm()
|
return paddle.is_compiled_with_rocm()
|
||||||
|
|
||||||
|
def is_iluvatar(self) -> bool:
|
||||||
|
"""
|
||||||
|
whether platform is iluvatar gpu
|
||||||
|
"""
|
||||||
|
return paddle.is_compiled_with_custom_device("iluvatar_gpu")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_attention_backend_cls(self, selected_backend):
|
def get_attention_backend_cls(self, selected_backend):
|
||||||
"""Get the attention backend"""
|
"""Get the attention backend"""
|
||||||
|
26
fastdeploy/platforms/iluvatar.py
Normal file
26
fastdeploy/platforms/iluvatar.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from .base import Platform
|
||||||
|
|
||||||
|
|
||||||
|
class IluvatarPlatform(Platform):
|
||||||
|
device_name = "iluvatar_gpu"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_attention_backend_cls(cls, selected_backend):
|
||||||
|
"""
|
||||||
|
get_attention_backend_cls
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
"fastdeploy.model_executor.layers.attention.IluvatarAttnBackend")
|
1142
fastdeploy/worker/iluvatar_model_runner.py
Normal file
1142
fastdeploy/worker/iluvatar_model_runner.py
Normal file
File diff suppressed because it is too large
Load Diff
143
fastdeploy/worker/iluvatar_worker.py
Normal file
143
fastdeploy/worker/iluvatar_worker.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
import gc
|
||||||
|
import os
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
|
||||||
|
from fastdeploy.config import FDConfig
|
||||||
|
from fastdeploy.engine.request import Request
|
||||||
|
from fastdeploy.utils import get_logger
|
||||||
|
from fastdeploy.worker.iluvatar_model_runner import IluvatarModelRunner
|
||||||
|
from fastdeploy.worker.output import ModelRunnerOutput
|
||||||
|
from fastdeploy.worker.worker_base import WorkerBase
|
||||||
|
|
||||||
|
logger = get_logger("iluvatar_worker", "iluvatar_worker.log")
|
||||||
|
|
||||||
|
|
||||||
|
class IluvatarWorker(WorkerBase):
|
||||||
|
""" """
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fd_config: FDConfig,
|
||||||
|
local_rank: int,
|
||||||
|
rank: int,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
fd_config=fd_config,
|
||||||
|
local_rank=local_rank,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def init_device(self):
|
||||||
|
""" Initialize device and Construct model runner
|
||||||
|
"""
|
||||||
|
if paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
||||||
|
# Set evironment variable
|
||||||
|
self.device = f"iluvatar_gpu:{self.local_rank}"
|
||||||
|
paddle.device.set_device(self.device)
|
||||||
|
paddle.set_default_dtype(self.parallel_config.dtype)
|
||||||
|
self.device_ids = self.parallel_config.device_ids.split(",")
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Not support device type: {self.device_config.device}")
|
||||||
|
|
||||||
|
# Construct model runner
|
||||||
|
self.model_runner: IluvatarModelRunner = IluvatarModelRunner(
|
||||||
|
fd_config=self.fd_config,
|
||||||
|
device=self.device,
|
||||||
|
device_id=self.device_ids[self.local_rank],
|
||||||
|
rank=self.rank,
|
||||||
|
local_rank=self.local_rank)
|
||||||
|
|
||||||
|
def prefill_finished(self):
|
||||||
|
"""
|
||||||
|
check whether prefill stage finished
|
||||||
|
"""
|
||||||
|
return self.model_runner.prefill_finished()
|
||||||
|
|
||||||
|
def determine_available_memory(self) -> int:
|
||||||
|
"""
|
||||||
|
Profiles the peak memory usage of the model to determine how much
|
||||||
|
memory can be used for KV cache without OOMs.
|
||||||
|
|
||||||
|
The engine will first conduct a profiling of the existing memory usage.
|
||||||
|
Then, it calculate the maximum possible number of GPU and CPU blocks
|
||||||
|
that can be allocated with the remaining free memory.
|
||||||
|
|
||||||
|
Tip:
|
||||||
|
You may limit the usage of GPU memory
|
||||||
|
by adjusting the `gpu_memory_utilization` parameter.
|
||||||
|
"""
|
||||||
|
# 1. Record memory state before profile run
|
||||||
|
return int(float(os.getenv("FD_ILUVATAR_KVCACHE_MEM", "3")) * 1024**3)
|
||||||
|
|
||||||
|
def load_model(self) -> None:
|
||||||
|
""" """
|
||||||
|
self.model_runner.load_model()
|
||||||
|
|
||||||
|
def get_model(self) -> nn.Layer:
|
||||||
|
""" """
|
||||||
|
return self.model_runner.get_model()
|
||||||
|
|
||||||
|
def initialize_cache(self, num_gpu_blocks: int,
|
||||||
|
num_cpu_blocks: int) -> None:
|
||||||
|
""" """
|
||||||
|
pass
|
||||||
|
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
model_forward_batch: Optional[List[Request]] = None,
|
||||||
|
) -> Optional[ModelRunnerOutput]:
|
||||||
|
""" """
|
||||||
|
output = self.model_runner.execute_model(model_forward_batch)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def preprocess_new_task(self, req_dicts: List[Request]) -> None:
|
||||||
|
""" Process new requests and then start the decode loop
|
||||||
|
TODO(gongshaotian):The scheduler should schedule the handling of prefill,
|
||||||
|
and workers and modelrunners should not perceive it.
|
||||||
|
"""
|
||||||
|
self.model_runner.insert_prefill_inputs(req_dicts=req_dicts)
|
||||||
|
|
||||||
|
def graph_optimize_and_warm_up_model(self) -> None:
|
||||||
|
"""
|
||||||
|
Perform the warm-up and the graph optimization
|
||||||
|
"""
|
||||||
|
# 1. Warm up model
|
||||||
|
# NOTE(gongshaotian): may be not need warm_up at this place
|
||||||
|
|
||||||
|
# 2. Triger cuda grpah capture
|
||||||
|
self.model_runner.capture_model()
|
||||||
|
|
||||||
|
def check_health(self) -> bool:
|
||||||
|
""" """
|
||||||
|
return True
|
||||||
|
|
||||||
|
def cal_theortical_kvcache(self) -> int:
|
||||||
|
""" """
|
||||||
|
return self.model_runner.cal_theortical_kvcache()
|
||||||
|
|
||||||
|
def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None:
|
||||||
|
""" """
|
||||||
|
self.model_runner.update_share_input_block_num(
|
||||||
|
num_gpu_blocks=num_gpu_blocks)
|
@@ -48,6 +48,11 @@ def get_worker(fd_config: FDConfig, local_rank: int, rank: int) -> WorkerBase:
|
|||||||
if current_platform.is_xpu():
|
if current_platform.is_xpu():
|
||||||
from fastdeploy.worker.xpu_worker import XpuWorker
|
from fastdeploy.worker.xpu_worker import XpuWorker
|
||||||
return XpuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
|
return XpuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
|
||||||
|
if current_platform.is_iluvatar():
|
||||||
|
from fastdeploy.worker.iluvatar_worker import IluvatarWorker
|
||||||
|
return IluvatarWorker(fd_config=fd_config,
|
||||||
|
local_rank=local_rank,
|
||||||
|
rank=rank)
|
||||||
|
|
||||||
|
|
||||||
class PaddleDisWorkerProc():
|
class PaddleDisWorkerProc():
|
||||||
@@ -125,9 +130,9 @@ class PaddleDisWorkerProc():
|
|||||||
model_weights_status:
|
model_weights_status:
|
||||||
"""
|
"""
|
||||||
# init worker_ready_signal
|
# init worker_ready_signal
|
||||||
|
max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||||
array_size = min(
|
array_size = min(
|
||||||
8, self.parallel_config.tensor_parallel_degree *
|
max_chips_per_node, self.parallel_config.tensor_parallel_degree *
|
||||||
self.parallel_config.expert_parallel_degree)
|
self.parallel_config.expert_parallel_degree)
|
||||||
workers_ready = np.zeros(shape=[array_size], dtype=np.int32)
|
workers_ready = np.zeros(shape=[array_size], dtype=np.int32)
|
||||||
self.worker_ready_signal = IPCSignal(
|
self.worker_ready_signal = IPCSignal(
|
||||||
@@ -136,7 +141,8 @@ class PaddleDisWorkerProc():
|
|||||||
dtype=np.int32,
|
dtype=np.int32,
|
||||||
suffix=self.parallel_config.engine_pid,
|
suffix=self.parallel_config.engine_pid,
|
||||||
create=False)
|
create=False)
|
||||||
self.worker_ready_signal.value[self.local_rank % 8] = 1
|
self.worker_ready_signal.value[self.local_rank %
|
||||||
|
max_chips_per_node] = 1
|
||||||
|
|
||||||
# init worker_healthy_live_signal
|
# init worker_healthy_live_signal
|
||||||
workers_alive = np.zeros(shape=[self.ranks], dtype=np.int32)
|
workers_alive = np.zeros(shape=[self.ranks], dtype=np.int32)
|
||||||
|
29
requirements_iluvatar.txt
Normal file
29
requirements_iluvatar.txt
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
setuptools>=62.3.0,<80.0
|
||||||
|
pre-commit
|
||||||
|
yapf
|
||||||
|
flake8
|
||||||
|
ruamel.yaml
|
||||||
|
zmq
|
||||||
|
aiozmq
|
||||||
|
openai
|
||||||
|
tqdm
|
||||||
|
pynvml
|
||||||
|
uvicorn
|
||||||
|
fastapi
|
||||||
|
paddleformers
|
||||||
|
redis
|
||||||
|
etcd3
|
||||||
|
httpx
|
||||||
|
tool_helpers
|
||||||
|
pybind11[global]
|
||||||
|
tabulate
|
||||||
|
gradio
|
||||||
|
xlwt
|
||||||
|
visualdl
|
||||||
|
setuptools-scm>=8
|
||||||
|
prometheus-client
|
||||||
|
decord
|
||||||
|
moviepy
|
||||||
|
use-triton-in-paddle
|
||||||
|
crcmod
|
||||||
|
fastsafetensors==0.1.14
|
44
setup.py
44
setup.py
@@ -22,6 +22,7 @@ import subprocess
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from setuptools import Extension, find_packages, setup
|
from setuptools import Extension, find_packages, setup
|
||||||
from setuptools.command.build_ext import build_ext
|
from setuptools.command.build_ext import build_ext
|
||||||
|
from wheel.bdist_wheel import bdist_wheel
|
||||||
|
|
||||||
long_description = "FastDeploy: Large Language Model Serving.\n\n"
|
long_description = "FastDeploy: Large Language Model Serving.\n\n"
|
||||||
long_description += "GitHub: https://github.com/PaddlePaddle/FastDeploy\n"
|
long_description += "GitHub: https://github.com/PaddlePaddle/FastDeploy\n"
|
||||||
@@ -35,7 +36,6 @@ PLAT_TO_CMAKE = {
|
|||||||
"win-arm64": "ARM64",
|
"win-arm64": "ARM64",
|
||||||
}
|
}
|
||||||
|
|
||||||
from wheel.bdist_wheel import bdist_wheel
|
|
||||||
|
|
||||||
class CustomBdistWheel(bdist_wheel):
|
class CustomBdistWheel(bdist_wheel):
|
||||||
"""Custom wheel builder for pure Python packages."""
|
"""Custom wheel builder for pure Python packages."""
|
||||||
@@ -49,10 +49,14 @@ class CustomBdistWheel(bdist_wheel):
|
|||||||
self.plat_name_supplied = True
|
self.plat_name_supplied = True
|
||||||
self.plat_name = 'any'
|
self.plat_name = 'any'
|
||||||
|
|
||||||
|
|
||||||
class CMakeExtension(Extension):
|
class CMakeExtension(Extension):
|
||||||
"""A setuptools Extension for CMake-based builds."""
|
"""A setuptools Extension for CMake-based builds."""
|
||||||
|
|
||||||
def __init__(self, name: str, sourcedir: str = "", version: str = None) -> None:
|
def __init__(self,
|
||||||
|
name: str,
|
||||||
|
sourcedir: str = "",
|
||||||
|
version: str = None) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize CMake extension.
|
Initialize CMake extension.
|
||||||
|
|
||||||
@@ -84,15 +88,11 @@ class CMakeBuild(build_ext):
|
|||||||
extdir = ext_fullpath.parent.resolve()
|
extdir = ext_fullpath.parent.resolve()
|
||||||
cfg = "Debug" if int(os.environ.get("DEBUG", 0)) else "Release"
|
cfg = "Debug" if int(os.environ.get("DEBUG", 0)) else "Release"
|
||||||
|
|
||||||
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
|
|
||||||
|
|
||||||
cmake_args = [
|
cmake_args = [
|
||||||
f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}",
|
f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}",
|
||||||
f"-DPYTHON_EXECUTABLE={sys.executable}",
|
f"-DPYTHON_EXECUTABLE={sys.executable}",
|
||||||
f"-DCMAKE_BUILD_TYPE={cfg}",
|
f"-DCMAKE_BUILD_TYPE={cfg}", "-DVERSION_INFO=",
|
||||||
f"-DVERSION_INFO=",
|
"-DPYBIND11_PYTHON_VERSION=", "-DPYTHON_VERSION=",
|
||||||
f"-DPYBIND11_PYTHON_VERSION=",
|
|
||||||
f"-DPYTHON_VERSION=",
|
|
||||||
f"-DPYTHON_INCLUDE_DIR={sys.prefix}/include/python{sys.version_info.major}.{sys.version_info.minor}",
|
f"-DPYTHON_INCLUDE_DIR={sys.prefix}/include/python{sys.version_info.major}.{sys.version_info.minor}",
|
||||||
f"-DPYTHON_LIBRARY={sys.prefix}/lib/libpython{sys.version_info.major}.{sys.version_info.minor}.so"
|
f"-DPYTHON_LIBRARY={sys.prefix}/lib/libpython{sys.version_info.major}.{sys.version_info.minor}.so"
|
||||||
]
|
]
|
||||||
@@ -134,22 +134,27 @@ class CMakeBuild(build_ext):
|
|||||||
build_temp.mkdir(parents=True, exist_ok=True)
|
build_temp.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
subprocess.run(["cmake", ext.sourcedir, *cmake_args],
|
subprocess.run(["cmake", ext.sourcedir, *cmake_args],
|
||||||
cwd=build_temp,
|
cwd=build_temp,
|
||||||
check=True)
|
check=True)
|
||||||
subprocess.run(["cmake", "--build", ".", *build_args],
|
subprocess.run(["cmake", "--build", ".", *build_args],
|
||||||
cwd=build_temp,
|
cwd=build_temp,
|
||||||
check=True)
|
check=True)
|
||||||
|
|
||||||
|
|
||||||
def load_requirements():
|
def load_requirements():
|
||||||
"""Load dependencies from requirements.txt"""
|
"""Load dependencies from requirements.txt"""
|
||||||
|
requirements_file_name = 'requirements.txt'
|
||||||
|
if paddle.is_compiled_with_custom_device('iluvatar_gpu'):
|
||||||
|
requirements_file_name = 'requirements_iluvatar.txt'
|
||||||
requirements_path = os.path.join(os.path.dirname(__file__),
|
requirements_path = os.path.join(os.path.dirname(__file__),
|
||||||
'requirements.txt')
|
requirements_file_name)
|
||||||
with open(requirements_path, 'r') as f:
|
with open(requirements_path, 'r') as f:
|
||||||
return [
|
return [
|
||||||
line.strip() for line in f
|
line.strip() for line in f
|
||||||
if line.strip() and not line.startswith('#')
|
if line.strip() and not line.startswith('#')
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_device_type():
|
def get_device_type():
|
||||||
"""Get the device type (rocm/gpu/xpu/npu/cpu) that paddle is compiled with."""
|
"""Get the device type (rocm/gpu/xpu/npu/cpu) that paddle is compiled with."""
|
||||||
if paddle.is_compiled_with_rocm():
|
if paddle.is_compiled_with_rocm():
|
||||||
@@ -160,13 +165,17 @@ def get_device_type():
|
|||||||
return "xpu"
|
return "xpu"
|
||||||
elif paddle.is_compiled_with_custom_device('npu'):
|
elif paddle.is_compiled_with_custom_device('npu'):
|
||||||
return "npu"
|
return "npu"
|
||||||
|
elif paddle.is_compiled_with_custom_device('iluvatar_gpu'):
|
||||||
|
return "iluvatar-gpu"
|
||||||
else:
|
else:
|
||||||
return "cpu"
|
return "cpu"
|
||||||
|
|
||||||
|
|
||||||
def get_name():
|
def get_name():
|
||||||
"""get package name"""
|
"""get package name"""
|
||||||
return "fastdeploy-" + get_device_type()
|
return "fastdeploy-" + get_device_type()
|
||||||
|
|
||||||
|
|
||||||
cmdclass_dict = {'bdist_wheel': CustomBdistWheel}
|
cmdclass_dict = {'bdist_wheel': CustomBdistWheel}
|
||||||
cmdclass_dict['build_ext'] = CMakeBuild
|
cmdclass_dict['build_ext'] = CMakeBuild
|
||||||
FASTDEPLOY_VERSION = os.environ.get("FASTDEPLOY_VERSION", "2.0.0")
|
FASTDEPLOY_VERSION = os.environ.get("FASTDEPLOY_VERSION", "2.0.0")
|
||||||
@@ -187,8 +196,8 @@ setup(
|
|||||||
"model_executor/ops/gpu/*",
|
"model_executor/ops/gpu/*",
|
||||||
"model_executor/ops/gpu/deep_gemm/include/**/*",
|
"model_executor/ops/gpu/deep_gemm/include/**/*",
|
||||||
"model_executor/ops/cpu/*", "model_executor/ops/xpu/*",
|
"model_executor/ops/cpu/*", "model_executor/ops/xpu/*",
|
||||||
"model_executor/ops/xpu/libs/*",
|
"model_executor/ops/xpu/libs/*", "model_executor/ops/npu/*",
|
||||||
"model_executor/ops/npu/*", "model_executor/ops/base/*",
|
"model_executor/ops/base/*", "model_executor/ops/iluvatar/*",
|
||||||
"model_executor/models/*", "model_executor/layers/*",
|
"model_executor/models/*", "model_executor/layers/*",
|
||||||
"input/mm_processor/utils/*",
|
"input/mm_processor/utils/*",
|
||||||
"version.txt"
|
"version.txt"
|
||||||
@@ -198,9 +207,10 @@ setup(
|
|||||||
ext_modules=[
|
ext_modules=[
|
||||||
CMakeExtension(
|
CMakeExtension(
|
||||||
"rdma_comm",
|
"rdma_comm",
|
||||||
sourcedir="fastdeploy/cache_manager/transfer_factory/kvcache_transfer",
|
sourcedir=
|
||||||
|
"fastdeploy/cache_manager/transfer_factory/kvcache_transfer",
|
||||||
version=None)
|
version=None)
|
||||||
],
|
] if os.getenv("ENABLE_FD_RDMA", "0") == "1" else [],
|
||||||
cmdclass=cmdclass_dict if os.getenv("ENABLE_FD_RDMA", "0") == "1" else {},
|
cmdclass=cmdclass_dict if os.getenv("ENABLE_FD_RDMA", "0") == "1" else {},
|
||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
classifiers=[
|
classifiers=[
|
||||||
|
Reference in New Issue
Block a user