mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-16 13:41:30 +08:00
[Optimize] Optimize tensorwise fp8 performance (#2729)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* [Optimize] Optimize tensorwise fp8 performance
This commit is contained in:
@@ -468,6 +468,28 @@ std::vector<paddle::Tensor> NoauxTc(
|
|||||||
int topk,
|
int topk,
|
||||||
float routed_scaling_factor);
|
float routed_scaling_factor);
|
||||||
|
|
||||||
|
paddle::Tensor cutlass_fp8_fp8_half_gemm_func(
|
||||||
|
const paddle::Tensor& x,
|
||||||
|
const paddle::Tensor& y,
|
||||||
|
const paddle::optional<paddle::Tensor>& bias,
|
||||||
|
bool trans_x,
|
||||||
|
bool trans_y,
|
||||||
|
float scale, // only support per-tensor quantization
|
||||||
|
std::string output_dtype,
|
||||||
|
std::string activation_type);
|
||||||
|
|
||||||
|
paddle::Tensor MoeFusedHadamardQuantFp8Func(
|
||||||
|
const paddle::Tensor &input,
|
||||||
|
const paddle::Tensor &scale,
|
||||||
|
const paddle::Tensor &topk_ids,
|
||||||
|
const int top_k,
|
||||||
|
const int intermediate_size,
|
||||||
|
const bool tiled);
|
||||||
|
|
||||||
|
paddle::Tensor FusedHadamardQuantFp8Func(
|
||||||
|
const paddle::Tensor &input,
|
||||||
|
const float scale);
|
||||||
|
|
||||||
PYBIND11_MODULE(fastdeploy_ops, m) {
|
PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||||
|
|
||||||
m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
|
m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
|
||||||
@@ -697,38 +719,21 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
|||||||
"text_image_gather_scatter function");
|
"text_image_gather_scatter function");
|
||||||
|
|
||||||
m.def("count_tokens_per_expert_func", &count_tokens_per_expert_func);
|
m.def("count_tokens_per_expert_func", &count_tokens_per_expert_func);
|
||||||
|
|
||||||
m.def("tritonmoe_preprocess_func", &tritonmoe_preprocess_kernel);
|
m.def("tritonmoe_preprocess_func", &tritonmoe_preprocess_kernel);
|
||||||
|
|
||||||
m.def("MoeWna16MarlinGemmApi", &MoeWna16MarlinGemmApi,
|
m.def("MoeWna16MarlinGemmApi", &MoeWna16MarlinGemmApi,
|
||||||
py::arg("a"),
|
py::arg("a"), py::arg("c_or_none"), py::arg("b_q_weight"),
|
||||||
py::arg("c_or_none"),
|
py::arg("b_scales"), py::arg("global_scale_or_none"), py::arg("b_zeros_or_none"),
|
||||||
py::arg("b_q_weight"),
|
py::arg("g_idx_or_none"), py::arg("perm_or_none"), py::arg("workspace"), py::arg("sorted_token_ids"),
|
||||||
py::arg("b_scales"),
|
py::arg("expert_ids"), py::arg("num_tokens_post_padded"), py::arg("topk_weights"), py::arg("moe_block_size"),
|
||||||
py::arg("global_scale_or_none"),
|
py::arg("top_k"), py::arg("mul_topk_weights"), py::arg("is_ep"), py::arg("b_q_type_str"),
|
||||||
py::arg("b_zeros_or_none"),
|
py::arg("size_m"), py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full"), py::arg("use_atomic_add"),
|
||||||
py::arg("g_idx_or_none"),
|
py::arg("use_fp32_reduce"), py::arg("is_zp_float"));
|
||||||
py::arg("perm_or_none"),
|
|
||||||
py::arg("workspace"),
|
|
||||||
py::arg("sorted_token_ids"),
|
|
||||||
py::arg("expert_ids"),
|
|
||||||
py::arg("num_tokens_post_padded"),
|
|
||||||
py::arg("topk_weights"),
|
|
||||||
py::arg("moe_block_size"),
|
|
||||||
py::arg("top_k"),
|
|
||||||
py::arg("mul_topk_weights"),
|
|
||||||
py::arg("is_ep"),
|
|
||||||
py::arg("b_q_type_str"),
|
|
||||||
py::arg("size_m"),
|
|
||||||
py::arg("size_n"),
|
|
||||||
py::arg("size_k"),
|
|
||||||
py::arg("is_k_full"),
|
|
||||||
py::arg("use_atomic_add"),
|
|
||||||
py::arg("use_fp32_reduce"),
|
|
||||||
py::arg("is_zp_float"));
|
|
||||||
m.def("get_position_ids_and_mask_encoder_batch", &GetPositionIdsAndMaskEncoderBatch,
|
m.def("get_position_ids_and_mask_encoder_batch", &GetPositionIdsAndMaskEncoderBatch,
|
||||||
"get_position_ids_and_mask_encoder_batch function");
|
"get_position_ids_and_mask_encoder_batch function");
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* cutlass_scaled_mm.cu
|
* cutlass_scaled_mm.cu
|
||||||
* cutlass_scaled_mm
|
* cutlass_scaled_mm
|
||||||
@@ -753,6 +758,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
|||||||
m.def("dynamic_per_token_scaled_fp8_quant", &DynamicPerTokenScaledFp8Quant,
|
m.def("dynamic_per_token_scaled_fp8_quant", &DynamicPerTokenScaledFp8Quant,
|
||||||
"dynamic_per_token_scaled_fp8_quant function",
|
"dynamic_per_token_scaled_fp8_quant function",
|
||||||
py::arg("out"), py::arg("input"), py::arg("scales"), py::arg("scale_ub"));
|
py::arg("out"), py::arg("input"), py::arg("scales"), py::arg("scale_ub"));
|
||||||
|
|
||||||
m.def("decode_mla_write_cache", &DecodeMLAWriteCacheKernel, "decode_mla_write_cache function");
|
m.def("decode_mla_write_cache", &DecodeMLAWriteCacheKernel, "decode_mla_write_cache function");
|
||||||
|
|
||||||
m.def("prefill_mla_write_cache", &PrefillMLAWriteCacheKernel, "prefill_mla_write_cache function");
|
m.def("prefill_mla_write_cache", &PrefillMLAWriteCacheKernel, "prefill_mla_write_cache function");
|
||||||
@@ -762,4 +768,16 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
|||||||
m.def("multi_head_latent_attention", &MultiHeadLatentAttention, "multi_head_latent_attention function");
|
m.def("multi_head_latent_attention", &MultiHeadLatentAttention, "multi_head_latent_attention function");
|
||||||
|
|
||||||
m.def("noaux_tc",&NoauxTc, "noaux_tc for Deepseekv3 MoE compute");
|
m.def("noaux_tc",&NoauxTc, "noaux_tc for Deepseekv3 MoE compute");
|
||||||
|
|
||||||
|
m.def("cutlass_fp8_fp8_half_gemm_fused", &cutlass_fp8_fp8_half_gemm_func,
|
||||||
|
py::arg("x"), py::arg("y"), py::arg("bias"), py::arg("transpose_x"),
|
||||||
|
py::arg("transpose_y"), py::arg("scale"), py::arg("output_dtype"),
|
||||||
|
py::arg("activation_type"), "cutlass_fp8_fp8_half_gemm_fused function");
|
||||||
|
|
||||||
|
m.def("moe_fused_hadamard_quant_fp8", &MoeFusedHadamardQuantFp8Func,
|
||||||
|
py::arg("input"), py::arg("scale"), py::arg("topk_ids"),
|
||||||
|
py::arg("top_k"), py::arg("intermediate_size"), py::arg("tiled"), "moe_fused_hadamard_quant_fp8 function");
|
||||||
|
|
||||||
|
m.def("fused_hadamard_quant_fp8", &FusedHadamardQuantFp8Func,
|
||||||
|
py::arg("input"), py::arg("scale"), "fused_hadamard_quant_fp8 function");
|
||||||
}
|
}
|
||||||
|
@@ -19,7 +19,7 @@
|
|||||||
#include "fp8_fp8_half_cuda_core_gemm.h"
|
#include "fp8_fp8_half_cuda_core_gemm.h"
|
||||||
|
|
||||||
|
|
||||||
std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
|
paddle::Tensor cutlass_fp8_fp8_half_gemm_func(
|
||||||
const paddle::Tensor& x,
|
const paddle::Tensor& x,
|
||||||
const paddle::Tensor& y,
|
const paddle::Tensor& y,
|
||||||
const paddle::optional<paddle::Tensor>& bias,
|
const paddle::optional<paddle::Tensor>& bias,
|
||||||
@@ -142,7 +142,7 @@ std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
|
|||||||
{
|
{
|
||||||
if(output_dtype == "bfloat16") {
|
if(output_dtype == "bfloat16") {
|
||||||
cuda_core_gemm_launcher<__nv_fp8_e4m3, __nv_bfloat16>(params);
|
cuda_core_gemm_launcher<__nv_fp8_e4m3, __nv_bfloat16>(params);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
cuda_core_gemm_launcher<__nv_fp8_e4m3, half>(params);
|
cuda_core_gemm_launcher<__nv_fp8_e4m3, half>(params);
|
||||||
}
|
}
|
||||||
@@ -174,7 +174,21 @@ std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
|
|||||||
fuse_gemm_config};
|
fuse_gemm_config};
|
||||||
fp8_fp8_gemm_scale_bias_act(params);
|
fp8_fp8_gemm_scale_bias_act(params);
|
||||||
}
|
}
|
||||||
return {out};
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
|
||||||
|
const paddle::Tensor& x,
|
||||||
|
const paddle::Tensor& y,
|
||||||
|
const paddle::optional<paddle::Tensor>& bias,
|
||||||
|
bool trans_x,
|
||||||
|
bool trans_y,
|
||||||
|
float scale, // only support per-tensor quantization
|
||||||
|
std::string output_dtype,
|
||||||
|
std::string activation_type) {
|
||||||
|
return {cutlass_fp8_fp8_half_gemm_func(
|
||||||
|
x, y, bias, trans_x, trans_y, scale,
|
||||||
|
output_dtype, activation_type)};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::vector<int64_t>> CutlassFp8Fp8HalfGemmFusedInferShape(
|
std::vector<std::vector<int64_t>> CutlassFp8Fp8HalfGemmFusedInferShape(
|
||||||
|
198
custom_ops/gpu_ops/fused_hadamard_quant_fp8.cu
Normal file
198
custom_ops/gpu_ops/fused_hadamard_quant_fp8.cu
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <fcntl.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include <sys/mman.h>
|
||||||
|
#include <sys/stat.h>
|
||||||
|
#include <sys/types.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <algorithm>
|
||||||
|
#include "helper.h"
|
||||||
|
|
||||||
|
__device__ __forceinline__ void hadamard32_warp(__nv_bfloat16& x) {
|
||||||
|
int lane_id = threadIdx.x % 32;
|
||||||
|
#pragma unroll
|
||||||
|
for (int step = 0; step < 5; ++step) {
|
||||||
|
const int lane_mask = 1 << step;
|
||||||
|
const __nv_bfloat16 sign = (lane_id & lane_mask) ? -1.f : 1.f;
|
||||||
|
__nv_bfloat16 x_val_other = __shfl_xor_sync(0xffffffff, x, lane_mask);
|
||||||
|
x = sign * x + x_val_other;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void MoeFusedHadamardQuantFp8Kernel(
|
||||||
|
const __nv_bfloat16* __restrict__ input,
|
||||||
|
const float* __restrict__ scale,
|
||||||
|
const int64_t* __restrict__ topk_ids,
|
||||||
|
__nv_fp8_e4m3* out,
|
||||||
|
const int top_k,
|
||||||
|
const int intermediate_size,
|
||||||
|
const int64_t numel
|
||||||
|
) {
|
||||||
|
int64_t out_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (out_idx >= numel) return;
|
||||||
|
|
||||||
|
int64_t token_idx = out_idx / (top_k * intermediate_size);
|
||||||
|
int64_t topk_idx = (out_idx / intermediate_size) % top_k;
|
||||||
|
int64_t inter_idx = out_idx % intermediate_size;
|
||||||
|
|
||||||
|
int64_t input_idx = token_idx * intermediate_size + inter_idx;
|
||||||
|
if (input_idx >= numel / top_k) return;
|
||||||
|
|
||||||
|
int64_t expert_id = topk_ids[token_idx * top_k + topk_idx];
|
||||||
|
float scale_value = scale[expert_id];
|
||||||
|
|
||||||
|
__nv_bfloat16 x = input[input_idx];
|
||||||
|
hadamard32_warp(x);
|
||||||
|
|
||||||
|
float x_fp32 = __bfloat162float(x);
|
||||||
|
float quantized = x_fp32 / scale_value;
|
||||||
|
out[out_idx] = static_cast<__nv_fp8_e4m3>(quantized);
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void MoeFusedHadamardQuantFp8TiledKernel(
|
||||||
|
const __nv_bfloat16* __restrict__ input,
|
||||||
|
const float* __restrict__ scale,
|
||||||
|
const int64_t* __restrict__ topk_ids,
|
||||||
|
__nv_fp8_e4m3* out,
|
||||||
|
const int top_k,
|
||||||
|
const int intermediate_size,
|
||||||
|
const int64_t numel
|
||||||
|
) {
|
||||||
|
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (idx >= numel) return;
|
||||||
|
|
||||||
|
int64_t token_idx = idx / intermediate_size;
|
||||||
|
int64_t expert_id = topk_ids[token_idx];
|
||||||
|
float scale_value = scale[expert_id];
|
||||||
|
|
||||||
|
__nv_bfloat16 x = input[idx];
|
||||||
|
hadamard32_warp(x);
|
||||||
|
|
||||||
|
float x_fp32 = __bfloat162float(x);
|
||||||
|
float quantized = x_fp32 / scale_value;
|
||||||
|
out[idx] = static_cast<__nv_fp8_e4m3>(quantized);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<paddle::Tensor> MoeFusedHadamardQuantFp8(
|
||||||
|
const paddle::Tensor &input,
|
||||||
|
const paddle::Tensor &scale,
|
||||||
|
const paddle::Tensor &topk_ids,
|
||||||
|
const int top_k,
|
||||||
|
const int intermediate_size,
|
||||||
|
const bool tiled) {
|
||||||
|
int64_t numel = input.numel();
|
||||||
|
if (!tiled) numel *= top_k;
|
||||||
|
paddle::Tensor out = GetEmptyTensor(
|
||||||
|
{numel / intermediate_size, intermediate_size},
|
||||||
|
paddle::DataType::FLOAT8_E4M3FN,
|
||||||
|
input.place());
|
||||||
|
constexpr int64_t thread_per_block = 256;
|
||||||
|
int64_t block_per_grid = (numel + thread_per_block - 1) / thread_per_block;
|
||||||
|
auto stream = input.stream();
|
||||||
|
if (tiled) {
|
||||||
|
MoeFusedHadamardQuantFp8TiledKernel<<<block_per_grid, thread_per_block, 0, stream>>>(
|
||||||
|
reinterpret_cast<const __nv_bfloat16*>(input.data<paddle::bfloat16>()),
|
||||||
|
scale.data<float>(),
|
||||||
|
topk_ids.data<int64_t>(),
|
||||||
|
reinterpret_cast<__nv_fp8_e4m3*>(out.mutable_data<phi::dtype::float8_e4m3fn>()),
|
||||||
|
top_k,
|
||||||
|
intermediate_size,
|
||||||
|
numel
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
MoeFusedHadamardQuantFp8Kernel<<<block_per_grid, thread_per_block, 0, stream>>>(
|
||||||
|
reinterpret_cast<const __nv_bfloat16*>(input.data<phi::dtype::bfloat16>()),
|
||||||
|
scale.data<float>(),
|
||||||
|
topk_ids.data<int64_t>(),
|
||||||
|
reinterpret_cast<__nv_fp8_e4m3*>(out.mutable_data<phi::dtype::float8_e4m3fn>()),
|
||||||
|
top_k,
|
||||||
|
intermediate_size,
|
||||||
|
numel
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return {out};
|
||||||
|
}
|
||||||
|
|
||||||
|
PD_BUILD_STATIC_OP(moe_fused_hadamard_quant_fp8)
|
||||||
|
.Inputs({"input", "scale", "topk_ids"})
|
||||||
|
.Outputs({"output"})
|
||||||
|
.Attrs({"top_k: int",
|
||||||
|
"intermediate_size: int",
|
||||||
|
"tiled: bool"})
|
||||||
|
.SetKernelFn(PD_KERNEL(MoeFusedHadamardQuantFp8));
|
||||||
|
|
||||||
|
|
||||||
|
paddle::Tensor MoeFusedHadamardQuantFp8Func(
|
||||||
|
const paddle::Tensor &input,
|
||||||
|
const paddle::Tensor &scale,
|
||||||
|
const paddle::Tensor &topk_ids,
|
||||||
|
const int top_k,
|
||||||
|
const int intermediate_size,
|
||||||
|
const bool tiled) {
|
||||||
|
return MoeFusedHadamardQuantFp8(input, scale, topk_ids, top_k, intermediate_size, tiled)[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__global__ void FusedHadamardQuantFp8Kernel(
|
||||||
|
const __nv_bfloat16* __restrict__ input,
|
||||||
|
__nv_fp8_e4m3* out,
|
||||||
|
const float scale,
|
||||||
|
const int64_t numel) {
|
||||||
|
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (idx >= numel) return;
|
||||||
|
|
||||||
|
__nv_bfloat16 x = input[idx];
|
||||||
|
hadamard32_warp(x);
|
||||||
|
|
||||||
|
float x_fp32 = __bfloat162float(x);
|
||||||
|
float quantized = x_fp32 / scale;
|
||||||
|
out[idx] = static_cast<__nv_fp8_e4m3>(quantized);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<paddle::Tensor> FusedHadamardQuantFp8(
|
||||||
|
const paddle::Tensor &input,
|
||||||
|
const float scale) {
|
||||||
|
int64_t numel = input.numel();
|
||||||
|
paddle::Tensor out = GetEmptyTensor(
|
||||||
|
input.dims(),
|
||||||
|
paddle::DataType::FLOAT8_E4M3FN,
|
||||||
|
input.place());
|
||||||
|
constexpr int64_t thread_per_block = 256;
|
||||||
|
int64_t block_per_grid = (numel + thread_per_block - 1) / thread_per_block;
|
||||||
|
auto stream = input.stream();
|
||||||
|
FusedHadamardQuantFp8Kernel<<<block_per_grid, thread_per_block, 0, stream>>>(
|
||||||
|
reinterpret_cast<const __nv_bfloat16*>(input.data<paddle::bfloat16>()),
|
||||||
|
reinterpret_cast<__nv_fp8_e4m3*>(out.mutable_data<phi::dtype::float8_e4m3fn>()),
|
||||||
|
scale,
|
||||||
|
numel
|
||||||
|
);
|
||||||
|
return {out};
|
||||||
|
}
|
||||||
|
|
||||||
|
PD_BUILD_STATIC_OP(fused_hadamard_quant_fp8)
|
||||||
|
.Inputs({"input"})
|
||||||
|
.Outputs({"output"})
|
||||||
|
.Attrs({"scale: float"})
|
||||||
|
.SetKernelFn(PD_KERNEL(FusedHadamardQuantFp8));
|
||||||
|
|
||||||
|
|
||||||
|
paddle::Tensor FusedHadamardQuantFp8Func(
|
||||||
|
const paddle::Tensor &input,
|
||||||
|
const float scale) {
|
||||||
|
return FusedHadamardQuantFp8(input, scale)[0];
|
||||||
|
}
|
@@ -442,6 +442,7 @@ elif paddle.is_compiled_with_cuda():
|
|||||||
"gpu_ops/scaled_gemm_f8_i4_f16_weight_quantize.cu",
|
"gpu_ops/scaled_gemm_f8_i4_f16_weight_quantize.cu",
|
||||||
"gpu_ops/cutlass_kernels/cutlass_heuristic.cu",
|
"gpu_ops/cutlass_kernels/cutlass_heuristic.cu",
|
||||||
"gpu_ops/cutlass_kernels/cutlass_preprocessors.cu",
|
"gpu_ops/cutlass_kernels/cutlass_preprocessors.cu",
|
||||||
|
"gpu_ops/fused_hadamard_quant_fp8.cu"
|
||||||
]
|
]
|
||||||
|
|
||||||
sources += find_end_files(fp8_auto_gen_directory, ".cu")
|
sources += find_end_files(fp8_auto_gen_directory, ".cu")
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -20,8 +20,7 @@ from paddle import nn
|
|||||||
import fastdeploy
|
import fastdeploy
|
||||||
from fastdeploy.distributed.communication_op import \
|
from fastdeploy.distributed.communication_op import \
|
||||||
tensor_model_parallel_all_reduce
|
tensor_model_parallel_all_reduce
|
||||||
from fastdeploy.model_executor.layers.utils import (create_hadamard_matrix_map,
|
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||||
get_tensor)
|
|
||||||
from fastdeploy.utils import ceil_div
|
from fastdeploy.utils import ceil_div
|
||||||
|
|
||||||
from ..quantization.quant_base import QuantMethodBase
|
from ..quantization.quant_base import QuantMethodBase
|
||||||
@@ -272,8 +271,8 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
layer.moe_intermediate_size, layer.hidden_size
|
layer.moe_intermediate_size, layer.hidden_size
|
||||||
]
|
]
|
||||||
|
|
||||||
ffn1_tensor = paddle.stack(ffn1_tensor, axis=0)
|
ffn1_tensor = paddle.stack(ffn1_tensor, axis=0).view(paddle.float8_e4m3fn)
|
||||||
ffn2_tensor = paddle.stack(ffn2_tensor, axis=0)
|
ffn2_tensor = paddle.stack(ffn2_tensor, axis=0).view(paddle.float8_e4m3fn)
|
||||||
|
|
||||||
added_wfp8afp8_attrs = [
|
added_wfp8afp8_attrs = [
|
||||||
"moe_ffn1_weight", "moe_ffn2_weight", "moe_ffn1_weight_scale",
|
"moe_ffn1_weight", "moe_ffn2_weight", "moe_ffn1_weight_scale",
|
||||||
@@ -309,7 +308,10 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
dtype=weight_tensor.dtype,
|
dtype=weight_tensor.dtype,
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
))
|
))
|
||||||
getattr(layer, name).set_value(weight_tensor)
|
if weight_tensor.dtype == paddle.float8_e4m3fn:
|
||||||
|
getattr(layer, name).copy_(weight_tensor, False)
|
||||||
|
else:
|
||||||
|
getattr(layer, name).set_value(weight_tensor)
|
||||||
|
|
||||||
def create_weights(self, layer: nn.Layer, state_dict):
|
def create_weights(self, layer: nn.Layer, state_dict):
|
||||||
"""
|
"""
|
||||||
@@ -333,13 +335,13 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
moe_intermediate_size = layer.moe_intermediate_size
|
moe_intermediate_size = layer.moe_intermediate_size
|
||||||
hidden_size = layer.hidden_size
|
hidden_size = layer.hidden_size
|
||||||
|
|
||||||
scores = paddle.nn.functional.softmax(gate_out, axis=-1)
|
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||||
|
gate_out,
|
||||||
topk_weights, topk_ids = paddle.topk(scores,
|
layer.gate_correction_bias,
|
||||||
k=top_k,
|
top_k,
|
||||||
axis=-1,
|
True, # apply_norm_weight,
|
||||||
sorted=False)
|
False,
|
||||||
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdim=True)
|
)
|
||||||
|
|
||||||
intermediate_cache1 = paddle.empty(
|
intermediate_cache1 = paddle.empty(
|
||||||
[token_num * top_k, moe_intermediate_size * 2],
|
[token_num * top_k, moe_intermediate_size * 2],
|
||||||
@@ -354,34 +356,31 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
dtype=x.dtype,
|
dtype=x.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = {
|
config_ffn1 = {
|
||||||
"BLOCK_SIZE_M": 32,
|
"BLOCK_SIZE_M": 32,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 256,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func(
|
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func(
|
||||||
topk_ids, num_local_experts, config["BLOCK_SIZE_M"])
|
topk_ids, num_local_experts, config_ffn1["BLOCK_SIZE_M"])
|
||||||
max_possible_num_post_padded = sorted_token_ids.shape[0]
|
max_possible_num_post_padded = sorted_token_ids.shape[0]
|
||||||
grid = (
|
grid = (
|
||||||
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
|
ceil_div(max_possible_num_post_padded, config_ffn1["BLOCK_SIZE_M"]) *
|
||||||
ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), )
|
ceil_div(moe_intermediate_size * 2, config_ffn1["BLOCK_SIZE_N"]), )
|
||||||
|
|
||||||
adamard_matrix = create_hadamard_matrix_map[hidden_size]
|
permute_x = fastdeploy.model_executor.ops.gpu.moe_fused_hadamard_quant_fp8(
|
||||||
x = paddle.matmul(x.cast("float32"), adamard_matrix)
|
x,
|
||||||
|
scale=layer.moe_ffn1_in_scale,
|
||||||
permute_x = x[:, None, :].tile([1, top_k, 1])
|
topk_ids=topk_ids,
|
||||||
permute_x = permute_x.reshape([-1, hidden_size])
|
top_k=top_k,
|
||||||
|
intermediate_size=hidden_size,
|
||||||
quant_activation_scale = layer.moe_ffn1_in_scale[topk_ids].reshape(
|
tiled=False)
|
||||||
[-1, 1])
|
|
||||||
permute_x = permute_x / quant_activation_scale
|
|
||||||
permute_x = permute_x.astype("float8_e4m3fn")
|
|
||||||
|
|
||||||
fused_moe_kernel_paddle[grid](
|
fused_moe_kernel_paddle[grid](
|
||||||
permute_x,
|
permute_x,
|
||||||
layer.moe_ffn1_weight.view(paddle.float8_e4m3fn),
|
layer.moe_ffn1_weight,
|
||||||
intermediate_cache1,
|
intermediate_cache1,
|
||||||
layer.moe_ffn1_in_scale,
|
layer.moe_ffn1_in_scale,
|
||||||
layer.moe_ffn1_weight_scale,
|
layer.moe_ffn1_weight_scale,
|
||||||
@@ -409,36 +408,43 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
group_n=-1,
|
group_n=-1,
|
||||||
group_k=-1,
|
group_k=-1,
|
||||||
# Meta-parameters
|
# Meta-parameters
|
||||||
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
|
BLOCK_SIZE_M=config_ffn1["BLOCK_SIZE_M"],
|
||||||
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
|
BLOCK_SIZE_N=config_ffn1["BLOCK_SIZE_N"],
|
||||||
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
|
BLOCK_SIZE_K=config_ffn1["BLOCK_SIZE_K"],
|
||||||
GROUP_SIZE_M=config["GROUP_SIZE_M"],
|
GROUP_SIZE_M=config_ffn1["GROUP_SIZE_M"],
|
||||||
MUL_ROUTED_WEIGHT=False,
|
MUL_ROUTED_WEIGHT=False,
|
||||||
top_k=1,
|
top_k=1,
|
||||||
compute_type_enum=1,
|
compute_type_enum=1,
|
||||||
use_fp8_w8a8=True,
|
use_fp8_w8a8=True,
|
||||||
use_int8_w8a16=False,
|
use_int8_w8a16=False,
|
||||||
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
|
even_Ks=hidden_size % config_ffn1["BLOCK_SIZE_K"] == 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
|
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
|
||||||
intermediate_cache1)
|
intermediate_cache1)
|
||||||
|
|
||||||
hadamard_matrix = create_hadamard_matrix_map[moe_intermediate_size]
|
intermediate_cache2 = fastdeploy.model_executor.ops.gpu.moe_fused_hadamard_quant_fp8(
|
||||||
intermediate_cache2 = paddle.matmul(
|
intermediate_cache2,
|
||||||
intermediate_cache2.cast("float32"), hadamard_matrix)
|
scale=layer.moe_ffn2_in_scale,
|
||||||
quant_activation_scale = layer.moe_ffn2_in_scale[topk_ids].reshape(
|
topk_ids=topk_ids,
|
||||||
[-1, 1])
|
top_k=top_k,
|
||||||
intermediate_cache2 = intermediate_cache2 / quant_activation_scale
|
intermediate_size=moe_intermediate_size,
|
||||||
intermediate_cache2 = intermediate_cache2.astype("float8_e4m3fn")
|
tiled=True)
|
||||||
|
|
||||||
|
config_ffn2 = {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
}
|
||||||
|
|
||||||
grid = (
|
grid = (
|
||||||
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
|
ceil_div(max_possible_num_post_padded, config_ffn2["BLOCK_SIZE_M"]) *
|
||||||
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
|
ceil_div(hidden_size, config_ffn2["BLOCK_SIZE_N"]), )
|
||||||
|
|
||||||
fused_moe_kernel_paddle[grid](
|
fused_moe_kernel_paddle[grid](
|
||||||
intermediate_cache2,
|
intermediate_cache2,
|
||||||
layer.moe_ffn2_weight.view(paddle.float8_e4m3fn),
|
layer.moe_ffn2_weight,
|
||||||
intermediate_cache3,
|
intermediate_cache3,
|
||||||
layer.moe_ffn2_in_scale,
|
layer.moe_ffn2_in_scale,
|
||||||
layer.moe_ffn2_weight_scale,
|
layer.moe_ffn2_weight_scale,
|
||||||
@@ -465,16 +471,16 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
group_n=-1,
|
group_n=-1,
|
||||||
group_k=-1,
|
group_k=-1,
|
||||||
# Meta-parameters
|
# Meta-parameters
|
||||||
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
|
BLOCK_SIZE_M=config_ffn2["BLOCK_SIZE_M"],
|
||||||
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
|
BLOCK_SIZE_N=config_ffn2["BLOCK_SIZE_N"],
|
||||||
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
|
BLOCK_SIZE_K=config_ffn2["BLOCK_SIZE_K"],
|
||||||
GROUP_SIZE_M=config["GROUP_SIZE_M"],
|
GROUP_SIZE_M=config_ffn2["GROUP_SIZE_M"],
|
||||||
MUL_ROUTED_WEIGHT=True,
|
MUL_ROUTED_WEIGHT=True,
|
||||||
top_k=1,
|
top_k=1,
|
||||||
compute_type_enum=1,
|
compute_type_enum=1,
|
||||||
use_fp8_w8a8=True,
|
use_fp8_w8a8=True,
|
||||||
use_int8_w8a16=False,
|
use_int8_w8a16=False,
|
||||||
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
|
even_Ks=moe_intermediate_size % config_ffn2["BLOCK_SIZE_K"] == 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
intermediate_cache3.reshape_([token_num, top_k, hidden_size])
|
intermediate_cache3.reshape_([token_num, top_k, hidden_size])
|
||||||
|
@@ -15,8 +15,6 @@
|
|||||||
"""
|
"""
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import paddle
|
|
||||||
|
|
||||||
from fastdeploy.model_executor.layers.moe import FusedMoE
|
from fastdeploy.model_executor.layers.moe import FusedMoE
|
||||||
|
|
||||||
from ..utils import get_tensor
|
from ..utils import get_tensor
|
||||||
@@ -113,15 +111,10 @@ class TensorWiseFP8LinearMethod(QuantMethodBase):
|
|||||||
"""
|
"""
|
||||||
compute!
|
compute!
|
||||||
"""
|
"""
|
||||||
from fastdeploy.model_executor.ops.gpu import \
|
from fastdeploy.model_executor.ops.gpu import (
|
||||||
cutlass_fp8_fp8_half_gemm_fused
|
cutlass_fp8_fp8_half_gemm_fused, fused_hadamard_quant_fp8)
|
||||||
|
|
||||||
from ..utils import create_hadamard_matrix_map
|
fp8_x = fused_hadamard_quant_fp8(x, scale=self.act_scale)
|
||||||
|
|
||||||
hadamard_matrix = create_hadamard_matrix_map[x.shape[-1]]
|
|
||||||
new_x = paddle.matmul(x.cast("float32"), hadamard_matrix)
|
|
||||||
fp8_x = new_x / self.act_scale
|
|
||||||
fp8_x = fp8_x.astype("float8_e4m3fn")
|
|
||||||
|
|
||||||
linear_out = cutlass_fp8_fp8_half_gemm_fused(
|
linear_out = cutlass_fp8_fp8_half_gemm_fused(
|
||||||
fp8_x,
|
fp8_x,
|
||||||
|
Reference in New Issue
Block a user