mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 20:02:53 +08:00 
			
		
		
		
	 240bdac2a4
			
		
	
	240bdac2a4
	
	
		
			
	
		
	
	
		
			Some checks failed
		
		
	
	Deploy GitHub Pages / deploy (push) Has been cancelled
				
			* support fa3 backend run in pd disaggregated * support fa3 backend run in pd disaggregated * support fa3 backend run in pd disaggregated * support fa3 backend run in pd disaggregated * delete use_fast_ffn
		
			
				
	
	
		
			643 lines
		
	
	
		
			31 KiB
		
	
	
	
		
			Diff
		
	
	
	
	
	
			
		
		
	
	
			643 lines
		
	
	
		
			31 KiB
		
	
	
	
		
			Diff
		
	
	
	
	
	
| From 5112002c155dceecc5e5983cdb67157e4f5400e2 Mon Sep 17 00:00:00 2001
 | |
| From: minghaipeng <minghaipeng@baidu.com>
 | |
| Date: Wed, 25 Jun 2025 15:05:24 +0800
 | |
| Subject: [PATCH] DeepGEMM 95e81b3
 | |
| 
 | |
| ---
 | |
|  deep_gemm/__init__.py                     |  2 +-
 | |
|  deep_gemm/include/deep_gemm/scheduler.cuh |  2 +-
 | |
|  deep_gemm/jit/compiler.py                 |  2 +-
 | |
|  deep_gemm/jit/interleave_ffma.py          |  2 +-
 | |
|  deep_gemm/jit/runtime.py                  |  4 +-
 | |
|  deep_gemm/jit/template.py                 | 34 ++++----
 | |
|  deep_gemm/jit_kernels/gemm.py             | 44 +++++------
 | |
|  deep_gemm/jit_kernels/m_grouped_gemm.py   | 96 +++++++++++------------
 | |
|  deep_gemm/jit_kernels/tuner.py            | 10 +--
 | |
|  deep_gemm/jit_kernels/utils.py            | 18 +++--
 | |
|  deep_gemm/paddle_utils.py                 | 20 +++++
 | |
|  deep_gemm/utils.py                        | 30 +++----
 | |
|  12 files changed, 143 insertions(+), 121 deletions(-)
 | |
|  create mode 100644 deep_gemm/paddle_utils.py
 | |
| 
 | |
| diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py
 | |
| index 15b22ca..63e7fb7 100644
 | |
| --- a/deep_gemm/__init__.py
 | |
| +++ b/deep_gemm/__init__.py
 | |
| @@ -1,4 +1,4 @@
 | |
| -import torch
 | |
| +import paddle
 | |
| 
 | |
|  from . import jit
 | |
|  from .jit_kernels import (
 | |
| diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh
 | |
| index 9743871..6c97152 100644
 | |
| --- a/deep_gemm/include/deep_gemm/scheduler.cuh
 | |
| +++ b/deep_gemm/include/deep_gemm/scheduler.cuh
 | |
| @@ -102,7 +102,7 @@ struct Scheduler {
 | |
|          if constexpr (kGemmType == GemmType::Normal) {
 | |
|              return block_idx * block_size;
 | |
|          } else if constexpr (kGemmType == GemmType::GroupedContiguous) {
 | |
| -            auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M);
 | |
| +            auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M));
 | |
|              return offset * shape_dim + block_idx * block_size;
 | |
|          } else if constexpr (kGemmType == GemmType::GroupedMasked) {
 | |
|              return curr_group_idx * shape_dim + block_idx * block_size;
 | |
| diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py
 | |
| index c17d466..6fdc52f 100644
 | |
| --- a/deep_gemm/jit/compiler.py
 | |
| +++ b/deep_gemm/jit/compiler.py
 | |
| @@ -4,7 +4,7 @@ import os
 | |
|  import re
 | |
|  import subprocess
 | |
|  import uuid
 | |
| -from torch.utils.cpp_extension import CUDA_HOME
 | |
| +from ..paddle_utils import CUDA_HOME
 | |
|  from typing import Tuple
 | |
| 
 | |
|  from . import interleave_ffma
 | |
| diff --git a/deep_gemm/jit/interleave_ffma.py b/deep_gemm/jit/interleave_ffma.py
 | |
| index fcb377e..db9d6f3 100644
 | |
| --- a/deep_gemm/jit/interleave_ffma.py
 | |
| +++ b/deep_gemm/jit/interleave_ffma.py
 | |
| @@ -3,7 +3,7 @@ import mmap
 | |
|  import os
 | |
|  import re
 | |
|  import subprocess
 | |
| -from torch.utils.cpp_extension import CUDA_HOME
 | |
| +from ..paddle_utils import CUDA_HOME
 | |
| 
 | |
| 
 | |
|  def run_cuobjdump(file_path):
 | |
| diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py
 | |
| index 66c370a..4761426 100644
 | |
| --- a/deep_gemm/jit/runtime.py
 | |
| +++ b/deep_gemm/jit/runtime.py
 | |
| @@ -1,6 +1,6 @@
 | |
|  import ctypes
 | |
|  import os
 | |
| -import torch
 | |
| +import paddle
 | |
|  from typing import Optional
 | |
| 
 | |
|  from .template import map_ctype
 | |
| @@ -35,7 +35,7 @@ class Runtime:
 | |
|          assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
 | |
|          cargs = []
 | |
|          for arg, (name, dtype) in zip(args, self.args):
 | |
| -            if isinstance(arg, torch.Tensor):
 | |
| +            if isinstance(arg, paddle.Tensor):
 | |
|                  assert arg.dtype == dtype, f'Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`'
 | |
|              else:
 | |
|                  assert isinstance(arg, dtype), f'Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`'
 | |
| diff --git a/deep_gemm/jit/template.py b/deep_gemm/jit/template.py
 | |
| index ead37f5..51b02c1 100644
 | |
| --- a/deep_gemm/jit/template.py
 | |
| +++ b/deep_gemm/jit/template.py
 | |
| @@ -1,24 +1,24 @@
 | |
|  import copy
 | |
|  import ctypes
 | |
|  import os
 | |
| -import torch
 | |
| +import paddle
 | |
|  from typing import Any, Dict, Iterable, Tuple
 | |
| 
 | |
| 
 | |
|  # Name map for Python `eval`
 | |
|  typename_map: Dict[Any, str] = {
 | |
|      **{t: t.__name__ for t in (bool, int, float)},
 | |
| -    torch.int: 'torch.int',
 | |
| -    torch.float: 'torch.float',
 | |
| -    torch.bfloat16: 'torch.bfloat16',
 | |
| -    torch.float8_e4m3fn: 'torch.float8_e4m3fn',
 | |
| -    torch.cuda.Stream: 'torch.cuda.Stream',
 | |
| +    paddle.int32: 'paddle.int32',
 | |
| +    paddle.float32: 'paddle.float32',
 | |
| +    paddle.bfloat16: 'paddle.bfloat16',
 | |
| +    paddle.float8_e4m3fn: 'paddle.float8_e4m3fn',
 | |
| +    paddle.device.cuda.Stream: "paddle.device.cuda.Stream",
 | |
|  }
 | |
| 
 | |
|  # `ctype` map for Python casting
 | |
|  ctype_map: Dict[Any, Any] = {
 | |
|      **{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
 | |
| -    **{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
 | |
| +    **{t: ctypes.c_void_p for t in (paddle.int32, paddle.float32, paddle.bfloat16, paddle.float8_e4m3fn, paddle.device.cuda.Stream)},
 | |
|  }
 | |
| 
 | |
| 
 | |
| @@ -27,25 +27,25 @@ genc_map = {
 | |
|      bool: ('bool', 'bool'),
 | |
|      int: ('int', 'int'),
 | |
|      float: ('float', 'float'),
 | |
| -    torch.int: ('void*', 'int*'),
 | |
| -    torch.float: ('void*', 'float*'),
 | |
| -    torch.bfloat16: ('void*', '__nv_bfloat16*'),
 | |
| -    torch.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
 | |
| -    torch.cuda.Stream: ('void*', 'cudaStream_t'),
 | |
| +    paddle.int32: ('void*', 'int*'),
 | |
| +    paddle.float32: ('void*', 'float*'),
 | |
| +    paddle.bfloat16: ('void*', '__nv_bfloat16*'),
 | |
| +    paddle.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
 | |
| +    paddle.device.cuda.Stream: ('void*', 'cudaStream_t'),
 | |
|  }
 | |
| 
 | |
| 
 | |
|  def map_ctype(value: Any) -> Any:
 | |
|      if hasattr(value, 'data_ptr'):
 | |
| -        if value.dtype == torch.int:
 | |
| +        if value.dtype == paddle.int32:
 | |
|              return ctypes.c_void_p(value.data_ptr())
 | |
| -        elif value.dtype == torch.float:
 | |
| +        elif value.dtype == paddle.float32:
 | |
|              return ctypes.c_void_p(value.data_ptr())
 | |
| -        elif value.dtype == torch.bfloat16:
 | |
| +        elif value.dtype == paddle.bfloat16:
 | |
|              return ctypes.c_void_p(value.data_ptr())
 | |
| -        elif value.dtype == torch.float16:
 | |
| +        elif value.dtype == paddle.float16:
 | |
|              return ctypes.c_void_p(value.data_ptr())
 | |
| -        elif value.dtype == torch.float8_e4m3fn:
 | |
| +        elif value.dtype == paddle.float8_e4m3fn:
 | |
|              return ctypes.c_void_p(value.data_ptr())
 | |
|          else:
 | |
|              return ctypes.c_void_p(value.data_ptr())
 | |
| diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py
 | |
| index cb438b7..44aa0ed 100644
 | |
| --- a/deep_gemm/jit_kernels/gemm.py
 | |
| +++ b/deep_gemm/jit_kernels/gemm.py
 | |
| @@ -1,5 +1,5 @@
 | |
|  import math
 | |
| -import torch
 | |
| +import paddle
 | |
|  from functools import lru_cache
 | |
|  from typing import Tuple
 | |
| 
 | |
| @@ -166,20 +166,20 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
 | |
|      return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config
 | |
| 
 | |
| 
 | |
| -def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
 | |
| -                         rhs: Tuple[torch.Tensor, torch.Tensor],
 | |
| -                         out: torch.Tensor) -> None:
 | |
| +def gemm_fp8_fp8_bf16_nt(lhs: Tuple[paddle.Tensor, paddle.Tensor],
 | |
| +                         rhs: Tuple[paddle.Tensor, paddle.Tensor],
 | |
| +                         out: paddle.Tensor) -> None:
 | |
|      """
 | |
|      Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
 | |
|      LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
 | |
|      RHS and RHS scaling factors are required to be transposed.
 | |
|      The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
 | |
| -        this function will do a transposing with a set of slow PyTorch operations.
 | |
| +        this function will do a transposing with a set of slow paddle operations.
 | |
| 
 | |
|      Arguments:
 | |
| -        lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
 | |
| +        lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m, k]`,
 | |
|               the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`.
 | |
| -        rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`.
 | |
| +        rhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[n, k]`.
 | |
|               the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`.
 | |
|          out: the BF16 output tensor of shape `[m, n]`, representing the result.
 | |
|      """
 | |
| @@ -189,22 +189,22 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
 | |
|      n, k_ = rhs.shape
 | |
|      m_, n_ = out.shape
 | |
| 
 | |
| -    assert n % 64 == 0 and k % 128 == 0
 | |
| +    # assert n % 64 == 0 and k % 128 == 0
 | |
| 
 | |
|      # Type and shape checks
 | |
| -    assert m == m_ and n == n_ and k == k_
 | |
| -    assert n > 0 and k > 0
 | |
| -    assert lhs_scales.shape == (m, (k + 127) // 128)
 | |
| -    assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128)
 | |
| -    assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
 | |
| -    assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
 | |
| -    assert out.dtype == torch.bfloat16
 | |
| -    assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
 | |
| +    # assert m == m_ and n == n_ and k == k_
 | |
| +    # assert n > 0 and k > 0
 | |
| +    # assert lhs_scales.shape == (m, (k + 127) // 128)
 | |
| +    # assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128)
 | |
| +    # assert lhs.dtype == paddle.float8_e4m3fn and lhs_scales.dtype == paddle.float32
 | |
| +    # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
 | |
| +    # assert out.dtype == paddle.bfloat16
 | |
| +    # assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
 | |
| 
 | |
|      # LHS scales must be transposed for TMA load, but not for RHS scales
 | |
|      # NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
 | |
|      lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
 | |
| -    assert rhs_scales.is_contiguous()
 | |
| +    # assert rhs_scales.is_contiguous()
 | |
| 
 | |
|      # Do nothing if `m` is zero
 | |
|      if m == 0:
 | |
| @@ -214,7 +214,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
 | |
|      global includes, template
 | |
|      num_sms = get_num_sms()
 | |
|      num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms)
 | |
| -    args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_config[0])
 | |
| +    args = (lhs, lhs_scales, rhs, rhs_scales, out, m, paddle.device.cuda.current_stream(), num_sms, smem_config[0])
 | |
|      runtime = jit_tuner.compile_and_tune(
 | |
|          name='gemm_fp8_fp8_bf16_nt',
 | |
|          keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
 | |
| @@ -225,10 +225,10 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
 | |
|                'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
 | |
|          space=(),
 | |
|          includes=includes,
 | |
| -        arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
 | |
| -                  ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
 | |
| -                  ('out', torch.bfloat16), ('m', int),
 | |
| -                  ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
 | |
| +        arg_defs=(('lhs', paddle.float8_e4m3fn), ('lhs_scales', paddle.float32),
 | |
| +                  ('rhs', paddle.float8_e4m3fn), ('rhs_scales', paddle.float32),
 | |
| +                  ('out', paddle.bfloat16), ('m', int),
 | |
| +                  ('stream', paddle.device.cuda.Stream), ('num_sms', int), ('smem_size', int)),
 | |
|          template=template,
 | |
|          args=args
 | |
|      )
 | |
| diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py
 | |
| index 3b518c9..ba776bd 100644
 | |
| --- a/deep_gemm/jit_kernels/m_grouped_gemm.py
 | |
| +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py
 | |
| @@ -1,4 +1,4 @@
 | |
| -import torch
 | |
| +import paddle
 | |
|  from typing import Tuple
 | |
| 
 | |
|  from .gemm import get_best_configs, get_block_n_padding_for_smem_d
 | |
| @@ -37,25 +37,25 @@ gemm_t::run(out, rhs_scales, grouped_layout,
 | |
|  """
 | |
| 
 | |
| 
 | |
| -def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
 | |
| -                                              rhs: Tuple[torch.Tensor, torch.Tensor],
 | |
| -                                              out: torch.Tensor, m_indices: torch.Tensor) -> None:
 | |
| +def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[paddle.Tensor, paddle.Tensor],
 | |
| +                                              rhs: Tuple[paddle.Tensor, paddle.Tensor],
 | |
| +                                              out: paddle.Tensor, m_indices: paddle.Tensor) -> None:
 | |
|      """
 | |
|      Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
 | |
|      LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
 | |
|      RHS and RHS scaling factors are required to be transposed.
 | |
|      The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
 | |
| -        this function will do a transposing with a set of slow PyTorch operations.
 | |
| +        this function will do a transposing with a set of slow Pypaddle operations.
 | |
|      On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
 | |
|          `get_m_alignment_for_contiguous_layout()` (128).
 | |
| 
 | |
|      Arguments:
 | |
| -        lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
 | |
| +        lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m_sum, k]`,
 | |
|               the second element is an FP32 1x128 scaling tensor for LHS of shape `[m_sum, ⌈k / 128⌉]`.
 | |
| -        rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`.
 | |
| +        rhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, n, k]`.
 | |
|               the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
 | |
|          out: the BF16 output tensor of shape `[m_sum, n]`, representing the result.
 | |
| -        m_indices: a tensor of shape `[m_sum]` with type `torch.int`.
 | |
| +        m_indices: a tensor of shape `[m_sum]` with type `paddle.int`.
 | |
|              `m_indices[i]` records the group which the i-th row of the LHS belong to,
 | |
|              which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`.
 | |
|              Values of `m_indices` in every-m-alignment-block must also be the same.
 | |
| @@ -68,19 +68,19 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
 | |
|      m__ = m_indices.numel()
 | |
| 
 | |
|      # Type and shape checks
 | |
| -    assert m == m_ == m__ and k == k_ and n == n_
 | |
| -    assert lhs_scales.shape == (m, (k + 127) // 128)
 | |
| -    assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
 | |
| -    assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
 | |
| -    assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
 | |
| -    assert out.dtype == torch.bfloat16
 | |
| -    assert m_indices.dtype == torch.int32
 | |
| -    assert lhs.is_contiguous() and rhs.is_contiguous()
 | |
| -    assert out.is_contiguous() and m_indices.is_contiguous()
 | |
| +    # assert m == m_ == m__ and k == k_ and n == n_
 | |
| +    # assert lhs_scales.shape == (m, (k + 127) // 128)
 | |
| +    # assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
 | |
| +    # assert lhs.dtype == paddle.float8_e4m3fn and lhs_scales.dtype == paddle.float32
 | |
| +    # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
 | |
| +    # assert out.dtype == paddle.bfloat16
 | |
| +    # assert m_indices.dtype == paddle.int32
 | |
| +    # assert lhs.is_contiguous() and rhs.is_contiguous()
 | |
| +    # assert out.is_contiguous() and m_indices.is_contiguous()
 | |
| 
 | |
|      # LHS scales must be transposed for TMA load, but not for RHS scales
 | |
|      lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
 | |
| -    assert rhs_scales.is_contiguous()
 | |
| +    # assert rhs_scales.is_contiguous()
 | |
| 
 | |
|      # Do nothing if `m` is zero
 | |
|      if m == 0:
 | |
| @@ -92,7 +92,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
 | |
|      num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms, is_grouped_contiguous=True)
 | |
|      args = (lhs, lhs_scales, rhs, rhs_scales, out,
 | |
|              m_indices, m, num_groups,
 | |
| -            torch.cuda.current_stream(), num_sms, smem_config[0])
 | |
| +            paddle.device.cuda.current_stream(), num_sms, smem_config[0])
 | |
|      runtime = jit_tuner.compile_and_tune(
 | |
|          name='m_grouped_gemm_fp8_fp8_bf16_nt',
 | |
|          keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
 | |
| @@ -105,11 +105,11 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
 | |
|                'GEMM_TYPE': 'GroupedContiguous'},
 | |
|          space=(),
 | |
|          includes=includes,
 | |
| -        arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
 | |
| -                  ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
 | |
| -                  ('out', torch.bfloat16),
 | |
| -                  ('grouped_layout', torch.int32), ('m', int), ('num_groups', int),
 | |
| -                  ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
 | |
| +        arg_defs=(('lhs', paddle.float8_e4m3fn), ('lhs_scales', paddle.float32),
 | |
| +                  ('rhs', paddle.float8_e4m3fn), ('rhs_scales', paddle.float32),
 | |
| +                  ('out', paddle.bfloat16),
 | |
| +                  ('grouped_layout', paddle.int32), ('m', int), ('num_groups', int),
 | |
| +                  ('stream', paddle.device.cuda.Stream), ('num_sms', int), ('smem_size', int)),
 | |
|          template=template,
 | |
|          args=args
 | |
|      )
 | |
| @@ -118,22 +118,22 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
 | |
|      runtime(*args)
 | |
| 
 | |
| 
 | |
| -def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
 | |
| -                                          rhs: Tuple[torch.Tensor, torch.Tensor],
 | |
| -                                          out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
 | |
| +def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[paddle.Tensor, paddle.Tensor],
 | |
| +                                          rhs: Tuple[paddle.Tensor, paddle.Tensor],
 | |
| +                                          out: paddle.Tensor, masked_m: paddle.Tensor, expected_m: int) -> None:
 | |
|      """
 | |
|      Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
 | |
|      LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
 | |
|      RHS and RHS scaling factors are required to be transposed.
 | |
|      The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
 | |
| -        this function will do a transposing with a set of slow PyTorch operations.
 | |
| +        this function will do a transposing with a set of slow paddle operations.
 | |
|      Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
 | |
|          should be separately transposed.
 | |
| 
 | |
|      Arguments:
 | |
| -        lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
 | |
| +        lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
 | |
|               the second element is an FP32 1x128 scaling tensor for LHS of shape `[num_groups, m_max, ⌈k / 128⌉]`.
 | |
| -        rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`.
 | |
| +        rhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, n, k]`.
 | |
|               the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
 | |
|          out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result.
 | |
|          masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
 | |
| @@ -149,21 +149,21 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
 | |
|      num_groups___ = masked_m.numel()
 | |
| 
 | |
|      # Type and shape checks
 | |
| -    assert num_groups == num_groups_ == num_groups__ == num_groups___
 | |
| -    assert m == m_ and n == n_ and k == k_
 | |
| -    assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
 | |
| -    assert lhs_scales.shape == (num_groups, m, (k + 127) // 128)
 | |
| -    assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
 | |
| -    assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
 | |
| -    assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
 | |
| -    assert out.dtype == torch.bfloat16
 | |
| -    assert masked_m.dtype == torch.int32
 | |
| -    assert lhs.is_contiguous() and rhs.is_contiguous()
 | |
| -    assert out.is_contiguous() and masked_m.is_contiguous()
 | |
| +    # assert num_groups == num_groups_ == num_groups__ == num_groups___
 | |
| +    # assert m == m_ and n == n_ and k == k_
 | |
| +    # assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
 | |
| +    # assert lhs_scales.shape == (num_groups, m, (k + 127) // 128)
 | |
| +    # assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
 | |
| +    # assert lhs.dtype == paddle.float8_e4m3fn and lhs_scales.dtype == paddle.float32
 | |
| +    # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
 | |
| +    # assert out.dtype == paddle.bfloat16
 | |
| +    # assert masked_m.dtype == paddle.int32
 | |
| +    # assert lhs.is_contiguous() and rhs.is_contiguous()
 | |
| +    # assert out.is_contiguous() and masked_m.is_contiguous()
 | |
| 
 | |
|      # LHS scales must be transposed for TMA load, but not for RHS scales
 | |
|      lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
 | |
| -    assert rhs_scales.is_contiguous()
 | |
| +    # assert rhs_scales.is_contiguous()
 | |
| 
 | |
|      # Auto-tuning with compilation
 | |
|      global includes, template
 | |
| @@ -176,7 +176,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
 | |
| 
 | |
|      args = (lhs, lhs_scales, rhs, rhs_scales, out,
 | |
|              masked_m, m,
 | |
| -            torch.cuda.current_stream(), num_sms, smem_config[0])
 | |
| +            paddle.device.cuda.current_stream(), num_sms, smem_config[0])
 | |
|      runtime = jit_tuner.compile_and_tune(
 | |
|          name='m_grouped_gemm_fp8_fp8_bf16_nt',
 | |
|          keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
 | |
| @@ -189,11 +189,11 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
 | |
|                'GEMM_TYPE': 'GroupedMasked'},
 | |
|          space=(),
 | |
|          includes=includes,
 | |
| -        arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
 | |
| -                  ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
 | |
| -                  ('out', torch.bfloat16),
 | |
| -                  ('grouped_layout', torch.int32), ('m', int),
 | |
| -                  ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
 | |
| +        arg_defs=(('lhs', paddle.float8_e4m3fn), ('lhs_scales', paddle.float32),
 | |
| +                  ('rhs', paddle.float8_e4m3fn), ('rhs_scales', paddle.float32),
 | |
| +                  ('out', paddle.bfloat16),
 | |
| +                  ('grouped_layout', paddle.int32), ('m', int),
 | |
| +                  ('stream', paddle.device.cuda.Stream), ('num_sms', int), ('smem_size', int)),
 | |
|          template=template,
 | |
|          args=args
 | |
|      )
 | |
| diff --git a/deep_gemm/jit_kernels/tuner.py b/deep_gemm/jit_kernels/tuner.py
 | |
| index 6ed6749..9e1d70f 100644
 | |
| --- a/deep_gemm/jit_kernels/tuner.py
 | |
| +++ b/deep_gemm/jit_kernels/tuner.py
 | |
| @@ -1,6 +1,6 @@
 | |
|  import copy
 | |
|  import os
 | |
| -import torch
 | |
| +import paddle
 | |
|  from typing import Any, Dict
 | |
| 
 | |
|  from ..jit import build, cpp_format, generate, Runtime
 | |
| @@ -51,10 +51,10 @@ class JITTuner:
 | |
|                      continue
 | |
| 
 | |
|                  # Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
 | |
| -                start_event = torch.cuda.Event(enable_timing=True)
 | |
| -                end_event = torch.cuda.Event(enable_timing=True)
 | |
| -                torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_()
 | |
| -                torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn((8192, 8192), dtype=torch.float, device='cuda')
 | |
| +                start_event = paddle.device.cuda.Event(enable_timing=True)
 | |
| +                end_event = paddle.device.cuda.Event(enable_timing=True)
 | |
| +                paddle.empty((int(256e6 // 4)), dtype=paddle.int32).zero_()
 | |
| +                paddle.randn((8192, 8192), dtype=paddle.float32) @ paddle.randn((8192, 8192), dtype=paddle.float32)
 | |
|                  start_event.record()
 | |
|                  for i in range(20):
 | |
|                      assert runtime(*args) == 0
 | |
| diff --git a/deep_gemm/jit_kernels/utils.py b/deep_gemm/jit_kernels/utils.py
 | |
| index c6da56b..a17b1b1 100644
 | |
| --- a/deep_gemm/jit_kernels/utils.py
 | |
| +++ b/deep_gemm/jit_kernels/utils.py
 | |
| @@ -1,4 +1,4 @@
 | |
| -import torch
 | |
| +import paddle
 | |
| 
 | |
|  _num_sms = None
 | |
| 
 | |
| @@ -11,7 +11,7 @@ def set_num_sms(num_sms: int) -> None:
 | |
|          num_sms: the desired maximum SM count for all GEMM kernels to use.
 | |
|      """
 | |
|      global _num_sms
 | |
| -    assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count
 | |
| +    assert 0 < num_sms <= paddle.device.cuda.get_device_properties().multi_processor_count
 | |
|      _num_sms = num_sms
 | |
| 
 | |
| 
 | |
| @@ -25,7 +25,7 @@ def get_num_sms() -> int:
 | |
|      """
 | |
|      global _num_sms
 | |
|      if _num_sms is None:
 | |
| -        _num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
 | |
| +        _num_sms = paddle.device.cuda.get_device_properties().multi_processor_count
 | |
|      return _num_sms
 | |
| 
 | |
| 
 | |
| @@ -74,9 +74,9 @@ def get_tma_aligned_size(x: int, element_size: int) -> int:
 | |
|      return ceil_div(x, alignment) * alignment
 | |
| 
 | |
| 
 | |
| -def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
 | |
| +def get_col_major_tma_aligned_tensor(x: paddle.Tensor) -> paddle.Tensor:
 | |
|      """
 | |
| -    Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary.
 | |
| +    Returns TMA-aligned transposed format of the input tensor. `paddle.transpose` will be called if necessary.
 | |
|      If the input tensor is already column-major layout and 16-byte aligned along the M axis
 | |
|          (thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
 | |
| 
 | |
| @@ -92,18 +92,20 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
 | |
|      m, n = x.shape[-2], x.shape[-1]
 | |
|      aligned_m = get_tma_aligned_size(m, x.element_size())
 | |
|      if x.dim() == 2:
 | |
| -        if x.stride(0) == 1 and x.stride(1) == aligned_m:
 | |
| +        if x.strides[0] == 1 and x.strides[1] == aligned_m:
 | |
|              return x
 | |
|          x, remove_dim = x.unsqueeze(0), True
 | |
| 
 | |
|      b = x.shape[0]
 | |
| 
 | |
|      # The last kernel gives a column-major TMA aligned layout
 | |
| -    if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
 | |
| +    if x.strides[0] == aligned_m * n and x.strides[1] == 1 and x.strides[2] == aligned_m:
 | |
|          return x.squeeze(0) if remove_dim else x
 | |
| 
 | |
|      # Normal layout requires transposing
 | |
| -    aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
 | |
| +    aligned_x = paddle.transpose(
 | |
| +        paddle.empty((b, n, aligned_m), dtype=x.dtype), perm=[0, 2, 1]
 | |
| +    )
 | |
|      aligned_x[:, :m, :] = x
 | |
|      aligned_x = aligned_x[:, :m, :]
 | |
|      return aligned_x.squeeze(0) if remove_dim else aligned_x
 | |
| diff --git a/deep_gemm/paddle_utils.py b/deep_gemm/paddle_utils.py
 | |
| new file mode 100644
 | |
| index 0000000..2326807
 | |
| --- /dev/null
 | |
| +++ b/deep_gemm/paddle_utils.py
 | |
| @@ -0,0 +1,20 @@
 | |
| +import os
 | |
| +
 | |
| +def get_cuda_home():
 | |
| +    """Get Cuda home directory"""
 | |
| +    cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
 | |
| +    if cuda_home:
 | |
| +        return cuda_home
 | |
| +
 | |
| +    try:
 | |
| +        which_cmd = "which nvcc"
 | |
| +
 | |
| +        nvcc_path = os.popen(which_cmd).read().strip()
 | |
| +        if nvcc_path:
 | |
| +            return os.path.dirname(os.path.dirname(nvcc_path))
 | |
| +    except Exception:
 | |
| +        pass
 | |
| +
 | |
| +    return None
 | |
| +
 | |
| +CUDA_HOME = get_cuda_home()
 | |
| \ No newline at end of file
 | |
| diff --git a/deep_gemm/utils.py b/deep_gemm/utils.py
 | |
| index d5cdd01..5237f09 100644
 | |
| --- a/deep_gemm/utils.py
 | |
| +++ b/deep_gemm/utils.py
 | |
| @@ -1,15 +1,15 @@
 | |
|  import os
 | |
|  import sys
 | |
|  import time
 | |
| -import torch
 | |
| -import torch.distributed as dist
 | |
| +import paddle
 | |
| +import paddle.distributed as dist
 | |
| 
 | |
| 
 | |
|  def bench(fn, num_warmups: int = 5, num_tests: int = 10,
 | |
|            high_precision: bool = False):
 | |
|      # Flush L2 cache with 256 MB data
 | |
| -    torch.cuda.synchronize()
 | |
| -    cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
 | |
| +    paddle.device.synchronize()
 | |
| +    cache = paddle.empty((int(256e6 // 4)), dtype=paddle.int32)
 | |
|      cache.zero_()
 | |
| 
 | |
|      # Warmup
 | |
| @@ -18,18 +18,18 @@ def bench(fn, num_warmups: int = 5, num_tests: int = 10,
 | |
| 
 | |
|      # Add a large kernel to eliminate the CPU launch overhead
 | |
|      if high_precision:
 | |
| -        x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
 | |
| -        y = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
 | |
| +        x = paddle.randn((8192, 8192), dtype=paddle.float32)
 | |
| +        y = paddle.randn((8192, 8192), dtype=paddle.float32)
 | |
|          x @ y
 | |
| 
 | |
|      # Testing
 | |
| -    start_event = torch.cuda.Event(enable_timing=True)
 | |
| -    end_event = torch.cuda.Event(enable_timing=True)
 | |
| +    start_event = paddle.device.cuda.Event(enable_timing=True)
 | |
| +    end_event = paddle.device.cuda.Event(enable_timing=True)
 | |
|      start_event.record()
 | |
|      for i in range(num_tests):
 | |
|          fn()
 | |
|      end_event.record()
 | |
| -    torch.cuda.synchronize()
 | |
| +    paddle.device.synchronize()
 | |
| 
 | |
|      return start_event.elapsed_time(end_event) / num_tests
 | |
| 
 | |
| @@ -106,21 +106,21 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
 | |
|      # Profile
 | |
|      suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress
 | |
|      with suppress():
 | |
| -        schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None
 | |
| -        profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress()
 | |
| +        scheduler = paddle.profiler.make_scheduler(closed=0, ready=1, record=1, repeat=1) if not using_nsys else None
 | |
| +        profiler = paddle.profiler.Profiler(targets=[paddle.profiler.ProfilerTarget.CPU, paddle.profiler.ProfilerTarget.GPU], scheduler=scheduler) if not using_nsys else empty_suppress()
 | |
|          with profiler:
 | |
|              for i in range(2):
 | |
|                  # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
 | |
|                  if barrier_comm_profiling:
 | |
| -                    lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
 | |
| -                    rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
 | |
| +                    lhs = paddle.randn((8192, 8192), dtype=paddle.float32)
 | |
| +                    rhs = paddle.randn((8192, 8192), dtype=paddle.float32)
 | |
|                      lhs @ rhs
 | |
| -                    dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda'))
 | |
| +                    dist.all_reduce(paddle.ones(1, dtype=paddle.float32))
 | |
|                  for _ in range(num_tests):
 | |
|                      if sleep_between_tests > 0.0:
 | |
|                          time.sleep(sleep_between_tests)
 | |
|                      if flush_l2:
 | |
| -                        torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
 | |
| +                        paddle.empty(flush_l2_size, dtype=paddle.int32).zero_()
 | |
|                      fn()
 | |
| 
 | |
|                  if not using_nsys:
 | |
| --
 | |
| 2.43.0
 |