Revert "【New Feature】W4afp8 supports per group quantization (#4272)" (#4854)

This reverts commit 93fcf7e4ec.
This commit is contained in:
YuBaoku
2025-11-06 17:48:28 +08:00
committed by GitHub
parent 3478d20262
commit 819b2dbbae
26 changed files with 1718 additions and 4378 deletions

View File

@@ -304,7 +304,6 @@ paddle::Tensor MoeExpertFFNFunc(
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_proj_in_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,

View File

@@ -1,37 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "helper.h"
template <typename T, typename OutT>
void MoeFastHardamardWrapper(const T *x_data,
const int64_t *expert_idx_per_token,
const int64_t *recv_expert_count,
const T *shift,
const T *smooth,
const float *quant_scales,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
const int hadamard_block_size,
OutT *out,
cudaStream_t &stream);

File diff suppressed because it is too large Load Diff

View File

@@ -1,34 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fast_hardamard_kernel.hpp"
template void
MoeFastHardamardWrapper<phi::dtype::bfloat16, phi::dtype::bfloat16>(
const phi::dtype::bfloat16 *x_data,
const int64_t *expert_idx_per_token,
const int64_t *recv_expert_count,
const phi::dtype::bfloat16 *shift,
const phi::dtype::bfloat16 *smooth,
const float *quant_scales,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
const int hadamard_block_size,
phi::dtype::bfloat16 *out,
cudaStream_t &stream);

View File

@@ -1,34 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fast_hardamard_kernel.hpp"
template void
MoeFastHardamardWrapper<phi::dtype::bfloat16, phi::dtype::float8_e4m3fn>(
const phi::dtype::bfloat16 *x_data,
const int64_t *expert_idx_per_token,
const int64_t *recv_expert_count,
const phi::dtype::bfloat16 *shift,
const phi::dtype::bfloat16 *smooth,
const float *quant_scales,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
const int hadamard_block_size,
phi::dtype::float8_e4m3fn *out,
cudaStream_t &stream);

View File

@@ -1,33 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fast_hardamard_kernel.hpp"
template void MoeFastHardamardWrapper<phi::dtype::bfloat16, int8_t>(
const phi::dtype::bfloat16 *x_data,
const int64_t *expert_idx_per_token,
const int64_t *recv_expert_count,
const phi::dtype::bfloat16 *shift,
const phi::dtype::bfloat16 *smooth,
const float *quant_scales,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
const int hadamard_block_size,
int8_t *out,
cudaStream_t &stream);

View File

@@ -1,33 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fast_hardamard_kernel.hpp"
template void MoeFastHardamardWrapper<phi::dtype::float16, phi::dtype::float16>(
const phi::dtype::float16 *x_data,
const int64_t *expert_idx_per_token,
const int64_t *recv_expert_count,
const phi::dtype::float16 *shift,
const phi::dtype::float16 *smooth,
const float *quant_scales,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
const int hadamard_block_size,
phi::dtype::float16 *out,
cudaStream_t &stream);

View File

@@ -1,33 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fast_hardamard_kernel.hpp"
template void MoeFastHardamardWrapper<phi::dtype::float16, int8_t>(
const phi::dtype::float16 *x_data,
const int64_t *expert_idx_per_token,
const int64_t *recv_expert_count,
const phi::dtype::float16 *shift,
const phi::dtype::float16 *smooth,
const float *quant_scales,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
const int hadamard_block_size,
int8_t *out,
cudaStream_t &stream);

View File

@@ -25,8 +25,7 @@ template <typename T, int VecSize>
__global__ void moe_token_type_ids_kernel(T *gating_output,
const int *moe_token_type_ids_out,
const int num_rows,
const int num_experts,
const int k) {
const int num_experts, const int k) {
const int moe_token_index = blockIdx.x * blockDim.x + threadIdx.x;
if (moe_token_index >= num_rows) {
@@ -45,8 +44,7 @@ template <typename T>
void moe_token_type_ids_kernelLauncher(T *gating_output,
const int *moe_token_type_ids_out,
const int num_rows,
const int num_experts,
const int k,
const int num_experts, const int k,
cudaStream_t stream) {
const int blocks = num_rows * k / 512 + 1;
const int threads = 512;
@@ -54,35 +52,26 @@ void moe_token_type_ids_kernelLauncher(T *gating_output,
gating_output, moe_token_type_ids_out, num_rows, num_experts, k);
}
template <typename T, typename NvType>
class MoeHelper {
public:
using Fp16Traits =
cutlass::WintQuantTraits<NvType, cutlass::WintQuantMethod::kNone>;
using Int8Traits =
cutlass::WintQuantTraits<NvType,
cutlass::WintQuantMethod::kWeightOnlyInt8>;
using Int4Traits =
cutlass::WintQuantTraits<NvType,
cutlass::WintQuantMethod::kWeightOnlyInt4>;
template <typename T, typename NvType> class MoeHelper {
public:
using Fp16Traits = cutlass::WintQuantTraits<NvType, cutlass::WintQuantMethod::kNone>;
using Int8Traits = cutlass::WintQuantTraits<NvType, cutlass::WintQuantMethod::kWeightOnlyInt8>;
using Int4Traits = cutlass::WintQuantTraits<NvType, cutlass::WintQuantMethod::kWeightOnlyInt4>;
MoeHelper(const std::string gemm_method,
MoeGemmRunner<NvType, Fp16Traits> *fp16_moe_gemm_runner,
MoeGemmRunner<NvType, Int8Traits> *int8_moe_gemm_runner,
MoeGemmRunner<NvType, Int4Traits> *int4_moe_gemm_runner,
int layernum = 0)
: gemm_method_(gemm_method),
fp16_moe_gemm_runner_(fp16_moe_gemm_runner),
MoeHelper(
const std::string gemm_method,
MoeGemmRunner<NvType, Fp16Traits> *fp16_moe_gemm_runner,
MoeGemmRunner<NvType, Int8Traits> *int8_moe_gemm_runner,
MoeGemmRunner<NvType, Int4Traits> *int4_moe_gemm_runner,
int layernum = 0)
: gemm_method_(gemm_method), fp16_moe_gemm_runner_(fp16_moe_gemm_runner),
int8_moe_gemm_runner_(int8_moe_gemm_runner),
int4_moe_gemm_runner_(int4_moe_gemm_runner),
layernum_(layernum) {}
int4_moe_gemm_runner_(int4_moe_gemm_runner), layernum_(layernum) {}
// -------- getWorkspaceSize -------- //
template <typename KeyT>
size_t getWorkspaceSize(const int64_t num_rows,
const int64_t hidden_size,
const int64_t inter_size,
const int64_t num_experts,
size_t getWorkspaceSize(const int64_t num_rows, const int64_t hidden_size,
const int64_t inter_size, const int64_t num_experts,
const int64_t k) {
const size_t buf_size = AlignTo16(k * num_rows * hidden_size);
const size_t interbuf_size = AlignTo16(k * num_rows * inter_size);
@@ -93,10 +82,10 @@ class MoeHelper {
// FfnLayer forward.
size_t total_ws_bytes =
5 * num_moe_inputs *
sizeof(int); // source_rows_, permuted_rows_, permuted_experts_
total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data
sizeof(int); // source_rows_, permuted_rows_, permuted_experts_
total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data
total_ws_bytes +=
padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_
padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_
const size_t bytes_for_fc1_result = interbuf_size * sizeof(KeyT);
const size_t sorter_ws_size_bytes =
@@ -111,8 +100,8 @@ class MoeHelper {
}
total_ws_bytes +=
bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub
// sorting workspace
bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub
// sorting workspace
int64_t num_softmax_outs = 0;
const bool is_pow_2 =
@@ -126,27 +115,20 @@ class MoeHelper {
return total_ws_bytes;
}
void ComputeFFN(const paddle::Tensor *input,
const paddle::Tensor *gate_weight,
const paddle::Tensor *up_gate_proj_weight,
const paddle::Tensor *up_gate_proj_scale,
const paddle::Tensor *up_gate_proj_bias,
const paddle::Tensor *down_proj_weight,
const paddle::Tensor *down_proj_scale,
const paddle::Tensor *down_proj_bias,
const paddle::Tensor *moe_token_type_ids,
const int moe_topk,
const bool group_moe,
const bool norm_topk_prob,
const float routed_scaling_factor,
const std::string moe_type,
paddle::Tensor *output) {
void
ComputeFFN(const paddle::Tensor *input, const paddle::Tensor *gate_weight,
const paddle::Tensor *up_gate_proj_weight,
const paddle::Tensor *up_gate_proj_scale, const paddle::Tensor *up_gate_proj_bias,
const paddle::Tensor *down_proj_weight,
const paddle::Tensor *down_proj_scale, const paddle::Tensor *down_proj_bias,
const paddle::Tensor *moe_token_type_ids, const int moe_topk,
const bool group_moe, const bool norm_topk_prob,
const float routed_scaling_factor, const std::string moe_type,
paddle::Tensor *output) {
auto *input_activations = input->data<T>();
auto *gating_weights = gate_weight->data<float>();
const T *fc1_expert_biases =
up_gate_proj_bias ? up_gate_proj_bias->data<T>() : nullptr;
const T *fc2_expert_biases =
down_proj_bias ? down_proj_bias->data<T>() : nullptr;
const T *fc1_expert_biases = up_gate_proj_bias ? up_gate_proj_bias->data<T>() : nullptr;
const T *fc2_expert_biases = down_proj_bias ? down_proj_bias->data<T>() : nullptr;
auto *output_ = output->data<T>();
auto stream = input->stream();
@@ -166,8 +148,7 @@ class MoeHelper {
const int64_t hidden_size = up_gate_proj_dims[1];
int64_t inter_dim = 0;
if (moe_type == "qkv") {
inter_dim =
up_gate_proj_dims[2] * up_gate_proj_dims[3] * up_gate_proj_dims[4];
inter_dim = up_gate_proj_dims[2] * up_gate_proj_dims[3] * up_gate_proj_dims[4];
} else {
inter_dim = up_gate_proj_dims[2];
}
@@ -251,79 +232,44 @@ class MoeHelper {
if (moe_token_type_ids) {
auto *moe_token_type_ids_out = moe_token_type_ids->data<int>();
moe_token_type_ids_kernelLauncher<float>(gating_output,
moe_token_type_ids_out,
num_rows,
num_experts,
k,
stream);
moe_token_type_ids_out, num_rows,
num_experts, k, stream);
}
topk_gating_softmax_kernelLauncher<float, int>(gating_output,
nullptr,
expert_scales_float,
softmax_out_,
expert_for_source_row,
source_rows_,
softmax_max_prob,
num_rows,
num_experts,
k,
group_moe,
stream);
topk_gating_softmax_kernelLauncher<float, int>(
gating_output, nullptr, expert_scales_float, softmax_out_,
expert_for_source_row, source_rows_, softmax_max_prob, num_rows,
num_experts, k, group_moe, stream);
const int64_t sorter_ws_size_bytes =
AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows)));
sorter_.run(fc1_result_,
sorter_ws_size_bytes,
expert_for_source_row,
permuted_experts_,
source_rows_,
permuted_rows_,
k * num_rows,
false,
stream);
sorter_.run(fc1_result_, sorter_ws_size_bytes, expert_for_source_row,
permuted_experts_, source_rows_, permuted_rows_, k * num_rows,
false, stream);
initialize_moe_routing_kernelLauncher(
input_activations,
permuted_data_,
permuted_rows_,
nullptr,
nullptr,
expanded_source_row_to_expanded_dest_row,
nullptr,
num_rows,
num_rows,
hidden_size,
k,
stream);
input_activations, permuted_data_, permuted_rows_, nullptr, nullptr,
expanded_source_row_to_expanded_dest_row, num_rows, num_rows,
hidden_size, k, stream);
const int64_t expanded_active_expert_rows = k * num_rows;
compute_total_rows_before_expert(permuted_experts_,
expanded_active_expert_rows,
num_experts,
total_rows_before_expert_,
stream);
expanded_active_expert_rows, num_experts,
total_rows_before_expert_, stream);
if (gemm_method_ == "weight_only_int8") {
typename Int8Traits::Arguments up_gate_proj_quant_args;
int8_moe_gemm_runner_->moe_gemm_bias_act(
reinterpret_cast<NvType *>(permuted_data_),
reinterpret_cast<const uint8_t *>(
up_gate_proj_weight->data<int8_t>()),
reinterpret_cast<const uint8_t *>(up_gate_proj_weight->data<int8_t>()),
reinterpret_cast<const NvType *>(up_gate_proj_scale->data<T>()),
reinterpret_cast<const NvType *>(fc1_expert_biases),
reinterpret_cast<NvType *>(fc1_out),
total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows,
inter_size,
hidden_size,
num_experts,
up_gate_proj_quant_args,
"none",
stream);
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
up_gate_proj_quant_args, "none", stream);
} else if (gemm_method_ == "weight_only_int4") {
typename Int4Traits::Arguments up_gate_proj_quant_args;
int4_moe_gemm_runner_->moe_gemm_bias_act(
@@ -332,33 +278,20 @@ class MoeHelper {
up_gate_proj_weight->data<int8_t>()),
reinterpret_cast<const NvType *>(up_gate_proj_scale->data<T>()),
reinterpret_cast<const NvType *>(fc1_expert_biases),
reinterpret_cast<NvType *>(fc1_out),
total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows,
inter_size,
hidden_size,
num_experts,
up_gate_proj_quant_args,
"none",
stream);
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
up_gate_proj_quant_args, "none", stream);
} else {
typename Fp16Traits::Arguments up_gate_proj_quant_args;
fp16_moe_gemm_runner_->moe_gemm_bias_act(
reinterpret_cast<NvType *>(permuted_data_),
reinterpret_cast<const NvType *>(up_gate_proj_weight->data<T>()),
nullptr,
reinterpret_cast<const NvType *>(up_gate_proj_weight->data<T>()), nullptr,
reinterpret_cast<const NvType *>(fc1_expert_biases),
reinterpret_cast<NvType *>(fc1_out),
total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows,
inter_size,
hidden_size,
num_experts,
up_gate_proj_quant_args,
"none",
stream);
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
up_gate_proj_quant_args, "none", stream);
}
if (moe_type == "ffn") {
@@ -376,15 +309,10 @@ class MoeHelper {
reinterpret_cast<NvType *>(act_out),
reinterpret_cast<const uint8_t *>(down_proj_weight->data<int8_t>()),
reinterpret_cast<const NvType *>(down_proj_scale->data<T>()),
reinterpret_cast<NvType *>(fc2_result),
total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows,
hidden_size,
inter_size / 2,
num_experts,
down_proj_quant_args,
stream);
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, hidden_size, inter_size / 2,
num_experts, down_proj_quant_args, stream);
} else if (gemm_method_ == "weight_only_int4") {
typename Int4Traits::Arguments down_proj_quant_args;
int4_moe_gemm_runner_->moe_gemm(
@@ -392,66 +320,40 @@ class MoeHelper {
reinterpret_cast<const cutlass::uint4b_t *>(
down_proj_weight->data<int8_t>()),
reinterpret_cast<const NvType *>(down_proj_scale->data<T>()),
reinterpret_cast<NvType *>(fc2_result),
total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows,
hidden_size,
inter_size / 2,
num_experts,
down_proj_quant_args,
stream);
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, hidden_size, inter_size / 2,
num_experts, down_proj_quant_args, stream);
} else {
typename Fp16Traits::Arguments down_proj_quant_args;
fp16_moe_gemm_runner_->moe_gemm(
reinterpret_cast<NvType *>(act_out),
reinterpret_cast<const NvType *>(down_proj_weight->data<T>()),
nullptr,
reinterpret_cast<NvType *>(fc2_result),
total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows,
hidden_size,
inter_size / 2,
num_experts,
down_proj_quant_args,
stream);
reinterpret_cast<const NvType *>(down_proj_weight->data<T>()), nullptr,
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, hidden_size, inter_size / 2,
num_experts, down_proj_quant_args, stream);
}
finalize_moe_routing_kernelLauncher(
fc2_result,
output_,
fc2_expert_biases,
fc2_result, output_, fc2_expert_biases,
reinterpret_cast<float *>(expert_scales_float),
expanded_source_row_to_expanded_dest_row,
expert_for_source_row,
num_rows,
hidden_size,
k,
static_cast<int>(1),
norm_topk_prob,
routed_scaling_factor,
stream);
expanded_source_row_to_expanded_dest_row, expert_for_source_row,
num_rows, hidden_size, k, static_cast<int>(1), norm_topk_prob,
routed_scaling_factor, stream);
} else {
finalize_moe_routing_kernelLauncher(
// fc2_result,
fc1_out,
output_,
fc1_expert_biases, // fc2_expert_biases,
fc1_out, output_,
fc1_expert_biases, // fc2_expert_biases,
reinterpret_cast<float *>(expert_scales_float),
expanded_source_row_to_expanded_dest_row,
expert_for_source_row,
num_rows,
inter_size,
k,
static_cast<int>(0),
norm_topk_prob,
routed_scaling_factor,
stream);
expanded_source_row_to_expanded_dest_row, expert_for_source_row,
num_rows, inter_size, k, static_cast<int>(0), norm_topk_prob,
routed_scaling_factor, stream);
}
}
private:
private:
std::string gemm_method_;
MoeGemmRunner<NvType, Fp16Traits> *fp16_moe_gemm_runner_;
MoeGemmRunner<NvType, Int8Traits> *int8_moe_gemm_runner_;
@@ -460,4 +362,4 @@ class MoeHelper {
CubKeyValueSorter sorter_;
};
} // namespace phi
} // namespace phi

File diff suppressed because it is too large Load Diff

View File

@@ -26,26 +26,17 @@
template <paddle::DataType T>
void MoeDispatchKernel(
const paddle::Tensor &input,
const paddle::Tensor &gating_output,
const paddle::Tensor &input, const paddle::Tensor &gating_output,
const paddle::optional<paddle::Tensor> &gating_correction_bias,
const paddle::optional<paddle::Tensor> &w4a8_in_scale,
const int moe_topk,
const bool group_moe,
const bool topk_only_mode,
const int num_rows,
const int hidden_size,
const int expert_num,
paddle::Tensor *permute_input,
const paddle::optional<paddle::Tensor> &w4a8_in_scale, const int moe_topk,
const bool group_moe, const bool topk_only_mode, const int num_rows,
const int hidden_size, const int expert_num, paddle::Tensor *permute_input,
paddle::Tensor *tokens_expert_prefix_sum,
paddle::Tensor *permute_indices_per_token,
paddle::Tensor *topk_weight,
paddle::Tensor *topk_idx,
paddle::Tensor *expert_idx_per_token,
paddle::Tensor *dequant_scale) {
paddle::Tensor *permute_indices_per_token, paddle::Tensor *topk_weight,
paddle::Tensor *topk_idx, paddle::Tensor *expert_idx_per_token) {
using namespace phi;
if (num_rows == 0) {
if (num_rows == 0){
return;
}
typedef PDTraits<T> traits_;
@@ -57,14 +48,12 @@ void MoeDispatchKernel(
if (group_moe) {
// Check if expert_num is divisible by moe_topk, else throw an error
PADDLE_ENFORCE_EQ(expert_num % moe_topk,
0,
PADDLE_ENFORCE_EQ(expert_num % moe_topk, 0,
common::errors::InvalidArgument(
"The number of experts (expert_num) "
"must be divisible by moe_topk. "
"Got expert_num = %d and moe_topk = %d.",
expert_num,
moe_topk));
expert_num, moe_topk));
}
const int num_moe_inputs = AlignTo16(num_rows * moe_topk);
@@ -79,8 +68,7 @@ void MoeDispatchKernel(
paddle::Tensor ws_ptr_tensor =
GetEmptyTensor({bytes + sorter_ws_size_bytes + sort_tmp_in_out_size},
paddle::DataType::INT8,
place);
paddle::DataType::INT8, place);
int8_t *ws_ptr = ws_ptr_tensor.data<int8_t>();
int *source_rows_ = reinterpret_cast<int *>(ws_ptr);
@@ -108,8 +96,8 @@ void MoeDispatchKernel(
paddle::Tensor softmax_buffer;
if (!is_pow_2 || expert_num > 256 || group_moe || gating_correction_bias) {
softmax_buffer = GetEmptyTensor(
{num_rows * expert_num}, paddle::DataType::FLOAT32, place);
softmax_buffer = GetEmptyTensor({num_rows * expert_num},
paddle::DataType::FLOAT32, place);
softmax_out_ = softmax_buffer.data<float>();
} else {
softmax_out_ = nullptr;
@@ -119,106 +107,47 @@ void MoeDispatchKernel(
gating_output.data<float>(),
gating_correction_bias ? gating_correction_bias.get().data<float>()
: nullptr,
topk_weight->data<float>(),
softmax_out_,
topk_idx_ptr,
source_rows_,
softmax_max_prob,
num_rows,
expert_num,
moe_topk,
group_moe,
stream,
topk_weight->data<float>(), softmax_out_, topk_idx_ptr, source_rows_,
softmax_max_prob, num_rows, expert_num, moe_topk, group_moe, stream,
topk_only_mode);
sorter_.run(reinterpret_cast<void *>(sorter_ws_ptr),
sorter_ws_size_bytes,
topk_idx_ptr,
expert_idx_per_token->data<int32_t>(),
source_rows_,
permuted_rows_,
moe_topk * num_rows,
false,
stream);
sorter_.run(reinterpret_cast<void *>(sorter_ws_ptr), sorter_ws_size_bytes,
topk_idx_ptr, expert_idx_per_token->data<int32_t>(), source_rows_,
permuted_rows_, moe_topk * num_rows, false, stream);
if (w4a8_in_scale) {
if (permute_input->dtype() == paddle::DataType::INT8) {
initialize_moe_routing_kernelLauncher(
input.data<data_t>(),
permute_input->data<int8_t>(),
permuted_rows_,
expert_idx_per_token->data<int32_t>(),
w4a8_in_scale->data<float>(),
permute_indices_per_token->data<int32_t>(),
nullptr,
num_rows,
num_rows,
hidden_size,
moe_topk,
stream);
input.data<data_t>(), permute_input->data<int8_t>(), permuted_rows_,
expert_idx_per_token->data<int32_t>(), w4a8_in_scale->data<float>(),
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
hidden_size, moe_topk, stream);
} else if (permute_input->dtype() == paddle::DataType::FLOAT8_E4M3FN) {
initialize_moe_routing_kernelLauncher(
input.data<data_t>(),
permute_input->data<float8_e4m3fn>(),
permuted_rows_,
expert_idx_per_token->data<int32_t>(),
w4a8_in_scale->data<float>(),
permute_indices_per_token->data<int32_t>(),
nullptr,
num_rows,
num_rows,
hidden_size,
moe_topk,
stream);
input.data<data_t>(), permute_input->data<float8_e4m3fn>(),
permuted_rows_, expert_idx_per_token->data<int32_t>(),
w4a8_in_scale->data<float>(),
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
hidden_size, moe_topk, stream);
}
} else {
if (permute_input->dtype() == paddle::DataType::FLOAT8_E4M3FN) {
initialize_moe_routing_kernelLauncher(
input.data<data_t>(),
permute_input->data<float8_e4m3fn>(),
permuted_rows_,
expert_idx_per_token->data<int32_t>(),
nullptr,
permute_indices_per_token->data<int32_t>(),
dequant_scale->data<float>(),
num_rows,
num_rows,
hidden_size,
moe_topk,
stream);
} else {
initialize_moe_routing_kernelLauncher(
input.data<data_t>(),
permute_input->data<data_t>(),
permuted_rows_,
expert_idx_per_token->data<int32_t>(),
nullptr,
permute_indices_per_token->data<int32_t>(),
nullptr,
num_rows,
num_rows,
hidden_size,
moe_topk,
stream);
}
initialize_moe_routing_kernelLauncher(
input.data<data_t>(), permute_input->data<data_t>(), permuted_rows_,
expert_idx_per_token->data<int32_t>(), nullptr,
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
hidden_size, moe_topk, stream);
}
compute_total_rows_before_expert(expert_idx_per_token->data<int32_t>(),
moe_topk * num_rows,
expert_num,
tokens_expert_prefix_sum->data<int64_t>(),
stream);
compute_total_rows_before_expert(
expert_idx_per_token->data<int32_t>(), moe_topk * num_rows, expert_num,
tokens_expert_prefix_sum->data<int64_t>(), stream);
}
std::vector<paddle::Tensor> MoeExpertDispatch(
const paddle::Tensor &input,
const paddle::Tensor &gating_output,
const paddle::Tensor &input, const paddle::Tensor &gating_output,
const paddle::optional<paddle::Tensor> &gating_correction_bias,
const paddle::optional<paddle::Tensor> &w4a8_in_scale,
const int moe_topk,
const bool group_moe,
const std::string &moe_quant_type,
const bool topk_only_mode) {
const paddle::optional<paddle::Tensor> &w4a8_in_scale, const int moe_topk,
const bool group_moe, const std::string &moe_quant_type, const bool topk_only_mode) {
const auto input_type = input.dtype();
auto place = input.place();
int token_rows = 0;
@@ -241,21 +170,10 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
} else if (moe_quant_type == "w4afp8") {
permute_input_dtype = paddle::DataType::FLOAT8_E4M3FN;
}
} else {
if (moe_quant_type == "w4afp8") {
permute_input_dtype = paddle::DataType::FLOAT8_E4M3FN;
}
}
auto permute_input = GetEmptyTensor(
{moe_topk * num_rows, hidden_size}, permute_input_dtype, place);
int dequant_scale_size = 1;
if (moe_quant_type == "w4afp8" && !w4a8_in_scale) {
dequant_scale_size = moe_topk * num_rows;
}
auto dequant_scale =
GetEmptyTensor({dequant_scale_size}, paddle::DataType::FLOAT32, place);
auto permute_input = GetEmptyTensor({moe_topk * num_rows, hidden_size},
permute_input_dtype, place);
// correspond to the weighted coefficients of the results from each expert.
auto topk_weight =
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
@@ -270,65 +188,39 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
auto expert_idx_per_token =
GetEmptyTensor({num_rows * moe_topk}, paddle::DataType::INT32, place);
if (token_rows == 0) {
if (token_rows == 0){
return {permute_input,
tokens_expert_prefix_sum,
permute_indices_per_token,
topk_weight,
topk_idx,
expert_idx_per_token,
dequant_scale};
expert_idx_per_token};
}
switch (input_type) {
case paddle::DataType::BFLOAT16:
MoeDispatchKernel<paddle::DataType::BFLOAT16>(input,
gating_output,
gating_correction_bias,
w4a8_in_scale,
moe_topk,
group_moe,
topk_only_mode,
num_rows,
hidden_size,
expert_num,
&permute_input,
&tokens_expert_prefix_sum,
&permute_indices_per_token,
&topk_weight,
&topk_idx,
&expert_idx_per_token,
&dequant_scale);
break;
case paddle::DataType::FLOAT16:
MoeDispatchKernel<paddle::DataType::FLOAT16>(input,
gating_output,
gating_correction_bias,
w4a8_in_scale,
moe_topk,
group_moe,
topk_only_mode,
num_rows,
hidden_size,
expert_num,
&permute_input,
&tokens_expert_prefix_sum,
&permute_indices_per_token,
&topk_weight,
&topk_idx,
&expert_idx_per_token,
&dequant_scale);
break;
default:
PD_THROW("Unsupported data type for MoeDispatchKernel");
case paddle::DataType::BFLOAT16:
MoeDispatchKernel<paddle::DataType::BFLOAT16>(
input, gating_output, gating_correction_bias, w4a8_in_scale, moe_topk,
group_moe, topk_only_mode, num_rows, hidden_size, expert_num,
&permute_input, &tokens_expert_prefix_sum, &permute_indices_per_token,
&topk_weight, &topk_idx, &expert_idx_per_token);
break;
case paddle::DataType::FLOAT16:
MoeDispatchKernel<paddle::DataType::FLOAT16>(
input, gating_output, gating_correction_bias, w4a8_in_scale, moe_topk,
group_moe, topk_only_mode, num_rows, hidden_size, expert_num,
&permute_input, &tokens_expert_prefix_sum, &permute_indices_per_token,
&topk_weight, &topk_idx, &expert_idx_per_token);
break;
default:
PD_THROW("Unsupported data type for MoeDispatchKernel");
}
return {permute_input,
tokens_expert_prefix_sum,
permute_indices_per_token,
topk_weight,
topk_idx,
expert_idx_per_token,
dequant_scale};
expert_idx_per_token};
}
std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
@@ -353,22 +245,16 @@ std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
{moe_topk, num_rows},
{num_rows, moe_topk},
{num_rows, moe_topk},
{permuted_rows},
{num_rows}};
{permuted_rows}};
}
std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
const paddle::DataType &input_dtype,
const paddle::DataType &gating_output_dtype,
const paddle::optional<paddle::DataType> &bias_type,
const int moe_topk) {
return {input_dtype,
paddle::DataType::INT64,
paddle::DataType::INT32,
paddle::DataType::FLOAT32,
paddle::DataType::INT32,
paddle::DataType::INT32,
paddle::DataType::FLOAT32};
std::vector<paddle::DataType>
MoeExpertDispatchInferDtype(const paddle::DataType &input_dtype,
const paddle::DataType &gating_output_dtype,
const paddle::optional<paddle::DataType> &bias_type,
const int moe_topk) {
return {input_dtype, paddle::DataType::INT64, paddle::DataType::INT32,
paddle::DataType::FLOAT32, paddle::DataType::INT32, paddle::DataType::INT32};
}
/**
@@ -376,8 +262,7 @@ std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
*
* This operator performs the following key functions:
* 1. Computes top-k experts for each input token based on gating scores
* 2. Permutes input tokens according to their selected experts for efficient
* expert processing
* 2. Permutes input tokens according to their selected experts for efficient expert processing
* 3. Computes prefix sums of tokens per expert for group_gemm optimization
*
* Inputs:
@@ -387,17 +272,18 @@ std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
* - gating_output: Gating network output scores for each token-expert pair
* Shape: [total_tokens, expert_num]
* dtype: must be float32
* - gating_correction_bias: Optional bias term for gating correction
* (expert_num)
* - gating_correction_bias: Optional bias term for gating correction (expert_num)
*
* Outputs:
* - permute_input: Permuted input tensor organized by expert
* Shape: [moe_topk * total_tokens, hidden_size]
* dtype: Same as input
* - tokens_expert_prefix_sum: Prefix sum array of token counts per expert for
* group_gemm Shape: [expert_num] dtype: int64
* - permute_indices_per_token: Indices mapping for reconstructing original
* order Shape: [moe_topk, total_tokens] dtype: int32
* - tokens_expert_prefix_sum: Prefix sum array of token counts per expert for group_gemm
* Shape: [expert_num]
* dtype: int64
* - permute_indices_per_token: Indices mapping for reconstructing original order
* Shape: [moe_topk, total_tokens]
* dtype: int32
* - top_k_weight: Weight coefficients for combining expert outputs
* Shape: [total_tokens, moe_topk]
* dtype: float32
@@ -406,8 +292,7 @@ std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
* dtype: int32
*
* Attributes:
* - moe_topk: Number of experts to select for each token (k value in top-k
* routing)
* - moe_topk: Number of experts to select for each token (k value in top-k routing)
* - group_moe: Whether to perform group softmax within the operator
* (true: softmax is computed within groups of experts,
* false: standard softmax across all experts)
@@ -421,21 +306,13 @@ std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
* - When group_moe is true, expert_num must be divisible by moe_topk
*/
PD_BUILD_STATIC_OP(moe_expert_dispatch)
.Inputs({"input",
"gating_output",
.Inputs({"input", "gating_output",
paddle::Optional("gating_correction_bias"),
paddle::Optional("w4a8_in_scale")})
.Outputs({"permute_input",
"tokens_expert_prefix_sum",
"permute_indices_per_token",
"topk_weight",
"topk_idx",
"expert_idx_per_token",
"dequant_scale"})
.Attrs({"moe_topk:int",
"group_moe:bool",
"moe_quant_type:std::string",
"topk_only_mode:bool"})
.Outputs({"permute_input", "tokens_expert_prefix_sum",
"permute_indices_per_token", "topk_weight", "topk_idx",
"expert_idx_per_token"})
.Attrs({"moe_topk:int", "group_moe:bool", "moe_quant_type:std::string", "topk_only_mode:bool"})
.SetKernelFn(PD_KERNEL(MoeExpertDispatch))
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertDispatchInferDtype));

View File

@@ -17,7 +17,6 @@
#include "cutlass/numeric_conversion.h"
#include "group_swiglu_with_masked.h"
#include "helper.h"
#include "moe/fast_hardmard/fast_hardamard_kernel.h"
#include "moe/fused_moe_helper.h"
template <typename DataT,

View File

@@ -18,8 +18,8 @@
#include "cutlass_kernels/w4a8_moe/w4a8_moe_gemm_kernel.h"
#include "group_swiglu_with_masked.h"
#include "helper.h"
#include "moe/fast_hardmard/fast_hardamard_kernel.h"
#include "moe/fused_moe_helper.h"
#include "moe/moe_fast_hardamard_kernel.h"
#include "swigluoai.h"
#include "w4afp8_gemm/w4afp8_gemm.h"
@@ -28,7 +28,6 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_proj_in_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
@@ -198,20 +197,31 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
typedef PDTraits<paddle::DataType::FLOAT8_E4M3FN> traits_fp8;
typedef typename traits_fp8::DataType DataType_fp8;
typedef typename traits_fp8::data_t data_t_fp8;
paddle::Tensor weight_scale_tensor =
*const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr());
const int weight_scale_group_size = weight_scale_tensor.dims().size() == 2
? hidden_size
: weight_scale_tensor.dims()[3];
const float* input_dequant_scale =
up_proj_in_scale ? up_proj_in_scale.get().data<float>() : nullptr;
Allocator::AllocationPtr ffn1_input_row_sum;
ffn1_input_row_sum =
allocator->Allocate(sizeof(float) * expanded_active_expert_rows);
compute_row_sum(
permute_input.data<data_t_fp8>(),
expanded_active_expert_rows,
hidden_size,
reinterpret_cast<float*>(ffn1_input_row_sum->ptr()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
num_max_tokens_per_expert,
used_in_ep_low_latency,
stream);
float* row_scale = nullptr;
DisPatchW4AFp8GemmWrapper(
reinterpret_cast<const DataType_fp8*>(permute_input.data<data_t_fp8>()),
reinterpret_cast<const DataType_fp8*>(
up_gate_proj_weight.data<int8_t>()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
input_dequant_scale,
weight_scale_tensor.data<float>(),
reinterpret_cast<float*>(ffn1_input_row_sum->ptr()),
row_scale,
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr())
->data<float>(),
reinterpret_cast<NvType*>(fc1_out),
used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
used_in_ep_low_latency ? num_max_tokens_per_expert
@@ -219,7 +229,6 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
num_experts,
inter_size,
hidden_size,
weight_scale_group_size,
stream);
} else {
typename cutlass::WintQuantTraits<
@@ -346,84 +355,60 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
} else if (quant_method == "w4afp8") {
data_t* ffn2_shift = nullptr;
data_t* ffn2_smooth = nullptr;
float* input_dequant_scale = nullptr;
float* row_scale = nullptr;
Allocator::AllocationPtr fp8_act_out;
fp8_act_out = allocator->Allocate(SizeOf(paddle::DataType::INT8) *
act_out_tensor.numel());
Allocator::AllocationPtr ffn2_input_row_sum;
ffn2_input_row_sum =
allocator->Allocate(sizeof(float) * expanded_active_expert_rows);
if (down_proj_in_scale) {
MoeFastHardamardWrapper<data_t, data_t_fp8>(
act_out_tensor.data<data_t>(),
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>()
: nullptr,
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
ffn2_shift,
ffn2_smooth,
down_proj_in_scale
? const_cast<paddle::Tensor*>(down_proj_in_scale.get_ptr())
->data<float>()
: nullptr,
1,
448.0f,
-448.0f,
expanded_active_expert_rows,
inter_size / 2,
num_max_tokens_per_expert,
used_in_ep_low_latency,
hadamard_block_size,
reinterpret_cast<data_t_fp8*>(fp8_act_out->ptr()),
stream);
} else {
Allocator::AllocationPtr ffn2_input_dequant_scale;
ffn2_input_dequant_scale =
allocator->Allocate(sizeof(float) * expanded_active_expert_rows);
input_dequant_scale =
reinterpret_cast<float*>(ffn2_input_dequant_scale->ptr());
MoeFastHardamardWrapper<data_t, data_t>(
act_out_tensor.data<data_t>(),
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>()
: nullptr,
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
ffn2_shift, // ffn2_shift->data<T>(),
ffn2_smooth, // ffn2_smooth->data<T>(),
nullptr,
1,
448.0f,
-448.0f,
expanded_active_expert_rows,
inter_size / 2,
num_max_tokens_per_expert,
used_in_ep_low_latency,
hadamard_block_size,
act_out_tensor.data<data_t>(),
stream);
// note(yuanxiaolan): optimize this
MoeFastHardamardWrapper<data_t, data_t>(
act_out_tensor.data<data_t>(),
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>()
: nullptr,
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
ffn2_shift, // ffn2_shift->data<T>(),
ffn2_smooth, // ffn2_smooth->data<T>(),
nullptr,
1,
448.0f,
-448.0f,
expanded_active_expert_rows,
inter_size / 2,
num_max_tokens_per_expert,
used_in_ep_low_latency,
hadamard_block_size,
act_out_tensor.data<data_t>(),
stream);
quantize_moe_input<data_t, data_t_fp8>(
act_out_tensor.data<data_t>(),
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>()
: nullptr,
expanded_active_expert_rows,
inter_size / 2,
input_dequant_scale,
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
num_max_tokens_per_expert,
used_in_ep_low_latency,
reinterpret_cast<data_t_fp8*>(fp8_act_out->ptr()),
stream);
}
paddle::Tensor weight_scale_tensor =
*const_cast<paddle::Tensor*>(down_proj_scale.get_ptr());
const int weight_scale_group_size = weight_scale_tensor.dims().size() == 2
? inter_size / 2
: weight_scale_tensor.dims()[3];
quantize_moe_input<data_t, data_t_fp8>(
act_out_tensor.data<data_t>(),
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>()
: nullptr,
down_proj_in_scale
? const_cast<paddle::Tensor*>(down_proj_in_scale.get_ptr())
->data<float>()
: nullptr,
448.0f,
-448.0f,
expanded_active_expert_rows,
inter_size / 2,
reinterpret_cast<float*>(ffn2_input_row_sum->ptr()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
num_max_tokens_per_expert,
used_in_ep_low_latency,
reinterpret_cast<data_t_fp8*>(fp8_act_out->ptr()),
stream);
DisPatchW4AFp8GemmWrapper(
reinterpret_cast<const DataType_fp8*>(fp8_act_out->ptr()),
reinterpret_cast<const DataType_fp8*>(down_proj_weight.data<int8_t>()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
input_dequant_scale,
weight_scale_tensor.data<float>(),
reinterpret_cast<float*>(ffn2_input_row_sum->ptr()),
row_scale,
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr())->data<float>(),
reinterpret_cast<NvType*>(ffn_out_data),
used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
used_in_ep_low_latency ? num_max_tokens_per_expert
@@ -431,7 +416,6 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
num_experts,
hidden_size,
inter_size / 2,
weight_scale_group_size,
stream);
} else {
typename cutlass::WintQuantTraits<
@@ -458,7 +442,6 @@ paddle::Tensor MoeExpertFFNFunc(
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_proj_in_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
@@ -483,7 +466,6 @@ paddle::Tensor MoeExpertFFNFunc(
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
up_proj_in_scale,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
@@ -501,7 +483,6 @@ paddle::Tensor MoeExpertFFNFunc(
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
up_proj_in_scale,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
@@ -525,7 +506,6 @@ std::vector<paddle::Tensor> MoeExpertFFN(
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_proj_in_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
@@ -540,7 +520,6 @@ std::vector<paddle::Tensor> MoeExpertFFN(
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
up_proj_in_scale,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
@@ -558,7 +537,6 @@ std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
const std::vector<int64_t>& tokens_expert_prefix_sum_shape,
const std::vector<int64_t>& up_gate_proj_weight_shape,
const std::vector<int64_t>& down_proj_weight_shape,
const paddle::optional<std::vector<int64_t>>& up_proj_in_scale_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_bias_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_scale_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_scale_shape,
@@ -577,7 +555,6 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
const paddle::DataType& tokens_expert_prefix_sum_dtype,
const paddle::DataType& up_gate_proj_weight_dtype,
const paddle::DataType& down_proj_weight_dtype,
const paddle::optional<paddle::DataType>& up_proj_in_scale_dtype,
const paddle::optional<paddle::DataType>& up_gate_proj_bias_dtype,
const paddle::optional<paddle::DataType>& up_gate_proj_scale_dtype,
const paddle::optional<paddle::DataType>& down_proj_scale_dtype,
@@ -655,7 +632,6 @@ PD_BUILD_STATIC_OP(moe_expert_ffn)
"tokens_expert_prefix_sum",
"up_gate_proj_weight",
"down_proj_weight",
paddle::Optional("up_proj_in_scale"),
paddle::Optional("up_gate_proj_bias"),
paddle::Optional("up_gate_proj_scale"),
paddle::Optional("down_proj_scale"),

View File

@@ -23,142 +23,132 @@
using namespace cute;
template <int kStages,
class GemmType,
class OutputType,
class SmemLayoutA,
class SmemLayoutB,
class SmemLayoutC,
class SmemLayoutScale>
template <int kStages, class GemmType, class OutputType, class SmemLayoutA,
class SmemLayoutB, class SmemLayoutC>
struct SharedStorage {
union {
struct {
cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutA>> smem_a;
cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutB>> smem_b;
cute::array_aligned<float, cute::cosize_v<SmemLayoutScale>> smem_scale;
union {
struct {
cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutA>> smem_a;
cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutB>> smem_b;
};
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutC>> smem_c;
};
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutC>> smem_c;
};
struct {
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline;
};
};
template <int kBlockM_,
int kBlockN_,
int kBlockK_,
int kNWarps_,
int kStages_,
int kTiles_,
int M_,
int K_,
int TokenPackSize_,
int WeightScaleGroup_,
int kClusterM_ = 1,
typename elem_type = cutlass::float_e4m3_t,
typename OutputType = cutlass::bfloat16_t>
template<int kBlockM_, int kBlockN_, int kBlockK_,
int kNWarps_, int kStages_,
int kTiles_, int M_,
int TokenPackSize_,
int TAIL_N_ = 0,
int kClusterM_ = 1,
typename elem_type=cutlass::float_e4m3_t,
typename OutputType = cutlass::bfloat16_t>
struct Kernel_traits {
using Element = elem_type;
using ElementOutput = OutputType;
using ElementAccum = typename std::
conditional_t<WeightScaleGroup_ == K_, float, cutlass::half_t>;
static_assert(cutlass::sizeof_bits_v<Element> == 8);
using Element = elem_type;
using ElementAccum = float;
using ElementOutput = OutputType;
static_assert(cutlass::sizeof_bits_v<Element> == 8);
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int NumMmaThreads = kNThreads - NumProducerThreads;
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int NumMmaThreads = kNThreads - NumProducerThreads;
static_assert(kNWarps_ == 12 || kNWarps_ == 16);
static_assert(kNWarps_ == 12 || kNWarps_ == 16);
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kBlockK = kBlockK_;
static constexpr int kTiles = kTiles_;
static constexpr int TokenPackSize = TokenPackSize_;
static constexpr int M = M_;
static constexpr int K = K_;
static constexpr int WeightScaleGroup = WeightScaleGroup_;
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kBlockK = kBlockK_;
static constexpr int kTiles = kTiles_;
static constexpr int TokenPackSize = TokenPackSize_;
static constexpr int M = M_;
static constexpr int TAIL_N = TAIL_N_;
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kBlockK>>;
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kBlockK>>;
using TileShape_MNK_TAIL = Shape<Int<kBlockM>, Int<TAIL_N>, Int<kBlockK>>;
static constexpr int kClusterM = kClusterM_;
using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
static constexpr int kClusterM = kClusterM_;
using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
static constexpr int kStages = kStages_;
static_assert(kStages > 1);
static constexpr int kStages = kStages_;
static_assert(kStages > 1);
using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
using TiledMma = decltype(cute::make_tiled_mma(
cute::GMMA::
rs_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
AtomLayoutMNK{}));
using TiledMma = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
AtomLayoutMNK{}));
using SmemLayoutAtomA =
decltype(cutlass::gemm::collective::detail::rs_smem_selector<
GMMA::Major::K,
Element,
Int<kBlockM>,
Int<kBlockK / 2>>());
using TiledMma_TAIL = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK_TAIL>(),
AtomLayoutMNK{}));
using SmemLayoutA = decltype(tile_to_shape(
SmemLayoutAtomA{},
make_shape(Int<kBlockM>{}, Int<kBlockK / 2>{}, Int<kStages>{})));
using SmemLayoutAtomA = decltype(
cutlass::gemm::collective::detail::rs_smem_selector<
GMMA::Major::K, Element, Int<kBlockM>, Int<kBlockK / 2>>());
using SmemLayoutAtomB =
decltype(cutlass::gemm::collective::detail::rs_smem_selector<
GMMA::Major::K,
Element,
decltype(cute::get<1>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutA = decltype(
tile_to_shape(SmemLayoutAtomA{},
make_shape(Int<kBlockM>{}, Int<kBlockK / 2>{}, Int<kStages>{})));
using SmemLayoutB =
decltype(tile_to_shape(SmemLayoutAtomB{},
make_shape(shape<1>(TileShape_MNK{}),
shape<2>(TileShape_MNK{}),
Int<kStages>{})));
using SmemLayoutAtomC =
decltype(cutlass::gemm::collective::detail::rs_smem_selector<
GMMA::Major::K,
ElementOutput,
decltype(cute::get<0>(TileShape_MNK{})),
decltype(cute::get<1>(TileShape_MNK{}))>());
using SmemLayoutAtomB = decltype(
cutlass::gemm::collective::detail::rs_smem_selector<
GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutC =
decltype(tile_to_shape(SmemLayoutAtomC{}, select<0, 1>(TileShape_MNK{})));
using SmemLayoutB = decltype(
tile_to_shape(SmemLayoutAtomB{},
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using SmemCopyAtomAB = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
using SmemCopyAtomC = Copy_Atom<cute::SM90_U32x4_STSM_N, ElementOutput>;
using SmemLayoutAtomB_TAIL = decltype(
cutlass::gemm::collective::detail::rs_smem_selector<
GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK_TAIL{})),
decltype(cute::get<2>(TileShape_MNK_TAIL{}))>());
using SmemLayoutScale = Layout<Shape<Int<kBlockM>, Int<kStages>>>;
using SmemLayoutB_TAIL = decltype(
tile_to_shape(SmemLayoutAtomB_TAIL{},
make_shape(
shape<1>(TileShape_MNK_TAIL{}),
shape<2>(TileShape_MNK_TAIL{}),
Int<kStages>{})
));
using SharedStorage = SharedStorage<kStages,
Element,
ElementOutput,
SmemLayoutA,
SmemLayoutB,
SmemLayoutC,
SmemLayoutScale>;
using SmemLayoutAtomC = decltype(
cutlass::gemm::collective::detail::rs_smem_selector<
GMMA::Major::K, ElementOutput,
decltype(cute::get<0>(TileShape_MNK{})),
decltype(cute::get<1>(TileShape_MNK{}))>());
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
using PipelineState = typename cutlass::PipelineState<kStages>;
using SmemLayoutC = decltype(tile_to_shape(SmemLayoutAtomC{}, select<0, 1>(TileShape_MNK{})));
static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<OutputType>);
static constexpr int kNumThreadsPerRow = kBlockN / kNumVecElem;
// static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
using TiledCopyCAtom =
cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, OutputType>;
using TiledCopyCThrLayout = decltype(cute::make_layout(
cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
LayoutRight{}));
using TiledCopyCValLayout = decltype(cute::make_layout(
cute::make_shape(_1{}, Int<kNumVecElem>{}), LayoutRight{}));
using TiledCopyC =
decltype(make_tiled_copy(TiledCopyCAtom{},
TiledCopyCThrLayout{}, // Thr layout
TiledCopyCValLayout{} // Val layout
));
using SmemCopyAtomAB = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
using SmemCopyAtomC = Copy_Atom<cute::SM90_U32x4_STSM_N, ElementOutput>;
using SharedStorage = SharedStorage<
kStages, Element, ElementOutput, SmemLayoutA, SmemLayoutB, SmemLayoutC>;
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
using PipelineState = typename cutlass::PipelineState<kStages>;
static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<OutputType>);
static constexpr int kNumThreadsPerRow = kBlockN / kNumVecElem;
// static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
using TiledCopyCAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, OutputType>;
using TiledCopyCThrLayout = decltype(cute::make_layout(
cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
LayoutRight{}));
using TiledCopyCValLayout = decltype(cute::make_layout(
cute::make_shape(_1{}, Int<kNumVecElem>{}),
LayoutRight{}));
using TiledCopyC = decltype(make_tiled_copy(
TiledCopyCAtom{},
TiledCopyCThrLayout{}, // Thr layout
TiledCopyCValLayout{} // Val layout
));
};

View File

@@ -14,10 +14,10 @@
#pragma once
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include "cutlass/pipeline/pipeline.hpp"
#include "cute/tensor.hpp"
@@ -27,544 +27,368 @@
// #include "named_barrier.hpp"
#include "utils.hpp"
using namespace cute;
template <typename Ktraits>
struct CollectiveMainloopFwd {
using Element = typename Ktraits::Element;
using ElementOutput = typename Ktraits::ElementOutput;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
using ElementAccum = typename Ktraits::ElementAccum;
static constexpr int kStages = Ktraits::kStages;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockN = Ktraits::kBlockN;
static constexpr int kBlockK = Ktraits::kBlockK;
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int kTiles = Ktraits::kTiles;
static constexpr int M = Ktraits::M;
static constexpr int K = Ktraits::K;
static constexpr int TokenPackSize = Ktraits::TokenPackSize;
static constexpr int WeightScaleGroup = Ktraits::WeightScaleGroup;
using Element = typename Ktraits::Element;
using ElementOutput = typename Ktraits::ElementOutput;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using TileShape_MNK_TAIL = typename Ktraits::TileShape_MNK_TAIL;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
using ElementAccum = typename Ktraits::ElementAccum;
using GmemTiledCopy = cute::SM90_TMA_LOAD;
static constexpr int kStages = Ktraits::kStages;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockN = Ktraits::kBlockN;
static constexpr int TAIL_N = Ktraits::TAIL_N;
static constexpr int kBlockK = Ktraits::kBlockK;
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int kTiles = Ktraits::kTiles;
static constexpr int M = Ktraits::M;
static constexpr int TokenPackSize = Ktraits::TokenPackSize;
using SmemLayoutA = typename Ktraits::SmemLayoutA;
using SmemLayoutB = typename Ktraits::SmemLayoutB;
using SmemLayoutC = typename Ktraits::SmemLayoutC;
using SmemLayoutScale = typename Ktraits::SmemLayoutScale;
using GmemTiledCopy = cute::SM90_TMA_LOAD;
using ShapeT = cute::Shape<int64_t, int64_t, int64_t>;
using StrideT = cute::Shape<int64_t, _1, int64_t>;
using LayoutT = cute::Layout<ShapeT, StrideT>;
using ShapeTScale = cute::Shape<int64_t, int64_t, int64_t>;
using StrideTScale = cute::Shape<_1, int64_t, int64_t>;
using LayoutTScale = cute::Layout<ShapeTScale, StrideTScale>;
using SmemLayoutA = typename Ktraits::SmemLayoutA;
using SmemLayoutB = typename Ktraits::SmemLayoutB;
using SmemLayoutC = typename Ktraits::SmemLayoutC;
using SmemLayoutB_TAIL = typename Ktraits::SmemLayoutB_TAIL;
using TMA_A = decltype(make_tma_copy(
GmemTiledCopy{},
make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)),
ShapeT{},
StrideT{}),
SmemLayoutA{}(_, _, _0{}),
select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}),
size<0>(ClusterShape{})));
using ShapeT = cute::Shape<int64_t, int64_t, int64_t>;
using StrideT = cute::Shape<int64_t, _1, int64_t>;
using LayoutT = cute::Layout<ShapeT, StrideT>;
using TMA_B = decltype(make_tma_copy(
GmemTiledCopy{},
make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)),
ShapeT{},
StrideT{}),
take<0, 2>(SmemLayoutB{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{})));
using TMA_A = decltype(make_tma_copy(
GmemTiledCopy{},
make_tensor(
make_gmem_ptr(static_cast<Element const*>(nullptr)),
ShapeT{},
StrideT{}
),
SmemLayoutA{}(_, _, _0{}),
select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}),
size<0>(ClusterShape{})));
using TMA_Scale = decltype(make_tma_copy(
GmemTiledCopy{},
make_tensor(make_gmem_ptr(static_cast<float const*>(nullptr)),
ShapeTScale{},
StrideTScale{}),
SmemLayoutScale{}(_, _0{}),
select<0>(Shape<Int<kBlockM>>{}),
size<0>(ClusterShape{})));
using TMA_B = decltype(make_tma_copy(
GmemTiledCopy{},
make_tensor(
make_gmem_ptr(static_cast<Element const*>(nullptr)),
ShapeT{},
StrideT{}
),
take<0, 2>(SmemLayoutB{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{})));
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{});
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
using SmemCopyAtomAB = typename Ktraits::SmemCopyAtomAB;
using SmemCopyAtomC = typename Ktraits::SmemCopyAtomC;
using TiledCopyC = typename Ktraits::TiledCopyC;
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{});
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
using SmemCopyAtomAB = typename Ktraits::SmemCopyAtomAB;
using SmemCopyAtomC = typename Ktraits::SmemCopyAtomC;
using TiledCopyC = typename Ktraits::TiledCopyC;
static constexpr uint32_t TmaTransactionBytesA = static_cast<uint32_t>(
size(take<0, 2>(SmemLayoutA{})) * cutlass::sizeof_bits_v<Element> / 8);
static constexpr uint32_t TmaTransactionBytesB = static_cast<uint32_t>(
size(take<0, 2>(SmemLayoutB{})) * cutlass::sizeof_bits_v<Element> / 8);
static constexpr uint32_t TmaTransactionBytesScale = static_cast<uint32_t>(
size(SmemLayoutScale{}(_, _0{})) * cutlass::sizeof_bits_v<float> / 8);
static constexpr uint32_t TmaTransactionBytesA = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutA{})) * cutlass::sizeof_bits_v<Element> / 8);
static constexpr uint32_t TmaTransactionBytesB = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutB{})) * cutlass::sizeof_bits_v<Element> / 8);
struct Arguments {
Element const* ptr_A;
LayoutT layout_A;
Element const* ptr_B;
LayoutT layout_B;
ElementOutput* ptr_C;
LayoutT layout_C;
const float* weight_scale;
LayoutTScale layout_Scale;
const float* input_scale;
const int64_t* tokens;
};
struct Params {
LayoutT layout_A;
LayoutT layout_B;
LayoutTScale layout_Scale;
TMA_A tma_load_A;
TMA_B tma_load_B;
TMA_Scale tma_load_Scale;
ElementOutput* ptr_C;
const float* weight_scale;
const float* input_scale;
const int64_t* tokens;
};
Params static to_underlying_arguments(Arguments const& args) {
Tensor mA = make_tensor(make_gmem_ptr(args.ptr_A), args.layout_A);
TMA_A tma_load_A =
make_tma_copy(GmemTiledCopy{},
mA,
SmemLayoutA{}(_, _, _0{}),
select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}),
size<0>(ClusterShape{}));
Tensor mB = make_tensor(make_gmem_ptr(args.ptr_B), args.layout_B);
TMA_B tma_load_B = make_tma_copy(GmemTiledCopy{},
mB,
SmemLayoutB{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{}));
Tensor mScale =
make_tensor(make_gmem_ptr(args.weight_scale), args.layout_Scale);
TMA_Scale tma_load_Scale = make_tma_copy(GmemTiledCopy{},
mScale,
SmemLayoutScale{}(_, _0{}),
select<0>(Shape<Int<kBlockM>>{}),
size<0>(ClusterShape{}));
return {args.layout_A,
args.layout_B,
args.layout_Scale,
tma_load_A,
tma_load_B,
tma_load_Scale,
args.ptr_C,
args.weight_scale,
args.input_scale,
args.tokens};
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params) {
cute::prefetch_tma_descriptor(
mainloop_params.tma_load_A.get_tma_descriptor());
cute::prefetch_tma_descriptor(
mainloop_params.tma_load_B.get_tma_descriptor());
if constexpr (WeightScaleGroup < K) {
cute::prefetch_tma_descriptor(
mainloop_params.tma_load_Scale.get_tma_descriptor());
}
}
template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
CUTLASS_DEVICE void store(Params const& mainloop_params,
FrgTensorO& tOrO,
SharedStorage& shared_storage,
TiledMma tiled_mma,
const float* weight_scale,
const float* input_scale,
const int64_t tokens,
const int64_t pre_fix_tokens,
const int bidm,
const int bidn,
const int bidb,
const int tidx) {
using packHalf = typename PackedHalf<ElementOutput>::Type;
Tensor tOrO_out = make_tensor<ElementOutput>(tOrO.layout());
if (input_scale != nullptr) {
const int lane_id = tidx % 4 * 2;
if constexpr (WeightScaleGroup == K) {
#pragma unroll
for (int i = 0; i < size(tOrO); i += 4) {
const int scale_idx = i * 2 + lane_id;
tOrO[i] = tOrO[i] * weight_scale[0] * input_scale[scale_idx];
tOrO[i + 1] =
tOrO[i + 1] * weight_scale[0] * input_scale[scale_idx + 1];
tOrO[i + 2] = tOrO[i + 2] * weight_scale[1] * input_scale[scale_idx];
tOrO[i + 3] =
tOrO[i + 3] * weight_scale[1] * input_scale[scale_idx + 1];
*reinterpret_cast<packHalf*>(&tOrO_out[i]) =
packHalf(tOrO[i], tOrO[i + 2]);
*reinterpret_cast<packHalf*>(&tOrO_out[i + 2]) =
packHalf(tOrO[i + 1], tOrO[i + 3]);
}
} else {
#pragma unroll
for (int i = 0; i < size(tOrO); i += 4) {
const int scale_idx = i * 2 + lane_id;
*reinterpret_cast<packHalf*>(&tOrO_out[i]) =
packHalf(float(tOrO[i]) * input_scale[scale_idx],
float(tOrO[i + 2]) * input_scale[scale_idx]);
*reinterpret_cast<packHalf*>(&tOrO_out[i + 2]) =
packHalf(float(tOrO[i + 1]) * input_scale[scale_idx + 1],
float(tOrO[i + 3]) * input_scale[scale_idx + 1]);
}
}
} else {
if constexpr (WeightScaleGroup == K) {
#pragma unroll
for (int i = 0; i < size(tOrO); i += 4) {
tOrO[i] = (tOrO[i]) * weight_scale[0];
tOrO[i + 1] = tOrO[i + 1] * weight_scale[0];
tOrO[i + 2] = tOrO[i + 2] * weight_scale[1];
tOrO[i + 3] = tOrO[i + 3] * weight_scale[1];
*reinterpret_cast<packHalf*>(&tOrO_out[i]) =
packHalf(tOrO[i], tOrO[i + 2]);
*reinterpret_cast<packHalf*>(&tOrO_out[i + 2]) =
packHalf(tOrO[i + 1], tOrO[i + 3]);
}
} else {
#pragma unroll
for (int i = 0; i < size(tOrO); i += 4) {
*reinterpret_cast<packHalf*>(&tOrO_out[i]) =
packHalf(float(tOrO[i]), float(tOrO[i + 2]));
*reinterpret_cast<packHalf*>(&tOrO_out[i + 2]) =
packHalf(float(tOrO[i + 1]), float(tOrO[i + 3]));
}
}
}
uint16_t* smem_c =
reinterpret_cast<uint16_t*>(shared_storage.smem_c.data());
uint32_t* reg_data = reinterpret_cast<uint32_t*>(tOrO_out.data());
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
constexpr int k_copy_times = kBlockN / 16;
#pragma unroll
for (int i = 0; i < k_copy_times; i++) {
uint32_t smem_ptr = cast_smem_ptr_to_uint(
reinterpret_cast<uint128_t*>(smem_c + i * 16 * 128) + tidx);
#if defined(CUTE_ARCH_STSM_SM90_ENABLED)
asm volatile(
"stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, "
"%4};\n" ::"r"(smem_ptr),
"r"(reg_data[4 * i + 0]),
"r"(reg_data[4 * i + 2]),
"r"(reg_data[4 * i + 1]),
"r"(reg_data[4 * i + 3]));
#endif
}
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
const int expert_idx =
TokenPackSize == 0 ? pre_fix_tokens * M : bidb * M * TokenPackSize;
ElementOutput* store_c = mainloop_params.ptr_C + expert_idx +
bidn * (M * kBlockN) + bidm * kBlockM;
const int reamin_tokens = tokens - bidn * kBlockN;
const int col = tidx % 2;
constexpr int kPackSize = 16 / sizeof(ElementOutput);
constexpr int kNumVecElem = kBlockM / kPackSize;
constexpr int copy_len = kBlockN * kNumVecElem;
#pragma unroll
for (int idx = tidx; idx < copy_len; idx += NumMmaThreads) {
const int idx_div2 = idx / 2;
const int store_idx = idx_div2 / 128 * 128 + idx_div2 % 8 * 16 +
idx_div2 % 128 / 16 + idx_div2 % 16 / 8 * 8;
const int store_global_idx = store_idx * 2 + col;
const int row = store_global_idx / kNumVecElem;
const int col = store_global_idx % kNumVecElem;
if (row >= reamin_tokens) {
continue;
}
const int offset = row * (M / kPackSize) + col;
reinterpret_cast<uint4*>(store_c)[offset] =
reinterpret_cast<uint4*>(smem_c)[idx];
}
}
template <typename MTensor>
CUTLASS_DEVICE auto get_local_no_packed_tensor(const MTensor& mB,
const int pre_fix_token,
const int actual_token,
const int bidn) const {
auto g_tensor = domain_offset(make_coord(pre_fix_token, _0{}), mB(_, _, 0));
Tensor gB = local_tile(
g_tensor, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
return gB;
}
template <typename SharedStorage>
CUTLASS_DEVICE void load(Params const& mainloop_params,
MainloopPipeline pipeline,
PipelineState& smem_pipe_write,
SharedStorage& shared_storage,
const int tokens,
const int pre_fix_tokens,
const int bidm,
const int bidn,
const int bidb,
const int tidx) {
Tensor sA =
make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
Tensor sB =
make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{});
Tensor sScale = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()),
SmemLayoutScale{});
Tensor mA = mainloop_params.tma_load_A.get_tma_tensor(
mainloop_params.layout_A.shape());
Tensor mB = mainloop_params.tma_load_B.get_tma_tensor(
mainloop_params.layout_B.shape());
Tensor mScale = mainloop_params.tma_load_Scale.get_tma_tensor(
mainloop_params.layout_Scale.shape());
Tensor gA =
local_tile(mA(_, _, bidb),
select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}),
make_coord(bidm, _));
Tensor gScale = local_tile(
mScale(_, bidm, bidb), select<0>(Shape<Int<kBlockM>>{}), make_coord(_));
auto [tAgA, tAsA] = tma_partition(mainloop_params.tma_load_A,
_0{},
Layout<ClusterShape>{},
group_modes<0, 2>(sA),
group_modes<0, 2>(gA));
if constexpr (TokenPackSize == 0) {
Tensor gB = get_local_no_packed_tensor(mB, pre_fix_tokens, tokens, bidn);
auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B,
_0{},
Layout<ClusterShape>{},
group_modes<0, 2>(sB),
group_modes<0, 2>(gB));
if (tidx == 0) {
#pragma unroll
for (int kiter = 0; kiter < kTiles; ++kiter) {
pipeline.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_A.with(
*pipeline.producer_get_barrier(smem_pipe_write), 0),
tAgA(_, kiter),
tAsA(_, smem_pipe_write.index()));
copy(mainloop_params.tma_load_B.with(
*pipeline.producer_get_barrier(smem_pipe_write), 0),
tBgB(_, kiter),
tBsB(_, smem_pipe_write.index()));
if constexpr (WeightScaleGroup < K) {
copy(mainloop_params.tma_load_Scale.with(
*pipeline.producer_get_barrier(smem_pipe_write), 0),
gScale(_, kiter),
sScale(_, smem_pipe_write.index()));
}
++smem_pipe_write;
}
}
} else {
auto mB_this_expert = make_tensor(
mB(_, _, bidb).data(),
make_layout(cute::make_shape(tokens, size<1>(mB)), mB.stride()));
Tensor gB = local_tile(
mB_this_expert, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B,
_0{},
Layout<ClusterShape>{},
group_modes<0, 2>(sB),
group_modes<0, 2>(gB));
if (tidx == 0) {
#pragma unroll
for (int kiter = 0; kiter < kTiles; ++kiter) {
pipeline.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_A.with(
*pipeline.producer_get_barrier(smem_pipe_write), 0),
tAgA(_, kiter),
tAsA(_, smem_pipe_write.index()));
copy(mainloop_params.tma_load_B.with(
*pipeline.producer_get_barrier(smem_pipe_write), 0),
tBgB(_, kiter),
tBsB(_, smem_pipe_write.index()));
if constexpr (WeightScaleGroup < K) {
copy(mainloop_params.tma_load_Scale.with(
*pipeline.producer_get_barrier(smem_pipe_write), 0),
gScale(_, kiter),
sScale(_, smem_pipe_write.index()));
}
++smem_pipe_write;
}
}
}
}
template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
CUTLASS_DEVICE void mma(Params const& mainloop_params,
TiledMma tiled_mma,
MainloopPipeline pipeline,
PipelineState& smem_pipe_read,
SharedStorage& shared_storage,
FrgTensorO& tSrS,
const int tidx) {
Tensor sA =
make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
Tensor sB =
make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{});
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
auto threadMma = tiled_mma.get_thread_slice(tidx);
auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomAB{}, tiled_mma);
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(tidx);
Tensor tSrA = threadMma.partition_fragment_A(sA(_, _, 0));
Tensor tSrB = threadMma.partition_fragment_B(sB);
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
};
#pragma unroll
for (int kiter = 0; kiter < kTiles; ++kiter) {
Tensor tSsA =
smem_thr_copy_A.partition_S(sA(_, _, smem_pipe_read.index()));
consumer_wait(pipeline, smem_pipe_read);
gemm</*wg_wait=*/0>(tiled_mma,
tSrA,
tSsA,
tSrB(_, _, _, smem_pipe_read.index()),
tSrS,
smem_tiled_copy_A,
smem_thr_copy_A);
pipeline.consumer_release(smem_pipe_read);
++smem_pipe_read;
}
}
template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
CUTLASS_DEVICE void mma_pipeline(Params const& mainloop_params,
TiledMma tiled_mma,
MainloopPipeline pipeline,
PipelineState& smem_pipe_read,
SharedStorage& shared_storage,
FrgTensorO& tSrS,
const int tidx) {
Tensor sA =
make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
Tensor sB =
make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{});
float2* weight_scale =
reinterpret_cast<float2*>(shared_storage.smem_scale.data()) + tidx / 4;
Tensor tSrS1 = make_fragment_like(tSrS);
Tensor tSrS2 = make_fragment_like(tSrS);
__half2* tSrS_data =
reinterpret_cast<__half2*>(raw_pointer_cast(tSrS.data()));
__half2* tSrS1_data =
reinterpret_cast<__half2*>(raw_pointer_cast(tSrS1.data()));
__half2* tSrS2_data =
reinterpret_cast<__half2*>(raw_pointer_cast(tSrS2.data()));
auto threadMma = tiled_mma.get_thread_slice(tidx);
auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomAB{}, tiled_mma);
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(tidx);
Tensor tSrA = threadMma.partition_fragment_A(sA(_, _, 0));
Tensor tSrB = threadMma.partition_fragment_B(sB);
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
struct Arguments {
Element const* ptr_A;
LayoutT layout_A;
Element const* ptr_B;
LayoutT layout_B;
ElementOutput * ptr_C;
LayoutT layout_C;
const float *weight_scale;
const float *input_row_sum;
const int64_t * tokens;
};
__half2 scale1, scale2, scale3, scale4;
float2 scale_cur_k;
#pragma unroll
for (int kiter = 0; kiter < kTiles;) {
Tensor tSsA1 =
smem_thr_copy_A.partition_S(sA(_, _, smem_pipe_read.index()));
consumer_wait(pipeline, smem_pipe_read);
scale_cur_k = *(weight_scale + smem_pipe_read.index() * (kBlockM / 2));
scale1 = __half2(scale_cur_k.x, scale_cur_k.x);
scale2 = __half2(scale_cur_k.y, scale_cur_k.y);
struct Params {
LayoutT layout_A;
LayoutT layout_B;
TMA_A tma_load_A;
TMA_B tma_load_B;
ElementOutput * ptr_C;
const float *weight_scale;
const float *input_row_sum;
const int64_t * tokens;
};
gemm</*wg_wait=*/0>(tiled_mma,
tSrA,
tSsA1,
tSrB(_, _, _, smem_pipe_read.index()),
tSrS1,
smem_tiled_copy_A,
smem_thr_copy_A);
pipeline.consumer_release(smem_pipe_read);
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
if (kiter > 0) {
for (int i = 0; i < size(tSrS) / 2; i += 2) {
tSrS_data[i] = __hfma2(tSrS2_data[i], scale3, tSrS_data[i]);
tSrS_data[i + 1] =
__hfma2(tSrS2_data[i + 1], scale4, tSrS_data[i + 1]);
Params static
to_underlying_arguments(Arguments const& args) {
Tensor mA = make_tensor(make_gmem_ptr(args.ptr_A), args.layout_A);
TMA_A tma_load_A = make_tma_copy(
GmemTiledCopy{},
mA,
SmemLayoutA{}(_, _, _0{}),
select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}),
size<0>(ClusterShape{}));
Tensor mB = make_tensor(make_gmem_ptr(args.ptr_B), args.layout_B);
TMA_B tma_load_B = make_tma_copy(
GmemTiledCopy{},
mB,
SmemLayoutB{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{}));
return {args.layout_A, args.layout_B, tma_load_A, tma_load_B,
args.ptr_C, args.weight_scale, args.input_row_sum, args.tokens};
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params) {
cute::prefetch_tma_descriptor(mainloop_params.tma_load_A.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_B.get_tma_descriptor());
}
template <int CUR_N, typename SharedStorage, typename FrgTensorO, typename TiledMma>
CUTLASS_DEVICE void
store(Params const& mainloop_params,
FrgTensorO & tOrO,
SharedStorage& shared_storage,
TiledMma tiled_mma,
const float *input_row_sum,
const float *weight_scale,
const int64_t tokens,
const int64_t pre_fix_tokens,
const int bidm,
const int bidn,
const int bidb,
const int tidx) {
using packHalf = typename PackedHalf<ElementOutput>::Type;
Tensor tOrO_out = make_tensor<ElementOutput>(tOrO.layout());
#pragma unroll
for (int i = 0; i < size(tOrO); i+=4) {
const int sum_idx = i * 2;
tOrO[i] = (tOrO[i] + input_row_sum[sum_idx]) * weight_scale[0];
tOrO[i + 1] = (tOrO[i + 1] + input_row_sum[sum_idx + 1]) * weight_scale[0];
tOrO[i + 2] = (tOrO[i + 2] + input_row_sum[sum_idx]) * weight_scale[1];
tOrO[i + 3] = (tOrO[i + 3] + input_row_sum[sum_idx + 1]) * weight_scale[1];
*reinterpret_cast<packHalf*>(&tOrO_out[i]) = packHalf(tOrO[i], tOrO[i + 2]);
*reinterpret_cast<packHalf*>(&tOrO_out[i + 2]) = packHalf(tOrO[i + 1], tOrO[i + 3]);
}
}
++smem_pipe_read;
++kiter;
uint16_t *smem_c = reinterpret_cast<uint16_t *>(shared_storage.smem_c.data());
if (kiter < kTiles) {
Tensor tSsA2 =
smem_thr_copy_A.partition_S(sA(_, _, smem_pipe_read.index()));
consumer_wait(pipeline, smem_pipe_read);
scale_cur_k = *(weight_scale + smem_pipe_read.index() * (kBlockM / 2));
scale3 = __half2(scale_cur_k.x, scale_cur_k.x);
scale4 = __half2(scale_cur_k.y, scale_cur_k.y);
uint32_t * reg_data = reinterpret_cast<uint32_t*>(tOrO_out.data());
gemm</*wg_wait=*/0>(tiled_mma,
tSrA,
tSsA2,
tSrB(_, _, _, smem_pipe_read.index()),
tSrS2,
smem_tiled_copy_A,
smem_thr_copy_A);
pipeline.consumer_release(smem_pipe_read);
++smem_pipe_read;
++kiter;
}
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
for (int i = 0; i < size(tSrS) / 2; i += 2) {
tSrS_data[i] = __hfma2(tSrS1_data[i], scale1, tSrS_data[i]);
tSrS_data[i + 1] = __hfma2(tSrS1_data[i + 1], scale2, tSrS_data[i + 1]);
}
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
constexpr int k_copy_times = CUR_N / 16;
#pragma unroll
for (int i = 0; i < k_copy_times; i++) {
uint32_t smem_ptr = cast_smem_ptr_to_uint(reinterpret_cast<uint128_t*>(smem_c + i * 16 * 128) + tidx);
#if defined(CUTE_ARCH_STSM_SM90_ENABLED)
asm volatile (
"stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
:: "r"(smem_ptr), "r"(reg_data[4 * i + 0]), "r"(reg_data[4 * i + 2]), "r"(reg_data[4 * i + 1]), "r"(reg_data[4 * i + 3]));
#endif
}
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
const int batch_idx = TokenPackSize == 0 ? pre_fix_tokens * M : bidb * M * TokenPackSize;
ElementOutput * store_c = mainloop_params.ptr_C + batch_idx + bidn * (M * kBlockN) + bidm * kBlockM;
const int reamin_tokens = tokens - bidn * kBlockN;
const int col = tidx % 2;
constexpr int kPackSize = 16 / sizeof(ElementOutput);
constexpr int kNumVecElem = kBlockM / kPackSize;
constexpr int copy_len = CUR_N * kNumVecElem;
#pragma unroll
for (int idx = tidx; idx < copy_len; idx += NumMmaThreads) {
const int idx_div2 = idx / 2;
const int store_idx = idx_div2 / 128 * 128 + idx_div2 % 8 * 16 + idx_div2 % 128 / 16 + idx_div2 % 16 / 8 * 8;
const int store_global_idx = store_idx * 2 + col;
const int row = store_global_idx / kNumVecElem;
const int col = store_global_idx % kNumVecElem;
if (row >= reamin_tokens) {
continue;
}
const int offset = row * (M / kPackSize) + col;
reinterpret_cast<uint4*>(store_c)[offset] = reinterpret_cast<uint4*>(smem_c)[idx];
}
}
if constexpr (kTiles % 2 == 0) {
for (int i = 0; i < size(tSrS) / 2; i += 2) {
tSrS_data[i] = __hfma2(tSrS2_data[i], scale3, tSrS_data[i]);
tSrS_data[i + 1] = __hfma2(tSrS2_data[i + 1], scale4, tSrS_data[i + 1]);
}
template <typename MTensor>
CUTLASS_DEVICE auto get_local_no_packed_tensor(
const MTensor &mB,
const int pre_fix_token,
const int actual_token,
const int bidn) const {
auto g_tensor = domain_offset(make_coord(pre_fix_token, _0{}), mB(_, _, 0));
Tensor gB = local_tile(g_tensor, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
return gB;
}
template <typename SharedStorage>
CUTLASS_DEVICE void
load(Params const& mainloop_params,
MainloopPipeline pipeline,
PipelineState& smem_pipe_write,
SharedStorage &shared_storage,
const int tokens,
const int pre_fix_tokens,
const int bidm,
const int bidn,
const int bidb,
const int tidx) {
Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{});
Tensor mA = mainloop_params.tma_load_A.get_tma_tensor(mainloop_params.layout_A.shape());
Tensor mB = mainloop_params.tma_load_B.get_tma_tensor(mainloop_params.layout_B.shape());
Tensor gA = local_tile(mA(_, _, bidb), select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}), make_coord(bidm, _));
auto [tAgA, tAsA] = tma_partition(mainloop_params.tma_load_A, _0{}, Layout<ClusterShape>{}, group_modes<0, 2>(sA), group_modes<0, 2>(gA));
const int kIters = kTiles / kStages;
if constexpr (TokenPackSize == 0) {
Tensor gB = get_local_no_packed_tensor(
mB,
pre_fix_tokens,
tokens,
bidn);
auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B, _0{}, Layout<ClusterShape>{}, group_modes<0, 2>(sB), group_modes<0, 2>(gB));
if (tidx == 0) {
#pragma unroll
for (int kiter = 0; kiter < kIters; ++kiter) {
#pragma unroll
for (int s = 0; s < kStages; s++) {
const int i = kiter * kStages + s;
pipeline.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
++smem_pipe_write;
}
}
#pragma unroll
for (int i = kIters * kStages; i < kTiles; ++i) {
pipeline.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
++smem_pipe_write;
}
}
} else {
auto mB_this_batch = make_tensor(
mB(_, _, bidb).data(),
make_layout(
cute::make_shape(tokens, size<1>(mB)),
mB.stride()
));
Tensor gB = local_tile(mB_this_batch, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B, _0{}, Layout<ClusterShape>{}, group_modes<0, 2>(sB), group_modes<0, 2>(gB));
if (tidx == 0) {
#pragma unroll
for (int kiter = 0; kiter < kIters; ++kiter) {
#pragma unroll
for (int s = 0; s < kStages; s++) {
const int i = kiter * kStages + s;
pipeline.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
++smem_pipe_write;
}
}
#pragma unroll
for (int i = kIters * kStages; i < kTiles; ++i) {
pipeline.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
++smem_pipe_write;
}
}
}
}
template <int CUR_N, typename SharedStorage, typename FrgTensorO, typename TiledMma>
CUTLASS_DEVICE void
mma(Params const& mainloop_params,
TiledMma tiled_mma,
MainloopPipeline pipeline,
PipelineState& smem_pipe_read,
SharedStorage& shared_storage,
FrgTensorO &tSrS,
const int tidx) {
using sMemBLayout = std::conditional_t<
CUR_N == kBlockN,
SmemLayoutB,
SmemLayoutB_TAIL
>;
Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), sMemBLayout{});
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
auto threadMma = tiled_mma.get_thread_slice(tidx);
auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomAB{}, tiled_mma);
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(tidx);
Tensor tSrA = threadMma.partition_fragment_A(sA(_, _, 0));
Tensor tSrB = threadMma.partition_fragment_B(sB);
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
};
const int kIters = kTiles / kStages;
constexpr int B_STEPS = CUR_N == 0 ? 1 : (kBlockN / CUR_N);
#pragma unroll
for (int kiter = 0; kiter < kIters; ++kiter) {
#pragma unroll
for (int s = 0; s < kStages; s++) {
Tensor tSsA = smem_thr_copy_A.partition_S(sA(_, _, s));
consumer_wait(pipeline, smem_pipe_read);
gemm</*wg_wait=*/0>(tiled_mma, tSrA, tSsA, tSrB(_, _, _, s * B_STEPS), tSrS, smem_tiled_copy_A, smem_thr_copy_A);
pipeline.consumer_release(smem_pipe_read);
++smem_pipe_read;
}
}
#pragma unroll
for (int i = 0; i < kTiles % kStages; ++i) {
Tensor tSsA = smem_thr_copy_A.partition_S(sA(_, _, i));
consumer_wait(pipeline, smem_pipe_read);
gemm</*wg_wait=*/0>(tiled_mma, tSrA, tSsA, tSrB(_, _, _, i * B_STEPS), tSrS, smem_tiled_copy_A, smem_thr_copy_A);
pipeline.consumer_release(smem_pipe_read);
++smem_pipe_read;
}
}
}
};

View File

@@ -24,116 +24,91 @@
#include <cuda_bf16.h>
#endif
#include <cute/arch/cluster_sm90.hpp> // For cute::elect_one_sync()
#include <cute/tensor.hpp>
#include <cute/arch/cluster_sm90.hpp> // For cute::elect_one_sync()
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
using namespace cute;
template <typename T>
template<typename T>
struct PackedHalf;
template <>
template<>
struct PackedHalf<cutlass::half_t> {
using Type = __half2;
using Type = __half2;
};
template <>
template<>
struct PackedHalf<cutlass::bfloat16_t> {
using Type = nv_bfloat162;
using Type = nv_bfloat162;
};
template <typename To_type, typename Engine, typename Layout>
__forceinline__ __device__ auto convert_type(
Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
auto frag =
convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(
tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}
template <int numel>
__forceinline__ __device__ void convert_c4_2_fp8(const int32_t *src,
int32_t *dst1,
int32_t *dst2) {
#pragma unroll
for (int i = 0; i < numel; ++i) {
uint32_t head1 = src[i] & 0x80808080;
dst1[i] = (src[i] >> 4) & 0x07070707;
dst1[i] = dst1[i] | head1;
uint32_t head2 = (src[i] & 0x08080808) << 4;
dst2[i] = src[i] & 0x07070707;
dst2[i] = dst2[i] | head2;
}
}
template <int wg_wait = 0,
bool arrive = true,
bool commit = true,
typename Tensor0,
typename Tensor1,
typename Tensor2,
typename Tensor3,
typename TiledMma,
typename ThrCopyA,
typename TiledCopyA>
__forceinline__ __device__ void gemm(TiledMma &tiled_mma,
Tensor0 &tCrA,
Tensor1 &tCsA,
Tensor2 const &tCrB,
Tensor3 &tCrC,
TiledCopyA const &tiled_copy_A,
ThrCopyA const &thr_copy_A) {
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator,
typename TiledMma::FrgTypeA>::value;
Tensor tCrA1 = make_tensor<cutlass::float_e4m3_t>(tCrA.layout());
Tensor tCrA2 = make_tensor<cutlass::float_e4m3_t>(tCrA.layout());
if constexpr (Is_RS) {
warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA));
}
warpgroup_fence_operand(tCrC);
if constexpr (arrive) {
warpgroup_arrive();
}
constexpr int numel = decltype(size(tCrA(_, _, 0)))::value / 4;
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
cute::copy(tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{}));
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
if (k_block < size<2>(tCrA) - 1) {
cute::copy(tiled_copy_A,
tCsA(_, _, k_block + 1),
tCrA_copy_view(_, _, k_block + 1));
__forceinline__ __device__ void convert_c4_2_fp8(const int32_t * src, int32_t * dst1, int32_t * dst2) {
#pragma unroll
for (int i = 0; i < numel; ++i) {
dst1[i] = (src[i] >> 4) & 0x0f0f0f0f;
dst2[i] = src[i] & 0x0f0f0f0f;
}
int32_t *tCrA_data =
reinterpret_cast<int32_t *>(tCrA(_, _, k_block).data());
int32_t *tCrA1_data =
reinterpret_cast<int32_t *>(tCrA1(_, _, k_block).data());
int32_t *tCrA2_data =
reinterpret_cast<int32_t *>(tCrA2(_, _, k_block).data());
convert_c4_2_fp8<numel>(tCrA_data, tCrA1_data, tCrA2_data);
cute::gemm(tiled_mma, tCrA1(_, _, k_block), tCrB(_, _, 2 * k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
cute::gemm(
tiled_mma, tCrA2(_, _, k_block), tCrB(_, _, 2 * k_block + 1), tCrC);
}
if constexpr (commit) {
warpgroup_commit_batch();
}
if constexpr (wg_wait >= 0) {
warpgroup_wait<wg_wait>();
}
warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) {
warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA));
}
}
template <int wg_wait=0, bool arrive=true,
bool commit=true, typename Tensor0, typename Tensor1,
typename Tensor2, typename Tensor3, typename TiledMma,
typename ThrCopyA, typename TiledCopyA>
__forceinline__ __device__ void gemm(
TiledMma &tiled_mma,
Tensor0 &tCrA,
Tensor1 &tCsA,
Tensor2 const &tCrB,
Tensor3 &tCrC,
TiledCopyA const &tiled_copy_A,
ThrCopyA const &thr_copy_A) {
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
Tensor tCrA1 = make_tensor<cutlass::float_e4m3_t>(tCrA.layout());
Tensor tCrA2 = make_tensor<cutlass::float_e4m3_t>(tCrA.layout());
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
warpgroup_fence_operand(tCrC);
if constexpr (arrive) {
warpgroup_arrive();
}
constexpr int numel = decltype(size(tCrA(_, _, 0)))::value / 4;
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
cute::copy(tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{}));
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
if (k_block < size<2>(tCrA) - 1) {
cute::copy(tiled_copy_A, tCsA(_, _, k_block + 1), tCrA_copy_view(_, _, k_block + 1));
}
int32_t * tCrA_data = reinterpret_cast<int32_t *>(tCrA(_,_,k_block).data());
int32_t * tCrA1_data = reinterpret_cast<int32_t *>(tCrA1(_,_,k_block).data());
int32_t * tCrA2_data = reinterpret_cast<int32_t *>(tCrA2(_,_,k_block).data());
convert_c4_2_fp8<numel>(tCrA_data, tCrA1_data, tCrA2_data);
cute::gemm(tiled_mma, tCrA1(_,_,k_block), tCrB(_,_,2 * k_block), tCrC);
cute::gemm(tiled_mma, tCrA2(_,_,k_block), tCrB(_,_, 2 * k_block + 1), tCrC);
}
if constexpr (commit) {
warpgroup_commit_batch();
}
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
}

View File

@@ -16,179 +16,239 @@
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
#include "w4afp8_gemm.h"
#include "helper.h"
#include "paddle/extension.h"
#include "w4afp8_gemm_template.h"
#include "weight_kernel.hpp"
#include "weight_scale_kernel.hpp"
#include "w4afp8_gemm.h"
template <typename T>
class NVTraits;
template <>
class NVTraits<__nv_fp8_e4m3> {
public:
typedef cutlass::float_e4m3_t data_t;
void weight_convert(const uint8_t *weight, uint8_t *weight_new, int batch, int M, int K) {
assert(K % 64 == 0);
for (int b = 0; b < batch; ++b) {
for (int m = 0; m < M; ++m) {
for (int k = 0; k < K; k+=64) {
for (int k_inner = 0; k_inner < 32; ++k_inner) {
uint8_t temp = 0;
uint8_t left = weight[b * M * K + m * K + k + k_inner];
uint8_t right = weight[b * M * K + m * K + k + k_inner + 32];
temp |= left << 4;
temp |= right;
weight_new[b * M * K / 2 + m * K / 2 + k / 2 + k_inner] = *reinterpret_cast<uint8_t*>(&temp);
}
}
}
}
}
template <typename T> class NVTraits;
template <> class NVTraits<__nv_fp8_e4m3> {
public:
typedef cutlass::float_e4m3_t data_t;
};
template <>
class NVTraits<__nv_bfloat16> {
public:
typedef cutlass::bfloat16_t data_t;
template <> class NVTraits<__nv_bfloat16>{
public:
typedef cutlass::bfloat16_t data_t;
};
template <>
class NVTraits<half> {
public:
typedef cutlass::half_t data_t;
template <> class NVTraits<half>{
public:
typedef cutlass::half_t data_t;
};
template <typename OutputType>
void DisPatchW4AFp8Gemm(const cutlass::float_e4m3_t* input,
const cutlass::float_e4m3_t* weight,
const int64_t* tokens,
const float* weight_scale,
const float* input_dequant_scale,
OutputType* out,
const int64_t token_padding_size,
const int64_t max_tokens,
const int Experts,
const int64_t M,
const int64_t K,
const int WeightScaleGroup,
cudaStream_t stream) {
int kBlockN = 256;
if constexpr (std::is_same_v<OutputType, cutlass::bfloat16_t>) {
GEMM_SWITCH_BF16(M,
K,
Experts,
token_padding_size,
kBlockN,
WeightScaleGroup,
weight,
input,
out,
weight_scale,
input_dequant_scale,
tokens,
max_tokens,
stream)
} else {
PD_THROW("Only supported dtype in ['BFLOAT16'].");
}
void DisPatchW4AFp8Gemm(
const cutlass::float_e4m3_t* input,
const cutlass::float_e4m3_t* weight,
const int64_t * tokens,
const float * input_row_sum,
const float * weight_scale,
OutputType * out,
const int64_t token_padding_size,
const int64_t max_tokens,
const int batch_size,
const int64_t M,
const int64_t K,
cudaStream_t stream) {
int kBlockN = 256;
int TailN = 0;
if constexpr (std::is_same_v<OutputType, cutlass::bfloat16_t>) {
GEMM_SWITCH_BF16(
M, K, batch_size, token_padding_size, kBlockN, TailN,
weight,
input,
out,
weight_scale,
input_row_sum,
tokens,
max_tokens,
stream)
} else {
PD_THROW("Only supported dtype in ['BFLOAT16'].");
}
}
std::vector<paddle::Tensor> W4AFp8Gemm(
const paddle::Tensor& input,
const paddle::Tensor& weight,
const paddle::Tensor&
tokens, // If tokenpadding=0, this tensor represents the prefix sum of
// tensors, otherwise it represents the number of tokens in
// each group
const paddle::Tensor& weight_scale,
const paddle::optional<paddle::Tensor>& input_dequant_scale,
const int64_t token_padding_size,
const int64_t max_tokens,
const bool is_bfloat16) {
const int Experts = weight.dims()[0];
const int M = weight.dims()[1];
const int K = weight.dims()[2] * 2;
const int WeightScaleGroup =
weight_scale.dims().size() == 2 ? K : weight_scale.dims()[3];
const paddle::Tensor& input,
const paddle::Tensor& weight,
const paddle::Tensor& tokens, // If tokenpadding=0, this tensor represents the prefix sum of tensors, otherwise it represents the number of tokens in each group
const paddle::Tensor& input_row_sum,
const paddle::Tensor& weight_scale,
const int64_t token_padding_size,
const int64_t max_tokens,
const bool is_bfloat16) {
if (input.dtype() != paddle::DataType::FLOAT8_E4M3FN) {
PD_THROW("Only supported dtype in ['FLOAT8_E4M3FN'].");
}
if (token_padding_size == 0) {
const int all_tokens = input.dims()[0];
if (is_bfloat16) {
paddle::Tensor out = paddle::empty(
{all_tokens, M}, paddle::DataType::BFLOAT16, input.place());
phi::dtype::bfloat16* out_data = out.data<phi::dtype::bfloat16>();
DisPatchW4AFp8Gemm(
reinterpret_cast<const cutlass::float_e4m3_t*>(
input.data<phi::dtype::float8_e4m3fn>()),
reinterpret_cast<const cutlass::float_e4m3_t*>(
weight.data<uint8_t>()),
tokens.data<int64_t>(),
weight_scale.data<float>(),
input_dequant_scale
? const_cast<float*>(input_dequant_scale.get().data<float>())
: nullptr,
reinterpret_cast<cutlass::bfloat16_t*>(out_data),
token_padding_size,
max_tokens,
Experts,
M,
K,
WeightScaleGroup,
input.stream());
return {out};
} else {
PD_THROW("Only supported dtype in ['BFLOAT16'].");
const int batch_size = weight.dims()[0];
const int M = weight.dims()[1];
const int K = weight.dims()[2] * 2;
if (input.dtype() != paddle::DataType::FLOAT8_E4M3FN) {
PD_THROW("Only supported dtype in ['FLOAT8_E4M3FN'].");
}
} else {
if (is_bfloat16) {
paddle::Tensor out = paddle::empty({Experts, token_padding_size, M},
paddle::DataType::BFLOAT16,
input.place());
phi::dtype::bfloat16* out_data = out.data<phi::dtype::bfloat16>();
DisPatchW4AFp8Gemm(
reinterpret_cast<const cutlass::float_e4m3_t*>(
input.data<phi::dtype::float8_e4m3fn>()),
reinterpret_cast<const cutlass::float_e4m3_t*>(
weight.data<uint8_t>()),
tokens.data<int64_t>(),
weight_scale.data<float>(),
input_dequant_scale
? const_cast<float*>(input_dequant_scale.get().data<float>())
: nullptr,
reinterpret_cast<cutlass::bfloat16_t*>(out_data),
token_padding_size,
max_tokens,
Experts,
M,
K,
WeightScaleGroup,
input.stream());
return {out};
if (token_padding_size == 0) {
const int all_tokens = input.dims()[0];
if (is_bfloat16) {
paddle::Tensor out = paddle::empty({all_tokens, M}, paddle::DataType::BFLOAT16, input.place());
phi::dtype::bfloat16 *out_data = out.data<phi::dtype::bfloat16>();
DisPatchW4AFp8Gemm(
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
tokens.data<int64_t>(),
input_row_sum.data<float>(),
weight_scale.data<float>(),
reinterpret_cast<cutlass::bfloat16_t*>(out_data),
token_padding_size,
max_tokens,
batch_size,
M,
K,
input.stream());
return {out};
} else {
PD_THROW("Only supported dtype in ['BFLOAT16'].");
}
} else {
PD_THROW("Only supported dtype in ['BFLOAT16'].");
if (is_bfloat16) {
paddle::Tensor out = paddle::empty({batch_size, token_padding_size, M}, paddle::DataType::BFLOAT16, input.place());
phi::dtype::bfloat16 * out_data = out.data<phi::dtype::bfloat16>();
DisPatchW4AFp8Gemm(
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
tokens.data<int64_t>(),
input_row_sum.data<float>(),
weight_scale.data<float>(),
reinterpret_cast<cutlass::bfloat16_t*>(out_data),
token_padding_size,
max_tokens,
batch_size,
M,
K,
input.stream());
return {out};
} else {
PD_THROW("Only supported dtype in ['BFLOAT16'].");
}
}
}
}
template <typename InputType, typename OutputType>
void DisPatchW4AFp8GemmWrapper(const InputType* input,
const InputType* weight,
const int64_t* total_rows_before_expert,
const float* input_dequant_scale,
const float* weight_scale,
OutputType* out,
const int64_t token_padding_size,
const int64_t max_tokens,
const int num_experts,
const int64_t M,
const int64_t K,
const int WeightScaleGroup,
cudaStream_t stream) {
using InType = typename NVTraits<InputType>::data_t;
using OutType = typename NVTraits<OutputType>::data_t;
DisPatchW4AFp8Gemm(reinterpret_cast<const InType*>(input),
reinterpret_cast<const InType*>(weight),
total_rows_before_expert,
weight_scale,
input_dequant_scale,
reinterpret_cast<OutType*>(out),
token_padding_size,
max_tokens,
num_experts,
M,
K,
WeightScaleGroup,
stream);
void DisPatchW4AFp8GemmWrapper(
const InputType* input,
const InputType* weight,
const int64_t* total_rows_before_expert,
const float* input_row_sum,
const float* row_scale,
const float* weight_scale,
OutputType * out,
const int64_t token_padding_size,
const int64_t max_tokens,
const int num_experts,
const int64_t M,
const int64_t K,
cudaStream_t stream) {
using InType = typename NVTraits<InputType>::data_t;
using OutType = typename NVTraits<OutputType>::data_t;
DisPatchW4AFp8Gemm(
reinterpret_cast<const InType*>(input),
reinterpret_cast<const InType*>(weight),
total_rows_before_expert,
input_row_sum,
weight_scale,
reinterpret_cast<OutType*>(out),
token_padding_size,
max_tokens,
num_experts,
M,
K,
stream);
}
std::vector<paddle::Tensor> W4AFp8GemmWeightConvert(const paddle::Tensor& weight) {
const int batch_size = weight.dims()[0];
const int M = weight.dims()[1];
const int K = weight.dims()[2];
paddle::Tensor weight_new = paddle::empty({batch_size, M, K / 2}, paddle::DataType::UINT8, weight.place());
weight_convert(weight.data<uint8_t>(), weight_new.data<uint8_t>(), batch_size, M, K);
return {weight_new};
}
template <typename T, int kPackSize>
__global__ void permute_scale_kernel(
T* input_data,
const int numel) {
using LoadT = AlignedVector<T, kPackSize>;
LoadT input_vec;
LoadT dst_vec;
const int load_idx = (blockIdx.x * blockDim.x + threadIdx.x) * kPackSize;
if (load_idx >= numel) {
return;
}
Load<T, kPackSize>(&input_data[load_idx], &input_vec);
for (int i = 0; i < kPackSize; i+=2) {
dst_vec[i] = input_vec[i / 2];
dst_vec[i + 1] = input_vec[i / 2 + 8];
}
Store<T, kPackSize>(dst_vec, &input_data[load_idx]);
}
void W4AFp8GemmScalePermute(const paddle::Tensor& scale) {
const int row = scale.dims().size() == 2 ? scale.dims()[0] : 1;
const int col = scale.dims().size() == 2 ? scale.dims()[1] : scale.dims()[0];
if (col % 16 != 0) {
PD_THROW("Only supported when col is divisible by 16.");
}
const int numel = row * col;
const int threads = 128;
const int kPackSize = 16;
const int grid_size = (numel / kPackSize + threads - 1) / threads;
if (scale.dtype() == paddle::DataType::BFLOAT16) {
permute_scale_kernel<phi::dtype::bfloat16, kPackSize><<<grid_size, threads, 0, scale.stream()>>>(
const_cast<phi::dtype::bfloat16*>(scale.data<phi::dtype::bfloat16>()),
numel
);
} else if (scale.dtype() == paddle::DataType::FLOAT16) {
permute_scale_kernel<phi::dtype::float16, kPackSize><<<grid_size, threads, 0, scale.stream()>>>(
const_cast<phi::dtype::float16*>(scale.data<phi::dtype::float16>()),
numel
);
} else if (scale.dtype() == paddle::DataType::FLOAT32) {
permute_scale_kernel<float, kPackSize><<<grid_size, threads, 0, scale.stream()>>>(
const_cast<float*>(scale.data<float>()),
numel
);
}
}
PD_BUILD_STATIC_OP(w4afp8_gemm_scale_permute)
@@ -201,8 +261,8 @@ PD_BUILD_STATIC_OP(w4afp8_gemm)
.Inputs({"input",
"weight",
"tokens",
"weight_scale",
paddle::Optional("input_dequant_scale")})
"input_row_sum",
"weight_scale"})
.Outputs({"out"})
.Attrs({"token_padding_size: int64_t",
"max_tokens: int64_t",
@@ -215,31 +275,33 @@ PD_BUILD_STATIC_OP(w4afp8_gemm_weight_convert)
.SetKernelFn(PD_KERNEL(W4AFp8GemmWeightConvert));
template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, __nv_bfloat16>(
const __nv_fp8_e4m3* input,
const __nv_fp8_e4m3* weight,
const int64_t* tokens,
const float* input_dequant_scale,
const float* weight_scale,
__nv_bfloat16* out,
const int64_t token_padding_size,
const int64_t max_tokens,
const int num_experts,
const int64_t M,
const int64_t K,
const int WeightScaleGroup,
cudaStream_t stream);
const __nv_fp8_e4m3* input,
const __nv_fp8_e4m3* weight,
const int64_t * tokens,
const float * input_row_sum,
const float * row_scale,
const float * weight_scale,
__nv_bfloat16 * out,
const int64_t token_padding_size,
const int64_t max_tokens,
const int num_experts,
const int64_t M,
const int64_t K,
cudaStream_t stream
);
template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, half>(
const __nv_fp8_e4m3* input,
const __nv_fp8_e4m3* weight,
const int64_t* tokens,
const float* input_dequant_scale,
const float* weight_scale,
half* out,
const int64_t token_padding_size,
const int64_t max_tokens,
const int num_experts,
const int64_t M,
const int64_t K,
const int WeightScaleGroup,
cudaStream_t stream);
const __nv_fp8_e4m3* input,
const __nv_fp8_e4m3* weight,
const int64_t * tokens,
const float * input_row_sum,
const float * row_scale,
const float * weight_scale,
half * out,
const int64_t token_padding_size,
const int64_t max_tokens,
const int num_experts,
const int64_t M,
const int64_t K,
cudaStream_t stream
);

View File

@@ -18,30 +18,30 @@
#include <vector>
#include "helper.h"
std::vector<paddle::Tensor> W4AFp8Gemm(
const paddle::Tensor& input,
const paddle::Tensor& weight,
const paddle::Tensor&
tokens, // If tokenpadding=0, this tensor represents the prefix sum of
// tensors, otherwise it represents the number of tokens in
// each group
const paddle::Tensor& weight_scale,
const paddle::optional<paddle::Tensor>& input_dequant_scale,
const int64_t token_padding_size,
const int64_t max_tokens,
const bool is_bfloat16);
const paddle::Tensor& input,
const paddle::Tensor& weight,
const paddle::Tensor& tokens, // If tokenpadding=0, this tensor represents the prefix sum of tensors, otherwise it represents the number of tokens in each group
const paddle::Tensor& input_row_sum,
const paddle::Tensor& weight_scale,
const int64_t token_padding_size,
const int64_t max_tokens,
const bool is_bfloat16);
template <typename InputType, typename OutputType>
void DisPatchW4AFp8GemmWrapper(const InputType* input,
const InputType* weight,
const int64_t* tokens,
const float* input_dequant_scale,
const float* weight_scale,
OutputType* out,
const int64_t token_padding_size,
const int64_t max_tokens,
const int num_experts,
const int64_t M,
const int64_t K,
const int WeightScaleGroup,
cudaStream_t stream);
void DisPatchW4AFp8GemmWrapper(
const InputType* input,
const InputType* weight,
const int64_t * tokens,
const float * input_row_sum,
const float * row_scale,
const float * weight_scale,
OutputType * out,
const int64_t token_padding_size,
const int64_t max_tokens,
const int num_experts,
const int64_t M,
const int64_t K,
cudaStream_t stream);

View File

@@ -16,280 +16,237 @@
#include "cute/atom/mma_atom.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/arch/reg_reconfig.h"
#include "cutlass/cluster_launch.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
#include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/cluster_launch.hpp"
#include "cutlass/arch/reg_reconfig.h"
#include "kernel_traits.h"
#include "mainloop_fwd.h"
template <typename Ktraits>
void __global__ __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp,
1)
w4afp8_gemm_kernel(
CUTE_GRID_CONSTANT
typename CollectiveMainloopFwd<Ktraits>::Params const mainloop_params) {
using Element = typename Ktraits::Element;
static_assert(cutlass::sizeof_bits_v<Element> == 8);
void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) w4afp8_gemm_kernel(
CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits>::Params const mainloop_params) {
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
using Element = typename Ktraits::Element;
static_assert(cutlass::sizeof_bits_v<Element> == 8);
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{});
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int kBlockN = Ktraits::kBlockN;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int M = Ktraits::M;
static constexpr int K = Ktraits::K;
static constexpr int TokenPackSize = Ktraits::TokenPackSize;
static constexpr int WeightScaleGroup = Ktraits::WeightScaleGroup;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using TileShape_MNK_TAIL = typename Ktraits::TileShape_MNK_TAIL;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
using CollectiveMainloop = CollectiveMainloopFwd<Ktraits>;
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{});
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int kBlockN = Ktraits::kBlockN;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int M = Ktraits::M;
static constexpr int TokenPackSize = Ktraits::TokenPackSize;
static constexpr int TAIL_N = Ktraits::TAIL_N;
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
using ElementOutput = typename Ktraits::ElementOutput;
using CollectiveMainloop = CollectiveMainloopFwd<Ktraits>;
extern __shared__ char shared_memory[];
auto &shared_storage =
*reinterpret_cast<typename Ktraits::SharedStorage *>(shared_memory);
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
using ElementOutput = typename Ktraits::ElementOutput;
const int bidm = blockIdx.x;
const int bidn = blockIdx.y;
const int bidb = blockIdx.z;
const int tidx = threadIdx.x;
extern __shared__ char shared_memory[];
auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
if (tidx == 0) {
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
}
const int bidm = blockIdx.x;
const int bidn = blockIdx.y;
const int bidb = blockIdx.z;
const int tidx = threadIdx.x;
// Obtain warp index
int const warp_group_thread_idx =
threadIdx.x % cutlass::NumThreadsPerWarpGroup;
PipelineParams pipeline_params;
if constexpr (WeightScaleGroup == K) {
pipeline_params.transaction_bytes =
CollectiveMainloop::TmaTransactionBytesA +
CollectiveMainloop::TmaTransactionBytesB;
} else {
pipeline_params.transaction_bytes =
CollectiveMainloop::TmaTransactionBytesA +
CollectiveMainloop::TmaTransactionBytesB +
CollectiveMainloop::TmaTransactionBytesScale;
}
int warp_group_idx = cutlass::canonical_warp_group_idx();
pipeline_params.role = warp_group_idx == 0
? MainloopPipeline::ThreadCategory::Producer
: MainloopPipeline::ThreadCategory::Consumer;
pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = NumMmaThreads;
MainloopPipeline pipeline(
shared_storage.pipeline, pipeline_params, ClusterShape{});
CollectiveMainloop collective_mainloop;
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive_relaxed();
cute::cluster_wait();
} else {
__syncthreads();
}
const int pre_fix_tokens =
TokenPackSize == 0 ? (bidb == 0 ? 0 : mainloop_params.tokens[bidb - 1])
: 0;
const int tokens = TokenPackSize == 0
? mainloop_params.tokens[bidb] - pre_fix_tokens
: mainloop_params.tokens[bidb];
if (bidn * kBlockN >= tokens) {
return;
}
const bool is_need_input_scale = mainloop_params.input_scale != nullptr;
float *input_scale =
is_need_input_scale
? reinterpret_cast<float *>(shared_memory +
sizeof(typename Ktraits::SharedStorage))
: nullptr;
if (warp_group_idx == 0) {
cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 12 ? 40 : 32>();
PipelineState smem_pipe_write =
cutlass::make_producer_start_state<MainloopPipeline>();
collective_mainloop.load(mainloop_params,
pipeline,
smem_pipe_write,
shared_storage,
tokens,
pre_fix_tokens,
bidm,
bidn,
bidb,
tidx);
} else {
cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 12 ? 232 : 160>();
PipelineState smem_pipe_read;
typename Ktraits::TiledMma tiled_mma;
const int mma_tidx = tidx - NumCopyThreads;
if (is_need_input_scale) {
if constexpr (TokenPackSize == 0) {
const int input_scale_idx = pre_fix_tokens + bidn * kBlockN;
if (mma_tidx < tokens) {
reinterpret_cast<float *>(input_scale)[mma_tidx] =
reinterpret_cast<const float *>(mainloop_params.input_scale +
input_scale_idx)[mma_tidx];
}
} else {
const int input_scale_idx = bidb * TokenPackSize + bidn * kBlockN;
if (mma_tidx < kBlockN / 4) {
reinterpret_cast<float4 *>(input_scale)[mma_tidx] =
reinterpret_cast<const float4 *>(mainloop_params.input_scale +
input_scale_idx)[mma_tidx];
}
}
if (tidx == 0) {
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
}
float2 weight_scale;
// Obtain warp index
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
if constexpr (WeightScaleGroup == K) {
weight_scale = reinterpret_cast<const float2 *>(
mainloop_params.weight_scale + bidb * M +
bidm * kBlockM)[mma_tidx / 4];
}
Tensor tSrS =
partition_fragment_C(tiled_mma, select<0, 1>(TileShape_MNK{}));
PipelineParams pipeline_params;
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesA + CollectiveMainloop::TmaTransactionBytesB;
int warp_group_idx = cutlass::canonical_warp_group_idx();
pipeline_params.role = warp_group_idx == 0
? MainloopPipeline::ThreadCategory::Producer
: MainloopPipeline::ThreadCategory::Consumer;
pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = NumMmaThreads;
if constexpr (WeightScaleGroup == K) {
collective_mainloop.mma(mainloop_params,
tiled_mma,
pipeline,
smem_pipe_read,
shared_storage,
tSrS,
mma_tidx);
MainloopPipeline pipeline(shared_storage.pipeline, pipeline_params, ClusterShape{});
CollectiveMainloop collective_mainloop;
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive_relaxed();
cute::cluster_wait();
} else {
collective_mainloop.mma_pipeline(mainloop_params,
tiled_mma,
pipeline,
smem_pipe_read,
shared_storage,
tSrS,
mma_tidx);
__syncthreads();
}
const int pre_fix_tokens = TokenPackSize == 0 ? (bidb == 0 ? 0 : mainloop_params.tokens[bidb - 1]) : 0;
const int tokens = TokenPackSize == 0 ? mainloop_params.tokens[bidb] - pre_fix_tokens : mainloop_params.tokens[bidb];
if (bidn * kBlockN >= tokens) {
return;
}
float* input_row_sum = reinterpret_cast<float*>(
shared_memory + sizeof(typename Ktraits::SharedStorage));
if (warp_group_idx == 0) {
cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 12 ? 40 : 32>();
PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipeline>();
collective_mainloop.load(
mainloop_params,
pipeline,
smem_pipe_write,
shared_storage,
tokens,
pre_fix_tokens,
bidm,
bidn,
bidb,
tidx);
} else {
cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 12 ? 232 : 160>();
PipelineState smem_pipe_read;
typename Ktraits::TiledMma tiled_mma;
typename Ktraits::TiledMma_TAIL tiled_mma_tail;
const int mma_tidx = tidx - NumCopyThreads;
const int lane_id = mma_tidx % 4 * 2;
const float2 weight_scale = reinterpret_cast<const float2*>(mainloop_params.weight_scale + bidb * M + bidm * kBlockM)[mma_tidx / 4];
if constexpr (TokenPackSize == 0) {
const int input_sum_idx = pre_fix_tokens + bidn * kBlockN;
if (mma_tidx < kBlockN) {
reinterpret_cast<float*>(input_row_sum)[mma_tidx] = reinterpret_cast<const float*>(mainloop_params.input_row_sum + input_sum_idx)[mma_tidx];
}
} else {
const int input_sum_idx = bidb * TokenPackSize + bidn * kBlockN;
if (mma_tidx < kBlockN / 4) {
reinterpret_cast<float4*>(input_row_sum)[mma_tidx] = reinterpret_cast<const float4*>(mainloop_params.input_row_sum + input_sum_idx)[mma_tidx];
}
}
const int reamin_tokens = tokens - bidn * kBlockN;
if (TAIL_N > 0 && reamin_tokens < kBlockN) {
Tensor tSrS_tail = partition_fragment_C(tiled_mma_tail, select<0, 1>(TileShape_MNK_TAIL{}));
collective_mainloop.mma<TAIL_N>(
mainloop_params,
tiled_mma_tail,
pipeline,
smem_pipe_read,
shared_storage,
tSrS_tail,
mma_tidx);
collective_mainloop.store<TAIL_N>(
mainloop_params,
tSrS_tail,
shared_storage,
tiled_mma_tail,
input_row_sum + lane_id,
reinterpret_cast<const float*>(&weight_scale),
tokens,
pre_fix_tokens,
bidm,
bidn,
bidb,
mma_tidx);
} else {
Tensor tSrS = partition_fragment_C(tiled_mma, select<0, 1>(TileShape_MNK{}));
collective_mainloop.mma<kBlockN>(
mainloop_params,
tiled_mma,
pipeline,
smem_pipe_read,
shared_storage,
tSrS,
mma_tidx);
collective_mainloop.store<kBlockN>(
mainloop_params,
tSrS,
shared_storage,
tiled_mma,
input_row_sum + lane_id,
reinterpret_cast<const float*>(&weight_scale),
tokens,
pre_fix_tokens,
bidm,
bidn,
bidb,
mma_tidx);
}
}
collective_mainloop.store(mainloop_params,
tSrS,
shared_storage,
tiled_mma,
reinterpret_cast<const float *>(&weight_scale),
input_scale,
tokens,
pre_fix_tokens,
bidm,
bidn,
bidb,
mma_tidx);
}
}
template <int Experts>
template <int Batch>
auto get_gmem_layout(const int Rows, const int Cols) {
return make_layout(make_shape(static_cast<int64_t>(Rows),
static_cast<int64_t>(Cols),
static_cast<int64_t>(Experts)),
make_stride(static_cast<int64_t>(Cols),
cute::_1{},
static_cast<int64_t>(Rows * Cols)));
return make_layout(
make_shape(
static_cast<int64_t>(Rows),
static_cast<int64_t>(Cols),
static_cast<int64_t>(Batch)),
make_stride(
static_cast<int64_t>(Cols),
cute::_1{},
static_cast<int64_t>(Rows * Cols)));
}
template <int Experts>
auto get_scale_layout(const int Rows, const int Cols) {
return make_layout(make_shape(static_cast<int64_t>(Cols),
static_cast<int64_t>(Rows),
static_cast<int64_t>(Experts)),
make_stride(cute::_1{},
static_cast<int64_t>(Cols),
static_cast<int64_t>(Rows * Cols)));
}
template <typename InputType,
typename OutputType,
typename Kernel_traits,
int M,
int K,
int Experts,
int TokenPackSize,
int WeightScaleGroup>
void run_gemm(const InputType *A,
const InputType *B,
OutputType *C,
const float *weight_scale,
const float *input_dequant_scale,
const int64_t *tokens,
const int max_tokens,
cudaStream_t stream) {
using ElementOutput = typename Kernel_traits::ElementOutput;
using Element = typename Kernel_traits::Element;
using CollectiveMainloop = CollectiveMainloopFwd<Kernel_traits>;
using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
constexpr int M_nums =
(M + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
const int N_nums =
(max_tokens + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
constexpr int K_scale_nums = K / Kernel_traits::kBlockM;
static_assert(K % WeightScaleGroup == 0);
static_assert(WeightScaleGroup == 128 || WeightScaleGroup == K);
typename CollectiveMainloop::Params mainloop_params =
CollectiveMainloop::to_underlying_arguments(
{static_cast<Element const *>(A),
get_gmem_layout<Experts>(M, K / 2),
static_cast<Element const *>(B),
get_gmem_layout<Experts>(
TokenPackSize == 0 ? max_tokens : TokenPackSize, K),
static_cast<ElementOutput *>(C),
get_gmem_layout<Experts>(
M, TokenPackSize == 0 ? max_tokens : TokenPackSize),
weight_scale,
get_scale_layout<Experts>(M_nums,
K_scale_nums * Kernel_traits::kBlockM),
input_dequant_scale,
tokens});
void *kernel;
kernel = (void *)w4afp8_gemm_kernel<Kernel_traits>;
int smem_size = sizeof(typename Kernel_traits::SharedStorage) +
Kernel_traits::kBlockN * sizeof(float);
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
dim3 grid_dims;
grid_dims.x = M_nums;
grid_dims.y = N_nums;
grid_dims.z = Experts;
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
dim3 block_dims(ctaSize);
dim3 cluster_dims(size<0>(ClusterShape{}),
size<1>(ClusterShape{}),
size<2>(ClusterShape{}));
cutlass::ClusterLaunchParams launch_params{
grid_dims, block_dims, cluster_dims, smem_size, stream};
cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params);
template <typename InputType, typename OutputType, typename Kernel_traits, int M, int K, int Batch, int TokenPackSize>
void run_gemm(const InputType * A, const InputType * B, OutputType * C, const float *weight_scale,
const float *input_row_sum, const int64_t * tokens, const int64_t max_tokens, cudaStream_t stream) {
using ElementOutput = typename Kernel_traits::ElementOutput;
using Element = typename Kernel_traits::Element;
using CollectiveMainloop = CollectiveMainloopFwd<Kernel_traits>;
using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
constexpr int M_nums = (M + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
const int N_nums = (max_tokens + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
typename CollectiveMainloop::Params mainloop_params =
CollectiveMainloop::to_underlying_arguments({
static_cast<Element const*>(A),
get_gmem_layout<Batch>(M, K / 2),
static_cast<Element const*>(B),
get_gmem_layout<Batch>(TokenPackSize == 0 ? max_tokens: TokenPackSize, K),
static_cast<ElementOutput*>(C),
get_gmem_layout<Batch>(M, TokenPackSize == 0 ? max_tokens : TokenPackSize),
weight_scale,
input_row_sum,
tokens
});
void *kernel;
kernel = (void *)w4afp8_gemm_kernel<Kernel_traits>;
int smem_size = sizeof(typename Kernel_traits::SharedStorage) + sizeof(float) * Kernel_traits::kBlockN;
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
dim3 grid_dims;
grid_dims.x = M_nums;
grid_dims.y = N_nums;
grid_dims.z = Batch;
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
dim3 block_dims(ctaSize);
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
cutlass::launch_kernel_on_cluster(
launch_params, kernel, mainloop_params);
}

View File

@@ -1,131 +0,0 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/extension.h"
void weight_convert(
const uint8_t* weight, uint8_t* weight_new, int experts, int M, int K) {
assert(K % 64 == 0);
for (int b = 0; b < experts; ++b) {
for (int m = 0; m < M; ++m) {
for (int k = 0; k < K; k += 64) {
for (int k_inner = 0; k_inner < 32; ++k_inner) {
uint8_t temp = 0;
uint8_t left = weight[b * M * K + m * K + k + k_inner];
uint8_t right = weight[b * M * K + m * K + k + k_inner + 32];
temp |= left << 4;
temp |= right;
weight_new[b * M * K / 2 + m * K / 2 + k / 2 + k_inner] =
*reinterpret_cast<uint8_t*>(&temp);
}
}
}
}
}
__global__ void weight_permute_interleave_kernelw4afp8(const int8_t* input_data,
int8_t* output_data,
const int original_k,
const int original_n) {
const int numel = original_k * original_n / 4;
const int pack_group_size = 64;
const int thread_group_size = pack_group_size / 4; // 16
const int thread_k_stride = original_k / 4;
const int linear_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (linear_idx >= numel) return;
const int n_id = linear_idx / thread_k_stride;
const int k_id = linear_idx % thread_k_stride;
const int k_group_idx = k_id / thread_group_size;
const int k_idx_in_group = k_id % thread_group_size;
const int8_t* src = input_data +
k_group_idx * pack_group_size / 2 * original_n +
k_idx_in_group * original_n + n_id;
int8_t tmp0 = src[0];
int8_t tmp1 = src[pack_group_size / 4 * original_n];
int8_t tmp00 = (tmp0 & 0xF0) + 112;
int8_t tmp01 = ((tmp0 << 4) & 0xF0) + 112;
int8_t tmp10 = (tmp1 & 0xF0) + 112;
int8_t tmp11 = ((tmp1 << 4) & 0xF0) + 112;
uint8_t utmp00 = *(reinterpret_cast<uint8_t*>(&tmp00));
uint8_t utmp01 = *(reinterpret_cast<uint8_t*>(&tmp01));
uint8_t utmp10 = *(reinterpret_cast<uint8_t*>(&tmp10));
uint8_t utmp11 = *(reinterpret_cast<uint8_t*>(&tmp11));
utmp00 = (utmp00 & 0xF0) >> 4;
utmp01 = (utmp01 & 0xF0) >> 4;
utmp10 = (utmp10 & 0xF0) >> 4;
utmp11 = (utmp11 & 0xF0) >> 4;
tmp00 = *(reinterpret_cast<int8_t*>(&utmp00)) - 7;
tmp01 = *(reinterpret_cast<int8_t*>(&utmp01)) - 7;
tmp10 = *(reinterpret_cast<int8_t*>(&utmp10)) - 7;
tmp11 = *(reinterpret_cast<int8_t*>(&utmp11)) - 7;
if (tmp00 <= 0) {
tmp00 = 8 - tmp00;
}
if (tmp01 <= 0) {
tmp01 = 8 - tmp01;
}
if (tmp10 <= 0) {
tmp10 = 8 - tmp10;
}
if (tmp11 <= 0) {
tmp11 = 8 - tmp11;
}
int8_t dst0 = (tmp01 << 4) | tmp11;
int8_t dst1 = (tmp00 << 4) | tmp10;
int8_t* dst = output_data + n_id * original_k / 2 +
(k_group_idx * pack_group_size / 2) + k_idx_in_group * 2;
dst[0] = dst0;
dst[1] = dst1;
}
std::vector<paddle::Tensor> W4AFp8GemmWeightConvert(
const paddle::Tensor& weight) {
if (weight.place() == paddle::CPUPlace()) {
const int experts = weight.dims()[0];
const int M = weight.dims()[1];
const int K = weight.dims()[2];
paddle::Tensor weight_new = paddle::empty(
{experts, M, K / 2}, paddle::DataType::UINT8, weight.place());
weight_convert(
weight.data<uint8_t>(), weight_new.data<uint8_t>(), experts, M, K);
return {weight_new};
} else {
const int original_k = weight.dims()[0] * 2;
const int original_n = weight.dims()[1];
paddle::Tensor weight_new =
paddle::empty(weight.shape(), paddle::DataType::INT8, weight.place());
const int block_dim = 256;
const int original_numel = original_k * original_n;
const int grid_size = (original_numel + block_dim - 1) / block_dim;
weight_permute_interleave_kernelw4afp8<<<grid_size, block_dim>>>(
weight.data<int8_t>(),
weight_new.data<int8_t>(),
original_k,
original_n);
return {weight_new};
}
}

View File

@@ -1,63 +0,0 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/extension.h"
template <typename T, int kPackSize>
__global__ void permute_scale_kernel(T* input_data, const int numel) {
using LoadT = AlignedVector<T, kPackSize>;
LoadT input_vec;
LoadT dst_vec;
const int load_idx = (blockIdx.x * blockDim.x + threadIdx.x) * kPackSize;
if (load_idx >= numel) {
return;
}
Load<T, kPackSize>(&input_data[load_idx], &input_vec);
for (int i = 0; i < kPackSize; i += 2) {
dst_vec[i] = input_vec[i / 2];
dst_vec[i + 1] = input_vec[i / 2 + 8];
}
Store<T, kPackSize>(dst_vec, &input_data[load_idx]);
}
void W4AFp8GemmScalePermute(const paddle::Tensor& scale) {
const int row = scale.dims().size() == 2 ? scale.dims()[0] : 1;
const int col = scale.dims().size() == 2 ? scale.dims()[1] : scale.dims()[0];
if (col % 16 != 0) {
PD_THROW("Only supported when col is divisible by 16.");
}
const int numel = row * col;
const int threads = 128;
const int kPackSize = 16;
const int grid_size = (numel / kPackSize + threads - 1) / threads;
if (scale.dtype() == paddle::DataType::BFLOAT16) {
permute_scale_kernel<phi::dtype::bfloat16, kPackSize>
<<<grid_size, threads, 0, scale.stream()>>>(
const_cast<phi::dtype::bfloat16*>(
scale.data<phi::dtype::bfloat16>()),
numel);
} else if (scale.dtype() == paddle::DataType::FLOAT16) {
permute_scale_kernel<phi::dtype::float16, kPackSize>
<<<grid_size, threads, 0, scale.stream()>>>(
const_cast<phi::dtype::float16*>(scale.data<phi::dtype::float16>()),
numel);
} else if (scale.dtype() == paddle::DataType::FLOAT32) {
permute_scale_kernel<float, kPackSize>
<<<grid_size, threads, 0, scale.stream()>>>(
const_cast<float*>(scale.data<float>()), numel);
}
}