add w4afp8 offline script (#3636)

This commit is contained in:
Yuan Xiaolan
2025-08-29 17:56:05 +08:00
committed by GitHub
parent f677c032c0
commit c71ee0831c
12 changed files with 163 additions and 37 deletions

View File

@@ -83,10 +83,18 @@ void w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(
}}
"""
gemm_case = [[256, 256, 1, 0]]
gemm_case = [
[8192, 3584, 8, 0], # eb45T ffn1
[8192, 3584, 8, 2048], # eb45T ffn1
[7168, 8192, 8, 0], # eb45T ffn2
[7168, 8192, 8, 2048], # eb45T ffn2
]
dtype = ["BF16"]
use_fast_compile = True
n_range = [256] if use_fast_compile else [i for i in range(16, 257, 16)]
def get_cutlass_type(type):
if type == "BF16":
@@ -100,7 +108,7 @@ template_head_file.write(gemm_template_head)
for type in dtype:
for case in gemm_case:
for n in range(16, 257, 16):
for n in n_range:
template_head_file.write(
gemm_template_case.format(
M=case[0],
@@ -176,7 +184,7 @@ for type in dtype:
template_head_file.write("\n")
for case in gemm_case:
for n in range(16, 257, 16):
for n in n_range:
template_head_file.write(
""" }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\
w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format(