Files
FastDeploy/custom_ops/0001-DeepGEMM-95e81b3.patch
Yuanle Liu 240bdac2a4
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
[feat] support fa3 backend for pd disaggregated (#2695)
* support fa3 backend run in pd disaggregated

* support fa3 backend run in pd disaggregated

* support fa3 backend run in pd disaggregated

* support fa3 backend run in pd disaggregated

* delete use_fast_ffn
2025-07-03 22:33:27 +08:00

643 lines
31 KiB
Diff

From 5112002c155dceecc5e5983cdb67157e4f5400e2 Mon Sep 17 00:00:00 2001
From: minghaipeng <minghaipeng@baidu.com>
Date: Wed, 25 Jun 2025 15:05:24 +0800
Subject: [PATCH] DeepGEMM 95e81b3
---
deep_gemm/__init__.py | 2 +-
deep_gemm/include/deep_gemm/scheduler.cuh | 2 +-
deep_gemm/jit/compiler.py | 2 +-
deep_gemm/jit/interleave_ffma.py | 2 +-
deep_gemm/jit/runtime.py | 4 +-
deep_gemm/jit/template.py | 34 ++++----
deep_gemm/jit_kernels/gemm.py | 44 +++++------
deep_gemm/jit_kernels/m_grouped_gemm.py | 96 +++++++++++------------
deep_gemm/jit_kernels/tuner.py | 10 +--
deep_gemm/jit_kernels/utils.py | 18 +++--
deep_gemm/paddle_utils.py | 20 +++++
deep_gemm/utils.py | 30 +++----
12 files changed, 143 insertions(+), 121 deletions(-)
create mode 100644 deep_gemm/paddle_utils.py
diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py
index 15b22ca..63e7fb7 100644
--- a/deep_gemm/__init__.py
+++ b/deep_gemm/__init__.py
@@ -1,4 +1,4 @@
-import torch
+import paddle
from . import jit
from .jit_kernels import (
diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh
index 9743871..6c97152 100644
--- a/deep_gemm/include/deep_gemm/scheduler.cuh
+++ b/deep_gemm/include/deep_gemm/scheduler.cuh
@@ -102,7 +102,7 @@ struct Scheduler {
if constexpr (kGemmType == GemmType::Normal) {
return block_idx * block_size;
} else if constexpr (kGemmType == GemmType::GroupedContiguous) {
- auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M);
+ auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M));
return offset * shape_dim + block_idx * block_size;
} else if constexpr (kGemmType == GemmType::GroupedMasked) {
return curr_group_idx * shape_dim + block_idx * block_size;
diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py
index c17d466..6fdc52f 100644
--- a/deep_gemm/jit/compiler.py
+++ b/deep_gemm/jit/compiler.py
@@ -4,7 +4,7 @@ import os
import re
import subprocess
import uuid
-from torch.utils.cpp_extension import CUDA_HOME
+from ..paddle_utils import CUDA_HOME
from typing import Tuple
from . import interleave_ffma
diff --git a/deep_gemm/jit/interleave_ffma.py b/deep_gemm/jit/interleave_ffma.py
index fcb377e..db9d6f3 100644
--- a/deep_gemm/jit/interleave_ffma.py
+++ b/deep_gemm/jit/interleave_ffma.py
@@ -3,7 +3,7 @@ import mmap
import os
import re
import subprocess
-from torch.utils.cpp_extension import CUDA_HOME
+from ..paddle_utils import CUDA_HOME
def run_cuobjdump(file_path):
diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py
index 66c370a..4761426 100644
--- a/deep_gemm/jit/runtime.py
+++ b/deep_gemm/jit/runtime.py
@@ -1,6 +1,6 @@
import ctypes
import os
-import torch
+import paddle
from typing import Optional
from .template import map_ctype
@@ -35,7 +35,7 @@ class Runtime:
assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
cargs = []
for arg, (name, dtype) in zip(args, self.args):
- if isinstance(arg, torch.Tensor):
+ if isinstance(arg, paddle.Tensor):
assert arg.dtype == dtype, f'Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`'
else:
assert isinstance(arg, dtype), f'Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`'
diff --git a/deep_gemm/jit/template.py b/deep_gemm/jit/template.py
index ead37f5..51b02c1 100644
--- a/deep_gemm/jit/template.py
+++ b/deep_gemm/jit/template.py
@@ -1,24 +1,24 @@
import copy
import ctypes
import os
-import torch
+import paddle
from typing import Any, Dict, Iterable, Tuple
# Name map for Python `eval`
typename_map: Dict[Any, str] = {
**{t: t.__name__ for t in (bool, int, float)},
- torch.int: 'torch.int',
- torch.float: 'torch.float',
- torch.bfloat16: 'torch.bfloat16',
- torch.float8_e4m3fn: 'torch.float8_e4m3fn',
- torch.cuda.Stream: 'torch.cuda.Stream',
+ paddle.int32: 'paddle.int32',
+ paddle.float32: 'paddle.float32',
+ paddle.bfloat16: 'paddle.bfloat16',
+ paddle.float8_e4m3fn: 'paddle.float8_e4m3fn',
+ paddle.device.cuda.Stream: "paddle.device.cuda.Stream",
}
# `ctype` map for Python casting
ctype_map: Dict[Any, Any] = {
**{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
- **{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
+ **{t: ctypes.c_void_p for t in (paddle.int32, paddle.float32, paddle.bfloat16, paddle.float8_e4m3fn, paddle.device.cuda.Stream)},
}
@@ -27,25 +27,25 @@ genc_map = {
bool: ('bool', 'bool'),
int: ('int', 'int'),
float: ('float', 'float'),
- torch.int: ('void*', 'int*'),
- torch.float: ('void*', 'float*'),
- torch.bfloat16: ('void*', '__nv_bfloat16*'),
- torch.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
- torch.cuda.Stream: ('void*', 'cudaStream_t'),
+ paddle.int32: ('void*', 'int*'),
+ paddle.float32: ('void*', 'float*'),
+ paddle.bfloat16: ('void*', '__nv_bfloat16*'),
+ paddle.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
+ paddle.device.cuda.Stream: ('void*', 'cudaStream_t'),
}
def map_ctype(value: Any) -> Any:
if hasattr(value, 'data_ptr'):
- if value.dtype == torch.int:
+ if value.dtype == paddle.int32:
return ctypes.c_void_p(value.data_ptr())
- elif value.dtype == torch.float:
+ elif value.dtype == paddle.float32:
return ctypes.c_void_p(value.data_ptr())
- elif value.dtype == torch.bfloat16:
+ elif value.dtype == paddle.bfloat16:
return ctypes.c_void_p(value.data_ptr())
- elif value.dtype == torch.float16:
+ elif value.dtype == paddle.float16:
return ctypes.c_void_p(value.data_ptr())
- elif value.dtype == torch.float8_e4m3fn:
+ elif value.dtype == paddle.float8_e4m3fn:
return ctypes.c_void_p(value.data_ptr())
else:
return ctypes.c_void_p(value.data_ptr())
diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py
index cb438b7..44aa0ed 100644
--- a/deep_gemm/jit_kernels/gemm.py
+++ b/deep_gemm/jit_kernels/gemm.py
@@ -1,5 +1,5 @@
import math
-import torch
+import paddle
from functools import lru_cache
from typing import Tuple
@@ -166,20 +166,20 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config
-def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
- rhs: Tuple[torch.Tensor, torch.Tensor],
- out: torch.Tensor) -> None:
+def gemm_fp8_fp8_bf16_nt(lhs: Tuple[paddle.Tensor, paddle.Tensor],
+ rhs: Tuple[paddle.Tensor, paddle.Tensor],
+ out: paddle.Tensor) -> None:
"""
Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
- this function will do a transposing with a set of slow PyTorch operations.
+ this function will do a transposing with a set of slow paddle operations.
Arguments:
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m, k]`,
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`.
- rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`.
+ rhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[n, k]`.
the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`.
out: the BF16 output tensor of shape `[m, n]`, representing the result.
"""
@@ -189,22 +189,22 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
n, k_ = rhs.shape
m_, n_ = out.shape
- assert n % 64 == 0 and k % 128 == 0
+ # assert n % 64 == 0 and k % 128 == 0
# Type and shape checks
- assert m == m_ and n == n_ and k == k_
- assert n > 0 and k > 0
- assert lhs_scales.shape == (m, (k + 127) // 128)
- assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128)
- assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
- assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
- assert out.dtype == torch.bfloat16
- assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
+ # assert m == m_ and n == n_ and k == k_
+ # assert n > 0 and k > 0
+ # assert lhs_scales.shape == (m, (k + 127) // 128)
+ # assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128)
+ # assert lhs.dtype == paddle.float8_e4m3fn and lhs_scales.dtype == paddle.float32
+ # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
+ # assert out.dtype == paddle.bfloat16
+ # assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
# LHS scales must be transposed for TMA load, but not for RHS scales
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
- assert rhs_scales.is_contiguous()
+ # assert rhs_scales.is_contiguous()
# Do nothing if `m` is zero
if m == 0:
@@ -214,7 +214,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
global includes, template
num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms)
- args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_config[0])
+ args = (lhs, lhs_scales, rhs, rhs_scales, out, m, paddle.device.cuda.current_stream(), num_sms, smem_config[0])
runtime = jit_tuner.compile_and_tune(
name='gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
@@ -225,10 +225,10 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
space=(),
includes=includes,
- arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
- ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
- ('out', torch.bfloat16), ('m', int),
- ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
+ arg_defs=(('lhs', paddle.float8_e4m3fn), ('lhs_scales', paddle.float32),
+ ('rhs', paddle.float8_e4m3fn), ('rhs_scales', paddle.float32),
+ ('out', paddle.bfloat16), ('m', int),
+ ('stream', paddle.device.cuda.Stream), ('num_sms', int), ('smem_size', int)),
template=template,
args=args
)
diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py
index 3b518c9..ba776bd 100644
--- a/deep_gemm/jit_kernels/m_grouped_gemm.py
+++ b/deep_gemm/jit_kernels/m_grouped_gemm.py
@@ -1,4 +1,4 @@
-import torch
+import paddle
from typing import Tuple
from .gemm import get_best_configs, get_block_n_padding_for_smem_d
@@ -37,25 +37,25 @@ gemm_t::run(out, rhs_scales, grouped_layout,
"""
-def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
- rhs: Tuple[torch.Tensor, torch.Tensor],
- out: torch.Tensor, m_indices: torch.Tensor) -> None:
+def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[paddle.Tensor, paddle.Tensor],
+ rhs: Tuple[paddle.Tensor, paddle.Tensor],
+ out: paddle.Tensor, m_indices: paddle.Tensor) -> None:
"""
Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
- this function will do a transposing with a set of slow PyTorch operations.
+ this function will do a transposing with a set of slow Pypaddle operations.
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
`get_m_alignment_for_contiguous_layout()` (128).
Arguments:
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m_sum, k]`,
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m_sum, ⌈k / 128⌉]`.
- rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`.
+ rhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, n, k]`.
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
out: the BF16 output tensor of shape `[m_sum, n]`, representing the result.
- m_indices: a tensor of shape `[m_sum]` with type `torch.int`.
+ m_indices: a tensor of shape `[m_sum]` with type `paddle.int`.
`m_indices[i]` records the group which the i-th row of the LHS belong to,
which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`.
Values of `m_indices` in every-m-alignment-block must also be the same.
@@ -68,19 +68,19 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
m__ = m_indices.numel()
# Type and shape checks
- assert m == m_ == m__ and k == k_ and n == n_
- assert lhs_scales.shape == (m, (k + 127) // 128)
- assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
- assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
- assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
- assert out.dtype == torch.bfloat16
- assert m_indices.dtype == torch.int32
- assert lhs.is_contiguous() and rhs.is_contiguous()
- assert out.is_contiguous() and m_indices.is_contiguous()
+ # assert m == m_ == m__ and k == k_ and n == n_
+ # assert lhs_scales.shape == (m, (k + 127) // 128)
+ # assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
+ # assert lhs.dtype == paddle.float8_e4m3fn and lhs_scales.dtype == paddle.float32
+ # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
+ # assert out.dtype == paddle.bfloat16
+ # assert m_indices.dtype == paddle.int32
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
+ # assert out.is_contiguous() and m_indices.is_contiguous()
# LHS scales must be transposed for TMA load, but not for RHS scales
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
- assert rhs_scales.is_contiguous()
+ # assert rhs_scales.is_contiguous()
# Do nothing if `m` is zero
if m == 0:
@@ -92,7 +92,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms, is_grouped_contiguous=True)
args = (lhs, lhs_scales, rhs, rhs_scales, out,
m_indices, m, num_groups,
- torch.cuda.current_stream(), num_sms, smem_config[0])
+ paddle.device.cuda.current_stream(), num_sms, smem_config[0])
runtime = jit_tuner.compile_and_tune(
name='m_grouped_gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
@@ -105,11 +105,11 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
'GEMM_TYPE': 'GroupedContiguous'},
space=(),
includes=includes,
- arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
- ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
- ('out', torch.bfloat16),
- ('grouped_layout', torch.int32), ('m', int), ('num_groups', int),
- ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
+ arg_defs=(('lhs', paddle.float8_e4m3fn), ('lhs_scales', paddle.float32),
+ ('rhs', paddle.float8_e4m3fn), ('rhs_scales', paddle.float32),
+ ('out', paddle.bfloat16),
+ ('grouped_layout', paddle.int32), ('m', int), ('num_groups', int),
+ ('stream', paddle.device.cuda.Stream), ('num_sms', int), ('smem_size', int)),
template=template,
args=args
)
@@ -118,22 +118,22 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
runtime(*args)
-def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
- rhs: Tuple[torch.Tensor, torch.Tensor],
- out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
+def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[paddle.Tensor, paddle.Tensor],
+ rhs: Tuple[paddle.Tensor, paddle.Tensor],
+ out: paddle.Tensor, masked_m: paddle.Tensor, expected_m: int) -> None:
"""
Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
- this function will do a transposing with a set of slow PyTorch operations.
+ this function will do a transposing with a set of slow paddle operations.
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
should be separately transposed.
Arguments:
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
the second element is an FP32 1x128 scaling tensor for LHS of shape `[num_groups, m_max, ⌈k / 128⌉]`.
- rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`.
+ rhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, n, k]`.
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result.
masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
@@ -149,21 +149,21 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
num_groups___ = masked_m.numel()
# Type and shape checks
- assert num_groups == num_groups_ == num_groups__ == num_groups___
- assert m == m_ and n == n_ and k == k_
- assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
- assert lhs_scales.shape == (num_groups, m, (k + 127) // 128)
- assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
- assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
- assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
- assert out.dtype == torch.bfloat16
- assert masked_m.dtype == torch.int32
- assert lhs.is_contiguous() and rhs.is_contiguous()
- assert out.is_contiguous() and masked_m.is_contiguous()
+ # assert num_groups == num_groups_ == num_groups__ == num_groups___
+ # assert m == m_ and n == n_ and k == k_
+ # assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
+ # assert lhs_scales.shape == (num_groups, m, (k + 127) // 128)
+ # assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
+ # assert lhs.dtype == paddle.float8_e4m3fn and lhs_scales.dtype == paddle.float32
+ # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
+ # assert out.dtype == paddle.bfloat16
+ # assert masked_m.dtype == paddle.int32
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
+ # assert out.is_contiguous() and masked_m.is_contiguous()
# LHS scales must be transposed for TMA load, but not for RHS scales
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
- assert rhs_scales.is_contiguous()
+ # assert rhs_scales.is_contiguous()
# Auto-tuning with compilation
global includes, template
@@ -176,7 +176,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
args = (lhs, lhs_scales, rhs, rhs_scales, out,
masked_m, m,
- torch.cuda.current_stream(), num_sms, smem_config[0])
+ paddle.device.cuda.current_stream(), num_sms, smem_config[0])
runtime = jit_tuner.compile_and_tune(
name='m_grouped_gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
@@ -189,11 +189,11 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
'GEMM_TYPE': 'GroupedMasked'},
space=(),
includes=includes,
- arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
- ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
- ('out', torch.bfloat16),
- ('grouped_layout', torch.int32), ('m', int),
- ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
+ arg_defs=(('lhs', paddle.float8_e4m3fn), ('lhs_scales', paddle.float32),
+ ('rhs', paddle.float8_e4m3fn), ('rhs_scales', paddle.float32),
+ ('out', paddle.bfloat16),
+ ('grouped_layout', paddle.int32), ('m', int),
+ ('stream', paddle.device.cuda.Stream), ('num_sms', int), ('smem_size', int)),
template=template,
args=args
)
diff --git a/deep_gemm/jit_kernels/tuner.py b/deep_gemm/jit_kernels/tuner.py
index 6ed6749..9e1d70f 100644
--- a/deep_gemm/jit_kernels/tuner.py
+++ b/deep_gemm/jit_kernels/tuner.py
@@ -1,6 +1,6 @@
import copy
import os
-import torch
+import paddle
from typing import Any, Dict
from ..jit import build, cpp_format, generate, Runtime
@@ -51,10 +51,10 @@ class JITTuner:
continue
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
- torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_()
- torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn((8192, 8192), dtype=torch.float, device='cuda')
+ start_event = paddle.device.cuda.Event(enable_timing=True)
+ end_event = paddle.device.cuda.Event(enable_timing=True)
+ paddle.empty((int(256e6 // 4)), dtype=paddle.int32).zero_()
+ paddle.randn((8192, 8192), dtype=paddle.float32) @ paddle.randn((8192, 8192), dtype=paddle.float32)
start_event.record()
for i in range(20):
assert runtime(*args) == 0
diff --git a/deep_gemm/jit_kernels/utils.py b/deep_gemm/jit_kernels/utils.py
index c6da56b..a17b1b1 100644
--- a/deep_gemm/jit_kernels/utils.py
+++ b/deep_gemm/jit_kernels/utils.py
@@ -1,4 +1,4 @@
-import torch
+import paddle
_num_sms = None
@@ -11,7 +11,7 @@ def set_num_sms(num_sms: int) -> None:
num_sms: the desired maximum SM count for all GEMM kernels to use.
"""
global _num_sms
- assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count
+ assert 0 < num_sms <= paddle.device.cuda.get_device_properties().multi_processor_count
_num_sms = num_sms
@@ -25,7 +25,7 @@ def get_num_sms() -> int:
"""
global _num_sms
if _num_sms is None:
- _num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
+ _num_sms = paddle.device.cuda.get_device_properties().multi_processor_count
return _num_sms
@@ -74,9 +74,9 @@ def get_tma_aligned_size(x: int, element_size: int) -> int:
return ceil_div(x, alignment) * alignment
-def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
+def get_col_major_tma_aligned_tensor(x: paddle.Tensor) -> paddle.Tensor:
"""
- Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary.
+ Returns TMA-aligned transposed format of the input tensor. `paddle.transpose` will be called if necessary.
If the input tensor is already column-major layout and 16-byte aligned along the M axis
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
@@ -92,18 +92,20 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
m, n = x.shape[-2], x.shape[-1]
aligned_m = get_tma_aligned_size(m, x.element_size())
if x.dim() == 2:
- if x.stride(0) == 1 and x.stride(1) == aligned_m:
+ if x.strides[0] == 1 and x.strides[1] == aligned_m:
return x
x, remove_dim = x.unsqueeze(0), True
b = x.shape[0]
# The last kernel gives a column-major TMA aligned layout
- if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
+ if x.strides[0] == aligned_m * n and x.strides[1] == 1 and x.strides[2] == aligned_m:
return x.squeeze(0) if remove_dim else x
# Normal layout requires transposing
- aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
+ aligned_x = paddle.transpose(
+ paddle.empty((b, n, aligned_m), dtype=x.dtype), perm=[0, 2, 1]
+ )
aligned_x[:, :m, :] = x
aligned_x = aligned_x[:, :m, :]
return aligned_x.squeeze(0) if remove_dim else aligned_x
diff --git a/deep_gemm/paddle_utils.py b/deep_gemm/paddle_utils.py
new file mode 100644
index 0000000..2326807
--- /dev/null
+++ b/deep_gemm/paddle_utils.py
@@ -0,0 +1,20 @@
+import os
+
+def get_cuda_home():
+ """Get Cuda home directory"""
+ cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
+ if cuda_home:
+ return cuda_home
+
+ try:
+ which_cmd = "which nvcc"
+
+ nvcc_path = os.popen(which_cmd).read().strip()
+ if nvcc_path:
+ return os.path.dirname(os.path.dirname(nvcc_path))
+ except Exception:
+ pass
+
+ return None
+
+CUDA_HOME = get_cuda_home()
\ No newline at end of file
diff --git a/deep_gemm/utils.py b/deep_gemm/utils.py
index d5cdd01..5237f09 100644
--- a/deep_gemm/utils.py
+++ b/deep_gemm/utils.py
@@ -1,15 +1,15 @@
import os
import sys
import time
-import torch
-import torch.distributed as dist
+import paddle
+import paddle.distributed as dist
def bench(fn, num_warmups: int = 5, num_tests: int = 10,
high_precision: bool = False):
# Flush L2 cache with 256 MB data
- torch.cuda.synchronize()
- cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
+ paddle.device.synchronize()
+ cache = paddle.empty((int(256e6 // 4)), dtype=paddle.int32)
cache.zero_()
# Warmup
@@ -18,18 +18,18 @@ def bench(fn, num_warmups: int = 5, num_tests: int = 10,
# Add a large kernel to eliminate the CPU launch overhead
if high_precision:
- x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
- y = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
+ x = paddle.randn((8192, 8192), dtype=paddle.float32)
+ y = paddle.randn((8192, 8192), dtype=paddle.float32)
x @ y
# Testing
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
+ start_event = paddle.device.cuda.Event(enable_timing=True)
+ end_event = paddle.device.cuda.Event(enable_timing=True)
start_event.record()
for i in range(num_tests):
fn()
end_event.record()
- torch.cuda.synchronize()
+ paddle.device.synchronize()
return start_event.elapsed_time(end_event) / num_tests
@@ -106,21 +106,21 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
# Profile
suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress
with suppress():
- schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None
- profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress()
+ scheduler = paddle.profiler.make_scheduler(closed=0, ready=1, record=1, repeat=1) if not using_nsys else None
+ profiler = paddle.profiler.Profiler(targets=[paddle.profiler.ProfilerTarget.CPU, paddle.profiler.ProfilerTarget.GPU], scheduler=scheduler) if not using_nsys else empty_suppress()
with profiler:
for i in range(2):
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
if barrier_comm_profiling:
- lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
- rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
+ lhs = paddle.randn((8192, 8192), dtype=paddle.float32)
+ rhs = paddle.randn((8192, 8192), dtype=paddle.float32)
lhs @ rhs
- dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda'))
+ dist.all_reduce(paddle.ones(1, dtype=paddle.float32))
for _ in range(num_tests):
if sleep_between_tests > 0.0:
time.sleep(sleep_between_tests)
if flush_l2:
- torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
+ paddle.empty(flush_l2_size, dtype=paddle.int32).zero_()
fn()
if not using_nsys:
--
2.43.0