mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Deterministic] Move paddle version batch invariant pkg to Fastdeploy (#4763)
* Move batch invariant pkg to Fastdeploy * fix problem and pre-commit * move test * Change testcase to FD style * Add testcase for log_softmax * Add testcase for mean * Add testcase for addmm * fix pre-commit * API check v0.9 * move to layers and add comment about log_softmax * Update fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py 存在于原版代码注释中的版本控制遗留的内容,确实应该去除 Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/batch_invariant/test_batch_invariance_op_mean.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/batch_invariant/test_batch_invariance_op_logsoftmax.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * change comment after copilot fix * fix bug about addmm * avoid global effect by enable_torch_proxy --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,25 @@
|
||||
from .batch_invariant_ops import (
|
||||
AttentionBlockSize,
|
||||
disable_batch_invariant_mode,
|
||||
enable_batch_invariant_mode,
|
||||
get_batch_invariant_attention_block_size,
|
||||
is_batch_invariant_mode_enabled,
|
||||
log_softmax,
|
||||
matmul_persistent,
|
||||
mean_dim,
|
||||
set_batch_invariant_mode,
|
||||
)
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
__all__ = [
|
||||
"set_batch_invariant_mode",
|
||||
"is_batch_invariant_mode_enabled",
|
||||
"disable_batch_invariant_mode",
|
||||
"enable_batch_invariant_mode",
|
||||
"matmul_persistent",
|
||||
"log_softmax",
|
||||
"mean_dim",
|
||||
"get_batch_invariant_attention_block_size",
|
||||
"AttentionBlockSize",
|
||||
]
|
||||
@@ -0,0 +1,586 @@
|
||||
# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/batch_invariant_ops.py
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
from collections import namedtuple
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Dict
|
||||
|
||||
import paddle
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
__all__ = [
|
||||
"set_batch_invariant_mode",
|
||||
"is_batch_invariant_mode_enabled",
|
||||
"disable_batch_invariant_mode",
|
||||
"enable_batch_invariant_mode",
|
||||
]
|
||||
|
||||
|
||||
def _matmul_launch_metadata(grid: Callable[..., Any], kernel: Any, args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
ret = {}
|
||||
m, n, k = args["M"], args["N"], args["K"]
|
||||
ret["name"] = f"{kernel.name} [M={m}, N={n}, K={k}]"
|
||||
if "tiles_per_update" in args:
|
||||
ret["name"] = f"{kernel.name} [M={m}, N={n}, K={k}, tiles_per_update={args['tiles_per_update']:02}]"
|
||||
if "c_ptr" in args:
|
||||
bytes_per_elem = args["c_ptr"].element_size()
|
||||
else:
|
||||
bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2
|
||||
ret[f"flops{bytes_per_elem * 8}"] = 2.0 * m * n * k
|
||||
ret["bytes"] = bytes_per_elem * (m * k + n * k + m * n)
|
||||
return ret
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
|
||||
group_id = tile_id // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (tile_id % group_size_m)
|
||||
pid_n = (tile_id % num_pid_in_group) // group_size_m
|
||||
return pid_m, pid_n
|
||||
|
||||
|
||||
@triton.jit(launch_metadata=_matmul_launch_metadata)
|
||||
def matmul_kernel_persistent(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr, #
|
||||
bias_ptr,
|
||||
M,
|
||||
N,
|
||||
K, #
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
BLOCK_SIZE_M: tl.constexpr, #
|
||||
BLOCK_SIZE_N: tl.constexpr, #
|
||||
BLOCK_SIZE_K: tl.constexpr, #
|
||||
GROUP_SIZE_M: tl.constexpr, #
|
||||
NUM_SMS: tl.constexpr, #
|
||||
A_LARGE: tl.constexpr,
|
||||
B_LARGE: tl.constexpr,
|
||||
C_LARGE: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
):
|
||||
start_pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
num_tiles = num_pid_m * num_pid_n
|
||||
|
||||
tile_id_c = start_pid - NUM_SMS
|
||||
|
||||
offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
|
||||
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True):
|
||||
pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS)
|
||||
start_m = pid_m * BLOCK_SIZE_M
|
||||
start_n = pid_n * BLOCK_SIZE_N
|
||||
offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)
|
||||
if A_LARGE:
|
||||
offs_am = offs_am.to(tl.int64)
|
||||
if B_LARGE:
|
||||
offs_bn = offs_bn.to(tl.int64)
|
||||
offs_am = tl.where(offs_am < M, offs_am, 0)
|
||||
offs_bn = tl.where(offs_bn < N, offs_bn, 0)
|
||||
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
|
||||
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for ki in range(k_tiles):
|
||||
if A_LARGE or B_LARGE:
|
||||
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
|
||||
else:
|
||||
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
|
||||
a = tl.load(a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0)
|
||||
accumulator = tl.dot(a, b, accumulator)
|
||||
|
||||
tile_id_c += NUM_SMS
|
||||
pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS)
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
if C_LARGE:
|
||||
offs_cm = offs_cm.to(tl.int64)
|
||||
offs_cn = offs_cn.to(tl.int64)
|
||||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
if HAS_BIAS:
|
||||
bias_ptrs = bias_ptr + offs_cn
|
||||
bias = tl.load(bias_ptrs, mask=offs_cn < N, other=0.0).to(tl.float32)
|
||||
accumulator += bias
|
||||
c = accumulator.to(c_ptr.dtype.element_ty)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
|
||||
def get_compute_units():
|
||||
"""
|
||||
Returns the number of streaming multiprocessors (SMs) or equivalent compute units
|
||||
for the available accelerator. Assigns the value to NUM_SMS.
|
||||
"""
|
||||
NUM_SMS = None
|
||||
|
||||
if paddle.is_compiled_with_cuda():
|
||||
try:
|
||||
paddle.device.get_device() # Triton + Paddle may can't get the device
|
||||
device_properties = paddle.cuda.get_device_properties(0)
|
||||
NUM_SMS = device_properties.multi_processor_count
|
||||
except Exception:
|
||||
print("Could not get CUDA device properties. Falling back to CPU threads.")
|
||||
# TODO(liujundong): Paddle lacks a torch.get_num_threads() equivalent for the *configured* thread count.
|
||||
# Using os.cpu_count() (total logical cores) as a fallback, which may not be correct.
|
||||
# Must check downstream logic to determine if this impacts correctness.
|
||||
NUM_SMS = os.cpu_count()
|
||||
else:
|
||||
print("No CUDA device available. Using CPU.")
|
||||
# For CPU, use the number of CPU cores
|
||||
NUM_SMS = os.cpu_count()
|
||||
|
||||
return NUM_SMS
|
||||
|
||||
|
||||
def matmul_persistent(a: paddle.Tensor, b: paddle.Tensor, bias: paddle.Tensor | None = None):
|
||||
# Check constraints.
|
||||
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
|
||||
assert a.dtype == b.dtype, "Incompatible dtypes"
|
||||
assert bias is None or bias.dim() == 1, "Currently assuming bias is 1D, let Horace know if you run into this"
|
||||
|
||||
NUM_SMS = get_compute_units()
|
||||
M, K = a.shape
|
||||
K, N = b.shape
|
||||
dtype = a.dtype
|
||||
# Allocates output. In PaddlePaddle, we create on the same device as input tensor
|
||||
# Simply create the tensor without specifying device, Paddle will handle it
|
||||
c = paddle.empty((M, N), dtype=dtype)
|
||||
|
||||
# 1D launch kernel where each block gets its own program.
|
||||
def grid(META):
|
||||
return (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])),)
|
||||
|
||||
configs = {
|
||||
paddle.bfloat16: {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_stages": 3,
|
||||
"num_warps": 8,
|
||||
},
|
||||
paddle.float16: {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_stages": 3,
|
||||
"num_warps": 8,
|
||||
},
|
||||
paddle.float32: {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_stages": 3,
|
||||
"num_warps": 8,
|
||||
},
|
||||
}
|
||||
# print(a.device, b.device, c.device)
|
||||
matmul_kernel_persistent[grid](
|
||||
a,
|
||||
b,
|
||||
c, #
|
||||
bias,
|
||||
M,
|
||||
N,
|
||||
K, #
|
||||
a.stride(0),
|
||||
a.stride(1), #
|
||||
b.stride(0),
|
||||
b.stride(1), #
|
||||
c.stride(0),
|
||||
c.stride(1), #
|
||||
NUM_SMS=NUM_SMS, #
|
||||
A_LARGE=int(a.numel() > 2**31),
|
||||
B_LARGE=int(b.numel() > 2**31),
|
||||
C_LARGE=int(c.numel() > 2**31),
|
||||
HAS_BIAS=int(bias is not None),
|
||||
# The Triton compiler (when used with Paddle) cannot handle these variables as booleans. Explicitly cast to int so the compiler can process them.
|
||||
**configs[dtype],
|
||||
)
|
||||
return c
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _log_softmax_kernel(
|
||||
input_ptr,
|
||||
output_ptr,
|
||||
input_row_stride,
|
||||
output_row_stride,
|
||||
n_cols,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Compute log_softmax along the last dimension of a 2D tensor.
|
||||
Each block handles one row of the input tensor.
|
||||
"""
|
||||
# Get the row index for this block
|
||||
row_idx = tl.program_id(0).to(tl.int64)
|
||||
|
||||
# Compute base pointers for input and output rows
|
||||
row_start_ptr = input_ptr + row_idx * input_row_stride
|
||||
output_row_start_ptr = output_ptr + row_idx * output_row_stride
|
||||
|
||||
# Step 1: Find maximum value in the row for numerical stability
|
||||
max_val = -float("inf")
|
||||
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_idx < n_cols
|
||||
|
||||
# Load values
|
||||
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=-float("inf"))
|
||||
|
||||
# Update maximum
|
||||
max_val = tl.max(tl.maximum(vals, max_val))
|
||||
|
||||
# Step 2: Compute sum of exp(x - max_val)
|
||||
sum_exp = 0.0
|
||||
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_idx < n_cols
|
||||
|
||||
# Load values
|
||||
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
|
||||
|
||||
# Compute exp(x - max_val) and accumulate
|
||||
exp_vals = tl.exp(vals - max_val)
|
||||
sum_exp += tl.sum(tl.where(mask, exp_vals, 0.0))
|
||||
|
||||
# Compute log(sum_exp)
|
||||
log_sum_exp = tl.log(sum_exp)
|
||||
|
||||
# Step 3: Compute final log_softmax values: x - max_val - log_sum_exp
|
||||
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_idx < n_cols
|
||||
|
||||
# Load values
|
||||
vals = tl.load(row_start_ptr + col_idx, mask=mask)
|
||||
|
||||
# Compute log_softmax
|
||||
output = vals - max_val - log_sum_exp
|
||||
|
||||
# Store results
|
||||
tl.store(output_row_start_ptr + col_idx, output, mask=mask)
|
||||
|
||||
|
||||
def log_softmax(input: paddle.Tensor, axis: int = -1) -> paddle.Tensor:
|
||||
"""
|
||||
Compute log_softmax using Triton kernel.
|
||||
|
||||
Args:
|
||||
input: Input tensor
|
||||
axis: Dimension along which to compute log_softmax (only -1 or last dim supported)
|
||||
Returns:
|
||||
Tensor with log_softmax applied along the specified dimension
|
||||
"""
|
||||
# print("You are using triton impl for log_softmax")
|
||||
if axis != -1 and axis != input.ndim - 1:
|
||||
raise ValueError("This implementation only supports log_softmax along the last dimension")
|
||||
|
||||
# Flatten all dimensions except the last one
|
||||
original_shape = input.shape
|
||||
input_2d = input.reshape(-1, input.shape[-1])
|
||||
input_2d = input_2d.contiguous()
|
||||
|
||||
n_rows, n_cols = input_2d.shape
|
||||
|
||||
# Allocate output tensor
|
||||
output = paddle.empty_like(input_2d)
|
||||
|
||||
# Choose block size based on the number of columns
|
||||
BLOCK_SIZE = 1024
|
||||
|
||||
# Launch kernel with one block per row
|
||||
grid = (n_rows,)
|
||||
_log_softmax_kernel[grid](
|
||||
input_2d,
|
||||
output,
|
||||
input_2d.stride(0),
|
||||
output.stride(0),
|
||||
n_cols,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
# Reshape output back to original shape
|
||||
return output.reshape(original_shape)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def mean_kernel(
|
||||
input_ptr,
|
||||
output_ptr,
|
||||
input_stride0,
|
||||
input_stride1,
|
||||
input_stride2,
|
||||
output_stride0,
|
||||
output_stride1,
|
||||
M, # size before reduction dim
|
||||
N, # size of reduction dim
|
||||
K, # size after reduction dim
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Kernel for computing mean along a single dimension.
|
||||
Input is viewed as (M, N, K) where N is the dimension being reduced.
|
||||
"""
|
||||
# Program ID gives us which output element we're computing
|
||||
pid = tl.program_id(0)
|
||||
|
||||
# Compute output indices
|
||||
m_idx = pid // K
|
||||
k_idx = pid % K
|
||||
|
||||
# Bounds check
|
||||
if m_idx >= M or k_idx >= K:
|
||||
return
|
||||
|
||||
# Accumulate sum across reduction dimension
|
||||
acc = 0.0
|
||||
for n_start in range(0, N, BLOCK_SIZE):
|
||||
n_offsets = n_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = n_offsets < N
|
||||
|
||||
# Calculate input indices
|
||||
input_idx = m_idx * input_stride0 + n_offsets * input_stride1 + k_idx * input_stride2
|
||||
|
||||
# Load and accumulate
|
||||
vals = tl.load(input_ptr + input_idx, mask=mask, other=0.0)
|
||||
acc += tl.sum(vals)
|
||||
|
||||
# Compute mean and store
|
||||
mean_val = acc / N
|
||||
output_idx = m_idx * output_stride0 + k_idx * output_stride1
|
||||
tl.store(output_ptr + output_idx, mean_val)
|
||||
|
||||
|
||||
def mean_dim(
|
||||
input: paddle.Tensor, dim: int, keepdim: bool = False, dtype: paddle.dtype | None = None
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton implementation of paddle.mean with single dimension reduction.
|
||||
|
||||
Args:
|
||||
input: Input tensor
|
||||
dim: Single dimension along which to compute mean
|
||||
keepdim: Whether to keep the reduced dimension
|
||||
dtype: Output dtype. If None, uses input dtype (or float32 for integer inputs)
|
||||
|
||||
Returns:
|
||||
Tensor with mean values along specified dimension
|
||||
"""
|
||||
# Validate inputs
|
||||
assert input.is_cuda, "Input must be a CUDA tensor"
|
||||
assert -input.ndim <= dim < input.ndim, f"Invalid dimension {dim} for tensor with {input.ndim} dimensions"
|
||||
|
||||
# Handle negative dim
|
||||
if dim < 0:
|
||||
dim = dim + input.ndim
|
||||
|
||||
# Handle dtype
|
||||
if dtype is None:
|
||||
if input.dtype in [paddle.int8, paddle.int16, paddle.int32, paddle.int64]:
|
||||
dtype = paddle.float32
|
||||
else:
|
||||
dtype = input.dtype
|
||||
|
||||
# Convert input to appropriate dtype if needed
|
||||
if input.dtype != dtype:
|
||||
input = input.to(dtype)
|
||||
|
||||
# Get input shape and strides
|
||||
shape = list(input.shape)
|
||||
|
||||
# Calculate dimensions for kernel
|
||||
M = 1
|
||||
for i in range(dim):
|
||||
M *= shape[i]
|
||||
|
||||
N = shape[dim]
|
||||
|
||||
K = 1
|
||||
for i in range(dim + 1, len(shape)):
|
||||
K *= shape[i]
|
||||
|
||||
# Reshape input to 3D view (M, N, K)
|
||||
input_3d = input.reshape(M, N, K)
|
||||
|
||||
# Create output shape
|
||||
if keepdim:
|
||||
output_shape = shape.copy()
|
||||
output_shape[dim] = 1
|
||||
else:
|
||||
output_shape = shape[:dim] + shape[dim + 1 :]
|
||||
|
||||
# Create output tensor
|
||||
output = paddle.empty(output_shape, dtype=dtype)
|
||||
|
||||
# Reshape output for kernel
|
||||
if keepdim:
|
||||
output_2d = output.reshape(M, 1, K).squeeze(1)
|
||||
else:
|
||||
output_2d = output.reshape(M, K)
|
||||
|
||||
# Launch kernel
|
||||
grid = (M * K,)
|
||||
BLOCK_SIZE = 1024
|
||||
|
||||
mean_kernel[grid](
|
||||
input_3d,
|
||||
output_2d,
|
||||
input_3d.stride(0),
|
||||
input_3d.stride(1),
|
||||
input_3d.stride(2),
|
||||
output_2d.stride(0),
|
||||
output_2d.stride(1) if output_2d.ndim > 1 else 0,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def mm_batch_invariant(a, b, transpose_x=False, transpose_y=False):
|
||||
if transpose_x:
|
||||
a = a.T
|
||||
if transpose_y:
|
||||
b = b.T
|
||||
return matmul_persistent(a, b)
|
||||
|
||||
|
||||
def addmm_batch_invariant(
|
||||
input: paddle.Tensor, x: paddle.Tensor, y: paddle.Tensor, beta: float = 1.0, alpha: float = 1.0
|
||||
) -> paddle.Tensor:
|
||||
""" "
|
||||
We need achieve `Out = alpha * (x @ y) + beta * input`
|
||||
But matmul_persistent only achieve `x @ y + input`(according to aten::addmm in torch,paddle._C_ops.addmm have more parameters)
|
||||
So we use `alpha * (x @ y) + beta * input = alpha * [ (x @ y) + (beta / alpha) * input ]`
|
||||
to minimize the effection on performance
|
||||
"""
|
||||
matmul_result = matmul_persistent(a=x, b=y, bias=input * beta / alpha)
|
||||
result = alpha * matmul_result
|
||||
return result
|
||||
|
||||
|
||||
def _log_softmax_batch_invariant(x: paddle.Tensor, axis: int = -1) -> paddle.Tensor:
|
||||
return log_softmax(input=x, axis=axis)
|
||||
|
||||
|
||||
def mean_batch_invariant(
|
||||
x: paddle.Tensor, axis: list[int] = [], keepdim: bool = False, dtype: paddle.dtype | None = None, out=None
|
||||
) -> paddle.Tensor:
|
||||
assert dtype is None or dtype == paddle.float32, f"unsupported dtype: {dtype}"
|
||||
if type(axis) is int:
|
||||
result = mean_dim(x, axis, keepdim=keepdim)
|
||||
elif len(axis) == 1: # axis: int | Sequence[int]
|
||||
result = mean_dim(x, axis[0], keepdim=keepdim)
|
||||
else:
|
||||
assert x.dtype in {paddle.float16, paddle.bfloat16, paddle.float32}, "only float types supported for now"
|
||||
n_elems = 1
|
||||
for d in axis:
|
||||
n_elems *= x.shape[d]
|
||||
result = paddle.sum(x, axis=axis, keepdim=keepdim, dtype=paddle.float32) / n_elems
|
||||
|
||||
# Handle out parameter if provided
|
||||
if out is not None:
|
||||
out.copy_(result)
|
||||
return out
|
||||
return result
|
||||
|
||||
|
||||
_original_ops = {"mm": None, "addmm": None, "_log_softmax": None, "mean_dim": None}
|
||||
|
||||
_batch_invariant_MODE = False
|
||||
|
||||
|
||||
def is_batch_invariant_mode_enabled():
|
||||
return _batch_invariant_MODE
|
||||
|
||||
|
||||
def enable_batch_invariant_mode():
|
||||
global _batch_invariant_MODE, _original_ops
|
||||
if _batch_invariant_MODE:
|
||||
return
|
||||
|
||||
if hasattr(paddle, "compat") and hasattr(paddle.compat, "enable_torch_proxy"):
|
||||
paddle.compat.enable_torch_proxy()
|
||||
# TODO(liujundong): Enabling torch proxy here has a global effect.
|
||||
# Do NOT call this function from module import time,
|
||||
# otherwise it may affect other test cases during pytest collection.
|
||||
# (ex: Could not import module 'PretrainedTokenizer' or No module named 'paddle.distributed.tensor')
|
||||
# Other side effects have not been observed yet, but they should be watched out for in the future.
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Unable to enable batch-invariant mode: Paddle version is too old. " "Please upgrade PaddlePaddle."
|
||||
)
|
||||
|
||||
_original_ops["mm"] = paddle._C_ops.matmul
|
||||
_original_ops["addmm"] = paddle._C_ops.addmm
|
||||
_original_ops["log_softmax"] = paddle._C_ops.log_softmax
|
||||
_original_ops["mean"] = paddle._C_ops.mean
|
||||
|
||||
paddle._C_ops.matmul = mm_batch_invariant
|
||||
paddle._C_ops.addmm = addmm_batch_invariant
|
||||
paddle._C_ops.log_softmax = _log_softmax_batch_invariant
|
||||
paddle._C_ops.mean = mean_batch_invariant
|
||||
|
||||
_batch_invariant_MODE = True
|
||||
|
||||
|
||||
def disable_batch_invariant_mode():
|
||||
global _batch_invariant_MODE, _original_ops
|
||||
if not _batch_invariant_MODE:
|
||||
return
|
||||
|
||||
if _original_ops["mm"]:
|
||||
paddle._C_ops.matmul = _original_ops["mm"]
|
||||
if _original_ops["addmm"]:
|
||||
paddle._C_ops.addmm = _original_ops["addmm"]
|
||||
if _original_ops["log_softmax"]:
|
||||
paddle._C_ops.log_softmax = _original_ops["log_softmax"]
|
||||
if _original_ops["mean"]:
|
||||
paddle._C_ops.mean = _original_ops["mean"]
|
||||
|
||||
_batch_invariant_MODE = False
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_batch_invariant_mode(enabled: bool = True):
|
||||
global _batch_invariant_MODE, _original_ops
|
||||
old_mode = _batch_invariant_MODE
|
||||
if enabled:
|
||||
enable_batch_invariant_mode()
|
||||
else:
|
||||
disable_batch_invariant_mode()
|
||||
yield
|
||||
if old_mode:
|
||||
enable_batch_invariant_mode()
|
||||
else:
|
||||
disable_batch_invariant_mode()
|
||||
|
||||
|
||||
AttentionBlockSize = namedtuple("AttentionBlockSize", ["block_m", "block_n"])
|
||||
|
||||
|
||||
def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
|
||||
return AttentionBlockSize(block_m=16, block_n=16)
|
||||
69
tests/batch_invariant/test_batch_invariance_op_addmm.py
Normal file
69
tests/batch_invariant/test_batch_invariance_op_addmm.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/test_batch_invariance.py
|
||||
|
||||
import unittest
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.layers.batch_invariant_ops import (
|
||||
set_batch_invariant_mode,
|
||||
)
|
||||
|
||||
|
||||
class TestBatchInvariantForAddmm(unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""
|
||||
Initialize the test environment
|
||||
"""
|
||||
device = "gpu" if paddle.is_compiled_with_cuda() else "cpu"
|
||||
paddle.set_device(device)
|
||||
|
||||
def test_batch_invariance(self, B: int = 2048, D: int = 4096, dtype=paddle.float32):
|
||||
a = paddle.linspace(-100, 100, B * D, dtype=dtype).reshape(B, D)
|
||||
b = paddle.linspace(-100, 100, D * D, dtype=dtype).reshape(D, D)
|
||||
|
||||
# Method 1: Matrix-vector multiplication and add (batch size 1)
|
||||
out1 = paddle.addmm(a[:1].squeeze(0), a[:1], b)
|
||||
|
||||
# Method 2: Matrix-matrix multiplication and add, then slice (full batch)
|
||||
out2 = paddle.addmm(a[:1].squeeze(0), a, b)[:1]
|
||||
|
||||
# Check if results are identical
|
||||
diff = (out1 - out2).abs().max()
|
||||
return diff.item() == 0, diff
|
||||
|
||||
def run_iters(self, iters=10, ass=False):
|
||||
for dtype in [paddle.float32, paddle.bfloat16]:
|
||||
is_deterministic = True
|
||||
difflist = []
|
||||
for i in range(iters):
|
||||
isd, df = self.test_batch_invariance(dtype=dtype)
|
||||
is_deterministic = is_deterministic and isd
|
||||
difflist.append(df)
|
||||
print(
|
||||
f"Batch Deterministic: {is_deterministic} run-to-run max/min/diff {max(difflist)}/{min(difflist)}/{max(difflist)-min(difflist)} for {dtype} in {iters} iterations"
|
||||
)
|
||||
if ass:
|
||||
assert max(difflist) == 0
|
||||
|
||||
def test_case(self):
|
||||
# Test with standard Paddle (likely to show differences)
|
||||
print("Standard Paddle:")
|
||||
with set_batch_invariant_mode(False):
|
||||
self.run_iters(ass=False)
|
||||
# Test with batch-invariant operations
|
||||
print("\nBatch-Invariant Mode:")
|
||||
with set_batch_invariant_mode(True):
|
||||
self.run_iters(ass=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
"""
|
||||
Standard Paddle:
|
||||
Batch Deterministic: False run-to-run max/min/diff 10.7294921875/10.7294921875/0.0 for paddle.float32 in 10 iterations
|
||||
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.bfloat16 in 10 iterations
|
||||
|
||||
Batch-Invariant Mode:
|
||||
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.float32 in 10 iterations
|
||||
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.bfloat16 in 10 iterations
|
||||
"""
|
||||
125
tests/batch_invariant/test_batch_invariance_op_logsoftmax.py
Normal file
125
tests/batch_invariant/test_batch_invariance_op_logsoftmax.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/test_batch_invariance.py
|
||||
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.layers.batch_invariant_ops import (
|
||||
set_batch_invariant_mode,
|
||||
)
|
||||
|
||||
|
||||
class TestBatchInvariantForLogsoftmax(unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""
|
||||
Initialize the test environment
|
||||
"""
|
||||
device = "gpu" if paddle.is_compiled_with_cuda() else "cpu"
|
||||
paddle.set_device(device)
|
||||
|
||||
def create_softmax_trap_tensor(self, B, D, dtype):
|
||||
"""
|
||||
Constructs a "trap" tensor designed to trigger batch-invariance issues in Softmax/LogSoftmax.
|
||||
Inspired by https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/
|
||||
|
||||
Principle:
|
||||
The goal is to make the result of `exp(a - max(a))` contain numbers spanning an extremely wide numerical range
|
||||
(e.g., 1.0, 1e-5, 1e-10, and many numbers close to 0).
|
||||
When summing these numbers using parallel reduction, different summation orders (due to parallelism)
|
||||
can produce different accumulated rounding errors, leading to a subtle difference between
|
||||
batch (parallel) and single-sample (serial) computation results.
|
||||
"""
|
||||
# 1. Determine the desired values after `exp` and calculate the required input values using log().
|
||||
max_val = 20.0
|
||||
|
||||
# Offsets relative to max_val. These offsets result in values spanning vastly different orders of magnitude after exp.
|
||||
trap_values = [
|
||||
max_val, # Corresponds to exp(a-max) -> 1.0
|
||||
max_val - 4.6, # Corresponds to exp(a-max) -> ~1e-2
|
||||
max_val - 11.5, # Corresponds to exp(a-max) -> ~1e-5
|
||||
max_val - 23.0, # Corresponds to exp(a-max) -> ~1e-10
|
||||
]
|
||||
|
||||
# 2. Create a background tensor filled with a very large negative number.
|
||||
background_val = -1000.0
|
||||
a = paddle.full((B, D), background_val, dtype=dtype)
|
||||
|
||||
# 3. Scatter these "trap" values at random positions in each row.
|
||||
for i in range(B):
|
||||
# Randomly shuffle the positions of the trap values for each row to increase non-determinism.
|
||||
indices = random.sample(range(D), k=len(trap_values))
|
||||
for j, val in enumerate(trap_values):
|
||||
a[i, indices[j]] = val
|
||||
|
||||
return a
|
||||
|
||||
def test_batch_invariance(self, B: int = 2048, D: int = 4096, dtype=paddle.float32):
|
||||
a = self.create_softmax_trap_tensor(B, D, dtype)
|
||||
|
||||
# Method 1: log_softmax on batch size 1 (first row)
|
||||
out1 = paddle.nn.functional.log_softmax(a[:1])
|
||||
|
||||
# Method 2: log_softmax on full batch, then slice (first row)
|
||||
out2 = paddle.nn.functional.log_softmax(a)[:1]
|
||||
|
||||
# Check if results are identical
|
||||
diff = (out1 - out2).abs().max()
|
||||
return diff.item() == 0, diff
|
||||
|
||||
def run_iters(self, iters=10, ass=False):
|
||||
for dtype in [paddle.float32, paddle.bfloat16, paddle.float16]:
|
||||
is_deterministic = True
|
||||
difflist = []
|
||||
for i in range(iters):
|
||||
isd, df = self.test_batch_invariance(dtype=dtype)
|
||||
is_deterministic = is_deterministic and isd
|
||||
difflist.append(df)
|
||||
print(
|
||||
f"Batch Deterministic: {is_deterministic} run-to-run max/min/diff {max(difflist)}/{min(difflist)}/{max(difflist)-min(difflist)} for {dtype} in {iters} iterations"
|
||||
)
|
||||
if ass:
|
||||
assert max(difflist) == 0
|
||||
|
||||
def test_case(self):
|
||||
# Test with standard Paddle (likely to show differences)
|
||||
print("Standard Paddle:")
|
||||
with set_batch_invariant_mode(False):
|
||||
self.run_iters(ass=False)
|
||||
# Test with batch-invariant operations
|
||||
print("\nBatch-Invariant Mode:")
|
||||
with set_batch_invariant_mode(True):
|
||||
self.run_iters(ass=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
"""
|
||||
Even in Standard Paddle, we can achieve deterministic results, so maybe the standard implementation is already batch-invariant?
|
||||
|
||||
After reviewing the four implementations called by the dispatcher function `SoftmaxForwardCUDAKernelDriverImpl` (dispatched by 'D')
|
||||
in `paddle/phi/kernels/gpudnn/softmax_gpudnn.h`:
|
||||
|
||||
1. SwitchWarpSoftmaxForward (one Warp processes 1-2 rows)
|
||||
2. LaunchKeMatrixSoftmaxForwardKernel (one Block processes one row)
|
||||
3. LaunchSoftmaxForwardCudnnKernel (the Cudnn implementation)
|
||||
4. LaunchNormalSoftmaxForward (in one Block, threads with the same threadIdx.x [a "thread column"] cooperate to process one row)
|
||||
|
||||
Excluding the Cudnn implementation, the other three custom implementations are almost certainly batch-invariant.(Need someone check again)
|
||||
The determinism of the Cudnn implementation is uncertain.
|
||||
|
||||
However, in practice, this testcase (D=4096) is dispatched to the Cudnn implementation,
|
||||
while Qwen-3 8B is dispatched to the LaunchKeMatrixSoftmaxForwardKernel implementation.
|
||||
|
||||
Result:
|
||||
|
||||
Standard Paddle:
|
||||
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.float32 in 10 iterations
|
||||
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.bfloat16 in 10 iterations
|
||||
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.float16 in 10 iterations
|
||||
|
||||
Batch-Invariant Mode:
|
||||
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.float32 in 10 iterations
|
||||
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.bfloat16 in 10 iterations
|
||||
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.float16 in 10 iterations
|
||||
"""
|
||||
68
tests/batch_invariant/test_batch_invariance_op_mean.py
Normal file
68
tests/batch_invariant/test_batch_invariance_op_mean.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/test_batch_invariance.py
|
||||
|
||||
import unittest
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.layers.batch_invariant_ops import (
|
||||
set_batch_invariant_mode,
|
||||
)
|
||||
|
||||
|
||||
class TestBatchInvariantForMean(unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""
|
||||
Initialize the test environment
|
||||
"""
|
||||
device = "gpu" if paddle.is_compiled_with_cuda() else "cpu"
|
||||
paddle.set_device(device)
|
||||
|
||||
def test_batch_invariance(self, B: int = 2048, D: int = 4096, dtype=paddle.float32):
|
||||
a = paddle.linspace(-100, 100, B * D, dtype=dtype).reshape(B, D)
|
||||
|
||||
# Method 1: Mean reduction over last axis (batch size 1)
|
||||
out1 = paddle.mean(a[:1], axis=-1)
|
||||
|
||||
# Method 2: Mean reduction over last axis (full batch)
|
||||
out2 = paddle.mean(a, axis=-1)[:1]
|
||||
|
||||
# Check if results are identical
|
||||
diff = (out1 - out2).abs().max()
|
||||
return diff.item() == 0, diff
|
||||
|
||||
def run_iters(self, iters=10, ass=False):
|
||||
for dtype in [paddle.float32, paddle.bfloat16]:
|
||||
is_deterministic = True
|
||||
difflist = []
|
||||
for i in range(iters):
|
||||
isd, df = self.test_batch_invariance(dtype=dtype)
|
||||
is_deterministic = is_deterministic and isd
|
||||
difflist.append(df)
|
||||
print(
|
||||
f"Batch Deterministic: {is_deterministic} run-to-run max/min/diff {max(difflist)}/{min(difflist)}/{max(difflist)-min(difflist)} for {dtype} in {iters} iterations"
|
||||
)
|
||||
if ass:
|
||||
assert max(difflist) == 0
|
||||
|
||||
def test_case(self):
|
||||
# Test with standard Paddle (likely to show differences)
|
||||
print("Standard Paddle:")
|
||||
with set_batch_invariant_mode(False):
|
||||
self.run_iters(ass=False)
|
||||
# Test with batch-invariant operations
|
||||
print("\nBatch-Invariant Mode:")
|
||||
with set_batch_invariant_mode(True):
|
||||
self.run_iters(ass=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
"""
|
||||
Standard Paddle:
|
||||
Batch Deterministic: False run-to-run max/min/diff 7.62939453125e-06/7.62939453125e-06/0.0 for paddle.float32 in 10 iterations
|
||||
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.bfloat16 in 10 iterations
|
||||
|
||||
Batch-Invariant Mode:
|
||||
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.float32 in 10 iterations
|
||||
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.bfloat16 in 10 iterations
|
||||
"""
|
||||
70
tests/batch_invariant/test_batch_invariance_op_mm.py
Normal file
70
tests/batch_invariant/test_batch_invariance_op_mm.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/test_batch_invariance.py
|
||||
|
||||
import unittest
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.layers.batch_invariant_ops import (
|
||||
set_batch_invariant_mode,
|
||||
)
|
||||
|
||||
|
||||
class TestBatchInvariantForMM(unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""
|
||||
Initialize the test environment
|
||||
"""
|
||||
device = "gpu" if paddle.is_compiled_with_cuda() else "cpu"
|
||||
paddle.set_device(device)
|
||||
|
||||
def test_batch_invariance(self, B: int = 2048, D: int = 4096, dtype=paddle.float32):
|
||||
a = paddle.linspace(-100, 100, B * D, dtype=dtype).reshape(B, D)
|
||||
b = paddle.linspace(-100, 100, D * D, dtype=dtype).reshape(D, D)
|
||||
|
||||
# Method 1: Matrix-vector multiplication (batch size 1)
|
||||
out1 = paddle.mm(a[:1], b)
|
||||
|
||||
# Method 2: Matrix-matrix multiplication, then slice (full batch)
|
||||
out2 = paddle.mm(a, b)[:1]
|
||||
|
||||
# Check if results are identical
|
||||
diff = (out1 - out2).abs().max()
|
||||
return diff.item() == 0, diff
|
||||
|
||||
def run_iters(self, iters=10, ass=False):
|
||||
for dtype in [paddle.float32, paddle.bfloat16]:
|
||||
is_deterministic = True
|
||||
difflist = []
|
||||
for i in range(iters):
|
||||
isd, df = self.test_batch_invariance(dtype=dtype)
|
||||
is_deterministic = is_deterministic and isd
|
||||
difflist.append(df)
|
||||
print(
|
||||
f"Batch Deterministic: {is_deterministic} run-to-run max/min/diff {max(difflist)}/{min(difflist)}/{max(difflist)-min(difflist)} for {dtype} in {iters} iterations"
|
||||
)
|
||||
if ass:
|
||||
assert max(difflist) == 0
|
||||
|
||||
def test_case(self):
|
||||
# Test with standard Paddle (likely to show differences)
|
||||
print("Standard Paddle:")
|
||||
with set_batch_invariant_mode(False):
|
||||
self.run_iters(ass=False)
|
||||
# Test with batch-invariant operations
|
||||
print("\nBatch-Invariant Mode:")
|
||||
with set_batch_invariant_mode(True):
|
||||
self.run_iters(ass=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
"""
|
||||
|
||||
Standard Paddle:
|
||||
Batch Deterministic: False run-to-run max/min/diff 10.7294921875/10.7294921875/0.0 for paddle.float32 in 10 iterations
|
||||
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.bfloat16 in 10 iterations
|
||||
|
||||
Batch-Invariant Mode:
|
||||
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.float32 in 10 iterations
|
||||
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.bfloat16 in 10 iterations
|
||||
"""
|
||||
Reference in New Issue
Block a user