From 7bafcf1df3be3f4626accb71aedad3ff5694a99e Mon Sep 17 00:00:00 2001 From: K11OntheBoat Date: Fri, 28 Nov 2025 14:23:44 +0800 Subject: [PATCH] [OP]Remove extra H2D in DeepGemm (#5262) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: K11OntheBoat <“ruianmaidanglao@163.com”> --- custom_ops/0001-DeepGEMM-95e81b3.patch | 58 ++++++++++++++++---------- 1 file changed, 35 insertions(+), 23 deletions(-) diff --git a/custom_ops/0001-DeepGEMM-95e81b3.patch b/custom_ops/0001-DeepGEMM-95e81b3.patch index c3f409c14..eb828a1b5 100644 --- a/custom_ops/0001-DeepGEMM-95e81b3.patch +++ b/custom_ops/0001-DeepGEMM-95e81b3.patch @@ -1,22 +1,22 @@ -From 5112002c155dceecc5e5983cdb67157e4f5400e2 Mon Sep 17 00:00:00 2001 -From: minghaipeng -Date: Wed, 25 Jun 2025 15:05:24 +0800 -Subject: [PATCH] DeepGEMM 95e81b3 +From 7008a3c8b7fe833c952f27a5ab3848c485f02b5d Mon Sep 17 00:00:00 2001 +From: K11OntheBoat <“ruianmaidanglao@163.com”> +Date: Thu, 27 Nov 2025 14:38:47 +0800 +Subject: [PATCH] Remove extra H2D in DeepGemm --- - deep_gemm/__init__.py | 2 +- - deep_gemm/include/deep_gemm/scheduler.cuh | 2 +- - deep_gemm/jit/compiler.py | 2 +- - deep_gemm/jit/interleave_ffma.py | 2 +- - deep_gemm/jit/runtime.py | 4 +- - deep_gemm/jit/template.py | 34 ++++---- - deep_gemm/jit_kernels/gemm.py | 44 +++++------ - deep_gemm/jit_kernels/m_grouped_gemm.py | 96 +++++++++++------------ - deep_gemm/jit_kernels/tuner.py | 10 +-- - deep_gemm/jit_kernels/utils.py | 18 +++-- - deep_gemm/paddle_utils.py | 20 +++++ - deep_gemm/utils.py | 30 +++---- - 12 files changed, 143 insertions(+), 121 deletions(-) + deep_gemm/__init__.py | 2 +- + deep_gemm/include/deep_gemm/scheduler.cuh | 2 +- + deep_gemm/jit/compiler.py | 2 +- + deep_gemm/jit/interleave_ffma.py | 2 +- + deep_gemm/jit/runtime.py | 4 +- + deep_gemm/jit/template.py | 34 +++---- + deep_gemm/jit_kernels/gemm.py | 44 ++++----- + deep_gemm/jit_kernels/m_grouped_gemm.py | 104 +++++++++++----------- + deep_gemm/jit_kernels/tuner.py | 10 +-- + deep_gemm/jit_kernels/utils.py | 18 ++-- + deep_gemm/paddle_utils.py | 20 +++++ + deep_gemm/utils.py | 30 +++---- + 12 files changed, 147 insertions(+), 125 deletions(-) create mode 100644 deep_gemm/paddle_utils.py diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py @@ -257,7 +257,7 @@ index cb438b7..44aa0ed 100644 args=args ) diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py -index 3b518c9..ba776bd 100644 +index 3b518c9..b94e65d 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -1,4 +1,4 @@ @@ -299,8 +299,14 @@ index 3b518c9..ba776bd 100644 `m_indices[i]` records the group which the i-th row of the LHS belong to, which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`. Values of `m_indices` in every-m-alignment-block must also be the same. -@@ -68,19 +68,19 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten - m__ = m_indices.numel() +@@ -64,23 +64,23 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten + rhs, rhs_scales = rhs + m, k = lhs.shape + num_groups, n, k_ = rhs.shape +- m_, n_ = out.shape +- m__ = m_indices.numel() ++ # m_, n_ = out.shape ++ # m__ = m_indices.numel() # Type and shape checks - assert m == m_ == m__ and k == k_ and n == n_ @@ -384,8 +390,14 @@ index 3b518c9..ba776bd 100644 the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`. out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result. masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute -@@ -149,21 +149,21 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] - num_groups___ = masked_m.numel() +@@ -145,25 +145,25 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] + rhs, rhs_scales = rhs + num_groups, m, k = lhs.shape + num_groups_, n, k_ = rhs.shape +- num_groups__, m_, n_ = out.shape +- num_groups___ = masked_m.numel() ++ # num_groups__, m_, n_ = out.shape ++ # num_groups___ = masked_m.numel() # Type and shape checks - assert num_groups == num_groups_ == num_groups__ == num_groups___ @@ -563,7 +575,7 @@ index 0000000..2326807 +CUDA_HOME = get_cuda_home() \ No newline at end of file diff --git a/deep_gemm/utils.py b/deep_gemm/utils.py -index d5cdd01..5237f09 100644 +index d5cdd01..011f14a 100644 --- a/deep_gemm/utils.py +++ b/deep_gemm/utils.py @@ -1,15 +1,15 @@