mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00
196 lines
6.3 KiB
Python
196 lines
6.3 KiB
Python
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""UT for cutlass_fp8_fp8_half_gemm_fused"""
|
|
import paddle
|
|
|
|
from fastdeploy.utils import llm_logger as logger
|
|
|
|
|
|
def tune_cutlass_fp8_fp8_half_gemm_fused(
|
|
ns: list,
|
|
ks: list,
|
|
m_min: int = 32,
|
|
m_max: int = 32768,
|
|
):
|
|
"""
|
|
Tune fp8 gemm.
|
|
"""
|
|
assert len(ns) == len(ks), "The length of `ns` must be equal to that of `ks`"
|
|
try:
|
|
from fastdeploy.model_executor.ops.gpu import cutlass_fp8_fp8_half_gemm_fused
|
|
except ImportError:
|
|
logger.warning(
|
|
"From fastdeploy.model_executor.ops.gpu import cutlass_fp8_fp8_half_gemm_fused failed, \
|
|
fp8 is only support cuda arch 89+."
|
|
)
|
|
return
|
|
paddle.seed(2003)
|
|
for m in range(m_min, m_max + 32, 32):
|
|
if m > m_max:
|
|
break
|
|
for idx in range(len(ns)):
|
|
n = ns[idx]
|
|
k = ks[idx]
|
|
A = paddle.rand(shape=[m, k], dtype="bfloat16").astype("float8_e4m3fn")
|
|
B = paddle.rand(shape=[n, k], dtype="bfloat16").astype("float8_e4m3fn")
|
|
cutlass_fp8_fp8_half_gemm_fused(
|
|
A,
|
|
B,
|
|
bias=None,
|
|
transpose_x=False,
|
|
transpose_y=True,
|
|
output_dtype="bfloat16",
|
|
scale=0.5,
|
|
activation_type="identity",
|
|
)
|
|
paddle.device.cuda.empty_cache()
|
|
|
|
|
|
def tune_cutlass_fp8_fp8_fp8_dual_gemm_fused(
|
|
ns: list,
|
|
ks: list,
|
|
m_min: int = 32,
|
|
m_max: int = 32768,
|
|
):
|
|
"""
|
|
Tune fp8 dual-gemm.
|
|
"""
|
|
assert len(ns) == len(ks), "The length of `ns` must be equal to that of `ks`"
|
|
try:
|
|
from fastdeploy.model_executor.ops.gpu import (
|
|
cutlass_fp8_fp8_fp8_dual_gemm_fused,
|
|
)
|
|
except ImportError:
|
|
logger.warning(
|
|
"From fastdeploy.model_executor.ops.gpu import cutlass_fp8_fp8_fp8_dual_gemm_fused failed, \
|
|
fp8 is only support cuda arch 89+."
|
|
)
|
|
return
|
|
paddle.seed(2003)
|
|
for m in range(m_min, m_max + 32, 32):
|
|
if m > m_max:
|
|
break
|
|
for idx in range(len(ns)):
|
|
n = ns[idx]
|
|
k = ks[idx]
|
|
A = paddle.rand(shape=[m, k], dtype="bfloat16").astype("float8_e4m3fn")
|
|
B0 = paddle.rand(shape=[n, k], dtype="bfloat16").astype("float8_e4m3fn")
|
|
B1 = paddle.rand(shape=[n, k], dtype="bfloat16").astype("float8_e4m3fn")
|
|
cutlass_fp8_fp8_fp8_dual_gemm_fused(
|
|
A,
|
|
B0,
|
|
B1,
|
|
bias0=None,
|
|
bias1=None,
|
|
transpose_x=False,
|
|
transpose_y=True,
|
|
scale0=0.1,
|
|
scale1=0.1,
|
|
scale_out=0.5,
|
|
activation_type="swiglu",
|
|
)
|
|
paddle.device.cuda.empty_cache()
|
|
|
|
|
|
def tune_per_channel_fp8_gemm_fused(
|
|
ns: list,
|
|
ks: list,
|
|
m_min: int = 32,
|
|
m_max: int = 32768,
|
|
):
|
|
"""
|
|
Tune per-channel quant gemm.
|
|
"""
|
|
assert len(ns) == len(ks), "The length of `ns` must be equal to that of `ks`"
|
|
try:
|
|
from fastdeploy.model_executor.ops.gpu import (
|
|
per_channel_fp8_fp8_half_gemm_fused,
|
|
)
|
|
except ImportError:
|
|
logger.warning(
|
|
"From fastdeploy.model_executor.ops.gpu import per_channel_fp8_fp8_half_gemm_fused failed, \
|
|
fp8 is only support cuda arch 89+."
|
|
)
|
|
return
|
|
paddle.seed(2003)
|
|
for m in range(m_min, m_max + 32, 32):
|
|
if m > m_max:
|
|
break
|
|
for idx in range(len(ns)):
|
|
n = ns[idx]
|
|
k = ks[idx]
|
|
A = paddle.rand(shape=[m, k], dtype="bfloat16").astype("float8_e4m3fn")
|
|
B = paddle.rand(shape=[n, k], dtype="bfloat16").astype("float8_e4m3fn")
|
|
scalar_scale = paddle.full([1], 0.168, dtype="float32")
|
|
channel_scale = paddle.rand(shape=[n], dtype="float32")
|
|
|
|
per_channel_fp8_fp8_half_gemm_fused(
|
|
A,
|
|
B,
|
|
bias=None,
|
|
scalar_scale=scalar_scale,
|
|
channel_scale=channel_scale,
|
|
transpose_x=False,
|
|
transpose_y=True,
|
|
output_dtype="bfloat16",
|
|
)
|
|
paddle.device.cuda.empty_cache()
|
|
|
|
|
|
def tune_blockwise_fp8_gemm_fused(
|
|
ns: list,
|
|
ks: list,
|
|
m_min: int = 32,
|
|
m_max: int = 32768,
|
|
):
|
|
"""
|
|
Tune per-channel quant gemm.
|
|
"""
|
|
assert len(ns) == len(ks), "The length of `ns` must be equal to that of `ks`"
|
|
try:
|
|
from fastdeploy.model_executor.ops.gpu import (
|
|
cutlass_fp8_fp8_half_block_gemm_fused,
|
|
)
|
|
except ImportError:
|
|
logger.warning(
|
|
"From fastdeploy.model_executor.ops.gpu import cutlass_fp8_fp8_half_block_gemm_fused failed, \
|
|
fp8 is only support cuda arch 90+."
|
|
)
|
|
return
|
|
paddle.seed(2003)
|
|
for m in range(m_min, m_max + 32, 32):
|
|
if m > m_max:
|
|
break
|
|
for idx in range(len(ns)):
|
|
n = ns[idx]
|
|
k = ks[idx]
|
|
scale_n = (n + 128 - 1) // 128
|
|
scale_k = (k + 128 - 1) // 128
|
|
A = paddle.rand(shape=[m, k], dtype="bfloat16").astype("float8_e4m3fn")
|
|
B = paddle.rand(shape=[n, k], dtype="bfloat16").astype("float8_e4m3fn")
|
|
a_scale = paddle.randn([scale_k, m], dtype="float32")
|
|
b_scale = paddle.randn([scale_n, scale_k], dtype="float32")
|
|
|
|
cutlass_fp8_fp8_half_block_gemm_fused(
|
|
A,
|
|
B,
|
|
x_sacle=a_scale,
|
|
y_sacle=b_scale,
|
|
bias=None,
|
|
transpose_x=False,
|
|
transpose_y=True,
|
|
output_dtype="bfloat16",
|
|
)
|
|
paddle.device.cuda.empty_cache()
|