mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
This reverts commit 93fcf7e4ec.
This commit is contained in:
@@ -304,7 +304,6 @@ paddle::Tensor MoeExpertFFNFunc(
|
|||||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||||
const paddle::Tensor& up_gate_proj_weight,
|
const paddle::Tensor& up_gate_proj_weight,
|
||||||
const paddle::Tensor& down_proj_weight,
|
const paddle::Tensor& down_proj_weight,
|
||||||
const paddle::optional<paddle::Tensor>& up_proj_in_scale,
|
|
||||||
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||||
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
||||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||||
|
|||||||
@@ -1,37 +0,0 @@
|
|||||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
#include "helper.h"
|
|
||||||
|
|
||||||
template <typename T, typename OutT>
|
|
||||||
void MoeFastHardamardWrapper(const T *x_data,
|
|
||||||
const int64_t *expert_idx_per_token,
|
|
||||||
const int64_t *recv_expert_count,
|
|
||||||
const T *shift,
|
|
||||||
const T *smooth,
|
|
||||||
const float *quant_scales,
|
|
||||||
const int quant_round_type,
|
|
||||||
const float quant_max_bound,
|
|
||||||
const float quant_min_bound,
|
|
||||||
const int64_t token_num,
|
|
||||||
const int64_t dim,
|
|
||||||
const int num_max_tokens_per_expert,
|
|
||||||
bool used_in_ep_low_latency,
|
|
||||||
const int hadamard_block_size,
|
|
||||||
OutT *out,
|
|
||||||
cudaStream_t &stream);
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,34 +0,0 @@
|
|||||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#include "fast_hardamard_kernel.hpp"
|
|
||||||
|
|
||||||
template void
|
|
||||||
MoeFastHardamardWrapper<phi::dtype::bfloat16, phi::dtype::bfloat16>(
|
|
||||||
const phi::dtype::bfloat16 *x_data,
|
|
||||||
const int64_t *expert_idx_per_token,
|
|
||||||
const int64_t *recv_expert_count,
|
|
||||||
const phi::dtype::bfloat16 *shift,
|
|
||||||
const phi::dtype::bfloat16 *smooth,
|
|
||||||
const float *quant_scales,
|
|
||||||
const int quant_round_type,
|
|
||||||
const float quant_max_bound,
|
|
||||||
const float quant_min_bound,
|
|
||||||
const int64_t token_num,
|
|
||||||
const int64_t dim,
|
|
||||||
const int num_max_tokens_per_expert,
|
|
||||||
bool used_in_ep_low_latency,
|
|
||||||
const int hadamard_block_size,
|
|
||||||
phi::dtype::bfloat16 *out,
|
|
||||||
cudaStream_t &stream);
|
|
||||||
@@ -1,34 +0,0 @@
|
|||||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#include "fast_hardamard_kernel.hpp"
|
|
||||||
|
|
||||||
template void
|
|
||||||
MoeFastHardamardWrapper<phi::dtype::bfloat16, phi::dtype::float8_e4m3fn>(
|
|
||||||
const phi::dtype::bfloat16 *x_data,
|
|
||||||
const int64_t *expert_idx_per_token,
|
|
||||||
const int64_t *recv_expert_count,
|
|
||||||
const phi::dtype::bfloat16 *shift,
|
|
||||||
const phi::dtype::bfloat16 *smooth,
|
|
||||||
const float *quant_scales,
|
|
||||||
const int quant_round_type,
|
|
||||||
const float quant_max_bound,
|
|
||||||
const float quant_min_bound,
|
|
||||||
const int64_t token_num,
|
|
||||||
const int64_t dim,
|
|
||||||
const int num_max_tokens_per_expert,
|
|
||||||
bool used_in_ep_low_latency,
|
|
||||||
const int hadamard_block_size,
|
|
||||||
phi::dtype::float8_e4m3fn *out,
|
|
||||||
cudaStream_t &stream);
|
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#include "fast_hardamard_kernel.hpp"
|
|
||||||
|
|
||||||
template void MoeFastHardamardWrapper<phi::dtype::bfloat16, int8_t>(
|
|
||||||
const phi::dtype::bfloat16 *x_data,
|
|
||||||
const int64_t *expert_idx_per_token,
|
|
||||||
const int64_t *recv_expert_count,
|
|
||||||
const phi::dtype::bfloat16 *shift,
|
|
||||||
const phi::dtype::bfloat16 *smooth,
|
|
||||||
const float *quant_scales,
|
|
||||||
const int quant_round_type,
|
|
||||||
const float quant_max_bound,
|
|
||||||
const float quant_min_bound,
|
|
||||||
const int64_t token_num,
|
|
||||||
const int64_t dim,
|
|
||||||
const int num_max_tokens_per_expert,
|
|
||||||
bool used_in_ep_low_latency,
|
|
||||||
const int hadamard_block_size,
|
|
||||||
int8_t *out,
|
|
||||||
cudaStream_t &stream);
|
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#include "fast_hardamard_kernel.hpp"
|
|
||||||
|
|
||||||
template void MoeFastHardamardWrapper<phi::dtype::float16, phi::dtype::float16>(
|
|
||||||
const phi::dtype::float16 *x_data,
|
|
||||||
const int64_t *expert_idx_per_token,
|
|
||||||
const int64_t *recv_expert_count,
|
|
||||||
const phi::dtype::float16 *shift,
|
|
||||||
const phi::dtype::float16 *smooth,
|
|
||||||
const float *quant_scales,
|
|
||||||
const int quant_round_type,
|
|
||||||
const float quant_max_bound,
|
|
||||||
const float quant_min_bound,
|
|
||||||
const int64_t token_num,
|
|
||||||
const int64_t dim,
|
|
||||||
const int num_max_tokens_per_expert,
|
|
||||||
bool used_in_ep_low_latency,
|
|
||||||
const int hadamard_block_size,
|
|
||||||
phi::dtype::float16 *out,
|
|
||||||
cudaStream_t &stream);
|
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#include "fast_hardamard_kernel.hpp"
|
|
||||||
|
|
||||||
template void MoeFastHardamardWrapper<phi::dtype::float16, int8_t>(
|
|
||||||
const phi::dtype::float16 *x_data,
|
|
||||||
const int64_t *expert_idx_per_token,
|
|
||||||
const int64_t *recv_expert_count,
|
|
||||||
const phi::dtype::float16 *shift,
|
|
||||||
const phi::dtype::float16 *smooth,
|
|
||||||
const float *quant_scales,
|
|
||||||
const int quant_round_type,
|
|
||||||
const float quant_max_bound,
|
|
||||||
const float quant_min_bound,
|
|
||||||
const int64_t token_num,
|
|
||||||
const int64_t dim,
|
|
||||||
const int num_max_tokens_per_expert,
|
|
||||||
bool used_in_ep_low_latency,
|
|
||||||
const int hadamard_block_size,
|
|
||||||
int8_t *out,
|
|
||||||
cudaStream_t &stream);
|
|
||||||
@@ -25,8 +25,7 @@ template <typename T, int VecSize>
|
|||||||
__global__ void moe_token_type_ids_kernel(T *gating_output,
|
__global__ void moe_token_type_ids_kernel(T *gating_output,
|
||||||
const int *moe_token_type_ids_out,
|
const int *moe_token_type_ids_out,
|
||||||
const int num_rows,
|
const int num_rows,
|
||||||
const int num_experts,
|
const int num_experts, const int k) {
|
||||||
const int k) {
|
|
||||||
const int moe_token_index = blockIdx.x * blockDim.x + threadIdx.x;
|
const int moe_token_index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
if (moe_token_index >= num_rows) {
|
if (moe_token_index >= num_rows) {
|
||||||
@@ -45,8 +44,7 @@ template <typename T>
|
|||||||
void moe_token_type_ids_kernelLauncher(T *gating_output,
|
void moe_token_type_ids_kernelLauncher(T *gating_output,
|
||||||
const int *moe_token_type_ids_out,
|
const int *moe_token_type_ids_out,
|
||||||
const int num_rows,
|
const int num_rows,
|
||||||
const int num_experts,
|
const int num_experts, const int k,
|
||||||
const int k,
|
|
||||||
cudaStream_t stream) {
|
cudaStream_t stream) {
|
||||||
const int blocks = num_rows * k / 512 + 1;
|
const int blocks = num_rows * k / 512 + 1;
|
||||||
const int threads = 512;
|
const int threads = 512;
|
||||||
@@ -54,35 +52,26 @@ void moe_token_type_ids_kernelLauncher(T *gating_output,
|
|||||||
gating_output, moe_token_type_ids_out, num_rows, num_experts, k);
|
gating_output, moe_token_type_ids_out, num_rows, num_experts, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename NvType>
|
template <typename T, typename NvType> class MoeHelper {
|
||||||
class MoeHelper {
|
public:
|
||||||
public:
|
using Fp16Traits = cutlass::WintQuantTraits<NvType, cutlass::WintQuantMethod::kNone>;
|
||||||
using Fp16Traits =
|
using Int8Traits = cutlass::WintQuantTraits<NvType, cutlass::WintQuantMethod::kWeightOnlyInt8>;
|
||||||
cutlass::WintQuantTraits<NvType, cutlass::WintQuantMethod::kNone>;
|
using Int4Traits = cutlass::WintQuantTraits<NvType, cutlass::WintQuantMethod::kWeightOnlyInt4>;
|
||||||
using Int8Traits =
|
|
||||||
cutlass::WintQuantTraits<NvType,
|
|
||||||
cutlass::WintQuantMethod::kWeightOnlyInt8>;
|
|
||||||
using Int4Traits =
|
|
||||||
cutlass::WintQuantTraits<NvType,
|
|
||||||
cutlass::WintQuantMethod::kWeightOnlyInt4>;
|
|
||||||
|
|
||||||
MoeHelper(const std::string gemm_method,
|
MoeHelper(
|
||||||
MoeGemmRunner<NvType, Fp16Traits> *fp16_moe_gemm_runner,
|
const std::string gemm_method,
|
||||||
MoeGemmRunner<NvType, Int8Traits> *int8_moe_gemm_runner,
|
MoeGemmRunner<NvType, Fp16Traits> *fp16_moe_gemm_runner,
|
||||||
MoeGemmRunner<NvType, Int4Traits> *int4_moe_gemm_runner,
|
MoeGemmRunner<NvType, Int8Traits> *int8_moe_gemm_runner,
|
||||||
int layernum = 0)
|
MoeGemmRunner<NvType, Int4Traits> *int4_moe_gemm_runner,
|
||||||
: gemm_method_(gemm_method),
|
int layernum = 0)
|
||||||
fp16_moe_gemm_runner_(fp16_moe_gemm_runner),
|
: gemm_method_(gemm_method), fp16_moe_gemm_runner_(fp16_moe_gemm_runner),
|
||||||
int8_moe_gemm_runner_(int8_moe_gemm_runner),
|
int8_moe_gemm_runner_(int8_moe_gemm_runner),
|
||||||
int4_moe_gemm_runner_(int4_moe_gemm_runner),
|
int4_moe_gemm_runner_(int4_moe_gemm_runner), layernum_(layernum) {}
|
||||||
layernum_(layernum) {}
|
|
||||||
|
|
||||||
// -------- getWorkspaceSize -------- //
|
// -------- getWorkspaceSize -------- //
|
||||||
template <typename KeyT>
|
template <typename KeyT>
|
||||||
size_t getWorkspaceSize(const int64_t num_rows,
|
size_t getWorkspaceSize(const int64_t num_rows, const int64_t hidden_size,
|
||||||
const int64_t hidden_size,
|
const int64_t inter_size, const int64_t num_experts,
|
||||||
const int64_t inter_size,
|
|
||||||
const int64_t num_experts,
|
|
||||||
const int64_t k) {
|
const int64_t k) {
|
||||||
const size_t buf_size = AlignTo16(k * num_rows * hidden_size);
|
const size_t buf_size = AlignTo16(k * num_rows * hidden_size);
|
||||||
const size_t interbuf_size = AlignTo16(k * num_rows * inter_size);
|
const size_t interbuf_size = AlignTo16(k * num_rows * inter_size);
|
||||||
@@ -93,10 +82,10 @@ class MoeHelper {
|
|||||||
// FfnLayer forward.
|
// FfnLayer forward.
|
||||||
size_t total_ws_bytes =
|
size_t total_ws_bytes =
|
||||||
5 * num_moe_inputs *
|
5 * num_moe_inputs *
|
||||||
sizeof(int); // source_rows_, permuted_rows_, permuted_experts_
|
sizeof(int); // source_rows_, permuted_rows_, permuted_experts_
|
||||||
total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data
|
total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data
|
||||||
total_ws_bytes +=
|
total_ws_bytes +=
|
||||||
padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_
|
padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_
|
||||||
|
|
||||||
const size_t bytes_for_fc1_result = interbuf_size * sizeof(KeyT);
|
const size_t bytes_for_fc1_result = interbuf_size * sizeof(KeyT);
|
||||||
const size_t sorter_ws_size_bytes =
|
const size_t sorter_ws_size_bytes =
|
||||||
@@ -111,8 +100,8 @@ class MoeHelper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
total_ws_bytes +=
|
total_ws_bytes +=
|
||||||
bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub
|
bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub
|
||||||
// sorting workspace
|
// sorting workspace
|
||||||
|
|
||||||
int64_t num_softmax_outs = 0;
|
int64_t num_softmax_outs = 0;
|
||||||
const bool is_pow_2 =
|
const bool is_pow_2 =
|
||||||
@@ -126,27 +115,20 @@ class MoeHelper {
|
|||||||
return total_ws_bytes;
|
return total_ws_bytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ComputeFFN(const paddle::Tensor *input,
|
void
|
||||||
const paddle::Tensor *gate_weight,
|
ComputeFFN(const paddle::Tensor *input, const paddle::Tensor *gate_weight,
|
||||||
const paddle::Tensor *up_gate_proj_weight,
|
const paddle::Tensor *up_gate_proj_weight,
|
||||||
const paddle::Tensor *up_gate_proj_scale,
|
const paddle::Tensor *up_gate_proj_scale, const paddle::Tensor *up_gate_proj_bias,
|
||||||
const paddle::Tensor *up_gate_proj_bias,
|
const paddle::Tensor *down_proj_weight,
|
||||||
const paddle::Tensor *down_proj_weight,
|
const paddle::Tensor *down_proj_scale, const paddle::Tensor *down_proj_bias,
|
||||||
const paddle::Tensor *down_proj_scale,
|
const paddle::Tensor *moe_token_type_ids, const int moe_topk,
|
||||||
const paddle::Tensor *down_proj_bias,
|
const bool group_moe, const bool norm_topk_prob,
|
||||||
const paddle::Tensor *moe_token_type_ids,
|
const float routed_scaling_factor, const std::string moe_type,
|
||||||
const int moe_topk,
|
paddle::Tensor *output) {
|
||||||
const bool group_moe,
|
|
||||||
const bool norm_topk_prob,
|
|
||||||
const float routed_scaling_factor,
|
|
||||||
const std::string moe_type,
|
|
||||||
paddle::Tensor *output) {
|
|
||||||
auto *input_activations = input->data<T>();
|
auto *input_activations = input->data<T>();
|
||||||
auto *gating_weights = gate_weight->data<float>();
|
auto *gating_weights = gate_weight->data<float>();
|
||||||
const T *fc1_expert_biases =
|
const T *fc1_expert_biases = up_gate_proj_bias ? up_gate_proj_bias->data<T>() : nullptr;
|
||||||
up_gate_proj_bias ? up_gate_proj_bias->data<T>() : nullptr;
|
const T *fc2_expert_biases = down_proj_bias ? down_proj_bias->data<T>() : nullptr;
|
||||||
const T *fc2_expert_biases =
|
|
||||||
down_proj_bias ? down_proj_bias->data<T>() : nullptr;
|
|
||||||
|
|
||||||
auto *output_ = output->data<T>();
|
auto *output_ = output->data<T>();
|
||||||
auto stream = input->stream();
|
auto stream = input->stream();
|
||||||
@@ -166,8 +148,7 @@ class MoeHelper {
|
|||||||
const int64_t hidden_size = up_gate_proj_dims[1];
|
const int64_t hidden_size = up_gate_proj_dims[1];
|
||||||
int64_t inter_dim = 0;
|
int64_t inter_dim = 0;
|
||||||
if (moe_type == "qkv") {
|
if (moe_type == "qkv") {
|
||||||
inter_dim =
|
inter_dim = up_gate_proj_dims[2] * up_gate_proj_dims[3] * up_gate_proj_dims[4];
|
||||||
up_gate_proj_dims[2] * up_gate_proj_dims[3] * up_gate_proj_dims[4];
|
|
||||||
} else {
|
} else {
|
||||||
inter_dim = up_gate_proj_dims[2];
|
inter_dim = up_gate_proj_dims[2];
|
||||||
}
|
}
|
||||||
@@ -251,79 +232,44 @@ class MoeHelper {
|
|||||||
if (moe_token_type_ids) {
|
if (moe_token_type_ids) {
|
||||||
auto *moe_token_type_ids_out = moe_token_type_ids->data<int>();
|
auto *moe_token_type_ids_out = moe_token_type_ids->data<int>();
|
||||||
moe_token_type_ids_kernelLauncher<float>(gating_output,
|
moe_token_type_ids_kernelLauncher<float>(gating_output,
|
||||||
moe_token_type_ids_out,
|
moe_token_type_ids_out, num_rows,
|
||||||
num_rows,
|
num_experts, k, stream);
|
||||||
num_experts,
|
|
||||||
k,
|
|
||||||
stream);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
topk_gating_softmax_kernelLauncher<float, int>(gating_output,
|
topk_gating_softmax_kernelLauncher<float, int>(
|
||||||
nullptr,
|
gating_output, nullptr, expert_scales_float, softmax_out_,
|
||||||
expert_scales_float,
|
expert_for_source_row, source_rows_, softmax_max_prob, num_rows,
|
||||||
softmax_out_,
|
num_experts, k, group_moe, stream);
|
||||||
expert_for_source_row,
|
|
||||||
source_rows_,
|
|
||||||
softmax_max_prob,
|
|
||||||
num_rows,
|
|
||||||
num_experts,
|
|
||||||
k,
|
|
||||||
group_moe,
|
|
||||||
stream);
|
|
||||||
|
|
||||||
const int64_t sorter_ws_size_bytes =
|
const int64_t sorter_ws_size_bytes =
|
||||||
AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows)));
|
AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows)));
|
||||||
|
|
||||||
sorter_.run(fc1_result_,
|
sorter_.run(fc1_result_, sorter_ws_size_bytes, expert_for_source_row,
|
||||||
sorter_ws_size_bytes,
|
permuted_experts_, source_rows_, permuted_rows_, k * num_rows,
|
||||||
expert_for_source_row,
|
false, stream);
|
||||||
permuted_experts_,
|
|
||||||
source_rows_,
|
|
||||||
permuted_rows_,
|
|
||||||
k * num_rows,
|
|
||||||
false,
|
|
||||||
stream);
|
|
||||||
|
|
||||||
initialize_moe_routing_kernelLauncher(
|
initialize_moe_routing_kernelLauncher(
|
||||||
input_activations,
|
input_activations, permuted_data_, permuted_rows_, nullptr, nullptr,
|
||||||
permuted_data_,
|
expanded_source_row_to_expanded_dest_row, num_rows, num_rows,
|
||||||
permuted_rows_,
|
hidden_size, k, stream);
|
||||||
nullptr,
|
|
||||||
nullptr,
|
|
||||||
expanded_source_row_to_expanded_dest_row,
|
|
||||||
nullptr,
|
|
||||||
num_rows,
|
|
||||||
num_rows,
|
|
||||||
hidden_size,
|
|
||||||
k,
|
|
||||||
stream);
|
|
||||||
|
|
||||||
const int64_t expanded_active_expert_rows = k * num_rows;
|
const int64_t expanded_active_expert_rows = k * num_rows;
|
||||||
|
|
||||||
compute_total_rows_before_expert(permuted_experts_,
|
compute_total_rows_before_expert(permuted_experts_,
|
||||||
expanded_active_expert_rows,
|
expanded_active_expert_rows, num_experts,
|
||||||
num_experts,
|
total_rows_before_expert_, stream);
|
||||||
total_rows_before_expert_,
|
|
||||||
stream);
|
|
||||||
|
|
||||||
if (gemm_method_ == "weight_only_int8") {
|
if (gemm_method_ == "weight_only_int8") {
|
||||||
typename Int8Traits::Arguments up_gate_proj_quant_args;
|
typename Int8Traits::Arguments up_gate_proj_quant_args;
|
||||||
int8_moe_gemm_runner_->moe_gemm_bias_act(
|
int8_moe_gemm_runner_->moe_gemm_bias_act(
|
||||||
reinterpret_cast<NvType *>(permuted_data_),
|
reinterpret_cast<NvType *>(permuted_data_),
|
||||||
reinterpret_cast<const uint8_t *>(
|
reinterpret_cast<const uint8_t *>(up_gate_proj_weight->data<int8_t>()),
|
||||||
up_gate_proj_weight->data<int8_t>()),
|
|
||||||
reinterpret_cast<const NvType *>(up_gate_proj_scale->data<T>()),
|
reinterpret_cast<const NvType *>(up_gate_proj_scale->data<T>()),
|
||||||
reinterpret_cast<const NvType *>(fc1_expert_biases),
|
reinterpret_cast<const NvType *>(fc1_expert_biases),
|
||||||
reinterpret_cast<NvType *>(fc1_out),
|
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
|
||||||
total_rows_before_expert_,
|
-1, // useless
|
||||||
-1, // useless
|
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
|
||||||
expanded_active_expert_rows,
|
up_gate_proj_quant_args, "none", stream);
|
||||||
inter_size,
|
|
||||||
hidden_size,
|
|
||||||
num_experts,
|
|
||||||
up_gate_proj_quant_args,
|
|
||||||
"none",
|
|
||||||
stream);
|
|
||||||
} else if (gemm_method_ == "weight_only_int4") {
|
} else if (gemm_method_ == "weight_only_int4") {
|
||||||
typename Int4Traits::Arguments up_gate_proj_quant_args;
|
typename Int4Traits::Arguments up_gate_proj_quant_args;
|
||||||
int4_moe_gemm_runner_->moe_gemm_bias_act(
|
int4_moe_gemm_runner_->moe_gemm_bias_act(
|
||||||
@@ -332,33 +278,20 @@ class MoeHelper {
|
|||||||
up_gate_proj_weight->data<int8_t>()),
|
up_gate_proj_weight->data<int8_t>()),
|
||||||
reinterpret_cast<const NvType *>(up_gate_proj_scale->data<T>()),
|
reinterpret_cast<const NvType *>(up_gate_proj_scale->data<T>()),
|
||||||
reinterpret_cast<const NvType *>(fc1_expert_biases),
|
reinterpret_cast<const NvType *>(fc1_expert_biases),
|
||||||
reinterpret_cast<NvType *>(fc1_out),
|
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
|
||||||
total_rows_before_expert_,
|
-1, // useless
|
||||||
-1, // useless
|
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
|
||||||
expanded_active_expert_rows,
|
up_gate_proj_quant_args, "none", stream);
|
||||||
inter_size,
|
|
||||||
hidden_size,
|
|
||||||
num_experts,
|
|
||||||
up_gate_proj_quant_args,
|
|
||||||
"none",
|
|
||||||
stream);
|
|
||||||
} else {
|
} else {
|
||||||
typename Fp16Traits::Arguments up_gate_proj_quant_args;
|
typename Fp16Traits::Arguments up_gate_proj_quant_args;
|
||||||
fp16_moe_gemm_runner_->moe_gemm_bias_act(
|
fp16_moe_gemm_runner_->moe_gemm_bias_act(
|
||||||
reinterpret_cast<NvType *>(permuted_data_),
|
reinterpret_cast<NvType *>(permuted_data_),
|
||||||
reinterpret_cast<const NvType *>(up_gate_proj_weight->data<T>()),
|
reinterpret_cast<const NvType *>(up_gate_proj_weight->data<T>()), nullptr,
|
||||||
nullptr,
|
|
||||||
reinterpret_cast<const NvType *>(fc1_expert_biases),
|
reinterpret_cast<const NvType *>(fc1_expert_biases),
|
||||||
reinterpret_cast<NvType *>(fc1_out),
|
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
|
||||||
total_rows_before_expert_,
|
-1, // useless
|
||||||
-1, // useless
|
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
|
||||||
expanded_active_expert_rows,
|
up_gate_proj_quant_args, "none", stream);
|
||||||
inter_size,
|
|
||||||
hidden_size,
|
|
||||||
num_experts,
|
|
||||||
up_gate_proj_quant_args,
|
|
||||||
"none",
|
|
||||||
stream);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (moe_type == "ffn") {
|
if (moe_type == "ffn") {
|
||||||
@@ -376,15 +309,10 @@ class MoeHelper {
|
|||||||
reinterpret_cast<NvType *>(act_out),
|
reinterpret_cast<NvType *>(act_out),
|
||||||
reinterpret_cast<const uint8_t *>(down_proj_weight->data<int8_t>()),
|
reinterpret_cast<const uint8_t *>(down_proj_weight->data<int8_t>()),
|
||||||
reinterpret_cast<const NvType *>(down_proj_scale->data<T>()),
|
reinterpret_cast<const NvType *>(down_proj_scale->data<T>()),
|
||||||
reinterpret_cast<NvType *>(fc2_result),
|
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
|
||||||
total_rows_before_expert_,
|
-1, // useless
|
||||||
-1, // useless
|
expanded_active_expert_rows, hidden_size, inter_size / 2,
|
||||||
expanded_active_expert_rows,
|
num_experts, down_proj_quant_args, stream);
|
||||||
hidden_size,
|
|
||||||
inter_size / 2,
|
|
||||||
num_experts,
|
|
||||||
down_proj_quant_args,
|
|
||||||
stream);
|
|
||||||
} else if (gemm_method_ == "weight_only_int4") {
|
} else if (gemm_method_ == "weight_only_int4") {
|
||||||
typename Int4Traits::Arguments down_proj_quant_args;
|
typename Int4Traits::Arguments down_proj_quant_args;
|
||||||
int4_moe_gemm_runner_->moe_gemm(
|
int4_moe_gemm_runner_->moe_gemm(
|
||||||
@@ -392,66 +320,40 @@ class MoeHelper {
|
|||||||
reinterpret_cast<const cutlass::uint4b_t *>(
|
reinterpret_cast<const cutlass::uint4b_t *>(
|
||||||
down_proj_weight->data<int8_t>()),
|
down_proj_weight->data<int8_t>()),
|
||||||
reinterpret_cast<const NvType *>(down_proj_scale->data<T>()),
|
reinterpret_cast<const NvType *>(down_proj_scale->data<T>()),
|
||||||
reinterpret_cast<NvType *>(fc2_result),
|
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
|
||||||
total_rows_before_expert_,
|
-1, // useless
|
||||||
-1, // useless
|
expanded_active_expert_rows, hidden_size, inter_size / 2,
|
||||||
expanded_active_expert_rows,
|
num_experts, down_proj_quant_args, stream);
|
||||||
hidden_size,
|
|
||||||
inter_size / 2,
|
|
||||||
num_experts,
|
|
||||||
down_proj_quant_args,
|
|
||||||
stream);
|
|
||||||
} else {
|
} else {
|
||||||
typename Fp16Traits::Arguments down_proj_quant_args;
|
typename Fp16Traits::Arguments down_proj_quant_args;
|
||||||
fp16_moe_gemm_runner_->moe_gemm(
|
fp16_moe_gemm_runner_->moe_gemm(
|
||||||
reinterpret_cast<NvType *>(act_out),
|
reinterpret_cast<NvType *>(act_out),
|
||||||
reinterpret_cast<const NvType *>(down_proj_weight->data<T>()),
|
reinterpret_cast<const NvType *>(down_proj_weight->data<T>()), nullptr,
|
||||||
nullptr,
|
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
|
||||||
reinterpret_cast<NvType *>(fc2_result),
|
-1, // useless
|
||||||
total_rows_before_expert_,
|
expanded_active_expert_rows, hidden_size, inter_size / 2,
|
||||||
-1, // useless
|
num_experts, down_proj_quant_args, stream);
|
||||||
expanded_active_expert_rows,
|
|
||||||
hidden_size,
|
|
||||||
inter_size / 2,
|
|
||||||
num_experts,
|
|
||||||
down_proj_quant_args,
|
|
||||||
stream);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
finalize_moe_routing_kernelLauncher(
|
finalize_moe_routing_kernelLauncher(
|
||||||
fc2_result,
|
fc2_result, output_, fc2_expert_biases,
|
||||||
output_,
|
|
||||||
fc2_expert_biases,
|
|
||||||
reinterpret_cast<float *>(expert_scales_float),
|
reinterpret_cast<float *>(expert_scales_float),
|
||||||
expanded_source_row_to_expanded_dest_row,
|
expanded_source_row_to_expanded_dest_row, expert_for_source_row,
|
||||||
expert_for_source_row,
|
num_rows, hidden_size, k, static_cast<int>(1), norm_topk_prob,
|
||||||
num_rows,
|
routed_scaling_factor, stream);
|
||||||
hidden_size,
|
|
||||||
k,
|
|
||||||
static_cast<int>(1),
|
|
||||||
norm_topk_prob,
|
|
||||||
routed_scaling_factor,
|
|
||||||
stream);
|
|
||||||
} else {
|
} else {
|
||||||
finalize_moe_routing_kernelLauncher(
|
finalize_moe_routing_kernelLauncher(
|
||||||
// fc2_result,
|
// fc2_result,
|
||||||
fc1_out,
|
fc1_out, output_,
|
||||||
output_,
|
fc1_expert_biases, // fc2_expert_biases,
|
||||||
fc1_expert_biases, // fc2_expert_biases,
|
|
||||||
reinterpret_cast<float *>(expert_scales_float),
|
reinterpret_cast<float *>(expert_scales_float),
|
||||||
expanded_source_row_to_expanded_dest_row,
|
expanded_source_row_to_expanded_dest_row, expert_for_source_row,
|
||||||
expert_for_source_row,
|
num_rows, inter_size, k, static_cast<int>(0), norm_topk_prob,
|
||||||
num_rows,
|
routed_scaling_factor, stream);
|
||||||
inter_size,
|
|
||||||
k,
|
|
||||||
static_cast<int>(0),
|
|
||||||
norm_topk_prob,
|
|
||||||
routed_scaling_factor,
|
|
||||||
stream);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string gemm_method_;
|
std::string gemm_method_;
|
||||||
MoeGemmRunner<NvType, Fp16Traits> *fp16_moe_gemm_runner_;
|
MoeGemmRunner<NvType, Fp16Traits> *fp16_moe_gemm_runner_;
|
||||||
MoeGemmRunner<NvType, Int8Traits> *int8_moe_gemm_runner_;
|
MoeGemmRunner<NvType, Int8Traits> *int8_moe_gemm_runner_;
|
||||||
@@ -460,4 +362,4 @@ class MoeHelper {
|
|||||||
CubKeyValueSorter sorter_;
|
CubKeyValueSorter sorter_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace phi
|
} // namespace phi
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -26,26 +26,17 @@
|
|||||||
|
|
||||||
template <paddle::DataType T>
|
template <paddle::DataType T>
|
||||||
void MoeDispatchKernel(
|
void MoeDispatchKernel(
|
||||||
const paddle::Tensor &input,
|
const paddle::Tensor &input, const paddle::Tensor &gating_output,
|
||||||
const paddle::Tensor &gating_output,
|
|
||||||
const paddle::optional<paddle::Tensor> &gating_correction_bias,
|
const paddle::optional<paddle::Tensor> &gating_correction_bias,
|
||||||
const paddle::optional<paddle::Tensor> &w4a8_in_scale,
|
const paddle::optional<paddle::Tensor> &w4a8_in_scale, const int moe_topk,
|
||||||
const int moe_topk,
|
const bool group_moe, const bool topk_only_mode, const int num_rows,
|
||||||
const bool group_moe,
|
const int hidden_size, const int expert_num, paddle::Tensor *permute_input,
|
||||||
const bool topk_only_mode,
|
|
||||||
const int num_rows,
|
|
||||||
const int hidden_size,
|
|
||||||
const int expert_num,
|
|
||||||
paddle::Tensor *permute_input,
|
|
||||||
paddle::Tensor *tokens_expert_prefix_sum,
|
paddle::Tensor *tokens_expert_prefix_sum,
|
||||||
paddle::Tensor *permute_indices_per_token,
|
paddle::Tensor *permute_indices_per_token, paddle::Tensor *topk_weight,
|
||||||
paddle::Tensor *topk_weight,
|
paddle::Tensor *topk_idx, paddle::Tensor *expert_idx_per_token) {
|
||||||
paddle::Tensor *topk_idx,
|
|
||||||
paddle::Tensor *expert_idx_per_token,
|
|
||||||
paddle::Tensor *dequant_scale) {
|
|
||||||
using namespace phi;
|
using namespace phi;
|
||||||
|
|
||||||
if (num_rows == 0) {
|
if (num_rows == 0){
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
typedef PDTraits<T> traits_;
|
typedef PDTraits<T> traits_;
|
||||||
@@ -57,14 +48,12 @@ void MoeDispatchKernel(
|
|||||||
|
|
||||||
if (group_moe) {
|
if (group_moe) {
|
||||||
// Check if expert_num is divisible by moe_topk, else throw an error
|
// Check if expert_num is divisible by moe_topk, else throw an error
|
||||||
PADDLE_ENFORCE_EQ(expert_num % moe_topk,
|
PADDLE_ENFORCE_EQ(expert_num % moe_topk, 0,
|
||||||
0,
|
|
||||||
common::errors::InvalidArgument(
|
common::errors::InvalidArgument(
|
||||||
"The number of experts (expert_num) "
|
"The number of experts (expert_num) "
|
||||||
"must be divisible by moe_topk. "
|
"must be divisible by moe_topk. "
|
||||||
"Got expert_num = %d and moe_topk = %d.",
|
"Got expert_num = %d and moe_topk = %d.",
|
||||||
expert_num,
|
expert_num, moe_topk));
|
||||||
moe_topk));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const int num_moe_inputs = AlignTo16(num_rows * moe_topk);
|
const int num_moe_inputs = AlignTo16(num_rows * moe_topk);
|
||||||
@@ -79,8 +68,7 @@ void MoeDispatchKernel(
|
|||||||
|
|
||||||
paddle::Tensor ws_ptr_tensor =
|
paddle::Tensor ws_ptr_tensor =
|
||||||
GetEmptyTensor({bytes + sorter_ws_size_bytes + sort_tmp_in_out_size},
|
GetEmptyTensor({bytes + sorter_ws_size_bytes + sort_tmp_in_out_size},
|
||||||
paddle::DataType::INT8,
|
paddle::DataType::INT8, place);
|
||||||
place);
|
|
||||||
|
|
||||||
int8_t *ws_ptr = ws_ptr_tensor.data<int8_t>();
|
int8_t *ws_ptr = ws_ptr_tensor.data<int8_t>();
|
||||||
int *source_rows_ = reinterpret_cast<int *>(ws_ptr);
|
int *source_rows_ = reinterpret_cast<int *>(ws_ptr);
|
||||||
@@ -108,8 +96,8 @@ void MoeDispatchKernel(
|
|||||||
paddle::Tensor softmax_buffer;
|
paddle::Tensor softmax_buffer;
|
||||||
|
|
||||||
if (!is_pow_2 || expert_num > 256 || group_moe || gating_correction_bias) {
|
if (!is_pow_2 || expert_num > 256 || group_moe || gating_correction_bias) {
|
||||||
softmax_buffer = GetEmptyTensor(
|
softmax_buffer = GetEmptyTensor({num_rows * expert_num},
|
||||||
{num_rows * expert_num}, paddle::DataType::FLOAT32, place);
|
paddle::DataType::FLOAT32, place);
|
||||||
softmax_out_ = softmax_buffer.data<float>();
|
softmax_out_ = softmax_buffer.data<float>();
|
||||||
} else {
|
} else {
|
||||||
softmax_out_ = nullptr;
|
softmax_out_ = nullptr;
|
||||||
@@ -119,106 +107,47 @@ void MoeDispatchKernel(
|
|||||||
gating_output.data<float>(),
|
gating_output.data<float>(),
|
||||||
gating_correction_bias ? gating_correction_bias.get().data<float>()
|
gating_correction_bias ? gating_correction_bias.get().data<float>()
|
||||||
: nullptr,
|
: nullptr,
|
||||||
topk_weight->data<float>(),
|
topk_weight->data<float>(), softmax_out_, topk_idx_ptr, source_rows_,
|
||||||
softmax_out_,
|
softmax_max_prob, num_rows, expert_num, moe_topk, group_moe, stream,
|
||||||
topk_idx_ptr,
|
|
||||||
source_rows_,
|
|
||||||
softmax_max_prob,
|
|
||||||
num_rows,
|
|
||||||
expert_num,
|
|
||||||
moe_topk,
|
|
||||||
group_moe,
|
|
||||||
stream,
|
|
||||||
topk_only_mode);
|
topk_only_mode);
|
||||||
|
|
||||||
sorter_.run(reinterpret_cast<void *>(sorter_ws_ptr),
|
sorter_.run(reinterpret_cast<void *>(sorter_ws_ptr), sorter_ws_size_bytes,
|
||||||
sorter_ws_size_bytes,
|
topk_idx_ptr, expert_idx_per_token->data<int32_t>(), source_rows_,
|
||||||
topk_idx_ptr,
|
permuted_rows_, moe_topk * num_rows, false, stream);
|
||||||
expert_idx_per_token->data<int32_t>(),
|
|
||||||
source_rows_,
|
|
||||||
permuted_rows_,
|
|
||||||
moe_topk * num_rows,
|
|
||||||
false,
|
|
||||||
stream);
|
|
||||||
|
|
||||||
if (w4a8_in_scale) {
|
if (w4a8_in_scale) {
|
||||||
if (permute_input->dtype() == paddle::DataType::INT8) {
|
if (permute_input->dtype() == paddle::DataType::INT8) {
|
||||||
initialize_moe_routing_kernelLauncher(
|
initialize_moe_routing_kernelLauncher(
|
||||||
input.data<data_t>(),
|
input.data<data_t>(), permute_input->data<int8_t>(), permuted_rows_,
|
||||||
permute_input->data<int8_t>(),
|
expert_idx_per_token->data<int32_t>(), w4a8_in_scale->data<float>(),
|
||||||
permuted_rows_,
|
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
|
||||||
expert_idx_per_token->data<int32_t>(),
|
hidden_size, moe_topk, stream);
|
||||||
w4a8_in_scale->data<float>(),
|
|
||||||
permute_indices_per_token->data<int32_t>(),
|
|
||||||
nullptr,
|
|
||||||
num_rows,
|
|
||||||
num_rows,
|
|
||||||
hidden_size,
|
|
||||||
moe_topk,
|
|
||||||
stream);
|
|
||||||
} else if (permute_input->dtype() == paddle::DataType::FLOAT8_E4M3FN) {
|
} else if (permute_input->dtype() == paddle::DataType::FLOAT8_E4M3FN) {
|
||||||
initialize_moe_routing_kernelLauncher(
|
initialize_moe_routing_kernelLauncher(
|
||||||
input.data<data_t>(),
|
input.data<data_t>(), permute_input->data<float8_e4m3fn>(),
|
||||||
permute_input->data<float8_e4m3fn>(),
|
permuted_rows_, expert_idx_per_token->data<int32_t>(),
|
||||||
permuted_rows_,
|
w4a8_in_scale->data<float>(),
|
||||||
expert_idx_per_token->data<int32_t>(),
|
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
|
||||||
w4a8_in_scale->data<float>(),
|
hidden_size, moe_topk, stream);
|
||||||
permute_indices_per_token->data<int32_t>(),
|
|
||||||
nullptr,
|
|
||||||
num_rows,
|
|
||||||
num_rows,
|
|
||||||
hidden_size,
|
|
||||||
moe_topk,
|
|
||||||
stream);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (permute_input->dtype() == paddle::DataType::FLOAT8_E4M3FN) {
|
initialize_moe_routing_kernelLauncher(
|
||||||
initialize_moe_routing_kernelLauncher(
|
input.data<data_t>(), permute_input->data<data_t>(), permuted_rows_,
|
||||||
input.data<data_t>(),
|
expert_idx_per_token->data<int32_t>(), nullptr,
|
||||||
permute_input->data<float8_e4m3fn>(),
|
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
|
||||||
permuted_rows_,
|
hidden_size, moe_topk, stream);
|
||||||
expert_idx_per_token->data<int32_t>(),
|
|
||||||
nullptr,
|
|
||||||
permute_indices_per_token->data<int32_t>(),
|
|
||||||
dequant_scale->data<float>(),
|
|
||||||
num_rows,
|
|
||||||
num_rows,
|
|
||||||
hidden_size,
|
|
||||||
moe_topk,
|
|
||||||
stream);
|
|
||||||
} else {
|
|
||||||
initialize_moe_routing_kernelLauncher(
|
|
||||||
input.data<data_t>(),
|
|
||||||
permute_input->data<data_t>(),
|
|
||||||
permuted_rows_,
|
|
||||||
expert_idx_per_token->data<int32_t>(),
|
|
||||||
nullptr,
|
|
||||||
permute_indices_per_token->data<int32_t>(),
|
|
||||||
nullptr,
|
|
||||||
num_rows,
|
|
||||||
num_rows,
|
|
||||||
hidden_size,
|
|
||||||
moe_topk,
|
|
||||||
stream);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
compute_total_rows_before_expert(expert_idx_per_token->data<int32_t>(),
|
compute_total_rows_before_expert(
|
||||||
moe_topk * num_rows,
|
expert_idx_per_token->data<int32_t>(), moe_topk * num_rows, expert_num,
|
||||||
expert_num,
|
tokens_expert_prefix_sum->data<int64_t>(), stream);
|
||||||
tokens_expert_prefix_sum->data<int64_t>(),
|
|
||||||
stream);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<paddle::Tensor> MoeExpertDispatch(
|
std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||||
const paddle::Tensor &input,
|
const paddle::Tensor &input, const paddle::Tensor &gating_output,
|
||||||
const paddle::Tensor &gating_output,
|
|
||||||
const paddle::optional<paddle::Tensor> &gating_correction_bias,
|
const paddle::optional<paddle::Tensor> &gating_correction_bias,
|
||||||
const paddle::optional<paddle::Tensor> &w4a8_in_scale,
|
const paddle::optional<paddle::Tensor> &w4a8_in_scale, const int moe_topk,
|
||||||
const int moe_topk,
|
const bool group_moe, const std::string &moe_quant_type, const bool topk_only_mode) {
|
||||||
const bool group_moe,
|
|
||||||
const std::string &moe_quant_type,
|
|
||||||
const bool topk_only_mode) {
|
|
||||||
const auto input_type = input.dtype();
|
const auto input_type = input.dtype();
|
||||||
auto place = input.place();
|
auto place = input.place();
|
||||||
int token_rows = 0;
|
int token_rows = 0;
|
||||||
@@ -241,21 +170,10 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
|
|||||||
} else if (moe_quant_type == "w4afp8") {
|
} else if (moe_quant_type == "w4afp8") {
|
||||||
permute_input_dtype = paddle::DataType::FLOAT8_E4M3FN;
|
permute_input_dtype = paddle::DataType::FLOAT8_E4M3FN;
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
if (moe_quant_type == "w4afp8") {
|
|
||||||
permute_input_dtype = paddle::DataType::FLOAT8_E4M3FN;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto permute_input = GetEmptyTensor(
|
auto permute_input = GetEmptyTensor({moe_topk * num_rows, hidden_size},
|
||||||
{moe_topk * num_rows, hidden_size}, permute_input_dtype, place);
|
permute_input_dtype, place);
|
||||||
int dequant_scale_size = 1;
|
|
||||||
if (moe_quant_type == "w4afp8" && !w4a8_in_scale) {
|
|
||||||
dequant_scale_size = moe_topk * num_rows;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto dequant_scale =
|
|
||||||
GetEmptyTensor({dequant_scale_size}, paddle::DataType::FLOAT32, place);
|
|
||||||
// correspond to the weighted coefficients of the results from each expert.
|
// correspond to the weighted coefficients of the results from each expert.
|
||||||
auto topk_weight =
|
auto topk_weight =
|
||||||
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||||
@@ -270,65 +188,39 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
|
|||||||
auto expert_idx_per_token =
|
auto expert_idx_per_token =
|
||||||
GetEmptyTensor({num_rows * moe_topk}, paddle::DataType::INT32, place);
|
GetEmptyTensor({num_rows * moe_topk}, paddle::DataType::INT32, place);
|
||||||
|
|
||||||
if (token_rows == 0) {
|
if (token_rows == 0){
|
||||||
return {permute_input,
|
return {permute_input,
|
||||||
tokens_expert_prefix_sum,
|
tokens_expert_prefix_sum,
|
||||||
permute_indices_per_token,
|
permute_indices_per_token,
|
||||||
topk_weight,
|
topk_weight,
|
||||||
topk_idx,
|
topk_idx,
|
||||||
expert_idx_per_token,
|
expert_idx_per_token};
|
||||||
dequant_scale};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (input_type) {
|
switch (input_type) {
|
||||||
case paddle::DataType::BFLOAT16:
|
case paddle::DataType::BFLOAT16:
|
||||||
MoeDispatchKernel<paddle::DataType::BFLOAT16>(input,
|
MoeDispatchKernel<paddle::DataType::BFLOAT16>(
|
||||||
gating_output,
|
input, gating_output, gating_correction_bias, w4a8_in_scale, moe_topk,
|
||||||
gating_correction_bias,
|
group_moe, topk_only_mode, num_rows, hidden_size, expert_num,
|
||||||
w4a8_in_scale,
|
&permute_input, &tokens_expert_prefix_sum, &permute_indices_per_token,
|
||||||
moe_topk,
|
&topk_weight, &topk_idx, &expert_idx_per_token);
|
||||||
group_moe,
|
break;
|
||||||
topk_only_mode,
|
case paddle::DataType::FLOAT16:
|
||||||
num_rows,
|
MoeDispatchKernel<paddle::DataType::FLOAT16>(
|
||||||
hidden_size,
|
input, gating_output, gating_correction_bias, w4a8_in_scale, moe_topk,
|
||||||
expert_num,
|
group_moe, topk_only_mode, num_rows, hidden_size, expert_num,
|
||||||
&permute_input,
|
&permute_input, &tokens_expert_prefix_sum, &permute_indices_per_token,
|
||||||
&tokens_expert_prefix_sum,
|
&topk_weight, &topk_idx, &expert_idx_per_token);
|
||||||
&permute_indices_per_token,
|
break;
|
||||||
&topk_weight,
|
default:
|
||||||
&topk_idx,
|
PD_THROW("Unsupported data type for MoeDispatchKernel");
|
||||||
&expert_idx_per_token,
|
|
||||||
&dequant_scale);
|
|
||||||
break;
|
|
||||||
case paddle::DataType::FLOAT16:
|
|
||||||
MoeDispatchKernel<paddle::DataType::FLOAT16>(input,
|
|
||||||
gating_output,
|
|
||||||
gating_correction_bias,
|
|
||||||
w4a8_in_scale,
|
|
||||||
moe_topk,
|
|
||||||
group_moe,
|
|
||||||
topk_only_mode,
|
|
||||||
num_rows,
|
|
||||||
hidden_size,
|
|
||||||
expert_num,
|
|
||||||
&permute_input,
|
|
||||||
&tokens_expert_prefix_sum,
|
|
||||||
&permute_indices_per_token,
|
|
||||||
&topk_weight,
|
|
||||||
&topk_idx,
|
|
||||||
&expert_idx_per_token,
|
|
||||||
&dequant_scale);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
PD_THROW("Unsupported data type for MoeDispatchKernel");
|
|
||||||
}
|
}
|
||||||
return {permute_input,
|
return {permute_input,
|
||||||
tokens_expert_prefix_sum,
|
tokens_expert_prefix_sum,
|
||||||
permute_indices_per_token,
|
permute_indices_per_token,
|
||||||
topk_weight,
|
topk_weight,
|
||||||
topk_idx,
|
topk_idx,
|
||||||
expert_idx_per_token,
|
expert_idx_per_token};
|
||||||
dequant_scale};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
|
std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
|
||||||
@@ -353,22 +245,16 @@ std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
|
|||||||
{moe_topk, num_rows},
|
{moe_topk, num_rows},
|
||||||
{num_rows, moe_topk},
|
{num_rows, moe_topk},
|
||||||
{num_rows, moe_topk},
|
{num_rows, moe_topk},
|
||||||
{permuted_rows},
|
{permuted_rows}};
|
||||||
{num_rows}};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
|
std::vector<paddle::DataType>
|
||||||
const paddle::DataType &input_dtype,
|
MoeExpertDispatchInferDtype(const paddle::DataType &input_dtype,
|
||||||
const paddle::DataType &gating_output_dtype,
|
const paddle::DataType &gating_output_dtype,
|
||||||
const paddle::optional<paddle::DataType> &bias_type,
|
const paddle::optional<paddle::DataType> &bias_type,
|
||||||
const int moe_topk) {
|
const int moe_topk) {
|
||||||
return {input_dtype,
|
return {input_dtype, paddle::DataType::INT64, paddle::DataType::INT32,
|
||||||
paddle::DataType::INT64,
|
paddle::DataType::FLOAT32, paddle::DataType::INT32, paddle::DataType::INT32};
|
||||||
paddle::DataType::INT32,
|
|
||||||
paddle::DataType::FLOAT32,
|
|
||||||
paddle::DataType::INT32,
|
|
||||||
paddle::DataType::INT32,
|
|
||||||
paddle::DataType::FLOAT32};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -376,8 +262,7 @@ std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
|
|||||||
*
|
*
|
||||||
* This operator performs the following key functions:
|
* This operator performs the following key functions:
|
||||||
* 1. Computes top-k experts for each input token based on gating scores
|
* 1. Computes top-k experts for each input token based on gating scores
|
||||||
* 2. Permutes input tokens according to their selected experts for efficient
|
* 2. Permutes input tokens according to their selected experts for efficient expert processing
|
||||||
* expert processing
|
|
||||||
* 3. Computes prefix sums of tokens per expert for group_gemm optimization
|
* 3. Computes prefix sums of tokens per expert for group_gemm optimization
|
||||||
*
|
*
|
||||||
* Inputs:
|
* Inputs:
|
||||||
@@ -387,17 +272,18 @@ std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
|
|||||||
* - gating_output: Gating network output scores for each token-expert pair
|
* - gating_output: Gating network output scores for each token-expert pair
|
||||||
* Shape: [total_tokens, expert_num]
|
* Shape: [total_tokens, expert_num]
|
||||||
* dtype: must be float32
|
* dtype: must be float32
|
||||||
* - gating_correction_bias: Optional bias term for gating correction
|
* - gating_correction_bias: Optional bias term for gating correction (expert_num)
|
||||||
* (expert_num)
|
|
||||||
*
|
*
|
||||||
* Outputs:
|
* Outputs:
|
||||||
* - permute_input: Permuted input tensor organized by expert
|
* - permute_input: Permuted input tensor organized by expert
|
||||||
* Shape: [moe_topk * total_tokens, hidden_size]
|
* Shape: [moe_topk * total_tokens, hidden_size]
|
||||||
* dtype: Same as input
|
* dtype: Same as input
|
||||||
* - tokens_expert_prefix_sum: Prefix sum array of token counts per expert for
|
* - tokens_expert_prefix_sum: Prefix sum array of token counts per expert for group_gemm
|
||||||
* group_gemm Shape: [expert_num] dtype: int64
|
* Shape: [expert_num]
|
||||||
* - permute_indices_per_token: Indices mapping for reconstructing original
|
* dtype: int64
|
||||||
* order Shape: [moe_topk, total_tokens] dtype: int32
|
* - permute_indices_per_token: Indices mapping for reconstructing original order
|
||||||
|
* Shape: [moe_topk, total_tokens]
|
||||||
|
* dtype: int32
|
||||||
* - top_k_weight: Weight coefficients for combining expert outputs
|
* - top_k_weight: Weight coefficients for combining expert outputs
|
||||||
* Shape: [total_tokens, moe_topk]
|
* Shape: [total_tokens, moe_topk]
|
||||||
* dtype: float32
|
* dtype: float32
|
||||||
@@ -406,8 +292,7 @@ std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
|
|||||||
* dtype: int32
|
* dtype: int32
|
||||||
*
|
*
|
||||||
* Attributes:
|
* Attributes:
|
||||||
* - moe_topk: Number of experts to select for each token (k value in top-k
|
* - moe_topk: Number of experts to select for each token (k value in top-k routing)
|
||||||
* routing)
|
|
||||||
* - group_moe: Whether to perform group softmax within the operator
|
* - group_moe: Whether to perform group softmax within the operator
|
||||||
* (true: softmax is computed within groups of experts,
|
* (true: softmax is computed within groups of experts,
|
||||||
* false: standard softmax across all experts)
|
* false: standard softmax across all experts)
|
||||||
@@ -421,21 +306,13 @@ std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
|
|||||||
* - When group_moe is true, expert_num must be divisible by moe_topk
|
* - When group_moe is true, expert_num must be divisible by moe_topk
|
||||||
*/
|
*/
|
||||||
PD_BUILD_STATIC_OP(moe_expert_dispatch)
|
PD_BUILD_STATIC_OP(moe_expert_dispatch)
|
||||||
.Inputs({"input",
|
.Inputs({"input", "gating_output",
|
||||||
"gating_output",
|
|
||||||
paddle::Optional("gating_correction_bias"),
|
paddle::Optional("gating_correction_bias"),
|
||||||
paddle::Optional("w4a8_in_scale")})
|
paddle::Optional("w4a8_in_scale")})
|
||||||
.Outputs({"permute_input",
|
.Outputs({"permute_input", "tokens_expert_prefix_sum",
|
||||||
"tokens_expert_prefix_sum",
|
"permute_indices_per_token", "topk_weight", "topk_idx",
|
||||||
"permute_indices_per_token",
|
"expert_idx_per_token"})
|
||||||
"topk_weight",
|
.Attrs({"moe_topk:int", "group_moe:bool", "moe_quant_type:std::string", "topk_only_mode:bool"})
|
||||||
"topk_idx",
|
|
||||||
"expert_idx_per_token",
|
|
||||||
"dequant_scale"})
|
|
||||||
.Attrs({"moe_topk:int",
|
|
||||||
"group_moe:bool",
|
|
||||||
"moe_quant_type:std::string",
|
|
||||||
"topk_only_mode:bool"})
|
|
||||||
.SetKernelFn(PD_KERNEL(MoeExpertDispatch))
|
.SetKernelFn(PD_KERNEL(MoeExpertDispatch))
|
||||||
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape))
|
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape))
|
||||||
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertDispatchInferDtype));
|
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertDispatchInferDtype));
|
||||||
|
|||||||
@@ -17,7 +17,6 @@
|
|||||||
#include "cutlass/numeric_conversion.h"
|
#include "cutlass/numeric_conversion.h"
|
||||||
#include "group_swiglu_with_masked.h"
|
#include "group_swiglu_with_masked.h"
|
||||||
#include "helper.h"
|
#include "helper.h"
|
||||||
#include "moe/fast_hardmard/fast_hardamard_kernel.h"
|
|
||||||
#include "moe/fused_moe_helper.h"
|
#include "moe/fused_moe_helper.h"
|
||||||
|
|
||||||
template <typename DataT,
|
template <typename DataT,
|
||||||
|
|||||||
@@ -18,8 +18,8 @@
|
|||||||
#include "cutlass_kernels/w4a8_moe/w4a8_moe_gemm_kernel.h"
|
#include "cutlass_kernels/w4a8_moe/w4a8_moe_gemm_kernel.h"
|
||||||
#include "group_swiglu_with_masked.h"
|
#include "group_swiglu_with_masked.h"
|
||||||
#include "helper.h"
|
#include "helper.h"
|
||||||
#include "moe/fast_hardmard/fast_hardamard_kernel.h"
|
|
||||||
#include "moe/fused_moe_helper.h"
|
#include "moe/fused_moe_helper.h"
|
||||||
|
#include "moe/moe_fast_hardamard_kernel.h"
|
||||||
#include "swigluoai.h"
|
#include "swigluoai.h"
|
||||||
#include "w4afp8_gemm/w4afp8_gemm.h"
|
#include "w4afp8_gemm/w4afp8_gemm.h"
|
||||||
|
|
||||||
@@ -28,7 +28,6 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
|||||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||||
const paddle::Tensor& up_gate_proj_weight,
|
const paddle::Tensor& up_gate_proj_weight,
|
||||||
const paddle::Tensor& down_proj_weight,
|
const paddle::Tensor& down_proj_weight,
|
||||||
const paddle::optional<paddle::Tensor>& up_proj_in_scale,
|
|
||||||
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||||
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
||||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||||
@@ -198,20 +197,31 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
|||||||
typedef PDTraits<paddle::DataType::FLOAT8_E4M3FN> traits_fp8;
|
typedef PDTraits<paddle::DataType::FLOAT8_E4M3FN> traits_fp8;
|
||||||
typedef typename traits_fp8::DataType DataType_fp8;
|
typedef typename traits_fp8::DataType DataType_fp8;
|
||||||
typedef typename traits_fp8::data_t data_t_fp8;
|
typedef typename traits_fp8::data_t data_t_fp8;
|
||||||
paddle::Tensor weight_scale_tensor =
|
|
||||||
*const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr());
|
Allocator::AllocationPtr ffn1_input_row_sum;
|
||||||
const int weight_scale_group_size = weight_scale_tensor.dims().size() == 2
|
ffn1_input_row_sum =
|
||||||
? hidden_size
|
allocator->Allocate(sizeof(float) * expanded_active_expert_rows);
|
||||||
: weight_scale_tensor.dims()[3];
|
|
||||||
const float* input_dequant_scale =
|
compute_row_sum(
|
||||||
up_proj_in_scale ? up_proj_in_scale.get().data<float>() : nullptr;
|
permute_input.data<data_t_fp8>(),
|
||||||
|
expanded_active_expert_rows,
|
||||||
|
hidden_size,
|
||||||
|
reinterpret_cast<float*>(ffn1_input_row_sum->ptr()),
|
||||||
|
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||||
|
num_max_tokens_per_expert,
|
||||||
|
used_in_ep_low_latency,
|
||||||
|
stream);
|
||||||
|
|
||||||
|
float* row_scale = nullptr;
|
||||||
DisPatchW4AFp8GemmWrapper(
|
DisPatchW4AFp8GemmWrapper(
|
||||||
reinterpret_cast<const DataType_fp8*>(permute_input.data<data_t_fp8>()),
|
reinterpret_cast<const DataType_fp8*>(permute_input.data<data_t_fp8>()),
|
||||||
reinterpret_cast<const DataType_fp8*>(
|
reinterpret_cast<const DataType_fp8*>(
|
||||||
up_gate_proj_weight.data<int8_t>()),
|
up_gate_proj_weight.data<int8_t>()),
|
||||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||||
input_dequant_scale,
|
reinterpret_cast<float*>(ffn1_input_row_sum->ptr()),
|
||||||
weight_scale_tensor.data<float>(),
|
row_scale,
|
||||||
|
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr())
|
||||||
|
->data<float>(),
|
||||||
reinterpret_cast<NvType*>(fc1_out),
|
reinterpret_cast<NvType*>(fc1_out),
|
||||||
used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
|
used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
|
||||||
used_in_ep_low_latency ? num_max_tokens_per_expert
|
used_in_ep_low_latency ? num_max_tokens_per_expert
|
||||||
@@ -219,7 +229,6 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
|||||||
num_experts,
|
num_experts,
|
||||||
inter_size,
|
inter_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
weight_scale_group_size,
|
|
||||||
stream);
|
stream);
|
||||||
} else {
|
} else {
|
||||||
typename cutlass::WintQuantTraits<
|
typename cutlass::WintQuantTraits<
|
||||||
@@ -346,84 +355,60 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
|||||||
} else if (quant_method == "w4afp8") {
|
} else if (quant_method == "w4afp8") {
|
||||||
data_t* ffn2_shift = nullptr;
|
data_t* ffn2_shift = nullptr;
|
||||||
data_t* ffn2_smooth = nullptr;
|
data_t* ffn2_smooth = nullptr;
|
||||||
float* input_dequant_scale = nullptr;
|
float* row_scale = nullptr;
|
||||||
Allocator::AllocationPtr fp8_act_out;
|
Allocator::AllocationPtr fp8_act_out;
|
||||||
fp8_act_out = allocator->Allocate(SizeOf(paddle::DataType::INT8) *
|
fp8_act_out = allocator->Allocate(SizeOf(paddle::DataType::INT8) *
|
||||||
act_out_tensor.numel());
|
act_out_tensor.numel());
|
||||||
|
Allocator::AllocationPtr ffn2_input_row_sum;
|
||||||
|
ffn2_input_row_sum =
|
||||||
|
allocator->Allocate(sizeof(float) * expanded_active_expert_rows);
|
||||||
|
|
||||||
if (down_proj_in_scale) {
|
// note(yuanxiaolan): optimize this
|
||||||
MoeFastHardamardWrapper<data_t, data_t_fp8>(
|
MoeFastHardamardWrapper<data_t, data_t>(
|
||||||
act_out_tensor.data<data_t>(),
|
act_out_tensor.data<data_t>(),
|
||||||
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>()
|
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>()
|
||||||
: nullptr,
|
: nullptr,
|
||||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||||
ffn2_shift,
|
ffn2_shift, // ffn2_shift->data<T>(),
|
||||||
ffn2_smooth,
|
ffn2_smooth, // ffn2_smooth->data<T>(),
|
||||||
down_proj_in_scale
|
nullptr,
|
||||||
? const_cast<paddle::Tensor*>(down_proj_in_scale.get_ptr())
|
1,
|
||||||
->data<float>()
|
448.0f,
|
||||||
: nullptr,
|
-448.0f,
|
||||||
1,
|
expanded_active_expert_rows,
|
||||||
448.0f,
|
inter_size / 2,
|
||||||
-448.0f,
|
num_max_tokens_per_expert,
|
||||||
expanded_active_expert_rows,
|
used_in_ep_low_latency,
|
||||||
inter_size / 2,
|
hadamard_block_size,
|
||||||
num_max_tokens_per_expert,
|
act_out_tensor.data<data_t>(),
|
||||||
used_in_ep_low_latency,
|
stream);
|
||||||
hadamard_block_size,
|
|
||||||
reinterpret_cast<data_t_fp8*>(fp8_act_out->ptr()),
|
|
||||||
stream);
|
|
||||||
} else {
|
|
||||||
Allocator::AllocationPtr ffn2_input_dequant_scale;
|
|
||||||
ffn2_input_dequant_scale =
|
|
||||||
allocator->Allocate(sizeof(float) * expanded_active_expert_rows);
|
|
||||||
input_dequant_scale =
|
|
||||||
reinterpret_cast<float*>(ffn2_input_dequant_scale->ptr());
|
|
||||||
MoeFastHardamardWrapper<data_t, data_t>(
|
|
||||||
act_out_tensor.data<data_t>(),
|
|
||||||
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>()
|
|
||||||
: nullptr,
|
|
||||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
|
||||||
ffn2_shift, // ffn2_shift->data<T>(),
|
|
||||||
ffn2_smooth, // ffn2_smooth->data<T>(),
|
|
||||||
nullptr,
|
|
||||||
1,
|
|
||||||
448.0f,
|
|
||||||
-448.0f,
|
|
||||||
expanded_active_expert_rows,
|
|
||||||
inter_size / 2,
|
|
||||||
num_max_tokens_per_expert,
|
|
||||||
used_in_ep_low_latency,
|
|
||||||
hadamard_block_size,
|
|
||||||
act_out_tensor.data<data_t>(),
|
|
||||||
stream);
|
|
||||||
|
|
||||||
quantize_moe_input<data_t, data_t_fp8>(
|
quantize_moe_input<data_t, data_t_fp8>(
|
||||||
act_out_tensor.data<data_t>(),
|
act_out_tensor.data<data_t>(),
|
||||||
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>()
|
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>()
|
||||||
: nullptr,
|
: nullptr,
|
||||||
expanded_active_expert_rows,
|
down_proj_in_scale
|
||||||
inter_size / 2,
|
? const_cast<paddle::Tensor*>(down_proj_in_scale.get_ptr())
|
||||||
input_dequant_scale,
|
->data<float>()
|
||||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
: nullptr,
|
||||||
num_max_tokens_per_expert,
|
448.0f,
|
||||||
used_in_ep_low_latency,
|
-448.0f,
|
||||||
reinterpret_cast<data_t_fp8*>(fp8_act_out->ptr()),
|
expanded_active_expert_rows,
|
||||||
stream);
|
inter_size / 2,
|
||||||
}
|
reinterpret_cast<float*>(ffn2_input_row_sum->ptr()),
|
||||||
|
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||||
paddle::Tensor weight_scale_tensor =
|
num_max_tokens_per_expert,
|
||||||
*const_cast<paddle::Tensor*>(down_proj_scale.get_ptr());
|
used_in_ep_low_latency,
|
||||||
const int weight_scale_group_size = weight_scale_tensor.dims().size() == 2
|
reinterpret_cast<data_t_fp8*>(fp8_act_out->ptr()),
|
||||||
? inter_size / 2
|
stream);
|
||||||
: weight_scale_tensor.dims()[3];
|
|
||||||
|
|
||||||
DisPatchW4AFp8GemmWrapper(
|
DisPatchW4AFp8GemmWrapper(
|
||||||
reinterpret_cast<const DataType_fp8*>(fp8_act_out->ptr()),
|
reinterpret_cast<const DataType_fp8*>(fp8_act_out->ptr()),
|
||||||
reinterpret_cast<const DataType_fp8*>(down_proj_weight.data<int8_t>()),
|
reinterpret_cast<const DataType_fp8*>(down_proj_weight.data<int8_t>()),
|
||||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||||
input_dequant_scale,
|
reinterpret_cast<float*>(ffn2_input_row_sum->ptr()),
|
||||||
weight_scale_tensor.data<float>(),
|
row_scale,
|
||||||
|
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr())->data<float>(),
|
||||||
reinterpret_cast<NvType*>(ffn_out_data),
|
reinterpret_cast<NvType*>(ffn_out_data),
|
||||||
used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
|
used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
|
||||||
used_in_ep_low_latency ? num_max_tokens_per_expert
|
used_in_ep_low_latency ? num_max_tokens_per_expert
|
||||||
@@ -431,7 +416,6 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
|||||||
num_experts,
|
num_experts,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
inter_size / 2,
|
inter_size / 2,
|
||||||
weight_scale_group_size,
|
|
||||||
stream);
|
stream);
|
||||||
} else {
|
} else {
|
||||||
typename cutlass::WintQuantTraits<
|
typename cutlass::WintQuantTraits<
|
||||||
@@ -458,7 +442,6 @@ paddle::Tensor MoeExpertFFNFunc(
|
|||||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||||
const paddle::Tensor& up_gate_proj_weight,
|
const paddle::Tensor& up_gate_proj_weight,
|
||||||
const paddle::Tensor& down_proj_weight,
|
const paddle::Tensor& down_proj_weight,
|
||||||
const paddle::optional<paddle::Tensor>& up_proj_in_scale,
|
|
||||||
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||||
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
||||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||||
@@ -483,7 +466,6 @@ paddle::Tensor MoeExpertFFNFunc(
|
|||||||
tokens_expert_prefix_sum,
|
tokens_expert_prefix_sum,
|
||||||
up_gate_proj_weight,
|
up_gate_proj_weight,
|
||||||
down_proj_weight,
|
down_proj_weight,
|
||||||
up_proj_in_scale,
|
|
||||||
up_gate_proj_bias,
|
up_gate_proj_bias,
|
||||||
up_gate_proj_scale,
|
up_gate_proj_scale,
|
||||||
down_proj_scale,
|
down_proj_scale,
|
||||||
@@ -501,7 +483,6 @@ paddle::Tensor MoeExpertFFNFunc(
|
|||||||
tokens_expert_prefix_sum,
|
tokens_expert_prefix_sum,
|
||||||
up_gate_proj_weight,
|
up_gate_proj_weight,
|
||||||
down_proj_weight,
|
down_proj_weight,
|
||||||
up_proj_in_scale,
|
|
||||||
up_gate_proj_bias,
|
up_gate_proj_bias,
|
||||||
up_gate_proj_scale,
|
up_gate_proj_scale,
|
||||||
down_proj_scale,
|
down_proj_scale,
|
||||||
@@ -525,7 +506,6 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
|||||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||||
const paddle::Tensor& up_gate_proj_weight,
|
const paddle::Tensor& up_gate_proj_weight,
|
||||||
const paddle::Tensor& down_proj_weight,
|
const paddle::Tensor& down_proj_weight,
|
||||||
const paddle::optional<paddle::Tensor>& up_proj_in_scale,
|
|
||||||
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||||
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
||||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||||
@@ -540,7 +520,6 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
|||||||
tokens_expert_prefix_sum,
|
tokens_expert_prefix_sum,
|
||||||
up_gate_proj_weight,
|
up_gate_proj_weight,
|
||||||
down_proj_weight,
|
down_proj_weight,
|
||||||
up_proj_in_scale,
|
|
||||||
up_gate_proj_bias,
|
up_gate_proj_bias,
|
||||||
up_gate_proj_scale,
|
up_gate_proj_scale,
|
||||||
down_proj_scale,
|
down_proj_scale,
|
||||||
@@ -558,7 +537,6 @@ std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
|
|||||||
const std::vector<int64_t>& tokens_expert_prefix_sum_shape,
|
const std::vector<int64_t>& tokens_expert_prefix_sum_shape,
|
||||||
const std::vector<int64_t>& up_gate_proj_weight_shape,
|
const std::vector<int64_t>& up_gate_proj_weight_shape,
|
||||||
const std::vector<int64_t>& down_proj_weight_shape,
|
const std::vector<int64_t>& down_proj_weight_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& up_proj_in_scale_shape,
|
|
||||||
const paddle::optional<std::vector<int64_t>>& up_gate_proj_bias_shape,
|
const paddle::optional<std::vector<int64_t>>& up_gate_proj_bias_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& up_gate_proj_scale_shape,
|
const paddle::optional<std::vector<int64_t>>& up_gate_proj_scale_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& down_proj_scale_shape,
|
const paddle::optional<std::vector<int64_t>>& down_proj_scale_shape,
|
||||||
@@ -577,7 +555,6 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
|
|||||||
const paddle::DataType& tokens_expert_prefix_sum_dtype,
|
const paddle::DataType& tokens_expert_prefix_sum_dtype,
|
||||||
const paddle::DataType& up_gate_proj_weight_dtype,
|
const paddle::DataType& up_gate_proj_weight_dtype,
|
||||||
const paddle::DataType& down_proj_weight_dtype,
|
const paddle::DataType& down_proj_weight_dtype,
|
||||||
const paddle::optional<paddle::DataType>& up_proj_in_scale_dtype,
|
|
||||||
const paddle::optional<paddle::DataType>& up_gate_proj_bias_dtype,
|
const paddle::optional<paddle::DataType>& up_gate_proj_bias_dtype,
|
||||||
const paddle::optional<paddle::DataType>& up_gate_proj_scale_dtype,
|
const paddle::optional<paddle::DataType>& up_gate_proj_scale_dtype,
|
||||||
const paddle::optional<paddle::DataType>& down_proj_scale_dtype,
|
const paddle::optional<paddle::DataType>& down_proj_scale_dtype,
|
||||||
@@ -655,7 +632,6 @@ PD_BUILD_STATIC_OP(moe_expert_ffn)
|
|||||||
"tokens_expert_prefix_sum",
|
"tokens_expert_prefix_sum",
|
||||||
"up_gate_proj_weight",
|
"up_gate_proj_weight",
|
||||||
"down_proj_weight",
|
"down_proj_weight",
|
||||||
paddle::Optional("up_proj_in_scale"),
|
|
||||||
paddle::Optional("up_gate_proj_bias"),
|
paddle::Optional("up_gate_proj_bias"),
|
||||||
paddle::Optional("up_gate_proj_scale"),
|
paddle::Optional("up_gate_proj_scale"),
|
||||||
paddle::Optional("down_proj_scale"),
|
paddle::Optional("down_proj_scale"),
|
||||||
|
|||||||
@@ -23,142 +23,132 @@
|
|||||||
|
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
|
|
||||||
template <int kStages,
|
template <int kStages, class GemmType, class OutputType, class SmemLayoutA,
|
||||||
class GemmType,
|
class SmemLayoutB, class SmemLayoutC>
|
||||||
class OutputType,
|
|
||||||
class SmemLayoutA,
|
|
||||||
class SmemLayoutB,
|
|
||||||
class SmemLayoutC,
|
|
||||||
class SmemLayoutScale>
|
|
||||||
struct SharedStorage {
|
struct SharedStorage {
|
||||||
union {
|
union {
|
||||||
struct {
|
struct {
|
||||||
cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutA>> smem_a;
|
cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutA>> smem_a;
|
||||||
cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutB>> smem_b;
|
cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutB>> smem_b;
|
||||||
cute::array_aligned<float, cute::cosize_v<SmemLayoutScale>> smem_scale;
|
};
|
||||||
|
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutC>> smem_c;
|
||||||
};
|
};
|
||||||
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutC>> smem_c;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct {
|
struct {
|
||||||
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline;
|
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline;
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
template <int kBlockM_,
|
template<int kBlockM_, int kBlockN_, int kBlockK_,
|
||||||
int kBlockN_,
|
int kNWarps_, int kStages_,
|
||||||
int kBlockK_,
|
int kTiles_, int M_,
|
||||||
int kNWarps_,
|
int TokenPackSize_,
|
||||||
int kStages_,
|
int TAIL_N_ = 0,
|
||||||
int kTiles_,
|
int kClusterM_ = 1,
|
||||||
int M_,
|
typename elem_type=cutlass::float_e4m3_t,
|
||||||
int K_,
|
typename OutputType = cutlass::bfloat16_t>
|
||||||
int TokenPackSize_,
|
|
||||||
int WeightScaleGroup_,
|
|
||||||
int kClusterM_ = 1,
|
|
||||||
typename elem_type = cutlass::float_e4m3_t,
|
|
||||||
typename OutputType = cutlass::bfloat16_t>
|
|
||||||
struct Kernel_traits {
|
struct Kernel_traits {
|
||||||
using Element = elem_type;
|
using Element = elem_type;
|
||||||
using ElementOutput = OutputType;
|
using ElementAccum = float;
|
||||||
using ElementAccum = typename std::
|
using ElementOutput = OutputType;
|
||||||
conditional_t<WeightScaleGroup_ == K_, float, cutlass::half_t>;
|
static_assert(cutlass::sizeof_bits_v<Element> == 8);
|
||||||
static_assert(cutlass::sizeof_bits_v<Element> == 8);
|
|
||||||
|
|
||||||
static constexpr int kNWarps = kNWarps_;
|
static constexpr int kNWarps = kNWarps_;
|
||||||
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
|
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
|
||||||
static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
|
static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
|
||||||
static constexpr int NumMmaThreads = kNThreads - NumProducerThreads;
|
static constexpr int NumMmaThreads = kNThreads - NumProducerThreads;
|
||||||
|
|
||||||
static_assert(kNWarps_ == 12 || kNWarps_ == 16);
|
static_assert(kNWarps_ == 12 || kNWarps_ == 16);
|
||||||
|
|
||||||
static constexpr int kBlockM = kBlockM_;
|
static constexpr int kBlockM = kBlockM_;
|
||||||
static constexpr int kBlockN = kBlockN_;
|
static constexpr int kBlockN = kBlockN_;
|
||||||
static constexpr int kBlockK = kBlockK_;
|
static constexpr int kBlockK = kBlockK_;
|
||||||
static constexpr int kTiles = kTiles_;
|
static constexpr int kTiles = kTiles_;
|
||||||
static constexpr int TokenPackSize = TokenPackSize_;
|
static constexpr int TokenPackSize = TokenPackSize_;
|
||||||
static constexpr int M = M_;
|
static constexpr int M = M_;
|
||||||
static constexpr int K = K_;
|
static constexpr int TAIL_N = TAIL_N_;
|
||||||
static constexpr int WeightScaleGroup = WeightScaleGroup_;
|
|
||||||
|
|
||||||
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kBlockK>>;
|
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kBlockK>>;
|
||||||
|
using TileShape_MNK_TAIL = Shape<Int<kBlockM>, Int<TAIL_N>, Int<kBlockK>>;
|
||||||
|
|
||||||
static constexpr int kClusterM = kClusterM_;
|
static constexpr int kClusterM = kClusterM_;
|
||||||
using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
|
using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
|
||||||
|
|
||||||
static constexpr int kStages = kStages_;
|
static constexpr int kStages = kStages_;
|
||||||
static_assert(kStages > 1);
|
static_assert(kStages > 1);
|
||||||
|
|
||||||
using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
|
using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
|
||||||
|
|
||||||
using TiledMma = decltype(cute::make_tiled_mma(
|
using TiledMma = decltype(cute::make_tiled_mma(
|
||||||
cute::GMMA::
|
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
|
||||||
rs_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
|
AtomLayoutMNK{}));
|
||||||
AtomLayoutMNK{}));
|
|
||||||
|
|
||||||
using SmemLayoutAtomA =
|
using TiledMma_TAIL = decltype(cute::make_tiled_mma(
|
||||||
decltype(cutlass::gemm::collective::detail::rs_smem_selector<
|
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK_TAIL>(),
|
||||||
GMMA::Major::K,
|
AtomLayoutMNK{}));
|
||||||
Element,
|
|
||||||
Int<kBlockM>,
|
|
||||||
Int<kBlockK / 2>>());
|
|
||||||
|
|
||||||
using SmemLayoutA = decltype(tile_to_shape(
|
using SmemLayoutAtomA = decltype(
|
||||||
SmemLayoutAtomA{},
|
cutlass::gemm::collective::detail::rs_smem_selector<
|
||||||
make_shape(Int<kBlockM>{}, Int<kBlockK / 2>{}, Int<kStages>{})));
|
GMMA::Major::K, Element, Int<kBlockM>, Int<kBlockK / 2>>());
|
||||||
|
|
||||||
using SmemLayoutAtomB =
|
using SmemLayoutA = decltype(
|
||||||
decltype(cutlass::gemm::collective::detail::rs_smem_selector<
|
tile_to_shape(SmemLayoutAtomA{},
|
||||||
GMMA::Major::K,
|
make_shape(Int<kBlockM>{}, Int<kBlockK / 2>{}, Int<kStages>{})));
|
||||||
Element,
|
|
||||||
decltype(cute::get<1>(TileShape_MNK{})),
|
|
||||||
decltype(cute::get<2>(TileShape_MNK{}))>());
|
|
||||||
|
|
||||||
using SmemLayoutB =
|
using SmemLayoutAtomB = decltype(
|
||||||
decltype(tile_to_shape(SmemLayoutAtomB{},
|
cutlass::gemm::collective::detail::rs_smem_selector<
|
||||||
make_shape(shape<1>(TileShape_MNK{}),
|
GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK{})),
|
||||||
shape<2>(TileShape_MNK{}),
|
decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||||
Int<kStages>{})));
|
|
||||||
using SmemLayoutAtomC =
|
|
||||||
decltype(cutlass::gemm::collective::detail::rs_smem_selector<
|
|
||||||
GMMA::Major::K,
|
|
||||||
ElementOutput,
|
|
||||||
decltype(cute::get<0>(TileShape_MNK{})),
|
|
||||||
decltype(cute::get<1>(TileShape_MNK{}))>());
|
|
||||||
|
|
||||||
using SmemLayoutC =
|
using SmemLayoutB = decltype(
|
||||||
decltype(tile_to_shape(SmemLayoutAtomC{}, select<0, 1>(TileShape_MNK{})));
|
tile_to_shape(SmemLayoutAtomB{},
|
||||||
|
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
|
||||||
|
|
||||||
using SmemCopyAtomAB = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
|
using SmemLayoutAtomB_TAIL = decltype(
|
||||||
using SmemCopyAtomC = Copy_Atom<cute::SM90_U32x4_STSM_N, ElementOutput>;
|
cutlass::gemm::collective::detail::rs_smem_selector<
|
||||||
|
GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK_TAIL{})),
|
||||||
|
decltype(cute::get<2>(TileShape_MNK_TAIL{}))>());
|
||||||
|
|
||||||
using SmemLayoutScale = Layout<Shape<Int<kBlockM>, Int<kStages>>>;
|
using SmemLayoutB_TAIL = decltype(
|
||||||
|
tile_to_shape(SmemLayoutAtomB_TAIL{},
|
||||||
|
make_shape(
|
||||||
|
shape<1>(TileShape_MNK_TAIL{}),
|
||||||
|
shape<2>(TileShape_MNK_TAIL{}),
|
||||||
|
Int<kStages>{})
|
||||||
|
));
|
||||||
|
|
||||||
using SharedStorage = SharedStorage<kStages,
|
using SmemLayoutAtomC = decltype(
|
||||||
Element,
|
cutlass::gemm::collective::detail::rs_smem_selector<
|
||||||
ElementOutput,
|
GMMA::Major::K, ElementOutput,
|
||||||
SmemLayoutA,
|
decltype(cute::get<0>(TileShape_MNK{})),
|
||||||
SmemLayoutB,
|
decltype(cute::get<1>(TileShape_MNK{}))>());
|
||||||
SmemLayoutC,
|
|
||||||
SmemLayoutScale>;
|
|
||||||
|
|
||||||
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
|
using SmemLayoutC = decltype(tile_to_shape(SmemLayoutAtomC{}, select<0, 1>(TileShape_MNK{})));
|
||||||
using PipelineState = typename cutlass::PipelineState<kStages>;
|
|
||||||
|
|
||||||
static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<OutputType>);
|
using SmemCopyAtomAB = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
|
||||||
static constexpr int kNumThreadsPerRow = kBlockN / kNumVecElem;
|
using SmemCopyAtomC = Copy_Atom<cute::SM90_U32x4_STSM_N, ElementOutput>;
|
||||||
// static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
|
|
||||||
static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
|
using SharedStorage = SharedStorage<
|
||||||
using TiledCopyCAtom =
|
kStages, Element, ElementOutput, SmemLayoutA, SmemLayoutB, SmemLayoutC>;
|
||||||
cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, OutputType>;
|
|
||||||
using TiledCopyCThrLayout = decltype(cute::make_layout(
|
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
|
||||||
cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
|
using PipelineState = typename cutlass::PipelineState<kStages>;
|
||||||
LayoutRight{}));
|
|
||||||
using TiledCopyCValLayout = decltype(cute::make_layout(
|
|
||||||
cute::make_shape(_1{}, Int<kNumVecElem>{}), LayoutRight{}));
|
static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<OutputType>);
|
||||||
using TiledCopyC =
|
static constexpr int kNumThreadsPerRow = kBlockN / kNumVecElem;
|
||||||
decltype(make_tiled_copy(TiledCopyCAtom{},
|
// static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
|
||||||
TiledCopyCThrLayout{}, // Thr layout
|
static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
|
||||||
TiledCopyCValLayout{} // Val layout
|
using TiledCopyCAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, OutputType>;
|
||||||
));
|
using TiledCopyCThrLayout = decltype(cute::make_layout(
|
||||||
|
cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
|
||||||
|
LayoutRight{}));
|
||||||
|
using TiledCopyCValLayout = decltype(cute::make_layout(
|
||||||
|
cute::make_shape(_1{}, Int<kNumVecElem>{}),
|
||||||
|
LayoutRight{}));
|
||||||
|
using TiledCopyC = decltype(make_tiled_copy(
|
||||||
|
TiledCopyCAtom{},
|
||||||
|
TiledCopyCThrLayout{}, // Thr layout
|
||||||
|
TiledCopyCValLayout{} // Val layout
|
||||||
|
));
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -14,10 +14,10 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <cutlass/array.h>
|
|
||||||
#include <cutlass/cutlass.h>
|
#include <cutlass/cutlass.h>
|
||||||
#include <cutlass/numeric_conversion.h>
|
#include <cutlass/array.h>
|
||||||
#include <cutlass/numeric_types.h>
|
#include <cutlass/numeric_types.h>
|
||||||
|
#include <cutlass/numeric_conversion.h>
|
||||||
#include "cutlass/pipeline/pipeline.hpp"
|
#include "cutlass/pipeline/pipeline.hpp"
|
||||||
|
|
||||||
#include "cute/tensor.hpp"
|
#include "cute/tensor.hpp"
|
||||||
@@ -27,544 +27,368 @@
|
|||||||
// #include "named_barrier.hpp"
|
// #include "named_barrier.hpp"
|
||||||
#include "utils.hpp"
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
template <typename Ktraits>
|
template <typename Ktraits>
|
||||||
struct CollectiveMainloopFwd {
|
struct CollectiveMainloopFwd {
|
||||||
using Element = typename Ktraits::Element;
|
|
||||||
using ElementOutput = typename Ktraits::ElementOutput;
|
|
||||||
using TileShape_MNK = typename Ktraits::TileShape_MNK;
|
|
||||||
using ClusterShape = typename Ktraits::ClusterShape_MNK;
|
|
||||||
using ElementAccum = typename Ktraits::ElementAccum;
|
|
||||||
|
|
||||||
static constexpr int kStages = Ktraits::kStages;
|
using Element = typename Ktraits::Element;
|
||||||
static constexpr int kBlockM = Ktraits::kBlockM;
|
using ElementOutput = typename Ktraits::ElementOutput;
|
||||||
static constexpr int kBlockN = Ktraits::kBlockN;
|
using TileShape_MNK = typename Ktraits::TileShape_MNK;
|
||||||
static constexpr int kBlockK = Ktraits::kBlockK;
|
using TileShape_MNK_TAIL = typename Ktraits::TileShape_MNK_TAIL;
|
||||||
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
|
using ClusterShape = typename Ktraits::ClusterShape_MNK;
|
||||||
static constexpr int kTiles = Ktraits::kTiles;
|
using ElementAccum = typename Ktraits::ElementAccum;
|
||||||
static constexpr int M = Ktraits::M;
|
|
||||||
static constexpr int K = Ktraits::K;
|
|
||||||
static constexpr int TokenPackSize = Ktraits::TokenPackSize;
|
|
||||||
static constexpr int WeightScaleGroup = Ktraits::WeightScaleGroup;
|
|
||||||
|
|
||||||
using GmemTiledCopy = cute::SM90_TMA_LOAD;
|
static constexpr int kStages = Ktraits::kStages;
|
||||||
|
static constexpr int kBlockM = Ktraits::kBlockM;
|
||||||
|
static constexpr int kBlockN = Ktraits::kBlockN;
|
||||||
|
static constexpr int TAIL_N = Ktraits::TAIL_N;
|
||||||
|
static constexpr int kBlockK = Ktraits::kBlockK;
|
||||||
|
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
|
||||||
|
static constexpr int kTiles = Ktraits::kTiles;
|
||||||
|
static constexpr int M = Ktraits::M;
|
||||||
|
static constexpr int TokenPackSize = Ktraits::TokenPackSize;
|
||||||
|
|
||||||
using SmemLayoutA = typename Ktraits::SmemLayoutA;
|
using GmemTiledCopy = cute::SM90_TMA_LOAD;
|
||||||
using SmemLayoutB = typename Ktraits::SmemLayoutB;
|
|
||||||
using SmemLayoutC = typename Ktraits::SmemLayoutC;
|
|
||||||
using SmemLayoutScale = typename Ktraits::SmemLayoutScale;
|
|
||||||
|
|
||||||
using ShapeT = cute::Shape<int64_t, int64_t, int64_t>;
|
|
||||||
using StrideT = cute::Shape<int64_t, _1, int64_t>;
|
|
||||||
using LayoutT = cute::Layout<ShapeT, StrideT>;
|
|
||||||
|
|
||||||
using ShapeTScale = cute::Shape<int64_t, int64_t, int64_t>;
|
using SmemLayoutA = typename Ktraits::SmemLayoutA;
|
||||||
using StrideTScale = cute::Shape<_1, int64_t, int64_t>;
|
using SmemLayoutB = typename Ktraits::SmemLayoutB;
|
||||||
using LayoutTScale = cute::Layout<ShapeTScale, StrideTScale>;
|
using SmemLayoutC = typename Ktraits::SmemLayoutC;
|
||||||
|
using SmemLayoutB_TAIL = typename Ktraits::SmemLayoutB_TAIL;
|
||||||
|
|
||||||
using TMA_A = decltype(make_tma_copy(
|
using ShapeT = cute::Shape<int64_t, int64_t, int64_t>;
|
||||||
GmemTiledCopy{},
|
using StrideT = cute::Shape<int64_t, _1, int64_t>;
|
||||||
make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)),
|
using LayoutT = cute::Layout<ShapeT, StrideT>;
|
||||||
ShapeT{},
|
|
||||||
StrideT{}),
|
|
||||||
SmemLayoutA{}(_, _, _0{}),
|
|
||||||
select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}),
|
|
||||||
size<0>(ClusterShape{})));
|
|
||||||
|
|
||||||
using TMA_B = decltype(make_tma_copy(
|
using TMA_A = decltype(make_tma_copy(
|
||||||
GmemTiledCopy{},
|
GmemTiledCopy{},
|
||||||
make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)),
|
make_tensor(
|
||||||
ShapeT{},
|
make_gmem_ptr(static_cast<Element const*>(nullptr)),
|
||||||
StrideT{}),
|
ShapeT{},
|
||||||
take<0, 2>(SmemLayoutB{}),
|
StrideT{}
|
||||||
select<1, 2>(TileShape_MNK{}),
|
),
|
||||||
size<0>(ClusterShape{})));
|
SmemLayoutA{}(_, _, _0{}),
|
||||||
|
select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}),
|
||||||
|
size<0>(ClusterShape{})));
|
||||||
|
|
||||||
using TMA_Scale = decltype(make_tma_copy(
|
using TMA_B = decltype(make_tma_copy(
|
||||||
GmemTiledCopy{},
|
GmemTiledCopy{},
|
||||||
make_tensor(make_gmem_ptr(static_cast<float const*>(nullptr)),
|
make_tensor(
|
||||||
ShapeTScale{},
|
make_gmem_ptr(static_cast<Element const*>(nullptr)),
|
||||||
StrideTScale{}),
|
ShapeT{},
|
||||||
SmemLayoutScale{}(_, _0{}),
|
StrideT{}
|
||||||
select<0>(Shape<Int<kBlockM>>{}),
|
),
|
||||||
size<0>(ClusterShape{})));
|
take<0, 2>(SmemLayoutB{}),
|
||||||
|
select<1, 2>(TileShape_MNK{}),
|
||||||
|
size<0>(ClusterShape{})));
|
||||||
|
|
||||||
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{});
|
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{});
|
||||||
using MainloopPipeline = typename Ktraits::MainloopPipeline;
|
using MainloopPipeline = typename Ktraits::MainloopPipeline;
|
||||||
using PipelineParams = typename MainloopPipeline::Params;
|
using PipelineParams = typename MainloopPipeline::Params;
|
||||||
using PipelineState = typename MainloopPipeline::PipelineState;
|
using PipelineState = typename MainloopPipeline::PipelineState;
|
||||||
using SmemCopyAtomAB = typename Ktraits::SmemCopyAtomAB;
|
using SmemCopyAtomAB = typename Ktraits::SmemCopyAtomAB;
|
||||||
using SmemCopyAtomC = typename Ktraits::SmemCopyAtomC;
|
using SmemCopyAtomC = typename Ktraits::SmemCopyAtomC;
|
||||||
using TiledCopyC = typename Ktraits::TiledCopyC;
|
using TiledCopyC = typename Ktraits::TiledCopyC;
|
||||||
|
|
||||||
static constexpr uint32_t TmaTransactionBytesA = static_cast<uint32_t>(
|
static constexpr uint32_t TmaTransactionBytesA = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutA{})) * cutlass::sizeof_bits_v<Element> / 8);
|
||||||
size(take<0, 2>(SmemLayoutA{})) * cutlass::sizeof_bits_v<Element> / 8);
|
static constexpr uint32_t TmaTransactionBytesB = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutB{})) * cutlass::sizeof_bits_v<Element> / 8);
|
||||||
static constexpr uint32_t TmaTransactionBytesB = static_cast<uint32_t>(
|
|
||||||
size(take<0, 2>(SmemLayoutB{})) * cutlass::sizeof_bits_v<Element> / 8);
|
|
||||||
static constexpr uint32_t TmaTransactionBytesScale = static_cast<uint32_t>(
|
|
||||||
size(SmemLayoutScale{}(_, _0{})) * cutlass::sizeof_bits_v<float> / 8);
|
|
||||||
|
|
||||||
struct Arguments {
|
struct Arguments {
|
||||||
Element const* ptr_A;
|
Element const* ptr_A;
|
||||||
LayoutT layout_A;
|
LayoutT layout_A;
|
||||||
Element const* ptr_B;
|
Element const* ptr_B;
|
||||||
LayoutT layout_B;
|
LayoutT layout_B;
|
||||||
ElementOutput* ptr_C;
|
ElementOutput * ptr_C;
|
||||||
LayoutT layout_C;
|
LayoutT layout_C;
|
||||||
const float* weight_scale;
|
const float *weight_scale;
|
||||||
LayoutTScale layout_Scale;
|
const float *input_row_sum;
|
||||||
const float* input_scale;
|
const int64_t * tokens;
|
||||||
const int64_t* tokens;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Params {
|
|
||||||
LayoutT layout_A;
|
|
||||||
LayoutT layout_B;
|
|
||||||
LayoutTScale layout_Scale;
|
|
||||||
TMA_A tma_load_A;
|
|
||||||
TMA_B tma_load_B;
|
|
||||||
TMA_Scale tma_load_Scale;
|
|
||||||
ElementOutput* ptr_C;
|
|
||||||
const float* weight_scale;
|
|
||||||
const float* input_scale;
|
|
||||||
const int64_t* tokens;
|
|
||||||
};
|
|
||||||
|
|
||||||
Params static to_underlying_arguments(Arguments const& args) {
|
|
||||||
Tensor mA = make_tensor(make_gmem_ptr(args.ptr_A), args.layout_A);
|
|
||||||
TMA_A tma_load_A =
|
|
||||||
make_tma_copy(GmemTiledCopy{},
|
|
||||||
mA,
|
|
||||||
SmemLayoutA{}(_, _, _0{}),
|
|
||||||
select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}),
|
|
||||||
size<0>(ClusterShape{}));
|
|
||||||
Tensor mB = make_tensor(make_gmem_ptr(args.ptr_B), args.layout_B);
|
|
||||||
TMA_B tma_load_B = make_tma_copy(GmemTiledCopy{},
|
|
||||||
mB,
|
|
||||||
SmemLayoutB{}(_, _, _0{}),
|
|
||||||
select<1, 2>(TileShape_MNK{}),
|
|
||||||
size<0>(ClusterShape{}));
|
|
||||||
Tensor mScale =
|
|
||||||
make_tensor(make_gmem_ptr(args.weight_scale), args.layout_Scale);
|
|
||||||
TMA_Scale tma_load_Scale = make_tma_copy(GmemTiledCopy{},
|
|
||||||
mScale,
|
|
||||||
SmemLayoutScale{}(_, _0{}),
|
|
||||||
select<0>(Shape<Int<kBlockM>>{}),
|
|
||||||
size<0>(ClusterShape{}));
|
|
||||||
|
|
||||||
return {args.layout_A,
|
|
||||||
args.layout_B,
|
|
||||||
args.layout_Scale,
|
|
||||||
tma_load_A,
|
|
||||||
tma_load_B,
|
|
||||||
tma_load_Scale,
|
|
||||||
args.ptr_C,
|
|
||||||
args.weight_scale,
|
|
||||||
args.input_scale,
|
|
||||||
args.tokens};
|
|
||||||
}
|
|
||||||
|
|
||||||
CUTLASS_DEVICE
|
|
||||||
static void prefetch_tma_descriptors(Params const& mainloop_params) {
|
|
||||||
cute::prefetch_tma_descriptor(
|
|
||||||
mainloop_params.tma_load_A.get_tma_descriptor());
|
|
||||||
cute::prefetch_tma_descriptor(
|
|
||||||
mainloop_params.tma_load_B.get_tma_descriptor());
|
|
||||||
if constexpr (WeightScaleGroup < K) {
|
|
||||||
cute::prefetch_tma_descriptor(
|
|
||||||
mainloop_params.tma_load_Scale.get_tma_descriptor());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
|
|
||||||
CUTLASS_DEVICE void store(Params const& mainloop_params,
|
|
||||||
FrgTensorO& tOrO,
|
|
||||||
SharedStorage& shared_storage,
|
|
||||||
TiledMma tiled_mma,
|
|
||||||
const float* weight_scale,
|
|
||||||
const float* input_scale,
|
|
||||||
const int64_t tokens,
|
|
||||||
const int64_t pre_fix_tokens,
|
|
||||||
const int bidm,
|
|
||||||
const int bidn,
|
|
||||||
const int bidb,
|
|
||||||
const int tidx) {
|
|
||||||
using packHalf = typename PackedHalf<ElementOutput>::Type;
|
|
||||||
Tensor tOrO_out = make_tensor<ElementOutput>(tOrO.layout());
|
|
||||||
|
|
||||||
if (input_scale != nullptr) {
|
|
||||||
const int lane_id = tidx % 4 * 2;
|
|
||||||
if constexpr (WeightScaleGroup == K) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < size(tOrO); i += 4) {
|
|
||||||
const int scale_idx = i * 2 + lane_id;
|
|
||||||
tOrO[i] = tOrO[i] * weight_scale[0] * input_scale[scale_idx];
|
|
||||||
tOrO[i + 1] =
|
|
||||||
tOrO[i + 1] * weight_scale[0] * input_scale[scale_idx + 1];
|
|
||||||
tOrO[i + 2] = tOrO[i + 2] * weight_scale[1] * input_scale[scale_idx];
|
|
||||||
tOrO[i + 3] =
|
|
||||||
tOrO[i + 3] * weight_scale[1] * input_scale[scale_idx + 1];
|
|
||||||
*reinterpret_cast<packHalf*>(&tOrO_out[i]) =
|
|
||||||
packHalf(tOrO[i], tOrO[i + 2]);
|
|
||||||
*reinterpret_cast<packHalf*>(&tOrO_out[i + 2]) =
|
|
||||||
packHalf(tOrO[i + 1], tOrO[i + 3]);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < size(tOrO); i += 4) {
|
|
||||||
const int scale_idx = i * 2 + lane_id;
|
|
||||||
*reinterpret_cast<packHalf*>(&tOrO_out[i]) =
|
|
||||||
packHalf(float(tOrO[i]) * input_scale[scale_idx],
|
|
||||||
float(tOrO[i + 2]) * input_scale[scale_idx]);
|
|
||||||
*reinterpret_cast<packHalf*>(&tOrO_out[i + 2]) =
|
|
||||||
packHalf(float(tOrO[i + 1]) * input_scale[scale_idx + 1],
|
|
||||||
float(tOrO[i + 3]) * input_scale[scale_idx + 1]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if constexpr (WeightScaleGroup == K) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < size(tOrO); i += 4) {
|
|
||||||
tOrO[i] = (tOrO[i]) * weight_scale[0];
|
|
||||||
tOrO[i + 1] = tOrO[i + 1] * weight_scale[0];
|
|
||||||
tOrO[i + 2] = tOrO[i + 2] * weight_scale[1];
|
|
||||||
tOrO[i + 3] = tOrO[i + 3] * weight_scale[1];
|
|
||||||
*reinterpret_cast<packHalf*>(&tOrO_out[i]) =
|
|
||||||
packHalf(tOrO[i], tOrO[i + 2]);
|
|
||||||
*reinterpret_cast<packHalf*>(&tOrO_out[i + 2]) =
|
|
||||||
packHalf(tOrO[i + 1], tOrO[i + 3]);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < size(tOrO); i += 4) {
|
|
||||||
*reinterpret_cast<packHalf*>(&tOrO_out[i]) =
|
|
||||||
packHalf(float(tOrO[i]), float(tOrO[i + 2]));
|
|
||||||
*reinterpret_cast<packHalf*>(&tOrO_out[i + 2]) =
|
|
||||||
packHalf(float(tOrO[i + 1]), float(tOrO[i + 3]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
uint16_t* smem_c =
|
|
||||||
reinterpret_cast<uint16_t*>(shared_storage.smem_c.data());
|
|
||||||
|
|
||||||
uint32_t* reg_data = reinterpret_cast<uint32_t*>(tOrO_out.data());
|
|
||||||
|
|
||||||
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
|
|
||||||
|
|
||||||
constexpr int k_copy_times = kBlockN / 16;
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < k_copy_times; i++) {
|
|
||||||
uint32_t smem_ptr = cast_smem_ptr_to_uint(
|
|
||||||
reinterpret_cast<uint128_t*>(smem_c + i * 16 * 128) + tidx);
|
|
||||||
#if defined(CUTE_ARCH_STSM_SM90_ENABLED)
|
|
||||||
asm volatile(
|
|
||||||
"stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, "
|
|
||||||
"%4};\n" ::"r"(smem_ptr),
|
|
||||||
"r"(reg_data[4 * i + 0]),
|
|
||||||
"r"(reg_data[4 * i + 2]),
|
|
||||||
"r"(reg_data[4 * i + 1]),
|
|
||||||
"r"(reg_data[4 * i + 3]));
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
|
|
||||||
const int expert_idx =
|
|
||||||
TokenPackSize == 0 ? pre_fix_tokens * M : bidb * M * TokenPackSize;
|
|
||||||
ElementOutput* store_c = mainloop_params.ptr_C + expert_idx +
|
|
||||||
bidn * (M * kBlockN) + bidm * kBlockM;
|
|
||||||
|
|
||||||
const int reamin_tokens = tokens - bidn * kBlockN;
|
|
||||||
|
|
||||||
const int col = tidx % 2;
|
|
||||||
|
|
||||||
constexpr int kPackSize = 16 / sizeof(ElementOutput);
|
|
||||||
constexpr int kNumVecElem = kBlockM / kPackSize;
|
|
||||||
constexpr int copy_len = kBlockN * kNumVecElem;
|
|
||||||
#pragma unroll
|
|
||||||
for (int idx = tidx; idx < copy_len; idx += NumMmaThreads) {
|
|
||||||
const int idx_div2 = idx / 2;
|
|
||||||
const int store_idx = idx_div2 / 128 * 128 + idx_div2 % 8 * 16 +
|
|
||||||
idx_div2 % 128 / 16 + idx_div2 % 16 / 8 * 8;
|
|
||||||
const int store_global_idx = store_idx * 2 + col;
|
|
||||||
const int row = store_global_idx / kNumVecElem;
|
|
||||||
const int col = store_global_idx % kNumVecElem;
|
|
||||||
if (row >= reamin_tokens) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const int offset = row * (M / kPackSize) + col;
|
|
||||||
reinterpret_cast<uint4*>(store_c)[offset] =
|
|
||||||
reinterpret_cast<uint4*>(smem_c)[idx];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename MTensor>
|
|
||||||
CUTLASS_DEVICE auto get_local_no_packed_tensor(const MTensor& mB,
|
|
||||||
const int pre_fix_token,
|
|
||||||
const int actual_token,
|
|
||||||
const int bidn) const {
|
|
||||||
auto g_tensor = domain_offset(make_coord(pre_fix_token, _0{}), mB(_, _, 0));
|
|
||||||
|
|
||||||
Tensor gB = local_tile(
|
|
||||||
g_tensor, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
|
|
||||||
return gB;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename SharedStorage>
|
|
||||||
CUTLASS_DEVICE void load(Params const& mainloop_params,
|
|
||||||
MainloopPipeline pipeline,
|
|
||||||
PipelineState& smem_pipe_write,
|
|
||||||
SharedStorage& shared_storage,
|
|
||||||
const int tokens,
|
|
||||||
const int pre_fix_tokens,
|
|
||||||
const int bidm,
|
|
||||||
const int bidn,
|
|
||||||
const int bidb,
|
|
||||||
const int tidx) {
|
|
||||||
Tensor sA =
|
|
||||||
make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
|
|
||||||
Tensor sB =
|
|
||||||
make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{});
|
|
||||||
Tensor sScale = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()),
|
|
||||||
SmemLayoutScale{});
|
|
||||||
|
|
||||||
Tensor mA = mainloop_params.tma_load_A.get_tma_tensor(
|
|
||||||
mainloop_params.layout_A.shape());
|
|
||||||
Tensor mB = mainloop_params.tma_load_B.get_tma_tensor(
|
|
||||||
mainloop_params.layout_B.shape());
|
|
||||||
Tensor mScale = mainloop_params.tma_load_Scale.get_tma_tensor(
|
|
||||||
mainloop_params.layout_Scale.shape());
|
|
||||||
|
|
||||||
Tensor gA =
|
|
||||||
local_tile(mA(_, _, bidb),
|
|
||||||
select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}),
|
|
||||||
make_coord(bidm, _));
|
|
||||||
Tensor gScale = local_tile(
|
|
||||||
mScale(_, bidm, bidb), select<0>(Shape<Int<kBlockM>>{}), make_coord(_));
|
|
||||||
|
|
||||||
auto [tAgA, tAsA] = tma_partition(mainloop_params.tma_load_A,
|
|
||||||
_0{},
|
|
||||||
Layout<ClusterShape>{},
|
|
||||||
group_modes<0, 2>(sA),
|
|
||||||
group_modes<0, 2>(gA));
|
|
||||||
|
|
||||||
if constexpr (TokenPackSize == 0) {
|
|
||||||
Tensor gB = get_local_no_packed_tensor(mB, pre_fix_tokens, tokens, bidn);
|
|
||||||
|
|
||||||
auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B,
|
|
||||||
_0{},
|
|
||||||
Layout<ClusterShape>{},
|
|
||||||
group_modes<0, 2>(sB),
|
|
||||||
group_modes<0, 2>(gB));
|
|
||||||
|
|
||||||
if (tidx == 0) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int kiter = 0; kiter < kTiles; ++kiter) {
|
|
||||||
pipeline.producer_acquire(smem_pipe_write);
|
|
||||||
copy(mainloop_params.tma_load_A.with(
|
|
||||||
*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
|
||||||
tAgA(_, kiter),
|
|
||||||
tAsA(_, smem_pipe_write.index()));
|
|
||||||
|
|
||||||
copy(mainloop_params.tma_load_B.with(
|
|
||||||
*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
|
||||||
tBgB(_, kiter),
|
|
||||||
tBsB(_, smem_pipe_write.index()));
|
|
||||||
|
|
||||||
if constexpr (WeightScaleGroup < K) {
|
|
||||||
copy(mainloop_params.tma_load_Scale.with(
|
|
||||||
*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
|
||||||
gScale(_, kiter),
|
|
||||||
sScale(_, smem_pipe_write.index()));
|
|
||||||
}
|
|
||||||
|
|
||||||
++smem_pipe_write;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
auto mB_this_expert = make_tensor(
|
|
||||||
mB(_, _, bidb).data(),
|
|
||||||
make_layout(cute::make_shape(tokens, size<1>(mB)), mB.stride()));
|
|
||||||
Tensor gB = local_tile(
|
|
||||||
mB_this_expert, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
|
|
||||||
auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B,
|
|
||||||
_0{},
|
|
||||||
Layout<ClusterShape>{},
|
|
||||||
group_modes<0, 2>(sB),
|
|
||||||
group_modes<0, 2>(gB));
|
|
||||||
|
|
||||||
if (tidx == 0) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int kiter = 0; kiter < kTiles; ++kiter) {
|
|
||||||
pipeline.producer_acquire(smem_pipe_write);
|
|
||||||
copy(mainloop_params.tma_load_A.with(
|
|
||||||
*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
|
||||||
tAgA(_, kiter),
|
|
||||||
tAsA(_, smem_pipe_write.index()));
|
|
||||||
|
|
||||||
copy(mainloop_params.tma_load_B.with(
|
|
||||||
*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
|
||||||
tBgB(_, kiter),
|
|
||||||
tBsB(_, smem_pipe_write.index()));
|
|
||||||
|
|
||||||
if constexpr (WeightScaleGroup < K) {
|
|
||||||
copy(mainloop_params.tma_load_Scale.with(
|
|
||||||
*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
|
||||||
gScale(_, kiter),
|
|
||||||
sScale(_, smem_pipe_write.index()));
|
|
||||||
}
|
|
||||||
|
|
||||||
++smem_pipe_write;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
|
|
||||||
CUTLASS_DEVICE void mma(Params const& mainloop_params,
|
|
||||||
TiledMma tiled_mma,
|
|
||||||
MainloopPipeline pipeline,
|
|
||||||
PipelineState& smem_pipe_read,
|
|
||||||
SharedStorage& shared_storage,
|
|
||||||
FrgTensorO& tSrS,
|
|
||||||
const int tidx) {
|
|
||||||
Tensor sA =
|
|
||||||
make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
|
|
||||||
Tensor sB =
|
|
||||||
make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{});
|
|
||||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
|
||||||
|
|
||||||
auto threadMma = tiled_mma.get_thread_slice(tidx);
|
|
||||||
|
|
||||||
auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomAB{}, tiled_mma);
|
|
||||||
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(tidx);
|
|
||||||
|
|
||||||
Tensor tSrA = threadMma.partition_fragment_A(sA(_, _, 0));
|
|
||||||
Tensor tSrB = threadMma.partition_fragment_B(sB);
|
|
||||||
|
|
||||||
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
|
|
||||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
|
||||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
|
||||||
};
|
|
||||||
#pragma unroll
|
|
||||||
for (int kiter = 0; kiter < kTiles; ++kiter) {
|
|
||||||
Tensor tSsA =
|
|
||||||
smem_thr_copy_A.partition_S(sA(_, _, smem_pipe_read.index()));
|
|
||||||
consumer_wait(pipeline, smem_pipe_read);
|
|
||||||
gemm</*wg_wait=*/0>(tiled_mma,
|
|
||||||
tSrA,
|
|
||||||
tSsA,
|
|
||||||
tSrB(_, _, _, smem_pipe_read.index()),
|
|
||||||
tSrS,
|
|
||||||
smem_tiled_copy_A,
|
|
||||||
smem_thr_copy_A);
|
|
||||||
pipeline.consumer_release(smem_pipe_read);
|
|
||||||
++smem_pipe_read;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
|
|
||||||
CUTLASS_DEVICE void mma_pipeline(Params const& mainloop_params,
|
|
||||||
TiledMma tiled_mma,
|
|
||||||
MainloopPipeline pipeline,
|
|
||||||
PipelineState& smem_pipe_read,
|
|
||||||
SharedStorage& shared_storage,
|
|
||||||
FrgTensorO& tSrS,
|
|
||||||
const int tidx) {
|
|
||||||
Tensor sA =
|
|
||||||
make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
|
|
||||||
Tensor sB =
|
|
||||||
make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{});
|
|
||||||
float2* weight_scale =
|
|
||||||
reinterpret_cast<float2*>(shared_storage.smem_scale.data()) + tidx / 4;
|
|
||||||
|
|
||||||
Tensor tSrS1 = make_fragment_like(tSrS);
|
|
||||||
Tensor tSrS2 = make_fragment_like(tSrS);
|
|
||||||
|
|
||||||
__half2* tSrS_data =
|
|
||||||
reinterpret_cast<__half2*>(raw_pointer_cast(tSrS.data()));
|
|
||||||
__half2* tSrS1_data =
|
|
||||||
reinterpret_cast<__half2*>(raw_pointer_cast(tSrS1.data()));
|
|
||||||
__half2* tSrS2_data =
|
|
||||||
reinterpret_cast<__half2*>(raw_pointer_cast(tSrS2.data()));
|
|
||||||
|
|
||||||
auto threadMma = tiled_mma.get_thread_slice(tidx);
|
|
||||||
|
|
||||||
auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomAB{}, tiled_mma);
|
|
||||||
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(tidx);
|
|
||||||
|
|
||||||
Tensor tSrA = threadMma.partition_fragment_A(sA(_, _, 0));
|
|
||||||
Tensor tSrB = threadMma.partition_fragment_B(sB);
|
|
||||||
|
|
||||||
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
|
|
||||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
|
||||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
__half2 scale1, scale2, scale3, scale4;
|
struct Params {
|
||||||
float2 scale_cur_k;
|
LayoutT layout_A;
|
||||||
#pragma unroll
|
LayoutT layout_B;
|
||||||
for (int kiter = 0; kiter < kTiles;) {
|
TMA_A tma_load_A;
|
||||||
Tensor tSsA1 =
|
TMA_B tma_load_B;
|
||||||
smem_thr_copy_A.partition_S(sA(_, _, smem_pipe_read.index()));
|
ElementOutput * ptr_C;
|
||||||
consumer_wait(pipeline, smem_pipe_read);
|
const float *weight_scale;
|
||||||
scale_cur_k = *(weight_scale + smem_pipe_read.index() * (kBlockM / 2));
|
const float *input_row_sum;
|
||||||
scale1 = __half2(scale_cur_k.x, scale_cur_k.x);
|
const int64_t * tokens;
|
||||||
scale2 = __half2(scale_cur_k.y, scale_cur_k.y);
|
};
|
||||||
|
|
||||||
gemm</*wg_wait=*/0>(tiled_mma,
|
|
||||||
tSrA,
|
|
||||||
tSsA1,
|
|
||||||
tSrB(_, _, _, smem_pipe_read.index()),
|
|
||||||
tSrS1,
|
|
||||||
smem_tiled_copy_A,
|
|
||||||
smem_thr_copy_A);
|
|
||||||
pipeline.consumer_release(smem_pipe_read);
|
|
||||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
|
||||||
|
|
||||||
if (kiter > 0) {
|
Params static
|
||||||
for (int i = 0; i < size(tSrS) / 2; i += 2) {
|
to_underlying_arguments(Arguments const& args) {
|
||||||
tSrS_data[i] = __hfma2(tSrS2_data[i], scale3, tSrS_data[i]);
|
Tensor mA = make_tensor(make_gmem_ptr(args.ptr_A), args.layout_A);
|
||||||
tSrS_data[i + 1] =
|
TMA_A tma_load_A = make_tma_copy(
|
||||||
__hfma2(tSrS2_data[i + 1], scale4, tSrS_data[i + 1]);
|
GmemTiledCopy{},
|
||||||
|
mA,
|
||||||
|
SmemLayoutA{}(_, _, _0{}),
|
||||||
|
select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}),
|
||||||
|
size<0>(ClusterShape{}));
|
||||||
|
Tensor mB = make_tensor(make_gmem_ptr(args.ptr_B), args.layout_B);
|
||||||
|
TMA_B tma_load_B = make_tma_copy(
|
||||||
|
GmemTiledCopy{},
|
||||||
|
mB,
|
||||||
|
SmemLayoutB{}(_, _, _0{}),
|
||||||
|
select<1, 2>(TileShape_MNK{}),
|
||||||
|
size<0>(ClusterShape{}));
|
||||||
|
|
||||||
|
return {args.layout_A, args.layout_B, tma_load_A, tma_load_B,
|
||||||
|
args.ptr_C, args.weight_scale, args.input_row_sum, args.tokens};
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static void prefetch_tma_descriptors(Params const& mainloop_params) {
|
||||||
|
cute::prefetch_tma_descriptor(mainloop_params.tma_load_A.get_tma_descriptor());
|
||||||
|
cute::prefetch_tma_descriptor(mainloop_params.tma_load_B.get_tma_descriptor());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int CUR_N, typename SharedStorage, typename FrgTensorO, typename TiledMma>
|
||||||
|
CUTLASS_DEVICE void
|
||||||
|
store(Params const& mainloop_params,
|
||||||
|
FrgTensorO & tOrO,
|
||||||
|
SharedStorage& shared_storage,
|
||||||
|
TiledMma tiled_mma,
|
||||||
|
const float *input_row_sum,
|
||||||
|
const float *weight_scale,
|
||||||
|
const int64_t tokens,
|
||||||
|
const int64_t pre_fix_tokens,
|
||||||
|
const int bidm,
|
||||||
|
const int bidn,
|
||||||
|
const int bidb,
|
||||||
|
const int tidx) {
|
||||||
|
|
||||||
|
using packHalf = typename PackedHalf<ElementOutput>::Type;
|
||||||
|
Tensor tOrO_out = make_tensor<ElementOutput>(tOrO.layout());
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < size(tOrO); i+=4) {
|
||||||
|
const int sum_idx = i * 2;
|
||||||
|
tOrO[i] = (tOrO[i] + input_row_sum[sum_idx]) * weight_scale[0];
|
||||||
|
tOrO[i + 1] = (tOrO[i + 1] + input_row_sum[sum_idx + 1]) * weight_scale[0];
|
||||||
|
tOrO[i + 2] = (tOrO[i + 2] + input_row_sum[sum_idx]) * weight_scale[1];
|
||||||
|
tOrO[i + 3] = (tOrO[i + 3] + input_row_sum[sum_idx + 1]) * weight_scale[1];
|
||||||
|
*reinterpret_cast<packHalf*>(&tOrO_out[i]) = packHalf(tOrO[i], tOrO[i + 2]);
|
||||||
|
*reinterpret_cast<packHalf*>(&tOrO_out[i + 2]) = packHalf(tOrO[i + 1], tOrO[i + 3]);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
++smem_pipe_read;
|
uint16_t *smem_c = reinterpret_cast<uint16_t *>(shared_storage.smem_c.data());
|
||||||
++kiter;
|
|
||||||
|
|
||||||
if (kiter < kTiles) {
|
uint32_t * reg_data = reinterpret_cast<uint32_t*>(tOrO_out.data());
|
||||||
Tensor tSsA2 =
|
|
||||||
smem_thr_copy_A.partition_S(sA(_, _, smem_pipe_read.index()));
|
|
||||||
consumer_wait(pipeline, smem_pipe_read);
|
|
||||||
scale_cur_k = *(weight_scale + smem_pipe_read.index() * (kBlockM / 2));
|
|
||||||
scale3 = __half2(scale_cur_k.x, scale_cur_k.x);
|
|
||||||
scale4 = __half2(scale_cur_k.y, scale_cur_k.y);
|
|
||||||
|
|
||||||
gemm</*wg_wait=*/0>(tiled_mma,
|
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
|
||||||
tSrA,
|
|
||||||
tSsA2,
|
|
||||||
tSrB(_, _, _, smem_pipe_read.index()),
|
|
||||||
tSrS2,
|
|
||||||
smem_tiled_copy_A,
|
|
||||||
smem_thr_copy_A);
|
|
||||||
pipeline.consumer_release(smem_pipe_read);
|
|
||||||
++smem_pipe_read;
|
|
||||||
++kiter;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < size(tSrS) / 2; i += 2) {
|
constexpr int k_copy_times = CUR_N / 16;
|
||||||
tSrS_data[i] = __hfma2(tSrS1_data[i], scale1, tSrS_data[i]);
|
|
||||||
tSrS_data[i + 1] = __hfma2(tSrS1_data[i + 1], scale2, tSrS_data[i + 1]);
|
#pragma unroll
|
||||||
}
|
for (int i = 0; i < k_copy_times; i++) {
|
||||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
uint32_t smem_ptr = cast_smem_ptr_to_uint(reinterpret_cast<uint128_t*>(smem_c + i * 16 * 128) + tidx);
|
||||||
|
#if defined(CUTE_ARCH_STSM_SM90_ENABLED)
|
||||||
|
asm volatile (
|
||||||
|
"stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
|
||||||
|
:: "r"(smem_ptr), "r"(reg_data[4 * i + 0]), "r"(reg_data[4 * i + 2]), "r"(reg_data[4 * i + 1]), "r"(reg_data[4 * i + 3]));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
|
||||||
|
const int batch_idx = TokenPackSize == 0 ? pre_fix_tokens * M : bidb * M * TokenPackSize;
|
||||||
|
ElementOutput * store_c = mainloop_params.ptr_C + batch_idx + bidn * (M * kBlockN) + bidm * kBlockM;
|
||||||
|
|
||||||
|
const int reamin_tokens = tokens - bidn * kBlockN;
|
||||||
|
|
||||||
|
const int col = tidx % 2;
|
||||||
|
|
||||||
|
constexpr int kPackSize = 16 / sizeof(ElementOutput);
|
||||||
|
constexpr int kNumVecElem = kBlockM / kPackSize;
|
||||||
|
constexpr int copy_len = CUR_N * kNumVecElem;
|
||||||
|
#pragma unroll
|
||||||
|
for (int idx = tidx; idx < copy_len; idx += NumMmaThreads) {
|
||||||
|
const int idx_div2 = idx / 2;
|
||||||
|
const int store_idx = idx_div2 / 128 * 128 + idx_div2 % 8 * 16 + idx_div2 % 128 / 16 + idx_div2 % 16 / 8 * 8;
|
||||||
|
const int store_global_idx = store_idx * 2 + col;
|
||||||
|
const int row = store_global_idx / kNumVecElem;
|
||||||
|
const int col = store_global_idx % kNumVecElem;
|
||||||
|
if (row >= reamin_tokens) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const int offset = row * (M / kPackSize) + col;
|
||||||
|
reinterpret_cast<uint4*>(store_c)[offset] = reinterpret_cast<uint4*>(smem_c)[idx];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if constexpr (kTiles % 2 == 0) {
|
|
||||||
for (int i = 0; i < size(tSrS) / 2; i += 2) {
|
template <typename MTensor>
|
||||||
tSrS_data[i] = __hfma2(tSrS2_data[i], scale3, tSrS_data[i]);
|
CUTLASS_DEVICE auto get_local_no_packed_tensor(
|
||||||
tSrS_data[i + 1] = __hfma2(tSrS2_data[i + 1], scale4, tSrS_data[i + 1]);
|
const MTensor &mB,
|
||||||
}
|
const int pre_fix_token,
|
||||||
|
const int actual_token,
|
||||||
|
const int bidn) const {
|
||||||
|
|
||||||
|
auto g_tensor = domain_offset(make_coord(pre_fix_token, _0{}), mB(_, _, 0));
|
||||||
|
|
||||||
|
Tensor gB = local_tile(g_tensor, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
|
||||||
|
return gB;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SharedStorage>
|
||||||
|
CUTLASS_DEVICE void
|
||||||
|
load(Params const& mainloop_params,
|
||||||
|
MainloopPipeline pipeline,
|
||||||
|
PipelineState& smem_pipe_write,
|
||||||
|
SharedStorage &shared_storage,
|
||||||
|
const int tokens,
|
||||||
|
const int pre_fix_tokens,
|
||||||
|
const int bidm,
|
||||||
|
const int bidn,
|
||||||
|
const int bidb,
|
||||||
|
const int tidx) {
|
||||||
|
|
||||||
|
Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
|
||||||
|
Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{});
|
||||||
|
|
||||||
|
Tensor mA = mainloop_params.tma_load_A.get_tma_tensor(mainloop_params.layout_A.shape());
|
||||||
|
Tensor mB = mainloop_params.tma_load_B.get_tma_tensor(mainloop_params.layout_B.shape());
|
||||||
|
|
||||||
|
Tensor gA = local_tile(mA(_, _, bidb), select<0, 1>(Shape<Int<kBlockM>, Int<kBlockK / 2>>{}), make_coord(bidm, _));
|
||||||
|
|
||||||
|
auto [tAgA, tAsA] = tma_partition(mainloop_params.tma_load_A, _0{}, Layout<ClusterShape>{}, group_modes<0, 2>(sA), group_modes<0, 2>(gA));
|
||||||
|
|
||||||
|
const int kIters = kTiles / kStages;
|
||||||
|
|
||||||
|
if constexpr (TokenPackSize == 0) {
|
||||||
|
Tensor gB = get_local_no_packed_tensor(
|
||||||
|
mB,
|
||||||
|
pre_fix_tokens,
|
||||||
|
tokens,
|
||||||
|
bidn);
|
||||||
|
|
||||||
|
auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B, _0{}, Layout<ClusterShape>{}, group_modes<0, 2>(sB), group_modes<0, 2>(gB));
|
||||||
|
|
||||||
|
if (tidx == 0) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int kiter = 0; kiter < kIters; ++kiter) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int s = 0; s < kStages; s++) {
|
||||||
|
const int i = kiter * kStages + s;
|
||||||
|
pipeline.producer_acquire(smem_pipe_write);
|
||||||
|
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||||
|
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
|
||||||
|
|
||||||
|
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||||
|
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
|
||||||
|
++smem_pipe_write;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = kIters * kStages; i < kTiles; ++i) {
|
||||||
|
pipeline.producer_acquire(smem_pipe_write);
|
||||||
|
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||||
|
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
|
||||||
|
|
||||||
|
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||||
|
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
|
||||||
|
++smem_pipe_write;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto mB_this_batch = make_tensor(
|
||||||
|
mB(_, _, bidb).data(),
|
||||||
|
make_layout(
|
||||||
|
cute::make_shape(tokens, size<1>(mB)),
|
||||||
|
mB.stride()
|
||||||
|
));
|
||||||
|
Tensor gB = local_tile(mB_this_batch, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
|
||||||
|
auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B, _0{}, Layout<ClusterShape>{}, group_modes<0, 2>(sB), group_modes<0, 2>(gB));
|
||||||
|
|
||||||
|
if (tidx == 0) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int kiter = 0; kiter < kIters; ++kiter) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int s = 0; s < kStages; s++) {
|
||||||
|
const int i = kiter * kStages + s;
|
||||||
|
pipeline.producer_acquire(smem_pipe_write);
|
||||||
|
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||||
|
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
|
||||||
|
|
||||||
|
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||||
|
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
|
||||||
|
++smem_pipe_write;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = kIters * kStages; i < kTiles; ++i) {
|
||||||
|
pipeline.producer_acquire(smem_pipe_write);
|
||||||
|
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||||
|
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
|
||||||
|
|
||||||
|
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
|
||||||
|
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
|
||||||
|
++smem_pipe_write;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int CUR_N, typename SharedStorage, typename FrgTensorO, typename TiledMma>
|
||||||
|
CUTLASS_DEVICE void
|
||||||
|
mma(Params const& mainloop_params,
|
||||||
|
TiledMma tiled_mma,
|
||||||
|
MainloopPipeline pipeline,
|
||||||
|
PipelineState& smem_pipe_read,
|
||||||
|
SharedStorage& shared_storage,
|
||||||
|
FrgTensorO &tSrS,
|
||||||
|
const int tidx) {
|
||||||
|
|
||||||
|
using sMemBLayout = std::conditional_t<
|
||||||
|
CUR_N == kBlockN,
|
||||||
|
SmemLayoutB,
|
||||||
|
SmemLayoutB_TAIL
|
||||||
|
>;
|
||||||
|
|
||||||
|
Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
|
||||||
|
Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), sMemBLayout{});
|
||||||
|
|
||||||
|
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||||
|
|
||||||
|
auto threadMma = tiled_mma.get_thread_slice(tidx);
|
||||||
|
|
||||||
|
auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomAB{}, tiled_mma);
|
||||||
|
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(tidx);
|
||||||
|
|
||||||
|
Tensor tSrA = threadMma.partition_fragment_A(sA(_, _, 0));
|
||||||
|
Tensor tSrB = threadMma.partition_fragment_B(sB);
|
||||||
|
|
||||||
|
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
|
||||||
|
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||||
|
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||||
|
};
|
||||||
|
|
||||||
|
const int kIters = kTiles / kStages;
|
||||||
|
|
||||||
|
constexpr int B_STEPS = CUR_N == 0 ? 1 : (kBlockN / CUR_N);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int kiter = 0; kiter < kIters; ++kiter) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int s = 0; s < kStages; s++) {
|
||||||
|
Tensor tSsA = smem_thr_copy_A.partition_S(sA(_, _, s));
|
||||||
|
consumer_wait(pipeline, smem_pipe_read);
|
||||||
|
gemm</*wg_wait=*/0>(tiled_mma, tSrA, tSsA, tSrB(_, _, _, s * B_STEPS), tSrS, smem_tiled_copy_A, smem_thr_copy_A);
|
||||||
|
pipeline.consumer_release(smem_pipe_read);
|
||||||
|
++smem_pipe_read;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < kTiles % kStages; ++i) {
|
||||||
|
Tensor tSsA = smem_thr_copy_A.partition_S(sA(_, _, i));
|
||||||
|
consumer_wait(pipeline, smem_pipe_read);
|
||||||
|
|
||||||
|
gemm</*wg_wait=*/0>(tiled_mma, tSrA, tSsA, tSrB(_, _, _, i * B_STEPS), tSrS, smem_tiled_copy_A, smem_thr_copy_A);
|
||||||
|
pipeline.consumer_release(smem_pipe_read);
|
||||||
|
++smem_pipe_read;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -24,116 +24,91 @@
|
|||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include <cute/arch/cluster_sm90.hpp> // For cute::elect_one_sync()
|
|
||||||
#include <cute/tensor.hpp>
|
#include <cute/tensor.hpp>
|
||||||
|
#include <cute/arch/cluster_sm90.hpp> // For cute::elect_one_sync()
|
||||||
|
|
||||||
#include <cutlass/array.h>
|
#include <cutlass/array.h>
|
||||||
#include <cutlass/cutlass.h>
|
#include <cutlass/cutlass.h>
|
||||||
#include <cutlass/numeric_conversion.h>
|
#include <cutlass/numeric_conversion.h>
|
||||||
#include <cutlass/numeric_types.h>
|
#include <cutlass/numeric_types.h>
|
||||||
|
|
||||||
|
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
|
|
||||||
template <typename T>
|
template<typename T>
|
||||||
struct PackedHalf;
|
struct PackedHalf;
|
||||||
|
|
||||||
template <>
|
template<>
|
||||||
struct PackedHalf<cutlass::half_t> {
|
struct PackedHalf<cutlass::half_t> {
|
||||||
using Type = __half2;
|
using Type = __half2;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template<>
|
||||||
struct PackedHalf<cutlass::bfloat16_t> {
|
struct PackedHalf<cutlass::bfloat16_t> {
|
||||||
using Type = nv_bfloat162;
|
using Type = nv_bfloat162;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
template <typename To_type, typename Engine, typename Layout>
|
template <typename To_type, typename Engine, typename Layout>
|
||||||
__forceinline__ __device__ auto convert_type(
|
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
|
||||||
Tensor<Engine, Layout> const &tensor) {
|
using From_type = typename Engine::value_type;
|
||||||
using From_type = typename Engine::value_type;
|
constexpr int numel = decltype(size(tensor))::value;
|
||||||
constexpr int numel = decltype(size(tensor))::value;
|
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
|
||||||
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
|
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
|
||||||
auto frag =
|
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
|
||||||
convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(
|
|
||||||
tensor.data()));
|
|
||||||
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int numel>
|
template <int numel>
|
||||||
__forceinline__ __device__ void convert_c4_2_fp8(const int32_t *src,
|
__forceinline__ __device__ void convert_c4_2_fp8(const int32_t * src, int32_t * dst1, int32_t * dst2) {
|
||||||
int32_t *dst1,
|
#pragma unroll
|
||||||
int32_t *dst2) {
|
for (int i = 0; i < numel; ++i) {
|
||||||
#pragma unroll
|
dst1[i] = (src[i] >> 4) & 0x0f0f0f0f;
|
||||||
for (int i = 0; i < numel; ++i) {
|
dst2[i] = src[i] & 0x0f0f0f0f;
|
||||||
uint32_t head1 = src[i] & 0x80808080;
|
|
||||||
dst1[i] = (src[i] >> 4) & 0x07070707;
|
|
||||||
dst1[i] = dst1[i] | head1;
|
|
||||||
uint32_t head2 = (src[i] & 0x08080808) << 4;
|
|
||||||
dst2[i] = src[i] & 0x07070707;
|
|
||||||
dst2[i] = dst2[i] | head2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int wg_wait = 0,
|
|
||||||
bool arrive = true,
|
|
||||||
bool commit = true,
|
|
||||||
typename Tensor0,
|
|
||||||
typename Tensor1,
|
|
||||||
typename Tensor2,
|
|
||||||
typename Tensor3,
|
|
||||||
typename TiledMma,
|
|
||||||
typename ThrCopyA,
|
|
||||||
typename TiledCopyA>
|
|
||||||
__forceinline__ __device__ void gemm(TiledMma &tiled_mma,
|
|
||||||
Tensor0 &tCrA,
|
|
||||||
Tensor1 &tCsA,
|
|
||||||
Tensor2 const &tCrB,
|
|
||||||
Tensor3 &tCrC,
|
|
||||||
TiledCopyA const &tiled_copy_A,
|
|
||||||
ThrCopyA const &thr_copy_A) {
|
|
||||||
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator,
|
|
||||||
typename TiledMma::FrgTypeA>::value;
|
|
||||||
Tensor tCrA1 = make_tensor<cutlass::float_e4m3_t>(tCrA.layout());
|
|
||||||
Tensor tCrA2 = make_tensor<cutlass::float_e4m3_t>(tCrA.layout());
|
|
||||||
if constexpr (Is_RS) {
|
|
||||||
warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA));
|
|
||||||
}
|
|
||||||
warpgroup_fence_operand(tCrC);
|
|
||||||
if constexpr (arrive) {
|
|
||||||
warpgroup_arrive();
|
|
||||||
}
|
|
||||||
constexpr int numel = decltype(size(tCrA(_, _, 0)))::value / 4;
|
|
||||||
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
|
|
||||||
cute::copy(tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{}));
|
|
||||||
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
|
||||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
|
||||||
if (k_block < size<2>(tCrA) - 1) {
|
|
||||||
cute::copy(tiled_copy_A,
|
|
||||||
tCsA(_, _, k_block + 1),
|
|
||||||
tCrA_copy_view(_, _, k_block + 1));
|
|
||||||
}
|
}
|
||||||
int32_t *tCrA_data =
|
}
|
||||||
reinterpret_cast<int32_t *>(tCrA(_, _, k_block).data());
|
|
||||||
int32_t *tCrA1_data =
|
template <int wg_wait=0, bool arrive=true,
|
||||||
reinterpret_cast<int32_t *>(tCrA1(_, _, k_block).data());
|
bool commit=true, typename Tensor0, typename Tensor1,
|
||||||
int32_t *tCrA2_data =
|
typename Tensor2, typename Tensor3, typename TiledMma,
|
||||||
reinterpret_cast<int32_t *>(tCrA2(_, _, k_block).data());
|
typename ThrCopyA, typename TiledCopyA>
|
||||||
convert_c4_2_fp8<numel>(tCrA_data, tCrA1_data, tCrA2_data);
|
__forceinline__ __device__ void gemm(
|
||||||
|
TiledMma &tiled_mma,
|
||||||
cute::gemm(tiled_mma, tCrA1(_, _, k_block), tCrB(_, _, 2 * k_block), tCrC);
|
Tensor0 &tCrA,
|
||||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
Tensor1 &tCsA,
|
||||||
cute::gemm(
|
Tensor2 const &tCrB,
|
||||||
tiled_mma, tCrA2(_, _, k_block), tCrB(_, _, 2 * k_block + 1), tCrC);
|
Tensor3 &tCrC,
|
||||||
}
|
TiledCopyA const &tiled_copy_A,
|
||||||
if constexpr (commit) {
|
ThrCopyA const &thr_copy_A) {
|
||||||
warpgroup_commit_batch();
|
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
|
||||||
}
|
Tensor tCrA1 = make_tensor<cutlass::float_e4m3_t>(tCrA.layout());
|
||||||
if constexpr (wg_wait >= 0) {
|
Tensor tCrA2 = make_tensor<cutlass::float_e4m3_t>(tCrA.layout());
|
||||||
warpgroup_wait<wg_wait>();
|
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
|
||||||
}
|
warpgroup_fence_operand(tCrC);
|
||||||
warpgroup_fence_operand(tCrC);
|
if constexpr (arrive) {
|
||||||
if constexpr (Is_RS) {
|
warpgroup_arrive();
|
||||||
warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA));
|
}
|
||||||
}
|
constexpr int numel = decltype(size(tCrA(_, _, 0)))::value / 4;
|
||||||
|
|
||||||
|
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
|
||||||
|
cute::copy(tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{}));
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||||
|
if (k_block < size<2>(tCrA) - 1) {
|
||||||
|
cute::copy(tiled_copy_A, tCsA(_, _, k_block + 1), tCrA_copy_view(_, _, k_block + 1));
|
||||||
|
}
|
||||||
|
int32_t * tCrA_data = reinterpret_cast<int32_t *>(tCrA(_,_,k_block).data());
|
||||||
|
int32_t * tCrA1_data = reinterpret_cast<int32_t *>(tCrA1(_,_,k_block).data());
|
||||||
|
int32_t * tCrA2_data = reinterpret_cast<int32_t *>(tCrA2(_,_,k_block).data());
|
||||||
|
convert_c4_2_fp8<numel>(tCrA_data, tCrA1_data, tCrA2_data);
|
||||||
|
|
||||||
|
cute::gemm(tiled_mma, tCrA1(_,_,k_block), tCrB(_,_,2 * k_block), tCrC);
|
||||||
|
cute::gemm(tiled_mma, tCrA2(_,_,k_block), tCrB(_,_, 2 * k_block + 1), tCrC);
|
||||||
|
}
|
||||||
|
if constexpr (commit) {
|
||||||
|
warpgroup_commit_batch();
|
||||||
|
}
|
||||||
|
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
|
||||||
|
warpgroup_fence_operand(tCrC);
|
||||||
|
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,179 +16,239 @@
|
|||||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "w4afp8_gemm.h"
|
|
||||||
#include "helper.h"
|
#include "helper.h"
|
||||||
#include "paddle/extension.h"
|
#include "paddle/extension.h"
|
||||||
#include "w4afp8_gemm_template.h"
|
#include "w4afp8_gemm_template.h"
|
||||||
#include "weight_kernel.hpp"
|
#include "w4afp8_gemm.h"
|
||||||
#include "weight_scale_kernel.hpp"
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
class NVTraits;
|
|
||||||
|
|
||||||
template <>
|
void weight_convert(const uint8_t *weight, uint8_t *weight_new, int batch, int M, int K) {
|
||||||
class NVTraits<__nv_fp8_e4m3> {
|
assert(K % 64 == 0);
|
||||||
public:
|
for (int b = 0; b < batch; ++b) {
|
||||||
typedef cutlass::float_e4m3_t data_t;
|
for (int m = 0; m < M; ++m) {
|
||||||
|
for (int k = 0; k < K; k+=64) {
|
||||||
|
for (int k_inner = 0; k_inner < 32; ++k_inner) {
|
||||||
|
uint8_t temp = 0;
|
||||||
|
uint8_t left = weight[b * M * K + m * K + k + k_inner];
|
||||||
|
uint8_t right = weight[b * M * K + m * K + k + k_inner + 32];
|
||||||
|
temp |= left << 4;
|
||||||
|
temp |= right;
|
||||||
|
weight_new[b * M * K / 2 + m * K / 2 + k / 2 + k_inner] = *reinterpret_cast<uint8_t*>(&temp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T> class NVTraits;
|
||||||
|
|
||||||
|
template <> class NVTraits<__nv_fp8_e4m3> {
|
||||||
|
public:
|
||||||
|
typedef cutlass::float_e4m3_t data_t;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <> class NVTraits<__nv_bfloat16>{
|
||||||
class NVTraits<__nv_bfloat16> {
|
public:
|
||||||
public:
|
typedef cutlass::bfloat16_t data_t;
|
||||||
typedef cutlass::bfloat16_t data_t;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <> class NVTraits<half>{
|
||||||
class NVTraits<half> {
|
public:
|
||||||
public:
|
typedef cutlass::half_t data_t;
|
||||||
typedef cutlass::half_t data_t;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <typename OutputType>
|
template <typename OutputType>
|
||||||
void DisPatchW4AFp8Gemm(const cutlass::float_e4m3_t* input,
|
void DisPatchW4AFp8Gemm(
|
||||||
const cutlass::float_e4m3_t* weight,
|
const cutlass::float_e4m3_t* input,
|
||||||
const int64_t* tokens,
|
const cutlass::float_e4m3_t* weight,
|
||||||
const float* weight_scale,
|
const int64_t * tokens,
|
||||||
const float* input_dequant_scale,
|
const float * input_row_sum,
|
||||||
OutputType* out,
|
const float * weight_scale,
|
||||||
const int64_t token_padding_size,
|
OutputType * out,
|
||||||
const int64_t max_tokens,
|
const int64_t token_padding_size,
|
||||||
const int Experts,
|
const int64_t max_tokens,
|
||||||
const int64_t M,
|
const int batch_size,
|
||||||
const int64_t K,
|
const int64_t M,
|
||||||
const int WeightScaleGroup,
|
const int64_t K,
|
||||||
cudaStream_t stream) {
|
cudaStream_t stream) {
|
||||||
int kBlockN = 256;
|
|
||||||
if constexpr (std::is_same_v<OutputType, cutlass::bfloat16_t>) {
|
int kBlockN = 256;
|
||||||
GEMM_SWITCH_BF16(M,
|
int TailN = 0;
|
||||||
K,
|
if constexpr (std::is_same_v<OutputType, cutlass::bfloat16_t>) {
|
||||||
Experts,
|
GEMM_SWITCH_BF16(
|
||||||
token_padding_size,
|
M, K, batch_size, token_padding_size, kBlockN, TailN,
|
||||||
kBlockN,
|
weight,
|
||||||
WeightScaleGroup,
|
input,
|
||||||
weight,
|
out,
|
||||||
input,
|
weight_scale,
|
||||||
out,
|
input_row_sum,
|
||||||
weight_scale,
|
tokens,
|
||||||
input_dequant_scale,
|
max_tokens,
|
||||||
tokens,
|
stream)
|
||||||
max_tokens,
|
} else {
|
||||||
stream)
|
PD_THROW("Only supported dtype in ['BFLOAT16'].");
|
||||||
} else {
|
}
|
||||||
PD_THROW("Only supported dtype in ['BFLOAT16'].");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<paddle::Tensor> W4AFp8Gemm(
|
std::vector<paddle::Tensor> W4AFp8Gemm(
|
||||||
const paddle::Tensor& input,
|
const paddle::Tensor& input,
|
||||||
const paddle::Tensor& weight,
|
const paddle::Tensor& weight,
|
||||||
const paddle::Tensor&
|
const paddle::Tensor& tokens, // If tokenpadding=0, this tensor represents the prefix sum of tensors, otherwise it represents the number of tokens in each group
|
||||||
tokens, // If tokenpadding=0, this tensor represents the prefix sum of
|
const paddle::Tensor& input_row_sum,
|
||||||
// tensors, otherwise it represents the number of tokens in
|
const paddle::Tensor& weight_scale,
|
||||||
// each group
|
const int64_t token_padding_size,
|
||||||
const paddle::Tensor& weight_scale,
|
const int64_t max_tokens,
|
||||||
const paddle::optional<paddle::Tensor>& input_dequant_scale,
|
const bool is_bfloat16) {
|
||||||
const int64_t token_padding_size,
|
|
||||||
const int64_t max_tokens,
|
|
||||||
const bool is_bfloat16) {
|
|
||||||
const int Experts = weight.dims()[0];
|
|
||||||
const int M = weight.dims()[1];
|
|
||||||
const int K = weight.dims()[2] * 2;
|
|
||||||
const int WeightScaleGroup =
|
|
||||||
weight_scale.dims().size() == 2 ? K : weight_scale.dims()[3];
|
|
||||||
|
|
||||||
if (input.dtype() != paddle::DataType::FLOAT8_E4M3FN) {
|
|
||||||
PD_THROW("Only supported dtype in ['FLOAT8_E4M3FN'].");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (token_padding_size == 0) {
|
const int batch_size = weight.dims()[0];
|
||||||
const int all_tokens = input.dims()[0];
|
const int M = weight.dims()[1];
|
||||||
if (is_bfloat16) {
|
const int K = weight.dims()[2] * 2;
|
||||||
paddle::Tensor out = paddle::empty(
|
|
||||||
{all_tokens, M}, paddle::DataType::BFLOAT16, input.place());
|
if (input.dtype() != paddle::DataType::FLOAT8_E4M3FN) {
|
||||||
phi::dtype::bfloat16* out_data = out.data<phi::dtype::bfloat16>();
|
PD_THROW("Only supported dtype in ['FLOAT8_E4M3FN'].");
|
||||||
DisPatchW4AFp8Gemm(
|
|
||||||
reinterpret_cast<const cutlass::float_e4m3_t*>(
|
|
||||||
input.data<phi::dtype::float8_e4m3fn>()),
|
|
||||||
reinterpret_cast<const cutlass::float_e4m3_t*>(
|
|
||||||
weight.data<uint8_t>()),
|
|
||||||
tokens.data<int64_t>(),
|
|
||||||
weight_scale.data<float>(),
|
|
||||||
input_dequant_scale
|
|
||||||
? const_cast<float*>(input_dequant_scale.get().data<float>())
|
|
||||||
: nullptr,
|
|
||||||
reinterpret_cast<cutlass::bfloat16_t*>(out_data),
|
|
||||||
token_padding_size,
|
|
||||||
max_tokens,
|
|
||||||
Experts,
|
|
||||||
M,
|
|
||||||
K,
|
|
||||||
WeightScaleGroup,
|
|
||||||
input.stream());
|
|
||||||
return {out};
|
|
||||||
} else {
|
|
||||||
PD_THROW("Only supported dtype in ['BFLOAT16'].");
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
if (is_bfloat16) {
|
if (token_padding_size == 0) {
|
||||||
paddle::Tensor out = paddle::empty({Experts, token_padding_size, M},
|
const int all_tokens = input.dims()[0];
|
||||||
paddle::DataType::BFLOAT16,
|
if (is_bfloat16) {
|
||||||
input.place());
|
paddle::Tensor out = paddle::empty({all_tokens, M}, paddle::DataType::BFLOAT16, input.place());
|
||||||
phi::dtype::bfloat16* out_data = out.data<phi::dtype::bfloat16>();
|
phi::dtype::bfloat16 *out_data = out.data<phi::dtype::bfloat16>();
|
||||||
DisPatchW4AFp8Gemm(
|
DisPatchW4AFp8Gemm(
|
||||||
reinterpret_cast<const cutlass::float_e4m3_t*>(
|
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
|
||||||
input.data<phi::dtype::float8_e4m3fn>()),
|
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
|
||||||
reinterpret_cast<const cutlass::float_e4m3_t*>(
|
tokens.data<int64_t>(),
|
||||||
weight.data<uint8_t>()),
|
input_row_sum.data<float>(),
|
||||||
tokens.data<int64_t>(),
|
weight_scale.data<float>(),
|
||||||
weight_scale.data<float>(),
|
reinterpret_cast<cutlass::bfloat16_t*>(out_data),
|
||||||
input_dequant_scale
|
token_padding_size,
|
||||||
? const_cast<float*>(input_dequant_scale.get().data<float>())
|
max_tokens,
|
||||||
: nullptr,
|
batch_size,
|
||||||
reinterpret_cast<cutlass::bfloat16_t*>(out_data),
|
M,
|
||||||
token_padding_size,
|
K,
|
||||||
max_tokens,
|
input.stream());
|
||||||
Experts,
|
return {out};
|
||||||
M,
|
} else {
|
||||||
K,
|
PD_THROW("Only supported dtype in ['BFLOAT16'].");
|
||||||
WeightScaleGroup,
|
}
|
||||||
input.stream());
|
|
||||||
return {out};
|
|
||||||
} else {
|
} else {
|
||||||
PD_THROW("Only supported dtype in ['BFLOAT16'].");
|
if (is_bfloat16) {
|
||||||
|
paddle::Tensor out = paddle::empty({batch_size, token_padding_size, M}, paddle::DataType::BFLOAT16, input.place());
|
||||||
|
phi::dtype::bfloat16 * out_data = out.data<phi::dtype::bfloat16>();
|
||||||
|
DisPatchW4AFp8Gemm(
|
||||||
|
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
|
||||||
|
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
|
||||||
|
tokens.data<int64_t>(),
|
||||||
|
input_row_sum.data<float>(),
|
||||||
|
weight_scale.data<float>(),
|
||||||
|
reinterpret_cast<cutlass::bfloat16_t*>(out_data),
|
||||||
|
token_padding_size,
|
||||||
|
max_tokens,
|
||||||
|
batch_size,
|
||||||
|
M,
|
||||||
|
K,
|
||||||
|
input.stream());
|
||||||
|
return {out};
|
||||||
|
} else {
|
||||||
|
PD_THROW("Only supported dtype in ['BFLOAT16'].");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename InputType, typename OutputType>
|
template <typename InputType, typename OutputType>
|
||||||
void DisPatchW4AFp8GemmWrapper(const InputType* input,
|
void DisPatchW4AFp8GemmWrapper(
|
||||||
const InputType* weight,
|
const InputType* input,
|
||||||
const int64_t* total_rows_before_expert,
|
const InputType* weight,
|
||||||
const float* input_dequant_scale,
|
const int64_t* total_rows_before_expert,
|
||||||
const float* weight_scale,
|
const float* input_row_sum,
|
||||||
OutputType* out,
|
const float* row_scale,
|
||||||
const int64_t token_padding_size,
|
const float* weight_scale,
|
||||||
const int64_t max_tokens,
|
OutputType * out,
|
||||||
const int num_experts,
|
const int64_t token_padding_size,
|
||||||
const int64_t M,
|
const int64_t max_tokens,
|
||||||
const int64_t K,
|
const int num_experts,
|
||||||
const int WeightScaleGroup,
|
const int64_t M,
|
||||||
cudaStream_t stream) {
|
const int64_t K,
|
||||||
using InType = typename NVTraits<InputType>::data_t;
|
cudaStream_t stream) {
|
||||||
using OutType = typename NVTraits<OutputType>::data_t;
|
using InType = typename NVTraits<InputType>::data_t;
|
||||||
DisPatchW4AFp8Gemm(reinterpret_cast<const InType*>(input),
|
using OutType = typename NVTraits<OutputType>::data_t;
|
||||||
reinterpret_cast<const InType*>(weight),
|
DisPatchW4AFp8Gemm(
|
||||||
total_rows_before_expert,
|
reinterpret_cast<const InType*>(input),
|
||||||
weight_scale,
|
reinterpret_cast<const InType*>(weight),
|
||||||
input_dequant_scale,
|
total_rows_before_expert,
|
||||||
reinterpret_cast<OutType*>(out),
|
input_row_sum,
|
||||||
token_padding_size,
|
weight_scale,
|
||||||
max_tokens,
|
reinterpret_cast<OutType*>(out),
|
||||||
num_experts,
|
token_padding_size,
|
||||||
M,
|
max_tokens,
|
||||||
K,
|
num_experts,
|
||||||
WeightScaleGroup,
|
M,
|
||||||
stream);
|
K,
|
||||||
|
stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
std::vector<paddle::Tensor> W4AFp8GemmWeightConvert(const paddle::Tensor& weight) {
|
||||||
|
const int batch_size = weight.dims()[0];
|
||||||
|
const int M = weight.dims()[1];
|
||||||
|
const int K = weight.dims()[2];
|
||||||
|
paddle::Tensor weight_new = paddle::empty({batch_size, M, K / 2}, paddle::DataType::UINT8, weight.place());
|
||||||
|
weight_convert(weight.data<uint8_t>(), weight_new.data<uint8_t>(), batch_size, M, K);
|
||||||
|
return {weight_new};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int kPackSize>
|
||||||
|
__global__ void permute_scale_kernel(
|
||||||
|
T* input_data,
|
||||||
|
const int numel) {
|
||||||
|
using LoadT = AlignedVector<T, kPackSize>;
|
||||||
|
LoadT input_vec;
|
||||||
|
LoadT dst_vec;
|
||||||
|
const int load_idx = (blockIdx.x * blockDim.x + threadIdx.x) * kPackSize;
|
||||||
|
if (load_idx >= numel) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Load<T, kPackSize>(&input_data[load_idx], &input_vec);
|
||||||
|
|
||||||
|
for (int i = 0; i < kPackSize; i+=2) {
|
||||||
|
dst_vec[i] = input_vec[i / 2];
|
||||||
|
dst_vec[i + 1] = input_vec[i / 2 + 8];
|
||||||
|
}
|
||||||
|
|
||||||
|
Store<T, kPackSize>(dst_vec, &input_data[load_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void W4AFp8GemmScalePermute(const paddle::Tensor& scale) {
|
||||||
|
const int row = scale.dims().size() == 2 ? scale.dims()[0] : 1;
|
||||||
|
const int col = scale.dims().size() == 2 ? scale.dims()[1] : scale.dims()[0];
|
||||||
|
if (col % 16 != 0) {
|
||||||
|
PD_THROW("Only supported when col is divisible by 16.");
|
||||||
|
}
|
||||||
|
const int numel = row * col;
|
||||||
|
const int threads = 128;
|
||||||
|
const int kPackSize = 16;
|
||||||
|
const int grid_size = (numel / kPackSize + threads - 1) / threads;
|
||||||
|
|
||||||
|
if (scale.dtype() == paddle::DataType::BFLOAT16) {
|
||||||
|
permute_scale_kernel<phi::dtype::bfloat16, kPackSize><<<grid_size, threads, 0, scale.stream()>>>(
|
||||||
|
const_cast<phi::dtype::bfloat16*>(scale.data<phi::dtype::bfloat16>()),
|
||||||
|
numel
|
||||||
|
);
|
||||||
|
} else if (scale.dtype() == paddle::DataType::FLOAT16) {
|
||||||
|
permute_scale_kernel<phi::dtype::float16, kPackSize><<<grid_size, threads, 0, scale.stream()>>>(
|
||||||
|
const_cast<phi::dtype::float16*>(scale.data<phi::dtype::float16>()),
|
||||||
|
numel
|
||||||
|
);
|
||||||
|
} else if (scale.dtype() == paddle::DataType::FLOAT32) {
|
||||||
|
permute_scale_kernel<float, kPackSize><<<grid_size, threads, 0, scale.stream()>>>(
|
||||||
|
const_cast<float*>(scale.data<float>()),
|
||||||
|
numel
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(w4afp8_gemm_scale_permute)
|
PD_BUILD_STATIC_OP(w4afp8_gemm_scale_permute)
|
||||||
@@ -201,8 +261,8 @@ PD_BUILD_STATIC_OP(w4afp8_gemm)
|
|||||||
.Inputs({"input",
|
.Inputs({"input",
|
||||||
"weight",
|
"weight",
|
||||||
"tokens",
|
"tokens",
|
||||||
"weight_scale",
|
"input_row_sum",
|
||||||
paddle::Optional("input_dequant_scale")})
|
"weight_scale"})
|
||||||
.Outputs({"out"})
|
.Outputs({"out"})
|
||||||
.Attrs({"token_padding_size: int64_t",
|
.Attrs({"token_padding_size: int64_t",
|
||||||
"max_tokens: int64_t",
|
"max_tokens: int64_t",
|
||||||
@@ -215,31 +275,33 @@ PD_BUILD_STATIC_OP(w4afp8_gemm_weight_convert)
|
|||||||
.SetKernelFn(PD_KERNEL(W4AFp8GemmWeightConvert));
|
.SetKernelFn(PD_KERNEL(W4AFp8GemmWeightConvert));
|
||||||
|
|
||||||
template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, __nv_bfloat16>(
|
template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, __nv_bfloat16>(
|
||||||
const __nv_fp8_e4m3* input,
|
const __nv_fp8_e4m3* input,
|
||||||
const __nv_fp8_e4m3* weight,
|
const __nv_fp8_e4m3* weight,
|
||||||
const int64_t* tokens,
|
const int64_t * tokens,
|
||||||
const float* input_dequant_scale,
|
const float * input_row_sum,
|
||||||
const float* weight_scale,
|
const float * row_scale,
|
||||||
__nv_bfloat16* out,
|
const float * weight_scale,
|
||||||
const int64_t token_padding_size,
|
__nv_bfloat16 * out,
|
||||||
const int64_t max_tokens,
|
const int64_t token_padding_size,
|
||||||
const int num_experts,
|
const int64_t max_tokens,
|
||||||
const int64_t M,
|
const int num_experts,
|
||||||
const int64_t K,
|
const int64_t M,
|
||||||
const int WeightScaleGroup,
|
const int64_t K,
|
||||||
cudaStream_t stream);
|
cudaStream_t stream
|
||||||
|
);
|
||||||
|
|
||||||
template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, half>(
|
template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, half>(
|
||||||
const __nv_fp8_e4m3* input,
|
const __nv_fp8_e4m3* input,
|
||||||
const __nv_fp8_e4m3* weight,
|
const __nv_fp8_e4m3* weight,
|
||||||
const int64_t* tokens,
|
const int64_t * tokens,
|
||||||
const float* input_dequant_scale,
|
const float * input_row_sum,
|
||||||
const float* weight_scale,
|
const float * row_scale,
|
||||||
half* out,
|
const float * weight_scale,
|
||||||
const int64_t token_padding_size,
|
half * out,
|
||||||
const int64_t max_tokens,
|
const int64_t token_padding_size,
|
||||||
const int num_experts,
|
const int64_t max_tokens,
|
||||||
const int64_t M,
|
const int num_experts,
|
||||||
const int64_t K,
|
const int64_t M,
|
||||||
const int WeightScaleGroup,
|
const int64_t K,
|
||||||
cudaStream_t stream);
|
cudaStream_t stream
|
||||||
|
);
|
||||||
|
|||||||
@@ -18,30 +18,30 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include "helper.h"
|
#include "helper.h"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
std::vector<paddle::Tensor> W4AFp8Gemm(
|
std::vector<paddle::Tensor> W4AFp8Gemm(
|
||||||
const paddle::Tensor& input,
|
const paddle::Tensor& input,
|
||||||
const paddle::Tensor& weight,
|
const paddle::Tensor& weight,
|
||||||
const paddle::Tensor&
|
const paddle::Tensor& tokens, // If tokenpadding=0, this tensor represents the prefix sum of tensors, otherwise it represents the number of tokens in each group
|
||||||
tokens, // If tokenpadding=0, this tensor represents the prefix sum of
|
const paddle::Tensor& input_row_sum,
|
||||||
// tensors, otherwise it represents the number of tokens in
|
const paddle::Tensor& weight_scale,
|
||||||
// each group
|
const int64_t token_padding_size,
|
||||||
const paddle::Tensor& weight_scale,
|
const int64_t max_tokens,
|
||||||
const paddle::optional<paddle::Tensor>& input_dequant_scale,
|
const bool is_bfloat16);
|
||||||
const int64_t token_padding_size,
|
|
||||||
const int64_t max_tokens,
|
|
||||||
const bool is_bfloat16);
|
|
||||||
|
|
||||||
template <typename InputType, typename OutputType>
|
template <typename InputType, typename OutputType>
|
||||||
void DisPatchW4AFp8GemmWrapper(const InputType* input,
|
void DisPatchW4AFp8GemmWrapper(
|
||||||
const InputType* weight,
|
const InputType* input,
|
||||||
const int64_t* tokens,
|
const InputType* weight,
|
||||||
const float* input_dequant_scale,
|
const int64_t * tokens,
|
||||||
const float* weight_scale,
|
const float * input_row_sum,
|
||||||
OutputType* out,
|
const float * row_scale,
|
||||||
const int64_t token_padding_size,
|
const float * weight_scale,
|
||||||
const int64_t max_tokens,
|
OutputType * out,
|
||||||
const int num_experts,
|
const int64_t token_padding_size,
|
||||||
const int64_t M,
|
const int64_t max_tokens,
|
||||||
const int64_t K,
|
const int num_experts,
|
||||||
const int WeightScaleGroup,
|
const int64_t M,
|
||||||
cudaStream_t stream);
|
const int64_t K,
|
||||||
|
cudaStream_t stream);
|
||||||
|
|||||||
@@ -16,280 +16,237 @@
|
|||||||
#include "cute/atom/mma_atom.hpp"
|
#include "cute/atom/mma_atom.hpp"
|
||||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||||
|
|
||||||
#include "cutlass/arch/reg_reconfig.h"
|
|
||||||
#include "cutlass/cluster_launch.hpp"
|
|
||||||
#include "cutlass/cutlass.h"
|
#include "cutlass/cutlass.h"
|
||||||
#include "cutlass/layout/layout.h"
|
#include "cutlass/layout/layout.h"
|
||||||
#include "cutlass/numeric_types.h"
|
#include "cutlass/numeric_types.h"
|
||||||
#include "cutlass/pipeline/pipeline.hpp"
|
#include "cutlass/pipeline/pipeline.hpp"
|
||||||
|
#include "cutlass/cluster_launch.hpp"
|
||||||
|
#include "cutlass/arch/reg_reconfig.h"
|
||||||
|
|
||||||
#include "kernel_traits.h"
|
#include "kernel_traits.h"
|
||||||
#include "mainloop_fwd.h"
|
#include "mainloop_fwd.h"
|
||||||
|
|
||||||
template <typename Ktraits>
|
template <typename Ktraits>
|
||||||
void __global__ __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp,
|
void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) w4afp8_gemm_kernel(
|
||||||
1)
|
CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits>::Params const mainloop_params) {
|
||||||
w4afp8_gemm_kernel(
|
|
||||||
CUTE_GRID_CONSTANT
|
|
||||||
typename CollectiveMainloopFwd<Ktraits>::Params const mainloop_params) {
|
|
||||||
using Element = typename Ktraits::Element;
|
|
||||||
static_assert(cutlass::sizeof_bits_v<Element> == 8);
|
|
||||||
|
|
||||||
using TileShape_MNK = typename Ktraits::TileShape_MNK;
|
using Element = typename Ktraits::Element;
|
||||||
using ClusterShape = typename Ktraits::ClusterShape_MNK;
|
static_assert(cutlass::sizeof_bits_v<Element> == 8);
|
||||||
|
|
||||||
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{});
|
using TileShape_MNK = typename Ktraits::TileShape_MNK;
|
||||||
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
|
using TileShape_MNK_TAIL = typename Ktraits::TileShape_MNK_TAIL;
|
||||||
static constexpr int kBlockN = Ktraits::kBlockN;
|
using ClusterShape = typename Ktraits::ClusterShape_MNK;
|
||||||
static constexpr int kBlockM = Ktraits::kBlockM;
|
|
||||||
static constexpr int M = Ktraits::M;
|
|
||||||
static constexpr int K = Ktraits::K;
|
|
||||||
static constexpr int TokenPackSize = Ktraits::TokenPackSize;
|
|
||||||
static constexpr int WeightScaleGroup = Ktraits::WeightScaleGroup;
|
|
||||||
|
|
||||||
using CollectiveMainloop = CollectiveMainloopFwd<Ktraits>;
|
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{});
|
||||||
|
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
|
||||||
|
static constexpr int kBlockN = Ktraits::kBlockN;
|
||||||
|
static constexpr int kBlockM = Ktraits::kBlockM;
|
||||||
|
static constexpr int M = Ktraits::M;
|
||||||
|
static constexpr int TokenPackSize = Ktraits::TokenPackSize;
|
||||||
|
static constexpr int TAIL_N = Ktraits::TAIL_N;
|
||||||
|
|
||||||
using MainloopPipeline = typename Ktraits::MainloopPipeline;
|
using CollectiveMainloop = CollectiveMainloopFwd<Ktraits>;
|
||||||
using PipelineParams = typename MainloopPipeline::Params;
|
|
||||||
using PipelineState = typename MainloopPipeline::PipelineState;
|
|
||||||
using ElementOutput = typename Ktraits::ElementOutput;
|
|
||||||
|
|
||||||
extern __shared__ char shared_memory[];
|
using MainloopPipeline = typename Ktraits::MainloopPipeline;
|
||||||
auto &shared_storage =
|
using PipelineParams = typename MainloopPipeline::Params;
|
||||||
*reinterpret_cast<typename Ktraits::SharedStorage *>(shared_memory);
|
using PipelineState = typename MainloopPipeline::PipelineState;
|
||||||
|
using ElementOutput = typename Ktraits::ElementOutput;
|
||||||
|
|
||||||
const int bidm = blockIdx.x;
|
extern __shared__ char shared_memory[];
|
||||||
const int bidn = blockIdx.y;
|
auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
|
||||||
const int bidb = blockIdx.z;
|
|
||||||
const int tidx = threadIdx.x;
|
|
||||||
|
|
||||||
if (tidx == 0) {
|
const int bidm = blockIdx.x;
|
||||||
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
|
const int bidn = blockIdx.y;
|
||||||
}
|
const int bidb = blockIdx.z;
|
||||||
|
const int tidx = threadIdx.x;
|
||||||
|
|
||||||
// Obtain warp index
|
if (tidx == 0) {
|
||||||
int const warp_group_thread_idx =
|
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
|
||||||
threadIdx.x % cutlass::NumThreadsPerWarpGroup;
|
|
||||||
|
|
||||||
PipelineParams pipeline_params;
|
|
||||||
if constexpr (WeightScaleGroup == K) {
|
|
||||||
pipeline_params.transaction_bytes =
|
|
||||||
CollectiveMainloop::TmaTransactionBytesA +
|
|
||||||
CollectiveMainloop::TmaTransactionBytesB;
|
|
||||||
} else {
|
|
||||||
pipeline_params.transaction_bytes =
|
|
||||||
CollectiveMainloop::TmaTransactionBytesA +
|
|
||||||
CollectiveMainloop::TmaTransactionBytesB +
|
|
||||||
CollectiveMainloop::TmaTransactionBytesScale;
|
|
||||||
}
|
|
||||||
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
|
||||||
pipeline_params.role = warp_group_idx == 0
|
|
||||||
? MainloopPipeline::ThreadCategory::Producer
|
|
||||||
: MainloopPipeline::ThreadCategory::Consumer;
|
|
||||||
pipeline_params.is_leader = warp_group_thread_idx == 0;
|
|
||||||
pipeline_params.num_consumers = NumMmaThreads;
|
|
||||||
|
|
||||||
MainloopPipeline pipeline(
|
|
||||||
shared_storage.pipeline, pipeline_params, ClusterShape{});
|
|
||||||
|
|
||||||
CollectiveMainloop collective_mainloop;
|
|
||||||
|
|
||||||
if constexpr (size(ClusterShape{}) > 1) {
|
|
||||||
cute::cluster_arrive_relaxed();
|
|
||||||
cute::cluster_wait();
|
|
||||||
} else {
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
const int pre_fix_tokens =
|
|
||||||
TokenPackSize == 0 ? (bidb == 0 ? 0 : mainloop_params.tokens[bidb - 1])
|
|
||||||
: 0;
|
|
||||||
|
|
||||||
const int tokens = TokenPackSize == 0
|
|
||||||
? mainloop_params.tokens[bidb] - pre_fix_tokens
|
|
||||||
: mainloop_params.tokens[bidb];
|
|
||||||
|
|
||||||
if (bidn * kBlockN >= tokens) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const bool is_need_input_scale = mainloop_params.input_scale != nullptr;
|
|
||||||
|
|
||||||
float *input_scale =
|
|
||||||
is_need_input_scale
|
|
||||||
? reinterpret_cast<float *>(shared_memory +
|
|
||||||
sizeof(typename Ktraits::SharedStorage))
|
|
||||||
: nullptr;
|
|
||||||
|
|
||||||
if (warp_group_idx == 0) {
|
|
||||||
cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 12 ? 40 : 32>();
|
|
||||||
PipelineState smem_pipe_write =
|
|
||||||
cutlass::make_producer_start_state<MainloopPipeline>();
|
|
||||||
collective_mainloop.load(mainloop_params,
|
|
||||||
pipeline,
|
|
||||||
smem_pipe_write,
|
|
||||||
shared_storage,
|
|
||||||
tokens,
|
|
||||||
pre_fix_tokens,
|
|
||||||
bidm,
|
|
||||||
bidn,
|
|
||||||
bidb,
|
|
||||||
tidx);
|
|
||||||
} else {
|
|
||||||
cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 12 ? 232 : 160>();
|
|
||||||
PipelineState smem_pipe_read;
|
|
||||||
|
|
||||||
typename Ktraits::TiledMma tiled_mma;
|
|
||||||
|
|
||||||
const int mma_tidx = tidx - NumCopyThreads;
|
|
||||||
|
|
||||||
if (is_need_input_scale) {
|
|
||||||
if constexpr (TokenPackSize == 0) {
|
|
||||||
const int input_scale_idx = pre_fix_tokens + bidn * kBlockN;
|
|
||||||
if (mma_tidx < tokens) {
|
|
||||||
reinterpret_cast<float *>(input_scale)[mma_tidx] =
|
|
||||||
reinterpret_cast<const float *>(mainloop_params.input_scale +
|
|
||||||
input_scale_idx)[mma_tidx];
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
const int input_scale_idx = bidb * TokenPackSize + bidn * kBlockN;
|
|
||||||
if (mma_tidx < kBlockN / 4) {
|
|
||||||
reinterpret_cast<float4 *>(input_scale)[mma_tidx] =
|
|
||||||
reinterpret_cast<const float4 *>(mainloop_params.input_scale +
|
|
||||||
input_scale_idx)[mma_tidx];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
float2 weight_scale;
|
// Obtain warp index
|
||||||
|
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
|
||||||
|
|
||||||
if constexpr (WeightScaleGroup == K) {
|
PipelineParams pipeline_params;
|
||||||
weight_scale = reinterpret_cast<const float2 *>(
|
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesA + CollectiveMainloop::TmaTransactionBytesB;
|
||||||
mainloop_params.weight_scale + bidb * M +
|
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
||||||
bidm * kBlockM)[mma_tidx / 4];
|
pipeline_params.role = warp_group_idx == 0
|
||||||
}
|
? MainloopPipeline::ThreadCategory::Producer
|
||||||
Tensor tSrS =
|
: MainloopPipeline::ThreadCategory::Consumer;
|
||||||
partition_fragment_C(tiled_mma, select<0, 1>(TileShape_MNK{}));
|
pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||||
|
pipeline_params.num_consumers = NumMmaThreads;
|
||||||
|
|
||||||
if constexpr (WeightScaleGroup == K) {
|
MainloopPipeline pipeline(shared_storage.pipeline, pipeline_params, ClusterShape{});
|
||||||
collective_mainloop.mma(mainloop_params,
|
|
||||||
tiled_mma,
|
CollectiveMainloop collective_mainloop;
|
||||||
pipeline,
|
|
||||||
smem_pipe_read,
|
if constexpr (size(ClusterShape{}) > 1) {
|
||||||
shared_storage,
|
cute::cluster_arrive_relaxed();
|
||||||
tSrS,
|
cute::cluster_wait();
|
||||||
mma_tidx);
|
|
||||||
} else {
|
} else {
|
||||||
collective_mainloop.mma_pipeline(mainloop_params,
|
__syncthreads();
|
||||||
tiled_mma,
|
}
|
||||||
pipeline,
|
|
||||||
smem_pipe_read,
|
const int pre_fix_tokens = TokenPackSize == 0 ? (bidb == 0 ? 0 : mainloop_params.tokens[bidb - 1]) : 0;
|
||||||
shared_storage,
|
|
||||||
tSrS,
|
const int tokens = TokenPackSize == 0 ? mainloop_params.tokens[bidb] - pre_fix_tokens : mainloop_params.tokens[bidb];
|
||||||
mma_tidx);
|
|
||||||
|
|
||||||
|
if (bidn * kBlockN >= tokens) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
float* input_row_sum = reinterpret_cast<float*>(
|
||||||
|
shared_memory + sizeof(typename Ktraits::SharedStorage));
|
||||||
|
|
||||||
|
if (warp_group_idx == 0) {
|
||||||
|
cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 12 ? 40 : 32>();
|
||||||
|
PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||||
|
collective_mainloop.load(
|
||||||
|
mainloop_params,
|
||||||
|
pipeline,
|
||||||
|
smem_pipe_write,
|
||||||
|
shared_storage,
|
||||||
|
tokens,
|
||||||
|
pre_fix_tokens,
|
||||||
|
bidm,
|
||||||
|
bidn,
|
||||||
|
bidb,
|
||||||
|
tidx);
|
||||||
|
} else {
|
||||||
|
cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 12 ? 232 : 160>();
|
||||||
|
PipelineState smem_pipe_read;
|
||||||
|
|
||||||
|
typename Ktraits::TiledMma tiled_mma;
|
||||||
|
|
||||||
|
typename Ktraits::TiledMma_TAIL tiled_mma_tail;
|
||||||
|
|
||||||
|
const int mma_tidx = tidx - NumCopyThreads;
|
||||||
|
const int lane_id = mma_tidx % 4 * 2;
|
||||||
|
|
||||||
|
const float2 weight_scale = reinterpret_cast<const float2*>(mainloop_params.weight_scale + bidb * M + bidm * kBlockM)[mma_tidx / 4];
|
||||||
|
|
||||||
|
if constexpr (TokenPackSize == 0) {
|
||||||
|
const int input_sum_idx = pre_fix_tokens + bidn * kBlockN;
|
||||||
|
if (mma_tidx < kBlockN) {
|
||||||
|
reinterpret_cast<float*>(input_row_sum)[mma_tidx] = reinterpret_cast<const float*>(mainloop_params.input_row_sum + input_sum_idx)[mma_tidx];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const int input_sum_idx = bidb * TokenPackSize + bidn * kBlockN;
|
||||||
|
if (mma_tidx < kBlockN / 4) {
|
||||||
|
reinterpret_cast<float4*>(input_row_sum)[mma_tidx] = reinterpret_cast<const float4*>(mainloop_params.input_row_sum + input_sum_idx)[mma_tidx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const int reamin_tokens = tokens - bidn * kBlockN;
|
||||||
|
|
||||||
|
if (TAIL_N > 0 && reamin_tokens < kBlockN) {
|
||||||
|
Tensor tSrS_tail = partition_fragment_C(tiled_mma_tail, select<0, 1>(TileShape_MNK_TAIL{}));
|
||||||
|
collective_mainloop.mma<TAIL_N>(
|
||||||
|
mainloop_params,
|
||||||
|
tiled_mma_tail,
|
||||||
|
pipeline,
|
||||||
|
smem_pipe_read,
|
||||||
|
shared_storage,
|
||||||
|
tSrS_tail,
|
||||||
|
mma_tidx);
|
||||||
|
collective_mainloop.store<TAIL_N>(
|
||||||
|
mainloop_params,
|
||||||
|
tSrS_tail,
|
||||||
|
shared_storage,
|
||||||
|
tiled_mma_tail,
|
||||||
|
input_row_sum + lane_id,
|
||||||
|
reinterpret_cast<const float*>(&weight_scale),
|
||||||
|
tokens,
|
||||||
|
pre_fix_tokens,
|
||||||
|
bidm,
|
||||||
|
bidn,
|
||||||
|
bidb,
|
||||||
|
mma_tidx);
|
||||||
|
} else {
|
||||||
|
Tensor tSrS = partition_fragment_C(tiled_mma, select<0, 1>(TileShape_MNK{}));
|
||||||
|
collective_mainloop.mma<kBlockN>(
|
||||||
|
mainloop_params,
|
||||||
|
tiled_mma,
|
||||||
|
pipeline,
|
||||||
|
smem_pipe_read,
|
||||||
|
shared_storage,
|
||||||
|
tSrS,
|
||||||
|
mma_tidx);
|
||||||
|
collective_mainloop.store<kBlockN>(
|
||||||
|
mainloop_params,
|
||||||
|
tSrS,
|
||||||
|
shared_storage,
|
||||||
|
tiled_mma,
|
||||||
|
input_row_sum + lane_id,
|
||||||
|
reinterpret_cast<const float*>(&weight_scale),
|
||||||
|
tokens,
|
||||||
|
pre_fix_tokens,
|
||||||
|
bidm,
|
||||||
|
bidn,
|
||||||
|
bidb,
|
||||||
|
mma_tidx);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
collective_mainloop.store(mainloop_params,
|
|
||||||
tSrS,
|
|
||||||
shared_storage,
|
|
||||||
tiled_mma,
|
|
||||||
reinterpret_cast<const float *>(&weight_scale),
|
|
||||||
input_scale,
|
|
||||||
tokens,
|
|
||||||
pre_fix_tokens,
|
|
||||||
bidm,
|
|
||||||
bidn,
|
|
||||||
bidb,
|
|
||||||
mma_tidx);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int Experts>
|
template <int Batch>
|
||||||
auto get_gmem_layout(const int Rows, const int Cols) {
|
auto get_gmem_layout(const int Rows, const int Cols) {
|
||||||
return make_layout(make_shape(static_cast<int64_t>(Rows),
|
return make_layout(
|
||||||
static_cast<int64_t>(Cols),
|
make_shape(
|
||||||
static_cast<int64_t>(Experts)),
|
static_cast<int64_t>(Rows),
|
||||||
make_stride(static_cast<int64_t>(Cols),
|
static_cast<int64_t>(Cols),
|
||||||
cute::_1{},
|
static_cast<int64_t>(Batch)),
|
||||||
static_cast<int64_t>(Rows * Cols)));
|
make_stride(
|
||||||
|
static_cast<int64_t>(Cols),
|
||||||
|
cute::_1{},
|
||||||
|
static_cast<int64_t>(Rows * Cols)));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int Experts>
|
|
||||||
auto get_scale_layout(const int Rows, const int Cols) {
|
template <typename InputType, typename OutputType, typename Kernel_traits, int M, int K, int Batch, int TokenPackSize>
|
||||||
return make_layout(make_shape(static_cast<int64_t>(Cols),
|
void run_gemm(const InputType * A, const InputType * B, OutputType * C, const float *weight_scale,
|
||||||
static_cast<int64_t>(Rows),
|
const float *input_row_sum, const int64_t * tokens, const int64_t max_tokens, cudaStream_t stream) {
|
||||||
static_cast<int64_t>(Experts)),
|
|
||||||
make_stride(cute::_1{},
|
using ElementOutput = typename Kernel_traits::ElementOutput;
|
||||||
static_cast<int64_t>(Cols),
|
using Element = typename Kernel_traits::Element;
|
||||||
static_cast<int64_t>(Rows * Cols)));
|
using CollectiveMainloop = CollectiveMainloopFwd<Kernel_traits>;
|
||||||
}
|
using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
|
||||||
|
|
||||||
template <typename InputType,
|
constexpr int M_nums = (M + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
|
||||||
typename OutputType,
|
const int N_nums = (max_tokens + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
|
||||||
typename Kernel_traits,
|
|
||||||
int M,
|
typename CollectiveMainloop::Params mainloop_params =
|
||||||
int K,
|
CollectiveMainloop::to_underlying_arguments({
|
||||||
int Experts,
|
static_cast<Element const*>(A),
|
||||||
int TokenPackSize,
|
get_gmem_layout<Batch>(M, K / 2),
|
||||||
int WeightScaleGroup>
|
static_cast<Element const*>(B),
|
||||||
void run_gemm(const InputType *A,
|
get_gmem_layout<Batch>(TokenPackSize == 0 ? max_tokens: TokenPackSize, K),
|
||||||
const InputType *B,
|
static_cast<ElementOutput*>(C),
|
||||||
OutputType *C,
|
get_gmem_layout<Batch>(M, TokenPackSize == 0 ? max_tokens : TokenPackSize),
|
||||||
const float *weight_scale,
|
weight_scale,
|
||||||
const float *input_dequant_scale,
|
input_row_sum,
|
||||||
const int64_t *tokens,
|
tokens
|
||||||
const int max_tokens,
|
});
|
||||||
cudaStream_t stream) {
|
|
||||||
using ElementOutput = typename Kernel_traits::ElementOutput;
|
void *kernel;
|
||||||
using Element = typename Kernel_traits::Element;
|
kernel = (void *)w4afp8_gemm_kernel<Kernel_traits>;
|
||||||
using CollectiveMainloop = CollectiveMainloopFwd<Kernel_traits>;
|
|
||||||
using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
|
int smem_size = sizeof(typename Kernel_traits::SharedStorage) + sizeof(float) * Kernel_traits::kBlockN;
|
||||||
|
|
||||||
constexpr int M_nums =
|
if (smem_size >= 48 * 1024) {
|
||||||
(M + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
|
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||||
const int N_nums =
|
}
|
||||||
(max_tokens + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
|
|
||||||
constexpr int K_scale_nums = K / Kernel_traits::kBlockM;
|
dim3 grid_dims;
|
||||||
static_assert(K % WeightScaleGroup == 0);
|
grid_dims.x = M_nums;
|
||||||
static_assert(WeightScaleGroup == 128 || WeightScaleGroup == K);
|
grid_dims.y = N_nums;
|
||||||
|
grid_dims.z = Batch;
|
||||||
typename CollectiveMainloop::Params mainloop_params =
|
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
|
||||||
CollectiveMainloop::to_underlying_arguments(
|
dim3 block_dims(ctaSize);
|
||||||
{static_cast<Element const *>(A),
|
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
|
||||||
get_gmem_layout<Experts>(M, K / 2),
|
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
|
||||||
static_cast<Element const *>(B),
|
cutlass::launch_kernel_on_cluster(
|
||||||
get_gmem_layout<Experts>(
|
launch_params, kernel, mainloop_params);
|
||||||
TokenPackSize == 0 ? max_tokens : TokenPackSize, K),
|
|
||||||
static_cast<ElementOutput *>(C),
|
|
||||||
get_gmem_layout<Experts>(
|
|
||||||
M, TokenPackSize == 0 ? max_tokens : TokenPackSize),
|
|
||||||
weight_scale,
|
|
||||||
get_scale_layout<Experts>(M_nums,
|
|
||||||
K_scale_nums * Kernel_traits::kBlockM),
|
|
||||||
input_dequant_scale,
|
|
||||||
tokens});
|
|
||||||
|
|
||||||
void *kernel;
|
|
||||||
kernel = (void *)w4afp8_gemm_kernel<Kernel_traits>;
|
|
||||||
|
|
||||||
int smem_size = sizeof(typename Kernel_traits::SharedStorage) +
|
|
||||||
Kernel_traits::kBlockN * sizeof(float);
|
|
||||||
|
|
||||||
if (smem_size >= 48 * 1024) {
|
|
||||||
cudaFuncSetAttribute(
|
|
||||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
dim3 grid_dims;
|
|
||||||
grid_dims.x = M_nums;
|
|
||||||
grid_dims.y = N_nums;
|
|
||||||
grid_dims.z = Experts;
|
|
||||||
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
|
|
||||||
dim3 block_dims(ctaSize);
|
|
||||||
dim3 cluster_dims(size<0>(ClusterShape{}),
|
|
||||||
size<1>(ClusterShape{}),
|
|
||||||
size<2>(ClusterShape{}));
|
|
||||||
cutlass::ClusterLaunchParams launch_params{
|
|
||||||
grid_dims, block_dims, cluster_dims, smem_size, stream};
|
|
||||||
cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params);
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,131 +0,0 @@
|
|||||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
#include "helper.h"
|
|
||||||
#include "paddle/extension.h"
|
|
||||||
|
|
||||||
void weight_convert(
|
|
||||||
const uint8_t* weight, uint8_t* weight_new, int experts, int M, int K) {
|
|
||||||
assert(K % 64 == 0);
|
|
||||||
for (int b = 0; b < experts; ++b) {
|
|
||||||
for (int m = 0; m < M; ++m) {
|
|
||||||
for (int k = 0; k < K; k += 64) {
|
|
||||||
for (int k_inner = 0; k_inner < 32; ++k_inner) {
|
|
||||||
uint8_t temp = 0;
|
|
||||||
uint8_t left = weight[b * M * K + m * K + k + k_inner];
|
|
||||||
uint8_t right = weight[b * M * K + m * K + k + k_inner + 32];
|
|
||||||
temp |= left << 4;
|
|
||||||
temp |= right;
|
|
||||||
weight_new[b * M * K / 2 + m * K / 2 + k / 2 + k_inner] =
|
|
||||||
*reinterpret_cast<uint8_t*>(&temp);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
__global__ void weight_permute_interleave_kernelw4afp8(const int8_t* input_data,
|
|
||||||
int8_t* output_data,
|
|
||||||
const int original_k,
|
|
||||||
const int original_n) {
|
|
||||||
const int numel = original_k * original_n / 4;
|
|
||||||
const int pack_group_size = 64;
|
|
||||||
const int thread_group_size = pack_group_size / 4; // 16
|
|
||||||
const int thread_k_stride = original_k / 4;
|
|
||||||
|
|
||||||
const int linear_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
||||||
|
|
||||||
if (linear_idx >= numel) return;
|
|
||||||
|
|
||||||
const int n_id = linear_idx / thread_k_stride;
|
|
||||||
const int k_id = linear_idx % thread_k_stride;
|
|
||||||
const int k_group_idx = k_id / thread_group_size;
|
|
||||||
const int k_idx_in_group = k_id % thread_group_size;
|
|
||||||
|
|
||||||
const int8_t* src = input_data +
|
|
||||||
k_group_idx * pack_group_size / 2 * original_n +
|
|
||||||
k_idx_in_group * original_n + n_id;
|
|
||||||
|
|
||||||
int8_t tmp0 = src[0];
|
|
||||||
int8_t tmp1 = src[pack_group_size / 4 * original_n];
|
|
||||||
|
|
||||||
int8_t tmp00 = (tmp0 & 0xF0) + 112;
|
|
||||||
int8_t tmp01 = ((tmp0 << 4) & 0xF0) + 112;
|
|
||||||
int8_t tmp10 = (tmp1 & 0xF0) + 112;
|
|
||||||
int8_t tmp11 = ((tmp1 << 4) & 0xF0) + 112;
|
|
||||||
|
|
||||||
uint8_t utmp00 = *(reinterpret_cast<uint8_t*>(&tmp00));
|
|
||||||
uint8_t utmp01 = *(reinterpret_cast<uint8_t*>(&tmp01));
|
|
||||||
uint8_t utmp10 = *(reinterpret_cast<uint8_t*>(&tmp10));
|
|
||||||
uint8_t utmp11 = *(reinterpret_cast<uint8_t*>(&tmp11));
|
|
||||||
|
|
||||||
utmp00 = (utmp00 & 0xF0) >> 4;
|
|
||||||
utmp01 = (utmp01 & 0xF0) >> 4;
|
|
||||||
utmp10 = (utmp10 & 0xF0) >> 4;
|
|
||||||
utmp11 = (utmp11 & 0xF0) >> 4;
|
|
||||||
|
|
||||||
tmp00 = *(reinterpret_cast<int8_t*>(&utmp00)) - 7;
|
|
||||||
tmp01 = *(reinterpret_cast<int8_t*>(&utmp01)) - 7;
|
|
||||||
tmp10 = *(reinterpret_cast<int8_t*>(&utmp10)) - 7;
|
|
||||||
tmp11 = *(reinterpret_cast<int8_t*>(&utmp11)) - 7;
|
|
||||||
|
|
||||||
if (tmp00 <= 0) {
|
|
||||||
tmp00 = 8 - tmp00;
|
|
||||||
}
|
|
||||||
if (tmp01 <= 0) {
|
|
||||||
tmp01 = 8 - tmp01;
|
|
||||||
}
|
|
||||||
if (tmp10 <= 0) {
|
|
||||||
tmp10 = 8 - tmp10;
|
|
||||||
}
|
|
||||||
if (tmp11 <= 0) {
|
|
||||||
tmp11 = 8 - tmp11;
|
|
||||||
}
|
|
||||||
|
|
||||||
int8_t dst0 = (tmp01 << 4) | tmp11;
|
|
||||||
int8_t dst1 = (tmp00 << 4) | tmp10;
|
|
||||||
|
|
||||||
int8_t* dst = output_data + n_id * original_k / 2 +
|
|
||||||
(k_group_idx * pack_group_size / 2) + k_idx_in_group * 2;
|
|
||||||
dst[0] = dst0;
|
|
||||||
dst[1] = dst1;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<paddle::Tensor> W4AFp8GemmWeightConvert(
|
|
||||||
const paddle::Tensor& weight) {
|
|
||||||
if (weight.place() == paddle::CPUPlace()) {
|
|
||||||
const int experts = weight.dims()[0];
|
|
||||||
const int M = weight.dims()[1];
|
|
||||||
const int K = weight.dims()[2];
|
|
||||||
paddle::Tensor weight_new = paddle::empty(
|
|
||||||
{experts, M, K / 2}, paddle::DataType::UINT8, weight.place());
|
|
||||||
weight_convert(
|
|
||||||
weight.data<uint8_t>(), weight_new.data<uint8_t>(), experts, M, K);
|
|
||||||
return {weight_new};
|
|
||||||
} else {
|
|
||||||
const int original_k = weight.dims()[0] * 2;
|
|
||||||
const int original_n = weight.dims()[1];
|
|
||||||
paddle::Tensor weight_new =
|
|
||||||
paddle::empty(weight.shape(), paddle::DataType::INT8, weight.place());
|
|
||||||
const int block_dim = 256;
|
|
||||||
const int original_numel = original_k * original_n;
|
|
||||||
const int grid_size = (original_numel + block_dim - 1) / block_dim;
|
|
||||||
|
|
||||||
weight_permute_interleave_kernelw4afp8<<<grid_size, block_dim>>>(
|
|
||||||
weight.data<int8_t>(),
|
|
||||||
weight_new.data<int8_t>(),
|
|
||||||
original_k,
|
|
||||||
original_n);
|
|
||||||
return {weight_new};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
#include "helper.h"
|
|
||||||
#include "paddle/extension.h"
|
|
||||||
|
|
||||||
template <typename T, int kPackSize>
|
|
||||||
__global__ void permute_scale_kernel(T* input_data, const int numel) {
|
|
||||||
using LoadT = AlignedVector<T, kPackSize>;
|
|
||||||
LoadT input_vec;
|
|
||||||
LoadT dst_vec;
|
|
||||||
const int load_idx = (blockIdx.x * blockDim.x + threadIdx.x) * kPackSize;
|
|
||||||
if (load_idx >= numel) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
Load<T, kPackSize>(&input_data[load_idx], &input_vec);
|
|
||||||
|
|
||||||
for (int i = 0; i < kPackSize; i += 2) {
|
|
||||||
dst_vec[i] = input_vec[i / 2];
|
|
||||||
dst_vec[i + 1] = input_vec[i / 2 + 8];
|
|
||||||
}
|
|
||||||
|
|
||||||
Store<T, kPackSize>(dst_vec, &input_data[load_idx]);
|
|
||||||
}
|
|
||||||
|
|
||||||
void W4AFp8GemmScalePermute(const paddle::Tensor& scale) {
|
|
||||||
const int row = scale.dims().size() == 2 ? scale.dims()[0] : 1;
|
|
||||||
const int col = scale.dims().size() == 2 ? scale.dims()[1] : scale.dims()[0];
|
|
||||||
if (col % 16 != 0) {
|
|
||||||
PD_THROW("Only supported when col is divisible by 16.");
|
|
||||||
}
|
|
||||||
const int numel = row * col;
|
|
||||||
const int threads = 128;
|
|
||||||
const int kPackSize = 16;
|
|
||||||
const int grid_size = (numel / kPackSize + threads - 1) / threads;
|
|
||||||
|
|
||||||
if (scale.dtype() == paddle::DataType::BFLOAT16) {
|
|
||||||
permute_scale_kernel<phi::dtype::bfloat16, kPackSize>
|
|
||||||
<<<grid_size, threads, 0, scale.stream()>>>(
|
|
||||||
const_cast<phi::dtype::bfloat16*>(
|
|
||||||
scale.data<phi::dtype::bfloat16>()),
|
|
||||||
numel);
|
|
||||||
} else if (scale.dtype() == paddle::DataType::FLOAT16) {
|
|
||||||
permute_scale_kernel<phi::dtype::float16, kPackSize>
|
|
||||||
<<<grid_size, threads, 0, scale.stream()>>>(
|
|
||||||
const_cast<phi::dtype::float16*>(scale.data<phi::dtype::float16>()),
|
|
||||||
numel);
|
|
||||||
} else if (scale.dtype() == paddle::DataType::FLOAT32) {
|
|
||||||
permute_scale_kernel<float, kPackSize>
|
|
||||||
<<<grid_size, threads, 0, scale.stream()>>>(
|
|
||||||
const_cast<float*>(scale.data<float>()), numel);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -11,8 +11,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import os
|
|
||||||
import re
|
|
||||||
|
|
||||||
file_dir = "./gpu_ops/w4afp8_gemm/"
|
file_dir = "./gpu_ops/w4afp8_gemm/"
|
||||||
|
|
||||||
@@ -32,12 +30,12 @@ gemm_template_head = """
|
|||||||
#include <cutlass/numeric_types.h>
|
#include <cutlass/numeric_types.h>
|
||||||
"""
|
"""
|
||||||
gemm_template_case = """
|
gemm_template_case = """
|
||||||
void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}(
|
void w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(
|
||||||
const cutlass::float_e4m3_t * weight,
|
const cutlass::float_e4m3_t * weight,
|
||||||
const cutlass::float_e4m3_t * input,
|
const cutlass::float_e4m3_t * input,
|
||||||
{cutlass_type} * out,
|
{cutlass_type} * out,
|
||||||
const float *weight_scale,
|
const float *weight_scale,
|
||||||
const float * input_dequant_scale,
|
const float *input_row_sum,
|
||||||
const int64_t *tokens,
|
const int64_t *tokens,
|
||||||
const int64_t max_tokens,
|
const int64_t max_tokens,
|
||||||
cudaStream_t stream);
|
cudaStream_t stream);
|
||||||
@@ -50,22 +48,22 @@ gemm_template_cu_head = """
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
gemm_template_cu_template = """
|
gemm_template_cu_template = """
|
||||||
void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}(
|
void w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(
|
||||||
const cutlass::float_e4m3_t * weight,
|
const cutlass::float_e4m3_t * weight,
|
||||||
const cutlass::float_e4m3_t * input,
|
const cutlass::float_e4m3_t * input,
|
||||||
{cutlass_type} * out,
|
{cutlass_type} * out,
|
||||||
const float *weight_scale,
|
const float *weight_scale,
|
||||||
const float * input_dequant_scale,
|
const float *input_row_sum,
|
||||||
const int64_t *tokens,
|
const int64_t *tokens,
|
||||||
const int64_t max_tokens,
|
const int64_t max_tokens,
|
||||||
cudaStream_t stream) {{
|
cudaStream_t stream) {{
|
||||||
|
|
||||||
constexpr static int M = {M};
|
constexpr static int M = {M};
|
||||||
constexpr static int K = {K};
|
constexpr static int K = {K};
|
||||||
constexpr static int EXPERTS = {EXPERTS};
|
constexpr static int Batch = {BATCH};
|
||||||
constexpr static int TokenPackSize = {PADDING};
|
constexpr static int TokenPackSize = {PADDING};
|
||||||
constexpr static int kBlockN = {N};
|
constexpr static int kBlockN = {N};
|
||||||
constexpr static int kGroupSize = {GROUPSIZE};
|
constexpr static int kBlockN_TAIL = {TAILN};
|
||||||
constexpr static int kBlockM = 128;
|
constexpr static int kBlockM = 128;
|
||||||
constexpr static int kBlockK = 128;
|
constexpr static int kBlockK = 128;
|
||||||
constexpr static int kNWarps = 4 + kBlockM / 16;
|
constexpr static int kNWarps = 4 + kBlockM / 16;
|
||||||
@@ -76,24 +74,22 @@ void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}(
|
|||||||
|
|
||||||
using Kernel_traits = Kernel_traits<
|
using Kernel_traits = Kernel_traits<
|
||||||
kBlockM, kBlockN, kBlockK, kNWarps, kStages, kTiles,
|
kBlockM, kBlockN, kBlockK, kNWarps, kStages, kTiles,
|
||||||
M, K, TokenPackSize, kGroupSize, kCluster, cutlass::float_e4m3_t,
|
M, TokenPackSize, kBlockN_TAIL, kCluster, cutlass::float_e4m3_t,
|
||||||
{cutlass_type}>;
|
{cutlass_type}>;
|
||||||
run_gemm<cutlass::float_e4m3_t, {cutlass_type},
|
run_gemm<cutlass::float_e4m3_t, {cutlass_type},
|
||||||
Kernel_traits, M, K, EXPERTS, TokenPackSize, kGroupSize>
|
Kernel_traits, M, K, Batch, TokenPackSize>
|
||||||
(weight, input, out, weight_scale, input_dequant_scale, tokens, max_tokens, stream);
|
(weight, input, out, weight_scale,
|
||||||
|
input_row_sum, tokens, max_tokens, stream);
|
||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# [M, K, Number of experts, token Padding Size, weight K group size]
|
|
||||||
gemm_case = [
|
gemm_case = [
|
||||||
[8192, 3584, 16, 0, 128], # eb45T ffn1
|
[8192, 3584, 8, 0], # eb45T ffn1
|
||||||
[8192, 3584, 16, 512, 128], # eb45T ffn1
|
[8192, 3584, 8, 2048], # eb45T ffn1
|
||||||
[7168, 8192, 16, 0, 128], # eb45T ffn2
|
[7168, 8192, 8, 0], # eb45T ffn2
|
||||||
[7168, 8192, 16, 512, 128], # eb45T ffn2
|
[7168, 8192, 8, 2048], # eb45T ffn2
|
||||||
[1792, 8192, 64, 0, 8192], # eb45t ffn1
|
[1792, 8192, 64, 0], # eb45t ffn1
|
||||||
[8192, 896, 64, 0, 896], # eb45t ffn2
|
[8192, 896, 64, 0], # eb45t ffn2
|
||||||
[1792, 8192, 64, 0, 128], # eb45t ffn1
|
|
||||||
[8192, 896, 64, 0, 128], # eb45t ffn2
|
|
||||||
]
|
]
|
||||||
|
|
||||||
dtype = ["BF16"]
|
dtype = ["BF16"]
|
||||||
@@ -101,19 +97,6 @@ dtype = ["BF16"]
|
|||||||
use_fast_compile = True
|
use_fast_compile = True
|
||||||
n_range = [256] if use_fast_compile else [i for i in range(16, 257, 16)]
|
n_range = [256] if use_fast_compile else [i for i in range(16, 257, 16)]
|
||||||
|
|
||||||
all_cu_files = []
|
|
||||||
for type in dtype:
|
|
||||||
for case in gemm_case:
|
|
||||||
for n in n_range:
|
|
||||||
all_cu_files.append(f"w4afp8_gemm_M{case[0]}_N{n}_G{case[4]}_K{case[1]}_E{case[2]}_P{case[3]}_{type}.cu")
|
|
||||||
|
|
||||||
for file_path, empty_list, file_name_list in os.walk(file_dir):
|
|
||||||
for file_name in file_name_list:
|
|
||||||
if re.match(r"^w4afp8_gemm_M\d+_N\d+_.*\.cu$", file_name):
|
|
||||||
if file_name not in all_cu_files:
|
|
||||||
print("delete w4afp8 kernel file", file_path + file_name)
|
|
||||||
os.remove(file_path + file_name)
|
|
||||||
|
|
||||||
|
|
||||||
def get_cutlass_type(type):
|
def get_cutlass_type(type):
|
||||||
if type == "BF16":
|
if type == "BF16":
|
||||||
@@ -133,16 +116,28 @@ for type in dtype:
|
|||||||
M=case[0],
|
M=case[0],
|
||||||
K=case[1],
|
K=case[1],
|
||||||
N=n,
|
N=n,
|
||||||
EXPERTS=case[2],
|
BATCH=case[2],
|
||||||
TYPE=type,
|
TYPE=type,
|
||||||
PADDING=case[3],
|
PADDING=case[3],
|
||||||
GROUPSIZE=case[4],
|
TAILN=0,
|
||||||
|
cutlass_type=get_cutlass_type(type),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
template_head_file.write(
|
||||||
|
gemm_template_case.format(
|
||||||
|
M=case[0],
|
||||||
|
K=case[1],
|
||||||
|
N=256,
|
||||||
|
BATCH=case[2],
|
||||||
|
TYPE=type,
|
||||||
|
PADDING=case[3],
|
||||||
|
TAILN=n - 16,
|
||||||
cutlass_type=get_cutlass_type(type),
|
cutlass_type=get_cutlass_type(type),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
template_cu_file = open(
|
template_cu_file = open(
|
||||||
f"{file_dir}w4afp8_gemm_M{case[0]}_N{n}_G{case[4]}_K{case[1]}_E{case[2]}_P{case[3]}_{type}.cu", "w"
|
f"{file_dir}w4afp8_gemm_M{case[0]}_N{n}_TAILN{0}_K{case[1]}_B{case[2]}_P{case[3]}_{type}.cu", "w"
|
||||||
)
|
)
|
||||||
template_cu_file.write(gemm_template_cu_head)
|
template_cu_file.write(gemm_template_cu_head)
|
||||||
template_cu_file.write(
|
template_cu_file.write(
|
||||||
@@ -150,10 +145,29 @@ for type in dtype:
|
|||||||
M=case[0],
|
M=case[0],
|
||||||
K=case[1],
|
K=case[1],
|
||||||
N=n,
|
N=n,
|
||||||
EXPERTS=case[2],
|
BATCH=case[2],
|
||||||
TYPE=type,
|
TYPE=type,
|
||||||
PADDING=case[3],
|
PADDING=case[3],
|
||||||
GROUPSIZE=case[4],
|
TAILN=0,
|
||||||
|
cutlass_type=get_cutlass_type(type),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
template_cu_file.close()
|
||||||
|
|
||||||
|
template_cu_file = open(
|
||||||
|
f"{file_dir}w4afp8_gemm_M{case[0]}_N{256}_TAILN{n-16}_K{case[1]}_B{case[2]}_P{case[3]}_{type}.cu", "w"
|
||||||
|
)
|
||||||
|
template_cu_file.write(gemm_template_cu_head)
|
||||||
|
template_cu_file.write(
|
||||||
|
gemm_template_cu_template.format(
|
||||||
|
M=case[0],
|
||||||
|
K=case[1],
|
||||||
|
N=256,
|
||||||
|
BATCH=case[2],
|
||||||
|
TYPE=type,
|
||||||
|
PADDING=case[3],
|
||||||
|
TAILN=n - 16,
|
||||||
cutlass_type=get_cutlass_type(type),
|
cutlass_type=get_cutlass_type(type),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -163,8 +177,8 @@ for type in dtype:
|
|||||||
for type in dtype:
|
for type in dtype:
|
||||||
template_head_file.write("\n")
|
template_head_file.write("\n")
|
||||||
template_head_file.write(
|
template_head_file.write(
|
||||||
"""#define GEMM_SWITCH_{TYPE}(_M, _K, _EXPERTS, _TokenPaddingSize, _kBlockN, _GROUPSIZE, ...) {{ \\
|
"""#define GEMM_SWITCH_{TYPE}(_M, _K, _BATCH, _TokenPaddingSize, _kBlockN, _TailN, ...) {{ \\
|
||||||
if (_M == 0 && _K == 0 && _EXPERTS == 0 && _TokenPaddingSize == 0 && _kBlockN == 0 && _GROUPSIZE == 0) {{ \\""".format(
|
if (_M == 0 && _K == 0 && _BATCH == 0 && _TokenPaddingSize == 0 && _kBlockN == 0 && _TailN == 0) {{ \\""".format(
|
||||||
TYPE=type
|
TYPE=type
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -174,16 +188,23 @@ for type in dtype:
|
|||||||
for case in gemm_case:
|
for case in gemm_case:
|
||||||
for n in n_range:
|
for n in n_range:
|
||||||
template_head_file.write(
|
template_head_file.write(
|
||||||
""" }} else if (_M == {M} && _K == {K} && _EXPERTS == {EXPERTS} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _GROUPSIZE == {GROUPSIZE}) {{ \\
|
""" }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\
|
||||||
w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format(
|
w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format(
|
||||||
M=case[0], K=case[1], N=n, EXPERTS=case[2], TYPE=type, PADDING=case[3], GROUPSIZE=case[4]
|
M=case[0], K=case[1], N=n, BATCH=case[2], TYPE=type, PADDING=case[3], TAILN=0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
template_head_file.write("\n")
|
||||||
|
template_head_file.write(
|
||||||
|
""" }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\
|
||||||
|
w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format(
|
||||||
|
M=case[0], K=case[1], N=256, BATCH=case[2], TYPE=type, PADDING=case[3], TAILN=n - 16
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
template_head_file.write("\n")
|
template_head_file.write("\n")
|
||||||
|
|
||||||
template_head_file.write(
|
template_head_file.write(
|
||||||
""" } else { \\
|
""" } else { \\
|
||||||
PADDLE_THROW(phi::errors::Unimplemented("W4aFp8 not supported m=%d k=%d experts=%d token_padding_size=%d kBlockN=%d groupsize=%d\\n", _M, _K, _EXPERTS, _TokenPaddingSize, _kBlockN, _GROUPSIZE)); \\
|
PADDLE_THROW(phi::errors::Unimplemented("W4aFp8 not supported m=%d k=%d batch=%d token_padding_size=%d kBlockN=%d tailN=%d\\n", _M, _K, _BATCH, _TokenPaddingSize, _kBlockN, _TailN)); \\
|
||||||
} \\
|
} \\
|
||||||
}"""
|
}"""
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -30,10 +30,7 @@ if current_platform.is_cuda():
|
|||||||
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce
|
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from fastdeploy.model_executor.ops.gpu import (
|
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute
|
||||||
w4afp8_gemm_scale_permute,
|
|
||||||
w4afp8_gemm_weight_convert,
|
|
||||||
)
|
|
||||||
except:
|
except:
|
||||||
logger.warning("import w4afp8_gemm_scale_permute Failed!")
|
logger.warning("import w4afp8_gemm_scale_permute Failed!")
|
||||||
elif current_platform.is_iluvatar():
|
elif current_platform.is_iluvatar():
|
||||||
@@ -78,7 +75,6 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
expert_idx_per_token: paddle.Tensor,
|
expert_idx_per_token: paddle.Tensor,
|
||||||
used_in_ep_low_latency: bool = False,
|
used_in_ep_low_latency: bool = False,
|
||||||
estimate_total_token_nums: int = -1,
|
estimate_total_token_nums: int = -1,
|
||||||
dequant_scale: paddle.Tensor = None,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Paddle Cutlass compute Fused MoE.
|
Paddle Cutlass compute Fused MoE.
|
||||||
@@ -104,7 +100,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
token_nums_per_expert,
|
token_nums_per_expert,
|
||||||
getattr(layer, self.added_weight_attrs[0]),
|
getattr(layer, self.added_weight_attrs[0]),
|
||||||
getattr(layer, self.added_weight_attrs[1]),
|
getattr(layer, self.added_weight_attrs[1]),
|
||||||
dequant_scale,
|
# None,
|
||||||
(layer.up_gate_proj_bias if hasattr(layer, "up_gate_proj_bias") else None),
|
(layer.up_gate_proj_bias if hasattr(layer, "up_gate_proj_bias") else None),
|
||||||
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
|
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
|
||||||
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
|
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
|
||||||
@@ -116,7 +112,6 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
getattr(layer.moe_quant_config, "hadamard_block_size", 128),
|
getattr(layer.moe_quant_config, "hadamard_block_size", 128),
|
||||||
layer.activation,
|
layer.activation,
|
||||||
)
|
)
|
||||||
|
|
||||||
if layer.with_bias:
|
if layer.with_bias:
|
||||||
down_proj_bias_expand = paddle.index_select(layer.down_proj_bias, expert_idx_per_token, axis=0)
|
down_proj_bias_expand = paddle.index_select(layer.down_proj_bias, expert_idx_per_token, axis=0)
|
||||||
ffn_out_without_down_proj_bias = paddle.add(ffn_out_without_down_proj_bias, down_proj_bias_expand)
|
ffn_out_without_down_proj_bias = paddle.add(ffn_out_without_down_proj_bias, down_proj_bias_expand)
|
||||||
@@ -265,7 +260,6 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
topk_weights,
|
topk_weights,
|
||||||
topk_idx,
|
topk_idx,
|
||||||
expert_idx_per_token,
|
expert_idx_per_token,
|
||||||
dequant_scale,
|
|
||||||
) = moe_expert_dispatch(
|
) = moe_expert_dispatch(
|
||||||
x,
|
x,
|
||||||
gate_out,
|
gate_out,
|
||||||
@@ -286,21 +280,19 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
topk_weights,
|
topk_weights,
|
||||||
topk_idx,
|
topk_idx,
|
||||||
expert_idx_per_token,
|
expert_idx_per_token,
|
||||||
dequant_scale,
|
|
||||||
) = moe_expert_dispatch(
|
) = moe_expert_dispatch(
|
||||||
x,
|
x,
|
||||||
gate_out,
|
gate_out,
|
||||||
layer.gate_correction_bias,
|
layer.gate_correction_bias,
|
||||||
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None),
|
(
|
||||||
|
layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None
|
||||||
|
), # if set, permute_input will be int8_t
|
||||||
layer.top_k,
|
layer.top_k,
|
||||||
False,
|
False,
|
||||||
self.moe_quant_type,
|
self.moe_quant_type,
|
||||||
topk_only_mode=False,
|
topk_only_mode=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(layer, "up_gate_proj_in_scale"):
|
|
||||||
dequant_scale = None
|
|
||||||
|
|
||||||
if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
|
if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
|
||||||
# only w4a8 need expert_idx_per_token
|
# only w4a8 need expert_idx_per_token
|
||||||
# Other need not this tensor, so we make it None.
|
# Other need not this tensor, so we make it None.
|
||||||
@@ -308,9 +300,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
else:
|
else:
|
||||||
expert_idx_per_token = expert_idx_per_token.cast("int64")
|
expert_idx_per_token = expert_idx_per_token.cast("int64")
|
||||||
|
|
||||||
ffn_out = self.compute_ffn(
|
ffn_out = self.compute_ffn(layer, permute_input, token_nums_per_expert, expert_idx_per_token)
|
||||||
layer, permute_input, token_nums_per_expert, expert_idx_per_token, False, -1, dequant_scale
|
|
||||||
)
|
|
||||||
|
|
||||||
# reduce 中会做 topk 个 weight 的 norm 和 routed_scaling_factor
|
# reduce 中会做 topk 个 weight 的 norm 和 routed_scaling_factor
|
||||||
fused_moe_out = moe_expert_reduce(
|
fused_moe_out = moe_expert_reduce(
|
||||||
@@ -855,7 +845,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
|||||||
weight_name = self.added_weight_attrs[idx]
|
weight_name = self.added_weight_attrs[idx]
|
||||||
weight_list = []
|
weight_list = []
|
||||||
for i in range(layer.num_local_experts):
|
for i in range(layer.num_local_experts):
|
||||||
quant_weight = w4afp8_gemm_weight_convert(weight_tensor[i])
|
quant_weight, scale = weight_quantize(weight_tensor[i], algo=self.moe_quant_type, arch=80)
|
||||||
weight_list.append(quant_weight)
|
weight_list.append(quant_weight)
|
||||||
quanted_weight = paddle.stack(weight_list, axis=0)
|
quanted_weight = paddle.stack(weight_list, axis=0)
|
||||||
getattr(layer, weight_name).set_value(quanted_weight)
|
getattr(layer, weight_name).set_value(quanted_weight)
|
||||||
@@ -885,29 +875,16 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# in_scales
|
# in_scales
|
||||||
if not layer.moe_quant_config.moe_dynamic_quant:
|
for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]:
|
||||||
for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]:
|
setattr(
|
||||||
setattr(
|
layer,
|
||||||
layer,
|
in_scale_name,
|
||||||
in_scale_name,
|
layer.create_parameter(
|
||||||
layer.create_parameter(
|
shape=[layer.num_local_experts],
|
||||||
shape=[layer.num_local_experts],
|
dtype="float32",
|
||||||
dtype="float32",
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
if layer.ep_size > 1:
|
|
||||||
for in_scale_name in ["up_gate_proj_in_scale"]:
|
|
||||||
setattr(
|
|
||||||
layer,
|
|
||||||
in_scale_name,
|
|
||||||
layer.create_parameter(
|
|
||||||
shape=[layer.num_local_experts],
|
|
||||||
dtype="float32",
|
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# weight_scales
|
# weight_scales
|
||||||
setattr(
|
setattr(
|
||||||
@@ -965,57 +942,10 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
|||||||
return weight_scale
|
return weight_scale
|
||||||
|
|
||||||
def _process_weight_scale(name: str, weight_scales: list[paddle.Tensor], processed_in_scale: paddle.Tensor):
|
def _process_weight_scale(name: str, weight_scales: list[paddle.Tensor], processed_in_scale: paddle.Tensor):
|
||||||
if processed_in_scale is not None:
|
processed_weight_scale = (
|
||||||
processed_weight_scale = paddle.stack(weight_scales, axis=0) / (448 * 7 * 2 ** (-9))
|
paddle.stack(weight_scales, axis=0) / (448 * 7 * 2 ** (-9)) / processed_in_scale[:, None]
|
||||||
if len(processed_weight_scale.shape) == 3:
|
)
|
||||||
processed_weight_scale = (
|
processed_weight_scale = _permute_weight_scale(processed_weight_scale)
|
||||||
processed_weight_scale.transpose([0, 2, 1]) / processed_in_scale[:, None, None]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
processed_weight_scale = processed_weight_scale / processed_in_scale[:, None]
|
|
||||||
else:
|
|
||||||
processed_weight_scale = paddle.stack(weight_scales, axis=0) / (448 * 7 * 2 ** (-9))
|
|
||||||
|
|
||||||
if len(processed_weight_scale.shape) == 3:
|
|
||||||
if name == "up_gate_proj_weight_scale" and processed_weight_scale.shape[-1] * 128 != layer.hidden_size:
|
|
||||||
assert (
|
|
||||||
layer.hidden_size // 128 % processed_weight_scale.shape[-1] == 0
|
|
||||||
), "weight_scale_group_size must be a multiple of 128"
|
|
||||||
# If it is a multiple of 128, repeat to 128
|
|
||||||
processed_weight_scale = processed_weight_scale.repeat_interleave(
|
|
||||||
layer.hidden_size // 128 // processed_weight_scale.shape[-1], axis=-1
|
|
||||||
)
|
|
||||||
elif (
|
|
||||||
name == "down_proj_weight_scale"
|
|
||||||
and processed_weight_scale.shape[-1] * 128 != layer.moe_intermediate_size
|
|
||||||
):
|
|
||||||
assert (
|
|
||||||
layer.moe_intermediate_size // 128 % processed_weight_scale.shape[-1] == 0
|
|
||||||
), "weight_scale_group_size must be a multiple of 128"
|
|
||||||
# If it is a multiple of 128, repeat to 128
|
|
||||||
processed_weight_scale = processed_weight_scale.repeat_interleave(
|
|
||||||
layer.moe_intermediate_size // 128 // processed_weight_scale.shape[-1], axis=-1
|
|
||||||
)
|
|
||||||
|
|
||||||
origin_shape = processed_weight_scale.shape
|
|
||||||
processed_weight_scale = processed_weight_scale.transpose([0, 2, 1])
|
|
||||||
processed_weight_scale = processed_weight_scale.reshape([-1, processed_weight_scale.shape[-1]])
|
|
||||||
processed_weight_scale = _permute_weight_scale(processed_weight_scale)
|
|
||||||
processed_weight_scale = processed_weight_scale.reshape(
|
|
||||||
[origin_shape[0], origin_shape[2], origin_shape[1] // 128, 128]
|
|
||||||
)
|
|
||||||
processed_weight_scale = processed_weight_scale.transpose([0, 2, 1, 3])
|
|
||||||
setattr(
|
|
||||||
layer,
|
|
||||||
name,
|
|
||||||
layer.create_parameter(
|
|
||||||
shape=processed_weight_scale.shape,
|
|
||||||
dtype="float32",
|
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
processed_weight_scale = _permute_weight_scale(processed_weight_scale)
|
|
||||||
getattr(layer, name).set_value(processed_weight_scale)
|
getattr(layer, name).set_value(processed_weight_scale)
|
||||||
|
|
||||||
# 1. Init scale containers and maps
|
# 1. Init scale containers and maps
|
||||||
@@ -1062,15 +992,16 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
|||||||
scale_tensor = _extract_scale_tensor(layer, state_dict, scale_key_template, expert_idx)
|
scale_tensor = _extract_scale_tensor(layer, state_dict, scale_key_template, expert_idx)
|
||||||
scale_weight_map[name].append(scale_tensor)
|
scale_weight_map[name].append(scale_tensor)
|
||||||
|
|
||||||
|
# 3. Process scale tensor and set to layer
|
||||||
|
in_scales = []
|
||||||
|
for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]:
|
||||||
|
in_scales.append(_process_in_scale(in_scale_name, scale_weight_map[in_scale_name]))
|
||||||
|
|
||||||
for i, weight_scale_name in enumerate(["up_gate_proj_weight_scale", "down_proj_weight_scale"]):
|
for i, weight_scale_name in enumerate(["up_gate_proj_weight_scale", "down_proj_weight_scale"]):
|
||||||
in_scale_name = weight_scale_name.replace("_weight_scale", "_in_scale")
|
|
||||||
in_scale = None
|
|
||||||
if hasattr(layer, in_scale_name) and in_scale_name in scale_weight_map.keys():
|
|
||||||
in_scale = _process_in_scale(in_scale_name, scale_weight_map[in_scale_name])
|
|
||||||
_process_weight_scale(
|
_process_weight_scale(
|
||||||
weight_scale_name,
|
weight_scale_name,
|
||||||
scale_weight_map[weight_scale_name],
|
scale_weight_map[weight_scale_name],
|
||||||
in_scale,
|
in_scales[i],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -275,7 +275,6 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
|
|||||||
topk_weights,
|
topk_weights,
|
||||||
topk_idx,
|
topk_idx,
|
||||||
expert_idx_per_token,
|
expert_idx_per_token,
|
||||||
dequant_scale,
|
|
||||||
) = moe_expert_dispatch(
|
) = moe_expert_dispatch(
|
||||||
x,
|
x,
|
||||||
gate_out,
|
gate_out,
|
||||||
|
|||||||
@@ -39,7 +39,6 @@ class MixQuantConfig(QuantConfigBase):
|
|||||||
is_permuted: bool = True,
|
is_permuted: bool = True,
|
||||||
is_quantized: bool = False,
|
is_quantized: bool = False,
|
||||||
hadamard_block_size: int = 128,
|
hadamard_block_size: int = 128,
|
||||||
moe_dynamic_quant: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dense_quant_type = dense_quant_type
|
self.dense_quant_type = dense_quant_type
|
||||||
@@ -58,7 +57,6 @@ class MixQuantConfig(QuantConfigBase):
|
|||||||
self.is_checkpoint_bf16 = not is_quantized
|
self.is_checkpoint_bf16 = not is_quantized
|
||||||
self.is_quantized = is_quantized
|
self.is_quantized = is_quantized
|
||||||
self.hadamard_block_size = hadamard_block_size
|
self.hadamard_block_size = hadamard_block_size
|
||||||
self.moe_dynamic_quant = moe_dynamic_quant
|
|
||||||
|
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return "mix_quant"
|
return "mix_quant"
|
||||||
@@ -75,7 +73,6 @@ class MixQuantConfig(QuantConfigBase):
|
|||||||
config.get("is_permuted", True),
|
config.get("is_permuted", True),
|
||||||
config.get("is_quantized", False),
|
config.get("is_quantized", False),
|
||||||
config.get("hadamard_block_size", 128),
|
config.get("hadamard_block_size", 128),
|
||||||
config.get("moe_dynamic_quant", False),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||||
|
|||||||
@@ -17,20 +17,16 @@ import unittest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
|
|
||||||
from fastdeploy.model_executor.ops.gpu import (
|
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm, w4afp8_gemm_weight_convert
|
||||||
w4afp8_gemm,
|
|
||||||
w4afp8_gemm_scale_permute,
|
|
||||||
w4afp8_gemm_weight_convert,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestW4AFP8GEMM(unittest.TestCase):
|
class TestW4AFP8GEMM(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
paddle.seed(0)
|
paddle.seed(0)
|
||||||
self.tokens_per_group = 1
|
self.tokens_per_group = 256
|
||||||
self.N = 1792
|
self.N = 256
|
||||||
self.K = 8192
|
self.K = 256
|
||||||
self.BATCH = 64
|
self.BATCH = 1
|
||||||
self.TokenPadding = 0
|
self.TokenPadding = 0
|
||||||
|
|
||||||
tokens = [self.tokens_per_group] * self.BATCH
|
tokens = [self.tokens_per_group] * self.BATCH
|
||||||
@@ -42,15 +38,14 @@ class TestW4AFP8GEMM(unittest.TestCase):
|
|||||||
|
|
||||||
self.input_fp8 = paddle.randn([self.all_tokens, self.K], dtype="bfloat16").astype(paddle.float8_e4m3fn)
|
self.input_fp8 = paddle.randn([self.all_tokens, self.K], dtype="bfloat16").astype(paddle.float8_e4m3fn)
|
||||||
self.input_bf16 = self.input_fp8.astype("bfloat16")
|
self.input_bf16 = self.input_fp8.astype("bfloat16")
|
||||||
self.weight = paddle.randn([self.BATCH, self.N, self.K], dtype="bfloat16")
|
self.weight = paddle.randn([self.BATCH, self.N, self.K], dtype="bfloat16") / 10
|
||||||
|
|
||||||
self.weight_scale = 7 / self.weight.abs().max(axis=-1).reshape([self.BATCH, self.N, 1])
|
self.weight_scale = 7 / self.weight.abs().max(axis=-1).reshape([self.BATCH, self.N, 1])
|
||||||
self.weight_quant = (self.weight * self.weight_scale).astype("int")
|
self.weight_quant = (self.weight * self.weight_scale).astype("int") + 7
|
||||||
self.weight_quant = paddle.clip(self.weight_quant, -7, 7)
|
self.weight_quant = paddle.clip(self.weight_quant, 0, 14)
|
||||||
self.weight_quant_naive = self.weight_quant.astype("float32")
|
|
||||||
self.weight_quant = self.weight_quant.astype("bfloat16")
|
self.weight_quant = self.weight_quant.astype("bfloat16")
|
||||||
self.weight_quant = paddle.where(self.weight_quant > 0, self.weight_quant, 8 - self.weight_quant)
|
|
||||||
self.weight_dequant_scale = 1 / self.weight_scale.astype("float32")
|
self.weight_dequant_scale = 1 / self.weight_scale.astype("float32")
|
||||||
|
self.input_row_sum = self.input_bf16.sum(axis=1) * -7 / 512
|
||||||
self.max_tokens = int(self.tokens.max())
|
self.max_tokens = int(self.tokens.max())
|
||||||
|
|
||||||
def w4afp8_gemm_naive(self, input_bf16, weight_quant, tokens, weight_dequant_scale):
|
def w4afp8_gemm_naive(self, input_bf16, weight_quant, tokens, weight_dequant_scale):
|
||||||
@@ -59,7 +54,7 @@ class TestW4AFP8GEMM(unittest.TestCase):
|
|||||||
pre_fix_token = 0
|
pre_fix_token = 0
|
||||||
for i in range(self.BATCH):
|
for i in range(self.BATCH):
|
||||||
input = input_bf16[pre_fix_token : pre_fix_token + tokens[i], :]
|
input = input_bf16[pre_fix_token : pre_fix_token + tokens[i], :]
|
||||||
weight = weight_quant[i] * weight_dequant_scale[i]
|
weight = (weight_quant[i] - 7.0) * weight_dequant_scale[i]
|
||||||
out_i = paddle.matmul(input, weight.astype("bfloat16"), transpose_y=True)
|
out_i = paddle.matmul(input, weight.astype("bfloat16"), transpose_y=True)
|
||||||
out[pre_fix_token : pre_fix_token + tokens[i], :] = out_i
|
out[pre_fix_token : pre_fix_token + tokens[i], :] = out_i
|
||||||
pre_fix_token += tokens[i]
|
pre_fix_token += tokens[i]
|
||||||
@@ -76,53 +71,37 @@ class TestW4AFP8GEMM(unittest.TestCase):
|
|||||||
weight_scale[b, n + j + 1] = temp[j // 2 + 8]
|
weight_scale[b, n + j + 1] = temp[j // 2 + 8]
|
||||||
return weight_scale
|
return weight_scale
|
||||||
|
|
||||||
def get_per_group_scale(self, processed_weight_scale):
|
|
||||||
processed_weight_scale = processed_weight_scale.repeat_interleave(self.K // 128, axis=-1)
|
|
||||||
origin_shape = processed_weight_scale.shape
|
|
||||||
processed_weight_scale = processed_weight_scale.transpose([0, 2, 1])
|
|
||||||
processed_weight_scale = processed_weight_scale.reshape([-1, processed_weight_scale.shape[-1]])
|
|
||||||
|
|
||||||
processed_weight_scale = w4afp8_gemm_scale_permute(processed_weight_scale)
|
|
||||||
processed_weight_scale = processed_weight_scale.reshape(
|
|
||||||
[origin_shape[0], origin_shape[2], origin_shape[1] // 128, 128]
|
|
||||||
)
|
|
||||||
processed_weight_scale = processed_weight_scale.transpose([0, 2, 1, 3])
|
|
||||||
return processed_weight_scale
|
|
||||||
|
|
||||||
def test_w4afp8_gemm(self):
|
def test_w4afp8_gemm(self):
|
||||||
out_naive = self.w4afp8_gemm_naive(
|
out_naive = self.w4afp8_gemm_naive(self.input_bf16, self.weight_quant, self.tokens, self.weight_dequant_scale)
|
||||||
self.input_bf16, self.weight_quant_naive, self.tokens, self.weight_dequant_scale
|
|
||||||
)
|
|
||||||
|
|
||||||
# weight_dequant_scale = paddle.to_tensor(self.permute_scale(self.weight_dequant_scale) * 512)
|
weight_dequant_scale = paddle.to_tensor(self.permute_scale(self.weight_dequant_scale) * 512)
|
||||||
weight_dequant_scale = self.get_per_group_scale(self.weight_dequant_scale * 512)
|
weight_int4 = w4afp8_gemm_weight_convert(self.weight_quant.astype("uint8").cpu())
|
||||||
weight_int4 = w4afp8_gemm_weight_convert(self.weight_quant.astype("uint8").cpu()).cuda()
|
|
||||||
|
|
||||||
if self.TokenPadding == 0:
|
if self.TokenPadding == 0:
|
||||||
out_cuda = w4afp8_gemm(
|
out_cuda = w4afp8_gemm(
|
||||||
self.input_fp8,
|
self.input_fp8,
|
||||||
weight_int4,
|
weight_int4.cuda(),
|
||||||
self.tokens_prefix_sum,
|
self.tokens_prefix_sum,
|
||||||
|
self.input_row_sum.astype("float32"),
|
||||||
weight_dequant_scale.astype("float32"),
|
weight_dequant_scale.astype("float32"),
|
||||||
None,
|
|
||||||
int(self.TokenPadding),
|
int(self.TokenPadding),
|
||||||
self.all_tokens,
|
self.max_tokens,
|
||||||
True,
|
True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
out_cuda = w4afp8_gemm(
|
out_cuda = w4afp8_gemm(
|
||||||
self.input_fp8,
|
self.input_fp8,
|
||||||
weight_int4,
|
weight_int4.cuda(),
|
||||||
self.tokens,
|
self.tokens,
|
||||||
|
self.input_row_sum.astype("float32"),
|
||||||
weight_dequant_scale.astype("float32"),
|
weight_dequant_scale.astype("float32"),
|
||||||
None,
|
|
||||||
int(self.TokenPadding),
|
int(self.TokenPadding),
|
||||||
self.max_tokens,
|
self.max_tokens,
|
||||||
True,
|
True,
|
||||||
)
|
)
|
||||||
|
|
||||||
gap = (out_cuda - out_naive).abs()
|
gap = (out_cuda - out_naive).abs()
|
||||||
self.assertLess(float(gap.mean()), 0.11)
|
self.assertLess(float(gap.mean()), 0.07)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user