mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
【Infer】Improve the performance block_wise_fp8 of triton_moe_backend (#2942)
This commit is contained in:
@@ -552,7 +552,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
weight_list.append(quant_weight)
|
weight_list.append(quant_weight)
|
||||||
weight_scale_list.append(scale)
|
weight_scale_list.append(scale)
|
||||||
quanted_weight = paddle.stack(weight_list, axis=0)
|
quanted_weight = paddle.stack(weight_list, axis=0)
|
||||||
quanted_weight = quanted_weight.transpose([0, 2, 1]).contiguous()
|
quanted_weight = quanted_weight.transpose([0, 2, 1]).contiguous().view(paddle.float8_e4m3fn)
|
||||||
create_and_set_parameter(layer, weight_name, quanted_weight)
|
create_and_set_parameter(layer, weight_name, quanted_weight)
|
||||||
|
|
||||||
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
|
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
|
||||||
@@ -606,11 +606,14 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 3,
|
"num_stages": 3,
|
||||||
}
|
}
|
||||||
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
|
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func
|
||||||
|
|
||||||
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
|
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func(
|
||||||
topk_ids, num_local_experts, config["BLOCK_SIZE_M"]
|
topk_ids, num_local_experts, config["BLOCK_SIZE_M"])
|
||||||
)
|
# cache13 = create_empty_tensor(tuple([token_num * top_k * max(N1, N2)]), x.dtype)
|
||||||
|
cache13 = paddle.empty([token_num * top_k * max(N1, N2)], dtype=x.dtype)
|
||||||
|
intermediate_cache1 = cache13[:token_num * top_k * N1].view(
|
||||||
|
[token_num * top_k, N1])
|
||||||
max_num_tokens_padded = sorted_token_ids.shape[0]
|
max_num_tokens_padded = sorted_token_ids.shape[0]
|
||||||
|
|
||||||
grid = (
|
grid = (
|
||||||
@@ -622,13 +625,9 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
|
|
||||||
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, self.quant_config.weight_block_size[0])
|
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, self.quant_config.weight_block_size[0])
|
||||||
|
|
||||||
cache13 = paddle.empty([token_num * top_k * max(N1, N2)], dtype=x.dtype)
|
|
||||||
intermediate_cache1 = cache13[: token_num * top_k * N1].view([token_num * top_k, N1])
|
|
||||||
intermediate_cache3 = cache13[: token_num * top_k * N2].view([token_num * top_k, N2])
|
|
||||||
|
|
||||||
fused_moe_kernel_paddle[grid](
|
fused_moe_kernel_paddle[grid](
|
||||||
x_q,
|
x_q,
|
||||||
layer.up_gate_proj_weight.view(paddle.float8_e4m3fn),
|
layer.up_gate_proj_weight,
|
||||||
intermediate_cache1,
|
intermediate_cache1,
|
||||||
x_scale,
|
x_scale,
|
||||||
layer.up_gate_proj_weight_scale,
|
layer.up_gate_proj_weight_scale,
|
||||||
@@ -670,9 +669,11 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
|
|
||||||
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(intermediate_cache1)
|
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(intermediate_cache1)
|
||||||
|
|
||||||
grid = (
|
intermediate_cache3 = cache13[:token_num * top_k * N2].view(
|
||||||
ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) * ceil_div(hidden_size, config["BLOCK_SIZE_N"]),
|
[token_num * top_k, N2])
|
||||||
)
|
|
||||||
|
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) *
|
||||||
|
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
|
||||||
|
|
||||||
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(
|
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(
|
||||||
intermediate_cache2, self.quant_config.weight_block_size[0]
|
intermediate_cache2, self.quant_config.weight_block_size[0]
|
||||||
@@ -680,7 +681,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
|
|
||||||
fused_moe_kernel_paddle[grid](
|
fused_moe_kernel_paddle[grid](
|
||||||
x_q,
|
x_q,
|
||||||
layer.down_proj_weight.view(paddle.float8_e4m3fn),
|
layer.down_proj_weight,
|
||||||
intermediate_cache3,
|
intermediate_cache3,
|
||||||
x_scale,
|
x_scale,
|
||||||
layer.down_proj_weight_scale,
|
layer.down_proj_weight_scale,
|
||||||
|
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import functools
|
||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -375,3 +376,17 @@ def create_and_set_parameter(layer: nn.Layer, name: str, tensor: paddle.Tensor):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
getattr(layer, name).set_value(tensor)
|
getattr(layer, name).set_value(tensor)
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def create_empty_tensor(shape: Tuple[int, ...], dtype: Union[paddle.dtype, str]) -> paddle.Tensor:
|
||||||
|
"""
|
||||||
|
Creates and caches an empty tensor with the specified shape and data type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape (Tuple[int, ...]): A tuple representing the dimensions of the tensor.
|
||||||
|
dtype (Union[paddle.dtype, str]): The data type for the tensor, such as 'bfloat16', 'float16', etc.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
paddle.Tensor: An empty tensor with the specified shape and data type.
|
||||||
|
"""
|
||||||
|
return paddle.empty(list(shape), dtype=dtype)
|
||||||
|
Reference in New Issue
Block a user