mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Metax] support cutlass moe & optimize flash attention (#4208)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
This commit is contained in:
@@ -14,6 +14,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_fp8.h>
|
||||
|
||||
#ifndef PADDLE_WITH_COREX
|
||||
#include "glog/logging.h"
|
||||
#endif
|
||||
|
181
custom_ops/metax_ops/fused_moe.cu
Normal file
181
custom_ops/metax_ops/fused_moe.cu
Normal file
@@ -0,0 +1,181 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "helper.h"
|
||||
#include "mc_fused_moe_helper.h"
|
||||
#include "fused_moe_op.h"
|
||||
|
||||
__global__ void compute_total_rows_before_expert_kernel(
|
||||
int* sorted_experts,
|
||||
const int64_t sorted_experts_len,
|
||||
const int64_t num_experts,
|
||||
int32_t* total_rows_before_expert) {
|
||||
const int expert = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (expert >= num_experts) return;
|
||||
|
||||
total_rows_before_expert[expert] =
|
||||
find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert);
|
||||
}
|
||||
|
||||
void compute_total_rows_before_expert(int* sorted_indices,
|
||||
const int64_t total_indices,
|
||||
const int64_t num_experts,
|
||||
int32_t* total_rows_before_expert,
|
||||
cudaStream_t stream) {
|
||||
const int threads = std::min(int64_t(1024), num_experts);
|
||||
const int blocks = (num_experts + threads - 1) / threads;
|
||||
|
||||
compute_total_rows_before_expert_kernel<<<blocks, threads, 0, stream>>>(
|
||||
sorted_indices, total_indices, num_experts, total_rows_before_expert);
|
||||
}
|
||||
|
||||
template <paddle::DataType T, typename ElementA, typename ElementB, typename ElementC>
|
||||
void FusedMoeKernel(const paddle::Tensor& input,
|
||||
const paddle::Tensor& gate_weight,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_bias,
|
||||
const std::string& quant_method,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
const bool norm_topk_prob,
|
||||
paddle::Tensor* output) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto* output_data = output->data<data_t>();
|
||||
|
||||
auto moe_compute = McMoeHelper<data_t, ElementA, ElementB, ElementC>(quant_method);
|
||||
|
||||
moe_compute.computeFFN(
|
||||
&input,
|
||||
&gate_weight,
|
||||
&ffn1_weight,
|
||||
ffn1_scale ? ffn1_scale.get_ptr() : nullptr,
|
||||
ffn1_bias ? ffn1_bias.get_ptr() : nullptr,
|
||||
&ffn2_weight,
|
||||
ffn2_scale ? ffn2_scale.get_ptr() : nullptr,
|
||||
ffn2_bias ? ffn2_bias.get_ptr() : nullptr,
|
||||
nullptr,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
norm_topk_prob,
|
||||
1.0, // ComputeFFN
|
||||
"ffn",
|
||||
output);
|
||||
}
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> FusedExpertMoe(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& gate_weight,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const std::string& quant_method,
|
||||
const int moe_topk,
|
||||
const bool norm_topk_prob,
|
||||
const bool group_moe) {
|
||||
const auto input_type = input.dtype();
|
||||
auto output = paddle::empty_like(input);
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
FusedMoeKernel<paddle::DataType::BFLOAT16, maca_bfloat16, int8_t, maca_bfloat16>(input,
|
||||
gate_weight,
|
||||
ffn1_weight,
|
||||
ffn1_scale,
|
||||
ffn1_bias,
|
||||
ffn2_weight,
|
||||
ffn2_scale,
|
||||
ffn2_bias,
|
||||
quant_method,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
norm_topk_prob,
|
||||
&output);
|
||||
break;
|
||||
// case paddle::DataType::FLOAT16:
|
||||
// FusedMoeKernel<paddle::DataType::FLOAT16>(input,
|
||||
// gate_weight,
|
||||
// ffn1_weight,
|
||||
// ffn1_scale,
|
||||
// ffn1_bias,
|
||||
// ffn2_weight,
|
||||
// ffn2_scale,
|
||||
// ffn2_bias,
|
||||
// quant_method,
|
||||
// moe_topk,
|
||||
// group_moe,
|
||||
// norm_topk_prob,
|
||||
// &output);
|
||||
// break;
|
||||
default:
|
||||
PD_THROW("Only support bf16 for FusedMoeKernel");
|
||||
}
|
||||
return {output};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> FusedExpertMoeInferShape(
|
||||
const std::vector<int64_t>& input_shape,
|
||||
const std::vector<int64_t>& gate_weight_shape,
|
||||
const std::vector<int64_t>& ffn1_weight_shape,
|
||||
const std::vector<int64_t>& ffn2_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn2_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn2_scale_shape) {
|
||||
return {input_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> FusedExpertMoeInferDtype(
|
||||
const paddle::DataType& input_dtype,
|
||||
const paddle::DataType& gate_weight_dtype,
|
||||
const paddle::DataType& ffn1_weight_dtype,
|
||||
const paddle::DataType& ffn2_weight_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn1_bias_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn1_scale_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn2_bias_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn2_scale_dtype) {
|
||||
return {input_dtype};
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_OP(fused_expert_moe)
|
||||
.Inputs({"input",
|
||||
"gate_weight",
|
||||
"ffn1_weight",
|
||||
"ffn2_weight",
|
||||
paddle::Optional("ffn1_bias"),
|
||||
paddle::Optional("ffn1_scale"),
|
||||
paddle::Optional("ffn2_bias"),
|
||||
paddle::Optional("ffn2_scale")})
|
||||
.Outputs({"output"})
|
||||
.Attrs({"quant_method:std::string",
|
||||
"moe_topk:int",
|
||||
"norm_topk_prob:bool",
|
||||
"group_moe:bool"})
|
||||
.SetKernelFn(PD_KERNEL(FusedExpertMoe))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(FusedExpertMoeInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(FusedExpertMoeInferDtype));
|
53
custom_ops/metax_ops/fused_moe_helper.h
Normal file
53
custom_ops/metax_ops/fused_moe_helper.h
Normal file
@@ -0,0 +1,53 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h"
|
||||
#include "fused_moe_op.h"
|
||||
|
||||
using namespace phi;
|
||||
|
||||
template <typename T, int VecSize>
|
||||
__global__ void moe_token_type_ids_kernel(T *gating_output,
|
||||
const int *moe_token_type_ids_out,
|
||||
const int num_rows,
|
||||
const int num_experts,
|
||||
const int k) {
|
||||
const int moe_token_index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (moe_token_index >= num_rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
gating_output[moe_token_index * 2] =
|
||||
gating_output[moe_token_index * 2] +
|
||||
(moe_token_type_ids_out[moe_token_index]) * -1e10;
|
||||
gating_output[moe_token_index * 2 + 1] =
|
||||
gating_output[moe_token_index * 2 + 1] +
|
||||
(1 - moe_token_type_ids_out[moe_token_index]) * -1e10;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void moe_token_type_ids_kernelLauncher(T *gating_output,
|
||||
const int *moe_token_type_ids_out,
|
||||
const int num_rows,
|
||||
const int num_experts,
|
||||
const int k,
|
||||
cudaStream_t stream) {
|
||||
const int blocks = num_rows * k / 512 + 1;
|
||||
const int threads = 512;
|
||||
moe_token_type_ids_kernel<T, 1><<<blocks, 512, 0, stream>>>(
|
||||
gating_output, moe_token_type_ids_out, num_rows, num_experts, k);
|
||||
}
|
123
custom_ops/metax_ops/fused_moe_imp_op.h
Normal file
123
custom_ops/metax_ops/fused_moe_imp_op.h
Normal file
@@ -0,0 +1,123 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION &
|
||||
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include "cub/cub.cuh"
|
||||
|
||||
static const float HALF_FLT_MAX = 65504.F;
|
||||
static const float HALF_FLT_MIN = -65504.F;
|
||||
static inline size_t AlignTo16(const size_t& input) {
|
||||
static constexpr int ALIGNMENT = 16;
|
||||
return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT);
|
||||
}
|
||||
|
||||
class CubKeyValueSorter {
|
||||
public:
|
||||
CubKeyValueSorter() : num_experts_(0), num_bits_(sizeof(int) * 8) {}
|
||||
|
||||
explicit CubKeyValueSorter(const int num_experts)
|
||||
: num_experts_(num_experts),
|
||||
num_bits_(static_cast<int>(log2(num_experts)) + 1) {}
|
||||
|
||||
void update_num_experts(const int num_experts) {
|
||||
num_experts_ = num_experts;
|
||||
num_bits_ = static_cast<int>(log2(num_experts)) + 1;
|
||||
}
|
||||
|
||||
size_t getWorkspaceSize(const size_t num_key_value_pairs,
|
||||
bool descending = false) {
|
||||
num_key_value_pairs_ = num_key_value_pairs;
|
||||
size_t required_storage = 0;
|
||||
int* null_int = nullptr;
|
||||
if (descending) {
|
||||
cub::DeviceRadixSort::SortPairsDescending(NULL,
|
||||
required_storage,
|
||||
null_int,
|
||||
null_int,
|
||||
null_int,
|
||||
null_int,
|
||||
num_key_value_pairs,
|
||||
0,
|
||||
32);
|
||||
} else {
|
||||
cub::DeviceRadixSort::SortPairs(NULL,
|
||||
required_storage,
|
||||
null_int,
|
||||
null_int,
|
||||
null_int,
|
||||
null_int,
|
||||
num_key_value_pairs,
|
||||
0,
|
||||
num_bits_);
|
||||
}
|
||||
return required_storage;
|
||||
}
|
||||
|
||||
template <typename KeyT>
|
||||
void run(void* workspace,
|
||||
const size_t workspace_size,
|
||||
const KeyT* keys_in,
|
||||
KeyT* keys_out,
|
||||
const int* values_in,
|
||||
int* values_out,
|
||||
const size_t num_key_value_pairs,
|
||||
bool descending,
|
||||
cudaStream_t stream) {
|
||||
size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs);
|
||||
size_t actual_ws_size = workspace_size;
|
||||
|
||||
if (expected_ws_size > workspace_size) {
|
||||
std::stringstream err_ss;
|
||||
err_ss << "[Error][CubKeyValueSorter::run]\n";
|
||||
err_ss << "Error. The allocated workspace is too small to run this "
|
||||
"problem.\n";
|
||||
err_ss << "Expected workspace size of at least " << expected_ws_size
|
||||
<< " but got problem size " << workspace_size << "\n";
|
||||
throw std::runtime_error(err_ss.str());
|
||||
}
|
||||
if (descending) {
|
||||
cub::DeviceRadixSort::SortPairsDescending(workspace,
|
||||
actual_ws_size,
|
||||
keys_in,
|
||||
keys_out,
|
||||
values_in,
|
||||
values_out,
|
||||
num_key_value_pairs,
|
||||
0,
|
||||
32,
|
||||
stream);
|
||||
} else {
|
||||
cub::DeviceRadixSort::SortPairs(workspace,
|
||||
actual_ws_size,
|
||||
keys_in,
|
||||
keys_out,
|
||||
values_in,
|
||||
values_out,
|
||||
num_key_value_pairs,
|
||||
0,
|
||||
num_bits_,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
size_t num_key_value_pairs_;
|
||||
int num_experts_;
|
||||
int num_bits_;
|
||||
};
|
990
custom_ops/metax_ops/fused_moe_op.h
Normal file
990
custom_ops/metax_ops/fused_moe_op.h
Normal file
@@ -0,0 +1,990 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION &
|
||||
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include "fused_moe_imp_op.h"
|
||||
#include "fused_moe_helper.h"
|
||||
#include "mctlass/numeric_conversion.h" // BUILD_MARK
|
||||
// Ignore mctlass warnings about type punning
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
#pragma GCC diagnostic ignored "-Wunused-function"
|
||||
|
||||
// #include "paddle/phi/backends/gpu/gpu_info.h"
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
#define WARP_SIZE 32
|
||||
|
||||
struct GpuLaunchConfig {
|
||||
dim3 block_per_grid;
|
||||
dim3 thread_per_block;
|
||||
};
|
||||
|
||||
inline GpuLaunchConfig Get1DBlocksAnd2DGridsMoe(const int64_t cols) {
|
||||
int blocks_x = cols;
|
||||
int blocks_y = 1;
|
||||
int blocks_z = 1;
|
||||
if (blocks_x > 1024) {
|
||||
blocks_y = 256;
|
||||
blocks_x = (blocks_x + blocks_y - 1) / blocks_y;
|
||||
}
|
||||
|
||||
GpuLaunchConfig config;
|
||||
config.block_per_grid.x = blocks_x;
|
||||
config.block_per_grid.y = blocks_y;
|
||||
config.block_per_grid.z = blocks_z;
|
||||
return config;
|
||||
}
|
||||
|
||||
// ====================== Softmax things ===============================
|
||||
// We have our own implementation of softmax here so we can support transposing
|
||||
// the output in the softmax kernel when we extend this module to support
|
||||
// expert-choice routing.
|
||||
template <typename T, int TPB>
|
||||
__launch_bounds__(TPB) __global__
|
||||
void group_moe_softmax(const T* input,
|
||||
T* output,
|
||||
T* softmax_max_prob,
|
||||
const int64_t num_cols,
|
||||
const int64_t softmax_num_rows) {
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
__shared__ float normalizing_factor;
|
||||
__shared__ float float_max;
|
||||
__shared__ float max_out;
|
||||
|
||||
int globalIdx = blockIdx.x + blockIdx.y * gridDim.x;
|
||||
if (globalIdx >= softmax_num_rows) {
|
||||
return;
|
||||
}
|
||||
const int64_t thread_row_offset = globalIdx * num_cols;
|
||||
|
||||
cub::Sum sum;
|
||||
float threadData(-FLT_MAX);
|
||||
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData = max(static_cast<float>(input[idx]), threadData);
|
||||
}
|
||||
|
||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||
if (threadIdx.x == 0) {
|
||||
float_max = maxElem;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
threadData = 0;
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData += exp((static_cast<float>(input[idx]) - float_max));
|
||||
}
|
||||
|
||||
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
normalizing_factor = 1.f / Z;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
threadData = 0;
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
const float val =
|
||||
exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
|
||||
output[idx] = T(val);
|
||||
threadData = max(static_cast<float>(T(val)), threadData);
|
||||
}
|
||||
|
||||
const float maxOut = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||
if (threadIdx.x == 0) {
|
||||
// group max probs
|
||||
max_out = 1.f / maxOut;
|
||||
softmax_max_prob[globalIdx] = T(max_out);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
// group softmax normalization
|
||||
output[idx] = output[idx] * static_cast<T>(max_out);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int TPB>
|
||||
__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
|
||||
T* output,
|
||||
int* indices,
|
||||
int* source_rows,
|
||||
T* softmax_max_prob,
|
||||
const int64_t num_experts,
|
||||
const int64_t k,
|
||||
const int64_t num_rows) {
|
||||
using cub_kvp = cub::KeyValuePair<int, T>;
|
||||
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
cub_kvp thread_kvp;
|
||||
cub::ArgMax arg_max;
|
||||
|
||||
const int block_row = blockIdx.x + blockIdx.y * gridDim.x;
|
||||
if (block_row >= num_rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
const bool should_process_row = true;
|
||||
const int thread_read_offset = block_row * num_experts;
|
||||
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
thread_kvp.key = 0;
|
||||
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
|
||||
|
||||
cub_kvp inp_kvp;
|
||||
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
|
||||
const int idx = thread_read_offset + expert;
|
||||
inp_kvp.key = expert;
|
||||
inp_kvp.value = inputs_after_softmax[idx];
|
||||
|
||||
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
|
||||
const int prior_winning_expert = indices[k * block_row + prior_k];
|
||||
|
||||
if (prior_winning_expert == expert) {
|
||||
inp_kvp = thread_kvp;
|
||||
}
|
||||
}
|
||||
|
||||
thread_kvp = arg_max(inp_kvp, thread_kvp);
|
||||
}
|
||||
|
||||
const cub_kvp result_kvp =
|
||||
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
|
||||
if (threadIdx.x == 0) {
|
||||
const int idx = k * block_row + k_idx;
|
||||
// restore normalized probes
|
||||
output[idx] = result_kvp.value / T(softmax_max_prob[idx]);
|
||||
indices[idx] = should_process_row ? result_kvp.key : num_experts;
|
||||
source_rows[idx] = k_idx * num_rows + block_row;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int TPB>
|
||||
__launch_bounds__(TPB) __global__ void moe_softmax(const T* input,
|
||||
T* output,
|
||||
const int64_t num_cols,
|
||||
const int64_t num_rows) {
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
__shared__ float normalizing_factor;
|
||||
__shared__ float float_max;
|
||||
|
||||
int globalIdx = blockIdx.x + blockIdx.y * gridDim.x;
|
||||
if (globalIdx >= num_rows) {
|
||||
return;
|
||||
}
|
||||
const int64_t thread_row_offset = globalIdx * num_cols;
|
||||
|
||||
cub::Sum sum;
|
||||
float threadData(-FLT_MAX);
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData = max(static_cast<float>(input[idx]), threadData);
|
||||
}
|
||||
|
||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||
if (threadIdx.x == 0) {
|
||||
float_max = maxElem;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
threadData = 0;
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData += exp((static_cast<float>(input[idx]) - float_max));
|
||||
}
|
||||
|
||||
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
normalizing_factor = 1.f / Z;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
const float val =
|
||||
exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
|
||||
output[idx] = T(val);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int TPB>
|
||||
__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
|
||||
T* output,
|
||||
int* indices,
|
||||
int* source_rows,
|
||||
const int64_t num_experts,
|
||||
const int64_t k,
|
||||
const int64_t num_rows) {
|
||||
using cub_kvp = cub::KeyValuePair<int, T>;
|
||||
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
cub_kvp thread_kvp;
|
||||
cub::ArgMax arg_max;
|
||||
|
||||
const int block_row = blockIdx.x + blockIdx.y * gridDim.x;
|
||||
if (block_row >= num_rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
const bool should_process_row = true;
|
||||
const int thread_read_offset = block_row * num_experts;
|
||||
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
thread_kvp.key = 0;
|
||||
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
|
||||
|
||||
cub_kvp inp_kvp;
|
||||
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
|
||||
const int idx = thread_read_offset + expert;
|
||||
inp_kvp.key = expert;
|
||||
inp_kvp.value = inputs_after_softmax[idx];
|
||||
|
||||
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
|
||||
const int prior_winning_expert = indices[k * block_row + prior_k];
|
||||
|
||||
if (prior_winning_expert == expert) {
|
||||
inp_kvp = thread_kvp;
|
||||
}
|
||||
}
|
||||
|
||||
thread_kvp = arg_max(inp_kvp, thread_kvp);
|
||||
}
|
||||
|
||||
const cub_kvp result_kvp =
|
||||
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
|
||||
if (threadIdx.x == 0) {
|
||||
const int idx = k * block_row + k_idx;
|
||||
output[idx] = result_kvp.value;
|
||||
indices[idx] = should_process_row ? result_kvp.key : num_experts;
|
||||
source_rows[idx] = k_idx * num_rows + block_row;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
// ====================== TopK softmax things ===============================
|
||||
|
||||
/*
|
||||
A Top-K gating softmax written to exploit when the number of experts in the
|
||||
MoE layers are a small power of 2. This allows us to cleanly share the rows
|
||||
among the threads in a single warp and eliminate communication between warps
|
||||
(so no need to use shared mem).
|
||||
|
||||
It fuses the softmax, max and argmax into a single kernel.
|
||||
|
||||
Limitations:
|
||||
1) This implementation is intended for when the number of experts is a small
|
||||
power of 2. 2) This implementation assumes k is small, but will work for any
|
||||
k.
|
||||
*/
|
||||
|
||||
template <typename T,
|
||||
int VPT,
|
||||
int NUM_EXPERTS,
|
||||
int WARPS_PER_CTA,
|
||||
int BYTES_PER_LDG>
|
||||
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
||||
void topk_gating_softmax(const T* input,
|
||||
T* output,
|
||||
const int64_t num_rows,
|
||||
int* indices,
|
||||
int* source_rows,
|
||||
const int64_t k) {
|
||||
// We begin by enforcing compile time assertions and setting up compile time
|
||||
// constants.
|
||||
static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
|
||||
static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS),
|
||||
"NUM_EXPERTS must be power of 2");
|
||||
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG),
|
||||
"BYTES_PER_LDG must be power of 2");
|
||||
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
|
||||
|
||||
// Number of bytes each thread pulls in per load
|
||||
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T);
|
||||
static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
|
||||
static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
|
||||
static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
|
||||
|
||||
// Restrictions based on previous section.
|
||||
static_assert(
|
||||
VPT % ELTS_PER_LDG == 0,
|
||||
"The elements per thread must be a multiple of the elements per ldg");
|
||||
static_assert(WARP_SIZE % THREADS_PER_ROW == 0,
|
||||
"The threads per row must cleanly divide the threads per warp");
|
||||
static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW),
|
||||
"THREADS_PER_ROW must be power of 2");
|
||||
static_assert(THREADS_PER_ROW <= WARP_SIZE,
|
||||
"THREADS_PER_ROW can be at most warp size");
|
||||
|
||||
// We have NUM_EXPERTS elements per row. We specialize for small #experts
|
||||
static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT;
|
||||
static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
|
||||
static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
|
||||
|
||||
// Restrictions for previous section.
|
||||
static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0,
|
||||
"The elts per row must cleanly divide the total elt per warp");
|
||||
|
||||
// ===================== From this point, we finally start computing run-time
|
||||
// variables. ========================
|
||||
|
||||
// Compute CTA and warp rows. We pack multiple rows into a single warp, and a
|
||||
// block contains WARPS_PER_CTA warps. This, each block processes a chunk of
|
||||
// rows. We start by computing the start row for each block.
|
||||
const int cta_base_row = blockIdx.x * ROWS_PER_CTA;
|
||||
|
||||
// Now, using the base row per thread block, we compute the base row per warp.
|
||||
const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;
|
||||
|
||||
// The threads in a warp are split into sub-groups that will work on a row.
|
||||
// We compute row offset for each thread sub-group
|
||||
const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
|
||||
const int thread_row = warp_base_row + thread_row_in_warp;
|
||||
|
||||
// Threads with indices out of bounds should early exit here.
|
||||
if (thread_row >= num_rows) return;
|
||||
const bool should_process_row = true;
|
||||
|
||||
// We finally start setting up the read pointers for each thread. First, each
|
||||
// thread jumps to the start of the row it will read.
|
||||
const T* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
|
||||
|
||||
// Now, we compute the group each thread belong to in order to determine the
|
||||
// first column to start loads.
|
||||
const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
|
||||
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
|
||||
const T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
|
||||
|
||||
// Determine the pointer type to use to read in the data depending on the
|
||||
// BYTES_PER_LDG template param. In theory, this can support all powers of 2
|
||||
// up to 16.
|
||||
using AccessType = mctlass::AlignedArray<T, ELTS_PER_LDG>;
|
||||
|
||||
// Finally, we pull in the data from global mem
|
||||
mctlass::Array<T, VPT> row_chunk_input;
|
||||
AccessType* row_chunk_vec_ptr =
|
||||
reinterpret_cast<AccessType*>(&row_chunk_input);
|
||||
const AccessType* vec_thread_read_ptr =
|
||||
reinterpret_cast<const AccessType*>(thread_read_ptr);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
|
||||
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
|
||||
}
|
||||
|
||||
using ComputeType = float;
|
||||
using Converter = mctlass::NumericArrayConverter<ComputeType, T, VPT>;
|
||||
Converter compute_type_converter;
|
||||
mctlass::Array<ComputeType, VPT> row_chunk =
|
||||
compute_type_converter(row_chunk_input);
|
||||
|
||||
// First, we perform a max reduce within the thread. We can do the max in fp16
|
||||
// safely (I think) and just convert to float afterwards for the exp + sum
|
||||
// reduction.
|
||||
ComputeType thread_max = row_chunk[0];
|
||||
#pragma unroll
|
||||
for (int ii = 1; ii < VPT; ++ii) {
|
||||
thread_max = max(thread_max, row_chunk[ii]);
|
||||
}
|
||||
|
||||
// Now, we find the max within the thread group and distribute among the
|
||||
// threads. We use a butterfly reduce.
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
|
||||
thread_max =
|
||||
max(thread_max,
|
||||
__shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW));
|
||||
}
|
||||
|
||||
// From this point, thread max in all the threads have the max within the row.
|
||||
// Now, we subtract the max from each element in the thread and take the exp.
|
||||
// We also compute the thread local sum.
|
||||
float row_sum = 0;
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < VPT; ++ii) {
|
||||
row_chunk[ii] = expf(row_chunk[ii] - thread_max);
|
||||
row_sum += row_chunk[ii];
|
||||
}
|
||||
|
||||
// Now, we perform the sum reduce within each thread group. Similar to the max
|
||||
// reduce, we use a bufferfly pattern.
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
|
||||
row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW);
|
||||
}
|
||||
|
||||
// From this point, all threads have the max and the sum for their rows in the
|
||||
// thread_max and thread_sum variables respectively. Finally, we can scale the
|
||||
// rows for the softmax. Technically, for top-k gating we don't need to
|
||||
// compute the entire softmax row. We can likely look at the maxes and only
|
||||
// compute for the top-k values in the row. However, this kernel will likely
|
||||
// not be a bottle neck and it seems better to closer match torch and find the
|
||||
// argmax after computing the softmax.
|
||||
const float reciprocal_row_sum = 1.f / row_sum;
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < VPT; ++ii) {
|
||||
row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
|
||||
}
|
||||
|
||||
// Now, softmax_res contains the softmax of the row chunk. Now, I want to find
|
||||
// the topk elements in each row, along with the max index.
|
||||
int start_col = first_elt_read_by_thread;
|
||||
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
|
||||
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
// First, each thread does the local argmax
|
||||
float max_val = row_chunk[0];
|
||||
int expert = start_col;
|
||||
#pragma unroll
|
||||
for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD;
|
||||
++ldg, col += COLS_PER_GROUP_LDG) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ELTS_PER_LDG; ++ii) {
|
||||
float val = row_chunk[ldg * ELTS_PER_LDG + ii];
|
||||
|
||||
// No check on the experts here since columns with the smallest index
|
||||
// are processed first and only updated if > (not >=)
|
||||
if (val > max_val) {
|
||||
max_val = val;
|
||||
expert = col + ii;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now, we perform the argmax reduce. We use the butterfly pattern so threads
|
||||
// reach consensus about the max. This will be useful for K > 1 so that the
|
||||
// threads can agree on "who" had the max value. That thread can then blank out
|
||||
// their max with -inf and the warp can run more iterations...
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
|
||||
float other_max =
|
||||
__shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW);
|
||||
int other_expert =
|
||||
__shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW);
|
||||
|
||||
// We want lower indices to "win" in every thread so we break ties this
|
||||
// way
|
||||
if (other_max > max_val ||
|
||||
(other_max == max_val && other_expert < expert)) {
|
||||
max_val = other_max;
|
||||
expert = other_expert;
|
||||
}
|
||||
}
|
||||
|
||||
// Write the max for this k iteration to global memory.
|
||||
if (thread_group_idx == 0) {
|
||||
// The lead thread from each sub-group will write out the final results to
|
||||
// global memory. (This will be a single) thread per row of the
|
||||
// input/output matrices.
|
||||
const int idx = k * thread_row + k_idx;
|
||||
output[idx] = T(max_val);
|
||||
indices[idx] = should_process_row ? expert : NUM_EXPERTS;
|
||||
source_rows[idx] = k_idx * num_rows + thread_row;
|
||||
}
|
||||
|
||||
// Finally, we clear the value in the thread with the current max if there
|
||||
// is another iteration to run.
|
||||
if (k_idx + 1 < k) {
|
||||
const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
|
||||
const int thread_to_clear_in_group =
|
||||
(expert / ELTS_PER_LDG) % THREADS_PER_ROW;
|
||||
|
||||
// Only the thread in the group which produced the max will reset the
|
||||
// "winning" value to -inf.
|
||||
if (thread_group_idx == thread_to_clear_in_group) {
|
||||
const int offset_for_expert = expert % ELTS_PER_LDG;
|
||||
// Safe to set to any negative value since row_chunk values must be
|
||||
// between 0 and 1.
|
||||
row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] =
|
||||
ComputeType(-10000.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
// Constructs some constants needed to partition the work across threads at
|
||||
// compile time.
|
||||
template <typename T, int EXPERTS, int BYTES_PER_LDG>
|
||||
struct TopkConstants {
|
||||
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T);
|
||||
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 ||
|
||||
EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0,
|
||||
"");
|
||||
static constexpr int VECs_PER_THREAD =
|
||||
std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
|
||||
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
|
||||
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
|
||||
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <typename T, int EXPERTS, int WARPS_PER_TB>
|
||||
void topk_gating_softmax_launcher_helper(const T* input,
|
||||
T* output,
|
||||
int* indices,
|
||||
int* source_row,
|
||||
const int64_t num_rows,
|
||||
const int64_t num_experts,
|
||||
const int64_t k,
|
||||
cudaStream_t stream) {
|
||||
static constexpr uint64_t MAX_BYTES_PER_LDG = 16;
|
||||
static constexpr int BYTES_PER_LDG =
|
||||
std::min(MAX_BYTES_PER_LDG, sizeof(T) * EXPERTS);
|
||||
using Constants = detail::TopkConstants<T, EXPERTS, BYTES_PER_LDG>;
|
||||
static constexpr int VPT = Constants::VPT;
|
||||
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
||||
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
|
||||
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
|
||||
|
||||
dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
|
||||
topk_gating_softmax<T, VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG>
|
||||
<<<num_blocks, block_dim, 0, stream>>>(
|
||||
input, output, num_rows, indices, source_row, k);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void topk_gating_softmax_kernelLauncher(const T* input,
|
||||
T* output,
|
||||
T* softmax,
|
||||
int* indices,
|
||||
int* source_row,
|
||||
T* softmax_max_prob,
|
||||
const int64_t num_rows,
|
||||
const int64_t num_experts,
|
||||
const int64_t k,
|
||||
const bool group_moe,
|
||||
cudaStream_t stream,
|
||||
const bool topk_only_mode = false) {
|
||||
if (topk_only_mode) {
|
||||
static constexpr int TPB = 256;
|
||||
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
|
||||
moe_top_k<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
|
||||
input, output, indices, source_row, num_experts, k, num_rows);
|
||||
return;
|
||||
}
|
||||
static constexpr int WARPS_PER_TB = 4;
|
||||
|
||||
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
|
||||
case N: { \
|
||||
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB>( \
|
||||
input, output, indices, source_row, num_rows, num_experts, k, stream); \
|
||||
break; \
|
||||
}
|
||||
switch (num_experts) {
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(4)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(8)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(16)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(32)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(64)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(128)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(256)
|
||||
|
||||
default: {
|
||||
static constexpr int TPB = 256;
|
||||
if (group_moe) {
|
||||
const int group_experts = num_experts / k;
|
||||
const int softmax_num_rows = num_rows * k;
|
||||
const auto config_softmax = Get1DBlocksAnd2DGridsMoe(softmax_num_rows);
|
||||
group_moe_softmax<T, TPB>
|
||||
<<<config_softmax.block_per_grid, TPB, 0, stream>>>(
|
||||
input,
|
||||
softmax,
|
||||
softmax_max_prob,
|
||||
group_experts,
|
||||
softmax_num_rows);
|
||||
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
|
||||
moe_top_k<T, TPB>
|
||||
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
|
||||
output,
|
||||
indices,
|
||||
source_row,
|
||||
softmax_max_prob,
|
||||
num_experts,
|
||||
k,
|
||||
num_rows);
|
||||
} else {
|
||||
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
|
||||
moe_softmax<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
|
||||
input, softmax, num_experts, num_rows);
|
||||
moe_top_k<T, TPB>
|
||||
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
|
||||
output,
|
||||
indices,
|
||||
source_row,
|
||||
num_experts,
|
||||
k,
|
||||
num_rows);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ========================== Permutation things
|
||||
// =======================================
|
||||
|
||||
// Duplicated and permutes rows for MoE. In addition, reverse the permutation
|
||||
// map to help with finalizing routing.
|
||||
|
||||
// "expanded_x_row" simply means that the number of values is num_rows x k. It
|
||||
// is "expanded" since we will have to duplicate some rows in the input matrix
|
||||
// to match the dimensions. Duplicates will always get routed to separate
|
||||
// experts in the end.
|
||||
|
||||
// Note that the expanded_dest_row_to_expanded_source_row map referred to here
|
||||
// has indices in the range (0, k*rows_in_input - 1). However, it is set up so
|
||||
// that index 0, rows_in_input, 2*rows_in_input ... (k-1)*rows_in_input all map
|
||||
// to row 0 in the original matrix. Thus, to know where to read in the source
|
||||
// matrix, we simply take the modulus of the expanded index.
|
||||
|
||||
template <typename T, int VecSize>
|
||||
__global__ void initialize_moe_routing_kernel(
|
||||
const T* unpermuted_input,
|
||||
T* permuted_output,
|
||||
const int* expanded_dest_row_to_expanded_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row,
|
||||
const int64_t num_rows,
|
||||
const int64_t active_rows,
|
||||
const int64_t cols,
|
||||
const int64_t num_rows_k) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT src_vec;
|
||||
|
||||
// Reverse permutation map.
|
||||
// I do this so that later, we can use the source -> dest map to do the k-way
|
||||
// reduction and unpermuting. I need the reverse map for that reduction to
|
||||
// allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1
|
||||
// thread block will be responsible for all k summations.
|
||||
const int expanded_dest_row = blockIdx.x + blockIdx.y * gridDim.x;
|
||||
if (expanded_dest_row >= num_rows_k) return;
|
||||
const int expanded_source_row =
|
||||
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
|
||||
if (threadIdx.x == 0) {
|
||||
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
|
||||
expanded_dest_row;
|
||||
}
|
||||
|
||||
if ((blockIdx.x + blockIdx.y * gridDim.x) < active_rows) {
|
||||
// Duplicate and permute rows
|
||||
const int source_row = expanded_source_row % num_rows;
|
||||
|
||||
const T* source_row_ptr = unpermuted_input + source_row * cols;
|
||||
T* dest_row_ptr = permuted_output + expanded_dest_row * cols;
|
||||
|
||||
for (int tid = threadIdx.x * VecSize; tid < cols;
|
||||
tid += blockDim.x * VecSize) {
|
||||
// dest_row_ptr[tid] = source_row_ptr[tid];
|
||||
Load<T, VecSize>(&source_row_ptr[tid], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &dest_row_ptr[tid]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void initialize_moe_routing_kernelLauncher(
|
||||
const T* unpermuted_input,
|
||||
T* permuted_output,
|
||||
const int* expanded_dest_row_to_expanded_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row,
|
||||
const int64_t num_rows,
|
||||
const int64_t active_rows,
|
||||
const int64_t cols,
|
||||
const int64_t k,
|
||||
cudaStream_t stream) {
|
||||
const int threads = std::min(cols, int64_t(1024));
|
||||
constexpr int max_pack_size = 16 / sizeof(T);
|
||||
const auto config_initialize = Get1DBlocksAnd2DGridsMoe(num_rows * k);
|
||||
if (cols % max_pack_size == 0) {
|
||||
initialize_moe_routing_kernel<T, max_pack_size>
|
||||
<<<config_initialize.block_per_grid, threads, 0, stream>>>(
|
||||
unpermuted_input,
|
||||
permuted_output,
|
||||
expanded_dest_row_to_expanded_source_row,
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
num_rows,
|
||||
k * active_rows,
|
||||
cols,
|
||||
num_rows * k);
|
||||
} else {
|
||||
initialize_moe_routing_kernel<T, 1>
|
||||
<<<config_initialize.block_per_grid, threads, 0, stream>>>(
|
||||
unpermuted_input,
|
||||
permuted_output,
|
||||
expanded_dest_row_to_expanded_source_row,
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
num_rows,
|
||||
k * active_rows,
|
||||
cols,
|
||||
num_rows * k);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================== Infer GEMM sizes
|
||||
// =================================
|
||||
__device__ inline int find_total_elts_leq_target(int* sorted_indices,
|
||||
const int64_t arr_length,
|
||||
const int64_t target) {
|
||||
int64_t low = 0, high = arr_length - 1, target_location = -1;
|
||||
while (low <= high) {
|
||||
int64_t mid = (low + high) / 2;
|
||||
|
||||
if (sorted_indices[mid] > target) {
|
||||
high = mid - 1;
|
||||
} else {
|
||||
low = mid + 1;
|
||||
target_location = mid;
|
||||
}
|
||||
}
|
||||
return target_location + 1;
|
||||
}
|
||||
|
||||
void compute_total_rows_before_expert(int* sorted_indices,
|
||||
const int64_t total_indices,
|
||||
const int64_t num_experts,
|
||||
int32_t* total_rows_before_expert,
|
||||
cudaStream_t stream);
|
||||
|
||||
// Final kernel to unpermute and scale
|
||||
// This kernel unpermutes the original data, does the k-way reduction and
|
||||
// performs the final skip connection.
|
||||
template <typename T, int RESIDUAL_NUM>
|
||||
__global__ void finalize_moe_routing_kernel(
|
||||
const T* expanded_permuted_rows,
|
||||
T* reduced_unpermuted_output,
|
||||
const T* bias,
|
||||
const float* scales,
|
||||
const int* expanded_source_row_to_expanded_dest_row,
|
||||
const int* expert_for_source_row,
|
||||
const int64_t cols,
|
||||
const int64_t k,
|
||||
const int64_t compute_bias,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor,
|
||||
const int64_t num_rows) {
|
||||
const int original_row = blockIdx.x + blockIdx.y * gridDim.x;
|
||||
// const int original_row = blockIdx.x;
|
||||
// const int num_rows = gridDim.x;
|
||||
if (original_row >= num_rows) return;
|
||||
T* reduced_row_ptr = reduced_unpermuted_output + original_row * cols;
|
||||
|
||||
for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) {
|
||||
T thread_output{0.f};
|
||||
float row_rescale{0.f};
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
const int expanded_original_row = original_row + k_idx * num_rows;
|
||||
const int expanded_permuted_row =
|
||||
expanded_source_row_to_expanded_dest_row[expanded_original_row];
|
||||
|
||||
const int64_t k_offset = original_row * k + k_idx;
|
||||
const float row_scale = scales[k_offset];
|
||||
row_rescale = row_rescale + row_scale;
|
||||
|
||||
const T* expanded_permuted_rows_row_ptr =
|
||||
expanded_permuted_rows + expanded_permuted_row * cols;
|
||||
|
||||
const int expert_idx = expert_for_source_row[k_offset];
|
||||
const T* bias_ptr = bias ? bias + expert_idx * cols : nullptr;
|
||||
const T bias_value = bias_ptr ? bias_ptr[tid] : T{0.f};
|
||||
|
||||
thread_output =
|
||||
static_cast<float>(thread_output) +
|
||||
row_scale * static_cast<float>(
|
||||
expanded_permuted_rows_row_ptr[tid] +
|
||||
bias_value *
|
||||
static_cast<T>(static_cast<float>(compute_bias)));
|
||||
}
|
||||
|
||||
thread_output = static_cast<float>(thread_output) /
|
||||
(norm_topk_prob ? row_rescale : 1.0f) *
|
||||
routed_scaling_factor;
|
||||
reduced_row_ptr[tid] = thread_output;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void finalize_moe_routing_kernelLauncher(
|
||||
const T* expanded_permuted_rows,
|
||||
T* reduced_unpermuted_output,
|
||||
const T* bias,
|
||||
const float* scales,
|
||||
const int* expanded_source_row_to_expanded_dest_row,
|
||||
const int* expert_for_source_row,
|
||||
const int64_t num_rows,
|
||||
const int64_t cols,
|
||||
const int64_t k,
|
||||
const int64_t compute_bias,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor,
|
||||
cudaStream_t stream) {
|
||||
const int threads = std::min(cols, int64_t(1024));
|
||||
const auto config_final = Get1DBlocksAnd2DGridsMoe(num_rows);
|
||||
|
||||
finalize_moe_routing_kernel<T, 1>
|
||||
<<<config_final.block_per_grid, threads, 0, stream>>>(
|
||||
expanded_permuted_rows,
|
||||
reduced_unpermuted_output,
|
||||
bias,
|
||||
scales,
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
expert_for_source_row,
|
||||
cols,
|
||||
k,
|
||||
compute_bias,
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
num_rows);
|
||||
}
|
||||
|
||||
// ========================= TopK Softmax specializations
|
||||
// ===========================
|
||||
template void topk_gating_softmax_kernelLauncher(const float*,
|
||||
float*,
|
||||
float*,
|
||||
int*,
|
||||
int*,
|
||||
float*,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const bool,
|
||||
cudaStream_t,
|
||||
const bool);
|
||||
template void topk_gating_softmax_kernelLauncher(const half*,
|
||||
half*,
|
||||
half*,
|
||||
int*,
|
||||
int*,
|
||||
half*,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const bool,
|
||||
cudaStream_t,
|
||||
const bool);
|
||||
#ifdef PADDLE_CUDA_BF16
|
||||
template void topk_gating_softmax_kernelLauncher(const __nv_bfloat16*,
|
||||
__nv_bfloat16*,
|
||||
__nv_bfloat16*,
|
||||
int*,
|
||||
int*,
|
||||
__nv_bfloat16*,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const bool,
|
||||
cudaStream_t,
|
||||
const bool);
|
||||
#endif
|
||||
// ===================== Specializations for init routing
|
||||
// =========================
|
||||
template void initialize_moe_routing_kernelLauncher(const float*,
|
||||
float*,
|
||||
const int*,
|
||||
int*,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
cudaStream_t);
|
||||
template void initialize_moe_routing_kernelLauncher(const half*,
|
||||
half*,
|
||||
const int*,
|
||||
int*,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
cudaStream_t);
|
||||
#ifdef PADDLE_CUDA_BF16
|
||||
template void initialize_moe_routing_kernelLauncher(const __nv_bfloat16*,
|
||||
__nv_bfloat16*,
|
||||
const int*,
|
||||
int*,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
cudaStream_t);
|
||||
#endif
|
||||
// ==================== Specializations for final routing
|
||||
// ===================================
|
||||
template void finalize_moe_routing_kernelLauncher(const float*,
|
||||
float*,
|
||||
const float*,
|
||||
const float*,
|
||||
const int*,
|
||||
const int*,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const bool,
|
||||
const float,
|
||||
cudaStream_t);
|
||||
template void finalize_moe_routing_kernelLauncher(const half*,
|
||||
half*,
|
||||
const half*,
|
||||
const float*,
|
||||
const int*,
|
||||
const int*,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const bool,
|
||||
const float,
|
||||
cudaStream_t);
|
||||
#ifdef PADDLE_CUDA_BF16
|
||||
template void finalize_moe_routing_kernelLauncher(const __nv_bfloat16*,
|
||||
__nv_bfloat16*,
|
||||
const __nv_bfloat16*,
|
||||
const float*,
|
||||
const int*,
|
||||
const int*,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const bool,
|
||||
const float,
|
||||
cudaStream_t);
|
||||
#endif
|
417
custom_ops/metax_ops/mc_fused_moe_helper.h
Normal file
417
custom_ops/metax_ops/mc_fused_moe_helper.h
Normal file
@@ -0,0 +1,417 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#include "mctlass/numeric_conversion.h"
|
||||
#include "mctlassEx/mctlassEx.h"
|
||||
#include "fused_moe_helper.h"
|
||||
|
||||
|
||||
template <typename ElementA, typename ElementB, typename ElementC>
|
||||
void mc_grouped_gemm_basic_kernel(
|
||||
const ElementA* ptrA,
|
||||
mctlassExOrder_t majorA,
|
||||
const ElementB* ptrB,
|
||||
mctlassExOrder_t majorB,
|
||||
const ElementA* ptrScale,
|
||||
const ElementA* ptrBias,
|
||||
ElementC* ptrC,
|
||||
mctlassExOrder_t majorC,
|
||||
const int *ptrSegInd,
|
||||
int numExperts,
|
||||
int m, // expanded_active_expert_rows
|
||||
int n, // inter_dim
|
||||
int k, // hidden_size
|
||||
mcStream_t stream) {
|
||||
mctlassExHandle_t handle;
|
||||
mctlassExHandleCreate(&handle);
|
||||
|
||||
int* ptrMNumTilesInd;
|
||||
mcMallocAsync((void**)&ptrMNumTilesInd, sizeof(int) * numExperts, stream);
|
||||
|
||||
mctlassExMatrixLayout_t matLayoutA;
|
||||
mctlassExMatrixLayout_t matLayoutB;
|
||||
mctlassExMatrixLayout_t matLayoutC;
|
||||
|
||||
// mat A: (m, k)
|
||||
mctlassExMatrixLayoutCreate(&matLayoutA, mctlassExDataType::MCTLASS_EX_BF16, m, k, k);
|
||||
mctlassExMatrixLayoutSetAttribute(matLayoutA, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||
&majorA, sizeof(mctlassExOrder_t));
|
||||
// mat B: (num_experts, n, k)
|
||||
mctlassExMatrixLayoutCreate(&matLayoutB, mctlassExDataType::MCTLASS_EX_INT8, k, n, k);
|
||||
mctlassExMatrixLayoutSetAttribute(matLayoutB, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||
&majorB, sizeof(mctlassExOrder_t));
|
||||
mctlassExMatrixLayoutSetAttribute(matLayoutB, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&numExperts, sizeof(int));
|
||||
// mat C: (m, n)
|
||||
mctlassExMatrixLayoutCreate(&matLayoutC, mctlassExDataType::MCTLASS_EX_BF16, m, n, n);
|
||||
mctlassExMatrixLayoutSetAttribute(matLayoutC, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||
&majorC, sizeof(mctlassExOrder_t));
|
||||
// bias: (num_experts, n)
|
||||
// scale: (num, n)
|
||||
|
||||
mctlassExDesc_t mctlass_desc;
|
||||
mctlassExCreateDesc(&mctlass_desc);
|
||||
mctlassExDataType input_type = mctlassExDataType::MCTLASS_EX_BF16;
|
||||
mctlassExDataType scale_type = mctlassExDataType::MCTLASS_EX_INT8;
|
||||
mctlassExDataType compute_type = mctlassExDataType::MCTLASS_EX_FP32;
|
||||
mctlassExEpilogueType epilogue_type = mctlassExEpilogueType::MCTLASS_EX_GEMM_DEFAULT;
|
||||
if (ptrBias) {
|
||||
epilogue_type = mctlassExEpilogueType::MCTLASS_EX_GEMM_BIAS_PERGROUP;
|
||||
}
|
||||
// set scale
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_B_SCALE_POINTER,
|
||||
&ptrScale, sizeof(ptrScale));
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_B_SCALE_TYPE,
|
||||
&scale_type, sizeof(mctlassExDataType));
|
||||
// set bias
|
||||
if (ptrBias) {
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_BIAS_POINTER,
|
||||
&ptrBias, sizeof(ptrBias));
|
||||
}
|
||||
// set coumpute type
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_COMPUTE_TYPE,
|
||||
&compute_type, sizeof(mctlassExDataType));
|
||||
// set epilogue type
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_EPILOGUE_TYPE,
|
||||
&epilogue_type, sizeof(mctlassExEpilogueType));
|
||||
|
||||
const mctlassExContiguousGroupedGemmAlgo_t algo = mctlassExContiguousGroupedGemmAlgo_t::MCTLASS_EX_CONTIGUOUS_GROUPED_ALGO_SEGPTR;
|
||||
int blocksizeM = mctlassExContiguousGroupedGemmGetBlocksizeM(handle, mctlass_desc, matLayoutA, matLayoutB, matLayoutC, &algo);
|
||||
mctlassExContiguousGroupedGemmComputeMNumTilesIndptr(handle, mctlass_desc, matLayoutA, matLayoutB, matLayoutC, &algo, ptrSegInd, ptrMNumTilesInd, numExperts, blocksizeM);
|
||||
|
||||
mctlassExContiguousGroupedGemmBasic(handle, mctlass_desc,
|
||||
ptrA, matLayoutA,
|
||||
ptrB, matLayoutB,
|
||||
ptrC, matLayoutC,
|
||||
ptrSegInd, nullptr, ptrMNumTilesInd,
|
||||
&algo, nullptr, 0, stream);
|
||||
|
||||
mctlassExHandleDestroy(handle);
|
||||
mctlassExMatrixLayoutDestroy(matLayoutA);
|
||||
mctlassExMatrixLayoutDestroy(matLayoutB);
|
||||
mctlassExMatrixLayoutDestroy(matLayoutC);
|
||||
mctlassExDestroyDesc(mctlass_desc);
|
||||
mcFreeAsync(ptrMNumTilesInd, stream);
|
||||
}
|
||||
|
||||
template<typename T, typename ElementA, typename ElementB, typename ElementC>
|
||||
class McMoeHelper {
|
||||
public:
|
||||
McMoeHelper(const std::string gemm_method): gemm_method_(gemm_method) {}
|
||||
|
||||
// -------- getWorkspaceSize -------- //
|
||||
template <typename KeyT>
|
||||
size_t getWorkspaceSize(const int64_t num_rows,
|
||||
const int64_t hidden_size,
|
||||
const int64_t inter_size,
|
||||
const int64_t num_experts,
|
||||
const int64_t k) {
|
||||
const size_t buf_size = AlignTo16(k * num_rows * hidden_size);
|
||||
const size_t interbuf_size = AlignTo16(k * num_rows * inter_size);
|
||||
const size_t padded_experts = AlignTo16(num_experts);
|
||||
const size_t num_moe_inputs = AlignTo16(k * num_rows);
|
||||
// softmax output, permuted_rows and permuted_experts have moved to outside
|
||||
// of moe kernel, allocate them in Encoder or Decoder before invoking
|
||||
// FfnLayer forward.
|
||||
size_t total_ws_bytes =
|
||||
5 * num_moe_inputs *
|
||||
sizeof(int); // source_rows_, permuted_rows_, permuted_experts_
|
||||
total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data
|
||||
total_ws_bytes +=
|
||||
padded_experts * sizeof(int32_t); // Hold total_rows_before_expert_
|
||||
|
||||
const size_t bytes_for_fc1_result = interbuf_size * sizeof(KeyT);
|
||||
const size_t sorter_ws_size_bytes =
|
||||
AlignTo16(sorter_.getWorkspaceSize(num_rows));
|
||||
sorter_.update_num_experts(num_experts);
|
||||
|
||||
int64_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result;
|
||||
if (sorter_ws_size_bytes > bytes_for_fc1_result) {
|
||||
int64_t remaining_bytes =
|
||||
AlignTo16(sorter_ws_size_bytes - bytes_for_fc1_result);
|
||||
bytes_for_intermediate_and_sorting += remaining_bytes;
|
||||
}
|
||||
|
||||
total_ws_bytes +=
|
||||
bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub
|
||||
// sorting workspace
|
||||
|
||||
int64_t num_softmax_outs = 0;
|
||||
const bool is_pow_2 =
|
||||
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
if (!is_pow_2 || num_experts > 256) {
|
||||
num_softmax_outs = AlignTo16(num_rows * num_experts);
|
||||
}
|
||||
|
||||
total_ws_bytes += num_softmax_outs * sizeof(float);
|
||||
|
||||
return total_ws_bytes;
|
||||
}
|
||||
|
||||
void computeFFN(const paddle::Tensor *input,
|
||||
const paddle::Tensor *gate_weight,
|
||||
const paddle::Tensor *ffn1_weight,
|
||||
const paddle::Tensor *ffn1_scale,
|
||||
const paddle::Tensor *ffn1_bias,
|
||||
const paddle::Tensor *ffn2_weight,
|
||||
const paddle::Tensor *ffn2_scale,
|
||||
const paddle::Tensor *ffn2_bias,
|
||||
const paddle::Tensor *moe_token_type_ids,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor,
|
||||
const std::string moe_type,
|
||||
paddle::Tensor *output) {
|
||||
auto *input_activations = input->data<T>();
|
||||
auto *gating_weights = gate_weight->data<float>();
|
||||
const T *fc1_expert_biases = ffn1_bias ? ffn1_bias->data<T>() : nullptr;
|
||||
const T *fc2_expert_biases = ffn2_bias ? ffn2_bias->data<T>() : nullptr;
|
||||
|
||||
auto *output_ = output->data<T>();
|
||||
auto stream = input->stream();
|
||||
auto place = input->place();
|
||||
auto input_type = input->dtype();
|
||||
|
||||
auto input_dims = input->dims();
|
||||
auto ffn1_dims = ffn1_weight->dims();
|
||||
int64_t token_num = 0;
|
||||
if (input_dims.size() == 3) {
|
||||
token_num = input_dims[0] * input_dims[1];
|
||||
} else {
|
||||
token_num = input_dims[0];
|
||||
}
|
||||
const int64_t num_rows = token_num;
|
||||
|
||||
const int64_t hidden_size = ffn1_dims[2];
|
||||
int64_t inter_dim = 0;
|
||||
if (moe_type == "qkv") {
|
||||
inter_dim = ffn1_dims[2] * ffn1_dims[3] * ffn1_dims[4];
|
||||
} else {
|
||||
inter_dim = ffn1_dims[1];
|
||||
}
|
||||
|
||||
// if (gemm_method == "weight_only_int4") {
|
||||
// inter_dim = inter_dim * 2;
|
||||
// }
|
||||
|
||||
const int64_t inter_size = inter_dim;
|
||||
const int64_t num_experts = ffn1_dims[0];
|
||||
const int64_t k = moe_topk;
|
||||
|
||||
|
||||
int64_t bytes =
|
||||
getWorkspaceSize<T>(num_rows, hidden_size, inter_size, num_experts, k);
|
||||
|
||||
// Pointers
|
||||
int *expert_for_source_row;
|
||||
int *source_rows_;
|
||||
int *permuted_rows_;
|
||||
int *permuted_experts_;
|
||||
int *expanded_source_row_to_expanded_dest_row;
|
||||
|
||||
T *permuted_data_;
|
||||
int32_t *total_rows_before_expert_;
|
||||
T *fc1_result_;
|
||||
float *softmax_out_;
|
||||
|
||||
paddle::Tensor ws_ptr_tensor =
|
||||
GetEmptyTensor({bytes}, paddle::DataType::INT8, place);
|
||||
int8_t *ws_ptr = ws_ptr_tensor.data<int8_t>();
|
||||
|
||||
const int64_t buf_size = AlignTo16(k * num_rows * hidden_size);
|
||||
const int64_t interbuf_size = AlignTo16(k * num_rows * inter_size);
|
||||
const int64_t padded_experts = AlignTo16(num_experts);
|
||||
const int64_t num_moe_inputs = AlignTo16(k * num_rows);
|
||||
|
||||
expert_for_source_row = reinterpret_cast<int *>(ws_ptr);
|
||||
source_rows_ = expert_for_source_row + num_moe_inputs;
|
||||
permuted_rows_ = source_rows_ + num_moe_inputs;
|
||||
permuted_experts_ = permuted_rows_ + num_moe_inputs;
|
||||
expanded_source_row_to_expanded_dest_row =
|
||||
permuted_experts_ + num_moe_inputs;
|
||||
permuted_data_ = reinterpret_cast<T *>(
|
||||
expanded_source_row_to_expanded_dest_row + num_moe_inputs);
|
||||
total_rows_before_expert_ =
|
||||
reinterpret_cast<int32_t *>(permuted_data_ + buf_size);
|
||||
fc1_result_ =
|
||||
reinterpret_cast<T *>(total_rows_before_expert_ + padded_experts);
|
||||
|
||||
const bool is_pow_2 =
|
||||
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
if (!is_pow_2 || num_experts > 256) {
|
||||
softmax_out_ = reinterpret_cast<float *>(fc1_result_ + interbuf_size);
|
||||
} else {
|
||||
softmax_out_ = nullptr;
|
||||
}
|
||||
|
||||
paddle::Tensor expert_scales_float_tensor =
|
||||
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||
float *expert_scales_float = expert_scales_float_tensor.data<float>();
|
||||
|
||||
float *softmax_max_prob = nullptr;
|
||||
if (group_moe) {
|
||||
paddle::Tensor softmax_max_prob_tensor = GetEmptyTensor(
|
||||
{num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||
// (TODO: check fill success ?)
|
||||
paddle::experimental::fill(softmax_max_prob_tensor, 0.f);
|
||||
softmax_max_prob = softmax_max_prob_tensor.data<float>();
|
||||
}
|
||||
|
||||
paddle::Tensor fc1_out_tensor =
|
||||
GetEmptyTensor({num_rows * k, inter_size}, input_type, place);
|
||||
T *fc1_out = fc1_out_tensor.data<T>();
|
||||
|
||||
auto input_cast_tensor =
|
||||
paddle::experimental::cast(*input, paddle::DataType::FLOAT32);
|
||||
auto gate_tensor =
|
||||
paddle::experimental::matmul(input_cast_tensor, *gate_weight);
|
||||
float *gating_output = gate_tensor.data<float>();
|
||||
|
||||
if (moe_token_type_ids) {
|
||||
auto *moe_token_type_ids_out = moe_token_type_ids->data<int>();
|
||||
moe_token_type_ids_kernelLauncher<float>(gating_output,
|
||||
moe_token_type_ids_out,
|
||||
num_rows,
|
||||
num_experts,
|
||||
k,
|
||||
stream);
|
||||
}
|
||||
|
||||
topk_gating_softmax_kernelLauncher<float>(gating_output,
|
||||
expert_scales_float,
|
||||
softmax_out_,
|
||||
expert_for_source_row,
|
||||
source_rows_,
|
||||
softmax_max_prob,
|
||||
num_rows,
|
||||
num_experts,
|
||||
k,
|
||||
group_moe,
|
||||
stream);
|
||||
|
||||
const int64_t sorter_ws_size_bytes =
|
||||
AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows)));
|
||||
|
||||
sorter_.run(fc1_result_,
|
||||
sorter_ws_size_bytes,
|
||||
expert_for_source_row,
|
||||
permuted_experts_,
|
||||
source_rows_,
|
||||
permuted_rows_,
|
||||
k * num_rows,
|
||||
false,
|
||||
stream);
|
||||
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input_activations,
|
||||
permuted_data_,
|
||||
permuted_rows_,
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
num_rows,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
k,
|
||||
stream);
|
||||
|
||||
const int64_t expanded_active_expert_rows = k * num_rows;
|
||||
|
||||
compute_total_rows_before_expert(permuted_experts_,
|
||||
expanded_active_expert_rows,
|
||||
num_experts,
|
||||
total_rows_before_expert_,
|
||||
stream);
|
||||
|
||||
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ROWMAJOR_ORDER;
|
||||
mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_COLUMNMAJOR_ORDER;
|
||||
|
||||
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
|
||||
reinterpret_cast<const ElementA *>(permuted_data_),
|
||||
row_major,
|
||||
reinterpret_cast<const ElementB *>(ffn1_weight->data<ElementB>()),
|
||||
column_major,
|
||||
reinterpret_cast<const ElementA *>(ffn1_scale->data<T>()),
|
||||
reinterpret_cast<const ElementA *>(fc1_expert_biases),
|
||||
reinterpret_cast<ElementC *>(fc1_out),
|
||||
row_major,
|
||||
total_rows_before_expert_,
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
inter_size,
|
||||
hidden_size,
|
||||
stream);
|
||||
|
||||
if (moe_type == "ffn") {
|
||||
auto act_out_tensor =
|
||||
paddle::experimental::swiglu(fc1_out_tensor, nullptr);
|
||||
auto act_out = act_out_tensor.data<T>();
|
||||
|
||||
paddle::Tensor fc2_output_tensor =
|
||||
GetEmptyTensor({k * num_rows, hidden_size}, input_type, place);
|
||||
T *fc2_result = fc2_output_tensor.data<T>();
|
||||
|
||||
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
|
||||
reinterpret_cast<const ElementA *>(act_out),
|
||||
row_major,
|
||||
reinterpret_cast<const ElementB *>(ffn2_weight->data<ElementB>()),
|
||||
column_major,
|
||||
reinterpret_cast<const ElementA *>(ffn2_scale->data<T>()),
|
||||
nullptr,
|
||||
reinterpret_cast<ElementC *>(fc2_result),
|
||||
row_major,
|
||||
total_rows_before_expert_,
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
hidden_size,
|
||||
inter_size / 2,
|
||||
stream);
|
||||
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
fc2_result,
|
||||
output_,
|
||||
fc2_expert_biases,
|
||||
reinterpret_cast<float *>(expert_scales_float),
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
expert_for_source_row,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
k,
|
||||
static_cast<int>(1),
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
} else {
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
// fc2_result,
|
||||
fc1_out,
|
||||
output_,
|
||||
fc1_expert_biases, // fc2_expert_biases,
|
||||
reinterpret_cast<float *>(expert_scales_float),
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
expert_for_source_row,
|
||||
num_rows,
|
||||
inter_size,
|
||||
k,
|
||||
static_cast<int>(0),
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::string gemm_method_;
|
||||
CubKeyValueSorter sorter_;
|
||||
};
|
274
custom_ops/metax_ops/moe_dispatch.cu
Normal file
274
custom_ops/metax_ops/moe_dispatch.cu
Normal file
@@ -0,0 +1,274 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
#pragma GCC diagnostic ignored "-Wunused-function"
|
||||
#pragma once
|
||||
|
||||
#include "fused_moe_helper.h"
|
||||
#include "fused_moe_op.h"
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
|
||||
template <paddle::DataType T>
|
||||
void MoeDispatchKernel(const paddle::Tensor& input,
|
||||
const paddle::Tensor& gating_output,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
const bool topk_only_mode,
|
||||
const int num_rows,
|
||||
const int hidden_size,
|
||||
const int expert_num,
|
||||
paddle::Tensor* permute_input,
|
||||
paddle::Tensor* tokens_expert_prefix_sum,
|
||||
paddle::Tensor* permute_indices_per_token,
|
||||
paddle::Tensor* top_k_weight,
|
||||
paddle::Tensor* top_k_indices) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto stream = input.stream();
|
||||
auto place = input.place();
|
||||
|
||||
if (group_moe) {
|
||||
// Check if expert_num is divisible by moe_topk, else throw an error
|
||||
PADDLE_ENFORCE_EQ(expert_num % moe_topk,
|
||||
0,
|
||||
common::errors::InvalidArgument(
|
||||
"The number of experts (expert_num) "
|
||||
"must be divisible by moe_topk. "
|
||||
"Got expert_num = %d and moe_topk = %d.",
|
||||
expert_num,
|
||||
moe_topk));
|
||||
}
|
||||
|
||||
const int num_moe_inputs = AlignTo16(num_rows * moe_topk);
|
||||
const int bytes = num_moe_inputs * sizeof(int);
|
||||
|
||||
CubKeyValueSorter sorter_;
|
||||
sorter_.update_num_experts(expert_num);
|
||||
|
||||
const int sorter_ws_size_bytes =
|
||||
AlignTo16(sorter_.getWorkspaceSize(moe_topk * num_rows));
|
||||
const int sort_tmp_in_out_size = num_moe_inputs * 2 * sizeof(int);
|
||||
|
||||
paddle::Tensor ws_ptr_tensor =
|
||||
GetEmptyTensor({bytes + sorter_ws_size_bytes + sort_tmp_in_out_size},
|
||||
paddle::DataType::INT8,
|
||||
place);
|
||||
|
||||
int8_t* ws_ptr = ws_ptr_tensor.data<int8_t>();
|
||||
int* source_rows_ = reinterpret_cast<int*>(ws_ptr);
|
||||
int8_t* sorter_ws_ptr = reinterpret_cast<int8_t*>(ws_ptr + bytes);
|
||||
int* permuted_experts_ =
|
||||
reinterpret_cast<int*>(sorter_ws_ptr + sorter_ws_size_bytes);
|
||||
int* permuted_rows_ = permuted_experts_ + num_moe_inputs;
|
||||
|
||||
int* expert_for_source_row = top_k_indices->data<int>();
|
||||
|
||||
float* softmax_max_prob = nullptr;
|
||||
if (group_moe) {
|
||||
paddle::Tensor softmax_max_prob_tensor =
|
||||
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||
paddle::experimental::fill(softmax_max_prob_tensor, 0.f);
|
||||
softmax_max_prob = softmax_max_prob_tensor.data<float>();
|
||||
}
|
||||
|
||||
float* softmax_out_;
|
||||
|
||||
const bool is_pow_2 =
|
||||
(expert_num != 0) && ((expert_num & (expert_num - 1)) == 0);
|
||||
|
||||
paddle::Tensor softmax_buffer;
|
||||
|
||||
if (!is_pow_2 || expert_num > 256 || group_moe) {
|
||||
softmax_buffer = GetEmptyTensor(
|
||||
{num_rows * expert_num}, paddle::DataType::FLOAT32, place);
|
||||
softmax_out_ = softmax_buffer.data<float>();
|
||||
} else {
|
||||
softmax_out_ = nullptr;
|
||||
}
|
||||
|
||||
topk_gating_softmax_kernelLauncher<float>(gating_output.data<float>(),
|
||||
top_k_weight->data<float>(),
|
||||
softmax_out_,
|
||||
expert_for_source_row,
|
||||
source_rows_,
|
||||
softmax_max_prob,
|
||||
num_rows,
|
||||
expert_num,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
stream,
|
||||
topk_only_mode);
|
||||
|
||||
sorter_.run(reinterpret_cast<void*>(sorter_ws_ptr),
|
||||
sorter_ws_size_bytes,
|
||||
expert_for_source_row,
|
||||
permuted_experts_,
|
||||
source_rows_,
|
||||
permuted_rows_,
|
||||
moe_topk * num_rows,
|
||||
false,
|
||||
stream);
|
||||
|
||||
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input.data<data_t>(),
|
||||
permute_input->data<data_t>(),
|
||||
permuted_rows_,
|
||||
permute_indices_per_token->data<int32_t>(),
|
||||
num_rows,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
moe_topk,
|
||||
stream);
|
||||
|
||||
|
||||
compute_total_rows_before_expert(
|
||||
permuted_experts_,
|
||||
moe_topk * num_rows,
|
||||
expert_num,
|
||||
tokens_expert_prefix_sum->data<int32_t>(),
|
||||
stream);
|
||||
}
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& gating_output,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
const bool topk_only_mode) {
|
||||
const auto input_type = input.dtype();
|
||||
auto place = input.place();
|
||||
int token_rows = 0;
|
||||
auto input_dims = input.dims();
|
||||
auto gating_dims = gating_output.dims();
|
||||
const int expert_num = gating_dims[gating_dims.size() - 1];
|
||||
|
||||
if (input_dims.size() == 3) {
|
||||
token_rows = input_dims[0] * input_dims[1];
|
||||
} else {
|
||||
token_rows = input_dims[0];
|
||||
}
|
||||
const int num_rows = token_rows;
|
||||
const int hidden_size = input.dims()[input_dims.size() - 1];
|
||||
|
||||
auto permute_input =
|
||||
GetEmptyTensor({moe_topk * num_rows, hidden_size}, input_type, place);
|
||||
// correspond to the weighted coefficients of the results from each expert.
|
||||
auto top_k_weight =
|
||||
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||
auto top_k_indices =
|
||||
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::INT32, place);
|
||||
|
||||
auto tokens_expert_prefix_sum =
|
||||
GetEmptyTensor({expert_num}, paddle::DataType::INT32, place);
|
||||
auto permute_indices_per_token =
|
||||
GetEmptyTensor({moe_topk, num_rows}, paddle::DataType::INT32, place);
|
||||
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeDispatchKernel<paddle::DataType::BFLOAT16>(input,
|
||||
gating_output,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
topk_only_mode,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
expert_num,
|
||||
&permute_input,
|
||||
&tokens_expert_prefix_sum,
|
||||
&permute_indices_per_token,
|
||||
&top_k_weight,
|
||||
&top_k_indices);
|
||||
break;
|
||||
// case paddle::DataType::FLOAT16:
|
||||
// MoeDispatchKernel<paddle::DataType::FLOAT16>(input,
|
||||
// gating_output,
|
||||
// moe_topk,
|
||||
// group_moe,
|
||||
// topk_only_mode,
|
||||
// num_rows,
|
||||
// hidden_size,
|
||||
// expert_num,
|
||||
// &permute_input,
|
||||
// &tokens_expert_prefix_sum,
|
||||
// &permute_indices_per_token,
|
||||
// &top_k_weight,
|
||||
// &top_k_indices);
|
||||
// break;
|
||||
default:
|
||||
PD_THROW("Only support bf16 for MoeDispatchKernel");
|
||||
}
|
||||
return {permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
permute_indices_per_token,
|
||||
top_k_weight,
|
||||
top_k_indices};
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
|
||||
const std::vector<int64_t>& input_shape,
|
||||
const std::vector<int64_t>& gating_output_shape,
|
||||
const int moe_topk) {
|
||||
int token_rows = -1;
|
||||
|
||||
if (input_shape.size() == 3) {
|
||||
token_rows = input_shape[0] * input_shape[1];
|
||||
} else {
|
||||
token_rows = input_shape[0];
|
||||
}
|
||||
const int expert_num = gating_output_shape[gating_output_shape.size() - 1];
|
||||
const int num_rows = token_rows;
|
||||
const int hidden_size = input_shape[input_shape.size() - 1];
|
||||
|
||||
return {{moe_topk * num_rows, hidden_size},
|
||||
{expert_num},
|
||||
{moe_topk, num_rows},
|
||||
{num_rows, moe_topk},
|
||||
{num_rows, moe_topk}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
|
||||
const paddle::DataType& input_dtype,
|
||||
const paddle::DataType& gating_output_dtype,
|
||||
const int moe_topk) {
|
||||
return {input_dtype,
|
||||
paddle::DataType::INT64,
|
||||
paddle::DataType::INT32,
|
||||
paddle::DataType::FLOAT32,
|
||||
paddle::DataType::INT32};
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_OP(moe_expert_dispatch)
|
||||
.Inputs({"input", "gating_output"})
|
||||
.Outputs({"permute_input",
|
||||
"tokens_expert_prefix_sum",
|
||||
"permute_indices_per_token",
|
||||
"top_k_weight",
|
||||
"top_k_indices"})
|
||||
.Attrs({"moe_topk:int", "group_moe:bool", "topk_only_mode:bool"})
|
||||
.SetKernelFn(PD_KERNEL(MoeExpertDispatch))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertDispatchInferDtype));
|
173
custom_ops/metax_ops/moe_ffn.cu
Normal file
173
custom_ops/metax_ops/moe_ffn.cu
Normal file
@@ -0,0 +1,173 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
#include "mc_fused_moe_helper.h"
|
||||
#include "helper.h"
|
||||
|
||||
template <paddle::DataType T, typename ElementA, typename ElementB, typename ElementC>
|
||||
void McMoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const std::string& quant_method,
|
||||
paddle::Tensor ffn_out) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto ffn_out_ptr = ffn_out.data<data_t>();
|
||||
auto permuted_input_ptr = permute_input.data<data_t>();
|
||||
auto place = permute_input.place();
|
||||
auto input_type = permute_input.dtype();
|
||||
auto stream = permute_input.stream();
|
||||
|
||||
const int expanded_active_expert_rows = permute_input.dims()[0]; // permute_input.dims(): m, k
|
||||
const int num_experts = ffn1_weight.dims()[0]; // batchsize
|
||||
const int hidden_size = ffn1_weight.dims()[2]; // n
|
||||
int inter_dim = ffn1_weight.dims()[1]; // k
|
||||
|
||||
const int64_t inter_size = inter_dim; // since weight_only_int_8
|
||||
paddle::Tensor fc1_out_tensor = GetEmptyTensor(
|
||||
{expanded_active_expert_rows, inter_size}, input_type, place);
|
||||
auto fc1_out_ptr = fc1_out_tensor.data<data_t>();
|
||||
|
||||
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ROWMAJOR_ORDER;
|
||||
mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_COLUMNMAJOR_ORDER;
|
||||
|
||||
// ffn1
|
||||
auto fc1_expert_biases =
|
||||
ffn1_bias
|
||||
? const_cast<paddle::Tensor*>(ffn1_bias.get_ptr())->data<data_t>()
|
||||
: nullptr;
|
||||
auto fc1_expert_scales = const_cast<paddle::Tensor*>(ffn1_scale.get_ptr())->data<data_t>();
|
||||
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
|
||||
reinterpret_cast<const ElementA *>(permuted_input_ptr),
|
||||
row_major,
|
||||
reinterpret_cast<const ElementB *>(ffn1_weight.data<ElementB>()),
|
||||
column_major,
|
||||
reinterpret_cast<const ElementA *>(fc1_expert_scales),
|
||||
reinterpret_cast<const ElementA *>(fc1_expert_biases),
|
||||
reinterpret_cast<ElementC *>(fc1_out_ptr),
|
||||
row_major,
|
||||
tokens_expert_prefix_sum.data<int>(),
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
inter_dim,
|
||||
hidden_size,
|
||||
stream);
|
||||
|
||||
// swiglu
|
||||
auto act_out_tensor = paddle::experimental::swiglu(fc1_out_tensor, nullptr);
|
||||
auto act_out = act_out_tensor.data<data_t>();
|
||||
|
||||
auto fc2_expert_scales = const_cast<paddle::Tensor*>(ffn2_scale.get_ptr())->data<data_t>();
|
||||
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
|
||||
reinterpret_cast<const ElementA *>(act_out),
|
||||
row_major,
|
||||
reinterpret_cast<const ElementB *>(ffn2_weight.data<ElementB>()),
|
||||
column_major,
|
||||
reinterpret_cast<const ElementA *>(fc2_expert_scales),
|
||||
nullptr,
|
||||
reinterpret_cast<ElementC *>(ffn_out_ptr),
|
||||
row_major,
|
||||
tokens_expert_prefix_sum.data<int>(),
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
hidden_size,
|
||||
inter_dim / 2,
|
||||
stream);
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
const paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const std::string& quant_method) {
|
||||
assert(quant_method == "weight_only_int8");
|
||||
const auto input_type = permute_input.dtype();
|
||||
auto ffn_out = paddle::empty_like(permute_input);
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
McMoeFFNKernel<paddle::DataType::BFLOAT16, maca_bfloat16, int8_t, maca_bfloat16>(permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
ffn1_weight,
|
||||
ffn2_weight,
|
||||
ffn1_bias,
|
||||
ffn1_scale,
|
||||
ffn2_scale,
|
||||
quant_method,
|
||||
ffn_out);
|
||||
break;
|
||||
// case paddle::DataType::FLOAT16:
|
||||
// MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input,
|
||||
// tokens_expert_prefix_sum,
|
||||
// ffn1_weight,
|
||||
// ffn2_weight,
|
||||
// ffn1_bias,
|
||||
// ffn1_scale,
|
||||
// ffn2_scale,
|
||||
// quant_method,
|
||||
// ffn_out);
|
||||
// break;
|
||||
default:
|
||||
PD_THROW("Only support bf16 for MoeExpertFFN");
|
||||
}
|
||||
return {ffn_out};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
|
||||
const std::vector<int64_t>& permute_input_shape,
|
||||
const std::vector<int64_t>& tokens_expert_prefix_sum_shape,
|
||||
const std::vector<int64_t>& ffn1_weight_shape,
|
||||
const std::vector<int64_t>& ffn2_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn2_scale_shape) {
|
||||
return {permute_input_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> MoeExpertFFNInferDtype(
|
||||
const paddle::DataType& permute_input_dtype,
|
||||
const paddle::DataType& tokens_expert_prefix_sum_dtype,
|
||||
const paddle::DataType& ffn1_weight_dtype,
|
||||
const paddle::DataType& ffn2_weight_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn1_bias_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn1_scale_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn2_scale_dtype) {
|
||||
return {permute_input_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(moe_expert_ffn)
|
||||
.Inputs({"permute_input",
|
||||
"tokens_expert_prefix_sum",
|
||||
"ffn1_weight",
|
||||
"ffn2_weight",
|
||||
paddle::Optional("ffn1_bias"),
|
||||
paddle::Optional("ffn1_scale"),
|
||||
paddle::Optional("ffn2_scale")})
|
||||
.Outputs({"output_tensor"})
|
||||
.Attrs({"quant_method:std::string"})
|
||||
.SetKernelFn(PD_KERNEL(MoeExpertFFN))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNInferDtype));
|
143
custom_ops/metax_ops/moe_reduce.cu
Normal file
143
custom_ops/metax_ops/moe_reduce.cu
Normal file
@@ -0,0 +1,143 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "helper.h"
|
||||
#include "fused_moe_helper.h"
|
||||
#include "fused_moe_op.h"
|
||||
|
||||
template <paddle::DataType T>
|
||||
void MoeReduceKernel(const paddle::Tensor& ffn_out,
|
||||
const paddle::Tensor& top_k_weight,
|
||||
const paddle::Tensor& permute_indices_per_token,
|
||||
const paddle::Tensor& top_k_indices,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_bias,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor,
|
||||
const int num_rows,
|
||||
const int hidden_size,
|
||||
const int topk,
|
||||
paddle::Tensor* output) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
auto stream = ffn_out.stream();
|
||||
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
ffn_out.data<data_t>(),
|
||||
output->data<data_t>(),
|
||||
ffn2_bias ? ffn2_bias->data<data_t>() : nullptr,
|
||||
top_k_weight.data<float>(),
|
||||
permute_indices_per_token.data<int32_t>(),
|
||||
top_k_indices.data<int>(),
|
||||
num_rows,
|
||||
hidden_size,
|
||||
topk,
|
||||
static_cast<int>(1),
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
}
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> MoeExpertReduce(
|
||||
const paddle::Tensor& ffn_out,
|
||||
const paddle::Tensor& top_k_weight,
|
||||
const paddle::Tensor& permute_indices_per_token,
|
||||
const paddle::Tensor& top_k_indices,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_bias,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor) {
|
||||
const auto input_type = ffn_out.dtype();
|
||||
auto place = ffn_out.place();
|
||||
|
||||
const int topk = top_k_indices.dims()[1];
|
||||
const int num_rows = ffn_out.dims()[0] / topk;
|
||||
const int hidden_size = ffn_out.dims()[1];
|
||||
|
||||
auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place);
|
||||
|
||||
// Avoids ‘invalid configuration argument’ when we launch the kernel.
|
||||
if (ffn_out.dims()[0] == 0) return {output};
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeReduceKernel<paddle::DataType::BFLOAT16>(ffn_out,
|
||||
top_k_weight,
|
||||
permute_indices_per_token,
|
||||
top_k_indices,
|
||||
ffn2_bias,
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
topk,
|
||||
&output);
|
||||
break;
|
||||
// case paddle::DataType::FLOAT16:
|
||||
// MoeReduceKernel<paddle::DataType::FLOAT16>(ffn_out,
|
||||
// top_k_weight,
|
||||
// permute_indices_per_token,
|
||||
// top_k_indices,
|
||||
// ffn2_bias,
|
||||
// norm_topk_prob,
|
||||
// routed_scaling_factor,
|
||||
// num_rows,
|
||||
// hidden_size,
|
||||
// topk,
|
||||
// &output);
|
||||
// break;
|
||||
default:
|
||||
PD_THROW("Only support bf16 for MoeDispatchKernel");
|
||||
}
|
||||
return {output};
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::vector<int64_t>> MoeExpertReduceInferShape(
|
||||
const std::vector<int64_t>& ffn_out_shape,
|
||||
const std::vector<int64_t>& top_k_weight_shape,
|
||||
const std::vector<int64_t>& permute_indices_per_token_shape,
|
||||
const std::vector<int64_t>& top_k_indices_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn2_bias_shape) {
|
||||
const int topk = top_k_indices_shape[1];
|
||||
std::vector<int64_t> fused_moe_out_shape = {ffn_out_shape[0] / topk,
|
||||
ffn_out_shape[1]};
|
||||
|
||||
return {fused_moe_out_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> MoeExpertReduceInferDtype(
|
||||
const paddle::DataType& ffn_out_dtype,
|
||||
const paddle::DataType& top_k_weight_dtype,
|
||||
const paddle::DataType& permute_indices_per_token_dtype,
|
||||
const paddle::DataType& top_k_indices_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn2_bias_dtype) {
|
||||
return {ffn_out_dtype};
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_OP(moe_expert_reduce)
|
||||
.Inputs({"ffn_out",
|
||||
"top_k_weight",
|
||||
"permute_indices_per_token",
|
||||
"top_k_indices",
|
||||
paddle::Optional("ffn2_bias")})
|
||||
.Outputs({"output"})
|
||||
.Attrs({"norm_topk_prob:bool", "routed_scaling_factor:float"})
|
||||
.SetKernelFn(PD_KERNEL(MoeExpertReduce))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertReduceInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertReduceInferDtype));
|
@@ -597,6 +597,10 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
|
||||
"gpu_ops/moe/tritonmoe_preprocess.cu",
|
||||
"gpu_ops/moe/moe_topk_select.cu",
|
||||
"gpu_ops/recover_decode_task.cu",
|
||||
"metax_ops/moe_dispatch.cu",
|
||||
"metax_ops/moe_ffn.cu",
|
||||
"metax_ops/moe_reduce.cu",
|
||||
"metax_ops/fused_moe.cu",
|
||||
]
|
||||
|
||||
sources += find_end_files("gpu_ops/speculate_decoding", ".cu")
|
||||
@@ -617,7 +621,7 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
|
||||
],
|
||||
},
|
||||
library_dirs=[os.path.join(maca_path, "lib")],
|
||||
extra_link_args=["-lruntime_cu"],
|
||||
extra_link_args=["-lruntime_cu", "-lmctlassEx"],
|
||||
include_dirs=[
|
||||
os.path.join(maca_path, "include"),
|
||||
os.path.join(maca_path, "include/mcr"),
|
||||
|
@@ -19,8 +19,8 @@ docker login --username=cr_temp_user --password=eyJpbnN0YW5jZUlkIjoiY3JpLXpxYTIz
|
||||
## 2. paddlepaddle and custom device installation
|
||||
|
||||
```shell
|
||||
1)pip install paddlepaddle==3.0.0.dev20250729 -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/
|
||||
2)pip install paddle-metax-gpu==3.0.0.dev20250807 -i https://www.paddlepaddle.org.cn/packages/nightly/maca/
|
||||
1)pip install paddlepaddle==3.0.0.dev20250825 -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/
|
||||
2)pip install paddle-metax-gpu==3.0.0.dev20250826 -i https://www.paddlepaddle.org.cn/packages/nightly/maca/
|
||||
```
|
||||
|
||||
## 3. Build Wheel from Source
|
||||
@@ -47,6 +47,8 @@ from fastdeploy.model_executor.ops.gpu import beam_search_softmax
|
||||
If the above code executes successfully, the environment is ready.
|
||||
|
||||
## 5. Demo
|
||||
|
||||
```python
|
||||
from fastdeploy import LLM, SamplingParams
|
||||
|
||||
prompts = [
|
||||
@@ -68,7 +70,9 @@ for output in outputs:
|
||||
print(prompt)
|
||||
print(generated_text)
|
||||
print("-" * 50)
|
||||
```
|
||||
|
||||
```
|
||||
Output:
|
||||
INFO 2025-08-18 10:54:18,455 416822 engine.py[line:202] Waiting worker processes ready...
|
||||
Loading Weights: 100%|█████████████████████████████████████████████████████████████████████████| 100/100 [03:33<00:00, 2.14s/it]
|
||||
@@ -81,3 +85,4 @@ Generated 1 outputs
|
||||
Hello. My name is
|
||||
Alice and I'm here to help you. What can I do for you today?
|
||||
Hello Alice! I'm trying to organize a small party
|
||||
```
|
||||
|
@@ -13,9 +13,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .attention.flash_attn_backend import FlashAttentionBackend
|
||||
from .moe.fused_moe_cutlass_metax_backend import MetaxCutlassWeightOnlyMoEMethod
|
||||
from .moe.fused_moe_triton_metax_backend import MetaxTritonWeightOnlyMoEMethod
|
||||
|
||||
__all__ = [
|
||||
"FlashAttentionBackend",
|
||||
"MetaxTritonWeightOnlyMoEMethod",
|
||||
"MetaxCutlassWeightOnlyMoEMethod",
|
||||
]
|
||||
|
@@ -1,3 +1,17 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
|
@@ -1,4 +1,3 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -12,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -261,27 +259,9 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
forward_meta: ForwardMeta,
|
||||
cu_seqlens_q: paddle.Tensor,
|
||||
batch_ids=None,
|
||||
is_decode=False,
|
||||
):
|
||||
q_end = self.num_heads * self.head_dim
|
||||
k_end = q_end + self.kv_num_heads * self.head_dim
|
||||
v_end = k_end + self.kv_num_heads * self.head_dim
|
||||
assert v_end == qkv.shape[-1], f"Shape mismatch: {v_end} vs {qkv.shape[-1]}"
|
||||
assert qkv.shape[0] == cu_seqlens_q[-1], f"Shape mismatch: {qkv.shape[0]} vs {cu_seqlens_q[-1]}"
|
||||
|
||||
if batch_ids is None:
|
||||
batch_ids = list(range(forward_meta.seq_lens_this_time.shape[0]))
|
||||
|
||||
q = qkv[..., 0:q_end]
|
||||
k = qkv[..., q_end:k_end]
|
||||
v = qkv[..., k_end:v_end]
|
||||
|
||||
q = q.view([-1, self.num_heads, self.head_dim])
|
||||
k = k.view([-1, self.kv_num_heads, self.head_dim])
|
||||
v = v.view([-1, self.kv_num_heads, self.head_dim])
|
||||
|
||||
if is_decode:
|
||||
return q, k, v
|
||||
qkv = qkv.view([-1, self.num_heads + self.kv_num_heads * 2, self.head_dim])
|
||||
q, k, v = qkv.split(num_or_sections=[self.num_heads, self.kv_num_heads, self.kv_num_heads], axis=-2)
|
||||
|
||||
for idx in range(len(cu_seqlens_q) - 1):
|
||||
batch_idx = batch_ids[idx]
|
||||
@@ -375,41 +355,6 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
cache_start += self.block_size
|
||||
tensor_start = tensor_end
|
||||
|
||||
def merge_output(self, prefill_out, decode_out, forward_meta: ForwardMeta):
|
||||
assert not (prefill_out is None and decode_out is None), "prefill and decode output cannot both be None"
|
||||
if prefill_out is None:
|
||||
return decode_out
|
||||
elif decode_out is None:
|
||||
return prefill_out
|
||||
else:
|
||||
prefill_out = prefill_out
|
||||
decode_out = decode_out
|
||||
|
||||
merged_output = []
|
||||
prefill_tensor_start = 0
|
||||
decode_tensor_start = 0
|
||||
for seq_lens_this_time in forward_meta.seq_lens_this_time:
|
||||
if seq_lens_this_time == 0:
|
||||
continue
|
||||
if seq_lens_this_time > 1:
|
||||
tensor_end = prefill_tensor_start + seq_lens_this_time
|
||||
merged_output.append(prefill_out[prefill_tensor_start:tensor_end, :, :])
|
||||
prefill_tensor_start = tensor_end
|
||||
else:
|
||||
assert seq_lens_this_time == 1
|
||||
tensor_end = decode_tensor_start + seq_lens_this_time
|
||||
merged_output.append(decode_out[decode_tensor_start:tensor_end, :, :])
|
||||
decode_tensor_start = tensor_end
|
||||
|
||||
assert (
|
||||
prefill_tensor_start == prefill_out.shape[0]
|
||||
), f"prefill merged unfinished: {prefill_tensor_start} vs {prefill_out.shape[0]}"
|
||||
assert (
|
||||
decode_tensor_start == decode_out.shape[0]
|
||||
), f"decode merged unfinished: {decode_tensor_start} vs {decode_out.shape[0]}"
|
||||
merged_output = paddle.concat(merged_output, axis=0)
|
||||
return merged_output
|
||||
|
||||
def forward_prefill(self, prefill_qkv, layer_id, k_cache_id, v_cache_id, forward_meta: ForwardMeta):
|
||||
|
||||
prefill_q, prefill_k, prefill_v = self.get_splited_qkv(
|
||||
@@ -438,23 +383,17 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
return prefill_out
|
||||
|
||||
def forward_decode(self, decode_qkv, k_cache_id, v_cache_id, forward_meta: ForwardMeta):
|
||||
cache_k = forward_meta.caches[k_cache_id]
|
||||
cache_v = forward_meta.caches[v_cache_id]
|
||||
cu_seq_lens = list(range(self.decode_len + 1))
|
||||
|
||||
q, k, v = self.get_splited_qkv(decode_qkv, forward_meta, cu_seq_lens, self.batch_ids_decode, is_decode=True)
|
||||
decoder_q = q.view([self.decode_len, 1, self.num_heads, self.head_dim])
|
||||
decoder_k_ = k.view([self.decode_len, 1, self.kv_num_heads, self.head_dim])
|
||||
decoder_v_ = v.view([self.decode_len, 1, self.kv_num_heads, self.head_dim])
|
||||
qkv = decode_qkv.view([-1, 1, self.num_heads + self.kv_num_heads * 2, self.head_dim])
|
||||
q, k, v = qkv.split(num_or_sections=[self.num_heads, self.kv_num_heads, self.kv_num_heads], axis=-2)
|
||||
|
||||
decode_out = flash_attn_kvcache_func(
|
||||
decoder_q,
|
||||
cache_k,
|
||||
cache_v,
|
||||
q,
|
||||
forward_meta.caches[k_cache_id],
|
||||
forward_meta.caches[v_cache_id],
|
||||
self.seq_lens_dec,
|
||||
self.block_table_dec,
|
||||
decoder_k_,
|
||||
decoder_v_,
|
||||
k,
|
||||
v,
|
||||
rotary_cos=forward_meta.rotary_embs[0, 0, :, 0, :].astype("bfloat16"),
|
||||
rotary_sin=forward_meta.rotary_embs[1, 0, :, 0, :].astype("bfloat16"),
|
||||
causal=self.causal,
|
||||
|
@@ -0,0 +1,370 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn.quant import weight_quantize
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.ops.gpu import fused_expert_moe
|
||||
from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs
|
||||
|
||||
|
||||
class MetaxCutlassMoEMethod(MoEMethodBase):
|
||||
"""
|
||||
Use Cutlass Group Gemm to compute Fused MoE.
|
||||
This method is the oldest way to compute MoE in Paddle.
|
||||
"""
|
||||
|
||||
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
|
||||
layer.extract_moe_ffn_weights(state_dict)
|
||||
)
|
||||
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
|
||||
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
|
||||
|
||||
layer.up_gate_proj_weight.set_value(stacked_up_gate_proj_weights)
|
||||
layer.down_proj_weight.set_value(stacked_down_proj_weights)
|
||||
|
||||
def compute_ffn(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
permute_input: paddle.Tensor,
|
||||
token_nums_per_expert: paddle.Tensor,
|
||||
expert_idx_per_token: paddle.Tensor,
|
||||
used_in_ep_low_latency: bool = False,
|
||||
estimate_total_token_nums: int = -1,
|
||||
):
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
"""
|
||||
return fastdeploy.model_executor.ops.gpu.moe_expert_ffn(
|
||||
permute_input,
|
||||
token_nums_per_expert,
|
||||
getattr(layer, self.added_weight_attrs[0]),
|
||||
getattr(layer, self.added_weight_attrs[1]),
|
||||
None,
|
||||
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
|
||||
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
|
||||
"weight_only_int8",
|
||||
)
|
||||
|
||||
def apply_ep_prefill(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def apply_ep_decode(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def apply_tp(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
"""
|
||||
|
||||
fused_moe_out = fused_expert_moe(
|
||||
x,
|
||||
gate.weight,
|
||||
getattr(layer, self.added_weight_attrs[0]),
|
||||
getattr(layer, self.added_weight_attrs[1]),
|
||||
None,
|
||||
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
|
||||
None,
|
||||
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
|
||||
"weight_only_int8",
|
||||
layer.top_k,
|
||||
True,
|
||||
False,
|
||||
)
|
||||
if layer.reduce_results and layer.tp_size > 1:
|
||||
tensor_model_parallel_all_reduce(fused_moe_out)
|
||||
|
||||
return fused_moe_out
|
||||
|
||||
|
||||
class MetaxCutlassWeightOnlyMoEMethod(MetaxCutlassMoEMethod):
|
||||
"""
|
||||
weight only for moe
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config=None):
|
||||
"""
|
||||
weight only for moe
|
||||
"""
|
||||
super().__init__(quant_config)
|
||||
# print(f"[DEBUG] quant_config: {quant_config}")
|
||||
self.quant_config = quant_config
|
||||
self.moe_quant_type = self.quant_config.algo
|
||||
self.pack_num = 1
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
|
||||
"""
|
||||
Paddle cutlass process prequanted weights.
|
||||
"""
|
||||
up_gate_proj_expert_weight_key = layer.weight_key_map.get("up_gate_proj_expert_weight_key", None)
|
||||
down_proj_expert_weight_key = layer.weight_key_map.get("down_proj_expert_weight_key", None)
|
||||
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
|
||||
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
||||
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = layer.load_experts_weight(
|
||||
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key, is_rearrange
|
||||
)
|
||||
# self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
up_gate_proj_weight_scale = []
|
||||
down_proj_weight_scale = []
|
||||
for expert_idx in logical_expert_ids:
|
||||
up_gate_proj_weight_scale.append(
|
||||
get_tensor(state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx)))
|
||||
)
|
||||
down_proj_weight_scale.append(
|
||||
get_tensor(state_dict.pop(down_proj_expert_weight_scale_key.format(expert_idx)))
|
||||
)
|
||||
|
||||
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
|
||||
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
|
||||
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0)
|
||||
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0)
|
||||
|
||||
name_tensor_map = {
|
||||
"up_gate_proj_weight": up_gate_proj_weight,
|
||||
"down_proj_weight": down_proj_weight,
|
||||
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
|
||||
"down_proj_weight_scale": down_proj_weight_scale,
|
||||
}
|
||||
for name, tensor in name_tensor_map.items():
|
||||
getattr(layer, name).set_value(tensor)
|
||||
|
||||
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
|
||||
"""
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
self.default_dtype = layer._helper.get_default_dtype()
|
||||
if self.moe_quant_type == "weight_only_int4":
|
||||
self.up_gate_proj_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size,
|
||||
layer.hidden_size,
|
||||
]
|
||||
else:
|
||||
self.up_gate_proj_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size * 2,
|
||||
layer.hidden_size,
|
||||
]
|
||||
if self.moe_quant_type == "weight_only_int4":
|
||||
self.down_proj_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size // 2,
|
||||
layer.moe_intermediate_size,
|
||||
]
|
||||
else:
|
||||
self.down_proj_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size,
|
||||
layer.moe_intermediate_size,
|
||||
]
|
||||
self.up_gate_proj_scale_shape = [layer.num_local_experts, layer.moe_intermediate_size * 2]
|
||||
self.down_proj_scale_shape = [layer.num_local_experts, layer.hidden_size]
|
||||
|
||||
if layer.fd_config.load_config.load_choices == "default_v1":
|
||||
layer.up_gate_proj_weight = layer.create_parameter(
|
||||
shape=[layer.num_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
|
||||
dtype=layer.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
layer.down_proj_weight = layer.create_parameter(
|
||||
shape=[layer.num_experts, layer.moe_intermediate_size, layer.hidden_size],
|
||||
dtype=layer.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
set_weight_attrs(
|
||||
layer.up_gate_proj_weight,
|
||||
{
|
||||
**extra_weight_attrs,
|
||||
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True),
|
||||
},
|
||||
)
|
||||
set_weight_attrs(
|
||||
layer.down_proj_weight,
|
||||
{
|
||||
**extra_weight_attrs,
|
||||
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False),
|
||||
},
|
||||
)
|
||||
else:
|
||||
self.weight_dtype = "int8"
|
||||
|
||||
up_gate_proj_weight_name = self.added_weight_attrs[0]
|
||||
down_proj_weight_name = self.added_weight_attrs[1]
|
||||
up_gate_proj_scale_name = self.added_scale_attrs[0]
|
||||
down_proj_scale_name = self.added_scale_attrs[1]
|
||||
|
||||
setattr(
|
||||
layer,
|
||||
up_gate_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.up_gate_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
down_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.down_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
# weight_scale
|
||||
setattr(
|
||||
layer,
|
||||
up_gate_proj_scale_name,
|
||||
layer.create_parameter(
|
||||
shape=self.up_gate_proj_scale_shape,
|
||||
dtype=self.default_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
down_proj_scale_name,
|
||||
layer.create_parameter(
|
||||
shape=self.down_proj_scale_shape,
|
||||
dtype=self.default_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
moe_extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}
|
||||
set_weight_attrs(layer.up_gate_proj_weight, moe_extra_weight_attrs)
|
||||
set_weight_attrs(layer.down_proj_weight, moe_extra_weight_attrs)
|
||||
scale_extra_weight_attrs = {
|
||||
**extra_weight_attrs,
|
||||
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "up": 0, "down": None},
|
||||
}
|
||||
set_weight_attrs(layer.up_gate_proj_weight_scale, scale_extra_weight_attrs)
|
||||
set_weight_attrs(layer.down_proj_weight_scale, scale_extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
""" """
|
||||
if not layer.fd_config.load_config.load_choices == "default_v1":
|
||||
return
|
||||
weight_id_map = {"gate_up": 0, "down": 1}
|
||||
if (
|
||||
hasattr(layer.up_gate_proj_weight, "tensor_track")
|
||||
and layer.up_gate_proj_weight.tensor_track is not None
|
||||
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
|
||||
):
|
||||
weight_type = "gate_up"
|
||||
else:
|
||||
weight_type = "down"
|
||||
|
||||
# 1.init shape and type
|
||||
# weight
|
||||
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
|
||||
unquantized_weight_name = weight_name.replace("quant_weight", "weight")
|
||||
weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape
|
||||
weight_dtype = "int8"
|
||||
# scale
|
||||
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
|
||||
scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape
|
||||
scale_dtype = self.default_dtype
|
||||
|
||||
# 2.crate tmp tensor
|
||||
|
||||
weight = paddle.empty(weight_shape, dtype=weight_dtype)
|
||||
scale = paddle.empty(scale_shape, dtype=scale_dtype)
|
||||
|
||||
# 3.quantize weight
|
||||
|
||||
for expert_id in range(layer.num_experts):
|
||||
weight[expert_id], scale[expert_id] = weight_quantize(
|
||||
getattr(layer, unquantized_weight_name)[expert_id], algo=self.moe_quant_type, arch=80, group_size=-1
|
||||
)
|
||||
|
||||
free_tensor(getattr(layer, unquantized_weight_name))
|
||||
|
||||
# create weight
|
||||
setattr(
|
||||
layer,
|
||||
weight_name,
|
||||
layer.create_parameter(
|
||||
shape=weight_shape,
|
||||
dtype=weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
# create scale
|
||||
setattr(
|
||||
layer,
|
||||
scale_name,
|
||||
layer.create_parameter(
|
||||
shape=scale_shape,
|
||||
dtype=scale_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
getattr(layer, weight_name).copy_(weight, False)
|
||||
getattr(layer, scale_name).copy_(scale, False)
|
||||
|
||||
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Paddle cutlass load weight process.
|
||||
"""
|
||||
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
|
||||
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
scale_name = self.added_scale_attrs[idx]
|
||||
|
||||
weight_list = []
|
||||
weight_scale_list = []
|
||||
for i in range(layer.num_local_experts):
|
||||
quant_weight, scale = weight_quantize(
|
||||
weight_tensor[i], algo=self.moe_quant_type, arch=80, group_size=-1
|
||||
)
|
||||
quant_weight = paddle.transpose(quant_weight, [1, 0])
|
||||
weight_list.append(quant_weight)
|
||||
weight_scale_list.append(scale)
|
||||
quanted_weight = paddle.stack(weight_list, axis=0)
|
||||
getattr(layer, weight_name).set_value(quanted_weight)
|
||||
|
||||
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
|
||||
getattr(layer, scale_name).set_value(quanted_weight_scale)
|
@@ -1,4 +1,3 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -12,12 +11,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
|
||||
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
|
||||
from fastdeploy.utils import ceil_div
|
||||
@@ -153,7 +152,6 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
Triton compute Fused MoE.
|
||||
"""
|
||||
token_num = x.shape[0]
|
||||
top_k = layer.top_k
|
||||
num_local_experts = layer.num_local_experts
|
||||
top_k = layer.top_k
|
||||
moe_intermediate_size = layer.moe_intermediate_size
|
||||
@@ -172,21 +170,12 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
if self.quant_config is not None:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"GROUP_SIZE_M": 4,
|
||||
}
|
||||
else:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
}
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
|
||||
topk_ids, num_local_experts, config["BLOCK_SIZE_M"]
|
||||
)
|
||||
@@ -292,4 +281,6 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
|
||||
down_proj_out.reshape_([token_num, top_k, hidden_size])
|
||||
out = down_proj_out.sum(axis=1)
|
||||
if layer.tp_size > 1:
|
||||
tensor_model_parallel_all_reduce(out)
|
||||
return out
|
||||
|
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -12,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
@@ -50,12 +50,7 @@ def get_moe_method():
|
||||
from fastdeploy.model_executor.layers.backends import GCUFusedMoeMethod
|
||||
|
||||
return GCUFusedMoeMethod(None)
|
||||
elif current_platform.is_maca():
|
||||
from fastdeploy.model_executor.layers.backends import (
|
||||
MetaxTritonWeightOnlyMoEMethod,
|
||||
)
|
||||
|
||||
return MetaxTritonWeightOnlyMoEMethod(None)
|
||||
elif current_platform.is_intel_hpu():
|
||||
from fastdeploy.model_executor.layers.backends import HpuMoEMethod
|
||||
|
||||
|
@@ -123,11 +123,19 @@ class WeightOnlyConfig(QuantConfigBase):
|
||||
elif current_platform.is_maca():
|
||||
if isinstance(layer, FusedMoE):
|
||||
from fastdeploy.model_executor.layers.backends import (
|
||||
MetaxCutlassWeightOnlyMoEMethod,
|
||||
MetaxTritonWeightOnlyMoEMethod,
|
||||
)
|
||||
|
||||
if layer.use_method == "cutlass":
|
||||
|
||||
return MetaxCutlassWeightOnlyMoEMethod(self)
|
||||
elif layer.use_method == "triton":
|
||||
|
||||
return MetaxTritonWeightOnlyMoEMethod(self)
|
||||
else:
|
||||
raise ValueError(f"Unsupported MOE backend {layer.use_method}")
|
||||
else:
|
||||
|
||||
return GPUWeightOnlyLinearMethod(self)
|
||||
else:
|
||||
|
@@ -8,9 +8,9 @@ aiozmq
|
||||
openai>=1.93.0
|
||||
tqdm
|
||||
pynvml
|
||||
uvicorn
|
||||
uvicorn==0.29.0
|
||||
fastapi
|
||||
paddleformers
|
||||
paddleformers>=0.2
|
||||
redis
|
||||
etcd3
|
||||
httpx
|
||||
@@ -30,11 +30,12 @@ use-triton-in-paddle
|
||||
crcmod
|
||||
fastsafetensors==0.1.14
|
||||
msgpack
|
||||
modelscope
|
||||
opentelemetry-api>=1.24.0
|
||||
opentelemetry-sdk>=1.24.0
|
||||
opentelemetry-instrumentation-redis
|
||||
opentelemetry-instrumentation-mysql
|
||||
opentelemetry-distro
|
||||
opentelemetry-distro
|
||||
opentelemetry-exporter-otlp
|
||||
opentelemetry-instrumentation-fastapi
|
||||
partial_json_parser
|
||||
|
Reference in New Issue
Block a user