Files
FastDeploy/tests/batch_invariant/test_batch_invariance_op_logsoftmax.py
Jundong Liu 6f42c37359 [Deterministic] Move paddle version batch invariant pkg to Fastdeploy (#4763)
* Move batch invariant pkg to Fastdeploy

* fix problem and pre-commit

* move test

* Change testcase to FD style

* Add testcase for log_softmax

* Add testcase for mean

* Add testcase for addmm

* fix pre-commit

* API check v0.9

* move to layers and add comment about log_softmax

* Update fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py

存在于原版代码注释中的版本控制遗留的内容,确实应该去除

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update tests/batch_invariant/test_batch_invariance_op_mean.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update tests/batch_invariant/test_batch_invariance_op_logsoftmax.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* change comment after copilot fix

* fix bug about addmm

* avoid global effect by enable_torch_proxy

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
2025-12-01 11:25:48 +08:00

126 lines
5.6 KiB
Python

# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/test_batch_invariance.py
import random
import unittest
import paddle
from fastdeploy.model_executor.layers.batch_invariant_ops import (
set_batch_invariant_mode,
)
class TestBatchInvariantForLogsoftmax(unittest.TestCase):
def setUp(self):
"""
Initialize the test environment
"""
device = "gpu" if paddle.is_compiled_with_cuda() else "cpu"
paddle.set_device(device)
def create_softmax_trap_tensor(self, B, D, dtype):
"""
Constructs a "trap" tensor designed to trigger batch-invariance issues in Softmax/LogSoftmax.
Inspired by https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/
Principle:
The goal is to make the result of `exp(a - max(a))` contain numbers spanning an extremely wide numerical range
(e.g., 1.0, 1e-5, 1e-10, and many numbers close to 0).
When summing these numbers using parallel reduction, different summation orders (due to parallelism)
can produce different accumulated rounding errors, leading to a subtle difference between
batch (parallel) and single-sample (serial) computation results.
"""
# 1. Determine the desired values after `exp` and calculate the required input values using log().
max_val = 20.0
# Offsets relative to max_val. These offsets result in values spanning vastly different orders of magnitude after exp.
trap_values = [
max_val, # Corresponds to exp(a-max) -> 1.0
max_val - 4.6, # Corresponds to exp(a-max) -> ~1e-2
max_val - 11.5, # Corresponds to exp(a-max) -> ~1e-5
max_val - 23.0, # Corresponds to exp(a-max) -> ~1e-10
]
# 2. Create a background tensor filled with a very large negative number.
background_val = -1000.0
a = paddle.full((B, D), background_val, dtype=dtype)
# 3. Scatter these "trap" values at random positions in each row.
for i in range(B):
# Randomly shuffle the positions of the trap values for each row to increase non-determinism.
indices = random.sample(range(D), k=len(trap_values))
for j, val in enumerate(trap_values):
a[i, indices[j]] = val
return a
def test_batch_invariance(self, B: int = 2048, D: int = 4096, dtype=paddle.float32):
a = self.create_softmax_trap_tensor(B, D, dtype)
# Method 1: log_softmax on batch size 1 (first row)
out1 = paddle.nn.functional.log_softmax(a[:1])
# Method 2: log_softmax on full batch, then slice (first row)
out2 = paddle.nn.functional.log_softmax(a)[:1]
# Check if results are identical
diff = (out1 - out2).abs().max()
return diff.item() == 0, diff
def run_iters(self, iters=10, ass=False):
for dtype in [paddle.float32, paddle.bfloat16, paddle.float16]:
is_deterministic = True
difflist = []
for i in range(iters):
isd, df = self.test_batch_invariance(dtype=dtype)
is_deterministic = is_deterministic and isd
difflist.append(df)
print(
f"Batch Deterministic: {is_deterministic} run-to-run max/min/diff {max(difflist)}/{min(difflist)}/{max(difflist)-min(difflist)} for {dtype} in {iters} iterations"
)
if ass:
assert max(difflist) == 0
def test_case(self):
# Test with standard Paddle (likely to show differences)
print("Standard Paddle:")
with set_batch_invariant_mode(False):
self.run_iters(ass=False)
# Test with batch-invariant operations
print("\nBatch-Invariant Mode:")
with set_batch_invariant_mode(True):
self.run_iters(ass=True)
if __name__ == "__main__":
unittest.main()
"""
Even in Standard Paddle, we can achieve deterministic results, so maybe the standard implementation is already batch-invariant?
After reviewing the four implementations called by the dispatcher function `SoftmaxForwardCUDAKernelDriverImpl` (dispatched by 'D')
in `paddle/phi/kernels/gpudnn/softmax_gpudnn.h`:
1. SwitchWarpSoftmaxForward (one Warp processes 1-2 rows)
2. LaunchKeMatrixSoftmaxForwardKernel (one Block processes one row)
3. LaunchSoftmaxForwardCudnnKernel (the Cudnn implementation)
4. LaunchNormalSoftmaxForward (in one Block, threads with the same threadIdx.x [a "thread column"] cooperate to process one row)
Excluding the Cudnn implementation, the other three custom implementations are almost certainly batch-invariant.(Need someone check again)
The determinism of the Cudnn implementation is uncertain.
However, in practice, this testcase (D=4096) is dispatched to the Cudnn implementation,
while Qwen-3 8B is dispatched to the LaunchKeMatrixSoftmaxForwardKernel implementation.
Result:
Standard Paddle:
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.float32 in 10 iterations
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.bfloat16 in 10 iterations
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.float16 in 10 iterations
Batch-Invariant Mode:
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.float32 in 10 iterations
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.bfloat16 in 10 iterations
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.float16 in 10 iterations
"""