【Infer】Improve the performance block_wise_fp8 of triton_moe_backend (#2942)

This commit is contained in:
chen
2025-07-23 13:02:50 +08:00
committed by GitHub
parent e51f018577
commit ad202272ed
2 changed files with 30 additions and 14 deletions

View File

@@ -552,7 +552,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
weight_list.append(quant_weight)
weight_scale_list.append(scale)
quanted_weight = paddle.stack(weight_list, axis=0)
quanted_weight = quanted_weight.transpose([0, 2, 1]).contiguous()
quanted_weight = quanted_weight.transpose([0, 2, 1]).contiguous().view(paddle.float8_e4m3fn)
create_and_set_parameter(layer, weight_name, quanted_weight)
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
@@ -606,11 +606,14 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
"num_warps": 4,
"num_stages": 3,
}
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
topk_ids, num_local_experts, config["BLOCK_SIZE_M"]
)
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func(
topk_ids, num_local_experts, config["BLOCK_SIZE_M"])
# cache13 = create_empty_tensor(tuple([token_num * top_k * max(N1, N2)]), x.dtype)
cache13 = paddle.empty([token_num * top_k * max(N1, N2)], dtype=x.dtype)
intermediate_cache1 = cache13[:token_num * top_k * N1].view(
[token_num * top_k, N1])
max_num_tokens_padded = sorted_token_ids.shape[0]
grid = (
@@ -622,13 +625,9 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, self.quant_config.weight_block_size[0])
cache13 = paddle.empty([token_num * top_k * max(N1, N2)], dtype=x.dtype)
intermediate_cache1 = cache13[: token_num * top_k * N1].view([token_num * top_k, N1])
intermediate_cache3 = cache13[: token_num * top_k * N2].view([token_num * top_k, N2])
fused_moe_kernel_paddle[grid](
x_q,
layer.up_gate_proj_weight.view(paddle.float8_e4m3fn),
layer.up_gate_proj_weight,
intermediate_cache1,
x_scale,
layer.up_gate_proj_weight_scale,
@@ -670,9 +669,11 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(intermediate_cache1)
grid = (
ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) * ceil_div(hidden_size, config["BLOCK_SIZE_N"]),
)
intermediate_cache3 = cache13[:token_num * top_k * N2].view(
[token_num * top_k, N2])
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) *
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(
intermediate_cache2, self.quant_config.weight_block_size[0]
@@ -680,7 +681,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
fused_moe_kernel_paddle[grid](
x_q,
layer.down_proj_weight.view(paddle.float8_e4m3fn),
layer.down_proj_weight,
intermediate_cache3,
x_scale,
layer.down_proj_weight_scale,