diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py index 2df1e924f..eb1d4d719 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py @@ -647,6 +647,9 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name) process_weights_after_loading_fn(model_sublayer_name, param) if self.tie_word_embeddings: + # because we use lazy guard and is not initialized by default + if not self.lm_head.linear.weight._is_initialized(): + self.lm_head.linear.weight.initialize() self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0])) @paddle.no_grad() diff --git a/scripts/coverage_run.sh b/scripts/coverage_run.sh index ad3b47a06..6cea12d8c 100644 --- a/scripts/coverage_run.sh +++ b/scripts/coverage_run.sh @@ -11,40 +11,6 @@ cd "$run_path" || exit 1 failed_tests_file="failed_tests.log" > "$failed_tests_file" -################################## -# 执行特殊单测case(不符合unittest/pytest格式) -################################## -special_tests=( - "graph_optimization/test_cuda_graph_dynamic_subgraph.py" - "graph_optimization/test_cuda_graph_spec_decode.py" - "layers/test_quant_layer.py" - "operators/test_token_penalty.py" - "operators/test_split_fuse.py" - "operators/test_flash_mask_attn.py" - "operators/test_w4afp8_gemm.py" - "model_loader/test_load_ernie_vl.py" - "operators/test_tree_mask.py" -) - -failed_special=0 -success_special=0 - -for test_file in "${special_tests[@]}"; do - if [ -f "$test_file" ]; then - echo "Running special test: $test_file" - python -m coverage run --parallel-mode "$test_file" - status=$? - if [ "$status" -ne 0 ]; then - echo "$test_file" >> "$failed_tests_file" - failed_special=$((failed_special+1)) - else - success_special=$((success_special+1)) - fi - else - echo "Warning: $test_file not found" - failed_special=$((failed_special+1)) - fi -done ################################## # 执行 pytest,每个文件单独跑 @@ -78,9 +44,8 @@ echo "Pytest failed: $failed_pytest" echo "Special tests total: ${#special_tests[@]}" echo "Special tests successful: $success_special" -echo "Special tests failed: $failed_special" -if [ "$failed_pytest" -ne 0 ] || [ "$failed_special" -ne 0 ]; then +if [ "$failed_pytest" -ne 0 ]; then echo "Failed test cases are listed in $failed_tests_file" cat "$failed_tests_file" exit 8 diff --git a/tests/operators/test_flash_mask_attn.py b/tests/operators/test_flash_mask_attn.py index 1b2361dc1..07be15769 100644 --- a/tests/operators/test_flash_mask_attn.py +++ b/tests/operators/test_flash_mask_attn.py @@ -1,93 +1,99 @@ +import unittest + import numpy as np import paddle from fastdeploy.model_executor.ops.gpu import flash_attention_mask -def naive_attn(q_input, k_input, v_input, mask): - gqa_group_size = q_input.shape[2] // k_input.shape[2] +class TestFlashMaskAttention(unittest.TestCase): + def setUp(self): + self.bsz = 1 + self.num_head = 8 + self.num_kv_head = 1 + self.q_seq_len = 1024 + self.k_seq_len = 1024 + self.head_dim = 128 + np.random.seed(self.q_seq_len) - q_cur = q_input.transpose([0, 2, 1, 3]) - k_cur = k_input.transpose([0, 2, 1, 3]) - v_cur = v_input.transpose([0, 2, 1, 3]) - out = np.zeros(q_cur.shape, dtype=q_input.dtype) + def naive_attn(self, q_input, k_input, v_input, mask): + gqa_group_size = q_input.shape[2] // k_input.shape[2] - for bsz in range(0, q_cur.shape[0]): - for hi in range(0, q_cur.shape[1]): - qk = np.matmul(q_cur[bsz, hi], k_cur[bsz, hi // gqa_group_size].T) * (1.0 / np.sqrt(q_cur.shape[3])) - for i in range(0, qk.shape[0]): - qk[i, mask[i] :] = -1000000 + q_cur = q_input.transpose([0, 2, 1, 3]) + k_cur = k_input.transpose([0, 2, 1, 3]) + v_cur = v_input.transpose([0, 2, 1, 3]) + out = np.zeros(q_cur.shape, dtype=q_input.dtype) - qk_max = np.expand_dims(qk.max(axis=-1), -1) - qk -= qk_max - qk = np.exp(qk) + for bsz in range(0, q_cur.shape[0]): + for hi in range(0, q_cur.shape[1]): + qk = np.matmul(q_cur[bsz, hi], k_cur[bsz, hi // gqa_group_size].T) * (1.0 / np.sqrt(q_cur.shape[3])) + for i in range(0, qk.shape[0]): + qk[i, mask[i] :] = -1000000 - exp_sum = np.expand_dims(qk.sum(axis=-1), -1) - exp_sum_inv = 1.0 / exp_sum + qk_max = np.expand_dims(qk.max(axis=-1), -1) + qk -= qk_max + qk = np.exp(qk) - out[bsz, hi] = (np.matmul(qk, v_cur[bsz, hi // gqa_group_size]) * exp_sum_inv).astype(q_input.dtype) - return out + exp_sum = np.expand_dims(qk.sum(axis=-1), -1) + exp_sum_inv = 1.0 / exp_sum + out[bsz, hi] = (np.matmul(qk, v_cur[bsz, hi // gqa_group_size]) * exp_sum_inv).astype(q_input.dtype) + return out -def paddle_flash_attn_mask(q_input, k_input, v_input, mask): - bsz = q_input.shape[0] - cu_seq_q = paddle.arange(bsz + 1) * q_input.shape[1] - cu_seq_k = paddle.arange(bsz + 1) * k_input.shape[1] - cu_seq_q = cu_seq_q.astype("int32") - cu_seq_k = cu_seq_k.astype("int32") - seq_len_encoder = paddle.ones(bsz) * q_input.shape[1] - seq_len_encoder = seq_len_encoder.astype("int32") - q_input = paddle.to_tensor(q_input).astype("bfloat16").reshape([-1, q_input.shape[2], q_input.shape[3]]) - k_input = paddle.to_tensor(k_input).astype("bfloat16").reshape([-1, k_input.shape[2], k_input.shape[3]]) - v_input = paddle.to_tensor(v_input).astype("bfloat16").reshape([-1, v_input.shape[2], v_input.shape[3]]) - v_input_pad = paddle.zeros([v_input.shape[0] + 128, v_input.shape[1], v_input.shape[2]]).astype("bfloat16") - v_input_pad[0 : v_input.shape[0]] = v_input - mask = paddle.to_tensor(mask).astype("int32") + def paddle_flash_attn_mask(self, q_input, k_input, v_input, mask): + bsz = q_input.shape[0] + cu_seq_q = paddle.arange(bsz + 1) * q_input.shape[1] + cu_seq_k = paddle.arange(bsz + 1) * k_input.shape[1] + cu_seq_q = cu_seq_q.astype("int32") + cu_seq_k = cu_seq_k.astype("int32") + seq_len_encoder = paddle.ones(bsz) * q_input.shape[1] + seq_len_encoder = seq_len_encoder.astype("int32") + q_input = paddle.to_tensor(q_input).astype("bfloat16").reshape([-1, q_input.shape[2], q_input.shape[3]]) + k_input = paddle.to_tensor(k_input).astype("bfloat16").reshape([-1, k_input.shape[2], k_input.shape[3]]) + v_input = paddle.to_tensor(v_input).astype("bfloat16").reshape([-1, v_input.shape[2], v_input.shape[3]]) + v_input_pad = paddle.zeros([v_input.shape[0] + 128, v_input.shape[1], v_input.shape[2]]).astype("bfloat16") + v_input_pad[0 : v_input.shape[0]] = v_input + mask = paddle.to_tensor(mask).astype("int32") - out = flash_attention_mask( - q_input, - k_input, - v_input_pad, - cu_seq_q, - cu_seq_k, - seq_len_encoder, - mask, - int(q_input.shape[1]), - int(k_input.shape[1]), - int(q_input.shape[2]), - int(k_input.shape[0]), - int(q_input.shape[0]), - int(k_input.shape[0]), - ) - return out + out = flash_attention_mask( + q_input, + k_input, + v_input_pad, + cu_seq_q, + cu_seq_k, + seq_len_encoder, + mask, + int(q_input.shape[1]), + int(k_input.shape[1]), + int(q_input.shape[2]), + int(k_input.shape[0]), + int(q_input.shape[0]), + int(k_input.shape[0]), + ) + return out + def test_flash_attention_mask(self): + q_input = np.random.normal(0, 0.5, size=(self.bsz, self.q_seq_len, self.num_head, self.head_dim)) + k_input = np.random.normal( + 0, 0.5, size=(self.bsz, self.q_seq_len + self.k_seq_len, self.num_kv_head, self.head_dim) + ) + v_input = np.random.normal( + 0, 0.5, size=(self.bsz, self.q_seq_len + self.k_seq_len, self.num_kv_head, self.head_dim) + ) -def test(bsz, num_head, num_kv_head, q_seq_len, k_seq_len): - head_dim = 128 - q_input = np.random.normal(0, 0.5, size=(bsz, q_seq_len, num_head, head_dim)) - k_input = np.random.normal(0, 0.5, size=(bsz, q_seq_len + k_seq_len, num_kv_head, head_dim)) - v_input = np.random.normal(0, 0.5, size=(bsz, q_seq_len + k_seq_len, num_kv_head, head_dim)) + random_len = np.random.randint(self.q_seq_len // 2, size=2) + text_len = random_len[0] + image_len = random_len[1] - random_len = np.random.randint(q_seq_len // 2, size=2) + mask = np.array([i + 1 for i in range(0, self.q_seq_len)]) + self.k_seq_len + mask[text_len : text_len + image_len] = text_len + image_len + self.k_seq_len - text_len = random_len[0] - image_len = random_len[1] + naive_attn_out = self.naive_attn(q_input, k_input, v_input, mask) + paddle_attn_out = self.paddle_flash_attn_mask(q_input, k_input, v_input, mask) - mask = np.array([i + 1 for i in range(0, q_seq_len)]) + k_seq_len - - mask[text_len : text_len + image_len] = text_len + image_len + k_seq_len - - naive_attn_out = naive_attn(q_input, k_input, v_input, mask) - paddle_attn_out = paddle_flash_attn_mask(q_input, k_input, v_input, mask) - - assert float((paddle_attn_out.reshape([-1]) - paddle.to_tensor(naive_attn_out).reshape([-1])).max()) <= 0.05 + max_diff = float((paddle_attn_out.reshape([-1]) - paddle.to_tensor(naive_attn_out).reshape([-1])).max()) + self.assertLessEqual(max_diff, 0.05) if __name__ == "__main__": - bsz = 1 - num_head = 8 - num_kv_head = 1 - q_seq_len = 1024 - k_seq_len = 1024 - np.random.seed(q_seq_len) - test(bsz, num_head, num_kv_head, q_seq_len, k_seq_len) + unittest.main() diff --git a/tests/operators/test_perchannel_gemm.py b/tests/operators/test_perchannel_gemm.py deleted file mode 100644 index 02bc33651..000000000 --- a/tests/operators/test_perchannel_gemm.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""UT for per_channel_fp8_fp8_half_gemm_fused kernel""" - -import os -import unittest -from itertools import product - -import numpy as np -import paddle - - -class Test(unittest.TestCase): - def setUp(self): - """ - Initialize the test environment, - including setting random seeds and environment variables. - """ - paddle.seed(2003) - os.environ["FLAGS_use_cutlass_device_best_config_path"] = "default" - - def testcase1(self): - """ - Check if the per_channel_fp8_fp8_half_gemm_fused function works properly. - """ - prop = paddle.device.cuda.get_device_properties() - cc = prop.major * 10 + prop.minor - if cc < 89: - self.skipTest("per_channel_fp8_fp8_half_gemm_fused only support sm89+") - - from fastdeploy.model_executor.ops.gpu import ( - per_channel_fp8_fp8_half_gemm_fused, - ) - - nks = [[2048, 2048], [2048, 5504], [6144, 2048]] - nks = nks + [[4096, 4096], [4096, 12800], [6144, 4096]] - nks = nks + [[5120, 5120], [5120, 13824], [15360, 5120]] - - m = [1, 32, 64, 128, 256, 512, 1024, 2048] - - combinations = list(product(m, nks)) - for m, (n, k) in combinations: - A_bf16 = paddle.rand(shape=[m, k], dtype="bfloat16") - A_fp8 = paddle.cast(A_bf16, "float8_e4m3fn") - B_bf16 = paddle.rand(shape=[n, k], dtype="bfloat16") - B_fp8 = B_bf16.astype("float8_e4m3fn") - - scalar_scale = paddle.full([1], 0.5, dtype="float32") - channel_scale = paddle.rand(shape=[n], dtype="float32") - bias = paddle.rand(shape=[n], dtype="bfloat16") - - result_bf16 = paddle.matmul(A_bf16, B_bf16, transpose_y=True) * scalar_scale * channel_scale + bias - result_fp8 = per_channel_fp8_fp8_half_gemm_fused( - A_fp8, - B_fp8, - bias=bias, - scalar_scale=scalar_scale, - channel_scale=channel_scale, - transpose_x=False, - transpose_y=True, - output_dtype="bfloat16", - ) - # absolute_error = paddle.abs(result_bf16 - result_fp8) - # mean_absolute_error = paddle.mean(absolute_error) - relative_error = paddle.abs(result_bf16 - result_fp8) / (paddle.abs(result_bf16)) - mean_relative_error = paddle.mean(relative_error) - np.testing.assert_allclose( - mean_relative_error.numpy(), - np.array([0.001]), - rtol=0.001, - atol=0.25, - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/operators/test_split_fuse.py b/tests/operators/test_split_fuse.py index ee0ea9e52..6e14104cd 100644 --- a/tests/operators/test_split_fuse.py +++ b/tests/operators/test_split_fuse.py @@ -13,72 +13,81 @@ # limitations under the License. """UT for set_stop_value""" +import unittest + import paddle from fastdeploy.model_executor.ops.gpu import get_mm_split_fuse -input_ids = [] -image_type_ids = [] -grid_thw = [] +class TestSplitFuse(unittest.TestCase): + def setUp(self): + self.grid_thw = [[6, 20, 20], [6, 40, 20]] + self.split_fuse_img_size = 16 + self.split_fuse_text_size = 384 # 1024 + self.max_seq_len = 2048 + self.image_token_id = 100295 -def split_grid(origin_grid_thw): - # 划分grid_thw,该函数用于视频场景 - # origin_grid_thw = [6, 10, 12] ---> [2, 10, 12, 2, 10, 12, 2, 10, 12] - grid_thw = [] - for t, h, w in origin_grid_thw: - if t > 2: - num_groups = t // 2 - remainder = t % 2 - for _ in range(num_groups): - grid_thw.extend([2, h, w]) - if remainder > 0: - grid_thw.extend([remainder, h, w]) - else: - grid_thw.extend([t, h, w]) - return grid_thw + def split_grid(self, origin_grid_thw): + # 划分grid_thw,该函数用于视频场景 + # origin_grid_thw = [6, 10, 12] ---> [2, 10, 12, 2, 10, 12, 2, 10, 12] + grid_thw = [] + for t, h, w in origin_grid_thw: + if t > 2: + num_groups = t // 2 + remainder = t % 2 + for _ in range(num_groups): + grid_thw.extend([2, h, w]) + if remainder > 0: + grid_thw.extend([remainder, h, w]) + else: + grid_thw.extend([t, h, w]) + return grid_thw + + def test_get_mm_split_fuse(self): + grid_thw = self.split_grid(self.grid_thw) + image_bs = len(grid_thw) // 3 + image_type_ids = [0] * image_bs + + # 随机拼接input_ids: [txt0+img1+tx1+img2] + input_ids = [2] * 19 + img1 = [self.image_token_id] * 100 * 3 + txt1 = [3] * 19 + img2 = [self.image_token_id] * 200 * 3 + input_ids.extend(img1) + input_ids.extend(txt1) + input_ids.extend(img2) + + seq_len = len(input_ids) + input_ids_tensor = paddle.to_tensor(input_ids, dtype="int64") + image_type_ids_tensor = paddle.to_tensor(image_type_ids, dtype="int32") + is_image_token = paddle.where(input_ids_tensor == self.image_token_id, 1, 0) + image_token_sum = paddle.cumsum(is_image_token) # 前缀和 + image_token_sum = paddle.concat([paddle.zeros([1], dtype="int64"), image_token_sum]) + + grid_thw_tensor = paddle.to_tensor(grid_thw, dtype="int64") + image_chunk_selections, split_fuse_cur_seq_lens = get_mm_split_fuse( + input_ids_tensor.cpu(), + image_type_ids_tensor.cast("int32").cpu(), + image_token_sum.cast("int32").cpu(), + grid_thw_tensor.cpu(), + self.image_token_id, + image_bs, + 0, + seq_len, + self.split_fuse_img_size, + self.split_fuse_text_size, + self.max_seq_len, + ) + + # Verify the outputs are not None + self.assertIsNotNone(image_chunk_selections) + self.assertIsNotNone(split_fuse_cur_seq_lens) + + # Verify the shapes are as expected + self.assertEqual(len(image_chunk_selections.shape), 1) + self.assertEqual(len(split_fuse_cur_seq_lens.shape), 1) if __name__ == "__main__": - grid_thw = [[6, 20, 20], [6, 40, 20]] - grid_thw = split_grid(grid_thw) - image_bs = len(grid_thw) // 3 - image_type_ids = [0] * image_bs - # 随机拼接input_ids: [txt0+img1+tx1+img2] - input_ids = [2] * 19 - img1 = [100295] * 100 * 3 - txt1 = [3] * 19 - img2 = [100295] * 200 * 3 - input_ids.extend(img1) - input_ids.extend(txt1) - input_ids.extend(img2) - - split_fuse_img_size = 16 - split_fuse_text_size = 384 # 1024 - - seq_len = len(input_ids) - input_ids_tensor = paddle.to_tensor(input_ids, dtype="int64") - image_type_ids_tensor = paddle.to_tensor(image_type_ids, dtype="int32") - is_image_token = paddle.where(input_ids_tensor == 100295, 1, 0) - image_token_sum = paddle.cumsum(is_image_token) # 前缀和 - image_token_sum = paddle.concat([paddle.zeros([1], dtype="int64"), image_token_sum]) - - grid_thw_tensor = paddle.to_tensor(grid_thw, dtype="int64") - image_chunk_selections, split_fuse_cur_seq_lens = get_mm_split_fuse( - input_ids_tensor.cpu(), - image_type_ids_tensor.cast("int32").cpu(), - image_token_sum.cast("int32").cpu(), - grid_thw_tensor.cpu(), - 100295, - image_bs, - 0, - seq_len, - split_fuse_img_size, - split_fuse_text_size, - 2048, - ) - - print("seq_len: ", seq_len) - print("grid_thw", grid_thw_tensor) - print("image_chunk_selections: ", image_chunk_selections) - print("split_fuse_cur_seq_lens: ", split_fuse_cur_seq_lens) + unittest.main() diff --git a/tests/operators/test_token_penalty.py b/tests/operators/test_token_penalty.py index 6114fb175..9745f221a 100644 --- a/tests/operators/test_token_penalty.py +++ b/tests/operators/test_token_penalty.py @@ -13,40 +13,58 @@ # limitations under the License. """UT for get_token_penalty""" +import unittest + import numpy as np import paddle from fastdeploy.model_executor.ops.gpu import get_token_penalty_once -paddle.seed(2023) -pre_ids = paddle.randint(0, 10000, (8, 1000)) -pre_ids[:, -1] = pre_ids[:, -2] -print(pre_ids) -logits = paddle.rand(shape=[8, 10000], dtype="float16") -penalty_scores = np.array([1.2] * 8).astype(np.float16).reshape(-1, 1) -penalty_scores = paddle.to_tensor(penalty_scores) +class TestTokenPenalty(unittest.TestCase): + def setUp(self): + paddle.seed(2023) + self.pre_ids = paddle.randint(0, 10000, (8, 1000)) + self.pre_ids[:, -1] = self.pre_ids[:, -2] + self.logits = paddle.rand(shape=[8, 10000], dtype="float16") + self.penalty_scores = np.array([1.2] * 8).astype(np.float16).reshape(-1, 1) + self.penalty_scores = paddle.to_tensor(self.penalty_scores) -print("logits[0][pre_ids[0]]: ", logits[0][pre_ids[0]]) -res = get_token_penalty_once(pre_ids, logits, penalty_scores) -for i in range(8): - print(f"res[{i}]:{res[i][pre_ids[i]]}") + def test_token_penalty_once(self): + res = get_token_penalty_once(self.pre_ids, self.logits, self.penalty_scores) + + # 验证结果形状 + self.assertEqual(res.shape, self.logits.shape) + + # 验证惩罚逻辑 + for i in range(8): + original_values = self.logits[i][self.pre_ids[i]] + penalized_values = res[i][self.pre_ids[i]] + # 检查是否应用了惩罚 + for orig, penal in zip(original_values.numpy(), penalized_values.numpy()): + if orig < 0: + self.assertLess(penal, orig, "负值应该乘以惩罚因子") + else: + self.assertLess(penal, orig, "正值应该除以惩罚因子") + + def test_compare_with_naive_implementation(self): + res = get_token_penalty_once(self.pre_ids, self.logits, self.penalty_scores) + + # 朴素实现 + score = paddle.index_sample(self.logits, self.pre_ids) + score = paddle.where(score < 0, score * self.penalty_scores, score / self.penalty_scores) + + bsz = paddle.shape(self.logits)[0] + bsz_range = paddle.arange(start=bsz * 0, end=bsz, step=bsz / bsz, name="bsz_range", dtype="int64").unsqueeze( + -1 + ) + input_ids = self.pre_ids + bsz_range * self.logits.shape[-1] + res2 = paddle.scatter(self.logits.flatten(), input_ids.flatten(), score.flatten()).reshape(self.logits.shape) + + # 比较两种实现的结果差异 + max_diff = (res - res2).abs().max().item() + self.assertLess(max_diff, 1e-5) -input_ids = pre_ids -score = paddle.index_sample(logits, input_ids) -score = paddle.where(score < 0, score * penalty_scores, score / penalty_scores) - -bsz = paddle.shape(logits)[0] # TODO: Bsz as input for inference with dynamic batch_size -bsz_range = paddle.arange(start=bsz * 0, end=bsz, step=bsz / bsz, name="bsz_range", dtype="int64").unsqueeze(-1) -input_ids = input_ids + bsz_range * logits.shape[-1] -res2 = paddle.scatter(logits.flatten(), input_ids.flatten(), score.flatten()).reshape(logits.shape) -print("-------------------------------------------") -for i in range(8): - print(res2[i][pre_ids[i]]) - -print("res_sub:") -for i in range(8): - print(res2[i][pre_ids[i]] - res[i][pre_ids[i]]) - -print((res.numpy() - res2.numpy()).sum()) +if __name__ == "__main__": + unittest.main() diff --git a/tests/operators/test_tree_mask.py b/tests/operators/test_tree_mask.py index ee8be1b3c..10e55a4b1 100644 --- a/tests/operators/test_tree_mask.py +++ b/tests/operators/test_tree_mask.py @@ -1,5 +1,6 @@ import math import time +import unittest import numpy as np import paddle @@ -10,351 +11,331 @@ from fastdeploy.model_executor.layers.attention.ops import ( get_block_shape_and_split_kv_block, ) -paddle.seed(0) -max_seq_len = 32768 -encoder_max_partition_size = max_seq_len -max_partition_size = max_seq_len +class TestTreeMask(unittest.TestCase): + def setUp(self): + paddle.seed(0) + self.max_seq_len = 32768 + self.encoder_max_partition_size = self.max_seq_len + self.max_partition_size = self.max_seq_len -max_dec_len = 1024 -bsz = 64 -run_time = 10 -warm_up = 2 -block_size = 64 -head_dim = 128 -num_q_head = 20 -num_kv_head = 4 -dtype = "bfloat16" + self.max_dec_len = 1024 + self.bsz = 64 + self.run_time = 10 + self.warm_up = 2 + self.block_size = 64 + self.head_dim = 128 + self.num_q_head = 20 + self.num_kv_head = 4 + self.dtype = "bfloat16" -rope_3d = False -use_neox_rotary_style = False -CURRENT_Q = [None] -TOTAL_K = [] -TOTAL_V = [] + self.rope_3d = False + self.use_neox_rotary_style = False + self.CURRENT_Q = [None] + self.TOTAL_K = [] + self.TOTAL_V = [] + # Initialize cache and block tables + block_num_per_seq = (self.max_seq_len + self.block_size - 1) // self.block_size + max_block_num = block_num_per_seq * self.bsz + cache_shape = ( + max_block_num, + self.num_kv_head, + self.block_size, + self.head_dim, + ) -def split_qkv(qkv, bsz, seq_len, num_q_head, num_kv_head, head_dim): - # [token_num, (num_q_head + 2 * num_kv_head) * head_dim] - qkv = qkv.reshape([bsz, seq_len, -1, head_dim]) - q = qkv[:, :, :num_q_head, :] - # [bsz, seq_len, num_q_head, head_dim] - CURRENT_Q[0] = q + self.cache_k = paddle.zeros(shape=cache_shape).astype(self.dtype) + self.cache_v = paddle.zeros(shape=cache_shape).astype(self.dtype) - # [bsz, seq_len, num_kv_head, head_dim] - k = qkv[:, :, num_q_head : num_q_head + num_kv_head, :] - TOTAL_K.append(k) + self.block_tables = paddle.zeros(shape=(self.bsz, block_num_per_seq), dtype="int32") - # [bsz, seq_len, num_kv_head, head_dim] - v = qkv[:, :, num_q_head + num_kv_head :, :] - TOTAL_V.append(v) + free_list = list(range(max_block_num - 1, -1, -1)) + for i in range(self.bsz): + need_block_num = (self.max_seq_len + self.block_size - 1) // self.block_size + for j in range(need_block_num): + block_id = free_list.pop() + self.block_tables[i, j] = block_id -def get_padding_offset(bsz, seq_lens_this_time, seq_lens_decoder): - batch_id_per_token = [] - cu_seqlens_q = paddle.zeros(shape=(bsz + 1), dtype="int32") - cu_seqlens_k = paddle.zeros(shape=(bsz + 1), dtype="int32") - cum_seq_len_q = 0 - cum_seq_len_k = 0 - for i in range(bsz): - seq_len_now = seq_lens_this_time[i] - seq_len_dec_now = seq_lens_decoder[i] - for j in range(seq_len_now): - batch_id_per_token.append(i) - cum_seq_len_q += seq_len_now - cum_seq_len_k += seq_len_now + seq_len_dec_now - cu_seqlens_q[i + 1] = cum_seq_len_q - cu_seqlens_k[i + 1] = cum_seq_len_k - return paddle.to_tensor(batch_id_per_token, dtype="int32"), cu_seqlens_q, cu_seqlens_k + def tearDown(self): + self.CURRENT_Q = [None] + self.TOTAL_K = [] + self.TOTAL_V = [] + def split_qkv(self, qkv, bsz, seq_len): + qkv = qkv.reshape([bsz, seq_len, -1, self.head_dim]) + q = qkv[:, :, : self.num_q_head, :] + self.CURRENT_Q[0] = q -# block_table -block_num_per_seq = (max_seq_len + block_size - 1) // block_size -max_block_num = block_num_per_seq * bsz -cache_shape = ( - max_block_num, - num_kv_head, - block_size, - head_dim, -) + k = qkv[:, :, self.num_q_head : self.num_q_head + self.num_kv_head, :] + self.TOTAL_K.append(k) -cache_k = paddle.zeros(shape=cache_shape).astype(dtype) -cache_v = paddle.zeros(shape=cache_shape).astype(dtype) + v = qkv[:, :, self.num_q_head + self.num_kv_head :, :] + self.TOTAL_V.append(v) -block_tables = paddle.zeros(shape=(bsz, block_num_per_seq), dtype="int32") + def get_padding_offset(self, bsz, seq_lens_this_time, seq_lens_decoder): + batch_id_per_token = [] + cu_seqlens_q = paddle.zeros(shape=(bsz + 1), dtype="int32") + cu_seqlens_k = paddle.zeros(shape=(bsz + 1), dtype="int32") + cum_seq_len_q = 0 + cum_seq_len_k = 0 + for i in range(bsz): + seq_len_now = seq_lens_this_time[i] + seq_len_dec_now = seq_lens_decoder[i] + for j in range(seq_len_now): + batch_id_per_token.append(i) + cum_seq_len_q += seq_len_now + cum_seq_len_k += seq_len_now + seq_len_dec_now + cu_seqlens_q[i + 1] = cum_seq_len_q + cu_seqlens_k[i + 1] = cum_seq_len_k + return paddle.to_tensor(batch_id_per_token, dtype="int32"), cu_seqlens_q, cu_seqlens_k -free_list = list(range(max_block_num - 1, -1, -1)) + def ref_attention(self, q, k, v, mask): + q = q.transpose([0, 2, 1, 3]) + if len(k) > 1: + k = paddle.concat(k, axis=1) + else: + k = k[0] + k = k.transpose([0, 2, 1, 3]) + if len(v) > 1: + v = paddle.concat(v, axis=1) + else: + v = v[0] + v = v.transpose([0, 2, 1, 3]) + total_len = k.shape[2] -for i in range(bsz): - need_block_num = (max_seq_len + block_size - 1) // block_size - for j in range(need_block_num): - block_id = free_list.pop() - block_tables[i, j] = block_id + scores = ( + q.reshape([self.bsz, self.num_kv_head, -1, self.head_dim]) + @ k.transpose([0, 1, 3, 2]) + * (1.0 / math.sqrt(self.head_dim)) + ) + scores = scores.reshape([self.bsz, self.num_q_head, -1, total_len]) + if mask is not None: + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + elif mask.ndim == 3: + mask = mask.unsqueeze(1) + scores = paddle.add(scores, mask) + weights = F.softmax(scores, axis=-1) -def ref_attention(q, k, v, num_q_head, num_kv_head, head_dim, mask): - q = q.transpose([0, 2, 1, 3]) - if len(k) > 1: - k = paddle.concat(k, axis=1) - else: - k = k[0] - k = k.transpose([0, 2, 1, 3]) - if len(v) > 1: - v = paddle.concat(v, axis=1) - else: - v = v[0] - v = v.transpose([0, 2, 1, 3]) - total_len = k.shape[2] + o = weights.reshape([self.bsz, self.num_kv_head, -1, total_len]) @ v + return ( + o.reshape([self.bsz, self.num_q_head, -1, self.head_dim]) + .transpose([0, 2, 1, 3]) + .reshape([-1, self.num_q_head, self.head_dim]) + ) - scores = q.reshape([bsz, num_kv_head, -1, head_dim]) @ k.transpose([0, 1, 3, 2]) * (1.0 / math.sqrt(head_dim)) - scores = scores.reshape([bsz, num_q_head, -1, total_len]) + def run_append_c16_attention(self, q_len, kv_len, prefill=False, attn_mask=None): + if prefill: + seq_lens_enc = [ + q_len, + ] * self.bsz + else: + seq_lens_enc = [ + 0, + ] * self.bsz - if mask is not None: - if mask.ndim == 2: - mask = mask.unsqueeze(0).unsqueeze(0) # [1,1,q_len,kv_len] - elif mask.ndim == 3: - mask = mask.unsqueeze(1) # [bsz,1,q_len,kv_len] - scores = paddle.add(scores, mask) - weights = F.softmax(scores, axis=-1) - - o = weights.reshape([bsz, num_kv_head, -1, total_len]) @ v - return o.reshape([bsz, num_q_head, -1, head_dim]).transpose([0, 2, 1, 3]).reshape([-1, num_q_head, head_dim]) - - -def clear_param(): - global CURRENT_Q, TOTAL_K, TOTAL_V - CURRENT_Q = [None] - TOTAL_K = [] - TOTAL_V = [] - - -def test_append_c16_attention(q_len, kv_len, prefill=False, attn_mask=None): - if prefill: - seq_lens_enc = [ + seq_lens_dec = [ + kv_len, + ] * self.bsz + seq_lens_cur = [ q_len, - ] * bsz - else: - seq_lens_enc = [ - 0, - ] * bsz + ] * self.bsz + token_num = sum(seq_lens_cur) + decoder_step_token_num = 1 if prefill else q_len - seq_lens_dec = [ - kv_len, - ] * bsz - seq_lens_cur = [ - q_len, - ] * bsz - token_num = sum(seq_lens_cur) - decoder_step_token_num = 1 if prefill else q_len + seq_lens_encoder = paddle.to_tensor(seq_lens_enc, "int32") + seq_lens_this_time = paddle.to_tensor(seq_lens_cur, "int32") + seq_lens_decoder = paddle.to_tensor(seq_lens_dec, "int32") - seq_lens_encoder = paddle.to_tensor(seq_lens_enc, "int32") - seq_lens_this_time = paddle.to_tensor(seq_lens_cur, "int32") - seq_lens_decoder = paddle.to_tensor(seq_lens_dec, "int32") + batch_id_per_token, cu_seqlens_q, cu_seqlens_k = self.get_padding_offset( + self.bsz, seq_lens_this_time, seq_lens_decoder + ) - batch_id_per_token, cu_seqlens_q, cu_seqlens_k = get_padding_offset(bsz, seq_lens_this_time, seq_lens_decoder) + qkv_varlen_shape = [token_num, (self.num_q_head + 2 * self.num_kv_head) * self.head_dim] + rotary_embs_shape = [ + 2, + 1, + self.max_seq_len, + 1, + self.head_dim if self.use_neox_rotary_style else self.head_dim // 2, + ] - # random data - qkv_varlen_shape = [token_num, (num_q_head + 2 * num_kv_head) * head_dim] + qkv = paddle.randn(shape=qkv_varlen_shape).astype(self.dtype) + self.split_qkv(qkv, self.bsz, q_len) - rotary_embs_shape = [2, 1, max_seq_len, 1, head_dim if use_neox_rotary_style else head_dim // 2] - # qkv_bias_shape = [num_q_head + 2 * num_kv_head, head_dim] + rotary_embs = paddle.randn(shape=rotary_embs_shape).astype("float32") + rotary_embs[0, :, :, :, :] = 1 + rotary_embs[1, :, :, :, :] = 0 - qkv = paddle.randn(shape=qkv_varlen_shape).astype(dtype) + cache_k_scale = None + cache_v_scale = None + cache_k_out_scale = None + cache_v_out_scale = None - # save q, k, v for ref - split_qkv(qkv, bsz, q_len, num_q_head, num_kv_head, head_dim) + encoder_block_shape_q = 64 + decoder_block_shape_q = 16 - rotary_embs = paddle.randn(shape=rotary_embs_shape).astype("float32") - rotary_embs[0, :, :, :, :] = 1 - rotary_embs[1, :, :, :, :] = 0 - - # qkv_scale = None - # qkv_bias = None - - cache_k_scale = None - cache_v_scale = None - cache_k_out_scale = None - cache_v_out_scale = None - # shift_bias = None - # smooth_weight = None - - encoder_block_shape_q = 64 - decoder_block_shape_q = 16 - - decode_max_tile_size = ( - bsz - * (decoder_step_token_num * (num_q_head // num_kv_head) + decoder_block_shape_q - 1) - / decoder_block_shape_q - ) - decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") - decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") - decoder_num_blocks = paddle.full([1], 0, dtype="int32").pin_memory() - max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu() - paddle.device.synchronize() - ( - encoder_batch_ids, - encoder_tile_ids_per_batch, - encoder_num_blocks, - kv_batch_ids, - kv_tile_ids_per_batch, - kv_num_blocks, - max_len_kv, - ) = get_block_shape_and_split_kv_block( - seq_lens_encoder, - seq_lens_decoder, - seq_lens_this_time, - decoder_batch_ids, - decoder_tile_ids_per_batch, - decoder_num_blocks, - max_len_tensor_cpu, - encoder_block_shape_q, - decoder_block_shape_q, - num_q_head // num_kv_head, - block_size, - decoder_step_token_num, - ) - s_time = 0 - for i in range(run_time + warm_up): - if i == warm_up: - s_time = time.time() - out = append_attention( - qkv, - cache_k, - cache_v, - seq_lens_encoder, - seq_lens_decoder, - seq_lens_this_time, - batch_id_per_token, - cu_seqlens_q, - block_tables, + decode_max_tile_size = ( + self.bsz + * (decoder_step_token_num * (self.num_q_head // self.num_kv_head) + decoder_block_shape_q - 1) + / decoder_block_shape_q + ) + decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") + decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") + decoder_num_blocks = paddle.full([1], 0, dtype="int32").pin_memory() + max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu() + paddle.device.synchronize() + ( encoder_batch_ids, encoder_tile_ids_per_batch, encoder_num_blocks, kv_batch_ids, kv_tile_ids_per_batch, kv_num_blocks, + max_len_kv, + ) = get_block_shape_and_split_kv_block( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks, max_len_tensor_cpu, - max_len_kv, - rotary_embs, - attn_mask, # attn_mask - None, - None, - cache_k_scale, - cache_v_scale, - cache_k_out_scale, - cache_v_out_scale, - None, # cache_k_zp - None, # cache_v_zp - None, - None, - None, - None, - None, - None, - 1e-6, - "bf16", - "none", # cache_quant_type - use_neox_rotary_style, - rope_3d, - max_seq_len, - 0.0, - 0.0, - -1.0, # out_linear_in_scale - encoder_block_shape_q, # encoder_block_shape_q - decoder_block_shape_q, # decoder_block_shape_q - max_partition_size, # max_partition_size - encoder_max_partition_size, # encoder_max_partition_size - decoder_step_token_num, # speculate_max_draft_token_num - True, # causal - decoder_step_token_num > 1, # speculate_decoder + encoder_block_shape_q, + decoder_block_shape_q, + self.num_q_head // self.num_kv_head, + self.block_size, + decoder_step_token_num, ) - paddle.device.synchronize() - e_time = time.time() - print(f"mean infer time: {np.mean((e_time - s_time) * 1000 / run_time):.2f}") - return out[0].reshape([token_num, num_q_head, head_dim]) + s_time = 0 + for i in range(self.run_time + self.warm_up): + if i == self.warm_up: + s_time = time.time() + out = append_attention( + qkv, + self.cache_k, + self.cache_v, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + self.block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks, + max_len_tensor_cpu, + max_len_kv, + rotary_embs, + attn_mask, + None, + None, + cache_k_scale, + cache_v_scale, + cache_k_out_scale, + cache_v_out_scale, + None, + None, + None, + None, + None, + None, + None, + None, + 1e-6, + "bf16", + "none", + self.use_neox_rotary_style, + self.rope_3d, + self.max_seq_len, + 0.0, + 0.0, + -1.0, + encoder_block_shape_q, + decoder_block_shape_q, + self.max_partition_size, + self.encoder_max_partition_size, + decoder_step_token_num, + True, + decoder_step_token_num > 1, + ) + paddle.device.synchronize() + e_time = time.time() + print(f"mean infer time: {np.mean((e_time - s_time) * 1000 / self.run_time):.2f}") + return out[0].reshape([token_num, self.num_q_head, self.head_dim]) + def test_naive_speculative_decoding(self): + prefill_len = 8192 + dec_len_q = 5 + total_len = prefill_len + dec_len_q + mask = paddle.tril(paddle.ones((self.bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len) + mask = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf"))) + self.run_append_c16_attention(prefill_len, 0, True) + dec_out = self.run_append_c16_attention(dec_len_q, prefill_len, False) -def test_naive_speculative_decoding(num_q_head, num_kv_head, head_dim): - prefill_len = 8192 - dec_len_q = 5 - total_len = prefill_len + dec_len_q - mask = paddle.tril(paddle.ones((bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len) - mask = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf"))) - test_append_c16_attention(prefill_len, 0, True) - dec_out = test_append_c16_attention(dec_len_q, prefill_len, False) + ref_out = self.ref_attention(self.CURRENT_Q[0], self.TOTAL_K, self.TOTAL_V, mask) + np.testing.assert_allclose( + ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03 + ) - ref_out = ref_attention(CURRENT_Q[0], TOTAL_K, TOTAL_V, num_q_head, num_kv_head, head_dim, mask) - np.testing.assert_allclose( - ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03 - ) + def test_mask(self): + prefill_len = 8192 + dec_len_q = 5 + total_len = prefill_len + dec_len_q + mask = paddle.tril(paddle.ones((self.bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len) + mask_ref = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf"))) + mask_append_attn = mask[:, :, prefill_len:] + mask_append_attn = paddle.where( + mask_append_attn == 1, + paddle.full_like(mask_append_attn, fill_value=False, dtype=bool), + paddle.full_like(mask_append_attn, fill_value=True, dtype=bool), + ) -def test_mask(num_q_head, num_kv_head, head_dim): - prefill_len = 8192 - dec_len_q = 5 - total_len = prefill_len + dec_len_q - mask = paddle.tril(paddle.ones((bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len) - mask_ref = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf"))) + self.run_append_c16_attention(prefill_len, 0, True) + dec_out = self.run_append_c16_attention(dec_len_q, prefill_len, False, mask_append_attn) - mask_append_attn = mask[:, :, prefill_len:] - mask_append_attn = paddle.where( - mask_append_attn == 1, - paddle.full_like(mask_append_attn, fill_value=False, dtype=bool), - paddle.full_like(mask_append_attn, fill_value=True, dtype=bool), - ) + ref_out = self.ref_attention(self.CURRENT_Q[0], self.TOTAL_K, self.TOTAL_V, mask_ref) - test_append_c16_attention(prefill_len, 0, True) - dec_out = test_append_c16_attention(dec_len_q, prefill_len, False, mask_append_attn) + np.testing.assert_allclose( + ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03 + ) - ref_out = ref_attention(CURRENT_Q[0], TOTAL_K, TOTAL_V, num_q_head, num_kv_head, head_dim, mask_ref) + def test_tree_mask(self): + prefill_len = 8192 + dec_len_q = 5 + total_len = prefill_len + dec_len_q + mask = paddle.tril(paddle.ones((self.bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len) + mask[:, 2, prefill_len + 1] = 0 + mask[:, 3, prefill_len + 2] = 0 + mask[:, 4, prefill_len + 1] = 0 + mask[:, 4, prefill_len + 3] = 0 - np.testing.assert_allclose( - ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03 - ) + mask_ref = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf"))) + mask_append_attn = mask[:, :, prefill_len:] + mask_append_attn = paddle.where( + mask_append_attn == 1, + paddle.full_like(mask_append_attn, fill_value=False, dtype=bool), + paddle.full_like(mask_append_attn, fill_value=True, dtype=bool), + ) -def test_tree_mask(num_q_head, num_kv_head, head_dim): - # tree - # [N, N+1, N+1, N+2, N+2] - # N [0, -inf, -inf, -inf, -inf] - # N+1 [0, 0, -inf, -inf, -inf] - # N+1 [0, -inf, 0, -inf, -inf] - # N+2 [0, 0, -inf, 0, -inf] - # N+2 [0, -inf, 0, -inf, 0] - prefill_len = 8192 - dec_len_q = 5 - total_len = prefill_len + dec_len_q - mask = paddle.tril(paddle.ones((bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len) - mask[:, 2, prefill_len + 1] = 0 - mask[:, 3, prefill_len + 2] = 0 - mask[:, 4, prefill_len + 1] = 0 - mask[:, 4, prefill_len + 3] = 0 - - mask_ref = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf"))) - - mask_append_attn = mask[:, :, prefill_len:] - mask_append_attn = paddle.where( - mask_append_attn == 1, - paddle.full_like(mask_append_attn, fill_value=False, dtype=bool), - paddle.full_like(mask_append_attn, fill_value=True, dtype=bool), - ) - - test_append_c16_attention(prefill_len, 0, True) - dec_out = test_append_c16_attention(dec_len_q, prefill_len, False, mask_append_attn) - ref_out = ref_attention(CURRENT_Q[0], TOTAL_K, TOTAL_V, num_q_head, num_kv_head, head_dim, mask_ref) - np.testing.assert_allclose( - ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03 - ) + self.run_append_c16_attention(prefill_len, 0, True) + dec_out = self.run_append_c16_attention(dec_len_q, prefill_len, False, mask_append_attn) + ref_out = self.ref_attention(self.CURRENT_Q[0], self.TOTAL_K, self.TOTAL_V, mask_ref) + np.testing.assert_allclose( + ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03 + ) if __name__ == "__main__": - - test_naive_speculative_decoding(num_q_head, num_kv_head, head_dim) - clear_param() - - test_mask(num_q_head, num_kv_head, head_dim) - clear_param() - - test_tree_mask(num_q_head, num_kv_head, head_dim) + unittest.main() diff --git a/tests/operators/test_w4afp8_gemm.py b/tests/operators/test_w4afp8_gemm.py index 65240c2d4..66ec408df 100644 --- a/tests/operators/test_w4afp8_gemm.py +++ b/tests/operators/test_w4afp8_gemm.py @@ -12,92 +12,97 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest + import numpy as np import paddle from fastdeploy.model_executor.ops.gpu import w4afp8_gemm, w4afp8_gemm_weight_convert -def w4afp8_gemm_naive(input_bf16, weight_quant, tokens, weight_dequant_scale, BATCH, N): - all_tokens = int(tokens.sum()) - out = paddle.zeros([all_tokens, N], dtype="bfloat16") - pre_fix_token = 0 - for i in range(BATCH): - input = input_bf16[pre_fix_token : pre_fix_token + tokens[i], :] - weight = (weight_quant[i] - 7.0) * weight_dequant_scale[i] - out_i = paddle.matmul(input, weight.astype("bfloat16"), transpose_y=True) - out[pre_fix_token : pre_fix_token + tokens[i], :] = out_i - pre_fix_token += tokens[i] - return out +class TestW4AFP8GEMM(unittest.TestCase): + def setUp(self): + paddle.seed(0) + self.tokens_per_group = 256 + self.N = 256 + self.K = 256 + self.BATCH = 1 + self.TokenPadding = 0 + + tokens = [self.tokens_per_group] * self.BATCH + self.tokens_perfix_sum = np.cumsum(tokens) + + self.tokens = paddle.to_tensor(tokens, dtype="int64") + self.tokens_perfix_sum = paddle.to_tensor(self.tokens_perfix_sum, dtype="int64") + self.all_tokens = int(self.tokens.sum()) + + self.input_fp8 = paddle.randn([self.all_tokens, self.K], dtype="bfloat16").astype(paddle.float8_e4m3fn) + self.input_bf16 = self.input_fp8.astype("bfloat16") + self.weight = paddle.randn([self.BATCH, self.N, self.K], dtype="bfloat16") / 10 + + self.weight_scale = 7 / self.weight.abs().max(axis=-1).reshape([self.BATCH, self.N, 1]) + self.weight_quant = (self.weight * self.weight_scale).astype("int") + 7 + self.weight_quant = paddle.clip(self.weight_quant, 0, 14) + self.weight_quant = self.weight_quant.astype("bfloat16") + self.weight_dequant_scale = 1 / self.weight_scale.astype("float32") + self.input_row_sum = self.input_bf16.sum(axis=1) * -7 / 512 + self.max_tokens = int(self.tokens.max()) + + def w4afp8_gemm_naive(self, input_bf16, weight_quant, tokens, weight_dequant_scale): + all_tokens = int(tokens.sum()) + out = paddle.zeros([all_tokens, self.N], dtype="bfloat16") + pre_fix_token = 0 + for i in range(self.BATCH): + input = input_bf16[pre_fix_token : pre_fix_token + tokens[i], :] + weight = (weight_quant[i] - 7.0) * weight_dequant_scale[i] + out_i = paddle.matmul(input, weight.astype("bfloat16"), transpose_y=True) + out[pre_fix_token : pre_fix_token + tokens[i], :] = out_i + pre_fix_token += tokens[i] + return out + + def permute_scale(self, weight_scale): + weight_scale = weight_scale.reshape([self.BATCH, self.N]) + temp = paddle.zeros([16]) + for b in range(self.BATCH): + for n in range(0, self.N, 16): + temp[:] = weight_scale[b, n : n + 16] + for j in range(0, 16, 2): + weight_scale[b, n + j] = temp[j // 2] + weight_scale[b, n + j + 1] = temp[j // 2 + 8] + return weight_scale + + def test_w4afp8_gemm(self): + out_naive = self.w4afp8_gemm_naive(self.input_bf16, self.weight_quant, self.tokens, self.weight_dequant_scale) + + weight_dequant_scale = paddle.to_tensor(self.permute_scale(self.weight_dequant_scale) * 512) + weight_int4 = w4afp8_gemm_weight_convert(self.weight_quant.astype("uint8").cpu()) + + if self.TokenPadding == 0: + out_cuda = w4afp8_gemm( + self.input_fp8, + weight_int4.cuda(), + self.tokens_perfix_sum, + self.input_row_sum.astype("float32"), + weight_dequant_scale.astype("float32"), + int(self.TokenPadding), + self.max_tokens, + True, + ) + else: + out_cuda = w4afp8_gemm( + self.input_fp8, + weight_int4.cuda(), + self.tokens, + self.input_row_sum.astype("float32"), + weight_dequant_scale.astype("float32"), + int(self.TokenPadding), + self.max_tokens, + True, + ) + + gap = (out_cuda - out_naive).abs() + self.assertLess(float(gap.mean()), 0.07) -def permute_scale(weight_scale): - weight_scale = weight_scale.reshape([BATCH, N]) - temp = paddle.zeros([16]) - for b in range(BATCH): - for n in range(0, N, 16): - temp[:] = weight_scale[b, n : n + 16] - for j in range(0, 16, 2): - weight_scale[b, n + j] = temp[j // 2] - weight_scale[b, n + j + 1] = temp[j // 2 + 8] - return weight_scale - - -paddle.seed(0) -tokens_per_group = 256 -N = 256 -K = 256 -BATCH = 1 -TokenPadding = 0 - -tokens = [tokens_per_group] * BATCH -tokens_perfix_sum = np.cumsum(tokens) - - -tokens = paddle.to_tensor(tokens, dtype="int64") -tokens_perfix_sum = paddle.to_tensor(tokens_perfix_sum, dtype="int64") - -all_tokens = int(tokens.sum()) - -input_fp8 = paddle.randn([all_tokens, K], dtype="bfloat16").astype(paddle.float8_e4m3fn) -input_bf16 = input_fp8.astype("bfloat16") -weight = paddle.randn([BATCH, N, K], dtype="bfloat16") / 10 - -weight_scale = 7 / weight.abs().max(axis=-1).reshape([BATCH, N, 1]) -weight_quant = (weight * weight_scale).astype("int") + 7 -weight_quant = paddle.clip(weight_quant, 0, 14) -weight_quant = weight_quant.astype("bfloat16") -weight_dequant_scale = 1 / weight_scale.astype("float32") -input_row_sum = input_bf16.sum(axis=1) * -7 / 512 -max_tokens = int(tokens.max()) - -out_naive = w4afp8_gemm_naive(input_bf16, weight_quant, tokens, weight_dequant_scale, BATCH, N) -weight_dequant_scale = paddle.to_tensor(permute_scale(weight_dequant_scale) * 512) - -weight_int4 = w4afp8_gemm_weight_convert(weight_quant.astype("uint8").cpu()) - -if TokenPadding == 0: - out_cuda = w4afp8_gemm( - input_fp8, - weight_int4.cuda(), - tokens_perfix_sum, - input_row_sum.astype("float32"), - weight_dequant_scale.astype("float32"), - int(TokenPadding), - max_tokens, - True, - ) -else: - out_cuda = w4afp8_gemm( - input_fp8, - weight_int4.cuda(), - tokens, - input_row_sum.astype("float32"), - weight_dequant_scale.astype("float32"), - int(TokenPadding), - max_tokens, - True, - ) - -gap = (out_cuda - out_naive).abs() -assert float(gap.mean()) < 0.07 +if __name__ == "__main__": + unittest.main()