mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
add w4afp8 offline script (#3636)
This commit is contained in:
@@ -83,10 +83,18 @@ void w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(
|
||||
}}
|
||||
"""
|
||||
|
||||
gemm_case = [[256, 256, 1, 0]]
|
||||
gemm_case = [
|
||||
[8192, 3584, 8, 0], # eb45T ffn1
|
||||
[8192, 3584, 8, 2048], # eb45T ffn1
|
||||
[7168, 8192, 8, 0], # eb45T ffn2
|
||||
[7168, 8192, 8, 2048], # eb45T ffn2
|
||||
]
|
||||
|
||||
dtype = ["BF16"]
|
||||
|
||||
use_fast_compile = True
|
||||
n_range = [256] if use_fast_compile else [i for i in range(16, 257, 16)]
|
||||
|
||||
|
||||
def get_cutlass_type(type):
|
||||
if type == "BF16":
|
||||
@@ -100,7 +108,7 @@ template_head_file.write(gemm_template_head)
|
||||
|
||||
for type in dtype:
|
||||
for case in gemm_case:
|
||||
for n in range(16, 257, 16):
|
||||
for n in n_range:
|
||||
template_head_file.write(
|
||||
gemm_template_case.format(
|
||||
M=case[0],
|
||||
@@ -176,7 +184,7 @@ for type in dtype:
|
||||
template_head_file.write("\n")
|
||||
|
||||
for case in gemm_case:
|
||||
for n in range(16, 257, 16):
|
||||
for n in n_range:
|
||||
template_head_file.write(
|
||||
""" }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\
|
||||
w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format(
|
||||
|
Reference in New Issue
Block a user