mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 11:56:44 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			199 lines
		
	
	
		
			6.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			199 lines
		
	
	
		
			6.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| """ UT for cutlass_fp8_fp8_half_gemm_fused """
 | |
| import paddle
 | |
| 
 | |
| from fastdeploy.utils import llm_logger as logger
 | |
| 
 | |
| 
 | |
| def tune_cutlass_fp8_fp8_half_gemm_fused(
 | |
|     ns: list,
 | |
|     ks: list,
 | |
|     m_min: int = 32,
 | |
|     m_max: int = 32768,
 | |
| ):
 | |
|     """
 | |
|     Tune fp8 gemm.
 | |
|     """
 | |
|     assert len(ns) == len(
 | |
|         ks), "The length of `ns` must be equal to that of `ks`"
 | |
|     try:
 | |
|         from fastdeploy.model_executor.ops.gpu import cutlass_fp8_fp8_half_gemm_fused
 | |
|     except ImportError:
 | |
|         logger.warning(
 | |
|             "From fastdeploy.model_executor.ops.gpu import cutlass_fp8_fp8_half_gemm_fused failed, \
 | |
|             fp8 is only support cuda arch 89+.")
 | |
|         return
 | |
|     paddle.seed(2003)
 | |
|     for m in range(m_min, m_max + 32, 32):
 | |
|         if m > m_max:
 | |
|             break
 | |
|         for idx in range(len(ns)):
 | |
|             n = ns[idx]
 | |
|             k = ks[idx]
 | |
|             A = paddle.rand(shape=[m, k],
 | |
|                             dtype="bfloat16").astype("float8_e4m3fn")
 | |
|             B = paddle.rand(shape=[n, k],
 | |
|                             dtype="bfloat16").astype("float8_e4m3fn")
 | |
|             cutlass_fp8_fp8_half_gemm_fused(
 | |
|                 A,
 | |
|                 B,
 | |
|                 bias=None,
 | |
|                 transpose_x=False,
 | |
|                 transpose_y=True,
 | |
|                 output_dtype="bfloat16",
 | |
|                 scale=0.5,
 | |
|                 activation_type="identity",
 | |
|             )
 | |
|             paddle.device.cuda.empty_cache()
 | |
| 
 | |
| 
 | |
| def tune_cutlass_fp8_fp8_fp8_dual_gemm_fused(
 | |
|     ns: list,
 | |
|     ks: list,
 | |
|     m_min: int = 32,
 | |
|     m_max: int = 32768,
 | |
| ):
 | |
|     """
 | |
|     Tune fp8 dual-gemm.
 | |
|     """
 | |
|     assert len(ns) == len(
 | |
|         ks), "The length of `ns` must be equal to that of `ks`"
 | |
|     try:
 | |
|         from fastdeploy.model_executor.ops.gpu import cutlass_fp8_fp8_fp8_dual_gemm_fused
 | |
|     except ImportError:
 | |
|         logger.warning(
 | |
|             "From fastdeploy.model_executor.ops.gpu import cutlass_fp8_fp8_fp8_dual_gemm_fused failed, \
 | |
|             fp8 is only support cuda arch 89+.")
 | |
|         return
 | |
|     paddle.seed(2003)
 | |
|     for m in range(m_min, m_max + 32, 32):
 | |
|         if m > m_max:
 | |
|             break
 | |
|         for idx in range(len(ns)):
 | |
|             n = ns[idx]
 | |
|             k = ks[idx]
 | |
|             A = paddle.rand(shape=[m, k],
 | |
|                             dtype="bfloat16").astype("float8_e4m3fn")
 | |
|             B0 = paddle.rand(shape=[n, k],
 | |
|                              dtype="bfloat16").astype("float8_e4m3fn")
 | |
|             B1 = paddle.rand(shape=[n, k],
 | |
|                              dtype="bfloat16").astype("float8_e4m3fn")
 | |
|             cutlass_fp8_fp8_fp8_dual_gemm_fused(
 | |
|                 A,
 | |
|                 B0,
 | |
|                 B1,
 | |
|                 bias0=None,
 | |
|                 bias1=None,
 | |
|                 transpose_x=False,
 | |
|                 transpose_y=True,
 | |
|                 scale0=0.1,
 | |
|                 scale1=0.1,
 | |
|                 scale_out=0.5,
 | |
|                 activation_type="swiglu",
 | |
|             )
 | |
|             paddle.device.cuda.empty_cache()
 | |
| 
 | |
| 
 | |
| def tune_per_channel_fp8_gemm_fused(
 | |
|     ns: list,
 | |
|     ks: list,
 | |
|     m_min: int = 32,
 | |
|     m_max: int = 32768,
 | |
| ):
 | |
|     """
 | |
|     Tune per-channel quant gemm.
 | |
|     """
 | |
|     assert len(ns) == len(
 | |
|         ks), "The length of `ns` must be equal to that of `ks`"
 | |
|     try:
 | |
|         from fastdeploy.model_executor.ops.gpu import per_channel_fp8_fp8_half_gemm_fused
 | |
|     except ImportError:
 | |
|         logger.warning(
 | |
|             "From fastdeploy.model_executor.ops.gpu import per_channel_fp8_fp8_half_gemm_fused failed, \
 | |
|             fp8 is only support cuda arch 89+.")
 | |
|         return
 | |
|     paddle.seed(2003)
 | |
|     for m in range(m_min, m_max + 32, 32):
 | |
|         if m > m_max:
 | |
|             break
 | |
|         for idx in range(len(ns)):
 | |
|             n = ns[idx]
 | |
|             k = ks[idx]
 | |
|             A = paddle.rand(shape=[m, k],
 | |
|                             dtype="bfloat16").astype("float8_e4m3fn")
 | |
|             B = paddle.rand(shape=[n, k],
 | |
|                             dtype="bfloat16").astype("float8_e4m3fn")
 | |
|             scalar_scale = paddle.full([1], 0.168, dtype="float32")
 | |
|             channel_scale = paddle.rand(shape=[n], dtype="float32")
 | |
| 
 | |
|             per_channel_fp8_fp8_half_gemm_fused(
 | |
|                 A,
 | |
|                 B,
 | |
|                 bias=None,
 | |
|                 scalar_scale=scalar_scale,
 | |
|                 channel_scale=channel_scale,
 | |
|                 transpose_x=False,
 | |
|                 transpose_y=True,
 | |
|                 output_dtype="bfloat16",
 | |
|             )
 | |
|             paddle.device.cuda.empty_cache()
 | |
| 
 | |
| 
 | |
| def tune_blockwise_fp8_gemm_fused(
 | |
|     ns: list,
 | |
|     ks: list,
 | |
|     m_min: int = 32,
 | |
|     m_max: int = 32768,
 | |
| ):
 | |
|     """
 | |
|     Tune per-channel quant gemm.
 | |
|     """
 | |
|     assert len(ns) == len(
 | |
|         ks), "The length of `ns` must be equal to that of `ks`"
 | |
|     try:
 | |
|         from fastdeploy.model_executor.ops.gpu import cutlass_fp8_fp8_half_block_gemm_fused
 | |
|     except ImportError:
 | |
|         logger.warning(
 | |
|             "From fastdeploy.model_executor.ops.gpu import cutlass_fp8_fp8_half_block_gemm_fused failed, \
 | |
|             fp8 is only support cuda arch 90+.")
 | |
|         return
 | |
|     paddle.seed(2003)
 | |
|     for m in range(m_min, m_max + 32, 32):
 | |
|         if m > m_max:
 | |
|             break
 | |
|         for idx in range(len(ns)):
 | |
|             n = ns[idx]
 | |
|             k = ks[idx]
 | |
|             scale_n = (n + 128 - 1) // 128
 | |
|             scale_k = (k + 128 - 1) // 128
 | |
|             A = paddle.rand(shape=[m, k],
 | |
|                             dtype="bfloat16").astype("float8_e4m3fn")
 | |
|             B = paddle.rand(shape=[n, k],
 | |
|                             dtype="bfloat16").astype("float8_e4m3fn")
 | |
|             a_scale = paddle.randn([scale_k, m], dtype="float32")
 | |
|             b_scale = paddle.randn([scale_n, scale_k], dtype="float32")
 | |
| 
 | |
|             cutlass_fp8_fp8_half_block_gemm_fused(
 | |
|                 A,
 | |
|                 B,
 | |
|                 x_sacle=a_scale,
 | |
|                 y_sacle=b_scale,
 | |
|                 bias=None,
 | |
|                 transpose_x=False,
 | |
|                 transpose_y=True,
 | |
|                 output_dtype="bfloat16",
 | |
|             )
 | |
|             paddle.device.cuda.empty_cache()
 | 
