diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index 430f3104b..f24936138 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -552,7 +552,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): weight_list.append(quant_weight) weight_scale_list.append(scale) quanted_weight = paddle.stack(weight_list, axis=0) - quanted_weight = quanted_weight.transpose([0, 2, 1]).contiguous() + quanted_weight = quanted_weight.transpose([0, 2, 1]).contiguous().view(paddle.float8_e4m3fn) create_and_set_parameter(layer, weight_name, quanted_weight) quanted_weight_scale = paddle.stack(weight_scale_list, axis=0) @@ -606,11 +606,14 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): "num_warps": 4, "num_stages": 3, } - from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess + from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func - sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess( - topk_ids, num_local_experts, config["BLOCK_SIZE_M"] - ) + sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func( + topk_ids, num_local_experts, config["BLOCK_SIZE_M"]) + # cache13 = create_empty_tensor(tuple([token_num * top_k * max(N1, N2)]), x.dtype) + cache13 = paddle.empty([token_num * top_k * max(N1, N2)], dtype=x.dtype) + intermediate_cache1 = cache13[:token_num * top_k * N1].view( + [token_num * top_k, N1]) max_num_tokens_padded = sorted_token_ids.shape[0] grid = ( @@ -622,13 +625,9 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, self.quant_config.weight_block_size[0]) - cache13 = paddle.empty([token_num * top_k * max(N1, N2)], dtype=x.dtype) - intermediate_cache1 = cache13[: token_num * top_k * N1].view([token_num * top_k, N1]) - intermediate_cache3 = cache13[: token_num * top_k * N2].view([token_num * top_k, N2]) - fused_moe_kernel_paddle[grid]( x_q, - layer.up_gate_proj_weight.view(paddle.float8_e4m3fn), + layer.up_gate_proj_weight, intermediate_cache1, x_scale, layer.up_gate_proj_weight_scale, @@ -670,9 +669,11 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): intermediate_cache2 = paddle.incubate.nn.functional.swiglu(intermediate_cache1) - grid = ( - ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) * ceil_div(hidden_size, config["BLOCK_SIZE_N"]), - ) + intermediate_cache3 = cache13[:token_num * top_k * N2].view( + [token_num * top_k, N2]) + + grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) * + ceil_div(hidden_size, config["BLOCK_SIZE_N"]), ) x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant( intermediate_cache2, self.quant_config.weight_block_size[0] @@ -680,7 +681,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): fused_moe_kernel_paddle[grid]( x_q, - layer.down_proj_weight.view(paddle.float8_e4m3fn), + layer.down_proj_weight, intermediate_cache3, x_scale, layer.down_proj_weight_scale, diff --git a/fastdeploy/model_executor/layers/utils.py b/fastdeploy/model_executor/layers/utils.py index 7ea753889..fa057965f 100644 --- a/fastdeploy/model_executor/layers/utils.py +++ b/fastdeploy/model_executor/layers/utils.py @@ -14,6 +14,7 @@ # limitations under the License. """ +import functools from typing import Tuple, Union import numpy as np @@ -375,3 +376,17 @@ def create_and_set_parameter(layer: nn.Layer, name: str, tensor: paddle.Tensor): ), ) getattr(layer, name).set_value(tensor) + +@functools.cache +def create_empty_tensor(shape: Tuple[int, ...], dtype: Union[paddle.dtype, str]) -> paddle.Tensor: + """ + Creates and caches an empty tensor with the specified shape and data type. + + Args: + shape (Tuple[int, ...]): A tuple representing the dimensions of the tensor. + dtype (Union[paddle.dtype, str]): The data type for the tensor, such as 'bfloat16', 'float16', etc. + + Returns: + paddle.Tensor: An empty tensor with the specified shape and data type. + """ + return paddle.empty(list(shape), dtype=dtype)