Revert "【New Feature】W4afp8 supports per group quantization (#4272)" (#4854)

This reverts commit 93fcf7e4ec.
This commit is contained in:
YuBaoku
2025-11-06 17:48:28 +08:00
committed by GitHub
parent 3478d20262
commit 819b2dbbae
26 changed files with 1718 additions and 4378 deletions

View File

@@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
file_dir = "./gpu_ops/w4afp8_gemm/"
@@ -32,12 +30,12 @@ gemm_template_head = """
#include <cutlass/numeric_types.h>
"""
gemm_template_case = """
void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}(
void w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(
const cutlass::float_e4m3_t * weight,
const cutlass::float_e4m3_t * input,
{cutlass_type} * out,
const float *weight_scale,
const float * input_dequant_scale,
const float *input_row_sum,
const int64_t *tokens,
const int64_t max_tokens,
cudaStream_t stream);
@@ -50,22 +48,22 @@ gemm_template_cu_head = """
"""
gemm_template_cu_template = """
void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}(
void w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(
const cutlass::float_e4m3_t * weight,
const cutlass::float_e4m3_t * input,
{cutlass_type} * out,
const float *weight_scale,
const float * input_dequant_scale,
const float *input_row_sum,
const int64_t *tokens,
const int64_t max_tokens,
cudaStream_t stream) {{
constexpr static int M = {M};
constexpr static int K = {K};
constexpr static int EXPERTS = {EXPERTS};
constexpr static int Batch = {BATCH};
constexpr static int TokenPackSize = {PADDING};
constexpr static int kBlockN = {N};
constexpr static int kGroupSize = {GROUPSIZE};
constexpr static int kBlockN_TAIL = {TAILN};
constexpr static int kBlockM = 128;
constexpr static int kBlockK = 128;
constexpr static int kNWarps = 4 + kBlockM / 16;
@@ -76,24 +74,22 @@ void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}(
using Kernel_traits = Kernel_traits<
kBlockM, kBlockN, kBlockK, kNWarps, kStages, kTiles,
M, K, TokenPackSize, kGroupSize, kCluster, cutlass::float_e4m3_t,
M, TokenPackSize, kBlockN_TAIL, kCluster, cutlass::float_e4m3_t,
{cutlass_type}>;
run_gemm<cutlass::float_e4m3_t, {cutlass_type},
Kernel_traits, M, K, EXPERTS, TokenPackSize, kGroupSize>
(weight, input, out, weight_scale, input_dequant_scale, tokens, max_tokens, stream);
Kernel_traits, M, K, Batch, TokenPackSize>
(weight, input, out, weight_scale,
input_row_sum, tokens, max_tokens, stream);
}}
"""
# [M, K, Number of experts, token Padding Size, weight K group size]
gemm_case = [
[8192, 3584, 16, 0, 128], # eb45T ffn1
[8192, 3584, 16, 512, 128], # eb45T ffn1
[7168, 8192, 16, 0, 128], # eb45T ffn2
[7168, 8192, 16, 512, 128], # eb45T ffn2
[1792, 8192, 64, 0, 8192], # eb45t ffn1
[8192, 896, 64, 0, 896], # eb45t ffn2
[1792, 8192, 64, 0, 128], # eb45t ffn1
[8192, 896, 64, 0, 128], # eb45t ffn2
[8192, 3584, 8, 0], # eb45T ffn1
[8192, 3584, 8, 2048], # eb45T ffn1
[7168, 8192, 8, 0], # eb45T ffn2
[7168, 8192, 8, 2048], # eb45T ffn2
[1792, 8192, 64, 0], # eb45t ffn1
[8192, 896, 64, 0], # eb45t ffn2
]
dtype = ["BF16"]
@@ -101,19 +97,6 @@ dtype = ["BF16"]
use_fast_compile = True
n_range = [256] if use_fast_compile else [i for i in range(16, 257, 16)]
all_cu_files = []
for type in dtype:
for case in gemm_case:
for n in n_range:
all_cu_files.append(f"w4afp8_gemm_M{case[0]}_N{n}_G{case[4]}_K{case[1]}_E{case[2]}_P{case[3]}_{type}.cu")
for file_path, empty_list, file_name_list in os.walk(file_dir):
for file_name in file_name_list:
if re.match(r"^w4afp8_gemm_M\d+_N\d+_.*\.cu$", file_name):
if file_name not in all_cu_files:
print("delete w4afp8 kernel file", file_path + file_name)
os.remove(file_path + file_name)
def get_cutlass_type(type):
if type == "BF16":
@@ -133,16 +116,28 @@ for type in dtype:
M=case[0],
K=case[1],
N=n,
EXPERTS=case[2],
BATCH=case[2],
TYPE=type,
PADDING=case[3],
GROUPSIZE=case[4],
TAILN=0,
cutlass_type=get_cutlass_type(type),
)
)
template_head_file.write(
gemm_template_case.format(
M=case[0],
K=case[1],
N=256,
BATCH=case[2],
TYPE=type,
PADDING=case[3],
TAILN=n - 16,
cutlass_type=get_cutlass_type(type),
)
)
template_cu_file = open(
f"{file_dir}w4afp8_gemm_M{case[0]}_N{n}_G{case[4]}_K{case[1]}_E{case[2]}_P{case[3]}_{type}.cu", "w"
f"{file_dir}w4afp8_gemm_M{case[0]}_N{n}_TAILN{0}_K{case[1]}_B{case[2]}_P{case[3]}_{type}.cu", "w"
)
template_cu_file.write(gemm_template_cu_head)
template_cu_file.write(
@@ -150,10 +145,29 @@ for type in dtype:
M=case[0],
K=case[1],
N=n,
EXPERTS=case[2],
BATCH=case[2],
TYPE=type,
PADDING=case[3],
GROUPSIZE=case[4],
TAILN=0,
cutlass_type=get_cutlass_type(type),
)
)
template_cu_file.close()
template_cu_file = open(
f"{file_dir}w4afp8_gemm_M{case[0]}_N{256}_TAILN{n-16}_K{case[1]}_B{case[2]}_P{case[3]}_{type}.cu", "w"
)
template_cu_file.write(gemm_template_cu_head)
template_cu_file.write(
gemm_template_cu_template.format(
M=case[0],
K=case[1],
N=256,
BATCH=case[2],
TYPE=type,
PADDING=case[3],
TAILN=n - 16,
cutlass_type=get_cutlass_type(type),
)
)
@@ -163,8 +177,8 @@ for type in dtype:
for type in dtype:
template_head_file.write("\n")
template_head_file.write(
"""#define GEMM_SWITCH_{TYPE}(_M, _K, _EXPERTS, _TokenPaddingSize, _kBlockN, _GROUPSIZE, ...) {{ \\
if (_M == 0 && _K == 0 && _EXPERTS == 0 && _TokenPaddingSize == 0 && _kBlockN == 0 && _GROUPSIZE == 0) {{ \\""".format(
"""#define GEMM_SWITCH_{TYPE}(_M, _K, _BATCH, _TokenPaddingSize, _kBlockN, _TailN, ...) {{ \\
if (_M == 0 && _K == 0 && _BATCH == 0 && _TokenPaddingSize == 0 && _kBlockN == 0 && _TailN == 0) {{ \\""".format(
TYPE=type
)
)
@@ -174,16 +188,23 @@ for type in dtype:
for case in gemm_case:
for n in n_range:
template_head_file.write(
""" }} else if (_M == {M} && _K == {K} && _EXPERTS == {EXPERTS} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _GROUPSIZE == {GROUPSIZE}) {{ \\
w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format(
M=case[0], K=case[1], N=n, EXPERTS=case[2], TYPE=type, PADDING=case[3], GROUPSIZE=case[4]
""" }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\
w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format(
M=case[0], K=case[1], N=n, BATCH=case[2], TYPE=type, PADDING=case[3], TAILN=0
)
)
template_head_file.write("\n")
template_head_file.write(
""" }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\
w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format(
M=case[0], K=case[1], N=256, BATCH=case[2], TYPE=type, PADDING=case[3], TAILN=n - 16
)
)
template_head_file.write("\n")
template_head_file.write(
""" } else { \\
PADDLE_THROW(phi::errors::Unimplemented("W4aFp8 not supported m=%d k=%d experts=%d token_padding_size=%d kBlockN=%d groupsize=%d\\n", _M, _K, _EXPERTS, _TokenPaddingSize, _kBlockN, _GROUPSIZE)); \\
PADDLE_THROW(phi::errors::Unimplemented("W4aFp8 not supported m=%d k=%d batch=%d token_padding_size=%d kBlockN=%d tailN=%d\\n", _M, _K, _BATCH, _TokenPaddingSize, _kBlockN, _TailN)); \\
} \\
}"""
)