Files
FastDeploy/tests/operators/test_fused_rotary_position_encoding.py
Echo-Nie cc6e14d2ec 【Hackathon 9th No.46】add test_fused_rotary_position_encoding (#3848)
* add test_fused_rotary_position_encoding

* 添加版权

* fix according to the review
2025-09-19 17:50:19 +08:00

135 lines
5.6 KiB
Python

# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from fastdeploy.model_executor.ops.gpu import fused_rotary_position_encoding
class TestFusedRotaryPositionEncoding(unittest.TestCase):
def setUp(self):
paddle.set_device("gpu")
np.random.seed(42)
def _make_cos_sin_cache(self, max_position: int, rot_dim: int) -> np.ndarray:
"""Generate cos/sin cache."""
assert rot_dim % 2 == 0, "rot_dim must be even"
half_dim = rot_dim // 2
inv_freq = 1.0 / (10000 ** (np.arange(0, half_dim).astype("float32") / half_dim))
positions = np.arange(max_position, dtype="float32")
freqs = np.outer(positions, inv_freq) # [max_position, half_dim]
cos_np = np.cos(freqs)
sin_np = np.sin(freqs)
return np.concatenate([cos_np, sin_np], axis=1).astype("float32")
def _ref_rotary(self, query, key, position_ids, cos_sin_cache, head_size, is_neox):
"""Numpy reference implementation."""
num_tokens, num_heads, _ = query.shape
num_kv_heads = key.shape[1]
rot_dim = cos_sin_cache.shape[1]
embed_dim = rot_dim // 2
query_ref = query.copy()
key_ref = key.copy()
for t in range(num_tokens):
pos = position_ids[t]
cos_ptr = cos_sin_cache[pos, :embed_dim]
sin_ptr = cos_sin_cache[pos, embed_dim:]
for h in range(num_heads):
arr = query_ref[t, h]
for i in range(embed_dim):
if is_neox:
x_idx, y_idx = i, embed_dim + i
cos, sin = cos_ptr[i], sin_ptr[i]
else:
x_idx, y_idx = 2 * i, 2 * i + 1
cos, sin = cos_ptr[i], sin_ptr[i]
x, y = arr[x_idx], arr[y_idx]
arr[x_idx] = x * cos - y * sin
arr[y_idx] = y * cos + x * sin
for h in range(num_kv_heads):
arr = key_ref[t, h]
for i in range(embed_dim):
if is_neox:
x_idx, y_idx = i, embed_dim + i
cos, sin = cos_ptr[i], sin_ptr[i]
else:
x_idx, y_idx = 2 * i, 2 * i + 1
cos, sin = cos_ptr[i], sin_ptr[i]
x, y = arr[x_idx], arr[y_idx]
arr[x_idx] = x * cos - y * sin
arr[y_idx] = y * cos + x * sin
return query_ref, key_ref
def _run_op(
self,
query_np: np.ndarray,
key_np: np.ndarray,
position_ids_np: np.ndarray,
cos_sin_cache_np: np.ndarray,
head_size: int,
is_neox: bool,
):
"""Run fused_rotary_position_encoding operator."""
query = paddle.to_tensor(query_np, dtype="float32")
key = paddle.to_tensor(key_np, dtype="float32")
position_ids = paddle.to_tensor(position_ids_np, dtype="int32")
cos_sin_cache = paddle.to_tensor(cos_sin_cache_np, dtype="float32")
fused_rotary_position_encoding(query, key, position_ids, cos_sin_cache, head_size, is_neox)
return query.numpy(), key.numpy()
def _check_correctness(self, num_tokens, num_heads, num_kv_heads, head_size, rot_dim, is_neox):
query_np = np.random.rand(num_tokens, num_heads, head_size).astype("float32")
key_np = np.random.rand(num_tokens, num_kv_heads, head_size).astype("float32")
position_ids_np = np.arange(num_tokens, dtype="int32")
cos_sin_cache_np = self._make_cos_sin_cache(num_tokens, rot_dim)
query_out, key_out = self._run_op(query_np, key_np, position_ids_np, cos_sin_cache_np, head_size, is_neox)
query_ref, key_ref = self._ref_rotary(query_np, key_np, position_ids_np, cos_sin_cache_np, head_size, is_neox)
np.testing.assert_allclose(query_out, query_ref, rtol=1e-5, atol=1e-6)
np.testing.assert_allclose(key_out, key_ref, rtol=1e-5, atol=1e-6)
def test_basic_case(self):
self._check_correctness(num_tokens=4, num_heads=2, num_kv_heads=2, head_size=6, rot_dim=4, is_neox=False)
def test_neox_mode(self):
self._check_correctness(num_tokens=3, num_heads=2, num_kv_heads=2, head_size=8, rot_dim=8, is_neox=True)
def test_large_num_tokens(self):
self._check_correctness(num_tokens=10, num_heads=2, num_kv_heads=2, head_size=4, rot_dim=4, is_neox=False)
def test_exceed_max_tokens(self):
num_tokens, num_heads, head_size = 65537, 1, 4
num_kv_heads, rot_dim = 1, 4
query_np = np.random.rand(num_tokens, num_heads, head_size).astype("float32")
key_np = np.random.rand(num_tokens, num_kv_heads, head_size).astype("float32")
position_ids_np = np.arange(num_tokens, dtype="int32")
cos_sin_cache_np = self._make_cos_sin_cache(num_tokens, rot_dim)
with self.assertRaises(Exception):
self._run_op(query_np, key_np, position_ids_np, cos_sin_cache_np, head_size, is_neox=False)
if __name__ == "__main__":
unittest.main()