Revert "【New Feature】W4afp8 supports per group quantization (#4272)" (#4854)

This reverts commit 93fcf7e4ec.
This commit is contained in:
YuBaoku
2025-11-06 17:48:28 +08:00
committed by GitHub
parent 3478d20262
commit 819b2dbbae
26 changed files with 1718 additions and 4378 deletions

View File

@@ -304,7 +304,6 @@ paddle::Tensor MoeExpertFFNFunc(
const paddle::Tensor& tokens_expert_prefix_sum, const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight, const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight, const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_proj_in_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias, const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale, const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale, const paddle::optional<paddle::Tensor>& down_proj_scale,

View File

@@ -1,37 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "helper.h"
template <typename T, typename OutT>
void MoeFastHardamardWrapper(const T *x_data,
const int64_t *expert_idx_per_token,
const int64_t *recv_expert_count,
const T *shift,
const T *smooth,
const float *quant_scales,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
const int hadamard_block_size,
OutT *out,
cudaStream_t &stream);

File diff suppressed because it is too large Load Diff

View File

@@ -1,34 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fast_hardamard_kernel.hpp"
template void
MoeFastHardamardWrapper<phi::dtype::bfloat16, phi::dtype::bfloat16>(
const phi::dtype::bfloat16 *x_data,
const int64_t *expert_idx_per_token,
const int64_t *recv_expert_count,
const phi::dtype::bfloat16 *shift,
const phi::dtype::bfloat16 *smooth,
const float *quant_scales,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
const int hadamard_block_size,
phi::dtype::bfloat16 *out,
cudaStream_t &stream);

View File

@@ -1,34 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fast_hardamard_kernel.hpp"
template void
MoeFastHardamardWrapper<phi::dtype::bfloat16, phi::dtype::float8_e4m3fn>(
const phi::dtype::bfloat16 *x_data,
const int64_t *expert_idx_per_token,
const int64_t *recv_expert_count,
const phi::dtype::bfloat16 *shift,
const phi::dtype::bfloat16 *smooth,
const float *quant_scales,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
const int hadamard_block_size,
phi::dtype::float8_e4m3fn *out,
cudaStream_t &stream);

View File

@@ -1,33 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fast_hardamard_kernel.hpp"
template void MoeFastHardamardWrapper<phi::dtype::bfloat16, int8_t>(
const phi::dtype::bfloat16 *x_data,
const int64_t *expert_idx_per_token,
const int64_t *recv_expert_count,
const phi::dtype::bfloat16 *shift,
const phi::dtype::bfloat16 *smooth,
const float *quant_scales,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
const int hadamard_block_size,
int8_t *out,
cudaStream_t &stream);

View File

@@ -1,33 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fast_hardamard_kernel.hpp"
template void MoeFastHardamardWrapper<phi::dtype::float16, phi::dtype::float16>(
const phi::dtype::float16 *x_data,
const int64_t *expert_idx_per_token,
const int64_t *recv_expert_count,
const phi::dtype::float16 *shift,
const phi::dtype::float16 *smooth,
const float *quant_scales,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
const int hadamard_block_size,
phi::dtype::float16 *out,
cudaStream_t &stream);

View File

@@ -1,33 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fast_hardamard_kernel.hpp"
template void MoeFastHardamardWrapper<phi::dtype::float16, int8_t>(
const phi::dtype::float16 *x_data,
const int64_t *expert_idx_per_token,
const int64_t *recv_expert_count,
const phi::dtype::float16 *shift,
const phi::dtype::float16 *smooth,
const float *quant_scales,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
const int hadamard_block_size,
int8_t *out,
cudaStream_t &stream);

View File

@@ -25,8 +25,7 @@ template <typename T, int VecSize>
__global__ void moe_token_type_ids_kernel(T *gating_output, __global__ void moe_token_type_ids_kernel(T *gating_output,
const int *moe_token_type_ids_out, const int *moe_token_type_ids_out,
const int num_rows, const int num_rows,
const int num_experts, const int num_experts, const int k) {
const int k) {
const int moe_token_index = blockIdx.x * blockDim.x + threadIdx.x; const int moe_token_index = blockIdx.x * blockDim.x + threadIdx.x;
if (moe_token_index >= num_rows) { if (moe_token_index >= num_rows) {
@@ -45,8 +44,7 @@ template <typename T>
void moe_token_type_ids_kernelLauncher(T *gating_output, void moe_token_type_ids_kernelLauncher(T *gating_output,
const int *moe_token_type_ids_out, const int *moe_token_type_ids_out,
const int num_rows, const int num_rows,
const int num_experts, const int num_experts, const int k,
const int k,
cudaStream_t stream) { cudaStream_t stream) {
const int blocks = num_rows * k / 512 + 1; const int blocks = num_rows * k / 512 + 1;
const int threads = 512; const int threads = 512;
@@ -54,35 +52,26 @@ void moe_token_type_ids_kernelLauncher(T *gating_output,
gating_output, moe_token_type_ids_out, num_rows, num_experts, k); gating_output, moe_token_type_ids_out, num_rows, num_experts, k);
} }
template <typename T, typename NvType> template <typename T, typename NvType> class MoeHelper {
class MoeHelper {
public: public:
using Fp16Traits = using Fp16Traits = cutlass::WintQuantTraits<NvType, cutlass::WintQuantMethod::kNone>;
cutlass::WintQuantTraits<NvType, cutlass::WintQuantMethod::kNone>; using Int8Traits = cutlass::WintQuantTraits<NvType, cutlass::WintQuantMethod::kWeightOnlyInt8>;
using Int8Traits = using Int4Traits = cutlass::WintQuantTraits<NvType, cutlass::WintQuantMethod::kWeightOnlyInt4>;
cutlass::WintQuantTraits<NvType,
cutlass::WintQuantMethod::kWeightOnlyInt8>;
using Int4Traits =
cutlass::WintQuantTraits<NvType,
cutlass::WintQuantMethod::kWeightOnlyInt4>;
MoeHelper(const std::string gemm_method, MoeHelper(
const std::string gemm_method,
MoeGemmRunner<NvType, Fp16Traits> *fp16_moe_gemm_runner, MoeGemmRunner<NvType, Fp16Traits> *fp16_moe_gemm_runner,
MoeGemmRunner<NvType, Int8Traits> *int8_moe_gemm_runner, MoeGemmRunner<NvType, Int8Traits> *int8_moe_gemm_runner,
MoeGemmRunner<NvType, Int4Traits> *int4_moe_gemm_runner, MoeGemmRunner<NvType, Int4Traits> *int4_moe_gemm_runner,
int layernum = 0) int layernum = 0)
: gemm_method_(gemm_method), : gemm_method_(gemm_method), fp16_moe_gemm_runner_(fp16_moe_gemm_runner),
fp16_moe_gemm_runner_(fp16_moe_gemm_runner),
int8_moe_gemm_runner_(int8_moe_gemm_runner), int8_moe_gemm_runner_(int8_moe_gemm_runner),
int4_moe_gemm_runner_(int4_moe_gemm_runner), int4_moe_gemm_runner_(int4_moe_gemm_runner), layernum_(layernum) {}
layernum_(layernum) {}
// -------- getWorkspaceSize -------- // // -------- getWorkspaceSize -------- //
template <typename KeyT> template <typename KeyT>
size_t getWorkspaceSize(const int64_t num_rows, size_t getWorkspaceSize(const int64_t num_rows, const int64_t hidden_size,
const int64_t hidden_size, const int64_t inter_size, const int64_t num_experts,
const int64_t inter_size,
const int64_t num_experts,
const int64_t k) { const int64_t k) {
const size_t buf_size = AlignTo16(k * num_rows * hidden_size); const size_t buf_size = AlignTo16(k * num_rows * hidden_size);
const size_t interbuf_size = AlignTo16(k * num_rows * inter_size); const size_t interbuf_size = AlignTo16(k * num_rows * inter_size);
@@ -126,27 +115,20 @@ class MoeHelper {
return total_ws_bytes; return total_ws_bytes;
} }
void ComputeFFN(const paddle::Tensor *input, void
const paddle::Tensor *gate_weight, ComputeFFN(const paddle::Tensor *input, const paddle::Tensor *gate_weight,
const paddle::Tensor *up_gate_proj_weight, const paddle::Tensor *up_gate_proj_weight,
const paddle::Tensor *up_gate_proj_scale, const paddle::Tensor *up_gate_proj_scale, const paddle::Tensor *up_gate_proj_bias,
const paddle::Tensor *up_gate_proj_bias,
const paddle::Tensor *down_proj_weight, const paddle::Tensor *down_proj_weight,
const paddle::Tensor *down_proj_scale, const paddle::Tensor *down_proj_scale, const paddle::Tensor *down_proj_bias,
const paddle::Tensor *down_proj_bias, const paddle::Tensor *moe_token_type_ids, const int moe_topk,
const paddle::Tensor *moe_token_type_ids, const bool group_moe, const bool norm_topk_prob,
const int moe_topk, const float routed_scaling_factor, const std::string moe_type,
const bool group_moe,
const bool norm_topk_prob,
const float routed_scaling_factor,
const std::string moe_type,
paddle::Tensor *output) { paddle::Tensor *output) {
auto *input_activations = input->data<T>(); auto *input_activations = input->data<T>();
auto *gating_weights = gate_weight->data<float>(); auto *gating_weights = gate_weight->data<float>();
const T *fc1_expert_biases = const T *fc1_expert_biases = up_gate_proj_bias ? up_gate_proj_bias->data<T>() : nullptr;
up_gate_proj_bias ? up_gate_proj_bias->data<T>() : nullptr; const T *fc2_expert_biases = down_proj_bias ? down_proj_bias->data<T>() : nullptr;
const T *fc2_expert_biases =
down_proj_bias ? down_proj_bias->data<T>() : nullptr;
auto *output_ = output->data<T>(); auto *output_ = output->data<T>();
auto stream = input->stream(); auto stream = input->stream();
@@ -166,8 +148,7 @@ class MoeHelper {
const int64_t hidden_size = up_gate_proj_dims[1]; const int64_t hidden_size = up_gate_proj_dims[1];
int64_t inter_dim = 0; int64_t inter_dim = 0;
if (moe_type == "qkv") { if (moe_type == "qkv") {
inter_dim = inter_dim = up_gate_proj_dims[2] * up_gate_proj_dims[3] * up_gate_proj_dims[4];
up_gate_proj_dims[2] * up_gate_proj_dims[3] * up_gate_proj_dims[4];
} else { } else {
inter_dim = up_gate_proj_dims[2]; inter_dim = up_gate_proj_dims[2];
} }
@@ -251,79 +232,44 @@ class MoeHelper {
if (moe_token_type_ids) { if (moe_token_type_ids) {
auto *moe_token_type_ids_out = moe_token_type_ids->data<int>(); auto *moe_token_type_ids_out = moe_token_type_ids->data<int>();
moe_token_type_ids_kernelLauncher<float>(gating_output, moe_token_type_ids_kernelLauncher<float>(gating_output,
moe_token_type_ids_out, moe_token_type_ids_out, num_rows,
num_rows, num_experts, k, stream);
num_experts,
k,
stream);
} }
topk_gating_softmax_kernelLauncher<float, int>(gating_output, topk_gating_softmax_kernelLauncher<float, int>(
nullptr, gating_output, nullptr, expert_scales_float, softmax_out_,
expert_scales_float, expert_for_source_row, source_rows_, softmax_max_prob, num_rows,
softmax_out_, num_experts, k, group_moe, stream);
expert_for_source_row,
source_rows_,
softmax_max_prob,
num_rows,
num_experts,
k,
group_moe,
stream);
const int64_t sorter_ws_size_bytes = const int64_t sorter_ws_size_bytes =
AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows))); AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows)));
sorter_.run(fc1_result_, sorter_.run(fc1_result_, sorter_ws_size_bytes, expert_for_source_row,
sorter_ws_size_bytes, permuted_experts_, source_rows_, permuted_rows_, k * num_rows,
expert_for_source_row, false, stream);
permuted_experts_,
source_rows_,
permuted_rows_,
k * num_rows,
false,
stream);
initialize_moe_routing_kernelLauncher( initialize_moe_routing_kernelLauncher(
input_activations, input_activations, permuted_data_, permuted_rows_, nullptr, nullptr,
permuted_data_, expanded_source_row_to_expanded_dest_row, num_rows, num_rows,
permuted_rows_, hidden_size, k, stream);
nullptr,
nullptr,
expanded_source_row_to_expanded_dest_row,
nullptr,
num_rows,
num_rows,
hidden_size,
k,
stream);
const int64_t expanded_active_expert_rows = k * num_rows; const int64_t expanded_active_expert_rows = k * num_rows;
compute_total_rows_before_expert(permuted_experts_, compute_total_rows_before_expert(permuted_experts_,
expanded_active_expert_rows, expanded_active_expert_rows, num_experts,
num_experts, total_rows_before_expert_, stream);
total_rows_before_expert_,
stream);
if (gemm_method_ == "weight_only_int8") { if (gemm_method_ == "weight_only_int8") {
typename Int8Traits::Arguments up_gate_proj_quant_args; typename Int8Traits::Arguments up_gate_proj_quant_args;
int8_moe_gemm_runner_->moe_gemm_bias_act( int8_moe_gemm_runner_->moe_gemm_bias_act(
reinterpret_cast<NvType *>(permuted_data_), reinterpret_cast<NvType *>(permuted_data_),
reinterpret_cast<const uint8_t *>( reinterpret_cast<const uint8_t *>(up_gate_proj_weight->data<int8_t>()),
up_gate_proj_weight->data<int8_t>()),
reinterpret_cast<const NvType *>(up_gate_proj_scale->data<T>()), reinterpret_cast<const NvType *>(up_gate_proj_scale->data<T>()),
reinterpret_cast<const NvType *>(fc1_expert_biases), reinterpret_cast<const NvType *>(fc1_expert_biases),
reinterpret_cast<NvType *>(fc1_out), reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
total_rows_before_expert_,
-1, // useless -1, // useless
expanded_active_expert_rows, expanded_active_expert_rows, inter_size, hidden_size, num_experts,
inter_size, up_gate_proj_quant_args, "none", stream);
hidden_size,
num_experts,
up_gate_proj_quant_args,
"none",
stream);
} else if (gemm_method_ == "weight_only_int4") { } else if (gemm_method_ == "weight_only_int4") {
typename Int4Traits::Arguments up_gate_proj_quant_args; typename Int4Traits::Arguments up_gate_proj_quant_args;
int4_moe_gemm_runner_->moe_gemm_bias_act( int4_moe_gemm_runner_->moe_gemm_bias_act(
@@ -332,33 +278,20 @@ class MoeHelper {
up_gate_proj_weight->data<int8_t>()), up_gate_proj_weight->data<int8_t>()),
reinterpret_cast<const NvType *>(up_gate_proj_scale->data<T>()), reinterpret_cast<const NvType *>(up_gate_proj_scale->data<T>()),
reinterpret_cast<const NvType *>(fc1_expert_biases), reinterpret_cast<const NvType *>(fc1_expert_biases),
reinterpret_cast<NvType *>(fc1_out), reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
total_rows_before_expert_,
-1, // useless -1, // useless
expanded_active_expert_rows, expanded_active_expert_rows, inter_size, hidden_size, num_experts,
inter_size, up_gate_proj_quant_args, "none", stream);
hidden_size,
num_experts,
up_gate_proj_quant_args,
"none",
stream);
} else { } else {
typename Fp16Traits::Arguments up_gate_proj_quant_args; typename Fp16Traits::Arguments up_gate_proj_quant_args;
fp16_moe_gemm_runner_->moe_gemm_bias_act( fp16_moe_gemm_runner_->moe_gemm_bias_act(
reinterpret_cast<NvType *>(permuted_data_), reinterpret_cast<NvType *>(permuted_data_),
reinterpret_cast<const NvType *>(up_gate_proj_weight->data<T>()), reinterpret_cast<const NvType *>(up_gate_proj_weight->data<T>()), nullptr,
nullptr,
reinterpret_cast<const NvType *>(fc1_expert_biases), reinterpret_cast<const NvType *>(fc1_expert_biases),
reinterpret_cast<NvType *>(fc1_out), reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
total_rows_before_expert_,
-1, // useless -1, // useless
expanded_active_expert_rows, expanded_active_expert_rows, inter_size, hidden_size, num_experts,
inter_size, up_gate_proj_quant_args, "none", stream);
hidden_size,
num_experts,
up_gate_proj_quant_args,
"none",
stream);
} }
if (moe_type == "ffn") { if (moe_type == "ffn") {
@@ -376,15 +309,10 @@ class MoeHelper {
reinterpret_cast<NvType *>(act_out), reinterpret_cast<NvType *>(act_out),
reinterpret_cast<const uint8_t *>(down_proj_weight->data<int8_t>()), reinterpret_cast<const uint8_t *>(down_proj_weight->data<int8_t>()),
reinterpret_cast<const NvType *>(down_proj_scale->data<T>()), reinterpret_cast<const NvType *>(down_proj_scale->data<T>()),
reinterpret_cast<NvType *>(fc2_result), reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
total_rows_before_expert_,
-1, // useless -1, // useless
expanded_active_expert_rows, expanded_active_expert_rows, hidden_size, inter_size / 2,
hidden_size, num_experts, down_proj_quant_args, stream);
inter_size / 2,
num_experts,
down_proj_quant_args,
stream);
} else if (gemm_method_ == "weight_only_int4") { } else if (gemm_method_ == "weight_only_int4") {
typename Int4Traits::Arguments down_proj_quant_args; typename Int4Traits::Arguments down_proj_quant_args;
int4_moe_gemm_runner_->moe_gemm( int4_moe_gemm_runner_->moe_gemm(
@@ -392,62 +320,36 @@ class MoeHelper {
reinterpret_cast<const cutlass::uint4b_t *>( reinterpret_cast<const cutlass::uint4b_t *>(
down_proj_weight->data<int8_t>()), down_proj_weight->data<int8_t>()),
reinterpret_cast<const NvType *>(down_proj_scale->data<T>()), reinterpret_cast<const NvType *>(down_proj_scale->data<T>()),
reinterpret_cast<NvType *>(fc2_result), reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
total_rows_before_expert_,
-1, // useless -1, // useless
expanded_active_expert_rows, expanded_active_expert_rows, hidden_size, inter_size / 2,
hidden_size, num_experts, down_proj_quant_args, stream);
inter_size / 2,
num_experts,
down_proj_quant_args,
stream);
} else { } else {
typename Fp16Traits::Arguments down_proj_quant_args; typename Fp16Traits::Arguments down_proj_quant_args;
fp16_moe_gemm_runner_->moe_gemm( fp16_moe_gemm_runner_->moe_gemm(
reinterpret_cast<NvType *>(act_out), reinterpret_cast<NvType *>(act_out),
reinterpret_cast<const NvType *>(down_proj_weight->data<T>()), reinterpret_cast<const NvType *>(down_proj_weight->data<T>()), nullptr,
nullptr, reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
reinterpret_cast<NvType *>(fc2_result),
total_rows_before_expert_,
-1, // useless -1, // useless
expanded_active_expert_rows, expanded_active_expert_rows, hidden_size, inter_size / 2,
hidden_size, num_experts, down_proj_quant_args, stream);
inter_size / 2,
num_experts,
down_proj_quant_args,
stream);
} }
finalize_moe_routing_kernelLauncher( finalize_moe_routing_kernelLauncher(
fc2_result, fc2_result, output_, fc2_expert_biases,
output_,
fc2_expert_biases,
reinterpret_cast<float *>(expert_scales_float), reinterpret_cast<float *>(expert_scales_float),
expanded_source_row_to_expanded_dest_row, expanded_source_row_to_expanded_dest_row, expert_for_source_row,
expert_for_source_row, num_rows, hidden_size, k, static_cast<int>(1), norm_topk_prob,
num_rows, routed_scaling_factor, stream);
hidden_size,
k,
static_cast<int>(1),
norm_topk_prob,
routed_scaling_factor,
stream);
} else { } else {
finalize_moe_routing_kernelLauncher( finalize_moe_routing_kernelLauncher(
// fc2_result, // fc2_result,
fc1_out, fc1_out, output_,
output_,
fc1_expert_biases, // fc2_expert_biases, fc1_expert_biases, // fc2_expert_biases,
reinterpret_cast<float *>(expert_scales_float), reinterpret_cast<float *>(expert_scales_float),
expanded_source_row_to_expanded_dest_row, expanded_source_row_to_expanded_dest_row, expert_for_source_row,
expert_for_source_row, num_rows, inter_size, k, static_cast<int>(0), norm_topk_prob,
num_rows, routed_scaling_factor, stream);
inter_size,
k,
static_cast<int>(0),
norm_topk_prob,
routed_scaling_factor,
stream);
} }
} }

View File

@@ -19,9 +19,9 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include "cutlass/numeric_conversion.h"
#include "moe/fused_moe_helper.h"
#include "moe/fused_moe_imp_op.h" #include "moe/fused_moe_imp_op.h"
#include "moe/fused_moe_helper.h"
#include "cutlass/numeric_conversion.h"
// Ignore CUTLASS warnings about type punning // Ignore CUTLASS warnings about type punning
#pragma GCC diagnostic push #pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing" #pragma GCC diagnostic ignored "-Wstrict-aliasing"
@@ -59,12 +59,14 @@ inline GpuLaunchConfig Get1DBlocksAnd2DGridsMoe(const int64_t cols) {
constexpr static int FINALIZE_THREADS_PER_BLOCK = 256; constexpr static int FINALIZE_THREADS_PER_BLOCK = 256;
template <class T, class U> template <class T, class U>
__host__ __device__ constexpr static U arrayConvert(T const& input) { __host__ __device__ constexpr static U arrayConvert(T const& input)
{
using Type = typename U::Element; using Type = typename U::Element;
static_assert(T::kElements == U::kElements); static_assert(T::kElements == U::kElements);
U u; U u;
#pragma unroll #pragma unroll
for (int i = 0; i < U::kElements; i++) { for (int i = 0; i < U::kElements; i++)
{
u[i] = static_cast<Type>(input[i]); u[i] = static_cast<Type>(input[i]);
} }
return u; return u;
@@ -75,8 +77,7 @@ struct uint8 {
uint4 v; uint4 v;
}; };
template <int BYTES> template<int BYTES> struct BytesToType {};
struct BytesToType {};
template<> template<>
struct BytesToType<32> { struct BytesToType<32> {
@@ -84,32 +85,27 @@ struct BytesToType<32> {
static_assert(sizeof(Type) == 32); static_assert(sizeof(Type) == 32);
}; };
template <> template<> struct BytesToType<16> {
struct BytesToType<16> {
using Type = uint4; using Type = uint4;
static_assert(sizeof(Type) == 16); static_assert(sizeof(Type) == 16);
}; };
template <> template<> struct BytesToType<8> {
struct BytesToType<8> {
using Type = uint64_t; using Type = uint64_t;
static_assert(sizeof(Type) == 8); static_assert(sizeof(Type) == 8);
}; };
template <> template<> struct BytesToType<4> {
struct BytesToType<4> {
using Type = uint32_t; using Type = uint32_t;
static_assert(sizeof(Type) == 4); static_assert(sizeof(Type) == 4);
}; };
template <> template<> struct BytesToType<2> {
struct BytesToType<2> {
using Type = uint16_t; using Type = uint16_t;
static_assert(sizeof(Type) == 2); static_assert(sizeof(Type) == 2);
}; };
template <> template<> struct BytesToType<1> {
struct BytesToType<1> {
using Type = uint8_t; using Type = uint8_t;
static_assert(sizeof(Type) == 1); static_assert(sizeof(Type) == 1);
}; };
@@ -129,23 +125,7 @@ __inline__ __device__ T BlockAllReduce(T val) {
template <typename T> template <typename T>
struct SumOp { struct SumOp {
__device__ __forceinline__ T operator()(T const& x, T const& y) { __device__ __forceinline__ T operator()(T const& x, T const& y) { return x + y; }
return x + y;
}
};
template <typename T>
struct MaxOp {
__device__ inline T operator()(T const& x, T const& y) {
return x > y ? x : y;
}
};
template <>
struct MaxOp<float> {
__device__ inline float operator()(float const& x, float const& y) {
return fmax(x, y);
}
}; };
template <typename InType, typename OutType> template <typename InType, typename OutType>
@@ -154,17 +134,18 @@ __forceinline__ __device__ OutType QuantHelperFunc(const InType input,
const float max_bound, const float max_bound,
const float min_bound) { const float min_bound) {
float quant_value = max_bound * scale * static_cast<float>(input); float quant_value = max_bound * scale * static_cast<float>(input);
return static_cast<OutType>( return static_cast<OutType>(ClipFunc<float>(quant_value, min_bound, max_bound));
ClipFunc<float>(quant_value, min_bound, max_bound));
} }
template <typename T, typename OutT, int VecSize, int Kthread> template <typename T, typename OutT, int VecSize, int Kthread>
__global__ void masked_quantize_moe_input_kernel( __global__ void masked_quantize_moe_input_kernel(const T* permuted_inputs,
const T* permuted_inputs,
const int64_t* expert_idx_per_token, const int64_t* expert_idx_per_token,
const float* quant_scales,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num, const int64_t token_num,
const int64_t dim, const int64_t dim,
float* input_dequant_scale, float* permuted_input_row_sum,
const int64_t* recv_expert_count, const int64_t* recv_expert_count,
const int num_max_tokens_per_expert, const int num_max_tokens_per_expert,
OutT* out) { OutT* out) {
@@ -172,54 +153,44 @@ __global__ void masked_quantize_moe_input_kernel(
using LoadOutT = AlignedVector<OutT, VecSize>; using LoadOutT = AlignedVector<OutT, VecSize>;
LoadT input_vec; LoadT input_vec;
LoadOutT output_vec; LoadOutT output_vec;
float scale_factor = -7.0f / 512.0f;
using vec_t = typename BytesToType<sizeof(OutT) * VecSize>::Type; using vec_t = typename BytesToType<sizeof(OutT) * VecSize>::Type;
extern __shared__ char smem_[]; for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) {
for (int token_idx = blockIdx.x; token_idx < token_num;
token_idx += gridDim.x) {
const auto token_idx_in_expert = token_idx % num_max_tokens_per_expert; const auto token_idx_in_expert = token_idx % num_max_tokens_per_expert;
const auto expert_id = token_idx / num_max_tokens_per_expert; const auto expert_id = token_idx / num_max_tokens_per_expert;
if (token_idx_in_expert >= recv_expert_count[expert_id]) { if (token_idx_in_expert >= recv_expert_count[expert_id]) {
auto next_expert_start_idx = (expert_id + 1) * num_max_tokens_per_expert; auto next_expert_start_idx = (expert_id + 1) * num_max_tokens_per_expert;
auto num_iters_to_next_expert = auto num_iters_to_next_expert = (next_expert_start_idx - token_idx - 1) / gridDim.x;
(next_expert_start_idx - token_idx - 1) / gridDim.x;
token_idx += num_iters_to_next_expert * gridDim.x; token_idx += num_iters_to_next_expert * gridDim.x;
continue; continue;
} }
int64_t expert_idx = expert_idx_per_token[token_idx]; int64_t expert_idx = expert_idx_per_token[token_idx];
float abs_max = 0.0f; float quant_scale = quant_scales[expert_idx];
float thread_row_sum = 0.0f;
for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) { for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) {
int64_t offset = token_idx * dim + idx * VecSize; int64_t offset = token_idx * dim + idx * VecSize;
Load<T, VecSize>(&permuted_inputs[offset], &input_vec); Load<T, VecSize>(&permuted_inputs[offset], &input_vec);
#pragma unroll #pragma unroll
for (int i = 0; i < VecSize; i++) { for (int i = 0; i < VecSize; i++) {
float res = static_cast<float>(input_vec[i]); output_vec[i] = QuantHelperFunc<T, OutT>(input_vec[i], quant_scale, quant_max_bound, quant_min_bound);
abs_max = fmax(abs_max, fabs(res)); thread_row_sum += static_cast<float>(output_vec[i]);
} }
Store<T, VecSize>(input_vec, reinterpret_cast<T*>(smem_) + idx * VecSize); *(reinterpret_cast<vec_t*>(&out[offset])) = *(reinterpret_cast<const vec_t*>(&output_vec));
}
abs_max = BlockAllReduce<MaxOp, float, Kthread>(abs_max);
input_dequant_scale[token_idx] = abs_max;
float quant_scale = 440.0f / abs_max;
for (int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) {
int64_t offset = token_idx * dim + idx * VecSize;
Load<T, VecSize>(reinterpret_cast<T*>(smem_) + idx * VecSize, &input_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
float res = static_cast<float>(input_vec[i]);
output_vec[i] = static_cast<OutT>(res * quant_scale);
}
*(reinterpret_cast<vec_t*>(&out[offset])) =
*(reinterpret_cast<const vec_t*>(&output_vec));
} }
float block_row_sum = BlockAllReduce<SumOp, float, Kthread>(thread_row_sum);
permuted_input_row_sum[token_idx] = block_row_sum * scale_factor;
} }
} }
template <typename T, typename OutT, int VecSize, int Kthread> template <typename T, typename OutT, int VecSize, int Kthread>
__global__ void quantize_moe_input_kernel(const T* permuted_inputs, __global__ void quantize_moe_input_kernel(const T* permuted_inputs,
const int64_t* expert_idx_per_token, const int64_t* expert_idx_per_token,
const float* quant_scales,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num, const int64_t token_num,
const int64_t dim, const int64_t dim,
float* input_dequant_scale, float* permuted_input_row_sum,
const int64_t* recv_expert_count, const int64_t* recv_expert_count,
const int num_max_tokens_per_expert, const int num_max_tokens_per_expert,
OutT* out) { OutT* out) {
@@ -228,47 +199,36 @@ __global__ void quantize_moe_input_kernel(const T* permuted_inputs,
LoadT input_vec; LoadT input_vec;
LoadOutT output_vec; LoadOutT output_vec;
using vec_t = typename BytesToType<sizeof(OutT) * VecSize>::Type; using vec_t = typename BytesToType<sizeof(OutT) * VecSize>::Type;
float scale_factor = -7.0f / 512.0f;
extern __shared__ char smem_[]; for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) {
for (int token_idx = blockIdx.x; token_idx < token_num;
token_idx += gridDim.x) {
int64_t expert_idx = expert_idx_per_token[token_idx]; int64_t expert_idx = expert_idx_per_token[token_idx];
float abs_max = 0.0f; float quant_scale = quant_scales[expert_idx];
float thread_row_sum = 0.0f;
for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) { for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) {
int64_t offset = token_idx * dim + idx * VecSize; int64_t offset = token_idx * dim + idx * VecSize;
Load<T, VecSize>(&permuted_inputs[offset], &input_vec); Load<T, VecSize>(&permuted_inputs[offset], &input_vec);
#pragma unroll #pragma unroll
for (int i = 0; i < VecSize; i++) { for (int i = 0; i < VecSize; i++) {
float res = static_cast<float>(input_vec[i]); output_vec[i] = QuantHelperFunc<T, OutT>(input_vec[i], quant_scale, quant_max_bound, quant_min_bound);
abs_max = fmax(abs_max, fabs(res)); thread_row_sum += static_cast<float>(output_vec[i]);
} }
Store<T, VecSize>(input_vec, reinterpret_cast<T*>(smem_) + idx * VecSize); *(reinterpret_cast<vec_t*>(&out[offset])) = *(reinterpret_cast<const vec_t*>(&output_vec));
}
abs_max = BlockAllReduce<MaxOp, float, Kthread>(abs_max);
input_dequant_scale[token_idx] = abs_max;
float quant_scale = 440.0f / abs_max;
for (int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) {
int64_t offset = token_idx * dim + idx * VecSize;
Load<T, VecSize>(reinterpret_cast<T*>(smem_) + idx * VecSize, &input_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
float res = static_cast<float>(input_vec[i]);
output_vec[i] = static_cast<OutT>(res * quant_scale);
}
*(reinterpret_cast<vec_t*>(&out[offset])) =
*(reinterpret_cast<const vec_t*>(&output_vec));
} }
float block_row_sum = BlockAllReduce<SumOp, float, Kthread>(thread_row_sum);
permuted_input_row_sum[token_idx] = block_row_sum * scale_factor;
} }
} }
template <typename T, typename OutT> template <typename T, typename OutT>
void quantize_moe_input(const T* permuted_inputs, void quantize_moe_input(
const T* permuted_inputs,
const int64_t* expert_idx_per_token, const int64_t* expert_idx_per_token,
const float* quant_scales,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num, const int64_t token_num,
const int64_t dim, const int64_t dim,
float* input_quant_scale, float* permuted_input_row_sum,
const int64_t* recv_expert_count, const int64_t* recv_expert_count,
const int num_max_tokens_per_expert, const int num_max_tokens_per_expert,
bool used_in_ep_low_latency, bool used_in_ep_low_latency,
@@ -281,34 +241,118 @@ void quantize_moe_input(const T* permuted_inputs,
int act_blocks_per_sm; int act_blocks_per_sm;
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
assert(dim % VecSize == 0); assert(dim % VecSize == 0);
auto kernel = auto kernel = used_in_ep_low_latency ? masked_quantize_moe_input_kernel<T, OutT, VecSize, threads_per_block> : quantize_moe_input_kernel<T, OutT, VecSize, threads_per_block>;
used_in_ep_low_latency
? masked_quantize_moe_input_kernel<T,
OutT,
VecSize,
threads_per_block>
: quantize_moe_input_kernel<T, OutT, VecSize, threads_per_block>;
cudaOccupancyMaxActiveBlocksPerMultiprocessor( cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&act_blocks_per_sm, kernel, threads_per_block, 0); &act_blocks_per_sm, kernel, threads_per_block, 0);
const int num_blocks_per_wave = sm_count * act_blocks_per_sm; const int num_blocks_per_wave = sm_count * act_blocks_per_sm;
dim3 grid; dim3 grid;
grid.x = min(static_cast<int64_t>(num_blocks_per_wave), token_num); grid.x = min(static_cast<int64_t>(num_blocks_per_wave), token_num);
const int smem_size = dim * sizeof(T); kernel<<<grid, threads_per_block, 0, stream>>>(
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
kernel<<<grid, threads_per_block, smem_size, stream>>>(
permuted_inputs, permuted_inputs,
expert_idx_per_token, expert_idx_per_token,
quant_scales,
quant_max_bound,
quant_min_bound,
token_num, token_num,
dim, dim,
input_quant_scale, permuted_input_row_sum,
recv_expert_count, recv_expert_count,
num_max_tokens_per_expert, num_max_tokens_per_expert,
out); out);
} }
template <typename T, int VecSize, int Kthread>
__global__ void masked_compute_row_sum_kernel(
const T* permuted_inputs,
const int64_t token_num,
const int64_t dim,
float* permuted_input_row_sum,
const int64_t* recv_expert_count,
const int num_max_tokens_per_expert) {
using LoadT = AlignedVector<T, VecSize>;
LoadT input_vec;
float scale_factor = -7.0f / 512.0f;
for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) {
const auto token_idx_in_expert = token_idx % num_max_tokens_per_expert;
const auto expert_id = token_idx / num_max_tokens_per_expert;
if (token_idx_in_expert >= recv_expert_count[expert_id]) {
auto next_expert_start_idx = (expert_id + 1) * num_max_tokens_per_expert;
auto num_iters_to_next_expert = (next_expert_start_idx - token_idx - 1) / gridDim.x;
token_idx += num_iters_to_next_expert * gridDim.x;
continue;
}
float thread_row_sum = 0.0f;
for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) {
int64_t offset = token_idx * dim + idx * VecSize;
Load<T, VecSize>(&permuted_inputs[offset], &input_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
thread_row_sum += static_cast<float>(input_vec[i]);
}
}
float block_row_sum = BlockAllReduce<SumOp, float, Kthread>(thread_row_sum);
permuted_input_row_sum[token_idx] = block_row_sum * scale_factor;
}
}
template <typename T, int VecSize, int Kthread>
__global__ void compute_row_sum_kernel(
const T* permuted_inputs,
const int64_t token_num,
const int64_t dim,
float* permuted_input_row_sum,
const int64_t* recv_expert_count,
const int num_max_tokens_per_expert) {
using LoadT = AlignedVector<T, VecSize>;
LoadT input_vec;
float scale_factor = -7.0f / 512.0f;
for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) {
float thread_row_sum = 0.0f;
for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) {
int64_t offset = token_idx * dim + idx * VecSize;
Load<T, VecSize>(&permuted_inputs[offset], &input_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
thread_row_sum += static_cast<float>(input_vec[i]);
}
}
float block_row_sum = BlockAllReduce<SumOp, float, Kthread>(thread_row_sum);
permuted_input_row_sum[token_idx] = block_row_sum * scale_factor;
}
}
template <typename T>
void compute_row_sum(
const T* permuted_inputs,
const int64_t token_num,
const int64_t dim,
float* permuted_input_row_sum,
const int64_t* recv_expert_count,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
cudaStream_t stream) {
constexpr int VecSize = 16 / sizeof(T);
constexpr int threads_per_block = 128;
const int dev_id = 0;
int sm_count;
int act_blocks_per_sm;
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
assert(dim % VecSize == 0);
auto kernel = used_in_ep_low_latency ? masked_compute_row_sum_kernel<T, VecSize, threads_per_block> : compute_row_sum_kernel<T, VecSize, threads_per_block>;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&act_blocks_per_sm, kernel, threads_per_block, 0);
const int num_blocks_per_wave = sm_count * act_blocks_per_sm;
dim3 grid;
grid.x = min(static_cast<int64_t>(num_blocks_per_wave), token_num);
kernel<<<grid, threads_per_block, 0, stream>>>(
permuted_inputs,
token_num,
dim,
permuted_input_row_sum,
recv_expert_count,
num_max_tokens_per_expert);
}
// ====================== Softmax things =============================== // ====================== Softmax things ===============================
// We have our own implementation of softmax here so we can support transposing // We have our own implementation of softmax here so we can support transposing
// the output in the softmax kernel when we extend this module to support // the output in the softmax kernel when we extend this module to support
@@ -336,6 +380,7 @@ __launch_bounds__(TPB) __global__
cub::Sum sum; cub::Sum sum;
float threadData(-FLT_MAX); float threadData(-FLT_MAX);
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii; const int idx = thread_row_offset + ii;
threadData = max(static_cast<float>(input[idx]), threadData); threadData = max(static_cast<float>(input[idx]), threadData);
@@ -440,8 +485,7 @@ __launch_bounds__(TPB) __global__ void moe_softmax(const T* input,
} }
template <typename T, int TPB, typename IdxT = int> template <typename T, int TPB, typename IdxT = int>
__launch_bounds__(TPB) __global__ __launch_bounds__(TPB) __global__ void group_moe_top_k(const T* inputs_after_softmax,
void group_moe_top_k(const T* inputs_after_softmax,
T* output, T* output,
IdxT* indices, IdxT* indices,
int* source_rows, int* source_rows,
@@ -537,8 +581,7 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
const int idx = thread_read_offset + expert; const int idx = thread_read_offset + expert;
inp_kvp.key = expert; inp_kvp.key = expert;
inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ;
: inputs_after_softmax[idx];
for (int prior_k = 0; prior_k < k_idx; ++prior_k) { for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
const int prior_winning_expert = indices[k * block_row + prior_k]; const int prior_winning_expert = indices[k * block_row + prior_k];
@@ -559,15 +602,12 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
source_rows[idx] = k_idx * num_rows + block_row; source_rows[idx] = k_idx * num_rows + block_row;
if constexpr (NormWeights){ if constexpr (NormWeights){
T row_out = T row_out = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]
: result_kvp.value;
row_outputs[k_idx] = row_out; row_outputs[k_idx] = row_out;
weight_sum += row_out; weight_sum += row_out;
} else { }
output[idx] = else{
bias ? inputs_after_softmax[thread_read_offset + result_kvp.key] output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
: result_kvp.value;
} }
} }
__syncthreads(); __syncthreads();
@@ -577,15 +617,13 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
weight_sum = __shfl_sync(0xffffffff, weight_sum, 0); weight_sum = __shfl_sync(0xffffffff, weight_sum, 0);
} }
if (threadIdx.x < k) { if (threadIdx.x < k) {
output[k * block_row + threadIdx.x] = output[k * block_row + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum;
row_outputs[threadIdx.x] / weight_sum;
} }
} }
} }
template <typename T, int TPB, bool NormWeights = false, typename IdxT = int> template <typename T, int TPB, bool NormWeights = false, typename IdxT = int>
__launch_bounds__(TPB) __global__ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
void moe_softmax_top_k_fused(const T* input,
const T* bias, const T* bias,
T* output, T* output,
IdxT* indices, IdxT* indices,
@@ -609,8 +647,7 @@ __launch_bounds__(TPB) __global__
cub::Sum sum; cub::Sum sum;
float threadData = float threadData = (threadIdx.x < num_experts) ? static_cast<float>(input[idx]) :(-FLT_MAX);
(threadIdx.x < num_experts) ? static_cast<float>(input[idx]) : (-FLT_MAX);
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
@@ -675,13 +712,12 @@ __launch_bounds__(TPB) __global__
source_rows[cur_idx] = k_idx * num_rows + globalIdx; source_rows[cur_idx] = k_idx * num_rows + globalIdx;
if constexpr (NormWeights) { if constexpr (NormWeights) {
T row_out = T row_out = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
row_outputs[k_idx] = row_out; row_outputs[k_idx] = row_out;
weight_sum += row_out; weight_sum += row_out;
} else { }
output[cur_idx] = else {
bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value; output[cur_idx] = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
} }
} }
__syncthreads(); __syncthreads();
@@ -692,8 +728,7 @@ __launch_bounds__(TPB) __global__
} }
if (threadIdx.x < k) { if (threadIdx.x < k) {
output[k * globalIdx + threadIdx.x] = output[k * globalIdx + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum;
row_outputs[threadIdx.x] / weight_sum;
} }
} }
} }
@@ -706,8 +741,7 @@ inline __device__ unsigned int xorwow_moe(unsigned int& state) {
} }
template <typename T, int TPB, typename IdxT = int> template <typename T, int TPB, typename IdxT = int>
__launch_bounds__(TPB) __global__ __launch_bounds__(TPB) __global__ void moe_redundant_top_k_normed(const T* inputs_after_softmax,
void moe_redundant_top_k_normed(const T* inputs_after_softmax,
const T* bias, const T* bias,
const int* expert_id_to_ep_rank_array, const int* expert_id_to_ep_rank_array,
const int* expert_in_rank_num_list, const int* expert_in_rank_num_list,
@@ -728,8 +762,7 @@ __launch_bounds__(TPB) __global__
cub::ArgMax arg_max; cub::ArgMax arg_max;
const int block_row = blockIdx.x + blockIdx.y * gridDim.x; const int block_row = blockIdx.x + blockIdx.y * gridDim.x;
// unsigned int state = block_row + blockIdx.x * blockDim.x + // unsigned int state = block_row + blockIdx.x * blockDim.x + *kernel_call_num;
// *kernel_call_num;
unsigned int state = block_row + blockIdx.x * blockDim.x; unsigned int state = block_row + blockIdx.x * blockDim.x;
if (block_row >= num_rows) { if (block_row >= num_rows) {
@@ -752,8 +785,7 @@ __launch_bounds__(TPB) __global__
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
const int idx = thread_read_offset + expert; const int idx = thread_read_offset + expert;
inp_kvp.key = expert; inp_kvp.key = expert;
inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ;
: inputs_after_softmax[idx];
for (int prior_k = 0; prior_k < k_idx; ++prior_k) { for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
const int prior_winning_expert = indices_tmp[k * block_row + prior_k]; const int prior_winning_expert = indices_tmp[k * block_row + prior_k];
@@ -770,26 +802,20 @@ __launch_bounds__(TPB) __global__
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
const int idx = k * block_row + k_idx; const int idx = k * block_row + k_idx;
// output[idx] = bias ? inputs_after_softmax[thread_read_offset + // output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
// result_kvp.key]: result_kvp.value;
source_rows[idx] = k_idx * num_rows + block_row; source_rows[idx] = k_idx * num_rows + block_row;
int expert_topk = should_process_row ? result_kvp.key : num_experts; int expert_topk = should_process_row ? result_kvp.key : num_experts;
// runduncy // runduncy
int len = expert_in_rank_num_list[expert_topk]; int len = expert_in_rank_num_list[expert_topk];
int select = (int)xorwow_moe(state) % len; int select = (int)xorwow_moe(state) % len;
int selected_rank = int selected_rank = expert_id_to_ep_rank_array[expert_topk * redundant_ep_rank_num_plus_one + select];
expert_id_to_ep_rank_array[expert_topk *
redundant_ep_rank_num_plus_one +
select];
indices[idx] = (IdxT)selected_rank; indices[idx] = (IdxT)selected_rank;
indices_tmp[idx] = result_kvp.key; indices_tmp[idx] = result_kvp.key;
atomicAdd(&tokens_per_expert_stats_list[result_kvp.key], 1); atomicAdd(&tokens_per_expert_stats_list[result_kvp.key], 1);
T row_out = T row_out = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]
: result_kvp.value;
row_outputs[k_idx] = row_out; row_outputs[k_idx] = row_out;
weight_sum += row_out; weight_sum += row_out;
} }
@@ -972,9 +998,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
#pragma unroll #pragma unroll
for (int ii = 0; ii < VPT; ++ii) { for (int ii = 0; ii < VPT; ++ii) {
row_chunk[ii] = bias ? row_chunk[ii] * reciprocal_row_sum + row_chunk[ii] = bias ? row_chunk[ii] * reciprocal_row_sum + bias[first_elt_read_by_thread + ii] : row_chunk[ii] * reciprocal_row_sum;
bias[first_elt_read_by_thread + ii]
: row_chunk[ii] * reciprocal_row_sum;
} }
// Now, softmax_res contains the softmax of the row chunk. Now, I want to find // Now, softmax_res contains the softmax of the row chunk. Now, I want to find
@@ -1033,7 +1057,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
const int idx_in_cta = k * thread_row_in_cta + k_idx; const int idx_in_cta = k * thread_row_in_cta + k_idx;
row_output[idx_in_cta] = final_val; row_output[idx_in_cta] = final_val;
weight_sum += final_val; weight_sum += final_val;
} else { }
else {
output[idx] = final_val; output[idx] = final_val;
} }
indices[idx] = should_process_row ? expert : NUM_EXPERTS; indices[idx] = should_process_row ? expert : NUM_EXPERTS;
@@ -1087,11 +1112,7 @@ struct TopkConstants {
}; };
} // namespace detail } // namespace detail
template <typename T, template <typename T, int EXPERTS, int WARPS_PER_TB, bool Norm_Weights = false, typename IdxT = int>
int EXPERTS,
int WARPS_PER_TB,
bool Norm_Weights = false,
typename IdxT = int>
void topk_gating_softmax_launcher_helper(const T* input, void topk_gating_softmax_launcher_helper(const T* input,
const T* bias, const T* bias,
T* output, T* output,
@@ -1112,12 +1133,7 @@ void topk_gating_softmax_launcher_helper(const T* input,
dim3 block_dim(WARP_SIZE, WARPS_PER_TB); dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
static constexpr int ROWS_PER_CTA = WARPS_PER_TB * ROWS_PER_WARP; static constexpr int ROWS_PER_CTA = WARPS_PER_TB * ROWS_PER_WARP;
topk_gating_softmax<T, topk_gating_softmax<T, VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG, Norm_Weights>
VPT,
EXPERTS,
WARPS_PER_TB,
BYTES_PER_LDG,
Norm_Weights>
<<<num_blocks, block_dim, ROWS_PER_CTA * k * sizeof(T), stream>>>( <<<num_blocks, block_dim, ROWS_PER_CTA * k * sizeof(T), stream>>>(
input, bias, output, num_rows, indices, source_row, k); input, bias, output, num_rows, indices, source_row, k);
} }
@@ -1139,15 +1155,8 @@ void topk_gating_softmax_kernelLauncher(const T* input,
if (topk_only_mode) { if (topk_only_mode) {
static constexpr int TPB = 256; static constexpr int TPB = 256;
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows); const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
moe_top_k<T, TPB> moe_top_k<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
<<<config_topk.block_per_grid, TPB, 0, stream>>>(input, input, gating_correction_bias, output, indices, source_row, num_experts, k, num_rows);
gating_correction_bias,
output,
indices,
source_row,
num_experts,
k,
num_rows);
return; return;
} }
static constexpr int WARPS_PER_TB = 4; static constexpr int WARPS_PER_TB = 4;
@@ -1155,15 +1164,7 @@ void topk_gating_softmax_kernelLauncher(const T* input,
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \ #define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
case N: { \ case N: { \
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB>( \ topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB>( \
input, \ input, gating_correction_bias, output, indices, source_row, num_rows, num_experts, k, stream); \
gating_correction_bias, \
output, \
indices, \
source_row, \
num_rows, \
num_experts, \
k, \
stream); \
break; \ break; \
} }
int64_t tem_num_experts = num_experts; int64_t tem_num_experts = num_experts;
@@ -1205,8 +1206,8 @@ void topk_gating_softmax_kernelLauncher(const T* input,
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows); const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
moe_softmax<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>( moe_softmax<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
input, softmax, num_experts, num_rows); input, softmax, num_experts, num_rows);
moe_top_k<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>( moe_top_k<T, TPB>
softmax, <<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
gating_correction_bias, gating_correction_bias,
output, output,
indices, indices,
@@ -1236,7 +1237,7 @@ void topk_gating_softmax_kernelLauncher(const T* input,
// to row 0 in the original matrix. Thus, to know where to read in the source // to row 0 in the original matrix. Thus, to know where to read in the source
// matrix, we simply take the modulus of the expanded index. // matrix, we simply take the modulus of the expanded index.
template <typename T, int VecSize, int Kthread, typename OutT = T> template <typename T, int VecSize, typename OutT=T>
__global__ void initialize_moe_routing_kernel( __global__ void initialize_moe_routing_kernel(
const T* unpermuted_input, const T* unpermuted_input,
OutT* permuted_output, OutT* permuted_output,
@@ -1244,7 +1245,6 @@ __global__ void initialize_moe_routing_kernel(
const int *expert_idx_per_token, const int *expert_idx_per_token,
const float *w4a8_in_scale, const float *w4a8_in_scale,
int* expanded_source_row_to_expanded_dest_row, int* expanded_source_row_to_expanded_dest_row,
float* dequant_scale,
const int64_t num_rows, const int64_t num_rows,
const int64_t active_rows, const int64_t active_rows,
const int64_t cols, const int64_t cols,
@@ -1266,50 +1266,15 @@ __global__ void initialize_moe_routing_kernel(
expanded_dest_row; expanded_dest_row;
} }
extern __shared__ char smem_[];
T* data_smem = reinterpret_cast<T*>(smem_);
if (expanded_dest_row < active_rows) { if (expanded_dest_row < active_rows) {
const int expert_idx = expert_idx_per_token[expanded_dest_row]; const int expert_idx = expert_idx_per_token[expanded_dest_row];
float scale; const float scale = w4a8_in_scale ? w4a8_in_scale[expert_idx] : -1;
const int source_row = expanded_source_row % num_rows; const int source_row = expanded_source_row % num_rows;
const T* source_row_ptr = unpermuted_input + source_row * cols; const T* source_row_ptr = unpermuted_input + source_row * cols;
OutT *dest_row_ptr = permuted_output + expanded_dest_row * cols; OutT *dest_row_ptr = permuted_output + expanded_dest_row * cols;
if constexpr (std::is_same<OutT, phi::dtype::float8_e4m3fn>::value) {
if (dequant_scale != nullptr) {
float abs_max = 0.f;
for (int tid = threadIdx.x * VecSize; tid < cols;
tid += blockDim.x * VecSize) {
Load<T, VecSize>(&source_row_ptr[tid], &src_vec);
Store<T, VecSize>(src_vec, &data_smem[tid]);
for (int j = 0; j < VecSize; j++) {
abs_max = fmaxf(abs_max, fabsf(static_cast<float>(src_vec[j])));
}
}
abs_max = BlockAllReduce<MaxOp, float, Kthread>(abs_max);
scale = 440.0f / abs_max;
dequant_scale[expanded_dest_row] = abs_max;
for (int tid = threadIdx.x * VecSize; tid < cols;
tid += blockDim.x * VecSize) {
Load<T, VecSize>(&data_smem[tid], &src_vec);
using StoreT = AlignedVector<OutT, VecSize>;
StoreT dest_vec;
for (int j = 0; j < VecSize; j++) {
float quant_value = scale * static_cast<float>(src_vec[j]);
dest_vec[j] = static_cast<OutT>(quant_value);
}
Store<OutT, VecSize>(dest_vec, &dest_row_ptr[tid]);
}
return;
} else {
scale = w4a8_in_scale ? w4a8_in_scale[expert_idx] : -1;
}
} else {
scale = w4a8_in_scale ? w4a8_in_scale[expert_idx] : -1;
}
for (int tid = threadIdx.x * VecSize; tid < cols; for (int tid = threadIdx.x * VecSize; tid < cols;
tid += blockDim.x * VecSize) { tid += blockDim.x * VecSize) {
// dest_row_ptr[tid] = source_row_ptr[tid]; // dest_row_ptr[tid] = source_row_ptr[tid];
@@ -1328,15 +1293,13 @@ __global__ void initialize_moe_routing_kernel(
dest_vec[j] = static_cast<int8_t>(round(quant_value)); dest_vec[j] = static_cast<int8_t>(round(quant_value));
} }
Store<OutT, VecSize>(dest_vec, &dest_row_ptr[tid]); Store<OutT, VecSize>(dest_vec, &dest_row_ptr[tid]);
} else if constexpr (std::is_same<OutT, } else if constexpr (std::is_same<OutT, phi::dtype::float8_e4m3fn>::value) {
phi::dtype::float8_e4m3fn>::value) {
using StoreT = AlignedVector<OutT, VecSize>; using StoreT = AlignedVector<OutT, VecSize>;
StoreT dest_vec; StoreT dest_vec;
const float max_bound = 448.f; const float max_bound = 448.f;
const float min_bound = -448.f; const float min_bound = -448.f;
for (int j = 0; j < VecSize; j++) { for (int j = 0; j < VecSize; j++) {
float quant_value = float quant_value = max_bound * scale * static_cast<float>(src_vec[j]);
max_bound * scale * static_cast<float>(src_vec[j]);
quant_value = quant_value > max_bound ? max_bound : quant_value; quant_value = quant_value > max_bound ? max_bound : quant_value;
quant_value = quant_value < min_bound ? min_bound : quant_value; quant_value = quant_value < min_bound ? min_bound : quant_value;
dest_vec[j] = static_cast<phi::dtype::float8_e4m3fn>(quant_value); dest_vec[j] = static_cast<phi::dtype::float8_e4m3fn>(quant_value);
@@ -1357,36 +1320,41 @@ void initialize_moe_routing_kernelLauncher(
const int *expert_idx_per_token, const int *expert_idx_per_token,
const float *w4a8_in_scale, const float *w4a8_in_scale,
int* expanded_source_row_to_expanded_dest_row, int* expanded_source_row_to_expanded_dest_row,
float* dequant_scale,
const int64_t num_rows, const int64_t num_rows,
const int64_t active_rows, const int64_t active_rows,
const int64_t cols, const int64_t cols,
const int64_t k, const int64_t k,
cudaStream_t stream) { cudaStream_t stream) {
constexpr int threads = 256; const int threads = std::min(cols, int64_t(1024));
constexpr int max_pack_size = 16 / sizeof(T); constexpr int max_pack_size = 16 / sizeof(T);
const auto config_initialize = Get1DBlocksAnd2DGridsMoe(num_rows * k); const auto config_initialize = Get1DBlocksAnd2DGridsMoe(num_rows * k);
const int smem_size = cols * sizeof(float); if (cols % max_pack_size == 0) {
auto kernel = &initialize_moe_routing_kernel<T, max_pack_size, threads, OutT>; initialize_moe_routing_kernel<T, max_pack_size>
if (cols % max_pack_size != 0) { <<<config_initialize.block_per_grid, threads, 0, stream>>>(
kernel = &initialize_moe_routing_kernel<T, 1, threads, OutT>;
}
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
kernel<<<config_initialize.block_per_grid, threads, smem_size, stream>>>(
unpermuted_input, unpermuted_input,
permuted_output, permuted_output,
expanded_dest_row_to_expanded_source_row, expanded_dest_row_to_expanded_source_row,
expert_idx_per_token, expert_idx_per_token,
w4a8_in_scale, w4a8_in_scale,
expanded_source_row_to_expanded_dest_row, expanded_source_row_to_expanded_dest_row,
dequant_scale,
num_rows, num_rows,
k * active_rows, k * active_rows,
cols, cols,
num_rows * k); num_rows * k);
} else {
initialize_moe_routing_kernel<T, 1>
<<<config_initialize.block_per_grid, threads, 0, stream>>>(
unpermuted_input,
permuted_output,
expanded_dest_row_to_expanded_source_row,
expert_idx_per_token,
w4a8_in_scale,
expanded_source_row_to_expanded_dest_row,
num_rows,
k * active_rows,
cols,
num_rows * k);
}
} }
// ============================== Infer GEMM sizes // ============================== Infer GEMM sizes
@@ -1435,8 +1403,8 @@ __global__ void finalize_moe_routing_kernel(
auto const offset = original_row * cols; auto const offset = original_row * cols;
T* reduced_row_ptr = reduced_unpermuted_output + offset; T* reduced_row_ptr = reduced_unpermuted_output + offset;
constexpr int64_t FINALIZE_ELEM_PER_THREAD = constexpr int64_t FINALIZE_ELEM_PER_THREAD
128 / cutlass::sizeof_bits<T>::value; = 128 / cutlass::sizeof_bits<T>::value;
int64_t const start_offset = threadIdx.x; int64_t const start_offset = threadIdx.x;
int64_t const stride = FINALIZE_THREADS_PER_BLOCK; int64_t const stride = FINALIZE_THREADS_PER_BLOCK;
int64_t const num_elems_in_col = cols / FINALIZE_ELEM_PER_THREAD; int64_t const num_elems_in_col = cols / FINALIZE_ELEM_PER_THREAD;
@@ -1448,48 +1416,51 @@ __global__ void finalize_moe_routing_kernel(
using SharedOutputElem = cutlass::Array<T, FINALIZE_ELEM_PER_THREAD>; using SharedOutputElem = cutlass::Array<T, FINALIZE_ELEM_PER_THREAD>;
auto const* bias_v = reinterpret_cast<BiasElem const*>(bias); auto const* bias_v = reinterpret_cast<BiasElem const*>(bias);
auto const* expanded_permuted_rows_v = auto const* expanded_permuted_rows_v = reinterpret_cast<InputElem const*>(expanded_permuted_rows);
reinterpret_cast<InputElem const*>(expanded_permuted_rows);
auto* reduced_row_ptr_v = reinterpret_cast<OutputElem*>(reduced_row_ptr); auto* reduced_row_ptr_v = reinterpret_cast<OutputElem*>(reduced_row_ptr);
#pragma unroll #pragma unroll
for (int elem_index = start_offset; elem_index < num_elems_in_col; for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride)
elem_index += stride) { {
ComputeElem thread_output; ComputeElem thread_output;
thread_output.fill(0); thread_output.fill(0);
float row_rescale{0.f}; float row_rescale{0.f};
for (int k_idx = 0; k_idx < k; ++k_idx) { for (int k_idx = 0; k_idx < k; ++k_idx)
{
int64_t const expanded_original_row = original_row + k_idx * num_rows; int64_t const expanded_original_row = original_row + k_idx * num_rows;
int64_t const expanded_permuted_row = int64_t const expanded_permuted_row = expanded_source_row_to_expanded_dest_row[expanded_original_row];
expanded_source_row_to_expanded_dest_row[expanded_original_row];
int64_t const k_offset = original_row * k + k_idx; int64_t const k_offset = original_row * k + k_idx;
const float row_scale = scales[k_offset]; const float row_scale = scales[k_offset];
row_rescale = row_rescale + row_scale; row_rescale = row_rescale + row_scale;
auto const* expanded_permuted_rows_row_ptr = auto const* expanded_permuted_rows_row_ptr
expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col; = expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col;
int const expert_idx = expert_for_source_row[k_offset]; int const expert_idx = expert_for_source_row[k_offset];
auto const* bias_ptr = bias_v + expert_idx * num_elems_in_col; auto const* bias_ptr = bias_v + expert_idx * num_elems_in_col;
ComputeElem bias_value; ComputeElem bias_value;
if (bias) { if (bias)
{
bias_value = arrayConvert<BiasElem, ComputeElem>(bias_ptr[elem_index]); bias_value = arrayConvert<BiasElem, ComputeElem>(bias_ptr[elem_index]);
} else { }
else
{
bias_value.fill(0); bias_value.fill(0);
} }
ComputeElem expert_result = arrayConvert<InputElem, ComputeElem>( ComputeElem expert_result
expanded_permuted_rows_row_ptr[elem_index]); = arrayConvert<InputElem, ComputeElem>(expanded_permuted_rows_row_ptr[elem_index]);
thread_output = thread_output + row_scale * (expert_result + bias_value); thread_output = thread_output + row_scale * (expert_result + bias_value);
} }
for (auto& elem : thread_output) { for (auto& elem : thread_output)
elem = {
elem / (norm_topk_prob ? row_rescale : 1.0f) * routed_scaling_factor; elem = elem / (norm_topk_prob ? row_rescale : 1.0f) * routed_scaling_factor;
} }
OutputElem output_elem = OutputElem output_elem = arrayConvert<ComputeElem, OutputElem>(thread_output);
arrayConvert<ComputeElem, OutputElem>(thread_output);
reduced_row_ptr_v[elem_index] = output_elem; reduced_row_ptr_v[elem_index] = output_elem;
} }
} }
@@ -1513,7 +1484,8 @@ void finalize_moe_routing_kernelLauncher(
const int threads = FINALIZE_THREADS_PER_BLOCK; const int threads = FINALIZE_THREADS_PER_BLOCK;
finalize_moe_routing_kernel<T, 1> finalize_moe_routing_kernel<T, 1>
<<<blocks, threads, 0, stream>>>(expanded_permuted_rows, <<<blocks, threads, 0, stream>>>(
expanded_permuted_rows,
reduced_unpermuted_output, reduced_unpermuted_output,
bias, bias,
scales, scales,

View File

@@ -26,23 +26,14 @@
template <paddle::DataType T> template <paddle::DataType T>
void MoeDispatchKernel( void MoeDispatchKernel(
const paddle::Tensor &input, const paddle::Tensor &input, const paddle::Tensor &gating_output,
const paddle::Tensor &gating_output,
const paddle::optional<paddle::Tensor> &gating_correction_bias, const paddle::optional<paddle::Tensor> &gating_correction_bias,
const paddle::optional<paddle::Tensor> &w4a8_in_scale, const paddle::optional<paddle::Tensor> &w4a8_in_scale, const int moe_topk,
const int moe_topk, const bool group_moe, const bool topk_only_mode, const int num_rows,
const bool group_moe, const int hidden_size, const int expert_num, paddle::Tensor *permute_input,
const bool topk_only_mode,
const int num_rows,
const int hidden_size,
const int expert_num,
paddle::Tensor *permute_input,
paddle::Tensor *tokens_expert_prefix_sum, paddle::Tensor *tokens_expert_prefix_sum,
paddle::Tensor *permute_indices_per_token, paddle::Tensor *permute_indices_per_token, paddle::Tensor *topk_weight,
paddle::Tensor *topk_weight, paddle::Tensor *topk_idx, paddle::Tensor *expert_idx_per_token) {
paddle::Tensor *topk_idx,
paddle::Tensor *expert_idx_per_token,
paddle::Tensor *dequant_scale) {
using namespace phi; using namespace phi;
if (num_rows == 0){ if (num_rows == 0){
@@ -57,14 +48,12 @@ void MoeDispatchKernel(
if (group_moe) { if (group_moe) {
// Check if expert_num is divisible by moe_topk, else throw an error // Check if expert_num is divisible by moe_topk, else throw an error
PADDLE_ENFORCE_EQ(expert_num % moe_topk, PADDLE_ENFORCE_EQ(expert_num % moe_topk, 0,
0,
common::errors::InvalidArgument( common::errors::InvalidArgument(
"The number of experts (expert_num) " "The number of experts (expert_num) "
"must be divisible by moe_topk. " "must be divisible by moe_topk. "
"Got expert_num = %d and moe_topk = %d.", "Got expert_num = %d and moe_topk = %d.",
expert_num, expert_num, moe_topk));
moe_topk));
} }
const int num_moe_inputs = AlignTo16(num_rows * moe_topk); const int num_moe_inputs = AlignTo16(num_rows * moe_topk);
@@ -79,8 +68,7 @@ void MoeDispatchKernel(
paddle::Tensor ws_ptr_tensor = paddle::Tensor ws_ptr_tensor =
GetEmptyTensor({bytes + sorter_ws_size_bytes + sort_tmp_in_out_size}, GetEmptyTensor({bytes + sorter_ws_size_bytes + sort_tmp_in_out_size},
paddle::DataType::INT8, paddle::DataType::INT8, place);
place);
int8_t *ws_ptr = ws_ptr_tensor.data<int8_t>(); int8_t *ws_ptr = ws_ptr_tensor.data<int8_t>();
int *source_rows_ = reinterpret_cast<int *>(ws_ptr); int *source_rows_ = reinterpret_cast<int *>(ws_ptr);
@@ -108,8 +96,8 @@ void MoeDispatchKernel(
paddle::Tensor softmax_buffer; paddle::Tensor softmax_buffer;
if (!is_pow_2 || expert_num > 256 || group_moe || gating_correction_bias) { if (!is_pow_2 || expert_num > 256 || group_moe || gating_correction_bias) {
softmax_buffer = GetEmptyTensor( softmax_buffer = GetEmptyTensor({num_rows * expert_num},
{num_rows * expert_num}, paddle::DataType::FLOAT32, place); paddle::DataType::FLOAT32, place);
softmax_out_ = softmax_buffer.data<float>(); softmax_out_ = softmax_buffer.data<float>();
} else { } else {
softmax_out_ = nullptr; softmax_out_ = nullptr;
@@ -119,106 +107,47 @@ void MoeDispatchKernel(
gating_output.data<float>(), gating_output.data<float>(),
gating_correction_bias ? gating_correction_bias.get().data<float>() gating_correction_bias ? gating_correction_bias.get().data<float>()
: nullptr, : nullptr,
topk_weight->data<float>(), topk_weight->data<float>(), softmax_out_, topk_idx_ptr, source_rows_,
softmax_out_, softmax_max_prob, num_rows, expert_num, moe_topk, group_moe, stream,
topk_idx_ptr,
source_rows_,
softmax_max_prob,
num_rows,
expert_num,
moe_topk,
group_moe,
stream,
topk_only_mode); topk_only_mode);
sorter_.run(reinterpret_cast<void *>(sorter_ws_ptr), sorter_.run(reinterpret_cast<void *>(sorter_ws_ptr), sorter_ws_size_bytes,
sorter_ws_size_bytes, topk_idx_ptr, expert_idx_per_token->data<int32_t>(), source_rows_,
topk_idx_ptr, permuted_rows_, moe_topk * num_rows, false, stream);
expert_idx_per_token->data<int32_t>(),
source_rows_,
permuted_rows_,
moe_topk * num_rows,
false,
stream);
if (w4a8_in_scale) { if (w4a8_in_scale) {
if (permute_input->dtype() == paddle::DataType::INT8) { if (permute_input->dtype() == paddle::DataType::INT8) {
initialize_moe_routing_kernelLauncher( initialize_moe_routing_kernelLauncher(
input.data<data_t>(), input.data<data_t>(), permute_input->data<int8_t>(), permuted_rows_,
permute_input->data<int8_t>(), expert_idx_per_token->data<int32_t>(), w4a8_in_scale->data<float>(),
permuted_rows_, permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
expert_idx_per_token->data<int32_t>(), hidden_size, moe_topk, stream);
w4a8_in_scale->data<float>(),
permute_indices_per_token->data<int32_t>(),
nullptr,
num_rows,
num_rows,
hidden_size,
moe_topk,
stream);
} else if (permute_input->dtype() == paddle::DataType::FLOAT8_E4M3FN) { } else if (permute_input->dtype() == paddle::DataType::FLOAT8_E4M3FN) {
initialize_moe_routing_kernelLauncher( initialize_moe_routing_kernelLauncher(
input.data<data_t>(), input.data<data_t>(), permute_input->data<float8_e4m3fn>(),
permute_input->data<float8_e4m3fn>(), permuted_rows_, expert_idx_per_token->data<int32_t>(),
permuted_rows_,
expert_idx_per_token->data<int32_t>(),
w4a8_in_scale->data<float>(), w4a8_in_scale->data<float>(),
permute_indices_per_token->data<int32_t>(), permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
nullptr, hidden_size, moe_topk, stream);
num_rows,
num_rows,
hidden_size,
moe_topk,
stream);
} }
} else { } else {
if (permute_input->dtype() == paddle::DataType::FLOAT8_E4M3FN) {
initialize_moe_routing_kernelLauncher( initialize_moe_routing_kernelLauncher(
input.data<data_t>(), input.data<data_t>(), permute_input->data<data_t>(), permuted_rows_,
permute_input->data<float8_e4m3fn>(), expert_idx_per_token->data<int32_t>(), nullptr,
permuted_rows_, permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
expert_idx_per_token->data<int32_t>(), hidden_size, moe_topk, stream);
nullptr,
permute_indices_per_token->data<int32_t>(),
dequant_scale->data<float>(),
num_rows,
num_rows,
hidden_size,
moe_topk,
stream);
} else {
initialize_moe_routing_kernelLauncher(
input.data<data_t>(),
permute_input->data<data_t>(),
permuted_rows_,
expert_idx_per_token->data<int32_t>(),
nullptr,
permute_indices_per_token->data<int32_t>(),
nullptr,
num_rows,
num_rows,
hidden_size,
moe_topk,
stream);
}
} }
compute_total_rows_before_expert(expert_idx_per_token->data<int32_t>(), compute_total_rows_before_expert(
moe_topk * num_rows, expert_idx_per_token->data<int32_t>(), moe_topk * num_rows, expert_num,
expert_num, tokens_expert_prefix_sum->data<int64_t>(), stream);
tokens_expert_prefix_sum->data<int64_t>(),
stream);
} }
std::vector<paddle::Tensor> MoeExpertDispatch( std::vector<paddle::Tensor> MoeExpertDispatch(
const paddle::Tensor &input, const paddle::Tensor &input, const paddle::Tensor &gating_output,
const paddle::Tensor &gating_output,
const paddle::optional<paddle::Tensor> &gating_correction_bias, const paddle::optional<paddle::Tensor> &gating_correction_bias,
const paddle::optional<paddle::Tensor> &w4a8_in_scale, const paddle::optional<paddle::Tensor> &w4a8_in_scale, const int moe_topk,
const int moe_topk, const bool group_moe, const std::string &moe_quant_type, const bool topk_only_mode) {
const bool group_moe,
const std::string &moe_quant_type,
const bool topk_only_mode) {
const auto input_type = input.dtype(); const auto input_type = input.dtype();
auto place = input.place(); auto place = input.place();
int token_rows = 0; int token_rows = 0;
@@ -241,21 +170,10 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
} else if (moe_quant_type == "w4afp8") { } else if (moe_quant_type == "w4afp8") {
permute_input_dtype = paddle::DataType::FLOAT8_E4M3FN; permute_input_dtype = paddle::DataType::FLOAT8_E4M3FN;
} }
} else {
if (moe_quant_type == "w4afp8") {
permute_input_dtype = paddle::DataType::FLOAT8_E4M3FN;
}
} }
auto permute_input = GetEmptyTensor( auto permute_input = GetEmptyTensor({moe_topk * num_rows, hidden_size},
{moe_topk * num_rows, hidden_size}, permute_input_dtype, place); permute_input_dtype, place);
int dequant_scale_size = 1;
if (moe_quant_type == "w4afp8" && !w4a8_in_scale) {
dequant_scale_size = moe_topk * num_rows;
}
auto dequant_scale =
GetEmptyTensor({dequant_scale_size}, paddle::DataType::FLOAT32, place);
// correspond to the weighted coefficients of the results from each expert. // correspond to the weighted coefficients of the results from each expert.
auto topk_weight = auto topk_weight =
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place); GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
@@ -276,48 +194,23 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
permute_indices_per_token, permute_indices_per_token,
topk_weight, topk_weight,
topk_idx, topk_idx,
expert_idx_per_token, expert_idx_per_token};
dequant_scale};
} }
switch (input_type) { switch (input_type) {
case paddle::DataType::BFLOAT16: case paddle::DataType::BFLOAT16:
MoeDispatchKernel<paddle::DataType::BFLOAT16>(input, MoeDispatchKernel<paddle::DataType::BFLOAT16>(
gating_output, input, gating_output, gating_correction_bias, w4a8_in_scale, moe_topk,
gating_correction_bias, group_moe, topk_only_mode, num_rows, hidden_size, expert_num,
w4a8_in_scale, &permute_input, &tokens_expert_prefix_sum, &permute_indices_per_token,
moe_topk, &topk_weight, &topk_idx, &expert_idx_per_token);
group_moe,
topk_only_mode,
num_rows,
hidden_size,
expert_num,
&permute_input,
&tokens_expert_prefix_sum,
&permute_indices_per_token,
&topk_weight,
&topk_idx,
&expert_idx_per_token,
&dequant_scale);
break; break;
case paddle::DataType::FLOAT16: case paddle::DataType::FLOAT16:
MoeDispatchKernel<paddle::DataType::FLOAT16>(input, MoeDispatchKernel<paddle::DataType::FLOAT16>(
gating_output, input, gating_output, gating_correction_bias, w4a8_in_scale, moe_topk,
gating_correction_bias, group_moe, topk_only_mode, num_rows, hidden_size, expert_num,
w4a8_in_scale, &permute_input, &tokens_expert_prefix_sum, &permute_indices_per_token,
moe_topk, &topk_weight, &topk_idx, &expert_idx_per_token);
group_moe,
topk_only_mode,
num_rows,
hidden_size,
expert_num,
&permute_input,
&tokens_expert_prefix_sum,
&permute_indices_per_token,
&topk_weight,
&topk_idx,
&expert_idx_per_token,
&dequant_scale);
break; break;
default: default:
PD_THROW("Unsupported data type for MoeDispatchKernel"); PD_THROW("Unsupported data type for MoeDispatchKernel");
@@ -327,8 +220,7 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
permute_indices_per_token, permute_indices_per_token,
topk_weight, topk_weight,
topk_idx, topk_idx,
expert_idx_per_token, expert_idx_per_token};
dequant_scale};
} }
std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape( std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
@@ -353,22 +245,16 @@ std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
{moe_topk, num_rows}, {moe_topk, num_rows},
{num_rows, moe_topk}, {num_rows, moe_topk},
{num_rows, moe_topk}, {num_rows, moe_topk},
{permuted_rows}, {permuted_rows}};
{num_rows}};
} }
std::vector<paddle::DataType> MoeExpertDispatchInferDtype( std::vector<paddle::DataType>
const paddle::DataType &input_dtype, MoeExpertDispatchInferDtype(const paddle::DataType &input_dtype,
const paddle::DataType &gating_output_dtype, const paddle::DataType &gating_output_dtype,
const paddle::optional<paddle::DataType> &bias_type, const paddle::optional<paddle::DataType> &bias_type,
const int moe_topk) { const int moe_topk) {
return {input_dtype, return {input_dtype, paddle::DataType::INT64, paddle::DataType::INT32,
paddle::DataType::INT64, paddle::DataType::FLOAT32, paddle::DataType::INT32, paddle::DataType::INT32};
paddle::DataType::INT32,
paddle::DataType::FLOAT32,
paddle::DataType::INT32,
paddle::DataType::INT32,
paddle::DataType::FLOAT32};
} }
/** /**
@@ -376,8 +262,7 @@ std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
* *
* This operator performs the following key functions: * This operator performs the following key functions:
* 1. Computes top-k experts for each input token based on gating scores * 1. Computes top-k experts for each input token based on gating scores
* 2. Permutes input tokens according to their selected experts for efficient * 2. Permutes input tokens according to their selected experts for efficient expert processing
* expert processing
* 3. Computes prefix sums of tokens per expert for group_gemm optimization * 3. Computes prefix sums of tokens per expert for group_gemm optimization
* *
* Inputs: * Inputs:
@@ -387,17 +272,18 @@ std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
* - gating_output: Gating network output scores for each token-expert pair * - gating_output: Gating network output scores for each token-expert pair
* Shape: [total_tokens, expert_num] * Shape: [total_tokens, expert_num]
* dtype: must be float32 * dtype: must be float32
* - gating_correction_bias: Optional bias term for gating correction * - gating_correction_bias: Optional bias term for gating correction (expert_num)
* (expert_num)
* *
* Outputs: * Outputs:
* - permute_input: Permuted input tensor organized by expert * - permute_input: Permuted input tensor organized by expert
* Shape: [moe_topk * total_tokens, hidden_size] * Shape: [moe_topk * total_tokens, hidden_size]
* dtype: Same as input * dtype: Same as input
* - tokens_expert_prefix_sum: Prefix sum array of token counts per expert for * - tokens_expert_prefix_sum: Prefix sum array of token counts per expert for group_gemm
* group_gemm Shape: [expert_num] dtype: int64 * Shape: [expert_num]
* - permute_indices_per_token: Indices mapping for reconstructing original * dtype: int64
* order Shape: [moe_topk, total_tokens] dtype: int32 * - permute_indices_per_token: Indices mapping for reconstructing original order
* Shape: [moe_topk, total_tokens]
* dtype: int32
* - top_k_weight: Weight coefficients for combining expert outputs * - top_k_weight: Weight coefficients for combining expert outputs
* Shape: [total_tokens, moe_topk] * Shape: [total_tokens, moe_topk]
* dtype: float32 * dtype: float32
@@ -406,8 +292,7 @@ std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
* dtype: int32 * dtype: int32
* *
* Attributes: * Attributes:
* - moe_topk: Number of experts to select for each token (k value in top-k * - moe_topk: Number of experts to select for each token (k value in top-k routing)
* routing)
* - group_moe: Whether to perform group softmax within the operator * - group_moe: Whether to perform group softmax within the operator
* (true: softmax is computed within groups of experts, * (true: softmax is computed within groups of experts,
* false: standard softmax across all experts) * false: standard softmax across all experts)
@@ -421,21 +306,13 @@ std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
* - When group_moe is true, expert_num must be divisible by moe_topk * - When group_moe is true, expert_num must be divisible by moe_topk
*/ */
PD_BUILD_STATIC_OP(moe_expert_dispatch) PD_BUILD_STATIC_OP(moe_expert_dispatch)
.Inputs({"input", .Inputs({"input", "gating_output",
"gating_output",
paddle::Optional("gating_correction_bias"), paddle::Optional("gating_correction_bias"),
paddle::Optional("w4a8_in_scale")}) paddle::Optional("w4a8_in_scale")})
.Outputs({"permute_input", .Outputs({"permute_input", "tokens_expert_prefix_sum",
"tokens_expert_prefix_sum", "permute_indices_per_token", "topk_weight", "topk_idx",
"permute_indices_per_token", "expert_idx_per_token"})
"topk_weight", .Attrs({"moe_topk:int", "group_moe:bool", "moe_quant_type:std::string", "topk_only_mode:bool"})
"topk_idx",
"expert_idx_per_token",
"dequant_scale"})
.Attrs({"moe_topk:int",
"group_moe:bool",
"moe_quant_type:std::string",
"topk_only_mode:bool"})
.SetKernelFn(PD_KERNEL(MoeExpertDispatch)) .SetKernelFn(PD_KERNEL(MoeExpertDispatch))
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape)) .SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertDispatchInferDtype)); .SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertDispatchInferDtype));

View File

@@ -17,7 +17,6 @@
#include "cutlass/numeric_conversion.h" #include "cutlass/numeric_conversion.h"
#include "group_swiglu_with_masked.h" #include "group_swiglu_with_masked.h"
#include "helper.h" #include "helper.h"
#include "moe/fast_hardmard/fast_hardamard_kernel.h"
#include "moe/fused_moe_helper.h" #include "moe/fused_moe_helper.h"
template <typename DataT, template <typename DataT,

View File

@@ -18,8 +18,8 @@
#include "cutlass_kernels/w4a8_moe/w4a8_moe_gemm_kernel.h" #include "cutlass_kernels/w4a8_moe/w4a8_moe_gemm_kernel.h"
#include "group_swiglu_with_masked.h" #include "group_swiglu_with_masked.h"
#include "helper.h" #include "helper.h"
#include "moe/fast_hardmard/fast_hardamard_kernel.h"
#include "moe/fused_moe_helper.h" #include "moe/fused_moe_helper.h"
#include "moe/moe_fast_hardamard_kernel.h"
#include "swigluoai.h" #include "swigluoai.h"
#include "w4afp8_gemm/w4afp8_gemm.h" #include "w4afp8_gemm/w4afp8_gemm.h"
@@ -28,7 +28,6 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum, const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight, const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight, const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_proj_in_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias, const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale, const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale, const paddle::optional<paddle::Tensor>& down_proj_scale,
@@ -198,20 +197,31 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
typedef PDTraits<paddle::DataType::FLOAT8_E4M3FN> traits_fp8; typedef PDTraits<paddle::DataType::FLOAT8_E4M3FN> traits_fp8;
typedef typename traits_fp8::DataType DataType_fp8; typedef typename traits_fp8::DataType DataType_fp8;
typedef typename traits_fp8::data_t data_t_fp8; typedef typename traits_fp8::data_t data_t_fp8;
paddle::Tensor weight_scale_tensor =
*const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr()); Allocator::AllocationPtr ffn1_input_row_sum;
const int weight_scale_group_size = weight_scale_tensor.dims().size() == 2 ffn1_input_row_sum =
? hidden_size allocator->Allocate(sizeof(float) * expanded_active_expert_rows);
: weight_scale_tensor.dims()[3];
const float* input_dequant_scale = compute_row_sum(
up_proj_in_scale ? up_proj_in_scale.get().data<float>() : nullptr; permute_input.data<data_t_fp8>(),
expanded_active_expert_rows,
hidden_size,
reinterpret_cast<float*>(ffn1_input_row_sum->ptr()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
num_max_tokens_per_expert,
used_in_ep_low_latency,
stream);
float* row_scale = nullptr;
DisPatchW4AFp8GemmWrapper( DisPatchW4AFp8GemmWrapper(
reinterpret_cast<const DataType_fp8*>(permute_input.data<data_t_fp8>()), reinterpret_cast<const DataType_fp8*>(permute_input.data<data_t_fp8>()),
reinterpret_cast<const DataType_fp8*>( reinterpret_cast<const DataType_fp8*>(
up_gate_proj_weight.data<int8_t>()), up_gate_proj_weight.data<int8_t>()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()), const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
input_dequant_scale, reinterpret_cast<float*>(ffn1_input_row_sum->ptr()),
weight_scale_tensor.data<float>(), row_scale,
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr())
->data<float>(),
reinterpret_cast<NvType*>(fc1_out), reinterpret_cast<NvType*>(fc1_out),
used_in_ep_low_latency ? num_max_tokens_per_expert : 0, used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
used_in_ep_low_latency ? num_max_tokens_per_expert used_in_ep_low_latency ? num_max_tokens_per_expert
@@ -219,7 +229,6 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
num_experts, num_experts,
inter_size, inter_size,
hidden_size, hidden_size,
weight_scale_group_size,
stream); stream);
} else { } else {
typename cutlass::WintQuantTraits< typename cutlass::WintQuantTraits<
@@ -346,39 +355,15 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
} else if (quant_method == "w4afp8") { } else if (quant_method == "w4afp8") {
data_t* ffn2_shift = nullptr; data_t* ffn2_shift = nullptr;
data_t* ffn2_smooth = nullptr; data_t* ffn2_smooth = nullptr;
float* input_dequant_scale = nullptr; float* row_scale = nullptr;
Allocator::AllocationPtr fp8_act_out; Allocator::AllocationPtr fp8_act_out;
fp8_act_out = allocator->Allocate(SizeOf(paddle::DataType::INT8) * fp8_act_out = allocator->Allocate(SizeOf(paddle::DataType::INT8) *
act_out_tensor.numel()); act_out_tensor.numel());
Allocator::AllocationPtr ffn2_input_row_sum;
if (down_proj_in_scale) { ffn2_input_row_sum =
MoeFastHardamardWrapper<data_t, data_t_fp8>(
act_out_tensor.data<data_t>(),
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>()
: nullptr,
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
ffn2_shift,
ffn2_smooth,
down_proj_in_scale
? const_cast<paddle::Tensor*>(down_proj_in_scale.get_ptr())
->data<float>()
: nullptr,
1,
448.0f,
-448.0f,
expanded_active_expert_rows,
inter_size / 2,
num_max_tokens_per_expert,
used_in_ep_low_latency,
hadamard_block_size,
reinterpret_cast<data_t_fp8*>(fp8_act_out->ptr()),
stream);
} else {
Allocator::AllocationPtr ffn2_input_dequant_scale;
ffn2_input_dequant_scale =
allocator->Allocate(sizeof(float) * expanded_active_expert_rows); allocator->Allocate(sizeof(float) * expanded_active_expert_rows);
input_dequant_scale =
reinterpret_cast<float*>(ffn2_input_dequant_scale->ptr()); // note(yuanxiaolan): optimize this
MoeFastHardamardWrapper<data_t, data_t>( MoeFastHardamardWrapper<data_t, data_t>(
act_out_tensor.data<data_t>(), act_out_tensor.data<data_t>(),
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>() expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>()
@@ -402,28 +387,28 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
act_out_tensor.data<data_t>(), act_out_tensor.data<data_t>(),
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>() expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>()
: nullptr, : nullptr,
down_proj_in_scale
? const_cast<paddle::Tensor*>(down_proj_in_scale.get_ptr())
->data<float>()
: nullptr,
448.0f,
-448.0f,
expanded_active_expert_rows, expanded_active_expert_rows,
inter_size / 2, inter_size / 2,
input_dequant_scale, reinterpret_cast<float*>(ffn2_input_row_sum->ptr()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()), const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
num_max_tokens_per_expert, num_max_tokens_per_expert,
used_in_ep_low_latency, used_in_ep_low_latency,
reinterpret_cast<data_t_fp8*>(fp8_act_out->ptr()), reinterpret_cast<data_t_fp8*>(fp8_act_out->ptr()),
stream); stream);
}
paddle::Tensor weight_scale_tensor =
*const_cast<paddle::Tensor*>(down_proj_scale.get_ptr());
const int weight_scale_group_size = weight_scale_tensor.dims().size() == 2
? inter_size / 2
: weight_scale_tensor.dims()[3];
DisPatchW4AFp8GemmWrapper( DisPatchW4AFp8GemmWrapper(
reinterpret_cast<const DataType_fp8*>(fp8_act_out->ptr()), reinterpret_cast<const DataType_fp8*>(fp8_act_out->ptr()),
reinterpret_cast<const DataType_fp8*>(down_proj_weight.data<int8_t>()), reinterpret_cast<const DataType_fp8*>(down_proj_weight.data<int8_t>()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()), const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
input_dequant_scale, reinterpret_cast<float*>(ffn2_input_row_sum->ptr()),
weight_scale_tensor.data<float>(), row_scale,
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr())->data<float>(),
reinterpret_cast<NvType*>(ffn_out_data), reinterpret_cast<NvType*>(ffn_out_data),
used_in_ep_low_latency ? num_max_tokens_per_expert : 0, used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
used_in_ep_low_latency ? num_max_tokens_per_expert used_in_ep_low_latency ? num_max_tokens_per_expert
@@ -431,7 +416,6 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
num_experts, num_experts,
hidden_size, hidden_size,
inter_size / 2, inter_size / 2,
weight_scale_group_size,
stream); stream);
} else { } else {
typename cutlass::WintQuantTraits< typename cutlass::WintQuantTraits<
@@ -458,7 +442,6 @@ paddle::Tensor MoeExpertFFNFunc(
const paddle::Tensor& tokens_expert_prefix_sum, const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight, const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight, const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_proj_in_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias, const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale, const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale, const paddle::optional<paddle::Tensor>& down_proj_scale,
@@ -483,7 +466,6 @@ paddle::Tensor MoeExpertFFNFunc(
tokens_expert_prefix_sum, tokens_expert_prefix_sum,
up_gate_proj_weight, up_gate_proj_weight,
down_proj_weight, down_proj_weight,
up_proj_in_scale,
up_gate_proj_bias, up_gate_proj_bias,
up_gate_proj_scale, up_gate_proj_scale,
down_proj_scale, down_proj_scale,
@@ -501,7 +483,6 @@ paddle::Tensor MoeExpertFFNFunc(
tokens_expert_prefix_sum, tokens_expert_prefix_sum,
up_gate_proj_weight, up_gate_proj_weight,
down_proj_weight, down_proj_weight,
up_proj_in_scale,
up_gate_proj_bias, up_gate_proj_bias,
up_gate_proj_scale, up_gate_proj_scale,
down_proj_scale, down_proj_scale,
@@ -525,7 +506,6 @@ std::vector<paddle::Tensor> MoeExpertFFN(
const paddle::Tensor& tokens_expert_prefix_sum, const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight, const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight, const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_proj_in_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias, const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale, const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale, const paddle::optional<paddle::Tensor>& down_proj_scale,
@@ -540,7 +520,6 @@ std::vector<paddle::Tensor> MoeExpertFFN(
tokens_expert_prefix_sum, tokens_expert_prefix_sum,
up_gate_proj_weight, up_gate_proj_weight,
down_proj_weight, down_proj_weight,
up_proj_in_scale,
up_gate_proj_bias, up_gate_proj_bias,
up_gate_proj_scale, up_gate_proj_scale,
down_proj_scale, down_proj_scale,
@@ -558,7 +537,6 @@ std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
const std::vector<int64_t>& tokens_expert_prefix_sum_shape, const std::vector<int64_t>& tokens_expert_prefix_sum_shape,
const std::vector<int64_t>& up_gate_proj_weight_shape, const std::vector<int64_t>& up_gate_proj_weight_shape,
const std::vector<int64_t>& down_proj_weight_shape, const std::vector<int64_t>& down_proj_weight_shape,
const paddle::optional<std::vector<int64_t>>& up_proj_in_scale_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_bias_shape, const paddle::optional<std::vector<int64_t>>& up_gate_proj_bias_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_scale_shape, const paddle::optional<std::vector<int64_t>>& up_gate_proj_scale_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_scale_shape, const paddle::optional<std::vector<int64_t>>& down_proj_scale_shape,
@@ -577,7 +555,6 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
const paddle::DataType& tokens_expert_prefix_sum_dtype, const paddle::DataType& tokens_expert_prefix_sum_dtype,
const paddle::DataType& up_gate_proj_weight_dtype, const paddle::DataType& up_gate_proj_weight_dtype,
const paddle::DataType& down_proj_weight_dtype, const paddle::DataType& down_proj_weight_dtype,
const paddle::optional<paddle::DataType>& up_proj_in_scale_dtype,
const paddle::optional<paddle::DataType>& up_gate_proj_bias_dtype, const paddle::optional<paddle::DataType>& up_gate_proj_bias_dtype,
const paddle::optional<paddle::DataType>& up_gate_proj_scale_dtype, const paddle::optional<paddle::DataType>& up_gate_proj_scale_dtype,
const paddle::optional<paddle::DataType>& down_proj_scale_dtype, const paddle::optional<paddle::DataType>& down_proj_scale_dtype,
@@ -655,7 +632,6 @@ PD_BUILD_STATIC_OP(moe_expert_ffn)
"tokens_expert_prefix_sum", "tokens_expert_prefix_sum",
"up_gate_proj_weight", "up_gate_proj_weight",
"down_proj_weight", "down_proj_weight",
paddle::Optional("up_proj_in_scale"),
paddle::Optional("up_gate_proj_bias"), paddle::Optional("up_gate_proj_bias"),
paddle::Optional("up_gate_proj_scale"), paddle::Optional("up_gate_proj_scale"),
paddle::Optional("down_proj_scale"), paddle::Optional("down_proj_scale"),

View File

@@ -23,19 +23,13 @@
using namespace cute; using namespace cute;
template <int kStages, template <int kStages, class GemmType, class OutputType, class SmemLayoutA,
class GemmType, class SmemLayoutB, class SmemLayoutC>
class OutputType,
class SmemLayoutA,
class SmemLayoutB,
class SmemLayoutC,
class SmemLayoutScale>
struct SharedStorage { struct SharedStorage {
union { union {
struct { struct {
cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutA>> smem_a; cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutA>> smem_a;
cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutB>> smem_b; cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutB>> smem_b;
cute::array_aligned<float, cute::cosize_v<SmemLayoutScale>> smem_scale;
}; };
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutC>> smem_c; cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutC>> smem_c;
}; };
@@ -45,24 +39,18 @@ struct SharedStorage {
}; };
}; };
template <int kBlockM_, template<int kBlockM_, int kBlockN_, int kBlockK_,
int kBlockN_, int kNWarps_, int kStages_,
int kBlockK_, int kTiles_, int M_,
int kNWarps_,
int kStages_,
int kTiles_,
int M_,
int K_,
int TokenPackSize_, int TokenPackSize_,
int WeightScaleGroup_, int TAIL_N_ = 0,
int kClusterM_ = 1, int kClusterM_ = 1,
typename elem_type=cutlass::float_e4m3_t, typename elem_type=cutlass::float_e4m3_t,
typename OutputType = cutlass::bfloat16_t> typename OutputType = cutlass::bfloat16_t>
struct Kernel_traits { struct Kernel_traits {
using Element = elem_type; using Element = elem_type;
using ElementAccum = float;
using ElementOutput = OutputType; using ElementOutput = OutputType;
using ElementAccum = typename std::
conditional_t<WeightScaleGroup_ == K_, float, cutlass::half_t>;
static_assert(cutlass::sizeof_bits_v<Element> == 8); static_assert(cutlass::sizeof_bits_v<Element> == 8);
static constexpr int kNWarps = kNWarps_; static constexpr int kNWarps = kNWarps_;
@@ -78,10 +66,10 @@ struct Kernel_traits {
static constexpr int kTiles = kTiles_; static constexpr int kTiles = kTiles_;
static constexpr int TokenPackSize = TokenPackSize_; static constexpr int TokenPackSize = TokenPackSize_;
static constexpr int M = M_; static constexpr int M = M_;
static constexpr int K = K_; static constexpr int TAIL_N = TAIL_N_;
static constexpr int WeightScaleGroup = WeightScaleGroup_;
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kBlockK>>; using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kBlockK>>;
using TileShape_MNK_TAIL = Shape<Int<kBlockM>, Int<TAIL_N>, Int<kBlockK>>;
static constexpr int kClusterM = kClusterM_; static constexpr int kClusterM = kClusterM_;
using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>; using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
@@ -92,72 +80,74 @@ struct Kernel_traits {
using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>; using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
using TiledMma = decltype(cute::make_tiled_mma( using TiledMma = decltype(cute::make_tiled_mma(
cute::GMMA:: cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
rs_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
AtomLayoutMNK{})); AtomLayoutMNK{}));
using SmemLayoutAtomA = using TiledMma_TAIL = decltype(cute::make_tiled_mma(
decltype(cutlass::gemm::collective::detail::rs_smem_selector< cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK_TAIL>(),
GMMA::Major::K, AtomLayoutMNK{}));
Element,
Int<kBlockM>,
Int<kBlockK / 2>>());
using SmemLayoutA = decltype(tile_to_shape( using SmemLayoutAtomA = decltype(
SmemLayoutAtomA{}, cutlass::gemm::collective::detail::rs_smem_selector<
GMMA::Major::K, Element, Int<kBlockM>, Int<kBlockK / 2>>());
using SmemLayoutA = decltype(
tile_to_shape(SmemLayoutAtomA{},
make_shape(Int<kBlockM>{}, Int<kBlockK / 2>{}, Int<kStages>{}))); make_shape(Int<kBlockM>{}, Int<kBlockK / 2>{}, Int<kStages>{})));
using SmemLayoutAtomB = using SmemLayoutAtomB = decltype(
decltype(cutlass::gemm::collective::detail::rs_smem_selector< cutlass::gemm::collective::detail::rs_smem_selector<
GMMA::Major::K, GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK{})),
Element,
decltype(cute::get<1>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>()); decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutB = using SmemLayoutB = decltype(
decltype(tile_to_shape(SmemLayoutAtomB{}, tile_to_shape(SmemLayoutAtomB{},
make_shape(shape<1>(TileShape_MNK{}), make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
shape<2>(TileShape_MNK{}),
Int<kStages>{}))); using SmemLayoutAtomB_TAIL = decltype(
using SmemLayoutAtomC = cutlass::gemm::collective::detail::rs_smem_selector<
decltype(cutlass::gemm::collective::detail::rs_smem_selector< GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK_TAIL{})),
GMMA::Major::K, decltype(cute::get<2>(TileShape_MNK_TAIL{}))>());
ElementOutput,
using SmemLayoutB_TAIL = decltype(
tile_to_shape(SmemLayoutAtomB_TAIL{},
make_shape(
shape<1>(TileShape_MNK_TAIL{}),
shape<2>(TileShape_MNK_TAIL{}),
Int<kStages>{})
));
using SmemLayoutAtomC = decltype(
cutlass::gemm::collective::detail::rs_smem_selector<
GMMA::Major::K, ElementOutput,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<0>(TileShape_MNK{})),
decltype(cute::get<1>(TileShape_MNK{}))>()); decltype(cute::get<1>(TileShape_MNK{}))>());
using SmemLayoutC = using SmemLayoutC = decltype(tile_to_shape(SmemLayoutAtomC{}, select<0, 1>(TileShape_MNK{})));
decltype(tile_to_shape(SmemLayoutAtomC{}, select<0, 1>(TileShape_MNK{})));
using SmemCopyAtomAB = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>; using SmemCopyAtomAB = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
using SmemCopyAtomC = Copy_Atom<cute::SM90_U32x4_STSM_N, ElementOutput>; using SmemCopyAtomC = Copy_Atom<cute::SM90_U32x4_STSM_N, ElementOutput>;
using SmemLayoutScale = Layout<Shape<Int<kBlockM>, Int<kStages>>>; using SharedStorage = SharedStorage<
kStages, Element, ElementOutput, SmemLayoutA, SmemLayoutB, SmemLayoutC>;
using SharedStorage = SharedStorage<kStages,
Element,
ElementOutput,
SmemLayoutA,
SmemLayoutB,
SmemLayoutC,
SmemLayoutScale>;
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>; using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
using PipelineState = typename cutlass::PipelineState<kStages>; using PipelineState = typename cutlass::PipelineState<kStages>;
static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<OutputType>); static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<OutputType>);
static constexpr int kNumThreadsPerRow = kBlockN / kNumVecElem; static constexpr int kNumThreadsPerRow = kBlockN / kNumVecElem;
// static_assert(NumMmaThreads % kNumThreadsPerRow == 0); // static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow; static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
using TiledCopyCAtom = using TiledCopyCAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, OutputType>;
cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, OutputType>;
using TiledCopyCThrLayout = decltype(cute::make_layout( using TiledCopyCThrLayout = decltype(cute::make_layout(
cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}), cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
LayoutRight{})); LayoutRight{}));
using TiledCopyCValLayout = decltype(cute::make_layout( using TiledCopyCValLayout = decltype(cute::make_layout(
cute::make_shape(_1{}, Int<kNumVecElem>{}), LayoutRight{})); cute::make_shape(_1{}, Int<kNumVecElem>{}),
using TiledCopyC = LayoutRight{}));
decltype(make_tiled_copy(TiledCopyCAtom{}, using TiledCopyC = decltype(make_tiled_copy(
TiledCopyCAtom{},
TiledCopyCThrLayout{}, // Thr layout TiledCopyCThrLayout{}, // Thr layout
TiledCopyCValLayout{} // Val layout TiledCopyCValLayout{} // Val layout
)); ));

View File

@@ -14,10 +14,10 @@
#pragma once #pragma once
#include <cutlass/array.h>
#include <cutlass/cutlass.h> #include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h> #include <cutlass/array.h>
#include <cutlass/numeric_types.h> #include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include "cutlass/pipeline/pipeline.hpp" #include "cutlass/pipeline/pipeline.hpp"
#include "cute/tensor.hpp" #include "cute/tensor.hpp"
@@ -27,68 +27,62 @@
// #include "named_barrier.hpp" // #include "named_barrier.hpp"
#include "utils.hpp" #include "utils.hpp"
using namespace cute; using namespace cute;
template <typename Ktraits> template <typename Ktraits>
struct CollectiveMainloopFwd { struct CollectiveMainloopFwd {
using Element = typename Ktraits::Element; using Element = typename Ktraits::Element;
using ElementOutput = typename Ktraits::ElementOutput; using ElementOutput = typename Ktraits::ElementOutput;
using TileShape_MNK = typename Ktraits::TileShape_MNK; using TileShape_MNK = typename Ktraits::TileShape_MNK;
using TileShape_MNK_TAIL = typename Ktraits::TileShape_MNK_TAIL;
using ClusterShape = typename Ktraits::ClusterShape_MNK; using ClusterShape = typename Ktraits::ClusterShape_MNK;
using ElementAccum = typename Ktraits::ElementAccum; using ElementAccum = typename Ktraits::ElementAccum;
static constexpr int kStages = Ktraits::kStages; static constexpr int kStages = Ktraits::kStages;
static constexpr int kBlockM = Ktraits::kBlockM; static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockN = Ktraits::kBlockN; static constexpr int kBlockN = Ktraits::kBlockN;
static constexpr int TAIL_N = Ktraits::TAIL_N;
static constexpr int kBlockK = Ktraits::kBlockK; static constexpr int kBlockK = Ktraits::kBlockK;
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup; static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int kTiles = Ktraits::kTiles; static constexpr int kTiles = Ktraits::kTiles;
static constexpr int M = Ktraits::M; static constexpr int M = Ktraits::M;
static constexpr int K = Ktraits::K;
static constexpr int TokenPackSize = Ktraits::TokenPackSize; static constexpr int TokenPackSize = Ktraits::TokenPackSize;
static constexpr int WeightScaleGroup = Ktraits::WeightScaleGroup;
using GmemTiledCopy = cute::SM90_TMA_LOAD; using GmemTiledCopy = cute::SM90_TMA_LOAD;
using SmemLayoutA = typename Ktraits::SmemLayoutA; using SmemLayoutA = typename Ktraits::SmemLayoutA;
using SmemLayoutB = typename Ktraits::SmemLayoutB; using SmemLayoutB = typename Ktraits::SmemLayoutB;
using SmemLayoutC = typename Ktraits::SmemLayoutC; using SmemLayoutC = typename Ktraits::SmemLayoutC;
using SmemLayoutScale = typename Ktraits::SmemLayoutScale; using SmemLayoutB_TAIL = typename Ktraits::SmemLayoutB_TAIL;
using ShapeT = cute::Shape<int64_t, int64_t, int64_t>; using ShapeT = cute::Shape<int64_t, int64_t, int64_t>;
using StrideT = cute::Shape<int64_t, _1, int64_t>; using StrideT = cute::Shape<int64_t, _1, int64_t>;
using LayoutT = cute::Layout<ShapeT, StrideT>; using LayoutT = cute::Layout<ShapeT, StrideT>;
using ShapeTScale = cute::Shape<int64_t, int64_t, int64_t>;
using StrideTScale = cute::Shape<_1, int64_t, int64_t>;
using LayoutTScale = cute::Layout<ShapeTScale, StrideTScale>;
using TMA_A = decltype(make_tma_copy( using TMA_A = decltype(make_tma_copy(
GmemTiledCopy{}, GmemTiledCopy{},
make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), make_tensor(
make_gmem_ptr(static_cast<Element const*>(nullptr)),
ShapeT{}, ShapeT{},
StrideT{}), StrideT{}
),
SmemLayoutA{}(_, _, _0{}), SmemLayoutA{}(_, _, _0{}),
select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}), select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}),
size<0>(ClusterShape{}))); size<0>(ClusterShape{})));
using TMA_B = decltype(make_tma_copy( using TMA_B = decltype(make_tma_copy(
GmemTiledCopy{}, GmemTiledCopy{},
make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), make_tensor(
make_gmem_ptr(static_cast<Element const*>(nullptr)),
ShapeT{}, ShapeT{},
StrideT{}), StrideT{}
),
take<0, 2>(SmemLayoutB{}), take<0, 2>(SmemLayoutB{}),
select<1, 2>(TileShape_MNK{}), select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{}))); size<0>(ClusterShape{})));
using TMA_Scale = decltype(make_tma_copy(
GmemTiledCopy{},
make_tensor(make_gmem_ptr(static_cast<float const*>(nullptr)),
ShapeTScale{},
StrideTScale{}),
SmemLayoutScale{}(_, _0{}),
select<0>(Shape<Int<kBlockM>>{}),
size<0>(ClusterShape{})));
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{}); static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{});
using MainloopPipeline = typename Ktraits::MainloopPipeline; using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params; using PipelineParams = typename MainloopPipeline::Params;
@@ -97,12 +91,8 @@ struct CollectiveMainloopFwd {
using SmemCopyAtomC = typename Ktraits::SmemCopyAtomC; using SmemCopyAtomC = typename Ktraits::SmemCopyAtomC;
using TiledCopyC = typename Ktraits::TiledCopyC; using TiledCopyC = typename Ktraits::TiledCopyC;
static constexpr uint32_t TmaTransactionBytesA = static_cast<uint32_t>( static constexpr uint32_t TmaTransactionBytesA = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutA{})) * cutlass::sizeof_bits_v<Element> / 8);
size(take<0, 2>(SmemLayoutA{})) * cutlass::sizeof_bits_v<Element> / 8); static constexpr uint32_t TmaTransactionBytesB = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutB{})) * cutlass::sizeof_bits_v<Element> / 8);
static constexpr uint32_t TmaTransactionBytesB = static_cast<uint32_t>(
size(take<0, 2>(SmemLayoutB{})) * cutlass::sizeof_bits_v<Element> / 8);
static constexpr uint32_t TmaTransactionBytesScale = static_cast<uint32_t>(
size(SmemLayoutScale{}(_, _0{})) * cutlass::sizeof_bits_v<float> / 8);
struct Arguments { struct Arguments {
Element const* ptr_A; Element const* ptr_A;
@@ -112,168 +102,99 @@ struct CollectiveMainloopFwd {
ElementOutput * ptr_C; ElementOutput * ptr_C;
LayoutT layout_C; LayoutT layout_C;
const float *weight_scale; const float *weight_scale;
LayoutTScale layout_Scale; const float *input_row_sum;
const float* input_scale;
const int64_t * tokens; const int64_t * tokens;
}; };
struct Params { struct Params {
LayoutT layout_A; LayoutT layout_A;
LayoutT layout_B; LayoutT layout_B;
LayoutTScale layout_Scale;
TMA_A tma_load_A; TMA_A tma_load_A;
TMA_B tma_load_B; TMA_B tma_load_B;
TMA_Scale tma_load_Scale;
ElementOutput * ptr_C; ElementOutput * ptr_C;
const float *weight_scale; const float *weight_scale;
const float* input_scale; const float *input_row_sum;
const int64_t * tokens; const int64_t * tokens;
}; };
Params static to_underlying_arguments(Arguments const& args) {
Params static
to_underlying_arguments(Arguments const& args) {
Tensor mA = make_tensor(make_gmem_ptr(args.ptr_A), args.layout_A); Tensor mA = make_tensor(make_gmem_ptr(args.ptr_A), args.layout_A);
TMA_A tma_load_A = TMA_A tma_load_A = make_tma_copy(
make_tma_copy(GmemTiledCopy{}, GmemTiledCopy{},
mA, mA,
SmemLayoutA{}(_, _, _0{}), SmemLayoutA{}(_, _, _0{}),
select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}), select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}),
size<0>(ClusterShape{})); size<0>(ClusterShape{}));
Tensor mB = make_tensor(make_gmem_ptr(args.ptr_B), args.layout_B); Tensor mB = make_tensor(make_gmem_ptr(args.ptr_B), args.layout_B);
TMA_B tma_load_B = make_tma_copy(GmemTiledCopy{}, TMA_B tma_load_B = make_tma_copy(
GmemTiledCopy{},
mB, mB,
SmemLayoutB{}(_, _, _0{}), SmemLayoutB{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}), select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{})); size<0>(ClusterShape{}));
Tensor mScale =
make_tensor(make_gmem_ptr(args.weight_scale), args.layout_Scale);
TMA_Scale tma_load_Scale = make_tma_copy(GmemTiledCopy{},
mScale,
SmemLayoutScale{}(_, _0{}),
select<0>(Shape<Int<kBlockM>>{}),
size<0>(ClusterShape{}));
return {args.layout_A, return {args.layout_A, args.layout_B, tma_load_A, tma_load_B,
args.layout_B, args.ptr_C, args.weight_scale, args.input_row_sum, args.tokens};
args.layout_Scale,
tma_load_A,
tma_load_B,
tma_load_Scale,
args.ptr_C,
args.weight_scale,
args.input_scale,
args.tokens};
} }
CUTLASS_DEVICE CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params) { static void prefetch_tma_descriptors(Params const& mainloop_params) {
cute::prefetch_tma_descriptor( cute::prefetch_tma_descriptor(mainloop_params.tma_load_A.get_tma_descriptor());
mainloop_params.tma_load_A.get_tma_descriptor()); cute::prefetch_tma_descriptor(mainloop_params.tma_load_B.get_tma_descriptor());
cute::prefetch_tma_descriptor(
mainloop_params.tma_load_B.get_tma_descriptor());
if constexpr (WeightScaleGroup < K) {
cute::prefetch_tma_descriptor(
mainloop_params.tma_load_Scale.get_tma_descriptor());
}
} }
template <typename SharedStorage, typename FrgTensorO, typename TiledMma> template <int CUR_N, typename SharedStorage, typename FrgTensorO, typename TiledMma>
CUTLASS_DEVICE void store(Params const& mainloop_params, CUTLASS_DEVICE void
store(Params const& mainloop_params,
FrgTensorO & tOrO, FrgTensorO & tOrO,
SharedStorage& shared_storage, SharedStorage& shared_storage,
TiledMma tiled_mma, TiledMma tiled_mma,
const float *input_row_sum,
const float *weight_scale, const float *weight_scale,
const float* input_scale,
const int64_t tokens, const int64_t tokens,
const int64_t pre_fix_tokens, const int64_t pre_fix_tokens,
const int bidm, const int bidm,
const int bidn, const int bidn,
const int bidb, const int bidb,
const int tidx) { const int tidx) {
using packHalf = typename PackedHalf<ElementOutput>::Type; using packHalf = typename PackedHalf<ElementOutput>::Type;
Tensor tOrO_out = make_tensor<ElementOutput>(tOrO.layout()); Tensor tOrO_out = make_tensor<ElementOutput>(tOrO.layout());
if (input_scale != nullptr) {
const int lane_id = tidx % 4 * 2;
if constexpr (WeightScaleGroup == K) {
#pragma unroll #pragma unroll
for (int i = 0; i < size(tOrO); i+=4) { for (int i = 0; i < size(tOrO); i+=4) {
const int scale_idx = i * 2 + lane_id; const int sum_idx = i * 2;
tOrO[i] = tOrO[i] * weight_scale[0] * input_scale[scale_idx]; tOrO[i] = (tOrO[i] + input_row_sum[sum_idx]) * weight_scale[0];
tOrO[i + 1] = tOrO[i + 1] = (tOrO[i + 1] + input_row_sum[sum_idx + 1]) * weight_scale[0];
tOrO[i + 1] * weight_scale[0] * input_scale[scale_idx + 1]; tOrO[i + 2] = (tOrO[i + 2] + input_row_sum[sum_idx]) * weight_scale[1];
tOrO[i + 2] = tOrO[i + 2] * weight_scale[1] * input_scale[scale_idx]; tOrO[i + 3] = (tOrO[i + 3] + input_row_sum[sum_idx + 1]) * weight_scale[1];
tOrO[i + 3] = *reinterpret_cast<packHalf*>(&tOrO_out[i]) = packHalf(tOrO[i], tOrO[i + 2]);
tOrO[i + 3] * weight_scale[1] * input_scale[scale_idx + 1]; *reinterpret_cast<packHalf*>(&tOrO_out[i + 2]) = packHalf(tOrO[i + 1], tOrO[i + 3]);
*reinterpret_cast<packHalf*>(&tOrO_out[i]) =
packHalf(tOrO[i], tOrO[i + 2]);
*reinterpret_cast<packHalf*>(&tOrO_out[i + 2]) =
packHalf(tOrO[i + 1], tOrO[i + 3]);
}
} else {
#pragma unroll
for (int i = 0; i < size(tOrO); i += 4) {
const int scale_idx = i * 2 + lane_id;
*reinterpret_cast<packHalf*>(&tOrO_out[i]) =
packHalf(float(tOrO[i]) * input_scale[scale_idx],
float(tOrO[i + 2]) * input_scale[scale_idx]);
*reinterpret_cast<packHalf*>(&tOrO_out[i + 2]) =
packHalf(float(tOrO[i + 1]) * input_scale[scale_idx + 1],
float(tOrO[i + 3]) * input_scale[scale_idx + 1]);
}
}
} else {
if constexpr (WeightScaleGroup == K) {
#pragma unroll
for (int i = 0; i < size(tOrO); i += 4) {
tOrO[i] = (tOrO[i]) * weight_scale[0];
tOrO[i + 1] = tOrO[i + 1] * weight_scale[0];
tOrO[i + 2] = tOrO[i + 2] * weight_scale[1];
tOrO[i + 3] = tOrO[i + 3] * weight_scale[1];
*reinterpret_cast<packHalf*>(&tOrO_out[i]) =
packHalf(tOrO[i], tOrO[i + 2]);
*reinterpret_cast<packHalf*>(&tOrO_out[i + 2]) =
packHalf(tOrO[i + 1], tOrO[i + 3]);
}
} else {
#pragma unroll
for (int i = 0; i < size(tOrO); i += 4) {
*reinterpret_cast<packHalf*>(&tOrO_out[i]) =
packHalf(float(tOrO[i]), float(tOrO[i + 2]));
*reinterpret_cast<packHalf*>(&tOrO_out[i + 2]) =
packHalf(float(tOrO[i + 1]), float(tOrO[i + 3]));
}
}
} }
uint16_t* smem_c = uint16_t *smem_c = reinterpret_cast<uint16_t *>(shared_storage.smem_c.data());
reinterpret_cast<uint16_t*>(shared_storage.smem_c.data());
uint32_t * reg_data = reinterpret_cast<uint32_t*>(tOrO_out.data()); uint32_t * reg_data = reinterpret_cast<uint32_t*>(tOrO_out.data());
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0); cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
constexpr int k_copy_times = kBlockN / 16; constexpr int k_copy_times = CUR_N / 16;
#pragma unroll #pragma unroll
for (int i = 0; i < k_copy_times; i++) { for (int i = 0; i < k_copy_times; i++) {
uint32_t smem_ptr = cast_smem_ptr_to_uint( uint32_t smem_ptr = cast_smem_ptr_to_uint(reinterpret_cast<uint128_t*>(smem_c + i * 16 * 128) + tidx);
reinterpret_cast<uint128_t*>(smem_c + i * 16 * 128) + tidx);
#if defined(CUTE_ARCH_STSM_SM90_ENABLED) #if defined(CUTE_ARCH_STSM_SM90_ENABLED)
asm volatile ( asm volatile (
"stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, " "stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
"%4};\n" ::"r"(smem_ptr), :: "r"(smem_ptr), "r"(reg_data[4 * i + 0]), "r"(reg_data[4 * i + 2]), "r"(reg_data[4 * i + 1]), "r"(reg_data[4 * i + 3]));
"r"(reg_data[4 * i + 0]),
"r"(reg_data[4 * i + 2]),
"r"(reg_data[4 * i + 1]),
"r"(reg_data[4 * i + 3]));
#endif #endif
} }
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0); cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
const int expert_idx = const int batch_idx = TokenPackSize == 0 ? pre_fix_tokens * M : bidb * M * TokenPackSize;
TokenPackSize == 0 ? pre_fix_tokens * M : bidb * M * TokenPackSize; ElementOutput * store_c = mainloop_params.ptr_C + batch_idx + bidn * (M * kBlockN) + bidm * kBlockM;
ElementOutput* store_c = mainloop_params.ptr_C + expert_idx +
bidn * (M * kBlockN) + bidm * kBlockM;
const int reamin_tokens = tokens - bidn * kBlockN; const int reamin_tokens = tokens - bidn * kBlockN;
@@ -281,12 +202,11 @@ struct CollectiveMainloopFwd {
constexpr int kPackSize = 16 / sizeof(ElementOutput); constexpr int kPackSize = 16 / sizeof(ElementOutput);
constexpr int kNumVecElem = kBlockM / kPackSize; constexpr int kNumVecElem = kBlockM / kPackSize;
constexpr int copy_len = kBlockN * kNumVecElem; constexpr int copy_len = CUR_N * kNumVecElem;
#pragma unroll #pragma unroll
for (int idx = tidx; idx < copy_len; idx += NumMmaThreads) { for (int idx = tidx; idx < copy_len; idx += NumMmaThreads) {
const int idx_div2 = idx / 2; const int idx_div2 = idx / 2;
const int store_idx = idx_div2 / 128 * 128 + idx_div2 % 8 * 16 + const int store_idx = idx_div2 / 128 * 128 + idx_div2 % 8 * 16 + idx_div2 % 128 / 16 + idx_div2 % 16 / 8 * 8;
idx_div2 % 128 / 16 + idx_div2 % 16 / 8 * 8;
const int store_global_idx = store_idx * 2 + col; const int store_global_idx = store_idx * 2 + col;
const int row = store_global_idx / kNumVecElem; const int row = store_global_idx / kNumVecElem;
const int col = store_global_idx % kNumVecElem; const int col = store_global_idx % kNumVecElem;
@@ -294,25 +214,26 @@ struct CollectiveMainloopFwd {
continue; continue;
} }
const int offset = row * (M / kPackSize) + col; const int offset = row * (M / kPackSize) + col;
reinterpret_cast<uint4*>(store_c)[offset] = reinterpret_cast<uint4*>(store_c)[offset] = reinterpret_cast<uint4*>(smem_c)[idx];
reinterpret_cast<uint4*>(smem_c)[idx];
} }
} }
template <typename MTensor> template <typename MTensor>
CUTLASS_DEVICE auto get_local_no_packed_tensor(const MTensor& mB, CUTLASS_DEVICE auto get_local_no_packed_tensor(
const MTensor &mB,
const int pre_fix_token, const int pre_fix_token,
const int actual_token, const int actual_token,
const int bidn) const { const int bidn) const {
auto g_tensor = domain_offset(make_coord(pre_fix_token, _0{}), mB(_, _, 0)); auto g_tensor = domain_offset(make_coord(pre_fix_token, _0{}), mB(_, _, 0));
Tensor gB = local_tile( Tensor gB = local_tile(g_tensor, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
g_tensor, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
return gB; return gB;
} }
template <typename SharedStorage> template <typename SharedStorage>
CUTLASS_DEVICE void load(Params const& mainloop_params, CUTLASS_DEVICE void
load(Params const& mainloop_params,
MainloopPipeline pipeline, MainloopPipeline pipeline,
PipelineState& smem_pipe_write, PipelineState& smem_pipe_write,
SharedStorage &shared_storage, SharedStorage &shared_storage,
@@ -322,117 +243,114 @@ struct CollectiveMainloopFwd {
const int bidn, const int bidn,
const int bidb, const int bidb,
const int tidx) { const int tidx) {
Tensor sA =
make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
Tensor sB =
make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{});
Tensor sScale = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()),
SmemLayoutScale{});
Tensor mA = mainloop_params.tma_load_A.get_tma_tensor( Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
mainloop_params.layout_A.shape()); Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{});
Tensor mB = mainloop_params.tma_load_B.get_tma_tensor(
mainloop_params.layout_B.shape());
Tensor mScale = mainloop_params.tma_load_Scale.get_tma_tensor(
mainloop_params.layout_Scale.shape());
Tensor gA = Tensor mA = mainloop_params.tma_load_A.get_tma_tensor(mainloop_params.layout_A.shape());
local_tile(mA(_, _, bidb), Tensor mB = mainloop_params.tma_load_B.get_tma_tensor(mainloop_params.layout_B.shape());
select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}),
make_coord(bidm, _));
Tensor gScale = local_tile(
mScale(_, bidm, bidb), select<0>(Shape<Int<kBlockM>>{}), make_coord(_));
auto [tAgA, tAsA] = tma_partition(mainloop_params.tma_load_A, Tensor gA = local_tile(mA(_, _, bidb), select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}), make_coord(bidm, _));
_0{},
Layout<ClusterShape>{}, auto [tAgA, tAsA] = tma_partition(mainloop_params.tma_load_A, _0{}, Layout<ClusterShape>{}, group_modes<0, 2>(sA), group_modes<0, 2>(gA));
group_modes<0, 2>(sA),
group_modes<0, 2>(gA)); const int kIters = kTiles / kStages;
if constexpr (TokenPackSize == 0) { if constexpr (TokenPackSize == 0) {
Tensor gB = get_local_no_packed_tensor(mB, pre_fix_tokens, tokens, bidn); Tensor gB = get_local_no_packed_tensor(
mB,
pre_fix_tokens,
tokens,
bidn);
auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B, auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B, _0{}, Layout<ClusterShape>{}, group_modes<0, 2>(sB), group_modes<0, 2>(gB));
_0{},
Layout<ClusterShape>{},
group_modes<0, 2>(sB),
group_modes<0, 2>(gB));
if (tidx == 0) { if (tidx == 0) {
#pragma unroll #pragma unroll
for (int kiter = 0; kiter < kTiles; ++kiter) { for (int kiter = 0; kiter < kIters; ++kiter) {
#pragma unroll
for (int s = 0; s < kStages; s++) {
const int i = kiter * kStages + s;
pipeline.producer_acquire(smem_pipe_write); pipeline.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_A.with( copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
*pipeline.producer_get_barrier(smem_pipe_write), 0), tAgA(_, i), tAsA(_, smem_pipe_write.index()));
tAgA(_, kiter),
tAsA(_, smem_pipe_write.index()));
copy(mainloop_params.tma_load_B.with( copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
*pipeline.producer_get_barrier(smem_pipe_write), 0), tBgB(_, i), tBsB(_, smem_pipe_write.index()));
tBgB(_, kiter), ++smem_pipe_write;
tBsB(_, smem_pipe_write.index())); }
if constexpr (WeightScaleGroup < K) {
copy(mainloop_params.tma_load_Scale.with(
*pipeline.producer_get_barrier(smem_pipe_write), 0),
gScale(_, kiter),
sScale(_, smem_pipe_write.index()));
} }
#pragma unroll
for (int i = kIters * kStages; i < kTiles; ++i) {
pipeline.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
++smem_pipe_write; ++smem_pipe_write;
} }
} }
} else { } else {
auto mB_this_expert = make_tensor( auto mB_this_batch = make_tensor(
mB(_, _, bidb).data(), mB(_, _, bidb).data(),
make_layout(cute::make_shape(tokens, size<1>(mB)), mB.stride())); make_layout(
Tensor gB = local_tile( cute::make_shape(tokens, size<1>(mB)),
mB_this_expert, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _)); mB.stride()
auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B, ));
_0{}, Tensor gB = local_tile(mB_this_batch, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
Layout<ClusterShape>{}, auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B, _0{}, Layout<ClusterShape>{}, group_modes<0, 2>(sB), group_modes<0, 2>(gB));
group_modes<0, 2>(sB),
group_modes<0, 2>(gB));
if (tidx == 0) { if (tidx == 0) {
#pragma unroll #pragma unroll
for (int kiter = 0; kiter < kTiles; ++kiter) { for (int kiter = 0; kiter < kIters; ++kiter) {
#pragma unroll
for (int s = 0; s < kStages; s++) {
const int i = kiter * kStages + s;
pipeline.producer_acquire(smem_pipe_write); pipeline.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_A.with( copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
*pipeline.producer_get_barrier(smem_pipe_write), 0), tAgA(_, i), tAsA(_, smem_pipe_write.index()));
tAgA(_, kiter),
tAsA(_, smem_pipe_write.index()));
copy(mainloop_params.tma_load_B.with( copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
*pipeline.producer_get_barrier(smem_pipe_write), 0), tBgB(_, i), tBsB(_, smem_pipe_write.index()));
tBgB(_, kiter), ++smem_pipe_write;
tBsB(_, smem_pipe_write.index())); }
if constexpr (WeightScaleGroup < K) {
copy(mainloop_params.tma_load_Scale.with(
*pipeline.producer_get_barrier(smem_pipe_write), 0),
gScale(_, kiter),
sScale(_, smem_pipe_write.index()));
} }
#pragma unroll
for (int i = kIters * kStages; i < kTiles; ++i) {
pipeline.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
++smem_pipe_write; ++smem_pipe_write;
} }
} }
} }
} }
template <typename SharedStorage, typename FrgTensorO, typename TiledMma> template <int CUR_N, typename SharedStorage, typename FrgTensorO, typename TiledMma>
CUTLASS_DEVICE void mma(Params const& mainloop_params, CUTLASS_DEVICE void
mma(Params const& mainloop_params,
TiledMma tiled_mma, TiledMma tiled_mma,
MainloopPipeline pipeline, MainloopPipeline pipeline,
PipelineState& smem_pipe_read, PipelineState& smem_pipe_read,
SharedStorage& shared_storage, SharedStorage& shared_storage,
FrgTensorO &tSrS, FrgTensorO &tSrS,
const int tidx) { const int tidx) {
Tensor sA =
make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{}); using sMemBLayout = std::conditional_t<
Tensor sB = CUR_N == kBlockN,
make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{}); SmemLayoutB,
SmemLayoutB_TAIL
>;
Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), sMemBLayout{});
tiled_mma.accumulate_ = GMMA::ScaleOut::One; tiled_mma.accumulate_ = GMMA::ScaleOut::One;
auto threadMma = tiled_mma.get_thread_slice(tidx); auto threadMma = tiled_mma.get_thread_slice(tidx);
@@ -447,124 +365,30 @@ struct CollectiveMainloopFwd {
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token); pipeline.consumer_wait(smem_pipe_read, barrier_token);
}; };
const int kIters = kTiles / kStages;
constexpr int B_STEPS = CUR_N == 0 ? 1 : (kBlockN / CUR_N);
#pragma unroll #pragma unroll
for (int kiter = 0; kiter < kTiles; ++kiter) { for (int kiter = 0; kiter < kIters; ++kiter) {
Tensor tSsA =
smem_thr_copy_A.partition_S(sA(_, _, smem_pipe_read.index()));
consumer_wait(pipeline, smem_pipe_read);
gemm</*wg_wait=*/0>(tiled_mma,
tSrA,
tSsA,
tSrB(_, _, _, smem_pipe_read.index()),
tSrS,
smem_tiled_copy_A,
smem_thr_copy_A);
pipeline.consumer_release(smem_pipe_read);
++smem_pipe_read;
}
}
template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
CUTLASS_DEVICE void mma_pipeline(Params const& mainloop_params,
TiledMma tiled_mma,
MainloopPipeline pipeline,
PipelineState& smem_pipe_read,
SharedStorage& shared_storage,
FrgTensorO& tSrS,
const int tidx) {
Tensor sA =
make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
Tensor sB =
make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{});
float2* weight_scale =
reinterpret_cast<float2*>(shared_storage.smem_scale.data()) + tidx / 4;
Tensor tSrS1 = make_fragment_like(tSrS);
Tensor tSrS2 = make_fragment_like(tSrS);
__half2* tSrS_data =
reinterpret_cast<__half2*>(raw_pointer_cast(tSrS.data()));
__half2* tSrS1_data =
reinterpret_cast<__half2*>(raw_pointer_cast(tSrS1.data()));
__half2* tSrS2_data =
reinterpret_cast<__half2*>(raw_pointer_cast(tSrS2.data()));
auto threadMma = tiled_mma.get_thread_slice(tidx);
auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomAB{}, tiled_mma);
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(tidx);
Tensor tSrA = threadMma.partition_fragment_A(sA(_, _, 0));
Tensor tSrB = threadMma.partition_fragment_B(sB);
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
};
__half2 scale1, scale2, scale3, scale4;
float2 scale_cur_k;
#pragma unroll #pragma unroll
for (int kiter = 0; kiter < kTiles;) { for (int s = 0; s < kStages; s++) {
Tensor tSsA1 = Tensor tSsA = smem_thr_copy_A.partition_S(sA(_, _, s));
smem_thr_copy_A.partition_S(sA(_, _, smem_pipe_read.index()));
consumer_wait(pipeline, smem_pipe_read); consumer_wait(pipeline, smem_pipe_read);
scale_cur_k = *(weight_scale + smem_pipe_read.index() * (kBlockM / 2)); gemm</*wg_wait=*/0>(tiled_mma, tSrA, tSsA, tSrB(_, _, _, s * B_STEPS), tSrS, smem_tiled_copy_A, smem_thr_copy_A);
scale1 = __half2(scale_cur_k.x, scale_cur_k.x);
scale2 = __half2(scale_cur_k.y, scale_cur_k.y);
gemm</*wg_wait=*/0>(tiled_mma,
tSrA,
tSsA1,
tSrB(_, _, _, smem_pipe_read.index()),
tSrS1,
smem_tiled_copy_A,
smem_thr_copy_A);
pipeline.consumer_release(smem_pipe_read);
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
if (kiter > 0) {
for (int i = 0; i < size(tSrS) / 2; i += 2) {
tSrS_data[i] = __hfma2(tSrS2_data[i], scale3, tSrS_data[i]);
tSrS_data[i + 1] =
__hfma2(tSrS2_data[i + 1], scale4, tSrS_data[i + 1]);
}
}
++smem_pipe_read;
++kiter;
if (kiter < kTiles) {
Tensor tSsA2 =
smem_thr_copy_A.partition_S(sA(_, _, smem_pipe_read.index()));
consumer_wait(pipeline, smem_pipe_read);
scale_cur_k = *(weight_scale + smem_pipe_read.index() * (kBlockM / 2));
scale3 = __half2(scale_cur_k.x, scale_cur_k.x);
scale4 = __half2(scale_cur_k.y, scale_cur_k.y);
gemm</*wg_wait=*/0>(tiled_mma,
tSrA,
tSsA2,
tSrB(_, _, _, smem_pipe_read.index()),
tSrS2,
smem_tiled_copy_A,
smem_thr_copy_A);
pipeline.consumer_release(smem_pipe_read); pipeline.consumer_release(smem_pipe_read);
++smem_pipe_read; ++smem_pipe_read;
++kiter;
} }
}
#pragma unroll
for (int i = 0; i < kTiles % kStages; ++i) {
Tensor tSsA = smem_thr_copy_A.partition_S(sA(_, _, i));
consumer_wait(pipeline, smem_pipe_read);
for (int i = 0; i < size(tSrS) / 2; i += 2) { gemm</*wg_wait=*/0>(tiled_mma, tSrA, tSsA, tSrB(_, _, _, i * B_STEPS), tSrS, smem_tiled_copy_A, smem_thr_copy_A);
tSrS_data[i] = __hfma2(tSrS1_data[i], scale1, tSrS_data[i]); pipeline.consumer_release(smem_pipe_read);
tSrS_data[i + 1] = __hfma2(tSrS1_data[i + 1], scale2, tSrS_data[i + 1]); ++smem_pipe_read;
}
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
if constexpr (kTiles % 2 == 0) {
for (int i = 0; i < size(tSrS) / 2; i += 2) {
tSrS_data[i] = __hfma2(tSrS2_data[i], scale3, tSrS_data[i]);
tSrS_data[i + 1] = __hfma2(tSrS2_data[i + 1], scale4, tSrS_data[i + 1]);
}
} }
} }
}; };

View File

@@ -24,14 +24,15 @@
#include <cuda_bf16.h> #include <cuda_bf16.h>
#endif #endif
#include <cute/arch/cluster_sm90.hpp> // For cute::elect_one_sync()
#include <cute/tensor.hpp> #include <cute/tensor.hpp>
#include <cute/arch/cluster_sm90.hpp> // For cute::elect_one_sync()
#include <cutlass/array.h> #include <cutlass/array.h>
#include <cutlass/cutlass.h> #include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h> #include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h> #include <cutlass/numeric_types.h>
using namespace cute; using namespace cute;
template<typename T> template<typename T>
@@ -47,93 +48,67 @@ struct PackedHalf<cutlass::bfloat16_t> {
using Type = nv_bfloat162; using Type = nv_bfloat162;
}; };
template <typename To_type, typename Engine, typename Layout> template <typename To_type, typename Engine, typename Layout>
__forceinline__ __device__ auto convert_type( __forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type; using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value; constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op; cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
auto frag = auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(
tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout()); return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
} }
template <int numel> template <int numel>
__forceinline__ __device__ void convert_c4_2_fp8(const int32_t *src, __forceinline__ __device__ void convert_c4_2_fp8(const int32_t * src, int32_t * dst1, int32_t * dst2) {
int32_t *dst1,
int32_t *dst2) {
#pragma unroll #pragma unroll
for (int i = 0; i < numel; ++i) { for (int i = 0; i < numel; ++i) {
uint32_t head1 = src[i] & 0x80808080; dst1[i] = (src[i] >> 4) & 0x0f0f0f0f;
dst1[i] = (src[i] >> 4) & 0x07070707; dst2[i] = src[i] & 0x0f0f0f0f;
dst1[i] = dst1[i] | head1;
uint32_t head2 = (src[i] & 0x08080808) << 4;
dst2[i] = src[i] & 0x07070707;
dst2[i] = dst2[i] | head2;
} }
} }
template <int wg_wait = 0, template <int wg_wait=0, bool arrive=true,
bool arrive = true, bool commit=true, typename Tensor0, typename Tensor1,
bool commit = true, typename Tensor2, typename Tensor3, typename TiledMma,
typename Tensor0, typename ThrCopyA, typename TiledCopyA>
typename Tensor1, __forceinline__ __device__ void gemm(
typename Tensor2, TiledMma &tiled_mma,
typename Tensor3,
typename TiledMma,
typename ThrCopyA,
typename TiledCopyA>
__forceinline__ __device__ void gemm(TiledMma &tiled_mma,
Tensor0 &tCrA, Tensor0 &tCrA,
Tensor1 &tCsA, Tensor1 &tCsA,
Tensor2 const &tCrB, Tensor2 const &tCrB,
Tensor3 &tCrC, Tensor3 &tCrC,
TiledCopyA const &tiled_copy_A, TiledCopyA const &tiled_copy_A,
ThrCopyA const &thr_copy_A) { ThrCopyA const &thr_copy_A) {
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
typename TiledMma::FrgTypeA>::value;
Tensor tCrA1 = make_tensor<cutlass::float_e4m3_t>(tCrA.layout()); Tensor tCrA1 = make_tensor<cutlass::float_e4m3_t>(tCrA.layout());
Tensor tCrA2 = make_tensor<cutlass::float_e4m3_t>(tCrA.layout()); Tensor tCrA2 = make_tensor<cutlass::float_e4m3_t>(tCrA.layout());
if constexpr (Is_RS) { if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA));
}
warpgroup_fence_operand(tCrC); warpgroup_fence_operand(tCrC);
if constexpr (arrive) { if constexpr (arrive) {
warpgroup_arrive(); warpgroup_arrive();
} }
constexpr int numel = decltype(size(tCrA(_, _, 0)))::value / 4; constexpr int numel = decltype(size(tCrA(_, _, 0)))::value / 4;
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
cute::copy(tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); cute::copy(tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{}));
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
if (k_block < size<2>(tCrA) - 1) { if (k_block < size<2>(tCrA) - 1) {
cute::copy(tiled_copy_A, cute::copy(tiled_copy_A, tCsA(_, _, k_block + 1), tCrA_copy_view(_, _, k_block + 1));
tCsA(_, _, k_block + 1),
tCrA_copy_view(_, _, k_block + 1));
} }
int32_t *tCrA_data = int32_t * tCrA_data = reinterpret_cast<int32_t *>(tCrA(_,_,k_block).data());
reinterpret_cast<int32_t *>(tCrA(_, _, k_block).data()); int32_t * tCrA1_data = reinterpret_cast<int32_t *>(tCrA1(_,_,k_block).data());
int32_t *tCrA1_data = int32_t * tCrA2_data = reinterpret_cast<int32_t *>(tCrA2(_,_,k_block).data());
reinterpret_cast<int32_t *>(tCrA1(_, _, k_block).data());
int32_t *tCrA2_data =
reinterpret_cast<int32_t *>(tCrA2(_, _, k_block).data());
convert_c4_2_fp8<numel>(tCrA_data, tCrA1_data, tCrA2_data); convert_c4_2_fp8<numel>(tCrA_data, tCrA1_data, tCrA2_data);
cute::gemm(tiled_mma, tCrA1(_,_,k_block), tCrB(_,_,2 * k_block), tCrC); cute::gemm(tiled_mma, tCrA1(_,_,k_block), tCrB(_,_,2 * k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One; cute::gemm(tiled_mma, tCrA2(_,_,k_block), tCrB(_,_, 2 * k_block + 1), tCrC);
cute::gemm(
tiled_mma, tCrA2(_, _, k_block), tCrB(_, _, 2 * k_block + 1), tCrC);
} }
if constexpr (commit) { if constexpr (commit) {
warpgroup_commit_batch(); warpgroup_commit_batch();
} }
if constexpr (wg_wait >= 0) { if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
warpgroup_wait<wg_wait>();
}
warpgroup_fence_operand(tCrC); warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) { if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA));
}
} }

View File

@@ -16,61 +16,75 @@
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif #endif
#include "w4afp8_gemm.h"
#include "helper.h" #include "helper.h"
#include "paddle/extension.h" #include "paddle/extension.h"
#include "w4afp8_gemm_template.h" #include "w4afp8_gemm_template.h"
#include "weight_kernel.hpp" #include "w4afp8_gemm.h"
#include "weight_scale_kernel.hpp"
template <typename T>
class NVTraits;
template <> void weight_convert(const uint8_t *weight, uint8_t *weight_new, int batch, int M, int K) {
class NVTraits<__nv_fp8_e4m3> { assert(K % 64 == 0);
for (int b = 0; b < batch; ++b) {
for (int m = 0; m < M; ++m) {
for (int k = 0; k < K; k+=64) {
for (int k_inner = 0; k_inner < 32; ++k_inner) {
uint8_t temp = 0;
uint8_t left = weight[b * M * K + m * K + k + k_inner];
uint8_t right = weight[b * M * K + m * K + k + k_inner + 32];
temp |= left << 4;
temp |= right;
weight_new[b * M * K / 2 + m * K / 2 + k / 2 + k_inner] = *reinterpret_cast<uint8_t*>(&temp);
}
}
}
}
}
template <typename T> class NVTraits;
template <> class NVTraits<__nv_fp8_e4m3> {
public: public:
typedef cutlass::float_e4m3_t data_t; typedef cutlass::float_e4m3_t data_t;
}; };
template <> template <> class NVTraits<__nv_bfloat16>{
class NVTraits<__nv_bfloat16> {
public: public:
typedef cutlass::bfloat16_t data_t; typedef cutlass::bfloat16_t data_t;
}; };
template <> template <> class NVTraits<half>{
class NVTraits<half> {
public: public:
typedef cutlass::half_t data_t; typedef cutlass::half_t data_t;
}; };
template <typename OutputType> template <typename OutputType>
void DisPatchW4AFp8Gemm(const cutlass::float_e4m3_t* input, void DisPatchW4AFp8Gemm(
const cutlass::float_e4m3_t* input,
const cutlass::float_e4m3_t* weight, const cutlass::float_e4m3_t* weight,
const int64_t * tokens, const int64_t * tokens,
const float * input_row_sum,
const float * weight_scale, const float * weight_scale,
const float* input_dequant_scale,
OutputType * out, OutputType * out,
const int64_t token_padding_size, const int64_t token_padding_size,
const int64_t max_tokens, const int64_t max_tokens,
const int Experts, const int batch_size,
const int64_t M, const int64_t M,
const int64_t K, const int64_t K,
const int WeightScaleGroup,
cudaStream_t stream) { cudaStream_t stream) {
int kBlockN = 256; int kBlockN = 256;
int TailN = 0;
if constexpr (std::is_same_v<OutputType, cutlass::bfloat16_t>) { if constexpr (std::is_same_v<OutputType, cutlass::bfloat16_t>) {
GEMM_SWITCH_BF16(M, GEMM_SWITCH_BF16(
K, M, K, batch_size, token_padding_size, kBlockN, TailN,
Experts,
token_padding_size,
kBlockN,
WeightScaleGroup,
weight, weight,
input, input,
out, out,
weight_scale, weight_scale,
input_dequant_scale, input_row_sum,
tokens, tokens,
max_tokens, max_tokens,
stream) stream)
@@ -82,20 +96,17 @@ void DisPatchW4AFp8Gemm(const cutlass::float_e4m3_t* input,
std::vector<paddle::Tensor> W4AFp8Gemm( std::vector<paddle::Tensor> W4AFp8Gemm(
const paddle::Tensor& input, const paddle::Tensor& input,
const paddle::Tensor& weight, const paddle::Tensor& weight,
const paddle::Tensor& const paddle::Tensor& tokens, // If tokenpadding=0, this tensor represents the prefix sum of tensors, otherwise it represents the number of tokens in each group
tokens, // If tokenpadding=0, this tensor represents the prefix sum of const paddle::Tensor& input_row_sum,
// tensors, otherwise it represents the number of tokens in
// each group
const paddle::Tensor& weight_scale, const paddle::Tensor& weight_scale,
const paddle::optional<paddle::Tensor>& input_dequant_scale,
const int64_t token_padding_size, const int64_t token_padding_size,
const int64_t max_tokens, const int64_t max_tokens,
const bool is_bfloat16) { const bool is_bfloat16) {
const int Experts = weight.dims()[0];
const int batch_size = weight.dims()[0];
const int M = weight.dims()[1]; const int M = weight.dims()[1];
const int K = weight.dims()[2] * 2; const int K = weight.dims()[2] * 2;
const int WeightScaleGroup =
weight_scale.dims().size() == 2 ? K : weight_scale.dims()[3];
if (input.dtype() != paddle::DataType::FLOAT8_E4M3FN) { if (input.dtype() != paddle::DataType::FLOAT8_E4M3FN) {
PD_THROW("Only supported dtype in ['FLOAT8_E4M3FN']."); PD_THROW("Only supported dtype in ['FLOAT8_E4M3FN'].");
@@ -104,26 +115,20 @@ std::vector<paddle::Tensor> W4AFp8Gemm(
if (token_padding_size == 0) { if (token_padding_size == 0) {
const int all_tokens = input.dims()[0]; const int all_tokens = input.dims()[0];
if (is_bfloat16) { if (is_bfloat16) {
paddle::Tensor out = paddle::empty( paddle::Tensor out = paddle::empty({all_tokens, M}, paddle::DataType::BFLOAT16, input.place());
{all_tokens, M}, paddle::DataType::BFLOAT16, input.place());
phi::dtype::bfloat16 *out_data = out.data<phi::dtype::bfloat16>(); phi::dtype::bfloat16 *out_data = out.data<phi::dtype::bfloat16>();
DisPatchW4AFp8Gemm( DisPatchW4AFp8Gemm(
reinterpret_cast<const cutlass::float_e4m3_t*>( reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
input.data<phi::dtype::float8_e4m3fn>()), reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
reinterpret_cast<const cutlass::float_e4m3_t*>(
weight.data<uint8_t>()),
tokens.data<int64_t>(), tokens.data<int64_t>(),
input_row_sum.data<float>(),
weight_scale.data<float>(), weight_scale.data<float>(),
input_dequant_scale
? const_cast<float*>(input_dequant_scale.get().data<float>())
: nullptr,
reinterpret_cast<cutlass::bfloat16_t*>(out_data), reinterpret_cast<cutlass::bfloat16_t*>(out_data),
token_padding_size, token_padding_size,
max_tokens, max_tokens,
Experts, batch_size,
M, M,
K, K,
WeightScaleGroup,
input.stream()); input.stream());
return {out}; return {out};
} else { } else {
@@ -131,27 +136,20 @@ std::vector<paddle::Tensor> W4AFp8Gemm(
} }
} else { } else {
if (is_bfloat16) { if (is_bfloat16) {
paddle::Tensor out = paddle::empty({Experts, token_padding_size, M}, paddle::Tensor out = paddle::empty({batch_size, token_padding_size, M}, paddle::DataType::BFLOAT16, input.place());
paddle::DataType::BFLOAT16,
input.place());
phi::dtype::bfloat16 * out_data = out.data<phi::dtype::bfloat16>(); phi::dtype::bfloat16 * out_data = out.data<phi::dtype::bfloat16>();
DisPatchW4AFp8Gemm( DisPatchW4AFp8Gemm(
reinterpret_cast<const cutlass::float_e4m3_t*>( reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
input.data<phi::dtype::float8_e4m3fn>()), reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
reinterpret_cast<const cutlass::float_e4m3_t*>(
weight.data<uint8_t>()),
tokens.data<int64_t>(), tokens.data<int64_t>(),
input_row_sum.data<float>(),
weight_scale.data<float>(), weight_scale.data<float>(),
input_dequant_scale
? const_cast<float*>(input_dequant_scale.get().data<float>())
: nullptr,
reinterpret_cast<cutlass::bfloat16_t*>(out_data), reinterpret_cast<cutlass::bfloat16_t*>(out_data),
token_padding_size, token_padding_size,
max_tokens, max_tokens,
Experts, batch_size,
M, M,
K, K,
WeightScaleGroup,
input.stream()); input.stream());
return {out}; return {out};
} else { } else {
@@ -161,10 +159,12 @@ std::vector<paddle::Tensor> W4AFp8Gemm(
} }
template <typename InputType, typename OutputType> template <typename InputType, typename OutputType>
void DisPatchW4AFp8GemmWrapper(const InputType* input, void DisPatchW4AFp8GemmWrapper(
const InputType* input,
const InputType* weight, const InputType* weight,
const int64_t* total_rows_before_expert, const int64_t* total_rows_before_expert,
const float* input_dequant_scale, const float* input_row_sum,
const float* row_scale,
const float* weight_scale, const float* weight_scale,
OutputType * out, OutputType * out,
const int64_t token_padding_size, const int64_t token_padding_size,
@@ -172,25 +172,85 @@ void DisPatchW4AFp8GemmWrapper(const InputType* input,
const int num_experts, const int num_experts,
const int64_t M, const int64_t M,
const int64_t K, const int64_t K,
const int WeightScaleGroup,
cudaStream_t stream) { cudaStream_t stream) {
using InType = typename NVTraits<InputType>::data_t; using InType = typename NVTraits<InputType>::data_t;
using OutType = typename NVTraits<OutputType>::data_t; using OutType = typename NVTraits<OutputType>::data_t;
DisPatchW4AFp8Gemm(reinterpret_cast<const InType*>(input), DisPatchW4AFp8Gemm(
reinterpret_cast<const InType*>(input),
reinterpret_cast<const InType*>(weight), reinterpret_cast<const InType*>(weight),
total_rows_before_expert, total_rows_before_expert,
input_row_sum,
weight_scale, weight_scale,
input_dequant_scale,
reinterpret_cast<OutType*>(out), reinterpret_cast<OutType*>(out),
token_padding_size, token_padding_size,
max_tokens, max_tokens,
num_experts, num_experts,
M, M,
K, K,
WeightScaleGroup,
stream); stream);
} }
std::vector<paddle::Tensor> W4AFp8GemmWeightConvert(const paddle::Tensor& weight) {
const int batch_size = weight.dims()[0];
const int M = weight.dims()[1];
const int K = weight.dims()[2];
paddle::Tensor weight_new = paddle::empty({batch_size, M, K / 2}, paddle::DataType::UINT8, weight.place());
weight_convert(weight.data<uint8_t>(), weight_new.data<uint8_t>(), batch_size, M, K);
return {weight_new};
}
template <typename T, int kPackSize>
__global__ void permute_scale_kernel(
T* input_data,
const int numel) {
using LoadT = AlignedVector<T, kPackSize>;
LoadT input_vec;
LoadT dst_vec;
const int load_idx = (blockIdx.x * blockDim.x + threadIdx.x) * kPackSize;
if (load_idx >= numel) {
return;
}
Load<T, kPackSize>(&input_data[load_idx], &input_vec);
for (int i = 0; i < kPackSize; i+=2) {
dst_vec[i] = input_vec[i / 2];
dst_vec[i + 1] = input_vec[i / 2 + 8];
}
Store<T, kPackSize>(dst_vec, &input_data[load_idx]);
}
void W4AFp8GemmScalePermute(const paddle::Tensor& scale) {
const int row = scale.dims().size() == 2 ? scale.dims()[0] : 1;
const int col = scale.dims().size() == 2 ? scale.dims()[1] : scale.dims()[0];
if (col % 16 != 0) {
PD_THROW("Only supported when col is divisible by 16.");
}
const int numel = row * col;
const int threads = 128;
const int kPackSize = 16;
const int grid_size = (numel / kPackSize + threads - 1) / threads;
if (scale.dtype() == paddle::DataType::BFLOAT16) {
permute_scale_kernel<phi::dtype::bfloat16, kPackSize><<<grid_size, threads, 0, scale.stream()>>>(
const_cast<phi::dtype::bfloat16*>(scale.data<phi::dtype::bfloat16>()),
numel
);
} else if (scale.dtype() == paddle::DataType::FLOAT16) {
permute_scale_kernel<phi::dtype::float16, kPackSize><<<grid_size, threads, 0, scale.stream()>>>(
const_cast<phi::dtype::float16*>(scale.data<phi::dtype::float16>()),
numel
);
} else if (scale.dtype() == paddle::DataType::FLOAT32) {
permute_scale_kernel<float, kPackSize><<<grid_size, threads, 0, scale.stream()>>>(
const_cast<float*>(scale.data<float>()),
numel
);
}
}
PD_BUILD_STATIC_OP(w4afp8_gemm_scale_permute) PD_BUILD_STATIC_OP(w4afp8_gemm_scale_permute)
.Inputs({"weight_scale"}) .Inputs({"weight_scale"})
.Outputs({"permute_scale"}) .Outputs({"permute_scale"})
@@ -201,8 +261,8 @@ PD_BUILD_STATIC_OP(w4afp8_gemm)
.Inputs({"input", .Inputs({"input",
"weight", "weight",
"tokens", "tokens",
"weight_scale", "input_row_sum",
paddle::Optional("input_dequant_scale")}) "weight_scale"})
.Outputs({"out"}) .Outputs({"out"})
.Attrs({"token_padding_size: int64_t", .Attrs({"token_padding_size: int64_t",
"max_tokens: int64_t", "max_tokens: int64_t",
@@ -218,7 +278,8 @@ template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, __nv_bfloat16>(
const __nv_fp8_e4m3* input, const __nv_fp8_e4m3* input,
const __nv_fp8_e4m3* weight, const __nv_fp8_e4m3* weight,
const int64_t * tokens, const int64_t * tokens,
const float* input_dequant_scale, const float * input_row_sum,
const float * row_scale,
const float * weight_scale, const float * weight_scale,
__nv_bfloat16 * out, __nv_bfloat16 * out,
const int64_t token_padding_size, const int64_t token_padding_size,
@@ -226,14 +287,15 @@ template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, __nv_bfloat16>(
const int num_experts, const int num_experts,
const int64_t M, const int64_t M,
const int64_t K, const int64_t K,
const int WeightScaleGroup, cudaStream_t stream
cudaStream_t stream); );
template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, half>( template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, half>(
const __nv_fp8_e4m3* input, const __nv_fp8_e4m3* input,
const __nv_fp8_e4m3* weight, const __nv_fp8_e4m3* weight,
const int64_t * tokens, const int64_t * tokens,
const float* input_dequant_scale, const float * input_row_sum,
const float * row_scale,
const float * weight_scale, const float * weight_scale,
half * out, half * out,
const int64_t token_padding_size, const int64_t token_padding_size,
@@ -241,5 +303,5 @@ template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, half>(
const int num_experts, const int num_experts,
const int64_t M, const int64_t M,
const int64_t K, const int64_t K,
const int WeightScaleGroup, cudaStream_t stream
cudaStream_t stream); );

View File

@@ -18,24 +18,25 @@
#include <vector> #include <vector>
#include "helper.h" #include "helper.h"
std::vector<paddle::Tensor> W4AFp8Gemm( std::vector<paddle::Tensor> W4AFp8Gemm(
const paddle::Tensor& input, const paddle::Tensor& input,
const paddle::Tensor& weight, const paddle::Tensor& weight,
const paddle::Tensor& const paddle::Tensor& tokens, // If tokenpadding=0, this tensor represents the prefix sum of tensors, otherwise it represents the number of tokens in each group
tokens, // If tokenpadding=0, this tensor represents the prefix sum of const paddle::Tensor& input_row_sum,
// tensors, otherwise it represents the number of tokens in
// each group
const paddle::Tensor& weight_scale, const paddle::Tensor& weight_scale,
const paddle::optional<paddle::Tensor>& input_dequant_scale,
const int64_t token_padding_size, const int64_t token_padding_size,
const int64_t max_tokens, const int64_t max_tokens,
const bool is_bfloat16); const bool is_bfloat16);
template <typename InputType, typename OutputType> template <typename InputType, typename OutputType>
void DisPatchW4AFp8GemmWrapper(const InputType* input, void DisPatchW4AFp8GemmWrapper(
const InputType* input,
const InputType* weight, const InputType* weight,
const int64_t * tokens, const int64_t * tokens,
const float* input_dequant_scale, const float * input_row_sum,
const float * row_scale,
const float * weight_scale, const float * weight_scale,
OutputType * out, OutputType * out,
const int64_t token_padding_size, const int64_t token_padding_size,
@@ -43,5 +44,4 @@ void DisPatchW4AFp8GemmWrapper(const InputType* input,
const int num_experts, const int num_experts,
const int64_t M, const int64_t M,
const int64_t K, const int64_t K,
const int WeightScaleGroup,
cudaStream_t stream); cudaStream_t stream);

View File

@@ -16,26 +16,25 @@
#include "cute/atom/mma_atom.hpp" #include "cute/atom/mma_atom.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/arch/reg_reconfig.h"
#include "cutlass/cluster_launch.hpp"
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h" #include "cutlass/layout/layout.h"
#include "cutlass/numeric_types.h" #include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp" #include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/cluster_launch.hpp"
#include "cutlass/arch/reg_reconfig.h"
#include "kernel_traits.h" #include "kernel_traits.h"
#include "mainloop_fwd.h" #include "mainloop_fwd.h"
template <typename Ktraits> template <typename Ktraits>
void __global__ __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp, void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) w4afp8_gemm_kernel(
1) CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits>::Params const mainloop_params) {
w4afp8_gemm_kernel(
CUTE_GRID_CONSTANT
typename CollectiveMainloopFwd<Ktraits>::Params const mainloop_params) {
using Element = typename Ktraits::Element; using Element = typename Ktraits::Element;
static_assert(cutlass::sizeof_bits_v<Element> == 8); static_assert(cutlass::sizeof_bits_v<Element> == 8);
using TileShape_MNK = typename Ktraits::TileShape_MNK; using TileShape_MNK = typename Ktraits::TileShape_MNK;
using TileShape_MNK_TAIL = typename Ktraits::TileShape_MNK_TAIL;
using ClusterShape = typename Ktraits::ClusterShape_MNK; using ClusterShape = typename Ktraits::ClusterShape_MNK;
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{}); static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{});
@@ -43,9 +42,8 @@ void __global__ __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp,
static constexpr int kBlockN = Ktraits::kBlockN; static constexpr int kBlockN = Ktraits::kBlockN;
static constexpr int kBlockM = Ktraits::kBlockM; static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int M = Ktraits::M; static constexpr int M = Ktraits::M;
static constexpr int K = Ktraits::K;
static constexpr int TokenPackSize = Ktraits::TokenPackSize; static constexpr int TokenPackSize = Ktraits::TokenPackSize;
static constexpr int WeightScaleGroup = Ktraits::WeightScaleGroup; static constexpr int TAIL_N = Ktraits::TAIL_N;
using CollectiveMainloop = CollectiveMainloopFwd<Ktraits>; using CollectiveMainloop = CollectiveMainloopFwd<Ktraits>;
@@ -55,8 +53,7 @@ void __global__ __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp,
using ElementOutput = typename Ktraits::ElementOutput; using ElementOutput = typename Ktraits::ElementOutput;
extern __shared__ char shared_memory[]; extern __shared__ char shared_memory[];
auto &shared_storage = auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
*reinterpret_cast<typename Ktraits::SharedStorage *>(shared_memory);
const int bidm = blockIdx.x; const int bidm = blockIdx.x;
const int bidn = blockIdx.y; const int bidn = blockIdx.y;
@@ -68,20 +65,10 @@ void __global__ __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp,
} }
// Obtain warp index // Obtain warp index
int const warp_group_thread_idx = int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
threadIdx.x % cutlass::NumThreadsPerWarpGroup;
PipelineParams pipeline_params; PipelineParams pipeline_params;
if constexpr (WeightScaleGroup == K) { pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesA + CollectiveMainloop::TmaTransactionBytesB;
pipeline_params.transaction_bytes =
CollectiveMainloop::TmaTransactionBytesA +
CollectiveMainloop::TmaTransactionBytesB;
} else {
pipeline_params.transaction_bytes =
CollectiveMainloop::TmaTransactionBytesA +
CollectiveMainloop::TmaTransactionBytesB +
CollectiveMainloop::TmaTransactionBytesScale;
}
int warp_group_idx = cutlass::canonical_warp_group_idx(); int warp_group_idx = cutlass::canonical_warp_group_idx();
pipeline_params.role = warp_group_idx == 0 pipeline_params.role = warp_group_idx == 0
? MainloopPipeline::ThreadCategory::Producer ? MainloopPipeline::ThreadCategory::Producer
@@ -89,8 +76,7 @@ void __global__ __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp,
pipeline_params.is_leader = warp_group_thread_idx == 0; pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = NumMmaThreads; pipeline_params.num_consumers = NumMmaThreads;
MainloopPipeline pipeline( MainloopPipeline pipeline(shared_storage.pipeline, pipeline_params, ClusterShape{});
shared_storage.pipeline, pipeline_params, ClusterShape{});
CollectiveMainloop collective_mainloop; CollectiveMainloop collective_mainloop;
@@ -101,31 +87,23 @@ void __global__ __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp,
__syncthreads(); __syncthreads();
} }
const int pre_fix_tokens = const int pre_fix_tokens = TokenPackSize == 0 ? (bidb == 0 ? 0 : mainloop_params.tokens[bidb - 1]) : 0;
TokenPackSize == 0 ? (bidb == 0 ? 0 : mainloop_params.tokens[bidb - 1])
: 0; const int tokens = TokenPackSize == 0 ? mainloop_params.tokens[bidb] - pre_fix_tokens : mainloop_params.tokens[bidb];
const int tokens = TokenPackSize == 0
? mainloop_params.tokens[bidb] - pre_fix_tokens
: mainloop_params.tokens[bidb];
if (bidn * kBlockN >= tokens) { if (bidn * kBlockN >= tokens) {
return; return;
} }
const bool is_need_input_scale = mainloop_params.input_scale != nullptr; float* input_row_sum = reinterpret_cast<float*>(
shared_memory + sizeof(typename Ktraits::SharedStorage));
float *input_scale =
is_need_input_scale
? reinterpret_cast<float *>(shared_memory +
sizeof(typename Ktraits::SharedStorage))
: nullptr;
if (warp_group_idx == 0) { if (warp_group_idx == 0) {
cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 12 ? 40 : 32>(); cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 12 ? 40 : 32>();
PipelineState smem_pipe_write = PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipeline>();
cutlass::make_producer_start_state<MainloopPipeline>(); collective_mainloop.load(
collective_mainloop.load(mainloop_params, mainloop_params,
pipeline, pipeline,
smem_pipe_write, smem_pipe_write,
shared_storage, shared_storage,
@@ -141,60 +119,67 @@ void __global__ __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp,
typename Ktraits::TiledMma tiled_mma; typename Ktraits::TiledMma tiled_mma;
typename Ktraits::TiledMma_TAIL tiled_mma_tail;
const int mma_tidx = tidx - NumCopyThreads; const int mma_tidx = tidx - NumCopyThreads;
const int lane_id = mma_tidx % 4 * 2;
const float2 weight_scale = reinterpret_cast<const float2*>(mainloop_params.weight_scale + bidb * M + bidm * kBlockM)[mma_tidx / 4];
if (is_need_input_scale) {
if constexpr (TokenPackSize == 0) { if constexpr (TokenPackSize == 0) {
const int input_scale_idx = pre_fix_tokens + bidn * kBlockN; const int input_sum_idx = pre_fix_tokens + bidn * kBlockN;
if (mma_tidx < tokens) { if (mma_tidx < kBlockN) {
reinterpret_cast<float *>(input_scale)[mma_tidx] = reinterpret_cast<float*>(input_row_sum)[mma_tidx] = reinterpret_cast<const float*>(mainloop_params.input_row_sum + input_sum_idx)[mma_tidx];
reinterpret_cast<const float *>(mainloop_params.input_scale +
input_scale_idx)[mma_tidx];
} }
} else { } else {
const int input_scale_idx = bidb * TokenPackSize + bidn * kBlockN; const int input_sum_idx = bidb * TokenPackSize + bidn * kBlockN;
if (mma_tidx < kBlockN / 4) { if (mma_tidx < kBlockN / 4) {
reinterpret_cast<float4 *>(input_scale)[mma_tidx] = reinterpret_cast<float4*>(input_row_sum)[mma_tidx] = reinterpret_cast<const float4*>(mainloop_params.input_row_sum + input_sum_idx)[mma_tidx];
reinterpret_cast<const float4 *>(mainloop_params.input_scale +
input_scale_idx)[mma_tidx];
}
} }
} }
float2 weight_scale; const int reamin_tokens = tokens - bidn * kBlockN;
if constexpr (WeightScaleGroup == K) { if (TAIL_N > 0 && reamin_tokens < kBlockN) {
weight_scale = reinterpret_cast<const float2 *>( Tensor tSrS_tail = partition_fragment_C(tiled_mma_tail, select<0, 1>(TileShape_MNK_TAIL{}));
mainloop_params.weight_scale + bidb * M + collective_mainloop.mma<TAIL_N>(
bidm * kBlockM)[mma_tidx / 4]; mainloop_params,
} tiled_mma_tail,
Tensor tSrS =
partition_fragment_C(tiled_mma, select<0, 1>(TileShape_MNK{}));
if constexpr (WeightScaleGroup == K) {
collective_mainloop.mma(mainloop_params,
tiled_mma,
pipeline, pipeline,
smem_pipe_read, smem_pipe_read,
shared_storage, shared_storage,
tSrS, tSrS_tail,
mma_tidx);
collective_mainloop.store<TAIL_N>(
mainloop_params,
tSrS_tail,
shared_storage,
tiled_mma_tail,
input_row_sum + lane_id,
reinterpret_cast<const float*>(&weight_scale),
tokens,
pre_fix_tokens,
bidm,
bidn,
bidb,
mma_tidx); mma_tidx);
} else { } else {
collective_mainloop.mma_pipeline(mainloop_params, Tensor tSrS = partition_fragment_C(tiled_mma, select<0, 1>(TileShape_MNK{}));
collective_mainloop.mma<kBlockN>(
mainloop_params,
tiled_mma, tiled_mma,
pipeline, pipeline,
smem_pipe_read, smem_pipe_read,
shared_storage, shared_storage,
tSrS, tSrS,
mma_tidx); mma_tidx);
} collective_mainloop.store<kBlockN>(
mainloop_params,
collective_mainloop.store(mainloop_params,
tSrS, tSrS,
shared_storage, shared_storage,
tiled_mma, tiled_mma,
input_row_sum + lane_id,
reinterpret_cast<const float*>(&weight_scale), reinterpret_cast<const float*>(&weight_scale),
input_scale,
tokens, tokens,
pre_fix_tokens, pre_fix_tokens,
bidm, bidm,
@@ -204,92 +189,64 @@ void __global__ __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp,
} }
} }
template <int Experts> }
template <int Batch>
auto get_gmem_layout(const int Rows, const int Cols) { auto get_gmem_layout(const int Rows, const int Cols) {
return make_layout(make_shape(static_cast<int64_t>(Rows), return make_layout(
make_shape(
static_cast<int64_t>(Rows),
static_cast<int64_t>(Cols),
static_cast<int64_t>(Batch)),
make_stride(
static_cast<int64_t>(Cols), static_cast<int64_t>(Cols),
static_cast<int64_t>(Experts)),
make_stride(static_cast<int64_t>(Cols),
cute::_1{}, cute::_1{},
static_cast<int64_t>(Rows * Cols))); static_cast<int64_t>(Rows * Cols)));
} }
template <int Experts>
auto get_scale_layout(const int Rows, const int Cols) {
return make_layout(make_shape(static_cast<int64_t>(Cols),
static_cast<int64_t>(Rows),
static_cast<int64_t>(Experts)),
make_stride(cute::_1{},
static_cast<int64_t>(Cols),
static_cast<int64_t>(Rows * Cols)));
}
template <typename InputType, template <typename InputType, typename OutputType, typename Kernel_traits, int M, int K, int Batch, int TokenPackSize>
typename OutputType, void run_gemm(const InputType * A, const InputType * B, OutputType * C, const float *weight_scale,
typename Kernel_traits, const float *input_row_sum, const int64_t * tokens, const int64_t max_tokens, cudaStream_t stream) {
int M,
int K,
int Experts,
int TokenPackSize,
int WeightScaleGroup>
void run_gemm(const InputType *A,
const InputType *B,
OutputType *C,
const float *weight_scale,
const float *input_dequant_scale,
const int64_t *tokens,
const int max_tokens,
cudaStream_t stream) {
using ElementOutput = typename Kernel_traits::ElementOutput; using ElementOutput = typename Kernel_traits::ElementOutput;
using Element = typename Kernel_traits::Element; using Element = typename Kernel_traits::Element;
using CollectiveMainloop = CollectiveMainloopFwd<Kernel_traits>; using CollectiveMainloop = CollectiveMainloopFwd<Kernel_traits>;
using ClusterShape = typename Kernel_traits::ClusterShape_MNK; using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
constexpr int M_nums = constexpr int M_nums = (M + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
(M + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; const int N_nums = (max_tokens + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
const int N_nums =
(max_tokens + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
constexpr int K_scale_nums = K / Kernel_traits::kBlockM;
static_assert(K % WeightScaleGroup == 0);
static_assert(WeightScaleGroup == 128 || WeightScaleGroup == K);
typename CollectiveMainloop::Params mainloop_params = typename CollectiveMainloop::Params mainloop_params =
CollectiveMainloop::to_underlying_arguments( CollectiveMainloop::to_underlying_arguments({
{static_cast<Element const *>(A), static_cast<Element const*>(A),
get_gmem_layout<Experts>(M, K / 2), get_gmem_layout<Batch>(M, K / 2),
static_cast<Element const*>(B), static_cast<Element const*>(B),
get_gmem_layout<Experts>( get_gmem_layout<Batch>(TokenPackSize == 0 ? max_tokens: TokenPackSize, K),
TokenPackSize == 0 ? max_tokens : TokenPackSize, K),
static_cast<ElementOutput*>(C), static_cast<ElementOutput*>(C),
get_gmem_layout<Experts>( get_gmem_layout<Batch>(M, TokenPackSize == 0 ? max_tokens : TokenPackSize),
M, TokenPackSize == 0 ? max_tokens : TokenPackSize),
weight_scale, weight_scale,
get_scale_layout<Experts>(M_nums, input_row_sum,
K_scale_nums * Kernel_traits::kBlockM), tokens
input_dequant_scale, });
tokens});
void *kernel; void *kernel;
kernel = (void *)w4afp8_gemm_kernel<Kernel_traits>; kernel = (void *)w4afp8_gemm_kernel<Kernel_traits>;
int smem_size = sizeof(typename Kernel_traits::SharedStorage) + int smem_size = sizeof(typename Kernel_traits::SharedStorage) + sizeof(float) * Kernel_traits::kBlockN;
Kernel_traits::kBlockN * sizeof(float);
if (smem_size >= 48 * 1024) { if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
} }
dim3 grid_dims; dim3 grid_dims;
grid_dims.x = M_nums; grid_dims.x = M_nums;
grid_dims.y = N_nums; grid_dims.y = N_nums;
grid_dims.z = Experts; grid_dims.z = Batch;
static constexpr int ctaSize = Kernel_traits::kNWarps * 32; static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
dim3 block_dims(ctaSize); dim3 block_dims(ctaSize);
dim3 cluster_dims(size<0>(ClusterShape{}), dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
size<1>(ClusterShape{}), cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
size<2>(ClusterShape{})); cutlass::launch_kernel_on_cluster(
cutlass::ClusterLaunchParams launch_params{ launch_params, kernel, mainloop_params);
grid_dims, block_dims, cluster_dims, smem_size, stream};
cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params);
} }

View File

@@ -1,131 +0,0 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/extension.h"
void weight_convert(
const uint8_t* weight, uint8_t* weight_new, int experts, int M, int K) {
assert(K % 64 == 0);
for (int b = 0; b < experts; ++b) {
for (int m = 0; m < M; ++m) {
for (int k = 0; k < K; k += 64) {
for (int k_inner = 0; k_inner < 32; ++k_inner) {
uint8_t temp = 0;
uint8_t left = weight[b * M * K + m * K + k + k_inner];
uint8_t right = weight[b * M * K + m * K + k + k_inner + 32];
temp |= left << 4;
temp |= right;
weight_new[b * M * K / 2 + m * K / 2 + k / 2 + k_inner] =
*reinterpret_cast<uint8_t*>(&temp);
}
}
}
}
}
__global__ void weight_permute_interleave_kernelw4afp8(const int8_t* input_data,
int8_t* output_data,
const int original_k,
const int original_n) {
const int numel = original_k * original_n / 4;
const int pack_group_size = 64;
const int thread_group_size = pack_group_size / 4; // 16
const int thread_k_stride = original_k / 4;
const int linear_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (linear_idx >= numel) return;
const int n_id = linear_idx / thread_k_stride;
const int k_id = linear_idx % thread_k_stride;
const int k_group_idx = k_id / thread_group_size;
const int k_idx_in_group = k_id % thread_group_size;
const int8_t* src = input_data +
k_group_idx * pack_group_size / 2 * original_n +
k_idx_in_group * original_n + n_id;
int8_t tmp0 = src[0];
int8_t tmp1 = src[pack_group_size / 4 * original_n];
int8_t tmp00 = (tmp0 & 0xF0) + 112;
int8_t tmp01 = ((tmp0 << 4) & 0xF0) + 112;
int8_t tmp10 = (tmp1 & 0xF0) + 112;
int8_t tmp11 = ((tmp1 << 4) & 0xF0) + 112;
uint8_t utmp00 = *(reinterpret_cast<uint8_t*>(&tmp00));
uint8_t utmp01 = *(reinterpret_cast<uint8_t*>(&tmp01));
uint8_t utmp10 = *(reinterpret_cast<uint8_t*>(&tmp10));
uint8_t utmp11 = *(reinterpret_cast<uint8_t*>(&tmp11));
utmp00 = (utmp00 & 0xF0) >> 4;
utmp01 = (utmp01 & 0xF0) >> 4;
utmp10 = (utmp10 & 0xF0) >> 4;
utmp11 = (utmp11 & 0xF0) >> 4;
tmp00 = *(reinterpret_cast<int8_t*>(&utmp00)) - 7;
tmp01 = *(reinterpret_cast<int8_t*>(&utmp01)) - 7;
tmp10 = *(reinterpret_cast<int8_t*>(&utmp10)) - 7;
tmp11 = *(reinterpret_cast<int8_t*>(&utmp11)) - 7;
if (tmp00 <= 0) {
tmp00 = 8 - tmp00;
}
if (tmp01 <= 0) {
tmp01 = 8 - tmp01;
}
if (tmp10 <= 0) {
tmp10 = 8 - tmp10;
}
if (tmp11 <= 0) {
tmp11 = 8 - tmp11;
}
int8_t dst0 = (tmp01 << 4) | tmp11;
int8_t dst1 = (tmp00 << 4) | tmp10;
int8_t* dst = output_data + n_id * original_k / 2 +
(k_group_idx * pack_group_size / 2) + k_idx_in_group * 2;
dst[0] = dst0;
dst[1] = dst1;
}
std::vector<paddle::Tensor> W4AFp8GemmWeightConvert(
const paddle::Tensor& weight) {
if (weight.place() == paddle::CPUPlace()) {
const int experts = weight.dims()[0];
const int M = weight.dims()[1];
const int K = weight.dims()[2];
paddle::Tensor weight_new = paddle::empty(
{experts, M, K / 2}, paddle::DataType::UINT8, weight.place());
weight_convert(
weight.data<uint8_t>(), weight_new.data<uint8_t>(), experts, M, K);
return {weight_new};
} else {
const int original_k = weight.dims()[0] * 2;
const int original_n = weight.dims()[1];
paddle::Tensor weight_new =
paddle::empty(weight.shape(), paddle::DataType::INT8, weight.place());
const int block_dim = 256;
const int original_numel = original_k * original_n;
const int grid_size = (original_numel + block_dim - 1) / block_dim;
weight_permute_interleave_kernelw4afp8<<<grid_size, block_dim>>>(
weight.data<int8_t>(),
weight_new.data<int8_t>(),
original_k,
original_n);
return {weight_new};
}
}

View File

@@ -1,63 +0,0 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/extension.h"
template <typename T, int kPackSize>
__global__ void permute_scale_kernel(T* input_data, const int numel) {
using LoadT = AlignedVector<T, kPackSize>;
LoadT input_vec;
LoadT dst_vec;
const int load_idx = (blockIdx.x * blockDim.x + threadIdx.x) * kPackSize;
if (load_idx >= numel) {
return;
}
Load<T, kPackSize>(&input_data[load_idx], &input_vec);
for (int i = 0; i < kPackSize; i += 2) {
dst_vec[i] = input_vec[i / 2];
dst_vec[i + 1] = input_vec[i / 2 + 8];
}
Store<T, kPackSize>(dst_vec, &input_data[load_idx]);
}
void W4AFp8GemmScalePermute(const paddle::Tensor& scale) {
const int row = scale.dims().size() == 2 ? scale.dims()[0] : 1;
const int col = scale.dims().size() == 2 ? scale.dims()[1] : scale.dims()[0];
if (col % 16 != 0) {
PD_THROW("Only supported when col is divisible by 16.");
}
const int numel = row * col;
const int threads = 128;
const int kPackSize = 16;
const int grid_size = (numel / kPackSize + threads - 1) / threads;
if (scale.dtype() == paddle::DataType::BFLOAT16) {
permute_scale_kernel<phi::dtype::bfloat16, kPackSize>
<<<grid_size, threads, 0, scale.stream()>>>(
const_cast<phi::dtype::bfloat16*>(
scale.data<phi::dtype::bfloat16>()),
numel);
} else if (scale.dtype() == paddle::DataType::FLOAT16) {
permute_scale_kernel<phi::dtype::float16, kPackSize>
<<<grid_size, threads, 0, scale.stream()>>>(
const_cast<phi::dtype::float16*>(scale.data<phi::dtype::float16>()),
numel);
} else if (scale.dtype() == paddle::DataType::FLOAT32) {
permute_scale_kernel<float, kPackSize>
<<<grid_size, threads, 0, scale.stream()>>>(
const_cast<float*>(scale.data<float>()), numel);
}
}

View File

@@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import re
file_dir = "./gpu_ops/w4afp8_gemm/" file_dir = "./gpu_ops/w4afp8_gemm/"
@@ -32,12 +30,12 @@ gemm_template_head = """
#include <cutlass/numeric_types.h> #include <cutlass/numeric_types.h>
""" """
gemm_template_case = """ gemm_template_case = """
void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}( void w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(
const cutlass::float_e4m3_t * weight, const cutlass::float_e4m3_t * weight,
const cutlass::float_e4m3_t * input, const cutlass::float_e4m3_t * input,
{cutlass_type} * out, {cutlass_type} * out,
const float *weight_scale, const float *weight_scale,
const float * input_dequant_scale, const float *input_row_sum,
const int64_t *tokens, const int64_t *tokens,
const int64_t max_tokens, const int64_t max_tokens,
cudaStream_t stream); cudaStream_t stream);
@@ -50,22 +48,22 @@ gemm_template_cu_head = """
""" """
gemm_template_cu_template = """ gemm_template_cu_template = """
void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}( void w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(
const cutlass::float_e4m3_t * weight, const cutlass::float_e4m3_t * weight,
const cutlass::float_e4m3_t * input, const cutlass::float_e4m3_t * input,
{cutlass_type} * out, {cutlass_type} * out,
const float *weight_scale, const float *weight_scale,
const float * input_dequant_scale, const float *input_row_sum,
const int64_t *tokens, const int64_t *tokens,
const int64_t max_tokens, const int64_t max_tokens,
cudaStream_t stream) {{ cudaStream_t stream) {{
constexpr static int M = {M}; constexpr static int M = {M};
constexpr static int K = {K}; constexpr static int K = {K};
constexpr static int EXPERTS = {EXPERTS}; constexpr static int Batch = {BATCH};
constexpr static int TokenPackSize = {PADDING}; constexpr static int TokenPackSize = {PADDING};
constexpr static int kBlockN = {N}; constexpr static int kBlockN = {N};
constexpr static int kGroupSize = {GROUPSIZE}; constexpr static int kBlockN_TAIL = {TAILN};
constexpr static int kBlockM = 128; constexpr static int kBlockM = 128;
constexpr static int kBlockK = 128; constexpr static int kBlockK = 128;
constexpr static int kNWarps = 4 + kBlockM / 16; constexpr static int kNWarps = 4 + kBlockM / 16;
@@ -76,24 +74,22 @@ void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}(
using Kernel_traits = Kernel_traits< using Kernel_traits = Kernel_traits<
kBlockM, kBlockN, kBlockK, kNWarps, kStages, kTiles, kBlockM, kBlockN, kBlockK, kNWarps, kStages, kTiles,
M, K, TokenPackSize, kGroupSize, kCluster, cutlass::float_e4m3_t, M, TokenPackSize, kBlockN_TAIL, kCluster, cutlass::float_e4m3_t,
{cutlass_type}>; {cutlass_type}>;
run_gemm<cutlass::float_e4m3_t, {cutlass_type}, run_gemm<cutlass::float_e4m3_t, {cutlass_type},
Kernel_traits, M, K, EXPERTS, TokenPackSize, kGroupSize> Kernel_traits, M, K, Batch, TokenPackSize>
(weight, input, out, weight_scale, input_dequant_scale, tokens, max_tokens, stream); (weight, input, out, weight_scale,
input_row_sum, tokens, max_tokens, stream);
}} }}
""" """
# [M, K, Number of experts, token Padding Size, weight K group size]
gemm_case = [ gemm_case = [
[8192, 3584, 16, 0, 128], # eb45T ffn1 [8192, 3584, 8, 0], # eb45T ffn1
[8192, 3584, 16, 512, 128], # eb45T ffn1 [8192, 3584, 8, 2048], # eb45T ffn1
[7168, 8192, 16, 0, 128], # eb45T ffn2 [7168, 8192, 8, 0], # eb45T ffn2
[7168, 8192, 16, 512, 128], # eb45T ffn2 [7168, 8192, 8, 2048], # eb45T ffn2
[1792, 8192, 64, 0, 8192], # eb45t ffn1 [1792, 8192, 64, 0], # eb45t ffn1
[8192, 896, 64, 0, 896], # eb45t ffn2 [8192, 896, 64, 0], # eb45t ffn2
[1792, 8192, 64, 0, 128], # eb45t ffn1
[8192, 896, 64, 0, 128], # eb45t ffn2
] ]
dtype = ["BF16"] dtype = ["BF16"]
@@ -101,19 +97,6 @@ dtype = ["BF16"]
use_fast_compile = True use_fast_compile = True
n_range = [256] if use_fast_compile else [i for i in range(16, 257, 16)] n_range = [256] if use_fast_compile else [i for i in range(16, 257, 16)]
all_cu_files = []
for type in dtype:
for case in gemm_case:
for n in n_range:
all_cu_files.append(f"w4afp8_gemm_M{case[0]}_N{n}_G{case[4]}_K{case[1]}_E{case[2]}_P{case[3]}_{type}.cu")
for file_path, empty_list, file_name_list in os.walk(file_dir):
for file_name in file_name_list:
if re.match(r"^w4afp8_gemm_M\d+_N\d+_.*\.cu$", file_name):
if file_name not in all_cu_files:
print("delete w4afp8 kernel file", file_path + file_name)
os.remove(file_path + file_name)
def get_cutlass_type(type): def get_cutlass_type(type):
if type == "BF16": if type == "BF16":
@@ -133,16 +116,28 @@ for type in dtype:
M=case[0], M=case[0],
K=case[1], K=case[1],
N=n, N=n,
EXPERTS=case[2], BATCH=case[2],
TYPE=type, TYPE=type,
PADDING=case[3], PADDING=case[3],
GROUPSIZE=case[4], TAILN=0,
cutlass_type=get_cutlass_type(type),
)
)
template_head_file.write(
gemm_template_case.format(
M=case[0],
K=case[1],
N=256,
BATCH=case[2],
TYPE=type,
PADDING=case[3],
TAILN=n - 16,
cutlass_type=get_cutlass_type(type), cutlass_type=get_cutlass_type(type),
) )
) )
template_cu_file = open( template_cu_file = open(
f"{file_dir}w4afp8_gemm_M{case[0]}_N{n}_G{case[4]}_K{case[1]}_E{case[2]}_P{case[3]}_{type}.cu", "w" f"{file_dir}w4afp8_gemm_M{case[0]}_N{n}_TAILN{0}_K{case[1]}_B{case[2]}_P{case[3]}_{type}.cu", "w"
) )
template_cu_file.write(gemm_template_cu_head) template_cu_file.write(gemm_template_cu_head)
template_cu_file.write( template_cu_file.write(
@@ -150,10 +145,29 @@ for type in dtype:
M=case[0], M=case[0],
K=case[1], K=case[1],
N=n, N=n,
EXPERTS=case[2], BATCH=case[2],
TYPE=type, TYPE=type,
PADDING=case[3], PADDING=case[3],
GROUPSIZE=case[4], TAILN=0,
cutlass_type=get_cutlass_type(type),
)
)
template_cu_file.close()
template_cu_file = open(
f"{file_dir}w4afp8_gemm_M{case[0]}_N{256}_TAILN{n-16}_K{case[1]}_B{case[2]}_P{case[3]}_{type}.cu", "w"
)
template_cu_file.write(gemm_template_cu_head)
template_cu_file.write(
gemm_template_cu_template.format(
M=case[0],
K=case[1],
N=256,
BATCH=case[2],
TYPE=type,
PADDING=case[3],
TAILN=n - 16,
cutlass_type=get_cutlass_type(type), cutlass_type=get_cutlass_type(type),
) )
) )
@@ -163,8 +177,8 @@ for type in dtype:
for type in dtype: for type in dtype:
template_head_file.write("\n") template_head_file.write("\n")
template_head_file.write( template_head_file.write(
"""#define GEMM_SWITCH_{TYPE}(_M, _K, _EXPERTS, _TokenPaddingSize, _kBlockN, _GROUPSIZE, ...) {{ \\ """#define GEMM_SWITCH_{TYPE}(_M, _K, _BATCH, _TokenPaddingSize, _kBlockN, _TailN, ...) {{ \\
if (_M == 0 && _K == 0 && _EXPERTS == 0 && _TokenPaddingSize == 0 && _kBlockN == 0 && _GROUPSIZE == 0) {{ \\""".format( if (_M == 0 && _K == 0 && _BATCH == 0 && _TokenPaddingSize == 0 && _kBlockN == 0 && _TailN == 0) {{ \\""".format(
TYPE=type TYPE=type
) )
) )
@@ -174,16 +188,23 @@ for type in dtype:
for case in gemm_case: for case in gemm_case:
for n in n_range: for n in n_range:
template_head_file.write( template_head_file.write(
""" }} else if (_M == {M} && _K == {K} && _EXPERTS == {EXPERTS} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _GROUPSIZE == {GROUPSIZE}) {{ \\ """ }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\
w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format( w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format(
M=case[0], K=case[1], N=n, EXPERTS=case[2], TYPE=type, PADDING=case[3], GROUPSIZE=case[4] M=case[0], K=case[1], N=n, BATCH=case[2], TYPE=type, PADDING=case[3], TAILN=0
)
)
template_head_file.write("\n")
template_head_file.write(
""" }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\
w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format(
M=case[0], K=case[1], N=256, BATCH=case[2], TYPE=type, PADDING=case[3], TAILN=n - 16
) )
) )
template_head_file.write("\n") template_head_file.write("\n")
template_head_file.write( template_head_file.write(
""" } else { \\ """ } else { \\
PADDLE_THROW(phi::errors::Unimplemented("W4aFp8 not supported m=%d k=%d experts=%d token_padding_size=%d kBlockN=%d groupsize=%d\\n", _M, _K, _EXPERTS, _TokenPaddingSize, _kBlockN, _GROUPSIZE)); \\ PADDLE_THROW(phi::errors::Unimplemented("W4aFp8 not supported m=%d k=%d batch=%d token_padding_size=%d kBlockN=%d tailN=%d\\n", _M, _K, _BATCH, _TokenPaddingSize, _kBlockN, _TailN)); \\
} \\ } \\
}""" }"""
) )

View File

@@ -30,10 +30,7 @@ if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce
try: try:
from fastdeploy.model_executor.ops.gpu import ( from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute
w4afp8_gemm_scale_permute,
w4afp8_gemm_weight_convert,
)
except: except:
logger.warning("import w4afp8_gemm_scale_permute Failed!") logger.warning("import w4afp8_gemm_scale_permute Failed!")
elif current_platform.is_iluvatar(): elif current_platform.is_iluvatar():
@@ -78,7 +75,6 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
expert_idx_per_token: paddle.Tensor, expert_idx_per_token: paddle.Tensor,
used_in_ep_low_latency: bool = False, used_in_ep_low_latency: bool = False,
estimate_total_token_nums: int = -1, estimate_total_token_nums: int = -1,
dequant_scale: paddle.Tensor = None,
): ):
""" """
Paddle Cutlass compute Fused MoE. Paddle Cutlass compute Fused MoE.
@@ -104,7 +100,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
token_nums_per_expert, token_nums_per_expert,
getattr(layer, self.added_weight_attrs[0]), getattr(layer, self.added_weight_attrs[0]),
getattr(layer, self.added_weight_attrs[1]), getattr(layer, self.added_weight_attrs[1]),
dequant_scale, # None,
(layer.up_gate_proj_bias if hasattr(layer, "up_gate_proj_bias") else None), (layer.up_gate_proj_bias if hasattr(layer, "up_gate_proj_bias") else None),
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None), (layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None), (layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
@@ -116,7 +112,6 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
getattr(layer.moe_quant_config, "hadamard_block_size", 128), getattr(layer.moe_quant_config, "hadamard_block_size", 128),
layer.activation, layer.activation,
) )
if layer.with_bias: if layer.with_bias:
down_proj_bias_expand = paddle.index_select(layer.down_proj_bias, expert_idx_per_token, axis=0) down_proj_bias_expand = paddle.index_select(layer.down_proj_bias, expert_idx_per_token, axis=0)
ffn_out_without_down_proj_bias = paddle.add(ffn_out_without_down_proj_bias, down_proj_bias_expand) ffn_out_without_down_proj_bias = paddle.add(ffn_out_without_down_proj_bias, down_proj_bias_expand)
@@ -265,7 +260,6 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
topk_weights, topk_weights,
topk_idx, topk_idx,
expert_idx_per_token, expert_idx_per_token,
dequant_scale,
) = moe_expert_dispatch( ) = moe_expert_dispatch(
x, x,
gate_out, gate_out,
@@ -286,21 +280,19 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
topk_weights, topk_weights,
topk_idx, topk_idx,
expert_idx_per_token, expert_idx_per_token,
dequant_scale,
) = moe_expert_dispatch( ) = moe_expert_dispatch(
x, x,
gate_out, gate_out,
layer.gate_correction_bias, layer.gate_correction_bias,
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None), (
layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None
), # if set, permute_input will be int8_t
layer.top_k, layer.top_k,
False, False,
self.moe_quant_type, self.moe_quant_type,
topk_only_mode=False, topk_only_mode=False,
) )
if hasattr(layer, "up_gate_proj_in_scale"):
dequant_scale = None
if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8": if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
# only w4a8 need expert_idx_per_token # only w4a8 need expert_idx_per_token
# Other need not this tensor, so we make it None. # Other need not this tensor, so we make it None.
@@ -308,9 +300,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
else: else:
expert_idx_per_token = expert_idx_per_token.cast("int64") expert_idx_per_token = expert_idx_per_token.cast("int64")
ffn_out = self.compute_ffn( ffn_out = self.compute_ffn(layer, permute_input, token_nums_per_expert, expert_idx_per_token)
layer, permute_input, token_nums_per_expert, expert_idx_per_token, False, -1, dequant_scale
)
# reduce 中会做 topk 个 weight 的 norm 和 routed_scaling_factor # reduce 中会做 topk 个 weight 的 norm 和 routed_scaling_factor
fused_moe_out = moe_expert_reduce( fused_moe_out = moe_expert_reduce(
@@ -855,7 +845,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
weight_name = self.added_weight_attrs[idx] weight_name = self.added_weight_attrs[idx]
weight_list = [] weight_list = []
for i in range(layer.num_local_experts): for i in range(layer.num_local_experts):
quant_weight = w4afp8_gemm_weight_convert(weight_tensor[i]) quant_weight, scale = weight_quantize(weight_tensor[i], algo=self.moe_quant_type, arch=80)
weight_list.append(quant_weight) weight_list.append(quant_weight)
quanted_weight = paddle.stack(weight_list, axis=0) quanted_weight = paddle.stack(weight_list, axis=0)
getattr(layer, weight_name).set_value(quanted_weight) getattr(layer, weight_name).set_value(quanted_weight)
@@ -885,7 +875,6 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
) )
# in_scales # in_scales
if not layer.moe_quant_config.moe_dynamic_quant:
for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]: for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]:
setattr( setattr(
layer, layer,
@@ -896,18 +885,6 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
default_initializer=paddle.nn.initializer.Constant(0), default_initializer=paddle.nn.initializer.Constant(0),
), ),
) )
else:
if layer.ep_size > 1:
for in_scale_name in ["up_gate_proj_in_scale"]:
setattr(
layer,
in_scale_name,
layer.create_parameter(
shape=[layer.num_local_experts],
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# weight_scales # weight_scales
setattr( setattr(
@@ -965,56 +942,9 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
return weight_scale return weight_scale
def _process_weight_scale(name: str, weight_scales: list[paddle.Tensor], processed_in_scale: paddle.Tensor): def _process_weight_scale(name: str, weight_scales: list[paddle.Tensor], processed_in_scale: paddle.Tensor):
if processed_in_scale is not None:
processed_weight_scale = paddle.stack(weight_scales, axis=0) / (448 * 7 * 2 ** (-9))
if len(processed_weight_scale.shape) == 3:
processed_weight_scale = ( processed_weight_scale = (
processed_weight_scale.transpose([0, 2, 1]) / processed_in_scale[:, None, None] paddle.stack(weight_scales, axis=0) / (448 * 7 * 2 ** (-9)) / processed_in_scale[:, None]
) )
else:
processed_weight_scale = processed_weight_scale / processed_in_scale[:, None]
else:
processed_weight_scale = paddle.stack(weight_scales, axis=0) / (448 * 7 * 2 ** (-9))
if len(processed_weight_scale.shape) == 3:
if name == "up_gate_proj_weight_scale" and processed_weight_scale.shape[-1] * 128 != layer.hidden_size:
assert (
layer.hidden_size // 128 % processed_weight_scale.shape[-1] == 0
), "weight_scale_group_size must be a multiple of 128"
# If it is a multiple of 128, repeat to 128
processed_weight_scale = processed_weight_scale.repeat_interleave(
layer.hidden_size // 128 // processed_weight_scale.shape[-1], axis=-1
)
elif (
name == "down_proj_weight_scale"
and processed_weight_scale.shape[-1] * 128 != layer.moe_intermediate_size
):
assert (
layer.moe_intermediate_size // 128 % processed_weight_scale.shape[-1] == 0
), "weight_scale_group_size must be a multiple of 128"
# If it is a multiple of 128, repeat to 128
processed_weight_scale = processed_weight_scale.repeat_interleave(
layer.moe_intermediate_size // 128 // processed_weight_scale.shape[-1], axis=-1
)
origin_shape = processed_weight_scale.shape
processed_weight_scale = processed_weight_scale.transpose([0, 2, 1])
processed_weight_scale = processed_weight_scale.reshape([-1, processed_weight_scale.shape[-1]])
processed_weight_scale = _permute_weight_scale(processed_weight_scale)
processed_weight_scale = processed_weight_scale.reshape(
[origin_shape[0], origin_shape[2], origin_shape[1] // 128, 128]
)
processed_weight_scale = processed_weight_scale.transpose([0, 2, 1, 3])
setattr(
layer,
name,
layer.create_parameter(
shape=processed_weight_scale.shape,
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
else:
processed_weight_scale = _permute_weight_scale(processed_weight_scale) processed_weight_scale = _permute_weight_scale(processed_weight_scale)
getattr(layer, name).set_value(processed_weight_scale) getattr(layer, name).set_value(processed_weight_scale)
@@ -1062,15 +992,16 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
scale_tensor = _extract_scale_tensor(layer, state_dict, scale_key_template, expert_idx) scale_tensor = _extract_scale_tensor(layer, state_dict, scale_key_template, expert_idx)
scale_weight_map[name].append(scale_tensor) scale_weight_map[name].append(scale_tensor)
# 3. Process scale tensor and set to layer
in_scales = []
for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]:
in_scales.append(_process_in_scale(in_scale_name, scale_weight_map[in_scale_name]))
for i, weight_scale_name in enumerate(["up_gate_proj_weight_scale", "down_proj_weight_scale"]): for i, weight_scale_name in enumerate(["up_gate_proj_weight_scale", "down_proj_weight_scale"]):
in_scale_name = weight_scale_name.replace("_weight_scale", "_in_scale")
in_scale = None
if hasattr(layer, in_scale_name) and in_scale_name in scale_weight_map.keys():
in_scale = _process_in_scale(in_scale_name, scale_weight_map[in_scale_name])
_process_weight_scale( _process_weight_scale(
weight_scale_name, weight_scale_name,
scale_weight_map[weight_scale_name], scale_weight_map[weight_scale_name],
in_scale, in_scales[i],
) )

View File

@@ -275,7 +275,6 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
topk_weights, topk_weights,
topk_idx, topk_idx,
expert_idx_per_token, expert_idx_per_token,
dequant_scale,
) = moe_expert_dispatch( ) = moe_expert_dispatch(
x, x,
gate_out, gate_out,

View File

@@ -39,7 +39,6 @@ class MixQuantConfig(QuantConfigBase):
is_permuted: bool = True, is_permuted: bool = True,
is_quantized: bool = False, is_quantized: bool = False,
hadamard_block_size: int = 128, hadamard_block_size: int = 128,
moe_dynamic_quant: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.dense_quant_type = dense_quant_type self.dense_quant_type = dense_quant_type
@@ -58,7 +57,6 @@ class MixQuantConfig(QuantConfigBase):
self.is_checkpoint_bf16 = not is_quantized self.is_checkpoint_bf16 = not is_quantized
self.is_quantized = is_quantized self.is_quantized = is_quantized
self.hadamard_block_size = hadamard_block_size self.hadamard_block_size = hadamard_block_size
self.moe_dynamic_quant = moe_dynamic_quant
def name(self) -> str: def name(self) -> str:
return "mix_quant" return "mix_quant"
@@ -75,7 +73,6 @@ class MixQuantConfig(QuantConfigBase):
config.get("is_permuted", True), config.get("is_permuted", True),
config.get("is_quantized", False), config.get("is_quantized", False),
config.get("hadamard_block_size", 128), config.get("hadamard_block_size", 128),
config.get("moe_dynamic_quant", False),
) )
def get_quant_method(self, layer) -> Optional[QuantMethodBase]: def get_quant_method(self, layer) -> Optional[QuantMethodBase]:

View File

@@ -17,20 +17,16 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
from fastdeploy.model_executor.ops.gpu import ( from fastdeploy.model_executor.ops.gpu import w4afp8_gemm, w4afp8_gemm_weight_convert
w4afp8_gemm,
w4afp8_gemm_scale_permute,
w4afp8_gemm_weight_convert,
)
class TestW4AFP8GEMM(unittest.TestCase): class TestW4AFP8GEMM(unittest.TestCase):
def setUp(self): def setUp(self):
paddle.seed(0) paddle.seed(0)
self.tokens_per_group = 1 self.tokens_per_group = 256
self.N = 1792 self.N = 256
self.K = 8192 self.K = 256
self.BATCH = 64 self.BATCH = 1
self.TokenPadding = 0 self.TokenPadding = 0
tokens = [self.tokens_per_group] * self.BATCH tokens = [self.tokens_per_group] * self.BATCH
@@ -42,15 +38,14 @@ class TestW4AFP8GEMM(unittest.TestCase):
self.input_fp8 = paddle.randn([self.all_tokens, self.K], dtype="bfloat16").astype(paddle.float8_e4m3fn) self.input_fp8 = paddle.randn([self.all_tokens, self.K], dtype="bfloat16").astype(paddle.float8_e4m3fn)
self.input_bf16 = self.input_fp8.astype("bfloat16") self.input_bf16 = self.input_fp8.astype("bfloat16")
self.weight = paddle.randn([self.BATCH, self.N, self.K], dtype="bfloat16") self.weight = paddle.randn([self.BATCH, self.N, self.K], dtype="bfloat16") / 10
self.weight_scale = 7 / self.weight.abs().max(axis=-1).reshape([self.BATCH, self.N, 1]) self.weight_scale = 7 / self.weight.abs().max(axis=-1).reshape([self.BATCH, self.N, 1])
self.weight_quant = (self.weight * self.weight_scale).astype("int") self.weight_quant = (self.weight * self.weight_scale).astype("int") + 7
self.weight_quant = paddle.clip(self.weight_quant, -7, 7) self.weight_quant = paddle.clip(self.weight_quant, 0, 14)
self.weight_quant_naive = self.weight_quant.astype("float32")
self.weight_quant = self.weight_quant.astype("bfloat16") self.weight_quant = self.weight_quant.astype("bfloat16")
self.weight_quant = paddle.where(self.weight_quant > 0, self.weight_quant, 8 - self.weight_quant)
self.weight_dequant_scale = 1 / self.weight_scale.astype("float32") self.weight_dequant_scale = 1 / self.weight_scale.astype("float32")
self.input_row_sum = self.input_bf16.sum(axis=1) * -7 / 512
self.max_tokens = int(self.tokens.max()) self.max_tokens = int(self.tokens.max())
def w4afp8_gemm_naive(self, input_bf16, weight_quant, tokens, weight_dequant_scale): def w4afp8_gemm_naive(self, input_bf16, weight_quant, tokens, weight_dequant_scale):
@@ -59,7 +54,7 @@ class TestW4AFP8GEMM(unittest.TestCase):
pre_fix_token = 0 pre_fix_token = 0
for i in range(self.BATCH): for i in range(self.BATCH):
input = input_bf16[pre_fix_token : pre_fix_token + tokens[i], :] input = input_bf16[pre_fix_token : pre_fix_token + tokens[i], :]
weight = weight_quant[i] * weight_dequant_scale[i] weight = (weight_quant[i] - 7.0) * weight_dequant_scale[i]
out_i = paddle.matmul(input, weight.astype("bfloat16"), transpose_y=True) out_i = paddle.matmul(input, weight.astype("bfloat16"), transpose_y=True)
out[pre_fix_token : pre_fix_token + tokens[i], :] = out_i out[pre_fix_token : pre_fix_token + tokens[i], :] = out_i
pre_fix_token += tokens[i] pre_fix_token += tokens[i]
@@ -76,53 +71,37 @@ class TestW4AFP8GEMM(unittest.TestCase):
weight_scale[b, n + j + 1] = temp[j // 2 + 8] weight_scale[b, n + j + 1] = temp[j // 2 + 8]
return weight_scale return weight_scale
def get_per_group_scale(self, processed_weight_scale):
processed_weight_scale = processed_weight_scale.repeat_interleave(self.K // 128, axis=-1)
origin_shape = processed_weight_scale.shape
processed_weight_scale = processed_weight_scale.transpose([0, 2, 1])
processed_weight_scale = processed_weight_scale.reshape([-1, processed_weight_scale.shape[-1]])
processed_weight_scale = w4afp8_gemm_scale_permute(processed_weight_scale)
processed_weight_scale = processed_weight_scale.reshape(
[origin_shape[0], origin_shape[2], origin_shape[1] // 128, 128]
)
processed_weight_scale = processed_weight_scale.transpose([0, 2, 1, 3])
return processed_weight_scale
def test_w4afp8_gemm(self): def test_w4afp8_gemm(self):
out_naive = self.w4afp8_gemm_naive( out_naive = self.w4afp8_gemm_naive(self.input_bf16, self.weight_quant, self.tokens, self.weight_dequant_scale)
self.input_bf16, self.weight_quant_naive, self.tokens, self.weight_dequant_scale
)
# weight_dequant_scale = paddle.to_tensor(self.permute_scale(self.weight_dequant_scale) * 512) weight_dequant_scale = paddle.to_tensor(self.permute_scale(self.weight_dequant_scale) * 512)
weight_dequant_scale = self.get_per_group_scale(self.weight_dequant_scale * 512) weight_int4 = w4afp8_gemm_weight_convert(self.weight_quant.astype("uint8").cpu())
weight_int4 = w4afp8_gemm_weight_convert(self.weight_quant.astype("uint8").cpu()).cuda()
if self.TokenPadding == 0: if self.TokenPadding == 0:
out_cuda = w4afp8_gemm( out_cuda = w4afp8_gemm(
self.input_fp8, self.input_fp8,
weight_int4, weight_int4.cuda(),
self.tokens_prefix_sum, self.tokens_prefix_sum,
self.input_row_sum.astype("float32"),
weight_dequant_scale.astype("float32"), weight_dequant_scale.astype("float32"),
None,
int(self.TokenPadding), int(self.TokenPadding),
self.all_tokens, self.max_tokens,
True, True,
) )
else: else:
out_cuda = w4afp8_gemm( out_cuda = w4afp8_gemm(
self.input_fp8, self.input_fp8,
weight_int4, weight_int4.cuda(),
self.tokens, self.tokens,
self.input_row_sum.astype("float32"),
weight_dequant_scale.astype("float32"), weight_dequant_scale.astype("float32"),
None,
int(self.TokenPadding), int(self.TokenPadding),
self.max_tokens, self.max_tokens,
True, True,
) )
gap = (out_cuda - out_naive).abs() gap = (out_cuda - out_naive).abs()
self.assertLess(float(gap.mean()), 0.11) self.assertLess(float(gap.mean()), 0.07)
if __name__ == "__main__": if __name__ == "__main__":