[OP]Remove extra H2D in DeepGemm (#5262)

Co-authored-by: K11OntheBoat <“ruianmaidanglao@163.com”>
This commit is contained in:
K11OntheBoat
2025-11-28 14:23:44 +08:00
committed by GitHub
parent 95243f012c
commit 7bafcf1df3

View File

@@ -1,22 +1,22 @@
From 5112002c155dceecc5e5983cdb67157e4f5400e2 Mon Sep 17 00:00:00 2001
From: minghaipeng <minghaipeng@baidu.com>
Date: Wed, 25 Jun 2025 15:05:24 +0800
Subject: [PATCH] DeepGEMM 95e81b3
From 7008a3c8b7fe833c952f27a5ab3848c485f02b5d Mon Sep 17 00:00:00 2001
From: K11OntheBoat <“ruianmaidanglao@163.com>
Date: Thu, 27 Nov 2025 14:38:47 +0800
Subject: [PATCH] Remove extra H2D in DeepGemm
---
deep_gemm/__init__.py | 2 +-
deep_gemm/include/deep_gemm/scheduler.cuh | 2 +-
deep_gemm/jit/compiler.py | 2 +-
deep_gemm/jit/interleave_ffma.py | 2 +-
deep_gemm/jit/runtime.py | 4 +-
deep_gemm/jit/template.py | 34 ++++----
deep_gemm/jit_kernels/gemm.py | 44 +++++------
deep_gemm/jit_kernels/m_grouped_gemm.py | 96 +++++++++++------------
deep_gemm/jit_kernels/tuner.py | 10 +--
deep_gemm/jit_kernels/utils.py | 18 +++--
deep_gemm/paddle_utils.py | 20 +++++
deep_gemm/utils.py | 30 +++----
12 files changed, 143 insertions(+), 121 deletions(-)
deep_gemm/__init__.py | 2 +-
deep_gemm/include/deep_gemm/scheduler.cuh | 2 +-
deep_gemm/jit/compiler.py | 2 +-
deep_gemm/jit/interleave_ffma.py | 2 +-
deep_gemm/jit/runtime.py | 4 +-
deep_gemm/jit/template.py | 34 +++----
deep_gemm/jit_kernels/gemm.py | 44 ++++-----
deep_gemm/jit_kernels/m_grouped_gemm.py | 104 +++++++++++-----------
deep_gemm/jit_kernels/tuner.py | 10 +--
deep_gemm/jit_kernels/utils.py | 18 ++--
deep_gemm/paddle_utils.py | 20 +++++
deep_gemm/utils.py | 30 +++----
12 files changed, 147 insertions(+), 125 deletions(-)
create mode 100644 deep_gemm/paddle_utils.py
diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py
@@ -257,7 +257,7 @@ index cb438b7..44aa0ed 100644
args=args
)
diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py
index 3b518c9..ba776bd 100644
index 3b518c9..b94e65d 100644
--- a/deep_gemm/jit_kernels/m_grouped_gemm.py
+++ b/deep_gemm/jit_kernels/m_grouped_gemm.py
@@ -1,4 +1,4 @@
@@ -299,8 +299,14 @@ index 3b518c9..ba776bd 100644
`m_indices[i]` records the group which the i-th row of the LHS belong to,
which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`.
Values of `m_indices` in every-m-alignment-block must also be the same.
@@ -68,19 +68,19 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
m__ = m_indices.numel()
@@ -64,23 +64,23 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
rhs, rhs_scales = rhs
m, k = lhs.shape
num_groups, n, k_ = rhs.shape
- m_, n_ = out.shape
- m__ = m_indices.numel()
+ # m_, n_ = out.shape
+ # m__ = m_indices.numel()
# Type and shape checks
- assert m == m_ == m__ and k == k_ and n == n_
@@ -384,8 +390,14 @@ index 3b518c9..ba776bd 100644
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result.
masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
@@ -149,21 +149,21 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
num_groups___ = masked_m.numel()
@@ -145,25 +145,25 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
rhs, rhs_scales = rhs
num_groups, m, k = lhs.shape
num_groups_, n, k_ = rhs.shape
- num_groups__, m_, n_ = out.shape
- num_groups___ = masked_m.numel()
+ # num_groups__, m_, n_ = out.shape
+ # num_groups___ = masked_m.numel()
# Type and shape checks
- assert num_groups == num_groups_ == num_groups__ == num_groups___
@@ -563,7 +575,7 @@ index 0000000..2326807
+CUDA_HOME = get_cuda_home()
\ No newline at end of file
diff --git a/deep_gemm/utils.py b/deep_gemm/utils.py
index d5cdd01..5237f09 100644
index d5cdd01..011f14a 100644
--- a/deep_gemm/utils.py
+++ b/deep_gemm/utils.py
@@ -1,15 +1,15 @@