mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-03 07:46:50 +08:00
【New Feature】支持Fp8 group Gemm 24稀疏 (#3463)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
* 支持24稀疏 * code style * 增加stmatrix 宏定义判断 * code style
This commit is contained in:
207
custom_ops/utils/auto_gen_wfp8afp8_sparse_gemm_kernel.py
Normal file
207
custom_ops/utils/auto_gen_wfp8afp8_sparse_gemm_kernel.py
Normal file
@@ -0,0 +1,207 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
file_dir = "./gpu_ops/wfp8afp8_sparse_gemm/"
|
||||
|
||||
gemm_template_head = """
|
||||
#pragma once
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
#include <cuda_fp16.h>
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
"""
|
||||
gemm_template_case = """
|
||||
void wfp8afp8_sparse_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(
|
||||
const cutlass::float_e4m3_t * weight,
|
||||
const uint32_t * sparse_idx,
|
||||
const cutlass::float_e4m3_t * input,
|
||||
{cutlass_type} * out,
|
||||
const float *weight_scale,
|
||||
const int *tokens,
|
||||
const int max_tokens,
|
||||
cudaStream_t stream);
|
||||
"""
|
||||
|
||||
gemm_template_cu_head = """
|
||||
#include "paddle/extension.h"
|
||||
#include "wfp8Afp8_sparse_gemm_template.h"
|
||||
#include "w8a8_sparse_gemm_kernel.hpp"
|
||||
|
||||
"""
|
||||
gemm_template_cu_template = """
|
||||
void wfp8afp8_sparse_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(
|
||||
const cutlass::float_e4m3_t * weight,
|
||||
const uint32_t * sparse_idx,
|
||||
const cutlass::float_e4m3_t * input,
|
||||
{cutlass_type} * out,
|
||||
const float *weight_scale,
|
||||
const int *tokens,
|
||||
const int max_tokens,
|
||||
cudaStream_t stream) {{
|
||||
|
||||
constexpr static int M = {M};
|
||||
constexpr static int K = {K};
|
||||
constexpr static int Batch = {BATCH};
|
||||
constexpr static int TokenPackSize = {PADDING};
|
||||
constexpr static int kBlockN = {N};
|
||||
constexpr static int kBlockN_TAIL = {TAILN};
|
||||
constexpr static int kBlockM = 128;
|
||||
constexpr static int kBlockK = 128;
|
||||
constexpr static int kNWarps = 4 + kBlockM / 16;
|
||||
constexpr static int kStages = 5;
|
||||
constexpr int kCluster = 1;
|
||||
static_assert(K % kBlockK == 0);
|
||||
constexpr int kTiles = K / kBlockK;
|
||||
|
||||
using Kernel_traits = Kernel_traits<
|
||||
kBlockM, kBlockN, kBlockK, kNWarps, kStages, kTiles,
|
||||
M, TokenPackSize, kBlockN_TAIL, kCluster, cutlass::float_e4m3_t,
|
||||
{cutlass_type}>;
|
||||
run_gemm<cutlass::float_e4m3_t, {cutlass_type},
|
||||
Kernel_traits, M, K, Batch, TokenPackSize>
|
||||
(weight, sparse_idx, input, out, weight_scale,
|
||||
tokens, max_tokens, stream);
|
||||
}}
|
||||
"""
|
||||
|
||||
gemm_case = [
|
||||
[128, 128, 1, 0],
|
||||
[7168, 8192, 8, 0], # eb45T ffn1
|
||||
]
|
||||
|
||||
dtype = ["BF16"]
|
||||
|
||||
|
||||
def get_cutlass_type(type):
|
||||
if type == "BF16":
|
||||
return "cutlass::bfloat16_t"
|
||||
elif type == "FP16":
|
||||
return "cutlass::half_t"
|
||||
|
||||
|
||||
template_head_file = open(f"{file_dir}wfp8Afp8_sparse_gemm_template.h", "w")
|
||||
template_head_file.write(gemm_template_head)
|
||||
|
||||
for type in dtype:
|
||||
for case in gemm_case:
|
||||
for n in range(32, 257, 32):
|
||||
template_head_file.write(
|
||||
gemm_template_case.format(
|
||||
M=case[0],
|
||||
K=case[1],
|
||||
N=n,
|
||||
BATCH=case[2],
|
||||
TYPE=type,
|
||||
PADDING=case[3],
|
||||
TAILN=0,
|
||||
cutlass_type=get_cutlass_type(type),
|
||||
)
|
||||
)
|
||||
template_head_file.write(
|
||||
gemm_template_case.format(
|
||||
M=case[0],
|
||||
K=case[1],
|
||||
N=256,
|
||||
BATCH=case[2],
|
||||
TYPE=type,
|
||||
PADDING=case[3],
|
||||
TAILN=n - 32,
|
||||
cutlass_type=get_cutlass_type(type),
|
||||
)
|
||||
)
|
||||
|
||||
template_cu_file = open(
|
||||
f"{file_dir}wfp8Afp8_sparse_gemm_M{case[0]}_N{n}_TAILN{0}_K{case[1]}_B{case[2]}_P{case[3]}_{type}.cu",
|
||||
"w",
|
||||
)
|
||||
template_cu_file.write(gemm_template_cu_head)
|
||||
template_cu_file.write(
|
||||
gemm_template_cu_template.format(
|
||||
M=case[0],
|
||||
K=case[1],
|
||||
N=n,
|
||||
BATCH=case[2],
|
||||
TYPE=type,
|
||||
PADDING=case[3],
|
||||
TAILN=0,
|
||||
cutlass_type=get_cutlass_type(type),
|
||||
)
|
||||
)
|
||||
|
||||
template_cu_file.close()
|
||||
|
||||
template_cu_file = open(
|
||||
f"{file_dir}wfp8Afp8_sparse_gemm_M{case[0]}_N{256}_TAILN{n-32}_K{case[1]}_B{case[2]}_P{case[3]}_{type}.cu",
|
||||
"w",
|
||||
)
|
||||
template_cu_file.write(gemm_template_cu_head)
|
||||
template_cu_file.write(
|
||||
gemm_template_cu_template.format(
|
||||
M=case[0],
|
||||
K=case[1],
|
||||
N=256,
|
||||
BATCH=case[2],
|
||||
TYPE=type,
|
||||
PADDING=case[3],
|
||||
TAILN=n - 32,
|
||||
cutlass_type=get_cutlass_type(type),
|
||||
)
|
||||
)
|
||||
|
||||
template_cu_file.close()
|
||||
|
||||
for type in dtype:
|
||||
template_head_file.write("\n")
|
||||
template_head_file.write(
|
||||
"""#define SPARSE_GEMM_SWITCH_{TYPE}(_M, _K, _BATCH, _TokenPaddingSize, _kBlockN, _TailN, ...) {{ \\
|
||||
if (_M == 0 && _K == 0 && _BATCH == 0 && _TokenPaddingSize == 0 && _kBlockN == 0 && _TailN == 0) {{ \\""".format(
|
||||
TYPE=type
|
||||
)
|
||||
)
|
||||
|
||||
template_head_file.write("\n")
|
||||
|
||||
for case in gemm_case:
|
||||
for n in range(32, 257, 32):
|
||||
template_head_file.write(
|
||||
""" }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\
|
||||
wfp8afp8_sparse_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format(
|
||||
M=case[0], K=case[1], N=n, BATCH=case[2], TYPE=type, PADDING=case[3], TAILN=0
|
||||
)
|
||||
)
|
||||
template_head_file.write("\n")
|
||||
template_head_file.write(
|
||||
""" }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\
|
||||
wfp8afp8_sparse_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format(
|
||||
M=case[0], K=case[1], N=256, BATCH=case[2], TYPE=type, PADDING=case[3], TAILN=n - 32
|
||||
)
|
||||
)
|
||||
template_head_file.write("\n")
|
||||
|
||||
template_head_file.write(
|
||||
""" } else { \\
|
||||
PADDLE_THROW(phi::errors::Unimplemented("WFp8aFp8 Sparse not supported m=%d k=%d batch=%d token_padding_size=%d kBlockN=%d tailN=%d\\n", _M, _K, _BATCH, _TokenPaddingSize, _kBlockN, _TailN)); \\
|
||||
} \\
|
||||
}"""
|
||||
)
|
||||
|
||||
template_head_file.close()
|
Reference in New Issue
Block a user