【Hackathon 9th No.86】autogen MoeFastHardamardImplWrapper template_instantiation (#4592)

* autogen MoeFastHardamardImplWrapper template_instantiation

* fix codestyle

* fix codestyle

* add impl cu files
This commit is contained in:
Zhenghai Zhang
2025-10-30 10:28:36 +08:00
committed by GitHub
parent e25c067f70
commit 1712e1351b
15 changed files with 2904 additions and 1862 deletions

View File

@@ -0,0 +1,142 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "helper.h"
#include "multiquery_decoder_attention_kernel.h"
#include "utils.cuh"
template <typename T>
void DecodeMLAAttentionKernel(
const AppendAttnMetaData &meta_data,
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
const paddle::Tensor &cache_k,
const paddle::Tensor &cache_v,
const paddle::optional<paddle::Tensor> &attn_mask,
const paddle::optional<paddle::Tensor> &shift_bias,
const paddle::optional<paddle::Tensor> &smooth_weight,
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &batch_id_per_token,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
int max_seq_len,
int max_dec_len,
float softmax_scale,
float in_scale,
bool causal,
cudaStream_t &stream,
paddle::Tensor *out) {
const auto token_num = meta_data.token_nums;
const auto block_size = meta_data.block_size;
const auto bsz = meta_data.batch_size;
const auto num_heads = meta_data.q_num_heads;
const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads;
const auto head_dim_qk = meta_data.head_dims;
const auto head_dim_v = meta_data.head_dims_v;
const float rope_scale = 0.0;
const float rope_theta = 0.0;
const uint32_t deal_each_time = get_cascade_attention_deal_each_time();
const uint32_t num_stage = get_cascade_attention_num_stages();
const uint32_t num_threads = get_cascade_attention_num_threads();
DISPATCH_CAUSAL(
causal,
CAUSAL,
{DISPATCH_MLA_GROUP_SIZE(
group_size,
GROUP_SIZE,
{DISPATCH_MLA_HEAD_DIM(
head_dim_qk,
HEAD_DIM_QK,
{DISPATCH_MLA_HEAD_DIM(
head_dim_v,
HEAD_DIM_V,
{DISPATCH_BLOCK_SIZE(
block_size,
BLOCK_SIZE,
{DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME, {
MultiQueryDecoderAttention<T,
GROUP_SIZE,
HEAD_DIM_QK,
HEAD_DIM_V,
BLOCK_SIZE,
CAUSAL,
2,
16,
DEAL_EACH_TIME>(
meta_data,
stream,
q,
cache_k,
cache_v,
attn_mask,
shift_bias,
smooth_weight,
seq_lens_q,
seq_lens_kv,
batch_id_per_token,
cu_seqlens_q,
block_table,
max_seq_len,
max_dec_len,
rope_scale,
rope_theta,
softmax_scale,
in_scale,
out);
})})})})})});
}
template void DecodeMLAAttentionKernel<paddle::bfloat16>(
const AppendAttnMetaData &meta_data,
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
const paddle::Tensor &cache_k,
const paddle::Tensor &cache_v,
const paddle::optional<paddle::Tensor> &attn_mask,
const paddle::optional<paddle::Tensor> &shift_bias,
const paddle::optional<paddle::Tensor> &smooth_weight,
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &batch_id_per_token,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
int max_seq_len,
int max_dec_len,
float softmax_scale,
float in_scale,
bool causal,
cudaStream_t &stream,
paddle::Tensor *out);
template void DecodeMLAAttentionKernel<paddle::float16>(
const AppendAttnMetaData &meta_data,
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
const paddle::Tensor &cache_k,
const paddle::Tensor &cache_v,
const paddle::optional<paddle::Tensor> &attn_mask,
const paddle::optional<paddle::Tensor> &shift_bias,
const paddle::optional<paddle::Tensor> &smooth_weight,
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &batch_id_per_token,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
int max_seq_len,
int max_dec_len,
float softmax_scale,
float in_scale,
bool causal,
cudaStream_t &stream,
paddle::Tensor *out);

View File

@@ -0,0 +1,39 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "helper.h"
#include "utils.cuh"
template <typename T>
void DecodeMLAAttentionKernel(
const AppendAttnMetaData &meta_data,
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
const paddle::Tensor &cache_k,
const paddle::Tensor &cache_v,
const paddle::optional<paddle::Tensor> &attn_mask,
const paddle::optional<paddle::Tensor> &shift_bias,
const paddle::optional<paddle::Tensor> &smooth_weight,
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &batch_id_per_token,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
int max_seq_len,
int max_dec_len,
float softmax_scale,
float in_scale,
bool causal,
cudaStream_t &stream,
paddle::Tensor *out);

View File

@@ -1,105 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "helper.h"
#include "utils.cuh"
#include "multiquery_decoder_attention_impl.cuh"
template <typename T>
void DecodeMLAAttentionKernel(
const AppendAttnMetaData& meta_data,
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
const paddle::Tensor &cache_k,
const paddle::Tensor &cache_v,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &batch_id_per_token,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
int max_seq_len,
int max_dec_len,
float softmax_scale,
float in_scale,
bool causal,
cudaStream_t &stream,
paddle::Tensor *out) {
const auto token_num = meta_data.token_nums;
const auto block_size = meta_data.block_size;
const auto bsz = meta_data.batch_size;
const auto num_heads = meta_data.q_num_heads;
const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads;
const auto head_dim_qk = meta_data.head_dims;
const auto head_dim_v = meta_data.head_dims_v;
const float rope_scale = 0.0;
const float rope_theta = 0.0;
const uint32_t deal_each_time = get_cascade_attention_deal_each_time();
const uint32_t num_stage = get_cascade_attention_num_stages();
const uint32_t num_threads = get_cascade_attention_num_threads();
DISPATCH_CAUSAL(causal, CAUSAL,
{DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE,
{DISPATCH_MLA_HEAD_DIM(head_dim_qk, HEAD_DIM_QK,
{DISPATCH_MLA_HEAD_DIM(head_dim_v, HEAD_DIM_V,
{DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE,
{DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME,
{MultiQueryDecoderAttention<T, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V, BLOCK_SIZE, CAUSAL, 2, 16, DEAL_EACH_TIME>(
meta_data, stream, q, cache_k, cache_v, attn_mask, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, batch_id_per_token, cu_seqlens_q,
block_table, max_seq_len, max_dec_len, rope_scale, rope_theta, softmax_scale, in_scale, out);})})})})})});
}
template void DecodeMLAAttentionKernel<paddle::bfloat16>(
const AppendAttnMetaData& meta_data,
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
const paddle::Tensor &cache_k,
const paddle::Tensor &cache_v,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &batch_id_per_token,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
int max_seq_len,
int max_dec_len,
float softmax_scale,
float in_scale,
bool causal,
cudaStream_t &stream,
paddle::Tensor *out);
template void DecodeMLAAttentionKernel<paddle::float16>(
const AppendAttnMetaData& meta_data,
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
const paddle::Tensor &cache_k,
const paddle::Tensor &cache_v,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &batch_id_per_token,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
int max_seq_len,
int max_dec_len,
float softmax_scale,
float in_scale,
bool causal,
cudaStream_t &stream,
paddle::Tensor *out);

File diff suppressed because it is too large Load Diff

View File

@@ -1,37 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "helper.h"
template <typename T, typename OutT>
void MoeFastHardamardWrapper(const T *x_data,
const int64_t *expert_idx_per_token,
const int64_t *recv_expert_count,
const T *shift,
const T *smooth,
const float* quant_scales,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
const int hadamard_block_size,
OutT* out,
cudaStream_t &stream);

View File

@@ -17,169 +17,190 @@
#include "cutlass/numeric_conversion.h"
#include "group_swiglu_with_masked.h"
#include "helper.h"
#include "moe/fast_hardamard_kernel.h"
#include "moe/fused_moe_helper.h"
template <typename DataT, typename NvType, typename WeightSavedT, cutlass::WintQuantMethod QuantMethod>
template <typename DataT,
typename NvType,
typename WeightSavedT,
cutlass::WintQuantMethod QuantMethod>
void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::Tensor* up_gate_proj_bias,
const paddle::Tensor* up_gate_proj_super_scale,
const paddle::Tensor* down_proj_super_scale,
const paddle::Tensor* up_gate_proj_local_scale,
const paddle::Tensor* up_gate_proj_code_scale,
const paddle::Tensor* up_gate_proj_code_zp,
const paddle::Tensor* down_proj_local_scale,
const paddle::Tensor* down_proj_code_scale,
const paddle::Tensor* down_proj_code_zp,
paddle::Tensor fc1_out,
paddle::Tensor ffn_out,
const int64_t total_rows_in_ll_else_minus1,
const int64_t actual_total_rows,
const int64_t inter_size,
const int64_t hidden_size,
const int num_experts,
bool used_in_ep_low_latency) {
using namespace phi;
using WeightOnlyTraits = cutlass::WintQuantTraits<NvType, QuantMethod>;
using WeightType = typename WeightOnlyTraits::WeightType;
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::Tensor* up_gate_proj_bias,
const paddle::Tensor* up_gate_proj_super_scale,
const paddle::Tensor* down_proj_super_scale,
const paddle::Tensor* up_gate_proj_local_scale,
const paddle::Tensor* up_gate_proj_code_scale,
const paddle::Tensor* up_gate_proj_code_zp,
const paddle::Tensor* down_proj_local_scale,
const paddle::Tensor* down_proj_code_scale,
const paddle::Tensor* down_proj_code_zp,
paddle::Tensor fc1_out,
paddle::Tensor ffn_out,
const int64_t total_rows_in_ll_else_minus1,
const int64_t actual_total_rows,
const int64_t inter_size,
const int64_t hidden_size,
const int num_experts,
bool used_in_ep_low_latency) {
using namespace phi;
using WeightOnlyTraits = cutlass::WintQuantTraits<NvType, QuantMethod>;
using WeightType = typename WeightOnlyTraits::WeightType;
typename WeightOnlyTraits::Arguments up_gate_proj_quant_args;
typename WeightOnlyTraits::Arguments down_proj_quant_args;
if constexpr (QuantMethod == cutlass::WintQuantMethod::kWeightOnlyInt2) {
up_gate_proj_quant_args.local_scale_ptr = const_cast<uint8_t*>(up_gate_proj_local_scale->data<uint8_t>());
up_gate_proj_quant_args.code_scale_ptr = const_cast<float*>(up_gate_proj_code_scale->data<float>());
up_gate_proj_quant_args.code_zp_ptr = const_cast<float*>(up_gate_proj_code_zp->data<float>());
typename WeightOnlyTraits::Arguments up_gate_proj_quant_args;
typename WeightOnlyTraits::Arguments down_proj_quant_args;
if constexpr (QuantMethod == cutlass::WintQuantMethod::kWeightOnlyInt2) {
up_gate_proj_quant_args.local_scale_ptr =
const_cast<uint8_t*>(up_gate_proj_local_scale->data<uint8_t>());
up_gate_proj_quant_args.code_scale_ptr =
const_cast<float*>(up_gate_proj_code_scale->data<float>());
up_gate_proj_quant_args.code_zp_ptr =
const_cast<float*>(up_gate_proj_code_zp->data<float>());
down_proj_quant_args.local_scale_ptr = const_cast<uint8_t*>(down_proj_local_scale->data<uint8_t>());
down_proj_quant_args.code_scale_ptr = const_cast<float*>(down_proj_code_scale->data<float>());
down_proj_quant_args.code_zp_ptr = const_cast<float*>(down_proj_code_zp->data<float>());
}
down_proj_quant_args.local_scale_ptr =
const_cast<uint8_t*>(down_proj_local_scale->data<uint8_t>());
down_proj_quant_args.code_scale_ptr =
const_cast<float*>(down_proj_code_scale->data<float>());
down_proj_quant_args.code_zp_ptr =
const_cast<float*>(down_proj_code_zp->data<float>());
}
auto moe_gemm_runner = MoeGemmRunner<NvType, WeightOnlyTraits>();
auto stream = permute_input.stream();
auto moe_gemm_runner = MoeGemmRunner<NvType, WeightOnlyTraits>();
auto stream = permute_input.stream();
moe_gemm_runner.moe_gemm_bias_act(
reinterpret_cast<const NvType*>(permute_input.data<DataT>()),
reinterpret_cast<const WeightType*>(up_gate_proj_weight.data<WeightSavedT>()),
reinterpret_cast<const NvType*>(up_gate_proj_super_scale ? up_gate_proj_super_scale->data<DataT>() : nullptr),
reinterpret_cast<const NvType*>(up_gate_proj_bias ? up_gate_proj_bias->data<DataT>() : nullptr),
reinterpret_cast<NvType*>(fc1_out.data<DataT>()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
actual_total_rows,
inter_size,
hidden_size,
num_experts,
up_gate_proj_quant_args,
"none",
stream);
moe_gemm_runner.moe_gemm_bias_act(
reinterpret_cast<const NvType*>(permute_input.data<DataT>()),
reinterpret_cast<const WeightType*>(
up_gate_proj_weight.data<WeightSavedT>()),
reinterpret_cast<const NvType*>(
up_gate_proj_super_scale ? up_gate_proj_super_scale->data<DataT>()
: nullptr),
reinterpret_cast<const NvType*>(
up_gate_proj_bias ? up_gate_proj_bias->data<DataT>() : nullptr),
reinterpret_cast<NvType*>(fc1_out.data<DataT>()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
actual_total_rows,
inter_size,
hidden_size,
num_experts,
up_gate_proj_quant_args,
"none",
stream);
paddle::Tensor act_out;
if (used_in_ep_low_latency) {
act_out = GroupSwigluWithMasked(fc1_out, tokens_expert_prefix_sum);
} else {
act_out = paddle::experimental::swiglu(fc1_out, nullptr);
}
paddle::Tensor act_out;
if (used_in_ep_low_latency) {
act_out = GroupSwigluWithMasked(fc1_out, tokens_expert_prefix_sum);
} else {
act_out = paddle::experimental::swiglu(fc1_out, nullptr);
}
moe_gemm_runner.moe_gemm(
reinterpret_cast<const NvType*>(act_out.data<DataT>()),
reinterpret_cast<const WeightType*>(down_proj_weight.data<WeightSavedT>()),
reinterpret_cast<const NvType*>(down_proj_super_scale ? down_proj_super_scale->data<DataT>() : nullptr),
reinterpret_cast<NvType*>(ffn_out.data<DataT>()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
actual_total_rows,
hidden_size,
inter_size / 2,
num_experts,
down_proj_quant_args,
stream);
moe_gemm_runner.moe_gemm(
reinterpret_cast<const NvType*>(act_out.data<DataT>()),
reinterpret_cast<const WeightType*>(
down_proj_weight.data<WeightSavedT>()),
reinterpret_cast<const NvType*>(down_proj_super_scale
? down_proj_super_scale->data<DataT>()
: nullptr),
reinterpret_cast<NvType*>(ffn_out.data<DataT>()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
actual_total_rows,
hidden_size,
inter_size / 2,
num_experts,
down_proj_quant_args,
stream);
}
template <paddle::DataType T>
void MoeFFNWint2Kernel(const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_local_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_code_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_code_zp,
const paddle::optional<paddle::Tensor>& down_proj_local_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_zp,
paddle::Tensor ffn_out,
bool used_in_ep_low_latency) {
using namespace phi;
using data_t = typename PDTraits<T>::data_t;
using NvType = typename PDTraits<T>::DataType;
void MoeFFNWint2Kernel(
const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_local_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_code_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_code_zp,
const paddle::optional<paddle::Tensor>& down_proj_local_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_zp,
paddle::Tensor ffn_out,
bool used_in_ep_low_latency) {
using namespace phi;
using data_t = typename PDTraits<T>::data_t;
using NvType = typename PDTraits<T>::DataType;
auto place = permute_input.place();
auto place = permute_input.place();
assert(permute_input.dims().size() == 3 || permute_input.dims().size() == 2);
assert(up_gate_proj_weight.dims().size() == 3);
assert(permute_input.dims().size() == 3 || permute_input.dims().size() == 2);
assert(up_gate_proj_weight.dims().size() == 3);
const int num_experts = up_gate_proj_weight.dims()[0];
const int hidden_size = permute_input.dims()[permute_input.dims().size() - 1];
const int num_experts = up_gate_proj_weight.dims()[0];
const int hidden_size = permute_input.dims()[permute_input.dims().size() - 1];
int inter_dim = up_gate_proj_weight.dims()[1] * up_gate_proj_weight.dims()[2] / hidden_size;
int inter_dim = up_gate_proj_weight.dims()[1] *
up_gate_proj_weight.dims()[2] / hidden_size;
const int64_t inter_size = inter_dim * 4;
const int64_t inter_size = inter_dim * 4;
int num_experts_ = num_experts;
int num_max_tokens_per_expert = 0;
int expanded_active_expert_rows = 0;
int num_experts_ = num_experts;
int num_max_tokens_per_expert = 0;
int expanded_active_expert_rows = 0;
paddle::Tensor fc1_out_tensor;
if (permute_input.dims().size() == 3) {
num_experts_ = permute_input.dims()[0];
assert(num_experts == num_experts_);
paddle::Tensor fc1_out_tensor;
if (permute_input.dims().size() == 3) {
num_experts_ = permute_input.dims()[0];
assert(num_experts == num_experts_);
num_max_tokens_per_expert = permute_input.dims()[1];
expanded_active_expert_rows = num_experts_ * num_max_tokens_per_expert;
fc1_out_tensor = GetEmptyTensor(
{num_experts_, num_max_tokens_per_expert, inter_size}, T, place);
} else {
expanded_active_expert_rows = permute_input.dims()[0];
fc1_out_tensor = GetEmptyTensor(
{expanded_active_expert_rows, inter_size}, T, place);
}
num_max_tokens_per_expert = permute_input.dims()[1];
expanded_active_expert_rows = num_experts_ * num_max_tokens_per_expert;
fc1_out_tensor = GetEmptyTensor(
{num_experts_, num_max_tokens_per_expert, inter_size}, T, place);
} else {
expanded_active_expert_rows = permute_input.dims()[0];
fc1_out_tensor =
GetEmptyTensor({expanded_active_expert_rows, inter_size}, T, place);
}
// This is a trick.
// expanded_active_expert_rows is not needed in variable group gemm.
// but is needed in accommodating deepep low latency mode
const int64_t total_rows_in_ll_else_minus1 = used_in_ep_low_latency ? expanded_active_expert_rows : -1;
// This is a trick.
// expanded_active_expert_rows is not needed in variable group gemm.
// but is needed in accommodating deepep low latency mode
const int64_t total_rows_in_ll_else_minus1 =
used_in_ep_low_latency ? expanded_active_expert_rows : -1;
// When we tune the optimal configuration, we need the actual total_rows.
const int64_t actual_total_rows = expanded_active_expert_rows;
// When we tune the optimal configuration, we need the actual total_rows.
const int64_t actual_total_rows = expanded_active_expert_rows;
WeightOnlyMoeFFNKernel<data_t, NvType, uint8_t, cutlass::WintQuantMethod::kWeightOnlyInt2>(
permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
const_cast<paddle::Tensor*>(up_gate_proj_bias.get_ptr()),
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr()),
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr()),
const_cast<paddle::Tensor*>(up_gate_proj_local_scale.get_ptr()),
const_cast<paddle::Tensor*>(up_gate_proj_code_scale.get_ptr()),
const_cast<paddle::Tensor*>(up_gate_proj_code_zp.get_ptr()),
const_cast<paddle::Tensor*>(down_proj_local_scale.get_ptr()),
const_cast<paddle::Tensor*>(down_proj_code_scale.get_ptr()),
const_cast<paddle::Tensor*>(down_proj_code_zp.get_ptr()),
fc1_out_tensor,
ffn_out,
total_rows_in_ll_else_minus1,
actual_total_rows,
inter_size,
hidden_size,
num_experts,
used_in_ep_low_latency);
WeightOnlyMoeFFNKernel<data_t,
NvType,
uint8_t,
cutlass::WintQuantMethod::kWeightOnlyInt2>(
permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
const_cast<paddle::Tensor*>(up_gate_proj_bias.get_ptr()),
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr()),
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr()),
const_cast<paddle::Tensor*>(up_gate_proj_local_scale.get_ptr()),
const_cast<paddle::Tensor*>(up_gate_proj_code_scale.get_ptr()),
const_cast<paddle::Tensor*>(up_gate_proj_code_zp.get_ptr()),
const_cast<paddle::Tensor*>(down_proj_local_scale.get_ptr()),
const_cast<paddle::Tensor*>(down_proj_code_scale.get_ptr()),
const_cast<paddle::Tensor*>(down_proj_code_zp.get_ptr()),
fc1_out_tensor,
ffn_out,
total_rows_in_ll_else_minus1,
actual_total_rows,
inter_size,
hidden_size,
num_experts,
used_in_ep_low_latency);
}
paddle::Tensor MoeExpertFFNWint2Func(
@@ -197,49 +218,48 @@ paddle::Tensor MoeExpertFFNWint2Func(
const paddle::optional<paddle::Tensor>& down_proj_code_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_zp,
const bool used_in_ep_low_latency) {
const auto dtype = permute_input.dtype();
auto ffn_out = paddle::empty_like(permute_input, dtype);
const auto dtype = permute_input.dtype();
auto ffn_out = paddle::empty_like(permute_input, dtype);
switch (dtype) {
case paddle::DataType::BFLOAT16:
MoeFFNWint2Kernel<paddle::DataType::BFLOAT16>(permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
up_gate_proj_local_scale,
up_gate_proj_code_scale,
up_gate_proj_code_zp,
down_proj_local_scale,
down_proj_code_scale,
down_proj_code_zp,
ffn_out,
used_in_ep_low_latency);
break;
case paddle::DataType::FLOAT16:
MoeFFNWint2Kernel<paddle::DataType::FLOAT16>(permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
up_gate_proj_local_scale,
up_gate_proj_code_scale,
up_gate_proj_code_zp,
down_proj_local_scale,
down_proj_code_scale,
down_proj_code_zp,
ffn_out,
used_in_ep_low_latency);
break;
default:
PD_THROW("Unsupported data type for MoeExpertFFN");
}
return ffn_out;
switch (dtype) {
case paddle::DataType::BFLOAT16:
MoeFFNWint2Kernel<paddle::DataType::BFLOAT16>(permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
up_gate_proj_local_scale,
up_gate_proj_code_scale,
up_gate_proj_code_zp,
down_proj_local_scale,
down_proj_code_scale,
down_proj_code_zp,
ffn_out,
used_in_ep_low_latency);
break;
case paddle::DataType::FLOAT16:
MoeFFNWint2Kernel<paddle::DataType::FLOAT16>(permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
up_gate_proj_local_scale,
up_gate_proj_code_scale,
up_gate_proj_code_zp,
down_proj_local_scale,
down_proj_code_scale,
down_proj_code_zp,
ffn_out,
used_in_ep_low_latency);
break;
default:
PD_THROW("Unsupported data type for MoeExpertFFN");
}
return ffn_out;
}
std::vector<paddle::Tensor> MoeExpertFFNWint2(
@@ -257,21 +277,20 @@ std::vector<paddle::Tensor> MoeExpertFFNWint2(
const paddle::optional<paddle::Tensor>& down_proj_code_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_zp,
const bool used_in_ep_low_latency) {
return {MoeExpertFFNWint2Func(permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
up_gate_proj_local_scale,
up_gate_proj_code_scale,
up_gate_proj_code_zp,
down_proj_local_scale,
down_proj_code_scale,
down_proj_code_zp,
used_in_ep_low_latency)};
return {MoeExpertFFNWint2Func(permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
up_gate_proj_local_scale,
up_gate_proj_code_scale,
up_gate_proj_code_zp,
down_proj_local_scale,
down_proj_code_scale,
down_proj_code_zp,
used_in_ep_low_latency)};
}
std::vector<std::vector<int64_t>> MoeExpertFFNWint2InferShape(
@@ -282,53 +301,53 @@ std::vector<std::vector<int64_t>> MoeExpertFFNWint2InferShape(
const paddle::optional<std::vector<int64_t>>& up_gate_proj_bias_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_scale_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_scale_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_local_scale_shape,
const paddle::optional<std::vector<int64_t>>&
up_gate_proj_local_scale_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_code_scale_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_code_zp_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_local_scale_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_code_scale_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_code_zp_shape,
const bool used_in_ep_low_latency) {
return {permute_input_shape};
return {permute_input_shape};
}
std::vector<paddle::DataType> MoeExpertFFNWint2InferDtype(
const paddle::DataType &permute_input_dtype,
const paddle::DataType &tokens_expert_prefix_sum_dtype,
const paddle::DataType &up_gate_proj_weight_dtype,
const paddle::DataType &down_proj_weight_dtype,
const paddle::optional<paddle::DataType> &up_gate_proj_bias_dtype,
const paddle::optional<paddle::DataType> &up_gate_proj_scale_dtype,
const paddle::optional<paddle::DataType> &down_proj_scale_dtype,
const paddle::optional<paddle::DataType> &up_gate_proj_local_scale_dtype,
const paddle::optional<paddle::DataType> &up_gate_proj_code_scale_dtype,
const paddle::optional<paddle::DataType> &up_gate_proj_code_zp_dtype,
const paddle::optional<paddle::DataType> &down_proj_local_scale_dtype,
const paddle::optional<paddle::DataType> &down_proj_code_scale_dtype,
const paddle::optional<paddle::DataType> &down_proj_code_zp_dtype,
const paddle::DataType& permute_input_dtype,
const paddle::DataType& tokens_expert_prefix_sum_dtype,
const paddle::DataType& up_gate_proj_weight_dtype,
const paddle::DataType& down_proj_weight_dtype,
const paddle::optional<paddle::DataType>& up_gate_proj_bias_dtype,
const paddle::optional<paddle::DataType>& up_gate_proj_scale_dtype,
const paddle::optional<paddle::DataType>& down_proj_scale_dtype,
const paddle::optional<paddle::DataType>& up_gate_proj_local_scale_dtype,
const paddle::optional<paddle::DataType>& up_gate_proj_code_scale_dtype,
const paddle::optional<paddle::DataType>& up_gate_proj_code_zp_dtype,
const paddle::optional<paddle::DataType>& down_proj_local_scale_dtype,
const paddle::optional<paddle::DataType>& down_proj_code_scale_dtype,
const paddle::optional<paddle::DataType>& down_proj_code_zp_dtype,
const bool used_in_ep_low_latency) {
return {permute_input_dtype};
return {permute_input_dtype};
}
/**
* @brief Weight-Only Quantized Mixture of Experts (MoE) Feed-Forward Network Operator
* @brief Weight-Only Quantized Mixture of Experts (MoE) Feed-Forward Network
* Operator
*
* This operator performs the expert computation in MoE architecture, including:
* 1. First linear transformation (up_gate_proj) with optional quantization
* 2. SwiGLU activation function
* 3. Second linear transformation (down_proj) with optional quantization
*
* Supports multiple quantization methods including weight-only int4/int8 and w4a8 quantization.
* Supports multiple quantization methods including weight-only int4/int8 and
* w4a8 quantization.
*
* Inputs:
* - permute_input: Permuted input tensor organized by expert
* Shape: [total_tokens * top_k, hidden_size]
* dtype: bfloat16/float16 (or int8 for w4a8)
* - tokens_expert_prefix_sum: Prefix sum array of token counts per expert for group_gemm
* Shape: [num_experts]
* dtype: int64
* - tokens_expert_prefix_sum: Prefix sum array of token counts per expert for
* group_gemm Shape: [num_experts] dtype: int64
* - up_gate_proj_weight: First FFN layer weights
* Shape: [num_experts, inter_size * 2, hidden_size]
* dtype: Same as input (unquantized) or int8 (quantized)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,164 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "helper.h"
#define FULL_MASK 0xffffffff
struct uint8 {
uint4 u;
uint4 v;
};
template <int BYTES>
struct BytesToType {};
template <>
struct BytesToType<32> {
using Type = uint8;
static_assert(sizeof(Type) == 32);
};
template <>
struct BytesToType<16> {
using Type = uint4;
static_assert(sizeof(Type) == 16);
};
template <>
struct BytesToType<8> {
using Type = uint64_t;
static_assert(sizeof(Type) == 8);
};
template <>
struct BytesToType<4> {
using Type = uint32_t;
static_assert(sizeof(Type) == 4);
};
template <>
struct BytesToType<2> {
using Type = uint16_t;
static_assert(sizeof(Type) == 2);
};
template <>
struct BytesToType<1> {
using Type = uint8_t;
static_assert(sizeof(Type) == 1);
};
template <typename T>
struct nv_type_traits {
using type = T;
};
template <>
struct nv_type_traits<phi::dtype::float16> {
using type = half;
};
template <>
struct nv_type_traits<phi::dtype::bfloat16> {
using type = __nv_bfloat16;
};
template <>
struct nv_type_traits<int8_t> {
using type = int8_t;
};
#define DISPATCH_SP_logN(logN, kLogN, ...) \
if (logN == 10) { \
constexpr int kLogN = 10; \
__VA_ARGS__ \
} else if (logN == 9) { \
constexpr int kLogN = 9; \
__VA_ARGS__ \
} else if (logN == 8) { \
constexpr int kLogN = 8; \
__VA_ARGS__ \
} else if (logN == 7) { \
constexpr int kLogN = 7; \
__VA_ARGS__ \
} else { \
PADDLE_THROW( \
phi::errors::Unimplemented("logN = %d is unsupported!", logN)); \
}
#define DISPATCH_SP_VS(vec_size, VEC_SIZE, ...) \
if (vec_size == 16) { \
constexpr int VEC_SIZE = 16; \
__VA_ARGS__ \
} else if (vec_size == 8) { \
constexpr int VEC_SIZE = 8; \
__VA_ARGS__ \
} else if (vec_size == 4) { \
constexpr int VEC_SIZE = 4; \
__VA_ARGS__ \
} else if (vec_size == 2) { \
constexpr int VEC_SIZE = 2; \
__VA_ARGS__ \
} else if (vec_size == 1) { \
constexpr int VEC_SIZE = 1; \
__VA_ARGS__ \
} else { \
PADDLE_THROW(phi::errors::Unimplemented("vec_size = %d is unsupported!", \
vec_size)); \
}
#define DISPATCH_logN(logN, kLogN, ...) \
if (logN == 11) { \
constexpr int kLogN = 11; \
__VA_ARGS__ \
} else if (logN == 12) { \
constexpr int kLogN = 12; \
__VA_ARGS__ \
} else if (logN == 13) { \
constexpr int kLogN = 13; \
__VA_ARGS__ \
} else if (logN == 14) { \
constexpr int kLogN = 14; \
__VA_ARGS__ \
} else { \
PADDLE_THROW(phi::errors::Unimplemented("unsupported logN")); \
}
template <typename T,
typename OutT,
int kLogN,
int VecSize,
int kNChunks,
int kThreads,
bool UseDiagonalBlockMatrix>
void MoeFastHardamardImplWrapper(const T *x,
const int64_t *expert_idx_per_token,
const int64_t *recv_expert_count,
const T *shift,
const T *smooth,
const float *quant_scales,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
OutT *out,
cudaStream_t stream);

View File

@@ -0,0 +1,230 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "helper.h"
#include "moe_fast_hardamard_impl_common.h"
template <typename T, typename OutT>
void MoeFastHardamardWrapper(const T *x_data,
const int64_t *expert_idx_per_token,
const int64_t *recv_expert_count,
const T *shift,
const T *smooth,
const float *quant_scales,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
const int hadamard_block_size,
OutT *out,
cudaStream_t &stream) {
bool FLAGS_hardamard_use_diagonal_block_matrix = true;
constexpr int kThreads = 128;
if (FLAGS_hardamard_use_diagonal_block_matrix) {
const int VecSize = hadamard_block_size / kThreads;
const int logN = int(ceil(std::log2(kThreads * VecSize)));
constexpr int kNChunks = 1;
DISPATCH_SP_VS(VecSize, VEC_SIZE, {DISPATCH_SP_logN(logN, kLogN, {
MoeFastHardamardImplWrapper<T,
OutT,
kLogN,
VEC_SIZE,
kNChunks,
kThreads,
true>(
x_data,
expert_idx_per_token,
recv_expert_count,
shift,
smooth,
quant_scales,
quant_round_type,
quant_max_bound,
quant_min_bound,
token_num,
dim,
num_max_tokens_per_expert,
used_in_ep_low_latency,
out,
stream);
})});
} else {
if (!((dim / 28) & (dim / 28 - 1))) {
VLOG(1) << "28 * 2^n";
const int logN = int(ceil(std::log2(dim / 28)));
constexpr int kNChunks = 28;
DISPATCH_SP_logN(logN, kLogN, {
constexpr int VecSize = (1 << kLogN) / kThreads;
MoeFastHardamardImplWrapper<T,
OutT,
kLogN,
VecSize,
kNChunks,
kThreads,
false>(x_data,
expert_idx_per_token,
recv_expert_count,
shift,
smooth,
quant_scales,
quant_round_type,
quant_max_bound,
quant_min_bound,
token_num,
dim,
num_max_tokens_per_expert,
used_in_ep_low_latency,
out,
stream);
});
} else if (!((dim / 36) & (dim / 36 - 1))) {
VLOG(1) << "36 * 2^n";
const int logN = int(ceil(std::log2(dim / 36)));
constexpr int kNChunks = 36;
DISPATCH_SP_logN(logN, kLogN, {
constexpr int VecSize = (1 << kLogN) / kThreads;
MoeFastHardamardImplWrapper<T,
OutT,
kLogN,
VecSize,
kNChunks,
kThreads,
false>(x_data,
expert_idx_per_token,
recv_expert_count,
shift,
smooth,
quant_scales,
quant_round_type,
quant_max_bound,
quant_min_bound,
token_num,
dim,
num_max_tokens_per_expert,
used_in_ep_low_latency,
out,
stream);
});
} else {
VLOG(1) << "2^n";
const int logN = int(ceil(std::log2(dim)));
constexpr int VecSize = 16 / sizeof(T);
DISPATCH_logN(logN, kLogN, {
constexpr int kNChunks = (1 << kLogN) / (kThreads * VecSize);
MoeFastHardamardImplWrapper<T,
OutT,
kLogN,
VecSize,
kNChunks,
kThreads,
false>(x_data,
expert_idx_per_token,
recv_expert_count,
shift,
smooth,
quant_scales,
quant_round_type,
quant_max_bound,
quant_min_bound,
token_num,
dim,
num_max_tokens_per_expert,
used_in_ep_low_latency,
out,
stream);
});
}
}
}
template void MoeFastHardamardWrapper<phi::dtype::float16, phi::dtype::float16>(
const phi::dtype::float16 *x_data,
const int64_t *expert_idx_per_token,
const int64_t *recv_expert_count,
const phi::dtype::float16 *shift,
const phi::dtype::float16 *smooth,
const float *quant_scales,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
const int hadamard_block_size,
phi::dtype::float16 *out,
cudaStream_t &stream);
template void MoeFastHardamardWrapper<phi::dtype::float16, int8_t>(
const phi::dtype::float16 *x_data,
const int64_t *expert_idx_per_token,
const int64_t *recv_expert_count,
const phi::dtype::float16 *shift,
const phi::dtype::float16 *smooth,
const float *quant_scales,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
const int hadamard_block_size,
int8_t *out,
cudaStream_t &stream);
template void
MoeFastHardamardWrapper<phi::dtype::bfloat16, phi::dtype::bfloat16>(
const phi::dtype::bfloat16 *x_data,
const int64_t *expert_idx_per_token,
const int64_t *recv_expert_count,
const phi::dtype::bfloat16 *shift,
const phi::dtype::bfloat16 *smooth,
const float *quant_scales,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
const int hadamard_block_size,
phi::dtype::bfloat16 *out,
cudaStream_t &stream);
template void MoeFastHardamardWrapper<phi::dtype::bfloat16, int8_t>(
const phi::dtype::bfloat16 *x_data,
const int64_t *expert_idx_per_token,
const int64_t *recv_expert_count,
const phi::dtype::bfloat16 *shift,
const phi::dtype::bfloat16 *smooth,
const float *quant_scales,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
const int hadamard_block_size,
int8_t *out,
cudaStream_t &stream);

View File

@@ -0,0 +1,35 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "helper.h"
template <typename T, typename OutT>
void MoeFastHardamardWrapper(const T *x_data,
const int64_t *expert_idx_per_token,
const int64_t *recv_expert_count,
const T *shift,
const T *smooth,
const float *quant_scales,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
const int hadamard_block_size,
OutT *out,
cudaStream_t &stream);

View File

@@ -18,10 +18,10 @@
#include "cutlass_kernels/w4a8_moe/w4a8_moe_gemm_kernel.h"
#include "group_swiglu_with_masked.h"
#include "helper.h"
#include "moe/fast_hardamard_kernel.h"
#include "moe/fused_moe_helper.h"
#include "w4afp8_gemm/w4afp8_gemm.h"
#include "moe/moe_fast_hardamard_kernel.h"
#include "swigluoai.h"
#include "w4afp8_gemm/w4afp8_gemm.h"
template <paddle::DataType T>
void MoeFFNKernel(const paddle::Tensor& permute_input,
@@ -39,367 +39,402 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
const int estimate_total_token_nums,
const int hadamard_block_size,
const std::string& activation) {
using namespace phi;
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto quant_mode = cutlass::epilogue::QuantMode::PerChannelQuant;
using namespace phi;
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto quant_mode = cutlass::epilogue::QuantMode::PerChannelQuant;
auto ffn_out_data = ffn_out.data<data_t>();
auto place = permute_input.place();
auto stream = permute_input.stream();
auto ffn_out_data = ffn_out.data<data_t>();
auto place = permute_input.place();
auto stream = permute_input.stream();
auto fp16_moe_gemm_runner = MoeGemmRunner<DataType_, cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>>();
auto int8_moe_gemm_runner = MoeGemmRunner<DataType_, cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt8>>();
auto int4_moe_gemm_runner = MoeGemmRunner<DataType_, cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt4>>();
auto w4a8_moe_gemm_runner = W4A8MoeGemmRunner<DataType_, int8_t, cutlass::uint4b_t>();
auto fp16_moe_gemm_runner = MoeGemmRunner<
DataType_,
cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>>();
auto int8_moe_gemm_runner = MoeGemmRunner<
DataType_,
cutlass::WintQuantTraits<DataType_,
cutlass::WintQuantMethod::kWeightOnlyInt8>>();
auto int4_moe_gemm_runner = MoeGemmRunner<
DataType_,
cutlass::WintQuantTraits<DataType_,
cutlass::WintQuantMethod::kWeightOnlyInt4>>();
auto w4a8_moe_gemm_runner =
W4A8MoeGemmRunner<DataType_, int8_t, cutlass::uint4b_t>();
assert(permute_input.dims().size() == 3 || permute_input.dims().size() == 2);
assert(permute_input.dims().size() == 3 || permute_input.dims().size() == 2);
const int num_experts = up_gate_proj_weight.dims()[0];
const int hidden_size = permute_input.dims()[permute_input.dims().size() - 1];
const int num_experts = up_gate_proj_weight.dims()[0];
const int hidden_size = permute_input.dims()[permute_input.dims().size() - 1];
assert(up_gate_proj_weight.dims().size() == 3);
int inter_dim = up_gate_proj_weight.dims()[1] * up_gate_proj_weight.dims()[2] / hidden_size;
assert(up_gate_proj_weight.dims().size() == 3);
int inter_dim = up_gate_proj_weight.dims()[1] *
up_gate_proj_weight.dims()[2] / hidden_size;
constexpr size_t workspace_size = 1 * 1024 * 1024 * 1024; // for nf4 stream-k
Allocator* allocator = paddle::GetAllocator(place);
Allocator::AllocationPtr workspace;
if (quant_method == "weight_only_int4" || quant_method == "w4a8" || quant_method == "w4afp8") {
inter_dim = inter_dim * 2;
}
if (quant_method == "w4a8" || quant_method == "w4afp8") {
workspace = allocator->Allocate(
SizeOf(paddle::DataType::INT8) * workspace_size);
}
constexpr size_t workspace_size = 1 * 1024 * 1024 * 1024; // for nf4 stream-k
Allocator* allocator = paddle::GetAllocator(place);
Allocator::AllocationPtr workspace;
if (quant_method == "weight_only_int4" || quant_method == "w4a8" ||
quant_method == "w4afp8") {
inter_dim = inter_dim * 2;
}
if (quant_method == "w4a8" || quant_method == "w4afp8") {
workspace =
allocator->Allocate(SizeOf(paddle::DataType::INT8) * workspace_size);
}
const int64_t inter_size = inter_dim;
const int64_t inter_size = inter_dim;
typedef PDTraits<paddle::DataType::FLOAT8_E4M3FN> traits_fp8;
typedef typename traits_fp8::DataType DataType_fp8;
typedef typename traits_fp8::data_t data_t_fp8;
int num_experts_ = num_experts;
int num_max_tokens_per_expert = 256;
int expanded_active_expert_rows;
paddle::Tensor fc1_out_tensor;
if (permute_input.dims().size() == 3) {
num_experts_ = permute_input.dims()[0];
assert(num_experts == num_experts_);
num_max_tokens_per_expert = permute_input.dims()[1];
expanded_active_expert_rows = num_experts_ * num_max_tokens_per_expert;
fc1_out_tensor = GetEmptyTensor(
{num_experts_, num_max_tokens_per_expert, inter_size}, T, place);
} else {
expanded_active_expert_rows = permute_input.dims()[0];
fc1_out_tensor =
GetEmptyTensor({expanded_active_expert_rows, inter_size}, T, place);
}
auto fc1_out = fc1_out_tensor.data<data_t>();
using NvType = typename traits_::DataType;
auto fc1_expert_biases =
up_gate_proj_bias
? const_cast<paddle::Tensor*>(up_gate_proj_bias.get_ptr())
->data<data_t>()
: nullptr;
// This is a trick.
// expanded_active_expert_rows is not needed in variable group gemm.
// but is needed in accommodating deepep low latency mode
const int64_t total_rows_in_ll_else_minus1 =
used_in_ep_low_latency ? expanded_active_expert_rows : -1;
// When we tune the optimal configuration, we need the actual total_rows.
const int64_t tune_total_rows = expanded_active_expert_rows;
if (quant_method == "weight_only_int8") {
typename cutlass::WintQuantTraits<
DataType_,
cutlass::WintQuantMethod::kWeightOnlyInt8>::Arguments quant_args;
int8_moe_gemm_runner.moe_gemm_bias_act(
reinterpret_cast<const NvType*>(permute_input.data<data_t>()),
reinterpret_cast<const uint8_t*>(up_gate_proj_weight.data<int8_t>()),
reinterpret_cast<const NvType*>(
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr())
->data<data_t>()),
reinterpret_cast<const NvType*>(fc1_expert_biases),
reinterpret_cast<NvType*>(fc1_out),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
tune_total_rows,
inter_size,
hidden_size,
num_experts,
quant_args,
"none",
stream);
} else if (quant_method == "weight_only_int4") {
typename cutlass::WintQuantTraits<
DataType_,
cutlass::WintQuantMethod::kWeightOnlyInt4>::Arguments quant_args;
int4_moe_gemm_runner.moe_gemm_bias_act(
reinterpret_cast<const NvType*>(permute_input.data<data_t>()),
reinterpret_cast<const cutlass::uint4b_t*>(
up_gate_proj_weight.data<int8_t>()),
reinterpret_cast<const NvType*>(
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr())
->data<data_t>()),
reinterpret_cast<const NvType*>(fc1_expert_biases),
reinterpret_cast<NvType*>(fc1_out),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
tune_total_rows,
inter_size,
hidden_size,
num_experts,
quant_args,
"none",
stream);
} else if (quant_method == "w4a8") {
w4a8_moe_gemm_runner.moe_gemm(
reinterpret_cast<const int8_t*>(permute_input.data<int8_t>()),
reinterpret_cast<const cutlass::uint4b_t*>(
up_gate_proj_weight.data<int8_t>()),
quant_mode,
reinterpret_cast<const NvType*>(
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr())
->data<data_t>()),
nullptr, // up_gate_proj_scale_dyquant
nullptr, // nf4_look_up_table
reinterpret_cast<NvType*>(fc1_out),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
used_in_ep_low_latency ? estimate_total_token_nums : tune_total_rows,
inter_size,
hidden_size,
reinterpret_cast<char*>(workspace->ptr()),
workspace_size,
num_experts,
stream);
} else if (quant_method == "w4afp8") {
typedef PDTraits<paddle::DataType::FLOAT8_E4M3FN> traits_fp8;
typedef typename traits_fp8::DataType DataType_fp8;
typedef typename traits_fp8::data_t data_t_fp8;
int num_experts_ = num_experts;
int num_max_tokens_per_expert = 256;
int expanded_active_expert_rows;
Allocator::AllocationPtr ffn1_input_row_sum;
ffn1_input_row_sum =
allocator->Allocate(sizeof(float) * expanded_active_expert_rows);
paddle::Tensor fc1_out_tensor;
compute_row_sum(
permute_input.data<data_t_fp8>(),
expanded_active_expert_rows,
hidden_size,
reinterpret_cast<float*>(ffn1_input_row_sum->ptr()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
num_max_tokens_per_expert,
used_in_ep_low_latency,
stream);
if (permute_input.dims().size() == 3) {
num_experts_ = permute_input.dims()[0];
assert(num_experts == num_experts_);
float* row_scale = nullptr;
DisPatchW4AFp8GemmWrapper(
reinterpret_cast<const DataType_fp8*>(permute_input.data<data_t_fp8>()),
reinterpret_cast<const DataType_fp8*>(
up_gate_proj_weight.data<int8_t>()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
reinterpret_cast<float*>(ffn1_input_row_sum->ptr()),
row_scale,
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr())
->data<float>(),
reinterpret_cast<NvType*>(fc1_out),
used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
used_in_ep_low_latency ? num_max_tokens_per_expert
: permute_input.dims()[0],
num_experts,
inter_size,
hidden_size,
stream);
} else {
typename cutlass::WintQuantTraits<
DataType_,
cutlass::WintQuantMethod::kNone>::Arguments quant_args;
fp16_moe_gemm_runner.moe_gemm_bias_act(
reinterpret_cast<const NvType*>(permute_input.data<data_t>()),
reinterpret_cast<const NvType*>(up_gate_proj_weight.data<data_t>()),
nullptr,
reinterpret_cast<const NvType*>(fc1_expert_biases),
reinterpret_cast<NvType*>(fc1_out),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
tune_total_rows,
inter_size,
hidden_size,
num_experts,
quant_args,
"none",
stream);
}
num_max_tokens_per_expert = permute_input.dims()[1];
expanded_active_expert_rows = num_experts_ * num_max_tokens_per_expert;
fc1_out_tensor = GetEmptyTensor(
{num_experts_, num_max_tokens_per_expert, inter_size}, T, place);
paddle::Tensor act_out_tensor;
if (used_in_ep_low_latency) {
act_out_tensor =
GroupSwigluWithMasked(fc1_out_tensor, tokens_expert_prefix_sum);
} else {
if (activation == "swigluoai") {
act_out_tensor = SwigluOAI(fc1_out_tensor, 1.702, 7.0, "interleave");
} else {
expanded_active_expert_rows = permute_input.dims()[0];
fc1_out_tensor = GetEmptyTensor(
{expanded_active_expert_rows, inter_size}, T, place);
act_out_tensor = paddle::experimental::swiglu(fc1_out_tensor, nullptr);
}
}
auto fc1_out = fc1_out_tensor.data<data_t>();
using NvType = typename traits_::DataType;
auto fc1_expert_biases =
up_gate_proj_bias
? const_cast<paddle::Tensor*>(up_gate_proj_bias.get_ptr())->data<data_t>()
: nullptr;
// This is a trick.
// expanded_active_expert_rows is not needed in variable group gemm.
// but is needed in accommodating deepep low latency mode
const int64_t total_rows_in_ll_else_minus1 = used_in_ep_low_latency ? expanded_active_expert_rows : -1;
// When we tune the optimal configuration, we need the actual total_rows.
const int64_t tune_total_rows = expanded_active_expert_rows;
if (quant_method == "weight_only_int8") {
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt8>::Arguments quant_args;
int8_moe_gemm_runner.moe_gemm_bias_act(
reinterpret_cast<const NvType*>(permute_input.data<data_t>()),
reinterpret_cast<const uint8_t*>(up_gate_proj_weight.data<int8_t>()),
reinterpret_cast<const NvType*>(
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr())
->data<data_t>()),
reinterpret_cast<const NvType*>(fc1_expert_biases),
reinterpret_cast<NvType*>(fc1_out),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
tune_total_rows,
inter_size,
hidden_size,
num_experts,
quant_args,
"none",
stream);
} else if (quant_method == "weight_only_int4") {
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt4>::Arguments quant_args;
int4_moe_gemm_runner.moe_gemm_bias_act(
reinterpret_cast<const NvType*>(permute_input.data<data_t>()),
reinterpret_cast<const cutlass::uint4b_t*>(
up_gate_proj_weight.data<int8_t>()),
reinterpret_cast<const NvType*>(
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr())
->data<data_t>()),
reinterpret_cast<const NvType*>(fc1_expert_biases),
reinterpret_cast<NvType*>(fc1_out),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
tune_total_rows,
inter_size,
hidden_size,
num_experts,
quant_args,
"none",
stream);
} else if (quant_method == "w4a8") {
w4a8_moe_gemm_runner.moe_gemm(
reinterpret_cast<const int8_t *>(permute_input.data<int8_t>()),
reinterpret_cast<const cutlass::uint4b_t *>(
up_gate_proj_weight.data<int8_t>()),
quant_mode,
reinterpret_cast<const NvType*>(
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr())
->data<data_t>()),
nullptr, // up_gate_proj_scale_dyquant
nullptr, // nf4_look_up_table
reinterpret_cast<NvType *>(fc1_out),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
used_in_ep_low_latency ? estimate_total_token_nums : tune_total_rows,
inter_size,
hidden_size,
reinterpret_cast<char*>(workspace->ptr()),
workspace_size,
num_experts,
stream);
} else if (quant_method == "w4afp8") {
typedef PDTraits<paddle::DataType::FLOAT8_E4M3FN> traits_fp8;
typedef typename traits_fp8::DataType DataType_fp8;
typedef typename traits_fp8::data_t data_t_fp8;
Allocator::AllocationPtr ffn1_input_row_sum;
ffn1_input_row_sum = allocator->Allocate(
sizeof(float) * expanded_active_expert_rows);
compute_row_sum(
permute_input.data<data_t_fp8>(),
expanded_active_expert_rows,
hidden_size,
reinterpret_cast<float*>(ffn1_input_row_sum->ptr()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
num_max_tokens_per_expert,
used_in_ep_low_latency,
stream);
float* row_scale = nullptr;
DisPatchW4AFp8GemmWrapper(
reinterpret_cast<const DataType_fp8 *>(permute_input.data<data_t_fp8>()),
reinterpret_cast<const DataType_fp8 *>(up_gate_proj_weight.data<int8_t>()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
reinterpret_cast<float*>(ffn1_input_row_sum->ptr()),
row_scale,
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr())
->data<float>(),
reinterpret_cast<NvType *>(fc1_out),
used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
used_in_ep_low_latency ? num_max_tokens_per_expert : permute_input.dims()[0],
num_experts,
inter_size,
hidden_size,
stream);
} else {
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>::Arguments quant_args;
fp16_moe_gemm_runner.moe_gemm_bias_act(
reinterpret_cast<const NvType*>(permute_input.data<data_t>()),
reinterpret_cast<const NvType*>(up_gate_proj_weight.data<data_t>()),
nullptr,
reinterpret_cast<const NvType*>(fc1_expert_biases),
reinterpret_cast<NvType*>(fc1_out),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
tune_total_rows,
inter_size,
hidden_size,
num_experts,
quant_args,
"none",
stream);
}
paddle::Tensor act_out_tensor;
if (used_in_ep_low_latency) {
act_out_tensor = GroupSwigluWithMasked(fc1_out_tensor, tokens_expert_prefix_sum);
} else {
if (activation == "swigluoai") {
act_out_tensor = SwigluOAI(fc1_out_tensor, 1.702, 7.0, "interleave");
} else {
act_out_tensor = paddle::experimental::swiglu(fc1_out_tensor, nullptr);
}
}
auto act_out = act_out_tensor.data<data_t>();
if (quant_method == "weight_only_int8") {
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt8>::Arguments quant_args;
int8_moe_gemm_runner.moe_gemm(
reinterpret_cast<const NvType*>(act_out),
reinterpret_cast<const uint8_t*>(down_proj_weight.data<int8_t>()),
reinterpret_cast<const NvType*>(
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr())
->data<data_t>()),
reinterpret_cast<NvType*>(ffn_out_data),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
tune_total_rows,
hidden_size,
inter_size / 2,
num_experts,
quant_args,
stream);
} else if (quant_method == "weight_only_int4") {
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt4>::Arguments quant_args;
int4_moe_gemm_runner.moe_gemm(
reinterpret_cast<const NvType*>(act_out),
reinterpret_cast<const cutlass::uint4b_t*>(
down_proj_weight.data<int8_t>()),
reinterpret_cast<const NvType*>(
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr())
->data<data_t>()),
reinterpret_cast<NvType*>(ffn_out_data),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
tune_total_rows,
hidden_size,
inter_size / 2,
num_experts,
quant_args,
stream);
} else if (quant_method == "w4a8") {
data_t *down_proj_shift = nullptr;
data_t *down_proj_smooth = nullptr;
Allocator::AllocationPtr int8_act_out;
int8_act_out = allocator->Allocate(
SizeOf(paddle::DataType::INT8) * act_out_tensor.numel());
MoeFastHardamardWrapper<data_t, int8_t>(
act_out_tensor.data<data_t>(),
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>() : nullptr,
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
down_proj_shift, // down_proj_shift->data<T>(),
down_proj_smooth, // down_proj_smooth->data<T>(),
down_proj_in_scale ? const_cast<paddle::Tensor*>(down_proj_in_scale.get_ptr())->data<float>() : nullptr,
1,
127.0,
-127.0,
expanded_active_expert_rows,
inter_size / 2,
num_max_tokens_per_expert,
used_in_ep_low_latency,
hadamard_block_size,
reinterpret_cast<int8_t *>(int8_act_out->ptr()),
stream
);
w4a8_moe_gemm_runner.moe_gemm(
reinterpret_cast<int8_t *>(int8_act_out->ptr()),
reinterpret_cast<const cutlass::uint4b_t *>(
down_proj_weight.data<int8_t>()),
quant_mode,
reinterpret_cast<const NvType*>(
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr())
->data<data_t>()),
nullptr, // down_proj_scale_dyquant
nullptr, // reinterpret_cast<const int32_t*>(d_nf4_look_up_table), // nf4_look_up_table
reinterpret_cast<NvType *>(ffn_out_data),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
used_in_ep_low_latency ? estimate_total_token_nums : tune_total_rows,
hidden_size,
inter_size / 2,
reinterpret_cast<char*>(workspace->ptr()),
workspace_size,
num_experts,
stream);
} else if (quant_method == "w4afp8") {
data_t *ffn2_shift = nullptr;
data_t *ffn2_smooth = nullptr;
float* row_scale = nullptr;
Allocator::AllocationPtr fp8_act_out;
fp8_act_out = allocator->Allocate(
SizeOf(paddle::DataType::INT8) * act_out_tensor.numel());
Allocator::AllocationPtr ffn2_input_row_sum;
ffn2_input_row_sum = allocator->Allocate(
sizeof(float) * expanded_active_expert_rows);
// note(yuanxiaolan): optimize this
MoeFastHardamardWrapper<data_t, data_t>(
act_out_tensor.data<data_t>(),
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>() : nullptr,
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
ffn2_shift, // ffn2_shift->data<T>(),
ffn2_smooth, // ffn2_smooth->data<T>(),
nullptr,
1,
448.0f,
-448.0f,
expanded_active_expert_rows,
inter_size / 2,
num_max_tokens_per_expert,
used_in_ep_low_latency,
hadamard_block_size,
act_out_tensor.data<data_t>(),
stream
);
quantize_moe_input<data_t, data_t_fp8>(act_out_tensor.data<data_t>(),
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>() : nullptr,
down_proj_in_scale ? const_cast<paddle::Tensor*>(down_proj_in_scale.get_ptr())->data<float>() : nullptr,
448.0f,
-448.0f,
expanded_active_expert_rows,
inter_size / 2,
reinterpret_cast<float*>(ffn2_input_row_sum->ptr()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
num_max_tokens_per_expert,
used_in_ep_low_latency,
reinterpret_cast<data_t_fp8 *>(fp8_act_out->ptr()),
stream
);
DisPatchW4AFp8GemmWrapper(
reinterpret_cast<const DataType_fp8 *>(fp8_act_out->ptr()),
reinterpret_cast<const DataType_fp8 *>(down_proj_weight.data<int8_t>()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
reinterpret_cast<float*>(ffn2_input_row_sum->ptr()),
row_scale,
auto act_out = act_out_tensor.data<data_t>();
if (quant_method == "weight_only_int8") {
typename cutlass::WintQuantTraits<
DataType_,
cutlass::WintQuantMethod::kWeightOnlyInt8>::Arguments quant_args;
int8_moe_gemm_runner.moe_gemm(
reinterpret_cast<const NvType*>(act_out),
reinterpret_cast<const uint8_t*>(down_proj_weight.data<int8_t>()),
reinterpret_cast<const NvType*>(
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr())
->data<float>(),
reinterpret_cast<NvType*>(ffn_out_data),
used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
used_in_ep_low_latency ? num_max_tokens_per_expert : act_out_tensor.dims()[0],
num_experts,
hidden_size,
inter_size / 2,
stream);
} else {
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>::Arguments quant_args;
fp16_moe_gemm_runner.moe_gemm(
reinterpret_cast<const NvType*>(act_out),
reinterpret_cast<const NvType*>(down_proj_weight.data<data_t>()),
nullptr,
reinterpret_cast<NvType*>(ffn_out_data),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
tune_total_rows,
hidden_size,
inter_size / 2,
num_experts,
quant_args,
stream);
}
->data<data_t>()),
reinterpret_cast<NvType*>(ffn_out_data),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
tune_total_rows,
hidden_size,
inter_size / 2,
num_experts,
quant_args,
stream);
} else if (quant_method == "weight_only_int4") {
typename cutlass::WintQuantTraits<
DataType_,
cutlass::WintQuantMethod::kWeightOnlyInt4>::Arguments quant_args;
int4_moe_gemm_runner.moe_gemm(
reinterpret_cast<const NvType*>(act_out),
reinterpret_cast<const cutlass::uint4b_t*>(
down_proj_weight.data<int8_t>()),
reinterpret_cast<const NvType*>(
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr())
->data<data_t>()),
reinterpret_cast<NvType*>(ffn_out_data),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
tune_total_rows,
hidden_size,
inter_size / 2,
num_experts,
quant_args,
stream);
} else if (quant_method == "w4a8") {
data_t* down_proj_shift = nullptr;
data_t* down_proj_smooth = nullptr;
Allocator::AllocationPtr int8_act_out;
int8_act_out = allocator->Allocate(SizeOf(paddle::DataType::INT8) *
act_out_tensor.numel());
MoeFastHardamardWrapper<data_t, int8_t>(
act_out_tensor.data<data_t>(),
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>()
: nullptr,
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
down_proj_shift, // down_proj_shift->data<T>(),
down_proj_smooth, // down_proj_smooth->data<T>(),
down_proj_in_scale
? const_cast<paddle::Tensor*>(down_proj_in_scale.get_ptr())
->data<float>()
: nullptr,
1,
127.0,
-127.0,
expanded_active_expert_rows,
inter_size / 2,
num_max_tokens_per_expert,
used_in_ep_low_latency,
hadamard_block_size,
reinterpret_cast<int8_t*>(int8_act_out->ptr()),
stream);
w4a8_moe_gemm_runner.moe_gemm(
reinterpret_cast<int8_t*>(int8_act_out->ptr()),
reinterpret_cast<const cutlass::uint4b_t*>(
down_proj_weight.data<int8_t>()),
quant_mode,
reinterpret_cast<const NvType*>(
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr())
->data<data_t>()),
nullptr, // down_proj_scale_dyquant
nullptr, // reinterpret_cast<const int32_t*>(d_nf4_look_up_table), //
// nf4_look_up_table
reinterpret_cast<NvType*>(ffn_out_data),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
used_in_ep_low_latency ? estimate_total_token_nums : tune_total_rows,
hidden_size,
inter_size / 2,
reinterpret_cast<char*>(workspace->ptr()),
workspace_size,
num_experts,
stream);
} else if (quant_method == "w4afp8") {
data_t* ffn2_shift = nullptr;
data_t* ffn2_smooth = nullptr;
float* row_scale = nullptr;
Allocator::AllocationPtr fp8_act_out;
fp8_act_out = allocator->Allocate(SizeOf(paddle::DataType::INT8) *
act_out_tensor.numel());
Allocator::AllocationPtr ffn2_input_row_sum;
ffn2_input_row_sum =
allocator->Allocate(sizeof(float) * expanded_active_expert_rows);
// note(yuanxiaolan): optimize this
MoeFastHardamardWrapper<data_t, data_t>(
act_out_tensor.data<data_t>(),
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>()
: nullptr,
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
ffn2_shift, // ffn2_shift->data<T>(),
ffn2_smooth, // ffn2_smooth->data<T>(),
nullptr,
1,
448.0f,
-448.0f,
expanded_active_expert_rows,
inter_size / 2,
num_max_tokens_per_expert,
used_in_ep_low_latency,
hadamard_block_size,
act_out_tensor.data<data_t>(),
stream);
quantize_moe_input<data_t, data_t_fp8>(
act_out_tensor.data<data_t>(),
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>()
: nullptr,
down_proj_in_scale
? const_cast<paddle::Tensor*>(down_proj_in_scale.get_ptr())
->data<float>()
: nullptr,
448.0f,
-448.0f,
expanded_active_expert_rows,
inter_size / 2,
reinterpret_cast<float*>(ffn2_input_row_sum->ptr()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
num_max_tokens_per_expert,
used_in_ep_low_latency,
reinterpret_cast<data_t_fp8*>(fp8_act_out->ptr()),
stream);
DisPatchW4AFp8GemmWrapper(
reinterpret_cast<const DataType_fp8*>(fp8_act_out->ptr()),
reinterpret_cast<const DataType_fp8*>(down_proj_weight.data<int8_t>()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
reinterpret_cast<float*>(ffn2_input_row_sum->ptr()),
row_scale,
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr())->data<float>(),
reinterpret_cast<NvType*>(ffn_out_data),
used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
used_in_ep_low_latency ? num_max_tokens_per_expert
: act_out_tensor.dims()[0],
num_experts,
hidden_size,
inter_size / 2,
stream);
} else {
typename cutlass::WintQuantTraits<
DataType_,
cutlass::WintQuantMethod::kNone>::Arguments quant_args;
fp16_moe_gemm_runner.moe_gemm(
reinterpret_cast<const NvType*>(act_out),
reinterpret_cast<const NvType*>(down_proj_weight.data<data_t>()),
nullptr,
reinterpret_cast<NvType*>(ffn_out_data),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
tune_total_rows,
hidden_size,
inter_size / 2,
num_experts,
quant_args,
stream);
}
}
paddle::Tensor MoeExpertFFNFunc(
@@ -414,55 +449,56 @@ paddle::Tensor MoeExpertFFNFunc(
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method,
const bool used_in_ep_low_latency,
const int estimate_total_token_nums, const int hadamard_block_size,
const int estimate_total_token_nums,
const int hadamard_block_size,
const std::string& activation) {
const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype() :
(quant_method == "w4afp8") ? paddle::DataType::BFLOAT16 :
permute_input.dtype();
auto ffn_out = paddle::empty_like(permute_input, t_type);
if(permute_input.numel() == 0){
return ffn_out;
}
switch (t_type) {
case paddle::DataType::BFLOAT16:
MoeFFNKernel<paddle::DataType::BFLOAT16>(permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
down_proj_in_scale,
expert_idx_per_token,
quant_method,
ffn_out,
used_in_ep_low_latency,
estimate_total_token_nums,
hadamard_block_size,
activation);
break;
case paddle::DataType::FLOAT16:
MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
down_proj_in_scale,
expert_idx_per_token,
quant_method,
ffn_out,
used_in_ep_low_latency,
estimate_total_token_nums,
hadamard_block_size,
activation);
break;
default:
PD_THROW("Unsupported data type for MoeExpertFFN");
}
const auto t_type = (quant_method == "w4a8")
? up_gate_proj_scale.get().dtype()
: (quant_method == "w4afp8") ? paddle::DataType::BFLOAT16
: permute_input.dtype();
auto ffn_out = paddle::empty_like(permute_input, t_type);
if (permute_input.numel() == 0) {
return ffn_out;
}
switch (t_type) {
case paddle::DataType::BFLOAT16:
MoeFFNKernel<paddle::DataType::BFLOAT16>(permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
down_proj_in_scale,
expert_idx_per_token,
quant_method,
ffn_out,
used_in_ep_low_latency,
estimate_total_token_nums,
hadamard_block_size,
activation);
break;
case paddle::DataType::FLOAT16:
MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
down_proj_in_scale,
expert_idx_per_token,
quant_method,
ffn_out,
used_in_ep_low_latency,
estimate_total_token_nums,
hadamard_block_size,
activation);
break;
default:
PD_THROW("Unsupported data type for MoeExpertFFN");
}
return ffn_out;
}
std::vector<paddle::Tensor> MoeExpertFFN(
@@ -475,24 +511,25 @@ std::vector<paddle::Tensor> MoeExpertFFN(
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method, const bool used_in_ep_low_latency,
const std::string& quant_method,
const bool used_in_ep_low_latency,
const int estimate_total_token_nums,
const int hadamard_block_size,
const std::string& activation) {
return {MoeExpertFFNFunc(permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
down_proj_in_scale,
expert_idx_per_token,
quant_method,
used_in_ep_low_latency,
estimate_total_token_nums,
hadamard_block_size,
activation)};
return {MoeExpertFFNFunc(permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
down_proj_in_scale,
expert_idx_per_token,
quant_method,
used_in_ep_low_latency,
estimate_total_token_nums,
hadamard_block_size,
activation)};
}
std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
@@ -510,21 +547,23 @@ std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
const int estimate_total_token_nums,
const int hadamard_block_size,
const std::string& activation) {
return {permute_input_shape};
return {permute_input_shape};
}
std::vector<paddle::DataType> MoeExpertFFNInferDtype(
const paddle::DataType &permute_input_dtype,
const paddle::DataType &tokens_expert_prefix_sum_dtype,
const paddle::DataType &up_gate_proj_weight_dtype,
const paddle::DataType &down_proj_weight_dtype,
const paddle::optional<paddle::DataType> &up_gate_proj_bias_dtype,
const paddle::optional<paddle::DataType> &up_gate_proj_scale_dtype,
const paddle::optional<paddle::DataType> &down_proj_scale_dtype,
const paddle::optional<paddle::DataType> &down_proj_in_scale_dtype,
const std::string &quant_method, const bool used_in_ep_low_latency,
const int estimate_total_token_nums, const int hadamard_block_size,
const std::string &activation) {
const paddle::DataType& permute_input_dtype,
const paddle::DataType& tokens_expert_prefix_sum_dtype,
const paddle::DataType& up_gate_proj_weight_dtype,
const paddle::DataType& down_proj_weight_dtype,
const paddle::optional<paddle::DataType>& up_gate_proj_bias_dtype,
const paddle::optional<paddle::DataType>& up_gate_proj_scale_dtype,
const paddle::optional<paddle::DataType>& down_proj_scale_dtype,
const paddle::optional<paddle::DataType>& down_proj_in_scale_dtype,
const std::string& quant_method,
const bool used_in_ep_low_latency,
const int estimate_total_token_nums,
const int hadamard_block_size,
const std::string& activation) {
if (quant_method == "w4a8" || quant_method == "w4afp8") {
return {up_gate_proj_scale_dtype.get()};
} else {
@@ -540,15 +579,15 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
* 2. SwiGLU activation function
* 3. Second linear transformation (down_proj) with optional quantization
*
* Supports multiple quantization methods including weight-only int4/int8 and w4a8 quantization.
* Supports multiple quantization methods including weight-only int4/int8 and
* w4a8 quantization.
*
* Inputs:
* - permute_input: Permuted input tensor organized by expert
* Shape: [total_tokens * top_k, hidden_size]
* dtype: bfloat16/float16 (or int8 for w4a8)
* - tokens_expert_prefix_sum: Prefix sum array of token counts per expert for group_gemm
* Shape: [num_experts]
* dtype: int64
* - tokens_expert_prefix_sum: Prefix sum array of token counts per expert for
* group_gemm Shape: [num_experts] dtype: int64
* - up_gate_proj_weight: First FFN layer weights
* Shape: [num_experts, inter_size * 2, hidden_size]
* dtype: Same as input (unquantized) or int8 (quantized)
@@ -564,8 +603,8 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
* - down_proj_scale: Quantization scales for second FFN layer
* Shape: [num_experts, hidden_size]
* dtype: Same as input
* - down_proj_in_scale: Optional input scales for second FFN layer (w4a8 only)
* dtype: float32
* - down_proj_in_scale: Optional input scales for second FFN layer (w4a8
* only) dtype: float32
* - expert_idx_per_token: Optional expert indices per token (w4a8 only)
* Shape: [total_tokens]
* dtype: int64
@@ -577,7 +616,8 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
*
* Attributes:
* - quant_method: Quantization method to use
* Options: "none", "weight_only_int4", "weight_only_int8", "w4a8"
* Options: "none", "weight_only_int4", "weight_only_int8",
* "w4a8"
* - used_in_ep_low_latency: Whether running in low latency mode
* Affects activation function implementation
* - estimate_total_token_nums: estimate total token numbers
@@ -598,7 +638,11 @@ PD_BUILD_STATIC_OP(moe_expert_ffn)
paddle::Optional("down_proj_in_scale"),
paddle::Optional("expert_idx_per_token")})
.Outputs({"output_tensor"})
.Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool", "estimate_total_token_nums:int", "hadamard_block_size:int", "activation:std::string"})
.Attrs({"quant_method:std::string",
"used_in_ep_low_latency:bool",
"estimate_total_token_nums:int",
"hadamard_block_size:int",
"activation:std::string"})
.SetKernelFn(PD_KERNEL(MoeExpertFFN))
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNInferDtype));

View File

@@ -0,0 +1,26 @@
{
"moe_fast_hardamard_impl": {
"name": "moe_fast_hardamard_impl",
"function_name": "MoeFastHardamardImplWrapper",
"impl_file": "moe_fast_hardamard_impl.cuh",
"template_params": [
"T",
"OutT",
"kLogN",
"VecSize",
"kNChunks",
"kThreads",
"UseDiagonalBlockMatrix"
],
"dispatch_params": {},
"data_types": [
["phi::dtype::float16", "phi::dtype::float16", "float16_float16"],
["phi::dtype::float16", "int8_t", "float16_int8"],
["phi::dtype::bfloat16", "phi::dtype::bfloat16", "bfloat16_bfloat16"],
["phi::dtype::bfloat16", "int8_t", "bfloat16_int8"]
],
"max_instances_per_file": 16,
"file_prefix": "moe_fast_hardamard_impl_",
"function_signature": "template void {function_name}{template_args}(\n const T *x,\n const int64_t *expert_idx_per_token,\n const int64_t *recv_expert_count,\n const T *shift,\n const T *smooth,\n const float* quant_scales,\n const int quant_round_type,\n const float quant_max_bound,\n const float quant_min_bound,\n const int64_t token_num,\n const int64_t dim,\n const int num_max_tokens_per_expert,\n bool used_in_ep_low_latency,\n OutT* out,\n cudaStream_t stream);\n\n"
}
}

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "append_attn/multi_head_latent_attention_kernel.h"
#include "append_attn/decoder_mla_attention_kernel.h"
#include "helper.h"
#include "mla_attn/batch_mla_with_paged_kv_cache.h"
@@ -66,10 +66,12 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
// int chunk_size = decoder_chunk_size_cpu.data<int>()[0];
//
const bool mla_use_tensorcore = true; //get_mla_use_tensorcore();
const bool mla_use_tensorcore = true; // get_mla_use_tensorcore();
auto sm_version = GetSMVersion();
if ((speculate_decoder || mla_use_tensorcore) && sm_version < 90) {
PD_THROW("Please use speculate_decoder=0 and FLAGS_mla_use_tensorcore=0 when sm < 90.");
PD_THROW(
"Please use speculate_decoder=0 and FLAGS_mla_use_tensorcore=0 when sm "
"< 90.");
}
auto main_stream = query.stream();

View File

@@ -381,7 +381,9 @@ elif paddle.is_compiled_with_cuda():
if cc >= 80:
# append_attention
os.system("python gpu_ops/append_attn/autogen_template_instantiation.py")
os.system(
"python utils/auto_gen_template_instantiation.py --config gpu_ops/append_attn/template_config.json --output gpu_ops/append_attn/template_instantiation/autogen"
)
sources += ["gpu_ops/append_attention.cu"]
sources += find_end_files("gpu_ops/append_attn", ".cu")
# mla
@@ -394,6 +396,9 @@ elif paddle.is_compiled_with_cuda():
nvcc_compile_args += ["-DENABLE_BF16"]
# moe
os.system("python gpu_ops/moe/moe_wna16_marlin_utils/generate_kernels.py")
os.system(
"python utils/auto_gen_template_instantiation.py --config gpu_ops/moe/template_config.json --output gpu_ops/moe/template_instantiation/autogen"
)
sources += find_end_files("gpu_ops/cutlass_kernels/moe_gemm/", ".cu")
sources += find_end_files("gpu_ops/cutlass_kernels/w4a8_moe/", ".cu")
sources += find_end_files("gpu_ops/moe/", ".cu")

View File

@@ -15,6 +15,7 @@
import argparse
import json
import shutil
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
@@ -65,6 +66,10 @@ class UniversalTemplateInstantiator:
f"Configuration '{config.name}' has T or OutT in template_params but no data_types configured"
)
# Skip validation for special handled functions
if config.name == "moe_fast_hardamard_impl":
return
special_params = {"T", "OutT", "NUM_WARP_Q"}
for param_name in config.template_params:
if param_name not in special_params and param_name not in config.dispatch_params:
@@ -112,10 +117,20 @@ class UniversalTemplateInstantiator:
return f"<{', '.join(template_args_parts)}>"
def _generate_function_signature(self, config: TemplateConfig, template_args: str) -> str:
def _generate_function_signature(
self, config: TemplateConfig, template_args: str, t_in: str = "", t_out: str = ""
) -> str:
"""Generate function signature."""
if config.function_signature:
return config.function_signature.format(function_name=config.function_name, template_args=template_args)
signature = config.function_signature.format(
function_name=config.function_name, template_args=template_args
)
# Replace T and OutT with actual types if provided
if t_in:
signature = signature.replace("const T *", f"const {t_in} *")
if t_out:
signature = signature.replace("OutT*", f"{t_out}*")
return signature
else:
raise ValueError(f"Function signature not found for {config.name}")
@@ -133,25 +148,73 @@ class UniversalTemplateInstantiator:
) -> str:
"""Generate template instantiation."""
template_args = self._build_template_args(config, t_in, t_out, params)
return self._generate_function_signature(config, template_args)
return self._generate_function_signature(config, template_args, t_in, t_out)
def _clean_output_directory(self, output_dir: str):
"""Clean output directory before generating new files."""
output_path = Path(output_dir)
if output_path.exists():
shutil.rmtree(output_path)
output_path.mkdir(parents=True, exist_ok=True)
def generate_combinations_for_type(self, config: TemplateConfig, t_in: str, t_out: str) -> List[Dict[str, Any]]:
"""Generate parameter combinations for specific type."""
combinations = []
def _generate_recursive(
params_dict: Dict[str, List[Any]], current_params: Dict[str, Any], param_names: List[str]
):
if not param_names:
combinations.append(current_params.copy())
return
if config.name == "moe_fast_hardamard_impl":
combinations = self._generate_moe_hardamard_combinations(config, t_in, t_out)
else:
param_name = param_names[0]
for value in params_dict[param_name]:
current_params[param_name] = value
_generate_recursive(params_dict, current_params, param_names[1:])
def _generate_recursive(
params_dict: Dict[str, List[Any]], current_params: Dict[str, Any], param_names: List[str]
):
if not param_names:
combinations.append(current_params.copy())
return
param_name = param_names[0]
for value in params_dict[param_name]:
current_params[param_name] = value
_generate_recursive(params_dict, current_params, param_names[1:])
_generate_recursive(config.dispatch_params, {}, list(config.dispatch_params.keys()))
return combinations
def _generate_moe_hardamard_combinations(
self, config: TemplateConfig, t_in: str, t_out: str
) -> List[Dict[str, Any]]:
"""Generate combinations for MoeFastHardamardImplWrapper based on code logic."""
combinations = []
for vec_size in [1, 2, 4, 8, 16]:
for log_n in [7, 8, 9, 10]:
combinations.append(
{"kLogN": log_n, "VecSize": vec_size, "kNChunks": 1, "kThreads": 128, "UseDiagonalBlockMatrix": 1}
)
for log_n in [7, 8, 9, 10]:
vec_size = (1 << log_n) // 128
combinations.append(
{"kLogN": log_n, "VecSize": vec_size, "kNChunks": 28, "kThreads": 128, "UseDiagonalBlockMatrix": 0}
)
combinations.append(
{"kLogN": log_n, "VecSize": vec_size, "kNChunks": 36, "kThreads": 128, "UseDiagonalBlockMatrix": 0}
)
for log_n in [11, 12, 13, 14]:
vec_size = 8
n_chunks = (1 << log_n) // (128 * vec_size)
combinations.append(
{
"kLogN": log_n,
"VecSize": vec_size,
"kNChunks": n_chunks,
"kThreads": 128,
"UseDiagonalBlockMatrix": 0,
}
)
_generate_recursive(config.dispatch_params, {}, list(config.dispatch_params.keys()))
return combinations
def split_combinations(self, combinations: List[Dict[str, Any]], max_per_file: int) -> List[List[Dict[str, Any]]]:
@@ -186,7 +249,7 @@ class UniversalTemplateInstantiator:
config = self.configs[function_name]
output_path = Path(output_dir)
output_path.mkdir(exist_ok=True)
output_path.mkdir(parents=True, exist_ok=True)
if not config.data_types:
data_types = [("", "", "")]
@@ -206,6 +269,7 @@ class UniversalTemplateInstantiator:
def generate_all(self, output_dir: str):
"""Generate all configured function types."""
self._clean_output_directory(output_dir)
for function_name in self.configs.keys():
print(f"Generating template instantiations for {function_name}...")
self.generate_for_function_type(function_name, output_dir)
@@ -219,14 +283,12 @@ def main():
"--config",
"-c",
type=str,
default="gpu_ops/append_attn/template_config.json",
help="Configuration file path (JSON format)",
)
parser.add_argument(
"--output",
"-o",
type=str,
default="gpu_ops/append_attn/template_instantiation/autogen",
help="Output directory",
)