mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
643
custom_ops/0001-DeepGEMM-95e81b3.patch
Normal file
643
custom_ops/0001-DeepGEMM-95e81b3.patch
Normal file
@@ -0,0 +1,643 @@
|
||||
From 5112002c155dceecc5e5983cdb67157e4f5400e2 Mon Sep 17 00:00:00 2001
|
||||
From: minghaipeng <minghaipeng@baidu.com>
|
||||
Date: Wed, 25 Jun 2025 15:05:24 +0800
|
||||
Subject: [PATCH] DeepGEMM 95e81b3
|
||||
|
||||
---
|
||||
deep_gemm/__init__.py | 2 +-
|
||||
deep_gemm/include/deep_gemm/scheduler.cuh | 2 +-
|
||||
deep_gemm/jit/compiler.py | 2 +-
|
||||
deep_gemm/jit/interleave_ffma.py | 2 +-
|
||||
deep_gemm/jit/runtime.py | 4 +-
|
||||
deep_gemm/jit/template.py | 34 ++++----
|
||||
deep_gemm/jit_kernels/gemm.py | 44 +++++------
|
||||
deep_gemm/jit_kernels/m_grouped_gemm.py | 96 +++++++++++------------
|
||||
deep_gemm/jit_kernels/tuner.py | 10 +--
|
||||
deep_gemm/jit_kernels/utils.py | 18 +++--
|
||||
deep_gemm/paddle_utils.py | 20 +++++
|
||||
deep_gemm/utils.py | 30 +++----
|
||||
12 files changed, 143 insertions(+), 121 deletions(-)
|
||||
create mode 100644 deep_gemm/paddle_utils.py
|
||||
|
||||
diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py
|
||||
index 15b22ca..63e7fb7 100644
|
||||
--- a/deep_gemm/__init__.py
|
||||
+++ b/deep_gemm/__init__.py
|
||||
@@ -1,4 +1,4 @@
|
||||
-import torch
|
||||
+import paddle
|
||||
|
||||
from . import jit
|
||||
from .jit_kernels import (
|
||||
diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh
|
||||
index 9743871..6c97152 100644
|
||||
--- a/deep_gemm/include/deep_gemm/scheduler.cuh
|
||||
+++ b/deep_gemm/include/deep_gemm/scheduler.cuh
|
||||
@@ -102,7 +102,7 @@ struct Scheduler {
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
return block_idx * block_size;
|
||||
} else if constexpr (kGemmType == GemmType::GroupedContiguous) {
|
||||
- auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M);
|
||||
+ auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M));
|
||||
return offset * shape_dim + block_idx * block_size;
|
||||
} else if constexpr (kGemmType == GemmType::GroupedMasked) {
|
||||
return curr_group_idx * shape_dim + block_idx * block_size;
|
||||
diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py
|
||||
index c17d466..6fdc52f 100644
|
||||
--- a/deep_gemm/jit/compiler.py
|
||||
+++ b/deep_gemm/jit/compiler.py
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
import re
|
||||
import subprocess
|
||||
import uuid
|
||||
-from torch.utils.cpp_extension import CUDA_HOME
|
||||
+from ..paddle_utils import CUDA_HOME
|
||||
from typing import Tuple
|
||||
|
||||
from . import interleave_ffma
|
||||
diff --git a/deep_gemm/jit/interleave_ffma.py b/deep_gemm/jit/interleave_ffma.py
|
||||
index fcb377e..db9d6f3 100644
|
||||
--- a/deep_gemm/jit/interleave_ffma.py
|
||||
+++ b/deep_gemm/jit/interleave_ffma.py
|
||||
@@ -3,7 +3,7 @@ import mmap
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
-from torch.utils.cpp_extension import CUDA_HOME
|
||||
+from ..paddle_utils import CUDA_HOME
|
||||
|
||||
|
||||
def run_cuobjdump(file_path):
|
||||
diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py
|
||||
index 66c370a..4761426 100644
|
||||
--- a/deep_gemm/jit/runtime.py
|
||||
+++ b/deep_gemm/jit/runtime.py
|
||||
@@ -1,6 +1,6 @@
|
||||
import ctypes
|
||||
import os
|
||||
-import torch
|
||||
+import paddle
|
||||
from typing import Optional
|
||||
|
||||
from .template import map_ctype
|
||||
@@ -35,7 +35,7 @@ class Runtime:
|
||||
assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
|
||||
cargs = []
|
||||
for arg, (name, dtype) in zip(args, self.args):
|
||||
- if isinstance(arg, torch.Tensor):
|
||||
+ if isinstance(arg, paddle.Tensor):
|
||||
assert arg.dtype == dtype, f'Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`'
|
||||
else:
|
||||
assert isinstance(arg, dtype), f'Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`'
|
||||
diff --git a/deep_gemm/jit/template.py b/deep_gemm/jit/template.py
|
||||
index ead37f5..51b02c1 100644
|
||||
--- a/deep_gemm/jit/template.py
|
||||
+++ b/deep_gemm/jit/template.py
|
||||
@@ -1,24 +1,24 @@
|
||||
import copy
|
||||
import ctypes
|
||||
import os
|
||||
-import torch
|
||||
+import paddle
|
||||
from typing import Any, Dict, Iterable, Tuple
|
||||
|
||||
|
||||
# Name map for Python `eval`
|
||||
typename_map: Dict[Any, str] = {
|
||||
**{t: t.__name__ for t in (bool, int, float)},
|
||||
- torch.int: 'torch.int',
|
||||
- torch.float: 'torch.float',
|
||||
- torch.bfloat16: 'torch.bfloat16',
|
||||
- torch.float8_e4m3fn: 'torch.float8_e4m3fn',
|
||||
- torch.cuda.Stream: 'torch.cuda.Stream',
|
||||
+ paddle.int32: 'paddle.int32',
|
||||
+ paddle.float32: 'paddle.float32',
|
||||
+ paddle.bfloat16: 'paddle.bfloat16',
|
||||
+ paddle.float8_e4m3fn: 'paddle.float8_e4m3fn',
|
||||
+ paddle.device.cuda.Stream: "paddle.device.cuda.Stream",
|
||||
}
|
||||
|
||||
# `ctype` map for Python casting
|
||||
ctype_map: Dict[Any, Any] = {
|
||||
**{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
|
||||
- **{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
|
||||
+ **{t: ctypes.c_void_p for t in (paddle.int32, paddle.float32, paddle.bfloat16, paddle.float8_e4m3fn, paddle.device.cuda.Stream)},
|
||||
}
|
||||
|
||||
|
||||
@@ -27,25 +27,25 @@ genc_map = {
|
||||
bool: ('bool', 'bool'),
|
||||
int: ('int', 'int'),
|
||||
float: ('float', 'float'),
|
||||
- torch.int: ('void*', 'int*'),
|
||||
- torch.float: ('void*', 'float*'),
|
||||
- torch.bfloat16: ('void*', '__nv_bfloat16*'),
|
||||
- torch.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
|
||||
- torch.cuda.Stream: ('void*', 'cudaStream_t'),
|
||||
+ paddle.int32: ('void*', 'int*'),
|
||||
+ paddle.float32: ('void*', 'float*'),
|
||||
+ paddle.bfloat16: ('void*', '__nv_bfloat16*'),
|
||||
+ paddle.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
|
||||
+ paddle.device.cuda.Stream: ('void*', 'cudaStream_t'),
|
||||
}
|
||||
|
||||
|
||||
def map_ctype(value: Any) -> Any:
|
||||
if hasattr(value, 'data_ptr'):
|
||||
- if value.dtype == torch.int:
|
||||
+ if value.dtype == paddle.int32:
|
||||
return ctypes.c_void_p(value.data_ptr())
|
||||
- elif value.dtype == torch.float:
|
||||
+ elif value.dtype == paddle.float32:
|
||||
return ctypes.c_void_p(value.data_ptr())
|
||||
- elif value.dtype == torch.bfloat16:
|
||||
+ elif value.dtype == paddle.bfloat16:
|
||||
return ctypes.c_void_p(value.data_ptr())
|
||||
- elif value.dtype == torch.float16:
|
||||
+ elif value.dtype == paddle.float16:
|
||||
return ctypes.c_void_p(value.data_ptr())
|
||||
- elif value.dtype == torch.float8_e4m3fn:
|
||||
+ elif value.dtype == paddle.float8_e4m3fn:
|
||||
return ctypes.c_void_p(value.data_ptr())
|
||||
else:
|
||||
return ctypes.c_void_p(value.data_ptr())
|
||||
diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py
|
||||
index cb438b7..44aa0ed 100644
|
||||
--- a/deep_gemm/jit_kernels/gemm.py
|
||||
+++ b/deep_gemm/jit_kernels/gemm.py
|
||||
@@ -1,5 +1,5 @@
|
||||
import math
|
||||
-import torch
|
||||
+import paddle
|
||||
from functools import lru_cache
|
||||
from typing import Tuple
|
||||
|
||||
@@ -166,20 +166,20 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config
|
||||
|
||||
|
||||
-def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- out: torch.Tensor) -> None:
|
||||
+def gemm_fp8_fp8_bf16_nt(lhs: Tuple[paddle.Tensor, paddle.Tensor],
|
||||
+ rhs: Tuple[paddle.Tensor, paddle.Tensor],
|
||||
+ out: paddle.Tensor) -> None:
|
||||
"""
|
||||
Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
||||
- this function will do a transposing with a set of slow PyTorch operations.
|
||||
+ this function will do a transposing with a set of slow paddle operations.
|
||||
|
||||
Arguments:
|
||||
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
|
||||
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m, k]`,
|
||||
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`.
|
||||
- rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`.
|
||||
+ rhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[n, k]`.
|
||||
the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`.
|
||||
out: the BF16 output tensor of shape `[m, n]`, representing the result.
|
||||
"""
|
||||
@@ -189,22 +189,22 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
n, k_ = rhs.shape
|
||||
m_, n_ = out.shape
|
||||
|
||||
- assert n % 64 == 0 and k % 128 == 0
|
||||
+ # assert n % 64 == 0 and k % 128 == 0
|
||||
|
||||
# Type and shape checks
|
||||
- assert m == m_ and n == n_ and k == k_
|
||||
- assert n > 0 and k > 0
|
||||
- assert lhs_scales.shape == (m, (k + 127) // 128)
|
||||
- assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128)
|
||||
- assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||
- assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||
- assert out.dtype == torch.bfloat16
|
||||
- assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
|
||||
+ # assert m == m_ and n == n_ and k == k_
|
||||
+ # assert n > 0 and k > 0
|
||||
+ # assert lhs_scales.shape == (m, (k + 127) // 128)
|
||||
+ # assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128)
|
||||
+ # assert lhs.dtype == paddle.float8_e4m3fn and lhs_scales.dtype == paddle.float32
|
||||
+ # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
|
||||
+ # assert out.dtype == paddle.bfloat16
|
||||
+ # assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
- assert rhs_scales.is_contiguous()
|
||||
+ # assert rhs_scales.is_contiguous()
|
||||
|
||||
# Do nothing if `m` is zero
|
||||
if m == 0:
|
||||
@@ -214,7 +214,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
global includes, template
|
||||
num_sms = get_num_sms()
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms)
|
||||
- args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_config[0])
|
||||
+ args = (lhs, lhs_scales, rhs, rhs_scales, out, m, paddle.device.cuda.current_stream(), num_sms, smem_config[0])
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
name='gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
@@ -225,10 +225,10 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
|
||||
space=(),
|
||||
includes=includes,
|
||||
- arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||
- ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
||||
- ('out', torch.bfloat16), ('m', int),
|
||||
- ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
||||
+ arg_defs=(('lhs', paddle.float8_e4m3fn), ('lhs_scales', paddle.float32),
|
||||
+ ('rhs', paddle.float8_e4m3fn), ('rhs_scales', paddle.float32),
|
||||
+ ('out', paddle.bfloat16), ('m', int),
|
||||
+ ('stream', paddle.device.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
||||
template=template,
|
||||
args=args
|
||||
)
|
||||
diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py
|
||||
index 3b518c9..ba776bd 100644
|
||||
--- a/deep_gemm/jit_kernels/m_grouped_gemm.py
|
||||
+++ b/deep_gemm/jit_kernels/m_grouped_gemm.py
|
||||
@@ -1,4 +1,4 @@
|
||||
-import torch
|
||||
+import paddle
|
||||
from typing import Tuple
|
||||
|
||||
from .gemm import get_best_configs, get_block_n_padding_for_smem_d
|
||||
@@ -37,25 +37,25 @@ gemm_t::run(out, rhs_scales, grouped_layout,
|
||||
"""
|
||||
|
||||
|
||||
-def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- out: torch.Tensor, m_indices: torch.Tensor) -> None:
|
||||
+def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[paddle.Tensor, paddle.Tensor],
|
||||
+ rhs: Tuple[paddle.Tensor, paddle.Tensor],
|
||||
+ out: paddle.Tensor, m_indices: paddle.Tensor) -> None:
|
||||
"""
|
||||
Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
||||
- this function will do a transposing with a set of slow PyTorch operations.
|
||||
+ this function will do a transposing with a set of slow Pypaddle operations.
|
||||
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
|
||||
`get_m_alignment_for_contiguous_layout()` (128).
|
||||
|
||||
Arguments:
|
||||
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
|
||||
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m_sum, k]`,
|
||||
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m_sum, ⌈k / 128⌉]`.
|
||||
- rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`.
|
||||
+ rhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, n, k]`.
|
||||
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
|
||||
out: the BF16 output tensor of shape `[m_sum, n]`, representing the result.
|
||||
- m_indices: a tensor of shape `[m_sum]` with type `torch.int`.
|
||||
+ m_indices: a tensor of shape `[m_sum]` with type `paddle.int`.
|
||||
`m_indices[i]` records the group which the i-th row of the LHS belong to,
|
||||
which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`.
|
||||
Values of `m_indices` in every-m-alignment-block must also be the same.
|
||||
@@ -68,19 +68,19 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
m__ = m_indices.numel()
|
||||
|
||||
# Type and shape checks
|
||||
- assert m == m_ == m__ and k == k_ and n == n_
|
||||
- assert lhs_scales.shape == (m, (k + 127) // 128)
|
||||
- assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
|
||||
- assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||
- assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||
- assert out.dtype == torch.bfloat16
|
||||
- assert m_indices.dtype == torch.int32
|
||||
- assert lhs.is_contiguous() and rhs.is_contiguous()
|
||||
- assert out.is_contiguous() and m_indices.is_contiguous()
|
||||
+ # assert m == m_ == m__ and k == k_ and n == n_
|
||||
+ # assert lhs_scales.shape == (m, (k + 127) // 128)
|
||||
+ # assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
|
||||
+ # assert lhs.dtype == paddle.float8_e4m3fn and lhs_scales.dtype == paddle.float32
|
||||
+ # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
|
||||
+ # assert out.dtype == paddle.bfloat16
|
||||
+ # assert m_indices.dtype == paddle.int32
|
||||
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
|
||||
+ # assert out.is_contiguous() and m_indices.is_contiguous()
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
- assert rhs_scales.is_contiguous()
|
||||
+ # assert rhs_scales.is_contiguous()
|
||||
|
||||
# Do nothing if `m` is zero
|
||||
if m == 0:
|
||||
@@ -92,7 +92,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms, is_grouped_contiguous=True)
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
||||
m_indices, m, num_groups,
|
||||
- torch.cuda.current_stream(), num_sms, smem_config[0])
|
||||
+ paddle.device.cuda.current_stream(), num_sms, smem_config[0])
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
@@ -105,11 +105,11 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
'GEMM_TYPE': 'GroupedContiguous'},
|
||||
space=(),
|
||||
includes=includes,
|
||||
- arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||
- ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
||||
- ('out', torch.bfloat16),
|
||||
- ('grouped_layout', torch.int32), ('m', int), ('num_groups', int),
|
||||
- ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
||||
+ arg_defs=(('lhs', paddle.float8_e4m3fn), ('lhs_scales', paddle.float32),
|
||||
+ ('rhs', paddle.float8_e4m3fn), ('rhs_scales', paddle.float32),
|
||||
+ ('out', paddle.bfloat16),
|
||||
+ ('grouped_layout', paddle.int32), ('m', int), ('num_groups', int),
|
||||
+ ('stream', paddle.device.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
||||
template=template,
|
||||
args=args
|
||||
)
|
||||
@@ -118,22 +118,22 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
runtime(*args)
|
||||
|
||||
|
||||
-def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
|
||||
+def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[paddle.Tensor, paddle.Tensor],
|
||||
+ rhs: Tuple[paddle.Tensor, paddle.Tensor],
|
||||
+ out: paddle.Tensor, masked_m: paddle.Tensor, expected_m: int) -> None:
|
||||
"""
|
||||
Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
||||
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
||||
RHS and RHS scaling factors are required to be transposed.
|
||||
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
||||
- this function will do a transposing with a set of slow PyTorch operations.
|
||||
+ this function will do a transposing with a set of slow paddle operations.
|
||||
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
|
||||
should be separately transposed.
|
||||
|
||||
Arguments:
|
||||
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
|
||||
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
|
||||
the second element is an FP32 1x128 scaling tensor for LHS of shape `[num_groups, m_max, ⌈k / 128⌉]`.
|
||||
- rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`.
|
||||
+ rhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, n, k]`.
|
||||
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
|
||||
out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result.
|
||||
masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
|
||||
@@ -149,21 +149,21 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
num_groups___ = masked_m.numel()
|
||||
|
||||
# Type and shape checks
|
||||
- assert num_groups == num_groups_ == num_groups__ == num_groups___
|
||||
- assert m == m_ and n == n_ and k == k_
|
||||
- assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
|
||||
- assert lhs_scales.shape == (num_groups, m, (k + 127) // 128)
|
||||
- assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
|
||||
- assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||
- assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||
- assert out.dtype == torch.bfloat16
|
||||
- assert masked_m.dtype == torch.int32
|
||||
- assert lhs.is_contiguous() and rhs.is_contiguous()
|
||||
- assert out.is_contiguous() and masked_m.is_contiguous()
|
||||
+ # assert num_groups == num_groups_ == num_groups__ == num_groups___
|
||||
+ # assert m == m_ and n == n_ and k == k_
|
||||
+ # assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
|
||||
+ # assert lhs_scales.shape == (num_groups, m, (k + 127) // 128)
|
||||
+ # assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
|
||||
+ # assert lhs.dtype == paddle.float8_e4m3fn and lhs_scales.dtype == paddle.float32
|
||||
+ # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
|
||||
+ # assert out.dtype == paddle.bfloat16
|
||||
+ # assert masked_m.dtype == paddle.int32
|
||||
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
|
||||
+ # assert out.is_contiguous() and masked_m.is_contiguous()
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
- assert rhs_scales.is_contiguous()
|
||||
+ # assert rhs_scales.is_contiguous()
|
||||
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
@@ -176,7 +176,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
||||
masked_m, m,
|
||||
- torch.cuda.current_stream(), num_sms, smem_config[0])
|
||||
+ paddle.device.cuda.current_stream(), num_sms, smem_config[0])
|
||||
runtime = jit_tuner.compile_and_tune(
|
||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||
@@ -189,11 +189,11 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
'GEMM_TYPE': 'GroupedMasked'},
|
||||
space=(),
|
||||
includes=includes,
|
||||
- arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
||||
- ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
||||
- ('out', torch.bfloat16),
|
||||
- ('grouped_layout', torch.int32), ('m', int),
|
||||
- ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
||||
+ arg_defs=(('lhs', paddle.float8_e4m3fn), ('lhs_scales', paddle.float32),
|
||||
+ ('rhs', paddle.float8_e4m3fn), ('rhs_scales', paddle.float32),
|
||||
+ ('out', paddle.bfloat16),
|
||||
+ ('grouped_layout', paddle.int32), ('m', int),
|
||||
+ ('stream', paddle.device.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
||||
template=template,
|
||||
args=args
|
||||
)
|
||||
diff --git a/deep_gemm/jit_kernels/tuner.py b/deep_gemm/jit_kernels/tuner.py
|
||||
index 6ed6749..9e1d70f 100644
|
||||
--- a/deep_gemm/jit_kernels/tuner.py
|
||||
+++ b/deep_gemm/jit_kernels/tuner.py
|
||||
@@ -1,6 +1,6 @@
|
||||
import copy
|
||||
import os
|
||||
-import torch
|
||||
+import paddle
|
||||
from typing import Any, Dict
|
||||
|
||||
from ..jit import build, cpp_format, generate, Runtime
|
||||
@@ -51,10 +51,10 @@ class JITTuner:
|
||||
continue
|
||||
|
||||
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
|
||||
- start_event = torch.cuda.Event(enable_timing=True)
|
||||
- end_event = torch.cuda.Event(enable_timing=True)
|
||||
- torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_()
|
||||
- torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
+ start_event = paddle.device.cuda.Event(enable_timing=True)
|
||||
+ end_event = paddle.device.cuda.Event(enable_timing=True)
|
||||
+ paddle.empty((int(256e6 // 4)), dtype=paddle.int32).zero_()
|
||||
+ paddle.randn((8192, 8192), dtype=paddle.float32) @ paddle.randn((8192, 8192), dtype=paddle.float32)
|
||||
start_event.record()
|
||||
for i in range(20):
|
||||
assert runtime(*args) == 0
|
||||
diff --git a/deep_gemm/jit_kernels/utils.py b/deep_gemm/jit_kernels/utils.py
|
||||
index c6da56b..a17b1b1 100644
|
||||
--- a/deep_gemm/jit_kernels/utils.py
|
||||
+++ b/deep_gemm/jit_kernels/utils.py
|
||||
@@ -1,4 +1,4 @@
|
||||
-import torch
|
||||
+import paddle
|
||||
|
||||
_num_sms = None
|
||||
|
||||
@@ -11,7 +11,7 @@ def set_num_sms(num_sms: int) -> None:
|
||||
num_sms: the desired maximum SM count for all GEMM kernels to use.
|
||||
"""
|
||||
global _num_sms
|
||||
- assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
||||
+ assert 0 < num_sms <= paddle.device.cuda.get_device_properties().multi_processor_count
|
||||
_num_sms = num_sms
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ def get_num_sms() -> int:
|
||||
"""
|
||||
global _num_sms
|
||||
if _num_sms is None:
|
||||
- _num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
||||
+ _num_sms = paddle.device.cuda.get_device_properties().multi_processor_count
|
||||
return _num_sms
|
||||
|
||||
|
||||
@@ -74,9 +74,9 @@ def get_tma_aligned_size(x: int, element_size: int) -> int:
|
||||
return ceil_div(x, alignment) * alignment
|
||||
|
||||
|
||||
-def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
+def get_col_major_tma_aligned_tensor(x: paddle.Tensor) -> paddle.Tensor:
|
||||
"""
|
||||
- Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary.
|
||||
+ Returns TMA-aligned transposed format of the input tensor. `paddle.transpose` will be called if necessary.
|
||||
If the input tensor is already column-major layout and 16-byte aligned along the M axis
|
||||
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
|
||||
|
||||
@@ -92,18 +92,20 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
m, n = x.shape[-2], x.shape[-1]
|
||||
aligned_m = get_tma_aligned_size(m, x.element_size())
|
||||
if x.dim() == 2:
|
||||
- if x.stride(0) == 1 and x.stride(1) == aligned_m:
|
||||
+ if x.strides[0] == 1 and x.strides[1] == aligned_m:
|
||||
return x
|
||||
x, remove_dim = x.unsqueeze(0), True
|
||||
|
||||
b = x.shape[0]
|
||||
|
||||
# The last kernel gives a column-major TMA aligned layout
|
||||
- if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
|
||||
+ if x.strides[0] == aligned_m * n and x.strides[1] == 1 and x.strides[2] == aligned_m:
|
||||
return x.squeeze(0) if remove_dim else x
|
||||
|
||||
# Normal layout requires transposing
|
||||
- aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
|
||||
+ aligned_x = paddle.transpose(
|
||||
+ paddle.empty((b, n, aligned_m), dtype=x.dtype), perm=[0, 2, 1]
|
||||
+ )
|
||||
aligned_x[:, :m, :] = x
|
||||
aligned_x = aligned_x[:, :m, :]
|
||||
return aligned_x.squeeze(0) if remove_dim else aligned_x
|
||||
diff --git a/deep_gemm/paddle_utils.py b/deep_gemm/paddle_utils.py
|
||||
new file mode 100644
|
||||
index 0000000..2326807
|
||||
--- /dev/null
|
||||
+++ b/deep_gemm/paddle_utils.py
|
||||
@@ -0,0 +1,20 @@
|
||||
+import os
|
||||
+
|
||||
+def get_cuda_home():
|
||||
+ """Get Cuda home directory"""
|
||||
+ cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
|
||||
+ if cuda_home:
|
||||
+ return cuda_home
|
||||
+
|
||||
+ try:
|
||||
+ which_cmd = "which nvcc"
|
||||
+
|
||||
+ nvcc_path = os.popen(which_cmd).read().strip()
|
||||
+ if nvcc_path:
|
||||
+ return os.path.dirname(os.path.dirname(nvcc_path))
|
||||
+ except Exception:
|
||||
+ pass
|
||||
+
|
||||
+ return None
|
||||
+
|
||||
+CUDA_HOME = get_cuda_home()
|
||||
\ No newline at end of file
|
||||
diff --git a/deep_gemm/utils.py b/deep_gemm/utils.py
|
||||
index d5cdd01..5237f09 100644
|
||||
--- a/deep_gemm/utils.py
|
||||
+++ b/deep_gemm/utils.py
|
||||
@@ -1,15 +1,15 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
-import torch
|
||||
-import torch.distributed as dist
|
||||
+import paddle
|
||||
+import paddle.distributed as dist
|
||||
|
||||
|
||||
def bench(fn, num_warmups: int = 5, num_tests: int = 10,
|
||||
high_precision: bool = False):
|
||||
# Flush L2 cache with 256 MB data
|
||||
- torch.cuda.synchronize()
|
||||
- cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
|
||||
+ paddle.device.cuda.synchronize()
|
||||
+ cache = paddle.empty((int(256e6 // 4)), dtype=paddle.int32)
|
||||
cache.zero_()
|
||||
|
||||
# Warmup
|
||||
@@ -18,18 +18,18 @@ def bench(fn, num_warmups: int = 5, num_tests: int = 10,
|
||||
|
||||
# Add a large kernel to eliminate the CPU launch overhead
|
||||
if high_precision:
|
||||
- x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
- y = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
+ x = paddle.randn((8192, 8192), dtype=paddle.float32)
|
||||
+ y = paddle.randn((8192, 8192), dtype=paddle.float32)
|
||||
x @ y
|
||||
|
||||
# Testing
|
||||
- start_event = torch.cuda.Event(enable_timing=True)
|
||||
- end_event = torch.cuda.Event(enable_timing=True)
|
||||
+ start_event = paddle.device.cuda.Event(enable_timing=True)
|
||||
+ end_event = paddle.device.cuda.Event(enable_timing=True)
|
||||
start_event.record()
|
||||
for i in range(num_tests):
|
||||
fn()
|
||||
end_event.record()
|
||||
- torch.cuda.synchronize()
|
||||
+ paddle.device.synchronize()
|
||||
|
||||
return start_event.elapsed_time(end_event) / num_tests
|
||||
|
||||
@@ -106,21 +106,21 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
|
||||
# Profile
|
||||
suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress
|
||||
with suppress():
|
||||
- schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None
|
||||
- profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress()
|
||||
+ scheduler = paddle.profiler.make_scheduler(closed=0, ready=1, record=1, repeat=1) if not using_nsys else None
|
||||
+ profiler = paddle.profiler.Profiler(targets=[paddle.profiler.ProfilerTarget.CPU, paddle.profiler.ProfilerTarget.GPU], scheduler=scheduler) if not using_nsys else empty_suppress()
|
||||
with profiler:
|
||||
for i in range(2):
|
||||
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
|
||||
if barrier_comm_profiling:
|
||||
- lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
- rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
+ lhs = paddle.randn((8192, 8192), dtype=paddle.float32)
|
||||
+ rhs = paddle.randn((8192, 8192), dtype=paddle.float32)
|
||||
lhs @ rhs
|
||||
- dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda'))
|
||||
+ dist.all_reduce(paddle.ones(1, dtype=paddle.float32))
|
||||
for _ in range(num_tests):
|
||||
if sleep_between_tests > 0.0:
|
||||
time.sleep(sleep_between_tests)
|
||||
if flush_l2:
|
||||
- torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
|
||||
+ paddle.empty(flush_l2_size, dtype=paddle.int32).zero_()
|
||||
fn()
|
||||
|
||||
if not using_nsys:
|
||||
--
|
||||
2.43.0
|
||||
|
||||
Reference in New Issue
Block a user