mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
Optimize the performance of moe_expert_ffn_wint2 (#2990)
* Change wint2 to ColumnMajor. Change-Id: I6b44d02946a685f8fe24d9f2c7be258b51e16da2 * Unify default_wint2x_mma. Change-Id: I9e77b0e8e6cecab01fedc0b24b536ee0a1a89ff7 * Change wint2 to ColumnMajorTileInterleave. Change-Id: I593cbe36f991c0c5044989d65f0014087587c624 * Enable async copy for B. Change-Id: Ia3ac37ad162a8cf3ccce4f268e81bd06c8ac3c46 * Add wint2x Dequantizer * Remove TileDequanterB related codes. Change-Id: Id8e65703b72a8984d367f584ff41b7726017fbb8 * Implement FastInterleavedAndBiasedNumericArrayConverter for wint2. Change-Id: I438f2b18ab964a04ae1cdb09d9e7d9f7b95eafca * Implement Wint2ParamsAccessor to load extra quant params from global memory. Change-Id: Ic3750cd9b767df8893501820880c3342a4b47233 * Implement FastInterleavedAndBiasedNumericArrayConverter for wint2. Change-Id: I438f2b18ab964a04ae1cdb09d9e7d9f7b95eafca * Use async copy for local_scale. Change-Id: Ib882ba41c3d2354bda4d25b40e2408ad3b2f7893 * Check and correct the load and dequantize of weights. Change-Id: Ie8dca505b39987144964fe6407d465b3b5953790 * Change for performance tuning. Change-Id: I1da026fb1d1533a9d70350c7ba23c27e896cfc29 * Optimize the global memory access size of local_scale reading. Change-Id: I4cbe3a2ef5951723d415c2d3252ce912394beaf5 * Specialize mma_tensor_op for wint2 to enable fine-grained pipeline. Change-Id: Icbb4d48f90a41136f42d6ffff42d68de32f408da * Minor fix. Change-Id: I14d4ac9d267ee05442a3b47f00c26bee13d79e6f * optimizing dequant performance with LOP3 * optimizing dequant performance with LOP3 * Avoid redundant dequantization of local_scale and use bf16 as computing type. Change-Id: I63239ebc8f8e4a92d6281af59840ba50600b4334 * Add Multiplier and remove some logs. Change-Id: Ifa199d81e6aeb472d2247c63f85ef30213684bcd * optimizing dequant performance with LOP3 * Use __byte_perm to implement int8 to float32 conversion for performance improvement. * Use lop3 to optimize the dequantize of local_scale. Change-Id: I6189759970cb5b8dcbef769724784b8a7533b63c * Minor fix and remove some logs. Change-Id: I6279ba9926d5041093b1c6aea200acf2e4c49d46 * Fix stages for test. Change-Id: I6f7b7cac612ef2c678e9d49f5ffa60eb53d3ae29 * Fix stages for test and add clock64 to profile. Change-Id: Iffaf7324beaa910ce9ee56f47ae289de98f1a267 * Use __byte_perm to replace shift-and-or operations for faster integer merging. * Split the uint2b convert. Change-Id: I78da672ce8968e21f685285140ba546a161521b4 * Optimize convert of unscale. Change-Id: I6795da1cdf5e8ab38ddaa9836240921b5312913a * Minor optimization. Change-Id: I1800aec34c3f4621abb02658208108f54da44d88 * Optimize mma pipeline and refine codes. Change-Id: Id3075cf7b88f2813a11ccd1d3b49c62c978f36b8 * Add missing support. Change-Id: Id65b7bc2c25fbb1a5b232c6bc9fb8c9093f691a8 * Accelerate FP16 dequantization performance * Support tile shape as Xx64x64. Change-Id: Ib8fd37e1ba1d06f7d11f2956e7f1367b0a92bcac * Remove debugging codes and minor optimization. Change-Id: I6b79bd56a6e8dd823efc169967ecd3cc9a43baf4 * Fix offset bug. Change-Id: Id7aeb91e99d6f51836f2aff22187b4f79607395e * Fix typo. Change-Id: I19dde93fc1c1f7e19605905c90dc46298e203952 * Restore some codes and remove some debugging logs. Change-Id: I8d44daf82ad1c6f8174134d195e7b3fe9a3afdfb --------- Co-authored-by: baoqiwen <baoqiwen@baidu.com>
This commit is contained in:
@@ -43,7 +43,6 @@
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h"
|
||||
#include "cutlass_extensions/tile_interleaved_layout.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -775,17 +774,54 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
template <WintQuantMethod QuantMethod, typename dummy>
|
||||
struct KernelRunner<QuantMethod, true, dummy> {
|
||||
using WeightQuantTraits = WintQuantTraits<ElementA, QuantMethod>;
|
||||
using QuantArguments = typename WeightQuantTraits::Arguments;
|
||||
using MmaQuantArguments = typename Mma::QuantParamsAccessor::Arguments;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static QuantArguments get_quant_args(Params const& params, int32_t problem_idx, const int64_t gemm_k, const int64_t gemm_n) {
|
||||
QuantArguments quant_args;
|
||||
if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) {
|
||||
quant_args.local_scale_ptr = params.local_scale + problem_idx * gemm_k * gemm_n / 128;
|
||||
quant_args.code_scale_ptr = params.code_scale + problem_idx * gemm_n;
|
||||
quant_args.code_zp_ptr = params.code_zp + problem_idx * gemm_n;
|
||||
}
|
||||
return quant_args;
|
||||
static MmaQuantArguments prepare_quant_args(
|
||||
Params const& params, cutlass::gemm::GemmCoord const& threadblock_offset,
|
||||
int64_t problem_idx, const int32_t gemm_k, const int32_t gemm_n, const int thread_idx) {
|
||||
// the begin threadblock_offset of scale, which holds the same column id with C, but with no row id
|
||||
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
|
||||
cutlass::MatrixCoord tb_offset_local_scale{0, threadblock_offset.n() * 2};
|
||||
|
||||
ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * gemm_n;
|
||||
typename Mma::QuantParamsAccessor::IteratorSuperScale iterator_super_scale(
|
||||
Mma::QuantParamsAccessor::LayoutSuperScale(gemm_n),
|
||||
weight_scale_ptr,
|
||||
{1, gemm_n},
|
||||
thread_idx,
|
||||
tb_offset_scale);
|
||||
|
||||
int local_scale_pointer_offset = ((ThreadblockShape::kK + 127) / 128) * (gemm_n * 2);
|
||||
int64_t offset_in_bytes = problem_idx * gemm_k * gemm_n / 128;
|
||||
uint4b_t *local_scale_ptr = reinterpret_cast<uint4b_t *>(params.local_scale + offset_in_bytes);
|
||||
|
||||
typename Mma::QuantParamsAccessor::IteratorLocalScale iterator_local_scale(
|
||||
Mma::QuantParamsAccessor::LayoutLocalScale(gemm_n * 2),
|
||||
local_scale_ptr,
|
||||
{(gemm_k + 127) / 128, gemm_n * 2},
|
||||
thread_idx,
|
||||
tb_offset_local_scale);
|
||||
|
||||
float* code_scale_ptr = params.code_scale + problem_idx * gemm_n;
|
||||
typename Mma::QuantParamsAccessor::IteratorCodeScaleZp iterator_code_scale(
|
||||
Mma::QuantParamsAccessor::LayoutCodeScaleZp(gemm_n),
|
||||
code_scale_ptr,
|
||||
{1, gemm_n},
|
||||
thread_idx,
|
||||
tb_offset_scale);
|
||||
|
||||
float* code_zp_ptr = params.code_zp + problem_idx * gemm_n;
|
||||
typename Mma::QuantParamsAccessor::IteratorCodeScaleZp iterator_code_zp(
|
||||
Mma::QuantParamsAccessor::LayoutCodeScaleZp(gemm_n),
|
||||
code_zp_ptr,
|
||||
{1, gemm_n},
|
||||
thread_idx,
|
||||
tb_offset_scale);
|
||||
|
||||
MmaQuantArguments mma_quant_args(
|
||||
iterator_super_scale, iterator_local_scale, iterator_code_scale, iterator_code_zp, local_scale_pointer_offset);
|
||||
return mma_quant_args;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
@@ -814,9 +850,6 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
kInterleave >= 1,
|
||||
"B must be row major/col major OR col major interleaved.");
|
||||
|
||||
// LayoutB should be RowMajor
|
||||
using TileDequanterB = cutlass::gemm::threadblock::TileDequanter<ElementA, ElementScale, ThreadblockShape::kK, ThreadblockShape::kN, kStages, kThreadCount, QuantMethod>;
|
||||
|
||||
//
|
||||
// Problem visitor.
|
||||
//
|
||||
@@ -843,12 +876,6 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
int(cta_idx % grid_shape.n()) * Mma::Shape::kN, // NOLINT
|
||||
0);
|
||||
|
||||
// begin address offset for weight_scale.
|
||||
ElementScale* weight_scale_ptr =
|
||||
params.weight_scales ? params.weight_scales + problem_idx * problem_size.n() : nullptr;
|
||||
// the begin threadblock_offset of scale, which holds the same column id with C, but with no row id
|
||||
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
|
||||
|
||||
// Load element pointers. Exchange pointers and strides if working on
|
||||
// the transpose
|
||||
int64_t rows_to_jump = 0;
|
||||
@@ -866,42 +893,20 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
// the begin threadblock_offset of A, which holds the same row id with C
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_offset.m(),
|
||||
0,
|
||||
};
|
||||
cutlass::MatrixCoord tb_offset_A{threadblock_offset.m(), 0};
|
||||
|
||||
// begin address offset for B for current problem_idx, totally num_experts problems
|
||||
char* byte_ptr_B = ((char*)params.ptr_B) + // NOLINT
|
||||
problem_idx * bytes_per_expert_matrix; // NOLINT
|
||||
|
||||
ElementB* ptr_B = reinterpret_cast<ElementB*>(byte_ptr_B);
|
||||
typename LayoutB::LongIndex ldm_B =
|
||||
platform::is_same<layout::RowMajor, LayoutB>::value
|
||||
? gemm_n
|
||||
: gemm_k * kInterleave;
|
||||
typename LayoutB::LongIndex ldm_B_shared = TileDequanterB::kColumns;
|
||||
|
||||
// the begin threadblock_offset of B, which holds the same column id with C
|
||||
cutlass::MatrixCoord tb_offset_B{0,
|
||||
threadblock_offset.n() / kInterleave};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave};
|
||||
cutlass::MatrixCoord extent_B{problem_size.k() * kInterleave, problem_size.n() / kInterleave};
|
||||
cutlass::MatrixCoord extent_B_shared{TileDequanterB::kRows, TileDequanterB::kColumns};
|
||||
|
||||
MmaElementB* smem_unzip_B_ptr = nullptr;
|
||||
if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) {
|
||||
smem_unzip_B_ptr = shared_storage.main_loop.operand_unzip_B_ptr();
|
||||
}
|
||||
QuantArguments quant_args = get_quant_args(params, problem_idx, gemm_k, gemm_n);
|
||||
TileDequanterB tile_dequanter_B(smem_unzip_B_ptr,
|
||||
byte_ptr_B,
|
||||
ldm_B,
|
||||
extent_B,
|
||||
tb_offset_B,
|
||||
weight_scale_ptr,
|
||||
tb_offset_scale,
|
||||
quant_args);
|
||||
MmaElementB* ptr_B = tile_dequanter_B.GetOutPtr();
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
@@ -914,20 +919,21 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
LayoutB(TileDequanterB::kUseSharedMemory ? ldm_B_shared : ldm_B),
|
||||
LayoutB(ldm_B),
|
||||
ptr_B,
|
||||
TileDequanterB::kUseSharedMemory ? extent_B_shared : extent_B,
|
||||
extent_B,
|
||||
thread_idx,
|
||||
TileDequanterB::kUseSharedMemory ? cutlass::make_Coord(0, 0) : tb_offset_B);
|
||||
tb_offset_B);
|
||||
|
||||
MmaQuantArguments mma_quant_args = prepare_quant_args(
|
||||
params, threadblock_offset, problem_idx, gemm_k, gemm_n, thread_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
@@ -950,7 +956,7 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
accumulators,
|
||||
iterator_A,
|
||||
iterator_B,
|
||||
tile_dequanter_B,
|
||||
mma_quant_args,
|
||||
accumulators);
|
||||
|
||||
//
|
||||
|
||||
@@ -205,7 +205,7 @@ void generic_moe_gemm_kernelLauncher(const T* A,
|
||||
threadblock_count,
|
||||
epilogue_op,
|
||||
reinterpret_cast<const ElementType*>(A),
|
||||
reinterpret_cast<const CutlassMmaWeightType*>(B),
|
||||
reinterpret_cast<const CutlassMmaKernelType*>(B),
|
||||
reinterpret_cast<const ElementType*>(weight_scales),
|
||||
reinterpret_cast<const ElementType*>(biases),
|
||||
reinterpret_cast<ElementType*>(C),
|
||||
|
||||
Reference in New Issue
Block a user