mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[OP]Remove extra H2D in DeepGemm (#5262)
Co-authored-by: K11OntheBoat <“ruianmaidanglao@163.com”>
This commit is contained in:
@@ -1,22 +1,22 @@
|
||||
From 5112002c155dceecc5e5983cdb67157e4f5400e2 Mon Sep 17 00:00:00 2001
|
||||
From: minghaipeng <minghaipeng@baidu.com>
|
||||
Date: Wed, 25 Jun 2025 15:05:24 +0800
|
||||
Subject: [PATCH] DeepGEMM 95e81b3
|
||||
From 7008a3c8b7fe833c952f27a5ab3848c485f02b5d Mon Sep 17 00:00:00 2001
|
||||
From: K11OntheBoat <“ruianmaidanglao@163.com”>
|
||||
Date: Thu, 27 Nov 2025 14:38:47 +0800
|
||||
Subject: [PATCH] Remove extra H2D in DeepGemm
|
||||
|
||||
---
|
||||
deep_gemm/__init__.py | 2 +-
|
||||
deep_gemm/include/deep_gemm/scheduler.cuh | 2 +-
|
||||
deep_gemm/jit/compiler.py | 2 +-
|
||||
deep_gemm/jit/interleave_ffma.py | 2 +-
|
||||
deep_gemm/jit/runtime.py | 4 +-
|
||||
deep_gemm/jit/template.py | 34 ++++----
|
||||
deep_gemm/jit_kernels/gemm.py | 44 +++++------
|
||||
deep_gemm/jit_kernels/m_grouped_gemm.py | 96 +++++++++++------------
|
||||
deep_gemm/jit_kernels/tuner.py | 10 +--
|
||||
deep_gemm/jit_kernels/utils.py | 18 +++--
|
||||
deep_gemm/paddle_utils.py | 20 +++++
|
||||
deep_gemm/utils.py | 30 +++----
|
||||
12 files changed, 143 insertions(+), 121 deletions(-)
|
||||
deep_gemm/__init__.py | 2 +-
|
||||
deep_gemm/include/deep_gemm/scheduler.cuh | 2 +-
|
||||
deep_gemm/jit/compiler.py | 2 +-
|
||||
deep_gemm/jit/interleave_ffma.py | 2 +-
|
||||
deep_gemm/jit/runtime.py | 4 +-
|
||||
deep_gemm/jit/template.py | 34 +++----
|
||||
deep_gemm/jit_kernels/gemm.py | 44 ++++-----
|
||||
deep_gemm/jit_kernels/m_grouped_gemm.py | 104 +++++++++++-----------
|
||||
deep_gemm/jit_kernels/tuner.py | 10 +--
|
||||
deep_gemm/jit_kernels/utils.py | 18 ++--
|
||||
deep_gemm/paddle_utils.py | 20 +++++
|
||||
deep_gemm/utils.py | 30 +++----
|
||||
12 files changed, 147 insertions(+), 125 deletions(-)
|
||||
create mode 100644 deep_gemm/paddle_utils.py
|
||||
|
||||
diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py
|
||||
@@ -257,7 +257,7 @@ index cb438b7..44aa0ed 100644
|
||||
args=args
|
||||
)
|
||||
diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py
|
||||
index 3b518c9..ba776bd 100644
|
||||
index 3b518c9..b94e65d 100644
|
||||
--- a/deep_gemm/jit_kernels/m_grouped_gemm.py
|
||||
+++ b/deep_gemm/jit_kernels/m_grouped_gemm.py
|
||||
@@ -1,4 +1,4 @@
|
||||
@@ -299,8 +299,14 @@ index 3b518c9..ba776bd 100644
|
||||
`m_indices[i]` records the group which the i-th row of the LHS belong to,
|
||||
which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`.
|
||||
Values of `m_indices` in every-m-alignment-block must also be the same.
|
||||
@@ -68,19 +68,19 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
m__ = m_indices.numel()
|
||||
@@ -64,23 +64,23 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
rhs, rhs_scales = rhs
|
||||
m, k = lhs.shape
|
||||
num_groups, n, k_ = rhs.shape
|
||||
- m_, n_ = out.shape
|
||||
- m__ = m_indices.numel()
|
||||
+ # m_, n_ = out.shape
|
||||
+ # m__ = m_indices.numel()
|
||||
|
||||
# Type and shape checks
|
||||
- assert m == m_ == m__ and k == k_ and n == n_
|
||||
@@ -384,8 +390,14 @@ index 3b518c9..ba776bd 100644
|
||||
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
|
||||
out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result.
|
||||
masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
|
||||
@@ -149,21 +149,21 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
num_groups___ = masked_m.numel()
|
||||
@@ -145,25 +145,25 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
rhs, rhs_scales = rhs
|
||||
num_groups, m, k = lhs.shape
|
||||
num_groups_, n, k_ = rhs.shape
|
||||
- num_groups__, m_, n_ = out.shape
|
||||
- num_groups___ = masked_m.numel()
|
||||
+ # num_groups__, m_, n_ = out.shape
|
||||
+ # num_groups___ = masked_m.numel()
|
||||
|
||||
# Type and shape checks
|
||||
- assert num_groups == num_groups_ == num_groups__ == num_groups___
|
||||
@@ -563,7 +575,7 @@ index 0000000..2326807
|
||||
+CUDA_HOME = get_cuda_home()
|
||||
\ No newline at end of file
|
||||
diff --git a/deep_gemm/utils.py b/deep_gemm/utils.py
|
||||
index d5cdd01..5237f09 100644
|
||||
index d5cdd01..011f14a 100644
|
||||
--- a/deep_gemm/utils.py
|
||||
+++ b/deep_gemm/utils.py
|
||||
@@ -1,15 +1,15 @@
|
||||
|
||||
Reference in New Issue
Block a user