mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-01 12:22:53 +08:00
[CI] Standard unittest (#3606)
* standard unittest * fix bugs * fix script
This commit is contained in:
@@ -647,6 +647,9 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
||||
model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name)
|
||||
process_weights_after_loading_fn(model_sublayer_name, param)
|
||||
if self.tie_word_embeddings:
|
||||
# because we use lazy guard and is not initialized by default
|
||||
if not self.lm_head.linear.weight._is_initialized():
|
||||
self.lm_head.linear.weight.initialize()
|
||||
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
|
||||
|
||||
@paddle.no_grad()
|
||||
|
||||
@@ -11,40 +11,6 @@ cd "$run_path" || exit 1
|
||||
failed_tests_file="failed_tests.log"
|
||||
> "$failed_tests_file"
|
||||
|
||||
##################################
|
||||
# 执行特殊单测case(不符合unittest/pytest格式)
|
||||
##################################
|
||||
special_tests=(
|
||||
"graph_optimization/test_cuda_graph_dynamic_subgraph.py"
|
||||
"graph_optimization/test_cuda_graph_spec_decode.py"
|
||||
"layers/test_quant_layer.py"
|
||||
"operators/test_token_penalty.py"
|
||||
"operators/test_split_fuse.py"
|
||||
"operators/test_flash_mask_attn.py"
|
||||
"operators/test_w4afp8_gemm.py"
|
||||
"model_loader/test_load_ernie_vl.py"
|
||||
"operators/test_tree_mask.py"
|
||||
)
|
||||
|
||||
failed_special=0
|
||||
success_special=0
|
||||
|
||||
for test_file in "${special_tests[@]}"; do
|
||||
if [ -f "$test_file" ]; then
|
||||
echo "Running special test: $test_file"
|
||||
python -m coverage run --parallel-mode "$test_file"
|
||||
status=$?
|
||||
if [ "$status" -ne 0 ]; then
|
||||
echo "$test_file" >> "$failed_tests_file"
|
||||
failed_special=$((failed_special+1))
|
||||
else
|
||||
success_special=$((success_special+1))
|
||||
fi
|
||||
else
|
||||
echo "Warning: $test_file not found"
|
||||
failed_special=$((failed_special+1))
|
||||
fi
|
||||
done
|
||||
|
||||
##################################
|
||||
# 执行 pytest,每个文件单独跑
|
||||
@@ -78,9 +44,8 @@ echo "Pytest failed: $failed_pytest"
|
||||
|
||||
echo "Special tests total: ${#special_tests[@]}"
|
||||
echo "Special tests successful: $success_special"
|
||||
echo "Special tests failed: $failed_special"
|
||||
|
||||
if [ "$failed_pytest" -ne 0 ] || [ "$failed_special" -ne 0 ]; then
|
||||
if [ "$failed_pytest" -ne 0 ]; then
|
||||
echo "Failed test cases are listed in $failed_tests_file"
|
||||
cat "$failed_tests_file"
|
||||
exit 8
|
||||
|
||||
@@ -1,93 +1,99 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import flash_attention_mask
|
||||
|
||||
|
||||
def naive_attn(q_input, k_input, v_input, mask):
|
||||
gqa_group_size = q_input.shape[2] // k_input.shape[2]
|
||||
class TestFlashMaskAttention(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.bsz = 1
|
||||
self.num_head = 8
|
||||
self.num_kv_head = 1
|
||||
self.q_seq_len = 1024
|
||||
self.k_seq_len = 1024
|
||||
self.head_dim = 128
|
||||
np.random.seed(self.q_seq_len)
|
||||
|
||||
q_cur = q_input.transpose([0, 2, 1, 3])
|
||||
k_cur = k_input.transpose([0, 2, 1, 3])
|
||||
v_cur = v_input.transpose([0, 2, 1, 3])
|
||||
out = np.zeros(q_cur.shape, dtype=q_input.dtype)
|
||||
def naive_attn(self, q_input, k_input, v_input, mask):
|
||||
gqa_group_size = q_input.shape[2] // k_input.shape[2]
|
||||
|
||||
for bsz in range(0, q_cur.shape[0]):
|
||||
for hi in range(0, q_cur.shape[1]):
|
||||
qk = np.matmul(q_cur[bsz, hi], k_cur[bsz, hi // gqa_group_size].T) * (1.0 / np.sqrt(q_cur.shape[3]))
|
||||
for i in range(0, qk.shape[0]):
|
||||
qk[i, mask[i] :] = -1000000
|
||||
q_cur = q_input.transpose([0, 2, 1, 3])
|
||||
k_cur = k_input.transpose([0, 2, 1, 3])
|
||||
v_cur = v_input.transpose([0, 2, 1, 3])
|
||||
out = np.zeros(q_cur.shape, dtype=q_input.dtype)
|
||||
|
||||
qk_max = np.expand_dims(qk.max(axis=-1), -1)
|
||||
qk -= qk_max
|
||||
qk = np.exp(qk)
|
||||
for bsz in range(0, q_cur.shape[0]):
|
||||
for hi in range(0, q_cur.shape[1]):
|
||||
qk = np.matmul(q_cur[bsz, hi], k_cur[bsz, hi // gqa_group_size].T) * (1.0 / np.sqrt(q_cur.shape[3]))
|
||||
for i in range(0, qk.shape[0]):
|
||||
qk[i, mask[i] :] = -1000000
|
||||
|
||||
exp_sum = np.expand_dims(qk.sum(axis=-1), -1)
|
||||
exp_sum_inv = 1.0 / exp_sum
|
||||
qk_max = np.expand_dims(qk.max(axis=-1), -1)
|
||||
qk -= qk_max
|
||||
qk = np.exp(qk)
|
||||
|
||||
out[bsz, hi] = (np.matmul(qk, v_cur[bsz, hi // gqa_group_size]) * exp_sum_inv).astype(q_input.dtype)
|
||||
return out
|
||||
exp_sum = np.expand_dims(qk.sum(axis=-1), -1)
|
||||
exp_sum_inv = 1.0 / exp_sum
|
||||
|
||||
out[bsz, hi] = (np.matmul(qk, v_cur[bsz, hi // gqa_group_size]) * exp_sum_inv).astype(q_input.dtype)
|
||||
return out
|
||||
|
||||
def paddle_flash_attn_mask(q_input, k_input, v_input, mask):
|
||||
bsz = q_input.shape[0]
|
||||
cu_seq_q = paddle.arange(bsz + 1) * q_input.shape[1]
|
||||
cu_seq_k = paddle.arange(bsz + 1) * k_input.shape[1]
|
||||
cu_seq_q = cu_seq_q.astype("int32")
|
||||
cu_seq_k = cu_seq_k.astype("int32")
|
||||
seq_len_encoder = paddle.ones(bsz) * q_input.shape[1]
|
||||
seq_len_encoder = seq_len_encoder.astype("int32")
|
||||
q_input = paddle.to_tensor(q_input).astype("bfloat16").reshape([-1, q_input.shape[2], q_input.shape[3]])
|
||||
k_input = paddle.to_tensor(k_input).astype("bfloat16").reshape([-1, k_input.shape[2], k_input.shape[3]])
|
||||
v_input = paddle.to_tensor(v_input).astype("bfloat16").reshape([-1, v_input.shape[2], v_input.shape[3]])
|
||||
v_input_pad = paddle.zeros([v_input.shape[0] + 128, v_input.shape[1], v_input.shape[2]]).astype("bfloat16")
|
||||
v_input_pad[0 : v_input.shape[0]] = v_input
|
||||
mask = paddle.to_tensor(mask).astype("int32")
|
||||
def paddle_flash_attn_mask(self, q_input, k_input, v_input, mask):
|
||||
bsz = q_input.shape[0]
|
||||
cu_seq_q = paddle.arange(bsz + 1) * q_input.shape[1]
|
||||
cu_seq_k = paddle.arange(bsz + 1) * k_input.shape[1]
|
||||
cu_seq_q = cu_seq_q.astype("int32")
|
||||
cu_seq_k = cu_seq_k.astype("int32")
|
||||
seq_len_encoder = paddle.ones(bsz) * q_input.shape[1]
|
||||
seq_len_encoder = seq_len_encoder.astype("int32")
|
||||
q_input = paddle.to_tensor(q_input).astype("bfloat16").reshape([-1, q_input.shape[2], q_input.shape[3]])
|
||||
k_input = paddle.to_tensor(k_input).astype("bfloat16").reshape([-1, k_input.shape[2], k_input.shape[3]])
|
||||
v_input = paddle.to_tensor(v_input).astype("bfloat16").reshape([-1, v_input.shape[2], v_input.shape[3]])
|
||||
v_input_pad = paddle.zeros([v_input.shape[0] + 128, v_input.shape[1], v_input.shape[2]]).astype("bfloat16")
|
||||
v_input_pad[0 : v_input.shape[0]] = v_input
|
||||
mask = paddle.to_tensor(mask).astype("int32")
|
||||
|
||||
out = flash_attention_mask(
|
||||
q_input,
|
||||
k_input,
|
||||
v_input_pad,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
seq_len_encoder,
|
||||
mask,
|
||||
int(q_input.shape[1]),
|
||||
int(k_input.shape[1]),
|
||||
int(q_input.shape[2]),
|
||||
int(k_input.shape[0]),
|
||||
int(q_input.shape[0]),
|
||||
int(k_input.shape[0]),
|
||||
)
|
||||
return out
|
||||
out = flash_attention_mask(
|
||||
q_input,
|
||||
k_input,
|
||||
v_input_pad,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
seq_len_encoder,
|
||||
mask,
|
||||
int(q_input.shape[1]),
|
||||
int(k_input.shape[1]),
|
||||
int(q_input.shape[2]),
|
||||
int(k_input.shape[0]),
|
||||
int(q_input.shape[0]),
|
||||
int(k_input.shape[0]),
|
||||
)
|
||||
return out
|
||||
|
||||
def test_flash_attention_mask(self):
|
||||
q_input = np.random.normal(0, 0.5, size=(self.bsz, self.q_seq_len, self.num_head, self.head_dim))
|
||||
k_input = np.random.normal(
|
||||
0, 0.5, size=(self.bsz, self.q_seq_len + self.k_seq_len, self.num_kv_head, self.head_dim)
|
||||
)
|
||||
v_input = np.random.normal(
|
||||
0, 0.5, size=(self.bsz, self.q_seq_len + self.k_seq_len, self.num_kv_head, self.head_dim)
|
||||
)
|
||||
|
||||
def test(bsz, num_head, num_kv_head, q_seq_len, k_seq_len):
|
||||
head_dim = 128
|
||||
q_input = np.random.normal(0, 0.5, size=(bsz, q_seq_len, num_head, head_dim))
|
||||
k_input = np.random.normal(0, 0.5, size=(bsz, q_seq_len + k_seq_len, num_kv_head, head_dim))
|
||||
v_input = np.random.normal(0, 0.5, size=(bsz, q_seq_len + k_seq_len, num_kv_head, head_dim))
|
||||
random_len = np.random.randint(self.q_seq_len // 2, size=2)
|
||||
text_len = random_len[0]
|
||||
image_len = random_len[1]
|
||||
|
||||
random_len = np.random.randint(q_seq_len // 2, size=2)
|
||||
mask = np.array([i + 1 for i in range(0, self.q_seq_len)]) + self.k_seq_len
|
||||
mask[text_len : text_len + image_len] = text_len + image_len + self.k_seq_len
|
||||
|
||||
text_len = random_len[0]
|
||||
image_len = random_len[1]
|
||||
naive_attn_out = self.naive_attn(q_input, k_input, v_input, mask)
|
||||
paddle_attn_out = self.paddle_flash_attn_mask(q_input, k_input, v_input, mask)
|
||||
|
||||
mask = np.array([i + 1 for i in range(0, q_seq_len)]) + k_seq_len
|
||||
|
||||
mask[text_len : text_len + image_len] = text_len + image_len + k_seq_len
|
||||
|
||||
naive_attn_out = naive_attn(q_input, k_input, v_input, mask)
|
||||
paddle_attn_out = paddle_flash_attn_mask(q_input, k_input, v_input, mask)
|
||||
|
||||
assert float((paddle_attn_out.reshape([-1]) - paddle.to_tensor(naive_attn_out).reshape([-1])).max()) <= 0.05
|
||||
max_diff = float((paddle_attn_out.reshape([-1]) - paddle.to_tensor(naive_attn_out).reshape([-1])).max())
|
||||
self.assertLessEqual(max_diff, 0.05)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bsz = 1
|
||||
num_head = 8
|
||||
num_kv_head = 1
|
||||
q_seq_len = 1024
|
||||
k_seq_len = 1024
|
||||
np.random.seed(q_seq_len)
|
||||
test(bsz, num_head, num_kv_head, q_seq_len, k_seq_len)
|
||||
unittest.main()
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""UT for per_channel_fp8_fp8_half_gemm_fused kernel"""
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
|
||||
class Test(unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""
|
||||
Initialize the test environment,
|
||||
including setting random seeds and environment variables.
|
||||
"""
|
||||
paddle.seed(2003)
|
||||
os.environ["FLAGS_use_cutlass_device_best_config_path"] = "default"
|
||||
|
||||
def testcase1(self):
|
||||
"""
|
||||
Check if the per_channel_fp8_fp8_half_gemm_fused function works properly.
|
||||
"""
|
||||
prop = paddle.device.cuda.get_device_properties()
|
||||
cc = prop.major * 10 + prop.minor
|
||||
if cc < 89:
|
||||
self.skipTest("per_channel_fp8_fp8_half_gemm_fused only support sm89+")
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
per_channel_fp8_fp8_half_gemm_fused,
|
||||
)
|
||||
|
||||
nks = [[2048, 2048], [2048, 5504], [6144, 2048]]
|
||||
nks = nks + [[4096, 4096], [4096, 12800], [6144, 4096]]
|
||||
nks = nks + [[5120, 5120], [5120, 13824], [15360, 5120]]
|
||||
|
||||
m = [1, 32, 64, 128, 256, 512, 1024, 2048]
|
||||
|
||||
combinations = list(product(m, nks))
|
||||
for m, (n, k) in combinations:
|
||||
A_bf16 = paddle.rand(shape=[m, k], dtype="bfloat16")
|
||||
A_fp8 = paddle.cast(A_bf16, "float8_e4m3fn")
|
||||
B_bf16 = paddle.rand(shape=[n, k], dtype="bfloat16")
|
||||
B_fp8 = B_bf16.astype("float8_e4m3fn")
|
||||
|
||||
scalar_scale = paddle.full([1], 0.5, dtype="float32")
|
||||
channel_scale = paddle.rand(shape=[n], dtype="float32")
|
||||
bias = paddle.rand(shape=[n], dtype="bfloat16")
|
||||
|
||||
result_bf16 = paddle.matmul(A_bf16, B_bf16, transpose_y=True) * scalar_scale * channel_scale + bias
|
||||
result_fp8 = per_channel_fp8_fp8_half_gemm_fused(
|
||||
A_fp8,
|
||||
B_fp8,
|
||||
bias=bias,
|
||||
scalar_scale=scalar_scale,
|
||||
channel_scale=channel_scale,
|
||||
transpose_x=False,
|
||||
transpose_y=True,
|
||||
output_dtype="bfloat16",
|
||||
)
|
||||
# absolute_error = paddle.abs(result_bf16 - result_fp8)
|
||||
# mean_absolute_error = paddle.mean(absolute_error)
|
||||
relative_error = paddle.abs(result_bf16 - result_fp8) / (paddle.abs(result_bf16))
|
||||
mean_relative_error = paddle.mean(relative_error)
|
||||
np.testing.assert_allclose(
|
||||
mean_relative_error.numpy(),
|
||||
np.array([0.001]),
|
||||
rtol=0.001,
|
||||
atol=0.25,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -13,72 +13,81 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""UT for set_stop_value"""
|
||||
import unittest
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import get_mm_split_fuse
|
||||
|
||||
input_ids = []
|
||||
image_type_ids = []
|
||||
grid_thw = []
|
||||
|
||||
class TestSplitFuse(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.grid_thw = [[6, 20, 20], [6, 40, 20]]
|
||||
self.split_fuse_img_size = 16
|
||||
self.split_fuse_text_size = 384 # 1024
|
||||
self.max_seq_len = 2048
|
||||
self.image_token_id = 100295
|
||||
|
||||
def split_grid(origin_grid_thw):
|
||||
# 划分grid_thw,该函数用于视频场景
|
||||
# origin_grid_thw = [6, 10, 12] ---> [2, 10, 12, 2, 10, 12, 2, 10, 12]
|
||||
grid_thw = []
|
||||
for t, h, w in origin_grid_thw:
|
||||
if t > 2:
|
||||
num_groups = t // 2
|
||||
remainder = t % 2
|
||||
for _ in range(num_groups):
|
||||
grid_thw.extend([2, h, w])
|
||||
if remainder > 0:
|
||||
grid_thw.extend([remainder, h, w])
|
||||
else:
|
||||
grid_thw.extend([t, h, w])
|
||||
return grid_thw
|
||||
def split_grid(self, origin_grid_thw):
|
||||
# 划分grid_thw,该函数用于视频场景
|
||||
# origin_grid_thw = [6, 10, 12] ---> [2, 10, 12, 2, 10, 12, 2, 10, 12]
|
||||
grid_thw = []
|
||||
for t, h, w in origin_grid_thw:
|
||||
if t > 2:
|
||||
num_groups = t // 2
|
||||
remainder = t % 2
|
||||
for _ in range(num_groups):
|
||||
grid_thw.extend([2, h, w])
|
||||
if remainder > 0:
|
||||
grid_thw.extend([remainder, h, w])
|
||||
else:
|
||||
grid_thw.extend([t, h, w])
|
||||
return grid_thw
|
||||
|
||||
def test_get_mm_split_fuse(self):
|
||||
grid_thw = self.split_grid(self.grid_thw)
|
||||
image_bs = len(grid_thw) // 3
|
||||
image_type_ids = [0] * image_bs
|
||||
|
||||
# 随机拼接input_ids: [txt0+img1+tx1+img2]
|
||||
input_ids = [2] * 19
|
||||
img1 = [self.image_token_id] * 100 * 3
|
||||
txt1 = [3] * 19
|
||||
img2 = [self.image_token_id] * 200 * 3
|
||||
input_ids.extend(img1)
|
||||
input_ids.extend(txt1)
|
||||
input_ids.extend(img2)
|
||||
|
||||
seq_len = len(input_ids)
|
||||
input_ids_tensor = paddle.to_tensor(input_ids, dtype="int64")
|
||||
image_type_ids_tensor = paddle.to_tensor(image_type_ids, dtype="int32")
|
||||
is_image_token = paddle.where(input_ids_tensor == self.image_token_id, 1, 0)
|
||||
image_token_sum = paddle.cumsum(is_image_token) # 前缀和
|
||||
image_token_sum = paddle.concat([paddle.zeros([1], dtype="int64"), image_token_sum])
|
||||
|
||||
grid_thw_tensor = paddle.to_tensor(grid_thw, dtype="int64")
|
||||
image_chunk_selections, split_fuse_cur_seq_lens = get_mm_split_fuse(
|
||||
input_ids_tensor.cpu(),
|
||||
image_type_ids_tensor.cast("int32").cpu(),
|
||||
image_token_sum.cast("int32").cpu(),
|
||||
grid_thw_tensor.cpu(),
|
||||
self.image_token_id,
|
||||
image_bs,
|
||||
0,
|
||||
seq_len,
|
||||
self.split_fuse_img_size,
|
||||
self.split_fuse_text_size,
|
||||
self.max_seq_len,
|
||||
)
|
||||
|
||||
# Verify the outputs are not None
|
||||
self.assertIsNotNone(image_chunk_selections)
|
||||
self.assertIsNotNone(split_fuse_cur_seq_lens)
|
||||
|
||||
# Verify the shapes are as expected
|
||||
self.assertEqual(len(image_chunk_selections.shape), 1)
|
||||
self.assertEqual(len(split_fuse_cur_seq_lens.shape), 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
grid_thw = [[6, 20, 20], [6, 40, 20]]
|
||||
grid_thw = split_grid(grid_thw)
|
||||
image_bs = len(grid_thw) // 3
|
||||
image_type_ids = [0] * image_bs
|
||||
# 随机拼接input_ids: [txt0+img1+tx1+img2]
|
||||
input_ids = [2] * 19
|
||||
img1 = [100295] * 100 * 3
|
||||
txt1 = [3] * 19
|
||||
img2 = [100295] * 200 * 3
|
||||
input_ids.extend(img1)
|
||||
input_ids.extend(txt1)
|
||||
input_ids.extend(img2)
|
||||
|
||||
split_fuse_img_size = 16
|
||||
split_fuse_text_size = 384 # 1024
|
||||
|
||||
seq_len = len(input_ids)
|
||||
input_ids_tensor = paddle.to_tensor(input_ids, dtype="int64")
|
||||
image_type_ids_tensor = paddle.to_tensor(image_type_ids, dtype="int32")
|
||||
is_image_token = paddle.where(input_ids_tensor == 100295, 1, 0)
|
||||
image_token_sum = paddle.cumsum(is_image_token) # 前缀和
|
||||
image_token_sum = paddle.concat([paddle.zeros([1], dtype="int64"), image_token_sum])
|
||||
|
||||
grid_thw_tensor = paddle.to_tensor(grid_thw, dtype="int64")
|
||||
image_chunk_selections, split_fuse_cur_seq_lens = get_mm_split_fuse(
|
||||
input_ids_tensor.cpu(),
|
||||
image_type_ids_tensor.cast("int32").cpu(),
|
||||
image_token_sum.cast("int32").cpu(),
|
||||
grid_thw_tensor.cpu(),
|
||||
100295,
|
||||
image_bs,
|
||||
0,
|
||||
seq_len,
|
||||
split_fuse_img_size,
|
||||
split_fuse_text_size,
|
||||
2048,
|
||||
)
|
||||
|
||||
print("seq_len: ", seq_len)
|
||||
print("grid_thw", grid_thw_tensor)
|
||||
print("image_chunk_selections: ", image_chunk_selections)
|
||||
print("split_fuse_cur_seq_lens: ", split_fuse_cur_seq_lens)
|
||||
unittest.main()
|
||||
|
||||
@@ -13,40 +13,58 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""UT for get_token_penalty"""
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import get_token_penalty_once
|
||||
|
||||
paddle.seed(2023)
|
||||
|
||||
pre_ids = paddle.randint(0, 10000, (8, 1000))
|
||||
pre_ids[:, -1] = pre_ids[:, -2]
|
||||
print(pre_ids)
|
||||
logits = paddle.rand(shape=[8, 10000], dtype="float16")
|
||||
penalty_scores = np.array([1.2] * 8).astype(np.float16).reshape(-1, 1)
|
||||
penalty_scores = paddle.to_tensor(penalty_scores)
|
||||
class TestTokenPenalty(unittest.TestCase):
|
||||
def setUp(self):
|
||||
paddle.seed(2023)
|
||||
self.pre_ids = paddle.randint(0, 10000, (8, 1000))
|
||||
self.pre_ids[:, -1] = self.pre_ids[:, -2]
|
||||
self.logits = paddle.rand(shape=[8, 10000], dtype="float16")
|
||||
self.penalty_scores = np.array([1.2] * 8).astype(np.float16).reshape(-1, 1)
|
||||
self.penalty_scores = paddle.to_tensor(self.penalty_scores)
|
||||
|
||||
print("logits[0][pre_ids[0]]: ", logits[0][pre_ids[0]])
|
||||
res = get_token_penalty_once(pre_ids, logits, penalty_scores)
|
||||
for i in range(8):
|
||||
print(f"res[{i}]:{res[i][pre_ids[i]]}")
|
||||
def test_token_penalty_once(self):
|
||||
res = get_token_penalty_once(self.pre_ids, self.logits, self.penalty_scores)
|
||||
|
||||
# 验证结果形状
|
||||
self.assertEqual(res.shape, self.logits.shape)
|
||||
|
||||
# 验证惩罚逻辑
|
||||
for i in range(8):
|
||||
original_values = self.logits[i][self.pre_ids[i]]
|
||||
penalized_values = res[i][self.pre_ids[i]]
|
||||
# 检查是否应用了惩罚
|
||||
for orig, penal in zip(original_values.numpy(), penalized_values.numpy()):
|
||||
if orig < 0:
|
||||
self.assertLess(penal, orig, "负值应该乘以惩罚因子")
|
||||
else:
|
||||
self.assertLess(penal, orig, "正值应该除以惩罚因子")
|
||||
|
||||
def test_compare_with_naive_implementation(self):
|
||||
res = get_token_penalty_once(self.pre_ids, self.logits, self.penalty_scores)
|
||||
|
||||
# 朴素实现
|
||||
score = paddle.index_sample(self.logits, self.pre_ids)
|
||||
score = paddle.where(score < 0, score * self.penalty_scores, score / self.penalty_scores)
|
||||
|
||||
bsz = paddle.shape(self.logits)[0]
|
||||
bsz_range = paddle.arange(start=bsz * 0, end=bsz, step=bsz / bsz, name="bsz_range", dtype="int64").unsqueeze(
|
||||
-1
|
||||
)
|
||||
input_ids = self.pre_ids + bsz_range * self.logits.shape[-1]
|
||||
res2 = paddle.scatter(self.logits.flatten(), input_ids.flatten(), score.flatten()).reshape(self.logits.shape)
|
||||
|
||||
# 比较两种实现的结果差异
|
||||
max_diff = (res - res2).abs().max().item()
|
||||
self.assertLess(max_diff, 1e-5)
|
||||
|
||||
|
||||
input_ids = pre_ids
|
||||
score = paddle.index_sample(logits, input_ids)
|
||||
score = paddle.where(score < 0, score * penalty_scores, score / penalty_scores)
|
||||
|
||||
bsz = paddle.shape(logits)[0] # TODO: Bsz as input for inference with dynamic batch_size
|
||||
bsz_range = paddle.arange(start=bsz * 0, end=bsz, step=bsz / bsz, name="bsz_range", dtype="int64").unsqueeze(-1)
|
||||
input_ids = input_ids + bsz_range * logits.shape[-1]
|
||||
res2 = paddle.scatter(logits.flatten(), input_ids.flatten(), score.flatten()).reshape(logits.shape)
|
||||
print("-------------------------------------------")
|
||||
for i in range(8):
|
||||
print(res2[i][pre_ids[i]])
|
||||
|
||||
print("res_sub:")
|
||||
for i in range(8):
|
||||
print(res2[i][pre_ids[i]] - res[i][pre_ids[i]])
|
||||
|
||||
print((res.numpy() - res2.numpy()).sum())
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import math
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
@@ -10,351 +11,331 @@ from fastdeploy.model_executor.layers.attention.ops import (
|
||||
get_block_shape_and_split_kv_block,
|
||||
)
|
||||
|
||||
paddle.seed(0)
|
||||
|
||||
max_seq_len = 32768
|
||||
encoder_max_partition_size = max_seq_len
|
||||
max_partition_size = max_seq_len
|
||||
class TestTreeMask(unittest.TestCase):
|
||||
def setUp(self):
|
||||
paddle.seed(0)
|
||||
self.max_seq_len = 32768
|
||||
self.encoder_max_partition_size = self.max_seq_len
|
||||
self.max_partition_size = self.max_seq_len
|
||||
|
||||
max_dec_len = 1024
|
||||
bsz = 64
|
||||
run_time = 10
|
||||
warm_up = 2
|
||||
block_size = 64
|
||||
head_dim = 128
|
||||
num_q_head = 20
|
||||
num_kv_head = 4
|
||||
dtype = "bfloat16"
|
||||
self.max_dec_len = 1024
|
||||
self.bsz = 64
|
||||
self.run_time = 10
|
||||
self.warm_up = 2
|
||||
self.block_size = 64
|
||||
self.head_dim = 128
|
||||
self.num_q_head = 20
|
||||
self.num_kv_head = 4
|
||||
self.dtype = "bfloat16"
|
||||
|
||||
rope_3d = False
|
||||
use_neox_rotary_style = False
|
||||
CURRENT_Q = [None]
|
||||
TOTAL_K = []
|
||||
TOTAL_V = []
|
||||
self.rope_3d = False
|
||||
self.use_neox_rotary_style = False
|
||||
self.CURRENT_Q = [None]
|
||||
self.TOTAL_K = []
|
||||
self.TOTAL_V = []
|
||||
|
||||
# Initialize cache and block tables
|
||||
block_num_per_seq = (self.max_seq_len + self.block_size - 1) // self.block_size
|
||||
max_block_num = block_num_per_seq * self.bsz
|
||||
cache_shape = (
|
||||
max_block_num,
|
||||
self.num_kv_head,
|
||||
self.block_size,
|
||||
self.head_dim,
|
||||
)
|
||||
|
||||
def split_qkv(qkv, bsz, seq_len, num_q_head, num_kv_head, head_dim):
|
||||
# [token_num, (num_q_head + 2 * num_kv_head) * head_dim]
|
||||
qkv = qkv.reshape([bsz, seq_len, -1, head_dim])
|
||||
q = qkv[:, :, :num_q_head, :]
|
||||
# [bsz, seq_len, num_q_head, head_dim]
|
||||
CURRENT_Q[0] = q
|
||||
self.cache_k = paddle.zeros(shape=cache_shape).astype(self.dtype)
|
||||
self.cache_v = paddle.zeros(shape=cache_shape).astype(self.dtype)
|
||||
|
||||
# [bsz, seq_len, num_kv_head, head_dim]
|
||||
k = qkv[:, :, num_q_head : num_q_head + num_kv_head, :]
|
||||
TOTAL_K.append(k)
|
||||
self.block_tables = paddle.zeros(shape=(self.bsz, block_num_per_seq), dtype="int32")
|
||||
|
||||
# [bsz, seq_len, num_kv_head, head_dim]
|
||||
v = qkv[:, :, num_q_head + num_kv_head :, :]
|
||||
TOTAL_V.append(v)
|
||||
free_list = list(range(max_block_num - 1, -1, -1))
|
||||
|
||||
for i in range(self.bsz):
|
||||
need_block_num = (self.max_seq_len + self.block_size - 1) // self.block_size
|
||||
for j in range(need_block_num):
|
||||
block_id = free_list.pop()
|
||||
self.block_tables[i, j] = block_id
|
||||
|
||||
def get_padding_offset(bsz, seq_lens_this_time, seq_lens_decoder):
|
||||
batch_id_per_token = []
|
||||
cu_seqlens_q = paddle.zeros(shape=(bsz + 1), dtype="int32")
|
||||
cu_seqlens_k = paddle.zeros(shape=(bsz + 1), dtype="int32")
|
||||
cum_seq_len_q = 0
|
||||
cum_seq_len_k = 0
|
||||
for i in range(bsz):
|
||||
seq_len_now = seq_lens_this_time[i]
|
||||
seq_len_dec_now = seq_lens_decoder[i]
|
||||
for j in range(seq_len_now):
|
||||
batch_id_per_token.append(i)
|
||||
cum_seq_len_q += seq_len_now
|
||||
cum_seq_len_k += seq_len_now + seq_len_dec_now
|
||||
cu_seqlens_q[i + 1] = cum_seq_len_q
|
||||
cu_seqlens_k[i + 1] = cum_seq_len_k
|
||||
return paddle.to_tensor(batch_id_per_token, dtype="int32"), cu_seqlens_q, cu_seqlens_k
|
||||
def tearDown(self):
|
||||
self.CURRENT_Q = [None]
|
||||
self.TOTAL_K = []
|
||||
self.TOTAL_V = []
|
||||
|
||||
def split_qkv(self, qkv, bsz, seq_len):
|
||||
qkv = qkv.reshape([bsz, seq_len, -1, self.head_dim])
|
||||
q = qkv[:, :, : self.num_q_head, :]
|
||||
self.CURRENT_Q[0] = q
|
||||
|
||||
# block_table
|
||||
block_num_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
max_block_num = block_num_per_seq * bsz
|
||||
cache_shape = (
|
||||
max_block_num,
|
||||
num_kv_head,
|
||||
block_size,
|
||||
head_dim,
|
||||
)
|
||||
k = qkv[:, :, self.num_q_head : self.num_q_head + self.num_kv_head, :]
|
||||
self.TOTAL_K.append(k)
|
||||
|
||||
cache_k = paddle.zeros(shape=cache_shape).astype(dtype)
|
||||
cache_v = paddle.zeros(shape=cache_shape).astype(dtype)
|
||||
v = qkv[:, :, self.num_q_head + self.num_kv_head :, :]
|
||||
self.TOTAL_V.append(v)
|
||||
|
||||
block_tables = paddle.zeros(shape=(bsz, block_num_per_seq), dtype="int32")
|
||||
def get_padding_offset(self, bsz, seq_lens_this_time, seq_lens_decoder):
|
||||
batch_id_per_token = []
|
||||
cu_seqlens_q = paddle.zeros(shape=(bsz + 1), dtype="int32")
|
||||
cu_seqlens_k = paddle.zeros(shape=(bsz + 1), dtype="int32")
|
||||
cum_seq_len_q = 0
|
||||
cum_seq_len_k = 0
|
||||
for i in range(bsz):
|
||||
seq_len_now = seq_lens_this_time[i]
|
||||
seq_len_dec_now = seq_lens_decoder[i]
|
||||
for j in range(seq_len_now):
|
||||
batch_id_per_token.append(i)
|
||||
cum_seq_len_q += seq_len_now
|
||||
cum_seq_len_k += seq_len_now + seq_len_dec_now
|
||||
cu_seqlens_q[i + 1] = cum_seq_len_q
|
||||
cu_seqlens_k[i + 1] = cum_seq_len_k
|
||||
return paddle.to_tensor(batch_id_per_token, dtype="int32"), cu_seqlens_q, cu_seqlens_k
|
||||
|
||||
free_list = list(range(max_block_num - 1, -1, -1))
|
||||
def ref_attention(self, q, k, v, mask):
|
||||
q = q.transpose([0, 2, 1, 3])
|
||||
if len(k) > 1:
|
||||
k = paddle.concat(k, axis=1)
|
||||
else:
|
||||
k = k[0]
|
||||
k = k.transpose([0, 2, 1, 3])
|
||||
if len(v) > 1:
|
||||
v = paddle.concat(v, axis=1)
|
||||
else:
|
||||
v = v[0]
|
||||
v = v.transpose([0, 2, 1, 3])
|
||||
total_len = k.shape[2]
|
||||
|
||||
for i in range(bsz):
|
||||
need_block_num = (max_seq_len + block_size - 1) // block_size
|
||||
for j in range(need_block_num):
|
||||
block_id = free_list.pop()
|
||||
block_tables[i, j] = block_id
|
||||
scores = (
|
||||
q.reshape([self.bsz, self.num_kv_head, -1, self.head_dim])
|
||||
@ k.transpose([0, 1, 3, 2])
|
||||
* (1.0 / math.sqrt(self.head_dim))
|
||||
)
|
||||
scores = scores.reshape([self.bsz, self.num_q_head, -1, total_len])
|
||||
|
||||
if mask is not None:
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0).unsqueeze(0)
|
||||
elif mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
scores = paddle.add(scores, mask)
|
||||
weights = F.softmax(scores, axis=-1)
|
||||
|
||||
def ref_attention(q, k, v, num_q_head, num_kv_head, head_dim, mask):
|
||||
q = q.transpose([0, 2, 1, 3])
|
||||
if len(k) > 1:
|
||||
k = paddle.concat(k, axis=1)
|
||||
else:
|
||||
k = k[0]
|
||||
k = k.transpose([0, 2, 1, 3])
|
||||
if len(v) > 1:
|
||||
v = paddle.concat(v, axis=1)
|
||||
else:
|
||||
v = v[0]
|
||||
v = v.transpose([0, 2, 1, 3])
|
||||
total_len = k.shape[2]
|
||||
o = weights.reshape([self.bsz, self.num_kv_head, -1, total_len]) @ v
|
||||
return (
|
||||
o.reshape([self.bsz, self.num_q_head, -1, self.head_dim])
|
||||
.transpose([0, 2, 1, 3])
|
||||
.reshape([-1, self.num_q_head, self.head_dim])
|
||||
)
|
||||
|
||||
scores = q.reshape([bsz, num_kv_head, -1, head_dim]) @ k.transpose([0, 1, 3, 2]) * (1.0 / math.sqrt(head_dim))
|
||||
scores = scores.reshape([bsz, num_q_head, -1, total_len])
|
||||
def run_append_c16_attention(self, q_len, kv_len, prefill=False, attn_mask=None):
|
||||
if prefill:
|
||||
seq_lens_enc = [
|
||||
q_len,
|
||||
] * self.bsz
|
||||
else:
|
||||
seq_lens_enc = [
|
||||
0,
|
||||
] * self.bsz
|
||||
|
||||
if mask is not None:
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0).unsqueeze(0) # [1,1,q_len,kv_len]
|
||||
elif mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1) # [bsz,1,q_len,kv_len]
|
||||
scores = paddle.add(scores, mask)
|
||||
weights = F.softmax(scores, axis=-1)
|
||||
|
||||
o = weights.reshape([bsz, num_kv_head, -1, total_len]) @ v
|
||||
return o.reshape([bsz, num_q_head, -1, head_dim]).transpose([0, 2, 1, 3]).reshape([-1, num_q_head, head_dim])
|
||||
|
||||
|
||||
def clear_param():
|
||||
global CURRENT_Q, TOTAL_K, TOTAL_V
|
||||
CURRENT_Q = [None]
|
||||
TOTAL_K = []
|
||||
TOTAL_V = []
|
||||
|
||||
|
||||
def test_append_c16_attention(q_len, kv_len, prefill=False, attn_mask=None):
|
||||
if prefill:
|
||||
seq_lens_enc = [
|
||||
seq_lens_dec = [
|
||||
kv_len,
|
||||
] * self.bsz
|
||||
seq_lens_cur = [
|
||||
q_len,
|
||||
] * bsz
|
||||
else:
|
||||
seq_lens_enc = [
|
||||
0,
|
||||
] * bsz
|
||||
] * self.bsz
|
||||
token_num = sum(seq_lens_cur)
|
||||
decoder_step_token_num = 1 if prefill else q_len
|
||||
|
||||
seq_lens_dec = [
|
||||
kv_len,
|
||||
] * bsz
|
||||
seq_lens_cur = [
|
||||
q_len,
|
||||
] * bsz
|
||||
token_num = sum(seq_lens_cur)
|
||||
decoder_step_token_num = 1 if prefill else q_len
|
||||
seq_lens_encoder = paddle.to_tensor(seq_lens_enc, "int32")
|
||||
seq_lens_this_time = paddle.to_tensor(seq_lens_cur, "int32")
|
||||
seq_lens_decoder = paddle.to_tensor(seq_lens_dec, "int32")
|
||||
|
||||
seq_lens_encoder = paddle.to_tensor(seq_lens_enc, "int32")
|
||||
seq_lens_this_time = paddle.to_tensor(seq_lens_cur, "int32")
|
||||
seq_lens_decoder = paddle.to_tensor(seq_lens_dec, "int32")
|
||||
batch_id_per_token, cu_seqlens_q, cu_seqlens_k = self.get_padding_offset(
|
||||
self.bsz, seq_lens_this_time, seq_lens_decoder
|
||||
)
|
||||
|
||||
batch_id_per_token, cu_seqlens_q, cu_seqlens_k = get_padding_offset(bsz, seq_lens_this_time, seq_lens_decoder)
|
||||
qkv_varlen_shape = [token_num, (self.num_q_head + 2 * self.num_kv_head) * self.head_dim]
|
||||
rotary_embs_shape = [
|
||||
2,
|
||||
1,
|
||||
self.max_seq_len,
|
||||
1,
|
||||
self.head_dim if self.use_neox_rotary_style else self.head_dim // 2,
|
||||
]
|
||||
|
||||
# random data
|
||||
qkv_varlen_shape = [token_num, (num_q_head + 2 * num_kv_head) * head_dim]
|
||||
qkv = paddle.randn(shape=qkv_varlen_shape).astype(self.dtype)
|
||||
self.split_qkv(qkv, self.bsz, q_len)
|
||||
|
||||
rotary_embs_shape = [2, 1, max_seq_len, 1, head_dim if use_neox_rotary_style else head_dim // 2]
|
||||
# qkv_bias_shape = [num_q_head + 2 * num_kv_head, head_dim]
|
||||
rotary_embs = paddle.randn(shape=rotary_embs_shape).astype("float32")
|
||||
rotary_embs[0, :, :, :, :] = 1
|
||||
rotary_embs[1, :, :, :, :] = 0
|
||||
|
||||
qkv = paddle.randn(shape=qkv_varlen_shape).astype(dtype)
|
||||
cache_k_scale = None
|
||||
cache_v_scale = None
|
||||
cache_k_out_scale = None
|
||||
cache_v_out_scale = None
|
||||
|
||||
# save q, k, v for ref
|
||||
split_qkv(qkv, bsz, q_len, num_q_head, num_kv_head, head_dim)
|
||||
encoder_block_shape_q = 64
|
||||
decoder_block_shape_q = 16
|
||||
|
||||
rotary_embs = paddle.randn(shape=rotary_embs_shape).astype("float32")
|
||||
rotary_embs[0, :, :, :, :] = 1
|
||||
rotary_embs[1, :, :, :, :] = 0
|
||||
|
||||
# qkv_scale = None
|
||||
# qkv_bias = None
|
||||
|
||||
cache_k_scale = None
|
||||
cache_v_scale = None
|
||||
cache_k_out_scale = None
|
||||
cache_v_out_scale = None
|
||||
# shift_bias = None
|
||||
# smooth_weight = None
|
||||
|
||||
encoder_block_shape_q = 64
|
||||
decoder_block_shape_q = 16
|
||||
|
||||
decode_max_tile_size = (
|
||||
bsz
|
||||
* (decoder_step_token_num * (num_q_head // num_kv_head) + decoder_block_shape_q - 1)
|
||||
/ decoder_block_shape_q
|
||||
)
|
||||
decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
decoder_num_blocks = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||
max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
|
||||
paddle.device.synchronize()
|
||||
(
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
max_len_kv,
|
||||
) = get_block_shape_and_split_kv_block(
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
max_len_tensor_cpu,
|
||||
encoder_block_shape_q,
|
||||
decoder_block_shape_q,
|
||||
num_q_head // num_kv_head,
|
||||
block_size,
|
||||
decoder_step_token_num,
|
||||
)
|
||||
s_time = 0
|
||||
for i in range(run_time + warm_up):
|
||||
if i == warm_up:
|
||||
s_time = time.time()
|
||||
out = append_attention(
|
||||
qkv,
|
||||
cache_k,
|
||||
cache_v,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
decode_max_tile_size = (
|
||||
self.bsz
|
||||
* (decoder_step_token_num * (self.num_q_head // self.num_kv_head) + decoder_block_shape_q - 1)
|
||||
/ decoder_block_shape_q
|
||||
)
|
||||
decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
decoder_num_blocks = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||
max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
|
||||
paddle.device.synchronize()
|
||||
(
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
max_len_kv,
|
||||
) = get_block_shape_and_split_kv_block(
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
max_len_tensor_cpu,
|
||||
max_len_kv,
|
||||
rotary_embs,
|
||||
attn_mask, # attn_mask
|
||||
None,
|
||||
None,
|
||||
cache_k_scale,
|
||||
cache_v_scale,
|
||||
cache_k_out_scale,
|
||||
cache_v_out_scale,
|
||||
None, # cache_k_zp
|
||||
None, # cache_v_zp
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
1e-6,
|
||||
"bf16",
|
||||
"none", # cache_quant_type
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_seq_len,
|
||||
0.0,
|
||||
0.0,
|
||||
-1.0, # out_linear_in_scale
|
||||
encoder_block_shape_q, # encoder_block_shape_q
|
||||
decoder_block_shape_q, # decoder_block_shape_q
|
||||
max_partition_size, # max_partition_size
|
||||
encoder_max_partition_size, # encoder_max_partition_size
|
||||
decoder_step_token_num, # speculate_max_draft_token_num
|
||||
True, # causal
|
||||
decoder_step_token_num > 1, # speculate_decoder
|
||||
encoder_block_shape_q,
|
||||
decoder_block_shape_q,
|
||||
self.num_q_head // self.num_kv_head,
|
||||
self.block_size,
|
||||
decoder_step_token_num,
|
||||
)
|
||||
paddle.device.synchronize()
|
||||
e_time = time.time()
|
||||
print(f"mean infer time: {np.mean((e_time - s_time) * 1000 / run_time):.2f}")
|
||||
return out[0].reshape([token_num, num_q_head, head_dim])
|
||||
s_time = 0
|
||||
for i in range(self.run_time + self.warm_up):
|
||||
if i == self.warm_up:
|
||||
s_time = time.time()
|
||||
out = append_attention(
|
||||
qkv,
|
||||
self.cache_k,
|
||||
self.cache_v,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
self.block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
max_len_tensor_cpu,
|
||||
max_len_kv,
|
||||
rotary_embs,
|
||||
attn_mask,
|
||||
None,
|
||||
None,
|
||||
cache_k_scale,
|
||||
cache_v_scale,
|
||||
cache_k_out_scale,
|
||||
cache_v_out_scale,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
1e-6,
|
||||
"bf16",
|
||||
"none",
|
||||
self.use_neox_rotary_style,
|
||||
self.rope_3d,
|
||||
self.max_seq_len,
|
||||
0.0,
|
||||
0.0,
|
||||
-1.0,
|
||||
encoder_block_shape_q,
|
||||
decoder_block_shape_q,
|
||||
self.max_partition_size,
|
||||
self.encoder_max_partition_size,
|
||||
decoder_step_token_num,
|
||||
True,
|
||||
decoder_step_token_num > 1,
|
||||
)
|
||||
paddle.device.synchronize()
|
||||
e_time = time.time()
|
||||
print(f"mean infer time: {np.mean((e_time - s_time) * 1000 / self.run_time):.2f}")
|
||||
return out[0].reshape([token_num, self.num_q_head, self.head_dim])
|
||||
|
||||
def test_naive_speculative_decoding(self):
|
||||
prefill_len = 8192
|
||||
dec_len_q = 5
|
||||
total_len = prefill_len + dec_len_q
|
||||
mask = paddle.tril(paddle.ones((self.bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len)
|
||||
mask = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf")))
|
||||
self.run_append_c16_attention(prefill_len, 0, True)
|
||||
dec_out = self.run_append_c16_attention(dec_len_q, prefill_len, False)
|
||||
|
||||
def test_naive_speculative_decoding(num_q_head, num_kv_head, head_dim):
|
||||
prefill_len = 8192
|
||||
dec_len_q = 5
|
||||
total_len = prefill_len + dec_len_q
|
||||
mask = paddle.tril(paddle.ones((bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len)
|
||||
mask = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf")))
|
||||
test_append_c16_attention(prefill_len, 0, True)
|
||||
dec_out = test_append_c16_attention(dec_len_q, prefill_len, False)
|
||||
ref_out = self.ref_attention(self.CURRENT_Q[0], self.TOTAL_K, self.TOTAL_V, mask)
|
||||
np.testing.assert_allclose(
|
||||
ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03
|
||||
)
|
||||
|
||||
ref_out = ref_attention(CURRENT_Q[0], TOTAL_K, TOTAL_V, num_q_head, num_kv_head, head_dim, mask)
|
||||
np.testing.assert_allclose(
|
||||
ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03
|
||||
)
|
||||
def test_mask(self):
|
||||
prefill_len = 8192
|
||||
dec_len_q = 5
|
||||
total_len = prefill_len + dec_len_q
|
||||
mask = paddle.tril(paddle.ones((self.bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len)
|
||||
mask_ref = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf")))
|
||||
|
||||
mask_append_attn = mask[:, :, prefill_len:]
|
||||
mask_append_attn = paddle.where(
|
||||
mask_append_attn == 1,
|
||||
paddle.full_like(mask_append_attn, fill_value=False, dtype=bool),
|
||||
paddle.full_like(mask_append_attn, fill_value=True, dtype=bool),
|
||||
)
|
||||
|
||||
def test_mask(num_q_head, num_kv_head, head_dim):
|
||||
prefill_len = 8192
|
||||
dec_len_q = 5
|
||||
total_len = prefill_len + dec_len_q
|
||||
mask = paddle.tril(paddle.ones((bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len)
|
||||
mask_ref = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf")))
|
||||
self.run_append_c16_attention(prefill_len, 0, True)
|
||||
dec_out = self.run_append_c16_attention(dec_len_q, prefill_len, False, mask_append_attn)
|
||||
|
||||
mask_append_attn = mask[:, :, prefill_len:]
|
||||
mask_append_attn = paddle.where(
|
||||
mask_append_attn == 1,
|
||||
paddle.full_like(mask_append_attn, fill_value=False, dtype=bool),
|
||||
paddle.full_like(mask_append_attn, fill_value=True, dtype=bool),
|
||||
)
|
||||
ref_out = self.ref_attention(self.CURRENT_Q[0], self.TOTAL_K, self.TOTAL_V, mask_ref)
|
||||
|
||||
test_append_c16_attention(prefill_len, 0, True)
|
||||
dec_out = test_append_c16_attention(dec_len_q, prefill_len, False, mask_append_attn)
|
||||
np.testing.assert_allclose(
|
||||
ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03
|
||||
)
|
||||
|
||||
ref_out = ref_attention(CURRENT_Q[0], TOTAL_K, TOTAL_V, num_q_head, num_kv_head, head_dim, mask_ref)
|
||||
def test_tree_mask(self):
|
||||
prefill_len = 8192
|
||||
dec_len_q = 5
|
||||
total_len = prefill_len + dec_len_q
|
||||
mask = paddle.tril(paddle.ones((self.bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len)
|
||||
mask[:, 2, prefill_len + 1] = 0
|
||||
mask[:, 3, prefill_len + 2] = 0
|
||||
mask[:, 4, prefill_len + 1] = 0
|
||||
mask[:, 4, prefill_len + 3] = 0
|
||||
|
||||
np.testing.assert_allclose(
|
||||
ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03
|
||||
)
|
||||
mask_ref = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf")))
|
||||
|
||||
mask_append_attn = mask[:, :, prefill_len:]
|
||||
mask_append_attn = paddle.where(
|
||||
mask_append_attn == 1,
|
||||
paddle.full_like(mask_append_attn, fill_value=False, dtype=bool),
|
||||
paddle.full_like(mask_append_attn, fill_value=True, dtype=bool),
|
||||
)
|
||||
|
||||
def test_tree_mask(num_q_head, num_kv_head, head_dim):
|
||||
# tree
|
||||
# [N, N+1, N+1, N+2, N+2]
|
||||
# N [0, -inf, -inf, -inf, -inf]
|
||||
# N+1 [0, 0, -inf, -inf, -inf]
|
||||
# N+1 [0, -inf, 0, -inf, -inf]
|
||||
# N+2 [0, 0, -inf, 0, -inf]
|
||||
# N+2 [0, -inf, 0, -inf, 0]
|
||||
prefill_len = 8192
|
||||
dec_len_q = 5
|
||||
total_len = prefill_len + dec_len_q
|
||||
mask = paddle.tril(paddle.ones((bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len)
|
||||
mask[:, 2, prefill_len + 1] = 0
|
||||
mask[:, 3, prefill_len + 2] = 0
|
||||
mask[:, 4, prefill_len + 1] = 0
|
||||
mask[:, 4, prefill_len + 3] = 0
|
||||
|
||||
mask_ref = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf")))
|
||||
|
||||
mask_append_attn = mask[:, :, prefill_len:]
|
||||
mask_append_attn = paddle.where(
|
||||
mask_append_attn == 1,
|
||||
paddle.full_like(mask_append_attn, fill_value=False, dtype=bool),
|
||||
paddle.full_like(mask_append_attn, fill_value=True, dtype=bool),
|
||||
)
|
||||
|
||||
test_append_c16_attention(prefill_len, 0, True)
|
||||
dec_out = test_append_c16_attention(dec_len_q, prefill_len, False, mask_append_attn)
|
||||
ref_out = ref_attention(CURRENT_Q[0], TOTAL_K, TOTAL_V, num_q_head, num_kv_head, head_dim, mask_ref)
|
||||
np.testing.assert_allclose(
|
||||
ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03
|
||||
)
|
||||
self.run_append_c16_attention(prefill_len, 0, True)
|
||||
dec_out = self.run_append_c16_attention(dec_len_q, prefill_len, False, mask_append_attn)
|
||||
ref_out = self.ref_attention(self.CURRENT_Q[0], self.TOTAL_K, self.TOTAL_V, mask_ref)
|
||||
np.testing.assert_allclose(
|
||||
ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
test_naive_speculative_decoding(num_q_head, num_kv_head, head_dim)
|
||||
clear_param()
|
||||
|
||||
test_mask(num_q_head, num_kv_head, head_dim)
|
||||
clear_param()
|
||||
|
||||
test_tree_mask(num_q_head, num_kv_head, head_dim)
|
||||
unittest.main()
|
||||
|
||||
@@ -12,92 +12,97 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm, w4afp8_gemm_weight_convert
|
||||
|
||||
|
||||
def w4afp8_gemm_naive(input_bf16, weight_quant, tokens, weight_dequant_scale, BATCH, N):
|
||||
all_tokens = int(tokens.sum())
|
||||
out = paddle.zeros([all_tokens, N], dtype="bfloat16")
|
||||
pre_fix_token = 0
|
||||
for i in range(BATCH):
|
||||
input = input_bf16[pre_fix_token : pre_fix_token + tokens[i], :]
|
||||
weight = (weight_quant[i] - 7.0) * weight_dequant_scale[i]
|
||||
out_i = paddle.matmul(input, weight.astype("bfloat16"), transpose_y=True)
|
||||
out[pre_fix_token : pre_fix_token + tokens[i], :] = out_i
|
||||
pre_fix_token += tokens[i]
|
||||
return out
|
||||
class TestW4AFP8GEMM(unittest.TestCase):
|
||||
def setUp(self):
|
||||
paddle.seed(0)
|
||||
self.tokens_per_group = 256
|
||||
self.N = 256
|
||||
self.K = 256
|
||||
self.BATCH = 1
|
||||
self.TokenPadding = 0
|
||||
|
||||
tokens = [self.tokens_per_group] * self.BATCH
|
||||
self.tokens_perfix_sum = np.cumsum(tokens)
|
||||
|
||||
self.tokens = paddle.to_tensor(tokens, dtype="int64")
|
||||
self.tokens_perfix_sum = paddle.to_tensor(self.tokens_perfix_sum, dtype="int64")
|
||||
self.all_tokens = int(self.tokens.sum())
|
||||
|
||||
self.input_fp8 = paddle.randn([self.all_tokens, self.K], dtype="bfloat16").astype(paddle.float8_e4m3fn)
|
||||
self.input_bf16 = self.input_fp8.astype("bfloat16")
|
||||
self.weight = paddle.randn([self.BATCH, self.N, self.K], dtype="bfloat16") / 10
|
||||
|
||||
self.weight_scale = 7 / self.weight.abs().max(axis=-1).reshape([self.BATCH, self.N, 1])
|
||||
self.weight_quant = (self.weight * self.weight_scale).astype("int") + 7
|
||||
self.weight_quant = paddle.clip(self.weight_quant, 0, 14)
|
||||
self.weight_quant = self.weight_quant.astype("bfloat16")
|
||||
self.weight_dequant_scale = 1 / self.weight_scale.astype("float32")
|
||||
self.input_row_sum = self.input_bf16.sum(axis=1) * -7 / 512
|
||||
self.max_tokens = int(self.tokens.max())
|
||||
|
||||
def w4afp8_gemm_naive(self, input_bf16, weight_quant, tokens, weight_dequant_scale):
|
||||
all_tokens = int(tokens.sum())
|
||||
out = paddle.zeros([all_tokens, self.N], dtype="bfloat16")
|
||||
pre_fix_token = 0
|
||||
for i in range(self.BATCH):
|
||||
input = input_bf16[pre_fix_token : pre_fix_token + tokens[i], :]
|
||||
weight = (weight_quant[i] - 7.0) * weight_dequant_scale[i]
|
||||
out_i = paddle.matmul(input, weight.astype("bfloat16"), transpose_y=True)
|
||||
out[pre_fix_token : pre_fix_token + tokens[i], :] = out_i
|
||||
pre_fix_token += tokens[i]
|
||||
return out
|
||||
|
||||
def permute_scale(self, weight_scale):
|
||||
weight_scale = weight_scale.reshape([self.BATCH, self.N])
|
||||
temp = paddle.zeros([16])
|
||||
for b in range(self.BATCH):
|
||||
for n in range(0, self.N, 16):
|
||||
temp[:] = weight_scale[b, n : n + 16]
|
||||
for j in range(0, 16, 2):
|
||||
weight_scale[b, n + j] = temp[j // 2]
|
||||
weight_scale[b, n + j + 1] = temp[j // 2 + 8]
|
||||
return weight_scale
|
||||
|
||||
def test_w4afp8_gemm(self):
|
||||
out_naive = self.w4afp8_gemm_naive(self.input_bf16, self.weight_quant, self.tokens, self.weight_dequant_scale)
|
||||
|
||||
weight_dequant_scale = paddle.to_tensor(self.permute_scale(self.weight_dequant_scale) * 512)
|
||||
weight_int4 = w4afp8_gemm_weight_convert(self.weight_quant.astype("uint8").cpu())
|
||||
|
||||
if self.TokenPadding == 0:
|
||||
out_cuda = w4afp8_gemm(
|
||||
self.input_fp8,
|
||||
weight_int4.cuda(),
|
||||
self.tokens_perfix_sum,
|
||||
self.input_row_sum.astype("float32"),
|
||||
weight_dequant_scale.astype("float32"),
|
||||
int(self.TokenPadding),
|
||||
self.max_tokens,
|
||||
True,
|
||||
)
|
||||
else:
|
||||
out_cuda = w4afp8_gemm(
|
||||
self.input_fp8,
|
||||
weight_int4.cuda(),
|
||||
self.tokens,
|
||||
self.input_row_sum.astype("float32"),
|
||||
weight_dequant_scale.astype("float32"),
|
||||
int(self.TokenPadding),
|
||||
self.max_tokens,
|
||||
True,
|
||||
)
|
||||
|
||||
gap = (out_cuda - out_naive).abs()
|
||||
self.assertLess(float(gap.mean()), 0.07)
|
||||
|
||||
|
||||
def permute_scale(weight_scale):
|
||||
weight_scale = weight_scale.reshape([BATCH, N])
|
||||
temp = paddle.zeros([16])
|
||||
for b in range(BATCH):
|
||||
for n in range(0, N, 16):
|
||||
temp[:] = weight_scale[b, n : n + 16]
|
||||
for j in range(0, 16, 2):
|
||||
weight_scale[b, n + j] = temp[j // 2]
|
||||
weight_scale[b, n + j + 1] = temp[j // 2 + 8]
|
||||
return weight_scale
|
||||
|
||||
|
||||
paddle.seed(0)
|
||||
tokens_per_group = 256
|
||||
N = 256
|
||||
K = 256
|
||||
BATCH = 1
|
||||
TokenPadding = 0
|
||||
|
||||
tokens = [tokens_per_group] * BATCH
|
||||
tokens_perfix_sum = np.cumsum(tokens)
|
||||
|
||||
|
||||
tokens = paddle.to_tensor(tokens, dtype="int64")
|
||||
tokens_perfix_sum = paddle.to_tensor(tokens_perfix_sum, dtype="int64")
|
||||
|
||||
all_tokens = int(tokens.sum())
|
||||
|
||||
input_fp8 = paddle.randn([all_tokens, K], dtype="bfloat16").astype(paddle.float8_e4m3fn)
|
||||
input_bf16 = input_fp8.astype("bfloat16")
|
||||
weight = paddle.randn([BATCH, N, K], dtype="bfloat16") / 10
|
||||
|
||||
weight_scale = 7 / weight.abs().max(axis=-1).reshape([BATCH, N, 1])
|
||||
weight_quant = (weight * weight_scale).astype("int") + 7
|
||||
weight_quant = paddle.clip(weight_quant, 0, 14)
|
||||
weight_quant = weight_quant.astype("bfloat16")
|
||||
weight_dequant_scale = 1 / weight_scale.astype("float32")
|
||||
input_row_sum = input_bf16.sum(axis=1) * -7 / 512
|
||||
max_tokens = int(tokens.max())
|
||||
|
||||
out_naive = w4afp8_gemm_naive(input_bf16, weight_quant, tokens, weight_dequant_scale, BATCH, N)
|
||||
weight_dequant_scale = paddle.to_tensor(permute_scale(weight_dequant_scale) * 512)
|
||||
|
||||
weight_int4 = w4afp8_gemm_weight_convert(weight_quant.astype("uint8").cpu())
|
||||
|
||||
if TokenPadding == 0:
|
||||
out_cuda = w4afp8_gemm(
|
||||
input_fp8,
|
||||
weight_int4.cuda(),
|
||||
tokens_perfix_sum,
|
||||
input_row_sum.astype("float32"),
|
||||
weight_dequant_scale.astype("float32"),
|
||||
int(TokenPadding),
|
||||
max_tokens,
|
||||
True,
|
||||
)
|
||||
else:
|
||||
out_cuda = w4afp8_gemm(
|
||||
input_fp8,
|
||||
weight_int4.cuda(),
|
||||
tokens,
|
||||
input_row_sum.astype("float32"),
|
||||
weight_dequant_scale.astype("float32"),
|
||||
int(TokenPadding),
|
||||
max_tokens,
|
||||
True,
|
||||
)
|
||||
|
||||
gap = (out_cuda - out_naive).abs()
|
||||
assert float(gap.mean()) < 0.07
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user