mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 20:02:53 +08:00 
			
		
		
		
	polish code with new pre-commit rule (#2923)
This commit is contained in:
		| @@ -11,7 +11,7 @@ | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| """ UT for cutlass_fp8_fp8_half_gemm_fused """ | ||||
| """UT for cutlass_fp8_fp8_half_gemm_fused""" | ||||
| import paddle | ||||
|  | ||||
| from fastdeploy.utils import llm_logger as logger | ||||
| @@ -26,14 +26,14 @@ def tune_cutlass_fp8_fp8_half_gemm_fused( | ||||
|     """ | ||||
|     Tune fp8 gemm. | ||||
|     """ | ||||
|     assert len(ns) == len( | ||||
|         ks), "The length of `ns` must be equal to that of `ks`" | ||||
|     assert len(ns) == len(ks), "The length of `ns` must be equal to that of `ks`" | ||||
|     try: | ||||
|         from fastdeploy.model_executor.ops.gpu import cutlass_fp8_fp8_half_gemm_fused | ||||
|     except ImportError: | ||||
|         logger.warning( | ||||
|             "From fastdeploy.model_executor.ops.gpu import cutlass_fp8_fp8_half_gemm_fused failed, \ | ||||
|             fp8 is only support cuda arch 89+.") | ||||
|             fp8 is only support cuda arch 89+." | ||||
|         ) | ||||
|         return | ||||
|     paddle.seed(2003) | ||||
|     for m in range(m_min, m_max + 32, 32): | ||||
| @@ -42,10 +42,8 @@ def tune_cutlass_fp8_fp8_half_gemm_fused( | ||||
|         for idx in range(len(ns)): | ||||
|             n = ns[idx] | ||||
|             k = ks[idx] | ||||
|             A = paddle.rand(shape=[m, k], | ||||
|                             dtype="bfloat16").astype("float8_e4m3fn") | ||||
|             B = paddle.rand(shape=[n, k], | ||||
|                             dtype="bfloat16").astype("float8_e4m3fn") | ||||
|             A = paddle.rand(shape=[m, k], dtype="bfloat16").astype("float8_e4m3fn") | ||||
|             B = paddle.rand(shape=[n, k], dtype="bfloat16").astype("float8_e4m3fn") | ||||
|             cutlass_fp8_fp8_half_gemm_fused( | ||||
|                 A, | ||||
|                 B, | ||||
| @@ -68,14 +66,16 @@ def tune_cutlass_fp8_fp8_fp8_dual_gemm_fused( | ||||
|     """ | ||||
|     Tune fp8 dual-gemm. | ||||
|     """ | ||||
|     assert len(ns) == len( | ||||
|         ks), "The length of `ns` must be equal to that of `ks`" | ||||
|     assert len(ns) == len(ks), "The length of `ns` must be equal to that of `ks`" | ||||
|     try: | ||||
|         from fastdeploy.model_executor.ops.gpu import cutlass_fp8_fp8_fp8_dual_gemm_fused | ||||
|         from fastdeploy.model_executor.ops.gpu import ( | ||||
|             cutlass_fp8_fp8_fp8_dual_gemm_fused, | ||||
|         ) | ||||
|     except ImportError: | ||||
|         logger.warning( | ||||
|             "From fastdeploy.model_executor.ops.gpu import cutlass_fp8_fp8_fp8_dual_gemm_fused failed, \ | ||||
|             fp8 is only support cuda arch 89+.") | ||||
|             fp8 is only support cuda arch 89+." | ||||
|         ) | ||||
|         return | ||||
|     paddle.seed(2003) | ||||
|     for m in range(m_min, m_max + 32, 32): | ||||
| @@ -84,12 +84,9 @@ def tune_cutlass_fp8_fp8_fp8_dual_gemm_fused( | ||||
|         for idx in range(len(ns)): | ||||
|             n = ns[idx] | ||||
|             k = ks[idx] | ||||
|             A = paddle.rand(shape=[m, k], | ||||
|                             dtype="bfloat16").astype("float8_e4m3fn") | ||||
|             B0 = paddle.rand(shape=[n, k], | ||||
|                              dtype="bfloat16").astype("float8_e4m3fn") | ||||
|             B1 = paddle.rand(shape=[n, k], | ||||
|                              dtype="bfloat16").astype("float8_e4m3fn") | ||||
|             A = paddle.rand(shape=[m, k], dtype="bfloat16").astype("float8_e4m3fn") | ||||
|             B0 = paddle.rand(shape=[n, k], dtype="bfloat16").astype("float8_e4m3fn") | ||||
|             B1 = paddle.rand(shape=[n, k], dtype="bfloat16").astype("float8_e4m3fn") | ||||
|             cutlass_fp8_fp8_fp8_dual_gemm_fused( | ||||
|                 A, | ||||
|                 B0, | ||||
| @@ -115,14 +112,16 @@ def tune_per_channel_fp8_gemm_fused( | ||||
|     """ | ||||
|     Tune per-channel quant gemm. | ||||
|     """ | ||||
|     assert len(ns) == len( | ||||
|         ks), "The length of `ns` must be equal to that of `ks`" | ||||
|     assert len(ns) == len(ks), "The length of `ns` must be equal to that of `ks`" | ||||
|     try: | ||||
|         from fastdeploy.model_executor.ops.gpu import per_channel_fp8_fp8_half_gemm_fused | ||||
|         from fastdeploy.model_executor.ops.gpu import ( | ||||
|             per_channel_fp8_fp8_half_gemm_fused, | ||||
|         ) | ||||
|     except ImportError: | ||||
|         logger.warning( | ||||
|             "From fastdeploy.model_executor.ops.gpu import per_channel_fp8_fp8_half_gemm_fused failed, \ | ||||
|             fp8 is only support cuda arch 89+.") | ||||
|             fp8 is only support cuda arch 89+." | ||||
|         ) | ||||
|         return | ||||
|     paddle.seed(2003) | ||||
|     for m in range(m_min, m_max + 32, 32): | ||||
| @@ -131,10 +130,8 @@ def tune_per_channel_fp8_gemm_fused( | ||||
|         for idx in range(len(ns)): | ||||
|             n = ns[idx] | ||||
|             k = ks[idx] | ||||
|             A = paddle.rand(shape=[m, k], | ||||
|                             dtype="bfloat16").astype("float8_e4m3fn") | ||||
|             B = paddle.rand(shape=[n, k], | ||||
|                             dtype="bfloat16").astype("float8_e4m3fn") | ||||
|             A = paddle.rand(shape=[m, k], dtype="bfloat16").astype("float8_e4m3fn") | ||||
|             B = paddle.rand(shape=[n, k], dtype="bfloat16").astype("float8_e4m3fn") | ||||
|             scalar_scale = paddle.full([1], 0.168, dtype="float32") | ||||
|             channel_scale = paddle.rand(shape=[n], dtype="float32") | ||||
|  | ||||
| @@ -160,14 +157,16 @@ def tune_blockwise_fp8_gemm_fused( | ||||
|     """ | ||||
|     Tune per-channel quant gemm. | ||||
|     """ | ||||
|     assert len(ns) == len( | ||||
|         ks), "The length of `ns` must be equal to that of `ks`" | ||||
|     assert len(ns) == len(ks), "The length of `ns` must be equal to that of `ks`" | ||||
|     try: | ||||
|         from fastdeploy.model_executor.ops.gpu import cutlass_fp8_fp8_half_block_gemm_fused | ||||
|         from fastdeploy.model_executor.ops.gpu import ( | ||||
|             cutlass_fp8_fp8_half_block_gemm_fused, | ||||
|         ) | ||||
|     except ImportError: | ||||
|         logger.warning( | ||||
|             "From fastdeploy.model_executor.ops.gpu import cutlass_fp8_fp8_half_block_gemm_fused failed, \ | ||||
|             fp8 is only support cuda arch 90+.") | ||||
|             fp8 is only support cuda arch 90+." | ||||
|         ) | ||||
|         return | ||||
|     paddle.seed(2003) | ||||
|     for m in range(m_min, m_max + 32, 32): | ||||
| @@ -178,10 +177,8 @@ def tune_blockwise_fp8_gemm_fused( | ||||
|             k = ks[idx] | ||||
|             scale_n = (n + 128 - 1) // 128 | ||||
|             scale_k = (k + 128 - 1) // 128 | ||||
|             A = paddle.rand(shape=[m, k], | ||||
|                             dtype="bfloat16").astype("float8_e4m3fn") | ||||
|             B = paddle.rand(shape=[n, k], | ||||
|                             dtype="bfloat16").astype("float8_e4m3fn") | ||||
|             A = paddle.rand(shape=[m, k], dtype="bfloat16").astype("float8_e4m3fn") | ||||
|             B = paddle.rand(shape=[n, k], dtype="bfloat16").astype("float8_e4m3fn") | ||||
|             a_scale = paddle.randn([scale_k, m], dtype="float32") | ||||
|             b_scale = paddle.randn([scale_n, scale_k], dtype="float32") | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Zero Rains
					Zero Rains