mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
This reverts commit 93fcf7e4ec.
This commit is contained in:
@@ -1,37 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "helper.h"
|
||||
|
||||
template <typename T, typename OutT>
|
||||
void MoeFastHardamardWrapper(const T *x_data,
|
||||
const int64_t *expert_idx_per_token,
|
||||
const int64_t *recv_expert_count,
|
||||
const T *shift,
|
||||
const T *smooth,
|
||||
const float *quant_scales,
|
||||
const int quant_round_type,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const int64_t token_num,
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
OutT *out,
|
||||
cudaStream_t &stream);
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,34 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fast_hardamard_kernel.hpp"
|
||||
|
||||
template void
|
||||
MoeFastHardamardWrapper<phi::dtype::bfloat16, phi::dtype::bfloat16>(
|
||||
const phi::dtype::bfloat16 *x_data,
|
||||
const int64_t *expert_idx_per_token,
|
||||
const int64_t *recv_expert_count,
|
||||
const phi::dtype::bfloat16 *shift,
|
||||
const phi::dtype::bfloat16 *smooth,
|
||||
const float *quant_scales,
|
||||
const int quant_round_type,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const int64_t token_num,
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
phi::dtype::bfloat16 *out,
|
||||
cudaStream_t &stream);
|
||||
@@ -1,34 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fast_hardamard_kernel.hpp"
|
||||
|
||||
template void
|
||||
MoeFastHardamardWrapper<phi::dtype::bfloat16, phi::dtype::float8_e4m3fn>(
|
||||
const phi::dtype::bfloat16 *x_data,
|
||||
const int64_t *expert_idx_per_token,
|
||||
const int64_t *recv_expert_count,
|
||||
const phi::dtype::bfloat16 *shift,
|
||||
const phi::dtype::bfloat16 *smooth,
|
||||
const float *quant_scales,
|
||||
const int quant_round_type,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const int64_t token_num,
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
phi::dtype::float8_e4m3fn *out,
|
||||
cudaStream_t &stream);
|
||||
@@ -1,33 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fast_hardamard_kernel.hpp"
|
||||
|
||||
template void MoeFastHardamardWrapper<phi::dtype::bfloat16, int8_t>(
|
||||
const phi::dtype::bfloat16 *x_data,
|
||||
const int64_t *expert_idx_per_token,
|
||||
const int64_t *recv_expert_count,
|
||||
const phi::dtype::bfloat16 *shift,
|
||||
const phi::dtype::bfloat16 *smooth,
|
||||
const float *quant_scales,
|
||||
const int quant_round_type,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const int64_t token_num,
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
int8_t *out,
|
||||
cudaStream_t &stream);
|
||||
@@ -1,33 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fast_hardamard_kernel.hpp"
|
||||
|
||||
template void MoeFastHardamardWrapper<phi::dtype::float16, phi::dtype::float16>(
|
||||
const phi::dtype::float16 *x_data,
|
||||
const int64_t *expert_idx_per_token,
|
||||
const int64_t *recv_expert_count,
|
||||
const phi::dtype::float16 *shift,
|
||||
const phi::dtype::float16 *smooth,
|
||||
const float *quant_scales,
|
||||
const int quant_round_type,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const int64_t token_num,
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
phi::dtype::float16 *out,
|
||||
cudaStream_t &stream);
|
||||
@@ -1,33 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fast_hardamard_kernel.hpp"
|
||||
|
||||
template void MoeFastHardamardWrapper<phi::dtype::float16, int8_t>(
|
||||
const phi::dtype::float16 *x_data,
|
||||
const int64_t *expert_idx_per_token,
|
||||
const int64_t *recv_expert_count,
|
||||
const phi::dtype::float16 *shift,
|
||||
const phi::dtype::float16 *smooth,
|
||||
const float *quant_scales,
|
||||
const int quant_round_type,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const int64_t token_num,
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
int8_t *out,
|
||||
cudaStream_t &stream);
|
||||
@@ -25,8 +25,7 @@ template <typename T, int VecSize>
|
||||
__global__ void moe_token_type_ids_kernel(T *gating_output,
|
||||
const int *moe_token_type_ids_out,
|
||||
const int num_rows,
|
||||
const int num_experts,
|
||||
const int k) {
|
||||
const int num_experts, const int k) {
|
||||
const int moe_token_index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (moe_token_index >= num_rows) {
|
||||
@@ -45,8 +44,7 @@ template <typename T>
|
||||
void moe_token_type_ids_kernelLauncher(T *gating_output,
|
||||
const int *moe_token_type_ids_out,
|
||||
const int num_rows,
|
||||
const int num_experts,
|
||||
const int k,
|
||||
const int num_experts, const int k,
|
||||
cudaStream_t stream) {
|
||||
const int blocks = num_rows * k / 512 + 1;
|
||||
const int threads = 512;
|
||||
@@ -54,35 +52,26 @@ void moe_token_type_ids_kernelLauncher(T *gating_output,
|
||||
gating_output, moe_token_type_ids_out, num_rows, num_experts, k);
|
||||
}
|
||||
|
||||
template <typename T, typename NvType>
|
||||
class MoeHelper {
|
||||
public:
|
||||
using Fp16Traits =
|
||||
cutlass::WintQuantTraits<NvType, cutlass::WintQuantMethod::kNone>;
|
||||
using Int8Traits =
|
||||
cutlass::WintQuantTraits<NvType,
|
||||
cutlass::WintQuantMethod::kWeightOnlyInt8>;
|
||||
using Int4Traits =
|
||||
cutlass::WintQuantTraits<NvType,
|
||||
cutlass::WintQuantMethod::kWeightOnlyInt4>;
|
||||
template <typename T, typename NvType> class MoeHelper {
|
||||
public:
|
||||
using Fp16Traits = cutlass::WintQuantTraits<NvType, cutlass::WintQuantMethod::kNone>;
|
||||
using Int8Traits = cutlass::WintQuantTraits<NvType, cutlass::WintQuantMethod::kWeightOnlyInt8>;
|
||||
using Int4Traits = cutlass::WintQuantTraits<NvType, cutlass::WintQuantMethod::kWeightOnlyInt4>;
|
||||
|
||||
MoeHelper(const std::string gemm_method,
|
||||
MoeGemmRunner<NvType, Fp16Traits> *fp16_moe_gemm_runner,
|
||||
MoeGemmRunner<NvType, Int8Traits> *int8_moe_gemm_runner,
|
||||
MoeGemmRunner<NvType, Int4Traits> *int4_moe_gemm_runner,
|
||||
int layernum = 0)
|
||||
: gemm_method_(gemm_method),
|
||||
fp16_moe_gemm_runner_(fp16_moe_gemm_runner),
|
||||
MoeHelper(
|
||||
const std::string gemm_method,
|
||||
MoeGemmRunner<NvType, Fp16Traits> *fp16_moe_gemm_runner,
|
||||
MoeGemmRunner<NvType, Int8Traits> *int8_moe_gemm_runner,
|
||||
MoeGemmRunner<NvType, Int4Traits> *int4_moe_gemm_runner,
|
||||
int layernum = 0)
|
||||
: gemm_method_(gemm_method), fp16_moe_gemm_runner_(fp16_moe_gemm_runner),
|
||||
int8_moe_gemm_runner_(int8_moe_gemm_runner),
|
||||
int4_moe_gemm_runner_(int4_moe_gemm_runner),
|
||||
layernum_(layernum) {}
|
||||
int4_moe_gemm_runner_(int4_moe_gemm_runner), layernum_(layernum) {}
|
||||
|
||||
// -------- getWorkspaceSize -------- //
|
||||
template <typename KeyT>
|
||||
size_t getWorkspaceSize(const int64_t num_rows,
|
||||
const int64_t hidden_size,
|
||||
const int64_t inter_size,
|
||||
const int64_t num_experts,
|
||||
size_t getWorkspaceSize(const int64_t num_rows, const int64_t hidden_size,
|
||||
const int64_t inter_size, const int64_t num_experts,
|
||||
const int64_t k) {
|
||||
const size_t buf_size = AlignTo16(k * num_rows * hidden_size);
|
||||
const size_t interbuf_size = AlignTo16(k * num_rows * inter_size);
|
||||
@@ -93,10 +82,10 @@ class MoeHelper {
|
||||
// FfnLayer forward.
|
||||
size_t total_ws_bytes =
|
||||
5 * num_moe_inputs *
|
||||
sizeof(int); // source_rows_, permuted_rows_, permuted_experts_
|
||||
total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data
|
||||
sizeof(int); // source_rows_, permuted_rows_, permuted_experts_
|
||||
total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data
|
||||
total_ws_bytes +=
|
||||
padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_
|
||||
padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_
|
||||
|
||||
const size_t bytes_for_fc1_result = interbuf_size * sizeof(KeyT);
|
||||
const size_t sorter_ws_size_bytes =
|
||||
@@ -111,8 +100,8 @@ class MoeHelper {
|
||||
}
|
||||
|
||||
total_ws_bytes +=
|
||||
bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub
|
||||
// sorting workspace
|
||||
bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub
|
||||
// sorting workspace
|
||||
|
||||
int64_t num_softmax_outs = 0;
|
||||
const bool is_pow_2 =
|
||||
@@ -126,27 +115,20 @@ class MoeHelper {
|
||||
return total_ws_bytes;
|
||||
}
|
||||
|
||||
void ComputeFFN(const paddle::Tensor *input,
|
||||
const paddle::Tensor *gate_weight,
|
||||
const paddle::Tensor *up_gate_proj_weight,
|
||||
const paddle::Tensor *up_gate_proj_scale,
|
||||
const paddle::Tensor *up_gate_proj_bias,
|
||||
const paddle::Tensor *down_proj_weight,
|
||||
const paddle::Tensor *down_proj_scale,
|
||||
const paddle::Tensor *down_proj_bias,
|
||||
const paddle::Tensor *moe_token_type_ids,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor,
|
||||
const std::string moe_type,
|
||||
paddle::Tensor *output) {
|
||||
void
|
||||
ComputeFFN(const paddle::Tensor *input, const paddle::Tensor *gate_weight,
|
||||
const paddle::Tensor *up_gate_proj_weight,
|
||||
const paddle::Tensor *up_gate_proj_scale, const paddle::Tensor *up_gate_proj_bias,
|
||||
const paddle::Tensor *down_proj_weight,
|
||||
const paddle::Tensor *down_proj_scale, const paddle::Tensor *down_proj_bias,
|
||||
const paddle::Tensor *moe_token_type_ids, const int moe_topk,
|
||||
const bool group_moe, const bool norm_topk_prob,
|
||||
const float routed_scaling_factor, const std::string moe_type,
|
||||
paddle::Tensor *output) {
|
||||
auto *input_activations = input->data<T>();
|
||||
auto *gating_weights = gate_weight->data<float>();
|
||||
const T *fc1_expert_biases =
|
||||
up_gate_proj_bias ? up_gate_proj_bias->data<T>() : nullptr;
|
||||
const T *fc2_expert_biases =
|
||||
down_proj_bias ? down_proj_bias->data<T>() : nullptr;
|
||||
const T *fc1_expert_biases = up_gate_proj_bias ? up_gate_proj_bias->data<T>() : nullptr;
|
||||
const T *fc2_expert_biases = down_proj_bias ? down_proj_bias->data<T>() : nullptr;
|
||||
|
||||
auto *output_ = output->data<T>();
|
||||
auto stream = input->stream();
|
||||
@@ -166,8 +148,7 @@ class MoeHelper {
|
||||
const int64_t hidden_size = up_gate_proj_dims[1];
|
||||
int64_t inter_dim = 0;
|
||||
if (moe_type == "qkv") {
|
||||
inter_dim =
|
||||
up_gate_proj_dims[2] * up_gate_proj_dims[3] * up_gate_proj_dims[4];
|
||||
inter_dim = up_gate_proj_dims[2] * up_gate_proj_dims[3] * up_gate_proj_dims[4];
|
||||
} else {
|
||||
inter_dim = up_gate_proj_dims[2];
|
||||
}
|
||||
@@ -251,79 +232,44 @@ class MoeHelper {
|
||||
if (moe_token_type_ids) {
|
||||
auto *moe_token_type_ids_out = moe_token_type_ids->data<int>();
|
||||
moe_token_type_ids_kernelLauncher<float>(gating_output,
|
||||
moe_token_type_ids_out,
|
||||
num_rows,
|
||||
num_experts,
|
||||
k,
|
||||
stream);
|
||||
moe_token_type_ids_out, num_rows,
|
||||
num_experts, k, stream);
|
||||
}
|
||||
|
||||
topk_gating_softmax_kernelLauncher<float, int>(gating_output,
|
||||
nullptr,
|
||||
expert_scales_float,
|
||||
softmax_out_,
|
||||
expert_for_source_row,
|
||||
source_rows_,
|
||||
softmax_max_prob,
|
||||
num_rows,
|
||||
num_experts,
|
||||
k,
|
||||
group_moe,
|
||||
stream);
|
||||
topk_gating_softmax_kernelLauncher<float, int>(
|
||||
gating_output, nullptr, expert_scales_float, softmax_out_,
|
||||
expert_for_source_row, source_rows_, softmax_max_prob, num_rows,
|
||||
num_experts, k, group_moe, stream);
|
||||
|
||||
const int64_t sorter_ws_size_bytes =
|
||||
AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows)));
|
||||
|
||||
sorter_.run(fc1_result_,
|
||||
sorter_ws_size_bytes,
|
||||
expert_for_source_row,
|
||||
permuted_experts_,
|
||||
source_rows_,
|
||||
permuted_rows_,
|
||||
k * num_rows,
|
||||
false,
|
||||
stream);
|
||||
sorter_.run(fc1_result_, sorter_ws_size_bytes, expert_for_source_row,
|
||||
permuted_experts_, source_rows_, permuted_rows_, k * num_rows,
|
||||
false, stream);
|
||||
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input_activations,
|
||||
permuted_data_,
|
||||
permuted_rows_,
|
||||
nullptr,
|
||||
nullptr,
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
nullptr,
|
||||
num_rows,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
k,
|
||||
stream);
|
||||
input_activations, permuted_data_, permuted_rows_, nullptr, nullptr,
|
||||
expanded_source_row_to_expanded_dest_row, num_rows, num_rows,
|
||||
hidden_size, k, stream);
|
||||
|
||||
const int64_t expanded_active_expert_rows = k * num_rows;
|
||||
|
||||
compute_total_rows_before_expert(permuted_experts_,
|
||||
expanded_active_expert_rows,
|
||||
num_experts,
|
||||
total_rows_before_expert_,
|
||||
stream);
|
||||
expanded_active_expert_rows, num_experts,
|
||||
total_rows_before_expert_, stream);
|
||||
|
||||
if (gemm_method_ == "weight_only_int8") {
|
||||
typename Int8Traits::Arguments up_gate_proj_quant_args;
|
||||
int8_moe_gemm_runner_->moe_gemm_bias_act(
|
||||
reinterpret_cast<NvType *>(permuted_data_),
|
||||
reinterpret_cast<const uint8_t *>(
|
||||
up_gate_proj_weight->data<int8_t>()),
|
||||
reinterpret_cast<const uint8_t *>(up_gate_proj_weight->data<int8_t>()),
|
||||
reinterpret_cast<const NvType *>(up_gate_proj_scale->data<T>()),
|
||||
reinterpret_cast<const NvType *>(fc1_expert_biases),
|
||||
reinterpret_cast<NvType *>(fc1_out),
|
||||
total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows,
|
||||
inter_size,
|
||||
hidden_size,
|
||||
num_experts,
|
||||
up_gate_proj_quant_args,
|
||||
"none",
|
||||
stream);
|
||||
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
|
||||
up_gate_proj_quant_args, "none", stream);
|
||||
} else if (gemm_method_ == "weight_only_int4") {
|
||||
typename Int4Traits::Arguments up_gate_proj_quant_args;
|
||||
int4_moe_gemm_runner_->moe_gemm_bias_act(
|
||||
@@ -332,33 +278,20 @@ class MoeHelper {
|
||||
up_gate_proj_weight->data<int8_t>()),
|
||||
reinterpret_cast<const NvType *>(up_gate_proj_scale->data<T>()),
|
||||
reinterpret_cast<const NvType *>(fc1_expert_biases),
|
||||
reinterpret_cast<NvType *>(fc1_out),
|
||||
total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows,
|
||||
inter_size,
|
||||
hidden_size,
|
||||
num_experts,
|
||||
up_gate_proj_quant_args,
|
||||
"none",
|
||||
stream);
|
||||
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
|
||||
up_gate_proj_quant_args, "none", stream);
|
||||
} else {
|
||||
typename Fp16Traits::Arguments up_gate_proj_quant_args;
|
||||
fp16_moe_gemm_runner_->moe_gemm_bias_act(
|
||||
reinterpret_cast<NvType *>(permuted_data_),
|
||||
reinterpret_cast<const NvType *>(up_gate_proj_weight->data<T>()),
|
||||
nullptr,
|
||||
reinterpret_cast<const NvType *>(up_gate_proj_weight->data<T>()), nullptr,
|
||||
reinterpret_cast<const NvType *>(fc1_expert_biases),
|
||||
reinterpret_cast<NvType *>(fc1_out),
|
||||
total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows,
|
||||
inter_size,
|
||||
hidden_size,
|
||||
num_experts,
|
||||
up_gate_proj_quant_args,
|
||||
"none",
|
||||
stream);
|
||||
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
|
||||
up_gate_proj_quant_args, "none", stream);
|
||||
}
|
||||
|
||||
if (moe_type == "ffn") {
|
||||
@@ -376,15 +309,10 @@ class MoeHelper {
|
||||
reinterpret_cast<NvType *>(act_out),
|
||||
reinterpret_cast<const uint8_t *>(down_proj_weight->data<int8_t>()),
|
||||
reinterpret_cast<const NvType *>(down_proj_scale->data<T>()),
|
||||
reinterpret_cast<NvType *>(fc2_result),
|
||||
total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows,
|
||||
hidden_size,
|
||||
inter_size / 2,
|
||||
num_experts,
|
||||
down_proj_quant_args,
|
||||
stream);
|
||||
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows, hidden_size, inter_size / 2,
|
||||
num_experts, down_proj_quant_args, stream);
|
||||
} else if (gemm_method_ == "weight_only_int4") {
|
||||
typename Int4Traits::Arguments down_proj_quant_args;
|
||||
int4_moe_gemm_runner_->moe_gemm(
|
||||
@@ -392,66 +320,40 @@ class MoeHelper {
|
||||
reinterpret_cast<const cutlass::uint4b_t *>(
|
||||
down_proj_weight->data<int8_t>()),
|
||||
reinterpret_cast<const NvType *>(down_proj_scale->data<T>()),
|
||||
reinterpret_cast<NvType *>(fc2_result),
|
||||
total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows,
|
||||
hidden_size,
|
||||
inter_size / 2,
|
||||
num_experts,
|
||||
down_proj_quant_args,
|
||||
stream);
|
||||
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows, hidden_size, inter_size / 2,
|
||||
num_experts, down_proj_quant_args, stream);
|
||||
} else {
|
||||
typename Fp16Traits::Arguments down_proj_quant_args;
|
||||
fp16_moe_gemm_runner_->moe_gemm(
|
||||
reinterpret_cast<NvType *>(act_out),
|
||||
reinterpret_cast<const NvType *>(down_proj_weight->data<T>()),
|
||||
nullptr,
|
||||
reinterpret_cast<NvType *>(fc2_result),
|
||||
total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows,
|
||||
hidden_size,
|
||||
inter_size / 2,
|
||||
num_experts,
|
||||
down_proj_quant_args,
|
||||
stream);
|
||||
reinterpret_cast<const NvType *>(down_proj_weight->data<T>()), nullptr,
|
||||
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows, hidden_size, inter_size / 2,
|
||||
num_experts, down_proj_quant_args, stream);
|
||||
}
|
||||
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
fc2_result,
|
||||
output_,
|
||||
fc2_expert_biases,
|
||||
fc2_result, output_, fc2_expert_biases,
|
||||
reinterpret_cast<float *>(expert_scales_float),
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
expert_for_source_row,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
k,
|
||||
static_cast<int>(1),
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
expanded_source_row_to_expanded_dest_row, expert_for_source_row,
|
||||
num_rows, hidden_size, k, static_cast<int>(1), norm_topk_prob,
|
||||
routed_scaling_factor, stream);
|
||||
} else {
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
// fc2_result,
|
||||
fc1_out,
|
||||
output_,
|
||||
fc1_expert_biases, // fc2_expert_biases,
|
||||
fc1_out, output_,
|
||||
fc1_expert_biases, // fc2_expert_biases,
|
||||
reinterpret_cast<float *>(expert_scales_float),
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
expert_for_source_row,
|
||||
num_rows,
|
||||
inter_size,
|
||||
k,
|
||||
static_cast<int>(0),
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
expanded_source_row_to_expanded_dest_row, expert_for_source_row,
|
||||
num_rows, inter_size, k, static_cast<int>(0), norm_topk_prob,
|
||||
routed_scaling_factor, stream);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
std::string gemm_method_;
|
||||
MoeGemmRunner<NvType, Fp16Traits> *fp16_moe_gemm_runner_;
|
||||
MoeGemmRunner<NvType, Int8Traits> *int8_moe_gemm_runner_;
|
||||
@@ -460,4 +362,4 @@ class MoeHelper {
|
||||
CubKeyValueSorter sorter_;
|
||||
};
|
||||
|
||||
} // namespace phi
|
||||
} // namespace phi
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -26,26 +26,17 @@
|
||||
|
||||
template <paddle::DataType T>
|
||||
void MoeDispatchKernel(
|
||||
const paddle::Tensor &input,
|
||||
const paddle::Tensor &gating_output,
|
||||
const paddle::Tensor &input, const paddle::Tensor &gating_output,
|
||||
const paddle::optional<paddle::Tensor> &gating_correction_bias,
|
||||
const paddle::optional<paddle::Tensor> &w4a8_in_scale,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
const bool topk_only_mode,
|
||||
const int num_rows,
|
||||
const int hidden_size,
|
||||
const int expert_num,
|
||||
paddle::Tensor *permute_input,
|
||||
const paddle::optional<paddle::Tensor> &w4a8_in_scale, const int moe_topk,
|
||||
const bool group_moe, const bool topk_only_mode, const int num_rows,
|
||||
const int hidden_size, const int expert_num, paddle::Tensor *permute_input,
|
||||
paddle::Tensor *tokens_expert_prefix_sum,
|
||||
paddle::Tensor *permute_indices_per_token,
|
||||
paddle::Tensor *topk_weight,
|
||||
paddle::Tensor *topk_idx,
|
||||
paddle::Tensor *expert_idx_per_token,
|
||||
paddle::Tensor *dequant_scale) {
|
||||
paddle::Tensor *permute_indices_per_token, paddle::Tensor *topk_weight,
|
||||
paddle::Tensor *topk_idx, paddle::Tensor *expert_idx_per_token) {
|
||||
using namespace phi;
|
||||
|
||||
if (num_rows == 0) {
|
||||
if (num_rows == 0){
|
||||
return;
|
||||
}
|
||||
typedef PDTraits<T> traits_;
|
||||
@@ -57,14 +48,12 @@ void MoeDispatchKernel(
|
||||
|
||||
if (group_moe) {
|
||||
// Check if expert_num is divisible by moe_topk, else throw an error
|
||||
PADDLE_ENFORCE_EQ(expert_num % moe_topk,
|
||||
0,
|
||||
PADDLE_ENFORCE_EQ(expert_num % moe_topk, 0,
|
||||
common::errors::InvalidArgument(
|
||||
"The number of experts (expert_num) "
|
||||
"must be divisible by moe_topk. "
|
||||
"Got expert_num = %d and moe_topk = %d.",
|
||||
expert_num,
|
||||
moe_topk));
|
||||
expert_num, moe_topk));
|
||||
}
|
||||
|
||||
const int num_moe_inputs = AlignTo16(num_rows * moe_topk);
|
||||
@@ -79,8 +68,7 @@ void MoeDispatchKernel(
|
||||
|
||||
paddle::Tensor ws_ptr_tensor =
|
||||
GetEmptyTensor({bytes + sorter_ws_size_bytes + sort_tmp_in_out_size},
|
||||
paddle::DataType::INT8,
|
||||
place);
|
||||
paddle::DataType::INT8, place);
|
||||
|
||||
int8_t *ws_ptr = ws_ptr_tensor.data<int8_t>();
|
||||
int *source_rows_ = reinterpret_cast<int *>(ws_ptr);
|
||||
@@ -108,8 +96,8 @@ void MoeDispatchKernel(
|
||||
paddle::Tensor softmax_buffer;
|
||||
|
||||
if (!is_pow_2 || expert_num > 256 || group_moe || gating_correction_bias) {
|
||||
softmax_buffer = GetEmptyTensor(
|
||||
{num_rows * expert_num}, paddle::DataType::FLOAT32, place);
|
||||
softmax_buffer = GetEmptyTensor({num_rows * expert_num},
|
||||
paddle::DataType::FLOAT32, place);
|
||||
softmax_out_ = softmax_buffer.data<float>();
|
||||
} else {
|
||||
softmax_out_ = nullptr;
|
||||
@@ -119,106 +107,47 @@ void MoeDispatchKernel(
|
||||
gating_output.data<float>(),
|
||||
gating_correction_bias ? gating_correction_bias.get().data<float>()
|
||||
: nullptr,
|
||||
topk_weight->data<float>(),
|
||||
softmax_out_,
|
||||
topk_idx_ptr,
|
||||
source_rows_,
|
||||
softmax_max_prob,
|
||||
num_rows,
|
||||
expert_num,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
stream,
|
||||
topk_weight->data<float>(), softmax_out_, topk_idx_ptr, source_rows_,
|
||||
softmax_max_prob, num_rows, expert_num, moe_topk, group_moe, stream,
|
||||
topk_only_mode);
|
||||
|
||||
sorter_.run(reinterpret_cast<void *>(sorter_ws_ptr),
|
||||
sorter_ws_size_bytes,
|
||||
topk_idx_ptr,
|
||||
expert_idx_per_token->data<int32_t>(),
|
||||
source_rows_,
|
||||
permuted_rows_,
|
||||
moe_topk * num_rows,
|
||||
false,
|
||||
stream);
|
||||
sorter_.run(reinterpret_cast<void *>(sorter_ws_ptr), sorter_ws_size_bytes,
|
||||
topk_idx_ptr, expert_idx_per_token->data<int32_t>(), source_rows_,
|
||||
permuted_rows_, moe_topk * num_rows, false, stream);
|
||||
|
||||
if (w4a8_in_scale) {
|
||||
if (permute_input->dtype() == paddle::DataType::INT8) {
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input.data<data_t>(),
|
||||
permute_input->data<int8_t>(),
|
||||
permuted_rows_,
|
||||
expert_idx_per_token->data<int32_t>(),
|
||||
w4a8_in_scale->data<float>(),
|
||||
permute_indices_per_token->data<int32_t>(),
|
||||
nullptr,
|
||||
num_rows,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
moe_topk,
|
||||
stream);
|
||||
input.data<data_t>(), permute_input->data<int8_t>(), permuted_rows_,
|
||||
expert_idx_per_token->data<int32_t>(), w4a8_in_scale->data<float>(),
|
||||
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
|
||||
hidden_size, moe_topk, stream);
|
||||
} else if (permute_input->dtype() == paddle::DataType::FLOAT8_E4M3FN) {
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input.data<data_t>(),
|
||||
permute_input->data<float8_e4m3fn>(),
|
||||
permuted_rows_,
|
||||
expert_idx_per_token->data<int32_t>(),
|
||||
w4a8_in_scale->data<float>(),
|
||||
permute_indices_per_token->data<int32_t>(),
|
||||
nullptr,
|
||||
num_rows,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
moe_topk,
|
||||
stream);
|
||||
input.data<data_t>(), permute_input->data<float8_e4m3fn>(),
|
||||
permuted_rows_, expert_idx_per_token->data<int32_t>(),
|
||||
w4a8_in_scale->data<float>(),
|
||||
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
|
||||
hidden_size, moe_topk, stream);
|
||||
}
|
||||
} else {
|
||||
if (permute_input->dtype() == paddle::DataType::FLOAT8_E4M3FN) {
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input.data<data_t>(),
|
||||
permute_input->data<float8_e4m3fn>(),
|
||||
permuted_rows_,
|
||||
expert_idx_per_token->data<int32_t>(),
|
||||
nullptr,
|
||||
permute_indices_per_token->data<int32_t>(),
|
||||
dequant_scale->data<float>(),
|
||||
num_rows,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
moe_topk,
|
||||
stream);
|
||||
} else {
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input.data<data_t>(),
|
||||
permute_input->data<data_t>(),
|
||||
permuted_rows_,
|
||||
expert_idx_per_token->data<int32_t>(),
|
||||
nullptr,
|
||||
permute_indices_per_token->data<int32_t>(),
|
||||
nullptr,
|
||||
num_rows,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
moe_topk,
|
||||
stream);
|
||||
}
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input.data<data_t>(), permute_input->data<data_t>(), permuted_rows_,
|
||||
expert_idx_per_token->data<int32_t>(), nullptr,
|
||||
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
|
||||
hidden_size, moe_topk, stream);
|
||||
}
|
||||
|
||||
compute_total_rows_before_expert(expert_idx_per_token->data<int32_t>(),
|
||||
moe_topk * num_rows,
|
||||
expert_num,
|
||||
tokens_expert_prefix_sum->data<int64_t>(),
|
||||
stream);
|
||||
compute_total_rows_before_expert(
|
||||
expert_idx_per_token->data<int32_t>(), moe_topk * num_rows, expert_num,
|
||||
tokens_expert_prefix_sum->data<int64_t>(), stream);
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
const paddle::Tensor &input,
|
||||
const paddle::Tensor &gating_output,
|
||||
const paddle::Tensor &input, const paddle::Tensor &gating_output,
|
||||
const paddle::optional<paddle::Tensor> &gating_correction_bias,
|
||||
const paddle::optional<paddle::Tensor> &w4a8_in_scale,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
const std::string &moe_quant_type,
|
||||
const bool topk_only_mode) {
|
||||
const paddle::optional<paddle::Tensor> &w4a8_in_scale, const int moe_topk,
|
||||
const bool group_moe, const std::string &moe_quant_type, const bool topk_only_mode) {
|
||||
const auto input_type = input.dtype();
|
||||
auto place = input.place();
|
||||
int token_rows = 0;
|
||||
@@ -241,21 +170,10 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
} else if (moe_quant_type == "w4afp8") {
|
||||
permute_input_dtype = paddle::DataType::FLOAT8_E4M3FN;
|
||||
}
|
||||
} else {
|
||||
if (moe_quant_type == "w4afp8") {
|
||||
permute_input_dtype = paddle::DataType::FLOAT8_E4M3FN;
|
||||
}
|
||||
}
|
||||
|
||||
auto permute_input = GetEmptyTensor(
|
||||
{moe_topk * num_rows, hidden_size}, permute_input_dtype, place);
|
||||
int dequant_scale_size = 1;
|
||||
if (moe_quant_type == "w4afp8" && !w4a8_in_scale) {
|
||||
dequant_scale_size = moe_topk * num_rows;
|
||||
}
|
||||
|
||||
auto dequant_scale =
|
||||
GetEmptyTensor({dequant_scale_size}, paddle::DataType::FLOAT32, place);
|
||||
auto permute_input = GetEmptyTensor({moe_topk * num_rows, hidden_size},
|
||||
permute_input_dtype, place);
|
||||
// correspond to the weighted coefficients of the results from each expert.
|
||||
auto topk_weight =
|
||||
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||
@@ -270,65 +188,39 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
auto expert_idx_per_token =
|
||||
GetEmptyTensor({num_rows * moe_topk}, paddle::DataType::INT32, place);
|
||||
|
||||
if (token_rows == 0) {
|
||||
if (token_rows == 0){
|
||||
return {permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
permute_indices_per_token,
|
||||
topk_weight,
|
||||
topk_idx,
|
||||
expert_idx_per_token,
|
||||
dequant_scale};
|
||||
expert_idx_per_token};
|
||||
}
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeDispatchKernel<paddle::DataType::BFLOAT16>(input,
|
||||
gating_output,
|
||||
gating_correction_bias,
|
||||
w4a8_in_scale,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
topk_only_mode,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
expert_num,
|
||||
&permute_input,
|
||||
&tokens_expert_prefix_sum,
|
||||
&permute_indices_per_token,
|
||||
&topk_weight,
|
||||
&topk_idx,
|
||||
&expert_idx_per_token,
|
||||
&dequant_scale);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
MoeDispatchKernel<paddle::DataType::FLOAT16>(input,
|
||||
gating_output,
|
||||
gating_correction_bias,
|
||||
w4a8_in_scale,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
topk_only_mode,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
expert_num,
|
||||
&permute_input,
|
||||
&tokens_expert_prefix_sum,
|
||||
&permute_indices_per_token,
|
||||
&topk_weight,
|
||||
&topk_idx,
|
||||
&expert_idx_per_token,
|
||||
&dequant_scale);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for MoeDispatchKernel");
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeDispatchKernel<paddle::DataType::BFLOAT16>(
|
||||
input, gating_output, gating_correction_bias, w4a8_in_scale, moe_topk,
|
||||
group_moe, topk_only_mode, num_rows, hidden_size, expert_num,
|
||||
&permute_input, &tokens_expert_prefix_sum, &permute_indices_per_token,
|
||||
&topk_weight, &topk_idx, &expert_idx_per_token);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
MoeDispatchKernel<paddle::DataType::FLOAT16>(
|
||||
input, gating_output, gating_correction_bias, w4a8_in_scale, moe_topk,
|
||||
group_moe, topk_only_mode, num_rows, hidden_size, expert_num,
|
||||
&permute_input, &tokens_expert_prefix_sum, &permute_indices_per_token,
|
||||
&topk_weight, &topk_idx, &expert_idx_per_token);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for MoeDispatchKernel");
|
||||
}
|
||||
return {permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
permute_indices_per_token,
|
||||
topk_weight,
|
||||
topk_idx,
|
||||
expert_idx_per_token,
|
||||
dequant_scale};
|
||||
expert_idx_per_token};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
|
||||
@@ -353,22 +245,16 @@ std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
|
||||
{moe_topk, num_rows},
|
||||
{num_rows, moe_topk},
|
||||
{num_rows, moe_topk},
|
||||
{permuted_rows},
|
||||
{num_rows}};
|
||||
{permuted_rows}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
|
||||
const paddle::DataType &input_dtype,
|
||||
const paddle::DataType &gating_output_dtype,
|
||||
const paddle::optional<paddle::DataType> &bias_type,
|
||||
const int moe_topk) {
|
||||
return {input_dtype,
|
||||
paddle::DataType::INT64,
|
||||
paddle::DataType::INT32,
|
||||
paddle::DataType::FLOAT32,
|
||||
paddle::DataType::INT32,
|
||||
paddle::DataType::INT32,
|
||||
paddle::DataType::FLOAT32};
|
||||
std::vector<paddle::DataType>
|
||||
MoeExpertDispatchInferDtype(const paddle::DataType &input_dtype,
|
||||
const paddle::DataType &gating_output_dtype,
|
||||
const paddle::optional<paddle::DataType> &bias_type,
|
||||
const int moe_topk) {
|
||||
return {input_dtype, paddle::DataType::INT64, paddle::DataType::INT32,
|
||||
paddle::DataType::FLOAT32, paddle::DataType::INT32, paddle::DataType::INT32};
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -376,8 +262,7 @@ std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
|
||||
*
|
||||
* This operator performs the following key functions:
|
||||
* 1. Computes top-k experts for each input token based on gating scores
|
||||
* 2. Permutes input tokens according to their selected experts for efficient
|
||||
* expert processing
|
||||
* 2. Permutes input tokens according to their selected experts for efficient expert processing
|
||||
* 3. Computes prefix sums of tokens per expert for group_gemm optimization
|
||||
*
|
||||
* Inputs:
|
||||
@@ -387,17 +272,18 @@ std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
|
||||
* - gating_output: Gating network output scores for each token-expert pair
|
||||
* Shape: [total_tokens, expert_num]
|
||||
* dtype: must be float32
|
||||
* - gating_correction_bias: Optional bias term for gating correction
|
||||
* (expert_num)
|
||||
* - gating_correction_bias: Optional bias term for gating correction (expert_num)
|
||||
*
|
||||
* Outputs:
|
||||
* - permute_input: Permuted input tensor organized by expert
|
||||
* Shape: [moe_topk * total_tokens, hidden_size]
|
||||
* dtype: Same as input
|
||||
* - tokens_expert_prefix_sum: Prefix sum array of token counts per expert for
|
||||
* group_gemm Shape: [expert_num] dtype: int64
|
||||
* - permute_indices_per_token: Indices mapping for reconstructing original
|
||||
* order Shape: [moe_topk, total_tokens] dtype: int32
|
||||
* - tokens_expert_prefix_sum: Prefix sum array of token counts per expert for group_gemm
|
||||
* Shape: [expert_num]
|
||||
* dtype: int64
|
||||
* - permute_indices_per_token: Indices mapping for reconstructing original order
|
||||
* Shape: [moe_topk, total_tokens]
|
||||
* dtype: int32
|
||||
* - top_k_weight: Weight coefficients for combining expert outputs
|
||||
* Shape: [total_tokens, moe_topk]
|
||||
* dtype: float32
|
||||
@@ -406,8 +292,7 @@ std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
|
||||
* dtype: int32
|
||||
*
|
||||
* Attributes:
|
||||
* - moe_topk: Number of experts to select for each token (k value in top-k
|
||||
* routing)
|
||||
* - moe_topk: Number of experts to select for each token (k value in top-k routing)
|
||||
* - group_moe: Whether to perform group softmax within the operator
|
||||
* (true: softmax is computed within groups of experts,
|
||||
* false: standard softmax across all experts)
|
||||
@@ -421,21 +306,13 @@ std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
|
||||
* - When group_moe is true, expert_num must be divisible by moe_topk
|
||||
*/
|
||||
PD_BUILD_STATIC_OP(moe_expert_dispatch)
|
||||
.Inputs({"input",
|
||||
"gating_output",
|
||||
.Inputs({"input", "gating_output",
|
||||
paddle::Optional("gating_correction_bias"),
|
||||
paddle::Optional("w4a8_in_scale")})
|
||||
.Outputs({"permute_input",
|
||||
"tokens_expert_prefix_sum",
|
||||
"permute_indices_per_token",
|
||||
"topk_weight",
|
||||
"topk_idx",
|
||||
"expert_idx_per_token",
|
||||
"dequant_scale"})
|
||||
.Attrs({"moe_topk:int",
|
||||
"group_moe:bool",
|
||||
"moe_quant_type:std::string",
|
||||
"topk_only_mode:bool"})
|
||||
.Outputs({"permute_input", "tokens_expert_prefix_sum",
|
||||
"permute_indices_per_token", "topk_weight", "topk_idx",
|
||||
"expert_idx_per_token"})
|
||||
.Attrs({"moe_topk:int", "group_moe:bool", "moe_quant_type:std::string", "topk_only_mode:bool"})
|
||||
.SetKernelFn(PD_KERNEL(MoeExpertDispatch))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertDispatchInferDtype));
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "group_swiglu_with_masked.h"
|
||||
#include "helper.h"
|
||||
#include "moe/fast_hardmard/fast_hardamard_kernel.h"
|
||||
#include "moe/fused_moe_helper.h"
|
||||
|
||||
template <typename DataT,
|
||||
|
||||
@@ -18,8 +18,8 @@
|
||||
#include "cutlass_kernels/w4a8_moe/w4a8_moe_gemm_kernel.h"
|
||||
#include "group_swiglu_with_masked.h"
|
||||
#include "helper.h"
|
||||
#include "moe/fast_hardmard/fast_hardamard_kernel.h"
|
||||
#include "moe/fused_moe_helper.h"
|
||||
#include "moe/moe_fast_hardamard_kernel.h"
|
||||
#include "swigluoai.h"
|
||||
#include "w4afp8_gemm/w4afp8_gemm.h"
|
||||
|
||||
@@ -28,7 +28,6 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& up_gate_proj_weight,
|
||||
const paddle::Tensor& down_proj_weight,
|
||||
const paddle::optional<paddle::Tensor>& up_proj_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||
@@ -198,20 +197,31 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
typedef PDTraits<paddle::DataType::FLOAT8_E4M3FN> traits_fp8;
|
||||
typedef typename traits_fp8::DataType DataType_fp8;
|
||||
typedef typename traits_fp8::data_t data_t_fp8;
|
||||
paddle::Tensor weight_scale_tensor =
|
||||
*const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr());
|
||||
const int weight_scale_group_size = weight_scale_tensor.dims().size() == 2
|
||||
? hidden_size
|
||||
: weight_scale_tensor.dims()[3];
|
||||
const float* input_dequant_scale =
|
||||
up_proj_in_scale ? up_proj_in_scale.get().data<float>() : nullptr;
|
||||
|
||||
Allocator::AllocationPtr ffn1_input_row_sum;
|
||||
ffn1_input_row_sum =
|
||||
allocator->Allocate(sizeof(float) * expanded_active_expert_rows);
|
||||
|
||||
compute_row_sum(
|
||||
permute_input.data<data_t_fp8>(),
|
||||
expanded_active_expert_rows,
|
||||
hidden_size,
|
||||
reinterpret_cast<float*>(ffn1_input_row_sum->ptr()),
|
||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||
num_max_tokens_per_expert,
|
||||
used_in_ep_low_latency,
|
||||
stream);
|
||||
|
||||
float* row_scale = nullptr;
|
||||
DisPatchW4AFp8GemmWrapper(
|
||||
reinterpret_cast<const DataType_fp8*>(permute_input.data<data_t_fp8>()),
|
||||
reinterpret_cast<const DataType_fp8*>(
|
||||
up_gate_proj_weight.data<int8_t>()),
|
||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||
input_dequant_scale,
|
||||
weight_scale_tensor.data<float>(),
|
||||
reinterpret_cast<float*>(ffn1_input_row_sum->ptr()),
|
||||
row_scale,
|
||||
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr())
|
||||
->data<float>(),
|
||||
reinterpret_cast<NvType*>(fc1_out),
|
||||
used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
|
||||
used_in_ep_low_latency ? num_max_tokens_per_expert
|
||||
@@ -219,7 +229,6 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
num_experts,
|
||||
inter_size,
|
||||
hidden_size,
|
||||
weight_scale_group_size,
|
||||
stream);
|
||||
} else {
|
||||
typename cutlass::WintQuantTraits<
|
||||
@@ -346,84 +355,60 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
} else if (quant_method == "w4afp8") {
|
||||
data_t* ffn2_shift = nullptr;
|
||||
data_t* ffn2_smooth = nullptr;
|
||||
float* input_dequant_scale = nullptr;
|
||||
float* row_scale = nullptr;
|
||||
Allocator::AllocationPtr fp8_act_out;
|
||||
fp8_act_out = allocator->Allocate(SizeOf(paddle::DataType::INT8) *
|
||||
act_out_tensor.numel());
|
||||
Allocator::AllocationPtr ffn2_input_row_sum;
|
||||
ffn2_input_row_sum =
|
||||
allocator->Allocate(sizeof(float) * expanded_active_expert_rows);
|
||||
|
||||
if (down_proj_in_scale) {
|
||||
MoeFastHardamardWrapper<data_t, data_t_fp8>(
|
||||
act_out_tensor.data<data_t>(),
|
||||
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>()
|
||||
: nullptr,
|
||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||
ffn2_shift,
|
||||
ffn2_smooth,
|
||||
down_proj_in_scale
|
||||
? const_cast<paddle::Tensor*>(down_proj_in_scale.get_ptr())
|
||||
->data<float>()
|
||||
: nullptr,
|
||||
1,
|
||||
448.0f,
|
||||
-448.0f,
|
||||
expanded_active_expert_rows,
|
||||
inter_size / 2,
|
||||
num_max_tokens_per_expert,
|
||||
used_in_ep_low_latency,
|
||||
hadamard_block_size,
|
||||
reinterpret_cast<data_t_fp8*>(fp8_act_out->ptr()),
|
||||
stream);
|
||||
} else {
|
||||
Allocator::AllocationPtr ffn2_input_dequant_scale;
|
||||
ffn2_input_dequant_scale =
|
||||
allocator->Allocate(sizeof(float) * expanded_active_expert_rows);
|
||||
input_dequant_scale =
|
||||
reinterpret_cast<float*>(ffn2_input_dequant_scale->ptr());
|
||||
MoeFastHardamardWrapper<data_t, data_t>(
|
||||
act_out_tensor.data<data_t>(),
|
||||
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>()
|
||||
: nullptr,
|
||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||
ffn2_shift, // ffn2_shift->data<T>(),
|
||||
ffn2_smooth, // ffn2_smooth->data<T>(),
|
||||
nullptr,
|
||||
1,
|
||||
448.0f,
|
||||
-448.0f,
|
||||
expanded_active_expert_rows,
|
||||
inter_size / 2,
|
||||
num_max_tokens_per_expert,
|
||||
used_in_ep_low_latency,
|
||||
hadamard_block_size,
|
||||
act_out_tensor.data<data_t>(),
|
||||
stream);
|
||||
// note(yuanxiaolan): optimize this
|
||||
MoeFastHardamardWrapper<data_t, data_t>(
|
||||
act_out_tensor.data<data_t>(),
|
||||
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>()
|
||||
: nullptr,
|
||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||
ffn2_shift, // ffn2_shift->data<T>(),
|
||||
ffn2_smooth, // ffn2_smooth->data<T>(),
|
||||
nullptr,
|
||||
1,
|
||||
448.0f,
|
||||
-448.0f,
|
||||
expanded_active_expert_rows,
|
||||
inter_size / 2,
|
||||
num_max_tokens_per_expert,
|
||||
used_in_ep_low_latency,
|
||||
hadamard_block_size,
|
||||
act_out_tensor.data<data_t>(),
|
||||
stream);
|
||||
|
||||
quantize_moe_input<data_t, data_t_fp8>(
|
||||
act_out_tensor.data<data_t>(),
|
||||
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>()
|
||||
: nullptr,
|
||||
expanded_active_expert_rows,
|
||||
inter_size / 2,
|
||||
input_dequant_scale,
|
||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||
num_max_tokens_per_expert,
|
||||
used_in_ep_low_latency,
|
||||
reinterpret_cast<data_t_fp8*>(fp8_act_out->ptr()),
|
||||
stream);
|
||||
}
|
||||
|
||||
paddle::Tensor weight_scale_tensor =
|
||||
*const_cast<paddle::Tensor*>(down_proj_scale.get_ptr());
|
||||
const int weight_scale_group_size = weight_scale_tensor.dims().size() == 2
|
||||
? inter_size / 2
|
||||
: weight_scale_tensor.dims()[3];
|
||||
quantize_moe_input<data_t, data_t_fp8>(
|
||||
act_out_tensor.data<data_t>(),
|
||||
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>()
|
||||
: nullptr,
|
||||
down_proj_in_scale
|
||||
? const_cast<paddle::Tensor*>(down_proj_in_scale.get_ptr())
|
||||
->data<float>()
|
||||
: nullptr,
|
||||
448.0f,
|
||||
-448.0f,
|
||||
expanded_active_expert_rows,
|
||||
inter_size / 2,
|
||||
reinterpret_cast<float*>(ffn2_input_row_sum->ptr()),
|
||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||
num_max_tokens_per_expert,
|
||||
used_in_ep_low_latency,
|
||||
reinterpret_cast<data_t_fp8*>(fp8_act_out->ptr()),
|
||||
stream);
|
||||
|
||||
DisPatchW4AFp8GemmWrapper(
|
||||
reinterpret_cast<const DataType_fp8*>(fp8_act_out->ptr()),
|
||||
reinterpret_cast<const DataType_fp8*>(down_proj_weight.data<int8_t>()),
|
||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||
input_dequant_scale,
|
||||
weight_scale_tensor.data<float>(),
|
||||
reinterpret_cast<float*>(ffn2_input_row_sum->ptr()),
|
||||
row_scale,
|
||||
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr())->data<float>(),
|
||||
reinterpret_cast<NvType*>(ffn_out_data),
|
||||
used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
|
||||
used_in_ep_low_latency ? num_max_tokens_per_expert
|
||||
@@ -431,7 +416,6 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
num_experts,
|
||||
hidden_size,
|
||||
inter_size / 2,
|
||||
weight_scale_group_size,
|
||||
stream);
|
||||
} else {
|
||||
typename cutlass::WintQuantTraits<
|
||||
@@ -458,7 +442,6 @@ paddle::Tensor MoeExpertFFNFunc(
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& up_gate_proj_weight,
|
||||
const paddle::Tensor& down_proj_weight,
|
||||
const paddle::optional<paddle::Tensor>& up_proj_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||
@@ -483,7 +466,6 @@ paddle::Tensor MoeExpertFFNFunc(
|
||||
tokens_expert_prefix_sum,
|
||||
up_gate_proj_weight,
|
||||
down_proj_weight,
|
||||
up_proj_in_scale,
|
||||
up_gate_proj_bias,
|
||||
up_gate_proj_scale,
|
||||
down_proj_scale,
|
||||
@@ -501,7 +483,6 @@ paddle::Tensor MoeExpertFFNFunc(
|
||||
tokens_expert_prefix_sum,
|
||||
up_gate_proj_weight,
|
||||
down_proj_weight,
|
||||
up_proj_in_scale,
|
||||
up_gate_proj_bias,
|
||||
up_gate_proj_scale,
|
||||
down_proj_scale,
|
||||
@@ -525,7 +506,6 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& up_gate_proj_weight,
|
||||
const paddle::Tensor& down_proj_weight,
|
||||
const paddle::optional<paddle::Tensor>& up_proj_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||
@@ -540,7 +520,6 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
tokens_expert_prefix_sum,
|
||||
up_gate_proj_weight,
|
||||
down_proj_weight,
|
||||
up_proj_in_scale,
|
||||
up_gate_proj_bias,
|
||||
up_gate_proj_scale,
|
||||
down_proj_scale,
|
||||
@@ -558,7 +537,6 @@ std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
|
||||
const std::vector<int64_t>& tokens_expert_prefix_sum_shape,
|
||||
const std::vector<int64_t>& up_gate_proj_weight_shape,
|
||||
const std::vector<int64_t>& down_proj_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& up_proj_in_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& up_gate_proj_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& up_gate_proj_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& down_proj_scale_shape,
|
||||
@@ -577,7 +555,6 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
|
||||
const paddle::DataType& tokens_expert_prefix_sum_dtype,
|
||||
const paddle::DataType& up_gate_proj_weight_dtype,
|
||||
const paddle::DataType& down_proj_weight_dtype,
|
||||
const paddle::optional<paddle::DataType>& up_proj_in_scale_dtype,
|
||||
const paddle::optional<paddle::DataType>& up_gate_proj_bias_dtype,
|
||||
const paddle::optional<paddle::DataType>& up_gate_proj_scale_dtype,
|
||||
const paddle::optional<paddle::DataType>& down_proj_scale_dtype,
|
||||
@@ -655,7 +632,6 @@ PD_BUILD_STATIC_OP(moe_expert_ffn)
|
||||
"tokens_expert_prefix_sum",
|
||||
"up_gate_proj_weight",
|
||||
"down_proj_weight",
|
||||
paddle::Optional("up_proj_in_scale"),
|
||||
paddle::Optional("up_gate_proj_bias"),
|
||||
paddle::Optional("up_gate_proj_scale"),
|
||||
paddle::Optional("down_proj_scale"),
|
||||
|
||||
Reference in New Issue
Block a user