mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00

Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* support fa3 backend run in pd disaggregated * support fa3 backend run in pd disaggregated * support fa3 backend run in pd disaggregated * support fa3 backend run in pd disaggregated * delete use_fast_ffn
643 lines
31 KiB
Diff
643 lines
31 KiB
Diff
From 5112002c155dceecc5e5983cdb67157e4f5400e2 Mon Sep 17 00:00:00 2001
|
|
From: minghaipeng <minghaipeng@baidu.com>
|
|
Date: Wed, 25 Jun 2025 15:05:24 +0800
|
|
Subject: [PATCH] DeepGEMM 95e81b3
|
|
|
|
---
|
|
deep_gemm/__init__.py | 2 +-
|
|
deep_gemm/include/deep_gemm/scheduler.cuh | 2 +-
|
|
deep_gemm/jit/compiler.py | 2 +-
|
|
deep_gemm/jit/interleave_ffma.py | 2 +-
|
|
deep_gemm/jit/runtime.py | 4 +-
|
|
deep_gemm/jit/template.py | 34 ++++----
|
|
deep_gemm/jit_kernels/gemm.py | 44 +++++------
|
|
deep_gemm/jit_kernels/m_grouped_gemm.py | 96 +++++++++++------------
|
|
deep_gemm/jit_kernels/tuner.py | 10 +--
|
|
deep_gemm/jit_kernels/utils.py | 18 +++--
|
|
deep_gemm/paddle_utils.py | 20 +++++
|
|
deep_gemm/utils.py | 30 +++----
|
|
12 files changed, 143 insertions(+), 121 deletions(-)
|
|
create mode 100644 deep_gemm/paddle_utils.py
|
|
|
|
diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py
|
|
index 15b22ca..63e7fb7 100644
|
|
--- a/deep_gemm/__init__.py
|
|
+++ b/deep_gemm/__init__.py
|
|
@@ -1,4 +1,4 @@
|
|
-import torch
|
|
+import paddle
|
|
|
|
from . import jit
|
|
from .jit_kernels import (
|
|
diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh
|
|
index 9743871..6c97152 100644
|
|
--- a/deep_gemm/include/deep_gemm/scheduler.cuh
|
|
+++ b/deep_gemm/include/deep_gemm/scheduler.cuh
|
|
@@ -102,7 +102,7 @@ struct Scheduler {
|
|
if constexpr (kGemmType == GemmType::Normal) {
|
|
return block_idx * block_size;
|
|
} else if constexpr (kGemmType == GemmType::GroupedContiguous) {
|
|
- auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M);
|
|
+ auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M));
|
|
return offset * shape_dim + block_idx * block_size;
|
|
} else if constexpr (kGemmType == GemmType::GroupedMasked) {
|
|
return curr_group_idx * shape_dim + block_idx * block_size;
|
|
diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py
|
|
index c17d466..6fdc52f 100644
|
|
--- a/deep_gemm/jit/compiler.py
|
|
+++ b/deep_gemm/jit/compiler.py
|
|
@@ -4,7 +4,7 @@ import os
|
|
import re
|
|
import subprocess
|
|
import uuid
|
|
-from torch.utils.cpp_extension import CUDA_HOME
|
|
+from ..paddle_utils import CUDA_HOME
|
|
from typing import Tuple
|
|
|
|
from . import interleave_ffma
|
|
diff --git a/deep_gemm/jit/interleave_ffma.py b/deep_gemm/jit/interleave_ffma.py
|
|
index fcb377e..db9d6f3 100644
|
|
--- a/deep_gemm/jit/interleave_ffma.py
|
|
+++ b/deep_gemm/jit/interleave_ffma.py
|
|
@@ -3,7 +3,7 @@ import mmap
|
|
import os
|
|
import re
|
|
import subprocess
|
|
-from torch.utils.cpp_extension import CUDA_HOME
|
|
+from ..paddle_utils import CUDA_HOME
|
|
|
|
|
|
def run_cuobjdump(file_path):
|
|
diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py
|
|
index 66c370a..4761426 100644
|
|
--- a/deep_gemm/jit/runtime.py
|
|
+++ b/deep_gemm/jit/runtime.py
|
|
@@ -1,6 +1,6 @@
|
|
import ctypes
|
|
import os
|
|
-import torch
|
|
+import paddle
|
|
from typing import Optional
|
|
|
|
from .template import map_ctype
|
|
@@ -35,7 +35,7 @@ class Runtime:
|
|
assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
|
|
cargs = []
|
|
for arg, (name, dtype) in zip(args, self.args):
|
|
- if isinstance(arg, torch.Tensor):
|
|
+ if isinstance(arg, paddle.Tensor):
|
|
assert arg.dtype == dtype, f'Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`'
|
|
else:
|
|
assert isinstance(arg, dtype), f'Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`'
|
|
diff --git a/deep_gemm/jit/template.py b/deep_gemm/jit/template.py
|
|
index ead37f5..51b02c1 100644
|
|
--- a/deep_gemm/jit/template.py
|
|
+++ b/deep_gemm/jit/template.py
|
|
@@ -1,24 +1,24 @@
|
|
import copy
|
|
import ctypes
|
|
import os
|
|
-import torch
|
|
+import paddle
|
|
from typing import Any, Dict, Iterable, Tuple
|
|
|
|
|
|
# Name map for Python `eval`
|
|
typename_map: Dict[Any, str] = {
|
|
**{t: t.__name__ for t in (bool, int, float)},
|
|
- torch.int: 'torch.int',
|
|
- torch.float: 'torch.float',
|
|
- torch.bfloat16: 'torch.bfloat16',
|
|
- torch.float8_e4m3fn: 'torch.float8_e4m3fn',
|
|
- torch.cuda.Stream: 'torch.cuda.Stream',
|
|
+ paddle.int32: 'paddle.int32',
|
|
+ paddle.float32: 'paddle.float32',
|
|
+ paddle.bfloat16: 'paddle.bfloat16',
|
|
+ paddle.float8_e4m3fn: 'paddle.float8_e4m3fn',
|
|
+ paddle.device.cuda.Stream: "paddle.device.cuda.Stream",
|
|
}
|
|
|
|
# `ctype` map for Python casting
|
|
ctype_map: Dict[Any, Any] = {
|
|
**{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
|
|
- **{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
|
|
+ **{t: ctypes.c_void_p for t in (paddle.int32, paddle.float32, paddle.bfloat16, paddle.float8_e4m3fn, paddle.device.cuda.Stream)},
|
|
}
|
|
|
|
|
|
@@ -27,25 +27,25 @@ genc_map = {
|
|
bool: ('bool', 'bool'),
|
|
int: ('int', 'int'),
|
|
float: ('float', 'float'),
|
|
- torch.int: ('void*', 'int*'),
|
|
- torch.float: ('void*', 'float*'),
|
|
- torch.bfloat16: ('void*', '__nv_bfloat16*'),
|
|
- torch.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
|
|
- torch.cuda.Stream: ('void*', 'cudaStream_t'),
|
|
+ paddle.int32: ('void*', 'int*'),
|
|
+ paddle.float32: ('void*', 'float*'),
|
|
+ paddle.bfloat16: ('void*', '__nv_bfloat16*'),
|
|
+ paddle.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
|
|
+ paddle.device.cuda.Stream: ('void*', 'cudaStream_t'),
|
|
}
|
|
|
|
|
|
def map_ctype(value: Any) -> Any:
|
|
if hasattr(value, 'data_ptr'):
|
|
- if value.dtype == torch.int:
|
|
+ if value.dtype == paddle.int32:
|
|
return ctypes.c_void_p(value.data_ptr())
|
|
- elif value.dtype == torch.float:
|
|
+ elif value.dtype == paddle.float32:
|
|
return ctypes.c_void_p(value.data_ptr())
|
|
- elif value.dtype == torch.bfloat16:
|
|
+ elif value.dtype == paddle.bfloat16:
|
|
return ctypes.c_void_p(value.data_ptr())
|
|
- elif value.dtype == torch.float16:
|
|
+ elif value.dtype == paddle.float16:
|
|
return ctypes.c_void_p(value.data_ptr())
|
|
- elif value.dtype == torch.float8_e4m3fn:
|
|
+ elif value.dtype == paddle.float8_e4m3fn:
|
|
return ctypes.c_void_p(value.data_ptr())
|
|
else:
|
|
return ctypes.c_void_p(value.data_ptr())
|
|
diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py
|
|
index cb438b7..44aa0ed 100644
|
|
--- a/deep_gemm/jit_kernels/gemm.py
|
|
+++ b/deep_gemm/jit_kernels/gemm.py
|
|
@@ -1,5 +1,5 @@
|
|
import math
|
|
-import torch
|
|
+import paddle
|
|
from functools import lru_cache
|
|
from typing import Tuple
|
|
|
|
@@ -166,20 +166,20 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
|
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config
|
|
|
|
|
|
-def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|
- rhs: Tuple[torch.Tensor, torch.Tensor],
|
|
- out: torch.Tensor) -> None:
|
|
+def gemm_fp8_fp8_bf16_nt(lhs: Tuple[paddle.Tensor, paddle.Tensor],
|
|
+ rhs: Tuple[paddle.Tensor, paddle.Tensor],
|
|
+ out: paddle.Tensor) -> None:
|
|
"""
|
|
Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
|
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
|
RHS and RHS scaling factors are required to be transposed.
|
|
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
|
- this function will do a transposing with a set of slow PyTorch operations.
|
|
+ this function will do a transposing with a set of slow paddle operations.
|
|
|
|
Arguments:
|
|
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
|
|
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m, k]`,
|
|
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`.
|
|
- rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`.
|
|
+ rhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[n, k]`.
|
|
the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`.
|
|
out: the BF16 output tensor of shape `[m, n]`, representing the result.
|
|
"""
|
|
@@ -189,22 +189,22 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|
n, k_ = rhs.shape
|
|
m_, n_ = out.shape
|
|
|
|
- assert n % 64 == 0 and k % 128 == 0
|
|
+ # assert n % 64 == 0 and k % 128 == 0
|
|
|
|
# Type and shape checks
|
|
- assert m == m_ and n == n_ and k == k_
|
|
- assert n > 0 and k > 0
|
|
- assert lhs_scales.shape == (m, (k + 127) // 128)
|
|
- assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128)
|
|
- assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
|
- assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
|
- assert out.dtype == torch.bfloat16
|
|
- assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
|
|
+ # assert m == m_ and n == n_ and k == k_
|
|
+ # assert n > 0 and k > 0
|
|
+ # assert lhs_scales.shape == (m, (k + 127) // 128)
|
|
+ # assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128)
|
|
+ # assert lhs.dtype == paddle.float8_e4m3fn and lhs_scales.dtype == paddle.float32
|
|
+ # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
|
|
+ # assert out.dtype == paddle.bfloat16
|
|
+ # assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
|
|
|
|
# LHS scales must be transposed for TMA load, but not for RHS scales
|
|
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
|
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
|
- assert rhs_scales.is_contiguous()
|
|
+ # assert rhs_scales.is_contiguous()
|
|
|
|
# Do nothing if `m` is zero
|
|
if m == 0:
|
|
@@ -214,7 +214,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|
global includes, template
|
|
num_sms = get_num_sms()
|
|
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms)
|
|
- args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_config[0])
|
|
+ args = (lhs, lhs_scales, rhs, rhs_scales, out, m, paddle.device.cuda.current_stream(), num_sms, smem_config[0])
|
|
runtime = jit_tuner.compile_and_tune(
|
|
name='gemm_fp8_fp8_bf16_nt',
|
|
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
|
@@ -225,10 +225,10 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
|
|
space=(),
|
|
includes=includes,
|
|
- arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
|
- ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
|
- ('out', torch.bfloat16), ('m', int),
|
|
- ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
|
+ arg_defs=(('lhs', paddle.float8_e4m3fn), ('lhs_scales', paddle.float32),
|
|
+ ('rhs', paddle.float8_e4m3fn), ('rhs_scales', paddle.float32),
|
|
+ ('out', paddle.bfloat16), ('m', int),
|
|
+ ('stream', paddle.device.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
|
template=template,
|
|
args=args
|
|
)
|
|
diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py
|
|
index 3b518c9..ba776bd 100644
|
|
--- a/deep_gemm/jit_kernels/m_grouped_gemm.py
|
|
+++ b/deep_gemm/jit_kernels/m_grouped_gemm.py
|
|
@@ -1,4 +1,4 @@
|
|
-import torch
|
|
+import paddle
|
|
from typing import Tuple
|
|
|
|
from .gemm import get_best_configs, get_block_n_padding_for_smem_d
|
|
@@ -37,25 +37,25 @@ gemm_t::run(out, rhs_scales, grouped_layout,
|
|
"""
|
|
|
|
|
|
-def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|
- rhs: Tuple[torch.Tensor, torch.Tensor],
|
|
- out: torch.Tensor, m_indices: torch.Tensor) -> None:
|
|
+def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[paddle.Tensor, paddle.Tensor],
|
|
+ rhs: Tuple[paddle.Tensor, paddle.Tensor],
|
|
+ out: paddle.Tensor, m_indices: paddle.Tensor) -> None:
|
|
"""
|
|
Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
|
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
|
RHS and RHS scaling factors are required to be transposed.
|
|
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
|
- this function will do a transposing with a set of slow PyTorch operations.
|
|
+ this function will do a transposing with a set of slow Pypaddle operations.
|
|
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
|
|
`get_m_alignment_for_contiguous_layout()` (128).
|
|
|
|
Arguments:
|
|
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
|
|
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m_sum, k]`,
|
|
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m_sum, ⌈k / 128⌉]`.
|
|
- rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`.
|
|
+ rhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, n, k]`.
|
|
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
|
|
out: the BF16 output tensor of shape `[m_sum, n]`, representing the result.
|
|
- m_indices: a tensor of shape `[m_sum]` with type `torch.int`.
|
|
+ m_indices: a tensor of shape `[m_sum]` with type `paddle.int`.
|
|
`m_indices[i]` records the group which the i-th row of the LHS belong to,
|
|
which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`.
|
|
Values of `m_indices` in every-m-alignment-block must also be the same.
|
|
@@ -68,19 +68,19 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
|
m__ = m_indices.numel()
|
|
|
|
# Type and shape checks
|
|
- assert m == m_ == m__ and k == k_ and n == n_
|
|
- assert lhs_scales.shape == (m, (k + 127) // 128)
|
|
- assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
|
|
- assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
|
- assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
|
- assert out.dtype == torch.bfloat16
|
|
- assert m_indices.dtype == torch.int32
|
|
- assert lhs.is_contiguous() and rhs.is_contiguous()
|
|
- assert out.is_contiguous() and m_indices.is_contiguous()
|
|
+ # assert m == m_ == m__ and k == k_ and n == n_
|
|
+ # assert lhs_scales.shape == (m, (k + 127) // 128)
|
|
+ # assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
|
|
+ # assert lhs.dtype == paddle.float8_e4m3fn and lhs_scales.dtype == paddle.float32
|
|
+ # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
|
|
+ # assert out.dtype == paddle.bfloat16
|
|
+ # assert m_indices.dtype == paddle.int32
|
|
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
|
|
+ # assert out.is_contiguous() and m_indices.is_contiguous()
|
|
|
|
# LHS scales must be transposed for TMA load, but not for RHS scales
|
|
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
|
- assert rhs_scales.is_contiguous()
|
|
+ # assert rhs_scales.is_contiguous()
|
|
|
|
# Do nothing if `m` is zero
|
|
if m == 0:
|
|
@@ -92,7 +92,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
|
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms, is_grouped_contiguous=True)
|
|
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
|
m_indices, m, num_groups,
|
|
- torch.cuda.current_stream(), num_sms, smem_config[0])
|
|
+ paddle.device.cuda.current_stream(), num_sms, smem_config[0])
|
|
runtime = jit_tuner.compile_and_tune(
|
|
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
|
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
|
@@ -105,11 +105,11 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
|
'GEMM_TYPE': 'GroupedContiguous'},
|
|
space=(),
|
|
includes=includes,
|
|
- arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
|
- ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
|
- ('out', torch.bfloat16),
|
|
- ('grouped_layout', torch.int32), ('m', int), ('num_groups', int),
|
|
- ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
|
+ arg_defs=(('lhs', paddle.float8_e4m3fn), ('lhs_scales', paddle.float32),
|
|
+ ('rhs', paddle.float8_e4m3fn), ('rhs_scales', paddle.float32),
|
|
+ ('out', paddle.bfloat16),
|
|
+ ('grouped_layout', paddle.int32), ('m', int), ('num_groups', int),
|
|
+ ('stream', paddle.device.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
|
template=template,
|
|
args=args
|
|
)
|
|
@@ -118,22 +118,22 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
|
runtime(*args)
|
|
|
|
|
|
-def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|
- rhs: Tuple[torch.Tensor, torch.Tensor],
|
|
- out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
|
|
+def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[paddle.Tensor, paddle.Tensor],
|
|
+ rhs: Tuple[paddle.Tensor, paddle.Tensor],
|
|
+ out: paddle.Tensor, masked_m: paddle.Tensor, expected_m: int) -> None:
|
|
"""
|
|
Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
|
|
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
|
|
RHS and RHS scaling factors are required to be transposed.
|
|
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
|
- this function will do a transposing with a set of slow PyTorch operations.
|
|
+ this function will do a transposing with a set of slow paddle operations.
|
|
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
|
|
should be separately transposed.
|
|
|
|
Arguments:
|
|
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
|
|
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
|
|
the second element is an FP32 1x128 scaling tensor for LHS of shape `[num_groups, m_max, ⌈k / 128⌉]`.
|
|
- rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`.
|
|
+ rhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, n, k]`.
|
|
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
|
|
out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result.
|
|
masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
|
|
@@ -149,21 +149,21 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
|
num_groups___ = masked_m.numel()
|
|
|
|
# Type and shape checks
|
|
- assert num_groups == num_groups_ == num_groups__ == num_groups___
|
|
- assert m == m_ and n == n_ and k == k_
|
|
- assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
|
|
- assert lhs_scales.shape == (num_groups, m, (k + 127) // 128)
|
|
- assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
|
|
- assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
|
- assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
|
- assert out.dtype == torch.bfloat16
|
|
- assert masked_m.dtype == torch.int32
|
|
- assert lhs.is_contiguous() and rhs.is_contiguous()
|
|
- assert out.is_contiguous() and masked_m.is_contiguous()
|
|
+ # assert num_groups == num_groups_ == num_groups__ == num_groups___
|
|
+ # assert m == m_ and n == n_ and k == k_
|
|
+ # assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
|
|
+ # assert lhs_scales.shape == (num_groups, m, (k + 127) // 128)
|
|
+ # assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
|
|
+ # assert lhs.dtype == paddle.float8_e4m3fn and lhs_scales.dtype == paddle.float32
|
|
+ # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
|
|
+ # assert out.dtype == paddle.bfloat16
|
|
+ # assert masked_m.dtype == paddle.int32
|
|
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
|
|
+ # assert out.is_contiguous() and masked_m.is_contiguous()
|
|
|
|
# LHS scales must be transposed for TMA load, but not for RHS scales
|
|
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
|
- assert rhs_scales.is_contiguous()
|
|
+ # assert rhs_scales.is_contiguous()
|
|
|
|
# Auto-tuning with compilation
|
|
global includes, template
|
|
@@ -176,7 +176,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
|
|
|
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
|
masked_m, m,
|
|
- torch.cuda.current_stream(), num_sms, smem_config[0])
|
|
+ paddle.device.cuda.current_stream(), num_sms, smem_config[0])
|
|
runtime = jit_tuner.compile_and_tune(
|
|
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
|
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
|
@@ -189,11 +189,11 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
|
'GEMM_TYPE': 'GroupedMasked'},
|
|
space=(),
|
|
includes=includes,
|
|
- arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
|
|
- ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
|
|
- ('out', torch.bfloat16),
|
|
- ('grouped_layout', torch.int32), ('m', int),
|
|
- ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
|
+ arg_defs=(('lhs', paddle.float8_e4m3fn), ('lhs_scales', paddle.float32),
|
|
+ ('rhs', paddle.float8_e4m3fn), ('rhs_scales', paddle.float32),
|
|
+ ('out', paddle.bfloat16),
|
|
+ ('grouped_layout', paddle.int32), ('m', int),
|
|
+ ('stream', paddle.device.cuda.Stream), ('num_sms', int), ('smem_size', int)),
|
|
template=template,
|
|
args=args
|
|
)
|
|
diff --git a/deep_gemm/jit_kernels/tuner.py b/deep_gemm/jit_kernels/tuner.py
|
|
index 6ed6749..9e1d70f 100644
|
|
--- a/deep_gemm/jit_kernels/tuner.py
|
|
+++ b/deep_gemm/jit_kernels/tuner.py
|
|
@@ -1,6 +1,6 @@
|
|
import copy
|
|
import os
|
|
-import torch
|
|
+import paddle
|
|
from typing import Any, Dict
|
|
|
|
from ..jit import build, cpp_format, generate, Runtime
|
|
@@ -51,10 +51,10 @@ class JITTuner:
|
|
continue
|
|
|
|
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
|
|
- start_event = torch.cuda.Event(enable_timing=True)
|
|
- end_event = torch.cuda.Event(enable_timing=True)
|
|
- torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_()
|
|
- torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
|
+ start_event = paddle.device.cuda.Event(enable_timing=True)
|
|
+ end_event = paddle.device.cuda.Event(enable_timing=True)
|
|
+ paddle.empty((int(256e6 // 4)), dtype=paddle.int32).zero_()
|
|
+ paddle.randn((8192, 8192), dtype=paddle.float32) @ paddle.randn((8192, 8192), dtype=paddle.float32)
|
|
start_event.record()
|
|
for i in range(20):
|
|
assert runtime(*args) == 0
|
|
diff --git a/deep_gemm/jit_kernels/utils.py b/deep_gemm/jit_kernels/utils.py
|
|
index c6da56b..a17b1b1 100644
|
|
--- a/deep_gemm/jit_kernels/utils.py
|
|
+++ b/deep_gemm/jit_kernels/utils.py
|
|
@@ -1,4 +1,4 @@
|
|
-import torch
|
|
+import paddle
|
|
|
|
_num_sms = None
|
|
|
|
@@ -11,7 +11,7 @@ def set_num_sms(num_sms: int) -> None:
|
|
num_sms: the desired maximum SM count for all GEMM kernels to use.
|
|
"""
|
|
global _num_sms
|
|
- assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
|
+ assert 0 < num_sms <= paddle.device.cuda.get_device_properties().multi_processor_count
|
|
_num_sms = num_sms
|
|
|
|
|
|
@@ -25,7 +25,7 @@ def get_num_sms() -> int:
|
|
"""
|
|
global _num_sms
|
|
if _num_sms is None:
|
|
- _num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
|
+ _num_sms = paddle.device.cuda.get_device_properties().multi_processor_count
|
|
return _num_sms
|
|
|
|
|
|
@@ -74,9 +74,9 @@ def get_tma_aligned_size(x: int, element_size: int) -> int:
|
|
return ceil_div(x, alignment) * alignment
|
|
|
|
|
|
-def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
|
+def get_col_major_tma_aligned_tensor(x: paddle.Tensor) -> paddle.Tensor:
|
|
"""
|
|
- Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary.
|
|
+ Returns TMA-aligned transposed format of the input tensor. `paddle.transpose` will be called if necessary.
|
|
If the input tensor is already column-major layout and 16-byte aligned along the M axis
|
|
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
|
|
|
|
@@ -92,18 +92,20 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
|
m, n = x.shape[-2], x.shape[-1]
|
|
aligned_m = get_tma_aligned_size(m, x.element_size())
|
|
if x.dim() == 2:
|
|
- if x.stride(0) == 1 and x.stride(1) == aligned_m:
|
|
+ if x.strides[0] == 1 and x.strides[1] == aligned_m:
|
|
return x
|
|
x, remove_dim = x.unsqueeze(0), True
|
|
|
|
b = x.shape[0]
|
|
|
|
# The last kernel gives a column-major TMA aligned layout
|
|
- if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
|
|
+ if x.strides[0] == aligned_m * n and x.strides[1] == 1 and x.strides[2] == aligned_m:
|
|
return x.squeeze(0) if remove_dim else x
|
|
|
|
# Normal layout requires transposing
|
|
- aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
|
|
+ aligned_x = paddle.transpose(
|
|
+ paddle.empty((b, n, aligned_m), dtype=x.dtype), perm=[0, 2, 1]
|
|
+ )
|
|
aligned_x[:, :m, :] = x
|
|
aligned_x = aligned_x[:, :m, :]
|
|
return aligned_x.squeeze(0) if remove_dim else aligned_x
|
|
diff --git a/deep_gemm/paddle_utils.py b/deep_gemm/paddle_utils.py
|
|
new file mode 100644
|
|
index 0000000..2326807
|
|
--- /dev/null
|
|
+++ b/deep_gemm/paddle_utils.py
|
|
@@ -0,0 +1,20 @@
|
|
+import os
|
|
+
|
|
+def get_cuda_home():
|
|
+ """Get Cuda home directory"""
|
|
+ cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
|
|
+ if cuda_home:
|
|
+ return cuda_home
|
|
+
|
|
+ try:
|
|
+ which_cmd = "which nvcc"
|
|
+
|
|
+ nvcc_path = os.popen(which_cmd).read().strip()
|
|
+ if nvcc_path:
|
|
+ return os.path.dirname(os.path.dirname(nvcc_path))
|
|
+ except Exception:
|
|
+ pass
|
|
+
|
|
+ return None
|
|
+
|
|
+CUDA_HOME = get_cuda_home()
|
|
\ No newline at end of file
|
|
diff --git a/deep_gemm/utils.py b/deep_gemm/utils.py
|
|
index d5cdd01..5237f09 100644
|
|
--- a/deep_gemm/utils.py
|
|
+++ b/deep_gemm/utils.py
|
|
@@ -1,15 +1,15 @@
|
|
import os
|
|
import sys
|
|
import time
|
|
-import torch
|
|
-import torch.distributed as dist
|
|
+import paddle
|
|
+import paddle.distributed as dist
|
|
|
|
|
|
def bench(fn, num_warmups: int = 5, num_tests: int = 10,
|
|
high_precision: bool = False):
|
|
# Flush L2 cache with 256 MB data
|
|
- torch.cuda.synchronize()
|
|
- cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
|
|
+ paddle.device.synchronize()
|
|
+ cache = paddle.empty((int(256e6 // 4)), dtype=paddle.int32)
|
|
cache.zero_()
|
|
|
|
# Warmup
|
|
@@ -18,18 +18,18 @@ def bench(fn, num_warmups: int = 5, num_tests: int = 10,
|
|
|
|
# Add a large kernel to eliminate the CPU launch overhead
|
|
if high_precision:
|
|
- x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
|
- y = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
|
+ x = paddle.randn((8192, 8192), dtype=paddle.float32)
|
|
+ y = paddle.randn((8192, 8192), dtype=paddle.float32)
|
|
x @ y
|
|
|
|
# Testing
|
|
- start_event = torch.cuda.Event(enable_timing=True)
|
|
- end_event = torch.cuda.Event(enable_timing=True)
|
|
+ start_event = paddle.device.cuda.Event(enable_timing=True)
|
|
+ end_event = paddle.device.cuda.Event(enable_timing=True)
|
|
start_event.record()
|
|
for i in range(num_tests):
|
|
fn()
|
|
end_event.record()
|
|
- torch.cuda.synchronize()
|
|
+ paddle.device.synchronize()
|
|
|
|
return start_event.elapsed_time(end_event) / num_tests
|
|
|
|
@@ -106,21 +106,21 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
|
|
# Profile
|
|
suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress
|
|
with suppress():
|
|
- schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None
|
|
- profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress()
|
|
+ scheduler = paddle.profiler.make_scheduler(closed=0, ready=1, record=1, repeat=1) if not using_nsys else None
|
|
+ profiler = paddle.profiler.Profiler(targets=[paddle.profiler.ProfilerTarget.CPU, paddle.profiler.ProfilerTarget.GPU], scheduler=scheduler) if not using_nsys else empty_suppress()
|
|
with profiler:
|
|
for i in range(2):
|
|
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
|
|
if barrier_comm_profiling:
|
|
- lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
|
- rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
|
+ lhs = paddle.randn((8192, 8192), dtype=paddle.float32)
|
|
+ rhs = paddle.randn((8192, 8192), dtype=paddle.float32)
|
|
lhs @ rhs
|
|
- dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda'))
|
|
+ dist.all_reduce(paddle.ones(1, dtype=paddle.float32))
|
|
for _ in range(num_tests):
|
|
if sleep_between_tests > 0.0:
|
|
time.sleep(sleep_between_tests)
|
|
if flush_l2:
|
|
- torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
|
|
+ paddle.empty(flush_l2_size, dtype=paddle.int32).zero_()
|
|
fn()
|
|
|
|
if not using_nsys:
|
|
--
|
|
2.43.0
|