[Bug fix] fix complie bug when sm < 89 (#2738)

This commit is contained in:
ming1753
2025-07-08 11:24:52 +08:00
committed by GitHub
parent ef6649a577
commit 1eb8ea7328
2 changed files with 44 additions and 46 deletions

View File

@@ -129,18 +129,10 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
True, # apply_norm_weight,
False,
)
intermediate_cache1 = paddle.empty(
ffn1_out = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2],
dtype=x.dtype,
)
intermediate_cache2 = paddle.empty(
(token_num * top_k, moe_intermediate_size),
dtype=x.dtype,
)
intermediate_cache3 = paddle.empty(
(token_num * top_k, hidden_size),
dtype=x.dtype,
)
config = {
"BLOCK_SIZE_M": 32,
@@ -158,7 +150,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
fused_moe_kernel_paddle[grid](
x,
layer.moe_ffn1_weight,
intermediate_cache1,
ffn1_out,
None,
layer.moe_ffn1_weight_scale,
None,
@@ -174,8 +166,8 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
stride_be=layer.moe_ffn1_weight.strides[0],
stride_bk=layer.moe_ffn1_weight.strides[1],
stride_bn=layer.moe_ffn1_weight.strides[2],
stride_cm=intermediate_cache1.strides[0],
stride_cn=intermediate_cache1.strides[1],
stride_cm=ffn1_out.strides[0],
stride_cn=ffn1_out.strides[1],
#
stride_asm=-1,
stride_ask=-1,
@@ -197,16 +189,21 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
)
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
intermediate_cache1)
ffn2_input = paddle.incubate.nn.functional.swiglu(
ffn1_out)
ffn2_out = paddle.empty(
(token_num * top_k, hidden_size),
dtype=x.dtype,
)
grid = (
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
fused_moe_kernel_paddle[grid](
intermediate_cache2,
ffn2_input,
layer.moe_ffn2_weight,
intermediate_cache3,
ffn2_out,
None,
layer.moe_ffn2_weight_scale,
topk_weights,
@@ -217,13 +214,13 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
token_num * top_k,
N=hidden_size,
K=moe_intermediate_size,
stride_am=intermediate_cache2.strides[0],
stride_ak=intermediate_cache2.strides[1],
stride_am=ffn2_input.strides[0],
stride_ak=ffn2_input.strides[1],
stride_be=layer.moe_ffn2_weight.strides[0],
stride_bk=layer.moe_ffn2_weight.strides[1],
stride_bn=layer.moe_ffn2_weight.strides[2],
stride_cm=intermediate_cache3.strides[0],
stride_cn=intermediate_cache3.strides[1],
stride_cm=ffn2_out.strides[0],
stride_cn=ffn2_out.strides[1],
stride_asm=-1,
stride_ask=-1,
stride_bse=layer.moe_ffn2_weight_scale.strides[0],
@@ -244,8 +241,8 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
)
intermediate_cache3.reshape_([token_num, top_k, hidden_size])
out = intermediate_cache3.sum(axis=1)
ffn2_out.reshape_([token_num, top_k, hidden_size])
out = ffn2_out.sum(axis=1)
return out
@@ -343,18 +340,10 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
False,
)
intermediate_cache1 = paddle.empty(
ffn1_out = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2],
dtype=x.dtype,
)
intermediate_cache2 = paddle.empty(
(token_num * top_k, moe_intermediate_size),
dtype=x.dtype,
)
intermediate_cache3 = paddle.empty(
(token_num * top_k, hidden_size),
dtype=x.dtype,
)
config_ffn1 = {
"BLOCK_SIZE_M": 32,
@@ -381,7 +370,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
fused_moe_kernel_paddle[grid](
permute_x,
layer.moe_ffn1_weight,
intermediate_cache1,
ffn1_out,
layer.moe_ffn1_in_scale,
layer.moe_ffn1_weight_scale,
None,
@@ -397,8 +386,8 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
stride_be=layer.moe_ffn1_weight.strides[0],
stride_bk=layer.moe_ffn1_weight.strides[1],
stride_bn=layer.moe_ffn1_weight.strides[2],
stride_cm=intermediate_cache1.strides[0],
stride_cn=intermediate_cache1.strides[1],
stride_cm=ffn1_out.strides[0],
stride_cn=ffn1_out.strides[1],
#
stride_asm=-1, # only used in blockwise fp8
stride_ask=-1, # only used in blockwise fp8
@@ -420,11 +409,11 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
even_Ks=hidden_size % config_ffn1["BLOCK_SIZE_K"] == 0,
)
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
intermediate_cache1)
ffn2_input = paddle.incubate.nn.functional.swiglu(
ffn1_out)
intermediate_cache2 = fastdeploy.model_executor.ops.gpu.moe_fused_hadamard_quant_fp8(
intermediate_cache2,
ffn2_input = fastdeploy.model_executor.ops.gpu.moe_fused_hadamard_quant_fp8(
ffn2_input,
scale=layer.moe_ffn2_in_scale,
topk_ids=topk_ids,
top_k=top_k,
@@ -438,14 +427,19 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
"GROUP_SIZE_M": 1,
}
ffn2_out = paddle.empty(
(token_num * top_k, hidden_size),
dtype=x.dtype,
)
grid = (
ceil_div(max_possible_num_post_padded, config_ffn2["BLOCK_SIZE_M"]) *
ceil_div(hidden_size, config_ffn2["BLOCK_SIZE_N"]), )
fused_moe_kernel_paddle[grid](
intermediate_cache2,
ffn2_input,
layer.moe_ffn2_weight,
intermediate_cache3,
ffn2_out,
layer.moe_ffn2_in_scale,
layer.moe_ffn2_weight_scale,
topk_weights,
@@ -456,13 +450,13 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
token_num * top_k,
N=hidden_size,
K=moe_intermediate_size,
stride_am=intermediate_cache2.strides[0],
stride_ak=intermediate_cache2.strides[1],
stride_am=ffn2_input.strides[0],
stride_ak=ffn2_input.strides[1],
stride_be=layer.moe_ffn2_weight.strides[0],
stride_bk=layer.moe_ffn2_weight.strides[1],
stride_bn=layer.moe_ffn2_weight.strides[2],
stride_cm=intermediate_cache3.strides[0],
stride_cn=intermediate_cache3.strides[1],
stride_cm=ffn2_out.strides[0],
stride_cn=ffn2_out.strides[1],
stride_asm=-1,
stride_ask=-1,
stride_bse=-1,
@@ -483,8 +477,8 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
even_Ks=moe_intermediate_size % config_ffn2["BLOCK_SIZE_K"] == 0,
)
intermediate_cache3.reshape_([token_num, top_k, hidden_size])
out = intermediate_cache3.sum(axis=1)
ffn2_out.reshape_([token_num, top_k, hidden_size])
out = ffn2_out.sum(axis=1)
if layer.tp_size > 1:
tensor_model_parallel_all_reduce(out)