mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 20:02:53 +08:00 
			
		
		
		
	[CI] Standard unittest (#3606)
* standard unittest * fix bugs * fix script
This commit is contained in:
		| @@ -647,6 +647,9 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): | |||||||
|             model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name) |             model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name) | ||||||
|             process_weights_after_loading_fn(model_sublayer_name, param) |             process_weights_after_loading_fn(model_sublayer_name, param) | ||||||
|         if self.tie_word_embeddings: |         if self.tie_word_embeddings: | ||||||
|  |             # because we use lazy guard and is not initialized by default | ||||||
|  |             if not self.lm_head.linear.weight._is_initialized(): | ||||||
|  |                 self.lm_head.linear.weight.initialize() | ||||||
|             self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0])) |             self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0])) | ||||||
|  |  | ||||||
|     @paddle.no_grad() |     @paddle.no_grad() | ||||||
|   | |||||||
| @@ -11,40 +11,6 @@ cd "$run_path" || exit 1 | |||||||
| failed_tests_file="failed_tests.log" | failed_tests_file="failed_tests.log" | ||||||
| > "$failed_tests_file" | > "$failed_tests_file" | ||||||
|  |  | ||||||
| ################################## |  | ||||||
| # 执行特殊单测case(不符合unittest/pytest格式) |  | ||||||
| ################################## |  | ||||||
| special_tests=( |  | ||||||
|     "graph_optimization/test_cuda_graph_dynamic_subgraph.py" |  | ||||||
|     "graph_optimization/test_cuda_graph_spec_decode.py" |  | ||||||
|     "layers/test_quant_layer.py" |  | ||||||
|     "operators/test_token_penalty.py" |  | ||||||
|     "operators/test_split_fuse.py" |  | ||||||
|     "operators/test_flash_mask_attn.py" |  | ||||||
|     "operators/test_w4afp8_gemm.py" |  | ||||||
|     "model_loader/test_load_ernie_vl.py" |  | ||||||
|     "operators/test_tree_mask.py" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| failed_special=0 |  | ||||||
| success_special=0 |  | ||||||
|  |  | ||||||
| for test_file in "${special_tests[@]}"; do |  | ||||||
|     if [ -f "$test_file" ]; then |  | ||||||
|         echo "Running special test: $test_file" |  | ||||||
|         python -m coverage run --parallel-mode "$test_file" |  | ||||||
|         status=$? |  | ||||||
|         if [ "$status" -ne 0 ]; then |  | ||||||
|             echo "$test_file" >> "$failed_tests_file" |  | ||||||
|             failed_special=$((failed_special+1)) |  | ||||||
|         else |  | ||||||
|             success_special=$((success_special+1)) |  | ||||||
|         fi |  | ||||||
|     else |  | ||||||
|         echo "Warning: $test_file not found" |  | ||||||
|         failed_special=$((failed_special+1)) |  | ||||||
|     fi |  | ||||||
| done |  | ||||||
|  |  | ||||||
| ################################## | ################################## | ||||||
| # 执行 pytest,每个文件单独跑 | # 执行 pytest,每个文件单独跑 | ||||||
| @@ -78,9 +44,8 @@ echo "Pytest failed: $failed_pytest" | |||||||
|  |  | ||||||
| echo "Special tests total: ${#special_tests[@]}" | echo "Special tests total: ${#special_tests[@]}" | ||||||
| echo "Special tests successful: $success_special" | echo "Special tests successful: $success_special" | ||||||
| echo "Special tests failed: $failed_special" |  | ||||||
|  |  | ||||||
| if [ "$failed_pytest" -ne 0 ] || [ "$failed_special" -ne 0 ]; then | if [ "$failed_pytest" -ne 0 ]; then | ||||||
|     echo "Failed test cases are listed in $failed_tests_file" |     echo "Failed test cases are listed in $failed_tests_file" | ||||||
|     cat "$failed_tests_file" |     cat "$failed_tests_file" | ||||||
|     exit 8 |     exit 8 | ||||||
|   | |||||||
| @@ -1,93 +1,99 @@ | |||||||
|  | import unittest | ||||||
|  |  | ||||||
| import numpy as np | import numpy as np | ||||||
| import paddle | import paddle | ||||||
|  |  | ||||||
| from fastdeploy.model_executor.ops.gpu import flash_attention_mask | from fastdeploy.model_executor.ops.gpu import flash_attention_mask | ||||||
|  |  | ||||||
|  |  | ||||||
| def naive_attn(q_input, k_input, v_input, mask): | class TestFlashMaskAttention(unittest.TestCase): | ||||||
|     gqa_group_size = q_input.shape[2] // k_input.shape[2] |     def setUp(self): | ||||||
|  |         self.bsz = 1 | ||||||
|  |         self.num_head = 8 | ||||||
|  |         self.num_kv_head = 1 | ||||||
|  |         self.q_seq_len = 1024 | ||||||
|  |         self.k_seq_len = 1024 | ||||||
|  |         self.head_dim = 128 | ||||||
|  |         np.random.seed(self.q_seq_len) | ||||||
|  |  | ||||||
|     q_cur = q_input.transpose([0, 2, 1, 3]) |     def naive_attn(self, q_input, k_input, v_input, mask): | ||||||
|     k_cur = k_input.transpose([0, 2, 1, 3]) |         gqa_group_size = q_input.shape[2] // k_input.shape[2] | ||||||
|     v_cur = v_input.transpose([0, 2, 1, 3]) |  | ||||||
|     out = np.zeros(q_cur.shape, dtype=q_input.dtype) |  | ||||||
|  |  | ||||||
|     for bsz in range(0, q_cur.shape[0]): |         q_cur = q_input.transpose([0, 2, 1, 3]) | ||||||
|         for hi in range(0, q_cur.shape[1]): |         k_cur = k_input.transpose([0, 2, 1, 3]) | ||||||
|             qk = np.matmul(q_cur[bsz, hi], k_cur[bsz, hi // gqa_group_size].T) * (1.0 / np.sqrt(q_cur.shape[3])) |         v_cur = v_input.transpose([0, 2, 1, 3]) | ||||||
|             for i in range(0, qk.shape[0]): |         out = np.zeros(q_cur.shape, dtype=q_input.dtype) | ||||||
|                 qk[i, mask[i] :] = -1000000 |  | ||||||
|  |  | ||||||
|             qk_max = np.expand_dims(qk.max(axis=-1), -1) |         for bsz in range(0, q_cur.shape[0]): | ||||||
|             qk -= qk_max |             for hi in range(0, q_cur.shape[1]): | ||||||
|             qk = np.exp(qk) |                 qk = np.matmul(q_cur[bsz, hi], k_cur[bsz, hi // gqa_group_size].T) * (1.0 / np.sqrt(q_cur.shape[3])) | ||||||
|  |                 for i in range(0, qk.shape[0]): | ||||||
|  |                     qk[i, mask[i] :] = -1000000 | ||||||
|  |  | ||||||
|             exp_sum = np.expand_dims(qk.sum(axis=-1), -1) |                 qk_max = np.expand_dims(qk.max(axis=-1), -1) | ||||||
|             exp_sum_inv = 1.0 / exp_sum |                 qk -= qk_max | ||||||
|  |                 qk = np.exp(qk) | ||||||
|  |  | ||||||
|             out[bsz, hi] = (np.matmul(qk, v_cur[bsz, hi // gqa_group_size]) * exp_sum_inv).astype(q_input.dtype) |                 exp_sum = np.expand_dims(qk.sum(axis=-1), -1) | ||||||
|     return out |                 exp_sum_inv = 1.0 / exp_sum | ||||||
|  |  | ||||||
|  |                 out[bsz, hi] = (np.matmul(qk, v_cur[bsz, hi // gqa_group_size]) * exp_sum_inv).astype(q_input.dtype) | ||||||
|  |         return out | ||||||
|  |  | ||||||
| def paddle_flash_attn_mask(q_input, k_input, v_input, mask): |     def paddle_flash_attn_mask(self, q_input, k_input, v_input, mask): | ||||||
|     bsz = q_input.shape[0] |         bsz = q_input.shape[0] | ||||||
|     cu_seq_q = paddle.arange(bsz + 1) * q_input.shape[1] |         cu_seq_q = paddle.arange(bsz + 1) * q_input.shape[1] | ||||||
|     cu_seq_k = paddle.arange(bsz + 1) * k_input.shape[1] |         cu_seq_k = paddle.arange(bsz + 1) * k_input.shape[1] | ||||||
|     cu_seq_q = cu_seq_q.astype("int32") |         cu_seq_q = cu_seq_q.astype("int32") | ||||||
|     cu_seq_k = cu_seq_k.astype("int32") |         cu_seq_k = cu_seq_k.astype("int32") | ||||||
|     seq_len_encoder = paddle.ones(bsz) * q_input.shape[1] |         seq_len_encoder = paddle.ones(bsz) * q_input.shape[1] | ||||||
|     seq_len_encoder = seq_len_encoder.astype("int32") |         seq_len_encoder = seq_len_encoder.astype("int32") | ||||||
|     q_input = paddle.to_tensor(q_input).astype("bfloat16").reshape([-1, q_input.shape[2], q_input.shape[3]]) |         q_input = paddle.to_tensor(q_input).astype("bfloat16").reshape([-1, q_input.shape[2], q_input.shape[3]]) | ||||||
|     k_input = paddle.to_tensor(k_input).astype("bfloat16").reshape([-1, k_input.shape[2], k_input.shape[3]]) |         k_input = paddle.to_tensor(k_input).astype("bfloat16").reshape([-1, k_input.shape[2], k_input.shape[3]]) | ||||||
|     v_input = paddle.to_tensor(v_input).astype("bfloat16").reshape([-1, v_input.shape[2], v_input.shape[3]]) |         v_input = paddle.to_tensor(v_input).astype("bfloat16").reshape([-1, v_input.shape[2], v_input.shape[3]]) | ||||||
|     v_input_pad = paddle.zeros([v_input.shape[0] + 128, v_input.shape[1], v_input.shape[2]]).astype("bfloat16") |         v_input_pad = paddle.zeros([v_input.shape[0] + 128, v_input.shape[1], v_input.shape[2]]).astype("bfloat16") | ||||||
|     v_input_pad[0 : v_input.shape[0]] = v_input |         v_input_pad[0 : v_input.shape[0]] = v_input | ||||||
|     mask = paddle.to_tensor(mask).astype("int32") |         mask = paddle.to_tensor(mask).astype("int32") | ||||||
|  |  | ||||||
|     out = flash_attention_mask( |         out = flash_attention_mask( | ||||||
|         q_input, |             q_input, | ||||||
|         k_input, |             k_input, | ||||||
|         v_input_pad, |             v_input_pad, | ||||||
|         cu_seq_q, |             cu_seq_q, | ||||||
|         cu_seq_k, |             cu_seq_k, | ||||||
|         seq_len_encoder, |             seq_len_encoder, | ||||||
|         mask, |             mask, | ||||||
|         int(q_input.shape[1]), |             int(q_input.shape[1]), | ||||||
|         int(k_input.shape[1]), |             int(k_input.shape[1]), | ||||||
|         int(q_input.shape[2]), |             int(q_input.shape[2]), | ||||||
|         int(k_input.shape[0]), |             int(k_input.shape[0]), | ||||||
|         int(q_input.shape[0]), |             int(q_input.shape[0]), | ||||||
|         int(k_input.shape[0]), |             int(k_input.shape[0]), | ||||||
|     ) |         ) | ||||||
|     return out |         return out | ||||||
|  |  | ||||||
|  |     def test_flash_attention_mask(self): | ||||||
|  |         q_input = np.random.normal(0, 0.5, size=(self.bsz, self.q_seq_len, self.num_head, self.head_dim)) | ||||||
|  |         k_input = np.random.normal( | ||||||
|  |             0, 0.5, size=(self.bsz, self.q_seq_len + self.k_seq_len, self.num_kv_head, self.head_dim) | ||||||
|  |         ) | ||||||
|  |         v_input = np.random.normal( | ||||||
|  |             0, 0.5, size=(self.bsz, self.q_seq_len + self.k_seq_len, self.num_kv_head, self.head_dim) | ||||||
|  |         ) | ||||||
|  |  | ||||||
| def test(bsz, num_head, num_kv_head, q_seq_len, k_seq_len): |         random_len = np.random.randint(self.q_seq_len // 2, size=2) | ||||||
|     head_dim = 128 |         text_len = random_len[0] | ||||||
|     q_input = np.random.normal(0, 0.5, size=(bsz, q_seq_len, num_head, head_dim)) |         image_len = random_len[1] | ||||||
|     k_input = np.random.normal(0, 0.5, size=(bsz, q_seq_len + k_seq_len, num_kv_head, head_dim)) |  | ||||||
|     v_input = np.random.normal(0, 0.5, size=(bsz, q_seq_len + k_seq_len, num_kv_head, head_dim)) |  | ||||||
|  |  | ||||||
|     random_len = np.random.randint(q_seq_len // 2, size=2) |         mask = np.array([i + 1 for i in range(0, self.q_seq_len)]) + self.k_seq_len | ||||||
|  |         mask[text_len : text_len + image_len] = text_len + image_len + self.k_seq_len | ||||||
|  |  | ||||||
|     text_len = random_len[0] |         naive_attn_out = self.naive_attn(q_input, k_input, v_input, mask) | ||||||
|     image_len = random_len[1] |         paddle_attn_out = self.paddle_flash_attn_mask(q_input, k_input, v_input, mask) | ||||||
|  |  | ||||||
|     mask = np.array([i + 1 for i in range(0, q_seq_len)]) + k_seq_len |         max_diff = float((paddle_attn_out.reshape([-1]) - paddle.to_tensor(naive_attn_out).reshape([-1])).max()) | ||||||
|  |         self.assertLessEqual(max_diff, 0.05) | ||||||
|     mask[text_len : text_len + image_len] = text_len + image_len + k_seq_len |  | ||||||
|  |  | ||||||
|     naive_attn_out = naive_attn(q_input, k_input, v_input, mask) |  | ||||||
|     paddle_attn_out = paddle_flash_attn_mask(q_input, k_input, v_input, mask) |  | ||||||
|  |  | ||||||
|     assert float((paddle_attn_out.reshape([-1]) - paddle.to_tensor(naive_attn_out).reshape([-1])).max()) <= 0.05 |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     bsz = 1 |     unittest.main() | ||||||
|     num_head = 8 |  | ||||||
|     num_kv_head = 1 |  | ||||||
|     q_seq_len = 1024 |  | ||||||
|     k_seq_len = 1024 |  | ||||||
|     np.random.seed(q_seq_len) |  | ||||||
|     test(bsz, num_head, num_kv_head, q_seq_len, k_seq_len) |  | ||||||
|   | |||||||
| @@ -1,88 +0,0 @@ | |||||||
| # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
|  |  | ||||||
| """UT for per_channel_fp8_fp8_half_gemm_fused kernel""" |  | ||||||
|  |  | ||||||
| import os |  | ||||||
| import unittest |  | ||||||
| from itertools import product |  | ||||||
|  |  | ||||||
| import numpy as np |  | ||||||
| import paddle |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Test(unittest.TestCase): |  | ||||||
|     def setUp(self): |  | ||||||
|         """ |  | ||||||
|             Initialize the test environment, |  | ||||||
|         including setting random seeds and environment variables. |  | ||||||
|         """ |  | ||||||
|         paddle.seed(2003) |  | ||||||
|         os.environ["FLAGS_use_cutlass_device_best_config_path"] = "default" |  | ||||||
|  |  | ||||||
|     def testcase1(self): |  | ||||||
|         """ |  | ||||||
|         Check if the per_channel_fp8_fp8_half_gemm_fused function works properly. |  | ||||||
|         """ |  | ||||||
|         prop = paddle.device.cuda.get_device_properties() |  | ||||||
|         cc = prop.major * 10 + prop.minor |  | ||||||
|         if cc < 89: |  | ||||||
|             self.skipTest("per_channel_fp8_fp8_half_gemm_fused only support sm89+") |  | ||||||
|  |  | ||||||
|         from fastdeploy.model_executor.ops.gpu import ( |  | ||||||
|             per_channel_fp8_fp8_half_gemm_fused, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         nks = [[2048, 2048], [2048, 5504], [6144, 2048]] |  | ||||||
|         nks = nks + [[4096, 4096], [4096, 12800], [6144, 4096]] |  | ||||||
|         nks = nks + [[5120, 5120], [5120, 13824], [15360, 5120]] |  | ||||||
|  |  | ||||||
|         m = [1, 32, 64, 128, 256, 512, 1024, 2048] |  | ||||||
|  |  | ||||||
|         combinations = list(product(m, nks)) |  | ||||||
|         for m, (n, k) in combinations: |  | ||||||
|             A_bf16 = paddle.rand(shape=[m, k], dtype="bfloat16") |  | ||||||
|             A_fp8 = paddle.cast(A_bf16, "float8_e4m3fn") |  | ||||||
|             B_bf16 = paddle.rand(shape=[n, k], dtype="bfloat16") |  | ||||||
|             B_fp8 = B_bf16.astype("float8_e4m3fn") |  | ||||||
|  |  | ||||||
|             scalar_scale = paddle.full([1], 0.5, dtype="float32") |  | ||||||
|             channel_scale = paddle.rand(shape=[n], dtype="float32") |  | ||||||
|             bias = paddle.rand(shape=[n], dtype="bfloat16") |  | ||||||
|  |  | ||||||
|             result_bf16 = paddle.matmul(A_bf16, B_bf16, transpose_y=True) * scalar_scale * channel_scale + bias |  | ||||||
|             result_fp8 = per_channel_fp8_fp8_half_gemm_fused( |  | ||||||
|                 A_fp8, |  | ||||||
|                 B_fp8, |  | ||||||
|                 bias=bias, |  | ||||||
|                 scalar_scale=scalar_scale, |  | ||||||
|                 channel_scale=channel_scale, |  | ||||||
|                 transpose_x=False, |  | ||||||
|                 transpose_y=True, |  | ||||||
|                 output_dtype="bfloat16", |  | ||||||
|             ) |  | ||||||
|             # absolute_error = paddle.abs(result_bf16 - result_fp8) |  | ||||||
|             # mean_absolute_error = paddle.mean(absolute_error) |  | ||||||
|             relative_error = paddle.abs(result_bf16 - result_fp8) / (paddle.abs(result_bf16)) |  | ||||||
|             mean_relative_error = paddle.mean(relative_error) |  | ||||||
|             np.testing.assert_allclose( |  | ||||||
|                 mean_relative_error.numpy(), |  | ||||||
|                 np.array([0.001]), |  | ||||||
|                 rtol=0.001, |  | ||||||
|                 atol=0.25, |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     unittest.main() |  | ||||||
| @@ -13,72 +13,81 @@ | |||||||
| # limitations under the License. | # limitations under the License. | ||||||
|  |  | ||||||
| """UT for set_stop_value""" | """UT for set_stop_value""" | ||||||
|  | import unittest | ||||||
|  |  | ||||||
| import paddle | import paddle | ||||||
|  |  | ||||||
| from fastdeploy.model_executor.ops.gpu import get_mm_split_fuse | from fastdeploy.model_executor.ops.gpu import get_mm_split_fuse | ||||||
|  |  | ||||||
| input_ids = [] |  | ||||||
| image_type_ids = [] |  | ||||||
| grid_thw = [] |  | ||||||
|  |  | ||||||
|  | class TestSplitFuse(unittest.TestCase): | ||||||
|  |     def setUp(self): | ||||||
|  |         self.grid_thw = [[6, 20, 20], [6, 40, 20]] | ||||||
|  |         self.split_fuse_img_size = 16 | ||||||
|  |         self.split_fuse_text_size = 384  # 1024 | ||||||
|  |         self.max_seq_len = 2048 | ||||||
|  |         self.image_token_id = 100295 | ||||||
|  |  | ||||||
| def split_grid(origin_grid_thw): |     def split_grid(self, origin_grid_thw): | ||||||
|     # 划分grid_thw,该函数用于视频场景 |         # 划分grid_thw,该函数用于视频场景 | ||||||
|     # origin_grid_thw = [6, 10, 12] ---> [2, 10, 12, 2, 10, 12, 2, 10, 12] |         # origin_grid_thw = [6, 10, 12] ---> [2, 10, 12, 2, 10, 12, 2, 10, 12] | ||||||
|     grid_thw = [] |         grid_thw = [] | ||||||
|     for t, h, w in origin_grid_thw: |         for t, h, w in origin_grid_thw: | ||||||
|         if t > 2: |             if t > 2: | ||||||
|             num_groups = t // 2 |                 num_groups = t // 2 | ||||||
|             remainder = t % 2 |                 remainder = t % 2 | ||||||
|             for _ in range(num_groups): |                 for _ in range(num_groups): | ||||||
|                 grid_thw.extend([2, h, w]) |                     grid_thw.extend([2, h, w]) | ||||||
|             if remainder > 0: |                 if remainder > 0: | ||||||
|                 grid_thw.extend([remainder, h, w]) |                     grid_thw.extend([remainder, h, w]) | ||||||
|         else: |             else: | ||||||
|             grid_thw.extend([t, h, w]) |                 grid_thw.extend([t, h, w]) | ||||||
|     return grid_thw |         return grid_thw | ||||||
|  |  | ||||||
|  |     def test_get_mm_split_fuse(self): | ||||||
|  |         grid_thw = self.split_grid(self.grid_thw) | ||||||
|  |         image_bs = len(grid_thw) // 3 | ||||||
|  |         image_type_ids = [0] * image_bs | ||||||
|  |  | ||||||
|  |         # 随机拼接input_ids: [txt0+img1+tx1+img2] | ||||||
|  |         input_ids = [2] * 19 | ||||||
|  |         img1 = [self.image_token_id] * 100 * 3 | ||||||
|  |         txt1 = [3] * 19 | ||||||
|  |         img2 = [self.image_token_id] * 200 * 3 | ||||||
|  |         input_ids.extend(img1) | ||||||
|  |         input_ids.extend(txt1) | ||||||
|  |         input_ids.extend(img2) | ||||||
|  |  | ||||||
|  |         seq_len = len(input_ids) | ||||||
|  |         input_ids_tensor = paddle.to_tensor(input_ids, dtype="int64") | ||||||
|  |         image_type_ids_tensor = paddle.to_tensor(image_type_ids, dtype="int32") | ||||||
|  |         is_image_token = paddle.where(input_ids_tensor == self.image_token_id, 1, 0) | ||||||
|  |         image_token_sum = paddle.cumsum(is_image_token)  # 前缀和 | ||||||
|  |         image_token_sum = paddle.concat([paddle.zeros([1], dtype="int64"), image_token_sum]) | ||||||
|  |  | ||||||
|  |         grid_thw_tensor = paddle.to_tensor(grid_thw, dtype="int64") | ||||||
|  |         image_chunk_selections, split_fuse_cur_seq_lens = get_mm_split_fuse( | ||||||
|  |             input_ids_tensor.cpu(), | ||||||
|  |             image_type_ids_tensor.cast("int32").cpu(), | ||||||
|  |             image_token_sum.cast("int32").cpu(), | ||||||
|  |             grid_thw_tensor.cpu(), | ||||||
|  |             self.image_token_id, | ||||||
|  |             image_bs, | ||||||
|  |             0, | ||||||
|  |             seq_len, | ||||||
|  |             self.split_fuse_img_size, | ||||||
|  |             self.split_fuse_text_size, | ||||||
|  |             self.max_seq_len, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         # Verify the outputs are not None | ||||||
|  |         self.assertIsNotNone(image_chunk_selections) | ||||||
|  |         self.assertIsNotNone(split_fuse_cur_seq_lens) | ||||||
|  |  | ||||||
|  |         # Verify the shapes are as expected | ||||||
|  |         self.assertEqual(len(image_chunk_selections.shape), 1) | ||||||
|  |         self.assertEqual(len(split_fuse_cur_seq_lens.shape), 1) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     grid_thw = [[6, 20, 20], [6, 40, 20]] |     unittest.main() | ||||||
|     grid_thw = split_grid(grid_thw) |  | ||||||
|     image_bs = len(grid_thw) // 3 |  | ||||||
|     image_type_ids = [0] * image_bs |  | ||||||
|     # 随机拼接input_ids: [txt0+img1+tx1+img2] |  | ||||||
|     input_ids = [2] * 19 |  | ||||||
|     img1 = [100295] * 100 * 3 |  | ||||||
|     txt1 = [3] * 19 |  | ||||||
|     img2 = [100295] * 200 * 3 |  | ||||||
|     input_ids.extend(img1) |  | ||||||
|     input_ids.extend(txt1) |  | ||||||
|     input_ids.extend(img2) |  | ||||||
|  |  | ||||||
|     split_fuse_img_size = 16 |  | ||||||
|     split_fuse_text_size = 384  # 1024 |  | ||||||
|  |  | ||||||
|     seq_len = len(input_ids) |  | ||||||
|     input_ids_tensor = paddle.to_tensor(input_ids, dtype="int64") |  | ||||||
|     image_type_ids_tensor = paddle.to_tensor(image_type_ids, dtype="int32") |  | ||||||
|     is_image_token = paddle.where(input_ids_tensor == 100295, 1, 0) |  | ||||||
|     image_token_sum = paddle.cumsum(is_image_token)  # 前缀和 |  | ||||||
|     image_token_sum = paddle.concat([paddle.zeros([1], dtype="int64"), image_token_sum]) |  | ||||||
|  |  | ||||||
|     grid_thw_tensor = paddle.to_tensor(grid_thw, dtype="int64") |  | ||||||
|     image_chunk_selections, split_fuse_cur_seq_lens = get_mm_split_fuse( |  | ||||||
|         input_ids_tensor.cpu(), |  | ||||||
|         image_type_ids_tensor.cast("int32").cpu(), |  | ||||||
|         image_token_sum.cast("int32").cpu(), |  | ||||||
|         grid_thw_tensor.cpu(), |  | ||||||
|         100295, |  | ||||||
|         image_bs, |  | ||||||
|         0, |  | ||||||
|         seq_len, |  | ||||||
|         split_fuse_img_size, |  | ||||||
|         split_fuse_text_size, |  | ||||||
|         2048, |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     print("seq_len: ", seq_len) |  | ||||||
|     print("grid_thw", grid_thw_tensor) |  | ||||||
|     print("image_chunk_selections: ", image_chunk_selections) |  | ||||||
|     print("split_fuse_cur_seq_lens: ", split_fuse_cur_seq_lens) |  | ||||||
|   | |||||||
| @@ -13,40 +13,58 @@ | |||||||
| # limitations under the License. | # limitations under the License. | ||||||
|  |  | ||||||
| """UT for get_token_penalty""" | """UT for get_token_penalty""" | ||||||
|  | import unittest | ||||||
|  |  | ||||||
| import numpy as np | import numpy as np | ||||||
| import paddle | import paddle | ||||||
|  |  | ||||||
| from fastdeploy.model_executor.ops.gpu import get_token_penalty_once | from fastdeploy.model_executor.ops.gpu import get_token_penalty_once | ||||||
|  |  | ||||||
| paddle.seed(2023) |  | ||||||
|  |  | ||||||
| pre_ids = paddle.randint(0, 10000, (8, 1000)) | class TestTokenPenalty(unittest.TestCase): | ||||||
| pre_ids[:, -1] = pre_ids[:, -2] |     def setUp(self): | ||||||
| print(pre_ids) |         paddle.seed(2023) | ||||||
| logits = paddle.rand(shape=[8, 10000], dtype="float16") |         self.pre_ids = paddle.randint(0, 10000, (8, 1000)) | ||||||
| penalty_scores = np.array([1.2] * 8).astype(np.float16).reshape(-1, 1) |         self.pre_ids[:, -1] = self.pre_ids[:, -2] | ||||||
| penalty_scores = paddle.to_tensor(penalty_scores) |         self.logits = paddle.rand(shape=[8, 10000], dtype="float16") | ||||||
|  |         self.penalty_scores = np.array([1.2] * 8).astype(np.float16).reshape(-1, 1) | ||||||
|  |         self.penalty_scores = paddle.to_tensor(self.penalty_scores) | ||||||
|  |  | ||||||
| print("logits[0][pre_ids[0]]: ", logits[0][pre_ids[0]]) |     def test_token_penalty_once(self): | ||||||
| res = get_token_penalty_once(pre_ids, logits, penalty_scores) |         res = get_token_penalty_once(self.pre_ids, self.logits, self.penalty_scores) | ||||||
| for i in range(8): |  | ||||||
|     print(f"res[{i}]:{res[i][pre_ids[i]]}") |         # 验证结果形状 | ||||||
|  |         self.assertEqual(res.shape, self.logits.shape) | ||||||
|  |  | ||||||
|  |         # 验证惩罚逻辑 | ||||||
|  |         for i in range(8): | ||||||
|  |             original_values = self.logits[i][self.pre_ids[i]] | ||||||
|  |             penalized_values = res[i][self.pre_ids[i]] | ||||||
|  |             # 检查是否应用了惩罚 | ||||||
|  |             for orig, penal in zip(original_values.numpy(), penalized_values.numpy()): | ||||||
|  |                 if orig < 0: | ||||||
|  |                     self.assertLess(penal, orig, "负值应该乘以惩罚因子") | ||||||
|  |                 else: | ||||||
|  |                     self.assertLess(penal, orig, "正值应该除以惩罚因子") | ||||||
|  |  | ||||||
|  |     def test_compare_with_naive_implementation(self): | ||||||
|  |         res = get_token_penalty_once(self.pre_ids, self.logits, self.penalty_scores) | ||||||
|  |  | ||||||
|  |         # 朴素实现 | ||||||
|  |         score = paddle.index_sample(self.logits, self.pre_ids) | ||||||
|  |         score = paddle.where(score < 0, score * self.penalty_scores, score / self.penalty_scores) | ||||||
|  |  | ||||||
|  |         bsz = paddle.shape(self.logits)[0] | ||||||
|  |         bsz_range = paddle.arange(start=bsz * 0, end=bsz, step=bsz / bsz, name="bsz_range", dtype="int64").unsqueeze( | ||||||
|  |             -1 | ||||||
|  |         ) | ||||||
|  |         input_ids = self.pre_ids + bsz_range * self.logits.shape[-1] | ||||||
|  |         res2 = paddle.scatter(self.logits.flatten(), input_ids.flatten(), score.flatten()).reshape(self.logits.shape) | ||||||
|  |  | ||||||
|  |         # 比较两种实现的结果差异 | ||||||
|  |         max_diff = (res - res2).abs().max().item() | ||||||
|  |         self.assertLess(max_diff, 1e-5) | ||||||
|  |  | ||||||
|  |  | ||||||
| input_ids = pre_ids | if __name__ == "__main__": | ||||||
| score = paddle.index_sample(logits, input_ids) |     unittest.main() | ||||||
| score = paddle.where(score < 0, score * penalty_scores, score / penalty_scores) |  | ||||||
|  |  | ||||||
| bsz = paddle.shape(logits)[0]  # TODO: Bsz as input for inference with dynamic batch_size |  | ||||||
| bsz_range = paddle.arange(start=bsz * 0, end=bsz, step=bsz / bsz, name="bsz_range", dtype="int64").unsqueeze(-1) |  | ||||||
| input_ids = input_ids + bsz_range * logits.shape[-1] |  | ||||||
| res2 = paddle.scatter(logits.flatten(), input_ids.flatten(), score.flatten()).reshape(logits.shape) |  | ||||||
| print("-------------------------------------------") |  | ||||||
| for i in range(8): |  | ||||||
|     print(res2[i][pre_ids[i]]) |  | ||||||
|  |  | ||||||
| print("res_sub:") |  | ||||||
| for i in range(8): |  | ||||||
|     print(res2[i][pre_ids[i]] - res[i][pre_ids[i]]) |  | ||||||
|  |  | ||||||
| print((res.numpy() - res2.numpy()).sum()) |  | ||||||
|   | |||||||
| @@ -1,5 +1,6 @@ | |||||||
| import math | import math | ||||||
| import time | import time | ||||||
|  | import unittest | ||||||
|  |  | ||||||
| import numpy as np | import numpy as np | ||||||
| import paddle | import paddle | ||||||
| @@ -10,351 +11,331 @@ from fastdeploy.model_executor.layers.attention.ops import ( | |||||||
|     get_block_shape_and_split_kv_block, |     get_block_shape_and_split_kv_block, | ||||||
| ) | ) | ||||||
|  |  | ||||||
| paddle.seed(0) |  | ||||||
|  |  | ||||||
| max_seq_len = 32768 | class TestTreeMask(unittest.TestCase): | ||||||
| encoder_max_partition_size = max_seq_len |     def setUp(self): | ||||||
| max_partition_size = max_seq_len |         paddle.seed(0) | ||||||
|  |         self.max_seq_len = 32768 | ||||||
|  |         self.encoder_max_partition_size = self.max_seq_len | ||||||
|  |         self.max_partition_size = self.max_seq_len | ||||||
|  |  | ||||||
| max_dec_len = 1024 |         self.max_dec_len = 1024 | ||||||
| bsz = 64 |         self.bsz = 64 | ||||||
| run_time = 10 |         self.run_time = 10 | ||||||
| warm_up = 2 |         self.warm_up = 2 | ||||||
| block_size = 64 |         self.block_size = 64 | ||||||
| head_dim = 128 |         self.head_dim = 128 | ||||||
| num_q_head = 20 |         self.num_q_head = 20 | ||||||
| num_kv_head = 4 |         self.num_kv_head = 4 | ||||||
| dtype = "bfloat16" |         self.dtype = "bfloat16" | ||||||
|  |  | ||||||
| rope_3d = False |         self.rope_3d = False | ||||||
| use_neox_rotary_style = False |         self.use_neox_rotary_style = False | ||||||
| CURRENT_Q = [None] |         self.CURRENT_Q = [None] | ||||||
| TOTAL_K = [] |         self.TOTAL_K = [] | ||||||
| TOTAL_V = [] |         self.TOTAL_V = [] | ||||||
|  |  | ||||||
|  |         # Initialize cache and block tables | ||||||
|  |         block_num_per_seq = (self.max_seq_len + self.block_size - 1) // self.block_size | ||||||
|  |         max_block_num = block_num_per_seq * self.bsz | ||||||
|  |         cache_shape = ( | ||||||
|  |             max_block_num, | ||||||
|  |             self.num_kv_head, | ||||||
|  |             self.block_size, | ||||||
|  |             self.head_dim, | ||||||
|  |         ) | ||||||
|  |  | ||||||
| def split_qkv(qkv, bsz, seq_len, num_q_head, num_kv_head, head_dim): |         self.cache_k = paddle.zeros(shape=cache_shape).astype(self.dtype) | ||||||
|     # [token_num, (num_q_head + 2 * num_kv_head) * head_dim] |         self.cache_v = paddle.zeros(shape=cache_shape).astype(self.dtype) | ||||||
|     qkv = qkv.reshape([bsz, seq_len, -1, head_dim]) |  | ||||||
|     q = qkv[:, :, :num_q_head, :] |  | ||||||
|     # [bsz,  seq_len, num_q_head, head_dim] |  | ||||||
|     CURRENT_Q[0] = q |  | ||||||
|  |  | ||||||
|     # [bsz,  seq_len, num_kv_head, head_dim] |         self.block_tables = paddle.zeros(shape=(self.bsz, block_num_per_seq), dtype="int32") | ||||||
|     k = qkv[:, :, num_q_head : num_q_head + num_kv_head, :] |  | ||||||
|     TOTAL_K.append(k) |  | ||||||
|  |  | ||||||
|     # [bsz,  seq_len, num_kv_head, head_dim] |         free_list = list(range(max_block_num - 1, -1, -1)) | ||||||
|     v = qkv[:, :, num_q_head + num_kv_head :, :] |  | ||||||
|     TOTAL_V.append(v) |  | ||||||
|  |  | ||||||
|  |         for i in range(self.bsz): | ||||||
|  |             need_block_num = (self.max_seq_len + self.block_size - 1) // self.block_size | ||||||
|  |             for j in range(need_block_num): | ||||||
|  |                 block_id = free_list.pop() | ||||||
|  |                 self.block_tables[i, j] = block_id | ||||||
|  |  | ||||||
| def get_padding_offset(bsz, seq_lens_this_time, seq_lens_decoder): |     def tearDown(self): | ||||||
|     batch_id_per_token = [] |         self.CURRENT_Q = [None] | ||||||
|     cu_seqlens_q = paddle.zeros(shape=(bsz + 1), dtype="int32") |         self.TOTAL_K = [] | ||||||
|     cu_seqlens_k = paddle.zeros(shape=(bsz + 1), dtype="int32") |         self.TOTAL_V = [] | ||||||
|     cum_seq_len_q = 0 |  | ||||||
|     cum_seq_len_k = 0 |  | ||||||
|     for i in range(bsz): |  | ||||||
|         seq_len_now = seq_lens_this_time[i] |  | ||||||
|         seq_len_dec_now = seq_lens_decoder[i] |  | ||||||
|         for j in range(seq_len_now): |  | ||||||
|             batch_id_per_token.append(i) |  | ||||||
|         cum_seq_len_q += seq_len_now |  | ||||||
|         cum_seq_len_k += seq_len_now + seq_len_dec_now |  | ||||||
|         cu_seqlens_q[i + 1] = cum_seq_len_q |  | ||||||
|         cu_seqlens_k[i + 1] = cum_seq_len_k |  | ||||||
|     return paddle.to_tensor(batch_id_per_token, dtype="int32"), cu_seqlens_q, cu_seqlens_k |  | ||||||
|  |  | ||||||
|  |     def split_qkv(self, qkv, bsz, seq_len): | ||||||
|  |         qkv = qkv.reshape([bsz, seq_len, -1, self.head_dim]) | ||||||
|  |         q = qkv[:, :, : self.num_q_head, :] | ||||||
|  |         self.CURRENT_Q[0] = q | ||||||
|  |  | ||||||
| # block_table |         k = qkv[:, :, self.num_q_head : self.num_q_head + self.num_kv_head, :] | ||||||
| block_num_per_seq = (max_seq_len + block_size - 1) // block_size |         self.TOTAL_K.append(k) | ||||||
| max_block_num = block_num_per_seq * bsz |  | ||||||
| cache_shape = ( |  | ||||||
|     max_block_num, |  | ||||||
|     num_kv_head, |  | ||||||
|     block_size, |  | ||||||
|     head_dim, |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| cache_k = paddle.zeros(shape=cache_shape).astype(dtype) |         v = qkv[:, :, self.num_q_head + self.num_kv_head :, :] | ||||||
| cache_v = paddle.zeros(shape=cache_shape).astype(dtype) |         self.TOTAL_V.append(v) | ||||||
|  |  | ||||||
| block_tables = paddle.zeros(shape=(bsz, block_num_per_seq), dtype="int32") |     def get_padding_offset(self, bsz, seq_lens_this_time, seq_lens_decoder): | ||||||
|  |         batch_id_per_token = [] | ||||||
|  |         cu_seqlens_q = paddle.zeros(shape=(bsz + 1), dtype="int32") | ||||||
|  |         cu_seqlens_k = paddle.zeros(shape=(bsz + 1), dtype="int32") | ||||||
|  |         cum_seq_len_q = 0 | ||||||
|  |         cum_seq_len_k = 0 | ||||||
|  |         for i in range(bsz): | ||||||
|  |             seq_len_now = seq_lens_this_time[i] | ||||||
|  |             seq_len_dec_now = seq_lens_decoder[i] | ||||||
|  |             for j in range(seq_len_now): | ||||||
|  |                 batch_id_per_token.append(i) | ||||||
|  |             cum_seq_len_q += seq_len_now | ||||||
|  |             cum_seq_len_k += seq_len_now + seq_len_dec_now | ||||||
|  |             cu_seqlens_q[i + 1] = cum_seq_len_q | ||||||
|  |             cu_seqlens_k[i + 1] = cum_seq_len_k | ||||||
|  |         return paddle.to_tensor(batch_id_per_token, dtype="int32"), cu_seqlens_q, cu_seqlens_k | ||||||
|  |  | ||||||
| free_list = list(range(max_block_num - 1, -1, -1)) |     def ref_attention(self, q, k, v, mask): | ||||||
|  |         q = q.transpose([0, 2, 1, 3]) | ||||||
|  |         if len(k) > 1: | ||||||
|  |             k = paddle.concat(k, axis=1) | ||||||
|  |         else: | ||||||
|  |             k = k[0] | ||||||
|  |         k = k.transpose([0, 2, 1, 3]) | ||||||
|  |         if len(v) > 1: | ||||||
|  |             v = paddle.concat(v, axis=1) | ||||||
|  |         else: | ||||||
|  |             v = v[0] | ||||||
|  |         v = v.transpose([0, 2, 1, 3]) | ||||||
|  |         total_len = k.shape[2] | ||||||
|  |  | ||||||
| for i in range(bsz): |         scores = ( | ||||||
|     need_block_num = (max_seq_len + block_size - 1) // block_size |             q.reshape([self.bsz, self.num_kv_head, -1, self.head_dim]) | ||||||
|     for j in range(need_block_num): |             @ k.transpose([0, 1, 3, 2]) | ||||||
|         block_id = free_list.pop() |             * (1.0 / math.sqrt(self.head_dim)) | ||||||
|         block_tables[i, j] = block_id |         ) | ||||||
|  |         scores = scores.reshape([self.bsz, self.num_q_head, -1, total_len]) | ||||||
|  |  | ||||||
|  |         if mask is not None: | ||||||
|  |             if mask.ndim == 2: | ||||||
|  |                 mask = mask.unsqueeze(0).unsqueeze(0) | ||||||
|  |             elif mask.ndim == 3: | ||||||
|  |                 mask = mask.unsqueeze(1) | ||||||
|  |             scores = paddle.add(scores, mask) | ||||||
|  |         weights = F.softmax(scores, axis=-1) | ||||||
|  |  | ||||||
| def ref_attention(q, k, v, num_q_head, num_kv_head, head_dim, mask): |         o = weights.reshape([self.bsz, self.num_kv_head, -1, total_len]) @ v | ||||||
|     q = q.transpose([0, 2, 1, 3]) |         return ( | ||||||
|     if len(k) > 1: |             o.reshape([self.bsz, self.num_q_head, -1, self.head_dim]) | ||||||
|         k = paddle.concat(k, axis=1) |             .transpose([0, 2, 1, 3]) | ||||||
|     else: |             .reshape([-1, self.num_q_head, self.head_dim]) | ||||||
|         k = k[0] |         ) | ||||||
|     k = k.transpose([0, 2, 1, 3]) |  | ||||||
|     if len(v) > 1: |  | ||||||
|         v = paddle.concat(v, axis=1) |  | ||||||
|     else: |  | ||||||
|         v = v[0] |  | ||||||
|     v = v.transpose([0, 2, 1, 3]) |  | ||||||
|     total_len = k.shape[2] |  | ||||||
|  |  | ||||||
|     scores = q.reshape([bsz, num_kv_head, -1, head_dim]) @ k.transpose([0, 1, 3, 2]) * (1.0 / math.sqrt(head_dim)) |     def run_append_c16_attention(self, q_len, kv_len, prefill=False, attn_mask=None): | ||||||
|     scores = scores.reshape([bsz, num_q_head, -1, total_len]) |         if prefill: | ||||||
|  |             seq_lens_enc = [ | ||||||
|  |                 q_len, | ||||||
|  |             ] * self.bsz | ||||||
|  |         else: | ||||||
|  |             seq_lens_enc = [ | ||||||
|  |                 0, | ||||||
|  |             ] * self.bsz | ||||||
|  |  | ||||||
|     if mask is not None: |         seq_lens_dec = [ | ||||||
|         if mask.ndim == 2: |             kv_len, | ||||||
|             mask = mask.unsqueeze(0).unsqueeze(0)  # [1,1,q_len,kv_len] |         ] * self.bsz | ||||||
|         elif mask.ndim == 3: |         seq_lens_cur = [ | ||||||
|             mask = mask.unsqueeze(1)  # [bsz,1,q_len,kv_len] |  | ||||||
|         scores = paddle.add(scores, mask) |  | ||||||
|     weights = F.softmax(scores, axis=-1) |  | ||||||
|  |  | ||||||
|     o = weights.reshape([bsz, num_kv_head, -1, total_len]) @ v |  | ||||||
|     return o.reshape([bsz, num_q_head, -1, head_dim]).transpose([0, 2, 1, 3]).reshape([-1, num_q_head, head_dim]) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def clear_param(): |  | ||||||
|     global CURRENT_Q, TOTAL_K, TOTAL_V |  | ||||||
|     CURRENT_Q = [None] |  | ||||||
|     TOTAL_K = [] |  | ||||||
|     TOTAL_V = [] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_append_c16_attention(q_len, kv_len, prefill=False, attn_mask=None): |  | ||||||
|     if prefill: |  | ||||||
|         seq_lens_enc = [ |  | ||||||
|             q_len, |             q_len, | ||||||
|         ] * bsz |         ] * self.bsz | ||||||
|     else: |         token_num = sum(seq_lens_cur) | ||||||
|         seq_lens_enc = [ |         decoder_step_token_num = 1 if prefill else q_len | ||||||
|             0, |  | ||||||
|         ] * bsz |  | ||||||
|  |  | ||||||
|     seq_lens_dec = [ |         seq_lens_encoder = paddle.to_tensor(seq_lens_enc, "int32") | ||||||
|         kv_len, |         seq_lens_this_time = paddle.to_tensor(seq_lens_cur, "int32") | ||||||
|     ] * bsz |         seq_lens_decoder = paddle.to_tensor(seq_lens_dec, "int32") | ||||||
|     seq_lens_cur = [ |  | ||||||
|         q_len, |  | ||||||
|     ] * bsz |  | ||||||
|     token_num = sum(seq_lens_cur) |  | ||||||
|     decoder_step_token_num = 1 if prefill else q_len |  | ||||||
|  |  | ||||||
|     seq_lens_encoder = paddle.to_tensor(seq_lens_enc, "int32") |         batch_id_per_token, cu_seqlens_q, cu_seqlens_k = self.get_padding_offset( | ||||||
|     seq_lens_this_time = paddle.to_tensor(seq_lens_cur, "int32") |             self.bsz, seq_lens_this_time, seq_lens_decoder | ||||||
|     seq_lens_decoder = paddle.to_tensor(seq_lens_dec, "int32") |         ) | ||||||
|  |  | ||||||
|     batch_id_per_token, cu_seqlens_q, cu_seqlens_k = get_padding_offset(bsz, seq_lens_this_time, seq_lens_decoder) |         qkv_varlen_shape = [token_num, (self.num_q_head + 2 * self.num_kv_head) * self.head_dim] | ||||||
|  |         rotary_embs_shape = [ | ||||||
|  |             2, | ||||||
|  |             1, | ||||||
|  |             self.max_seq_len, | ||||||
|  |             1, | ||||||
|  |             self.head_dim if self.use_neox_rotary_style else self.head_dim // 2, | ||||||
|  |         ] | ||||||
|  |  | ||||||
|     # random data |         qkv = paddle.randn(shape=qkv_varlen_shape).astype(self.dtype) | ||||||
|     qkv_varlen_shape = [token_num, (num_q_head + 2 * num_kv_head) * head_dim] |         self.split_qkv(qkv, self.bsz, q_len) | ||||||
|  |  | ||||||
|     rotary_embs_shape = [2, 1, max_seq_len, 1, head_dim if use_neox_rotary_style else head_dim // 2] |         rotary_embs = paddle.randn(shape=rotary_embs_shape).astype("float32") | ||||||
|     # qkv_bias_shape = [num_q_head + 2 * num_kv_head, head_dim] |         rotary_embs[0, :, :, :, :] = 1 | ||||||
|  |         rotary_embs[1, :, :, :, :] = 0 | ||||||
|  |  | ||||||
|     qkv = paddle.randn(shape=qkv_varlen_shape).astype(dtype) |         cache_k_scale = None | ||||||
|  |         cache_v_scale = None | ||||||
|  |         cache_k_out_scale = None | ||||||
|  |         cache_v_out_scale = None | ||||||
|  |  | ||||||
|     # save q, k, v for ref |         encoder_block_shape_q = 64 | ||||||
|     split_qkv(qkv, bsz, q_len, num_q_head, num_kv_head, head_dim) |         decoder_block_shape_q = 16 | ||||||
|  |  | ||||||
|     rotary_embs = paddle.randn(shape=rotary_embs_shape).astype("float32") |         decode_max_tile_size = ( | ||||||
|     rotary_embs[0, :, :, :, :] = 1 |             self.bsz | ||||||
|     rotary_embs[1, :, :, :, :] = 0 |             * (decoder_step_token_num * (self.num_q_head // self.num_kv_head) + decoder_block_shape_q - 1) | ||||||
|  |             / decoder_block_shape_q | ||||||
|     # qkv_scale = None |         ) | ||||||
|     # qkv_bias = None |         decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") | ||||||
|  |         decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") | ||||||
|     cache_k_scale = None |         decoder_num_blocks = paddle.full([1], 0, dtype="int32").pin_memory() | ||||||
|     cache_v_scale = None |         max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu() | ||||||
|     cache_k_out_scale = None |         paddle.device.synchronize() | ||||||
|     cache_v_out_scale = None |         ( | ||||||
|     # shift_bias = None |  | ||||||
|     # smooth_weight = None |  | ||||||
|  |  | ||||||
|     encoder_block_shape_q = 64 |  | ||||||
|     decoder_block_shape_q = 16 |  | ||||||
|  |  | ||||||
|     decode_max_tile_size = ( |  | ||||||
|         bsz |  | ||||||
|         * (decoder_step_token_num * (num_q_head // num_kv_head) + decoder_block_shape_q - 1) |  | ||||||
|         / decoder_block_shape_q |  | ||||||
|     ) |  | ||||||
|     decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") |  | ||||||
|     decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") |  | ||||||
|     decoder_num_blocks = paddle.full([1], 0, dtype="int32").pin_memory() |  | ||||||
|     max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu() |  | ||||||
|     paddle.device.synchronize() |  | ||||||
|     ( |  | ||||||
|         encoder_batch_ids, |  | ||||||
|         encoder_tile_ids_per_batch, |  | ||||||
|         encoder_num_blocks, |  | ||||||
|         kv_batch_ids, |  | ||||||
|         kv_tile_ids_per_batch, |  | ||||||
|         kv_num_blocks, |  | ||||||
|         max_len_kv, |  | ||||||
|     ) = get_block_shape_and_split_kv_block( |  | ||||||
|         seq_lens_encoder, |  | ||||||
|         seq_lens_decoder, |  | ||||||
|         seq_lens_this_time, |  | ||||||
|         decoder_batch_ids, |  | ||||||
|         decoder_tile_ids_per_batch, |  | ||||||
|         decoder_num_blocks, |  | ||||||
|         max_len_tensor_cpu, |  | ||||||
|         encoder_block_shape_q, |  | ||||||
|         decoder_block_shape_q, |  | ||||||
|         num_q_head // num_kv_head, |  | ||||||
|         block_size, |  | ||||||
|         decoder_step_token_num, |  | ||||||
|     ) |  | ||||||
|     s_time = 0 |  | ||||||
|     for i in range(run_time + warm_up): |  | ||||||
|         if i == warm_up: |  | ||||||
|             s_time = time.time() |  | ||||||
|         out = append_attention( |  | ||||||
|             qkv, |  | ||||||
|             cache_k, |  | ||||||
|             cache_v, |  | ||||||
|             seq_lens_encoder, |  | ||||||
|             seq_lens_decoder, |  | ||||||
|             seq_lens_this_time, |  | ||||||
|             batch_id_per_token, |  | ||||||
|             cu_seqlens_q, |  | ||||||
|             block_tables, |  | ||||||
|             encoder_batch_ids, |             encoder_batch_ids, | ||||||
|             encoder_tile_ids_per_batch, |             encoder_tile_ids_per_batch, | ||||||
|             encoder_num_blocks, |             encoder_num_blocks, | ||||||
|             kv_batch_ids, |             kv_batch_ids, | ||||||
|             kv_tile_ids_per_batch, |             kv_tile_ids_per_batch, | ||||||
|             kv_num_blocks, |             kv_num_blocks, | ||||||
|  |             max_len_kv, | ||||||
|  |         ) = get_block_shape_and_split_kv_block( | ||||||
|  |             seq_lens_encoder, | ||||||
|  |             seq_lens_decoder, | ||||||
|  |             seq_lens_this_time, | ||||||
|             decoder_batch_ids, |             decoder_batch_ids, | ||||||
|             decoder_tile_ids_per_batch, |             decoder_tile_ids_per_batch, | ||||||
|             decoder_num_blocks, |             decoder_num_blocks, | ||||||
|             max_len_tensor_cpu, |             max_len_tensor_cpu, | ||||||
|             max_len_kv, |             encoder_block_shape_q, | ||||||
|             rotary_embs, |             decoder_block_shape_q, | ||||||
|             attn_mask,  # attn_mask |             self.num_q_head // self.num_kv_head, | ||||||
|             None, |             self.block_size, | ||||||
|             None, |             decoder_step_token_num, | ||||||
|             cache_k_scale, |  | ||||||
|             cache_v_scale, |  | ||||||
|             cache_k_out_scale, |  | ||||||
|             cache_v_out_scale, |  | ||||||
|             None,  # cache_k_zp |  | ||||||
|             None,  # cache_v_zp |  | ||||||
|             None, |  | ||||||
|             None, |  | ||||||
|             None, |  | ||||||
|             None, |  | ||||||
|             None, |  | ||||||
|             None, |  | ||||||
|             1e-6, |  | ||||||
|             "bf16", |  | ||||||
|             "none",  # cache_quant_type |  | ||||||
|             use_neox_rotary_style, |  | ||||||
|             rope_3d, |  | ||||||
|             max_seq_len, |  | ||||||
|             0.0, |  | ||||||
|             0.0, |  | ||||||
|             -1.0,  # out_linear_in_scale |  | ||||||
|             encoder_block_shape_q,  # encoder_block_shape_q |  | ||||||
|             decoder_block_shape_q,  # decoder_block_shape_q |  | ||||||
|             max_partition_size,  # max_partition_size |  | ||||||
|             encoder_max_partition_size,  # encoder_max_partition_size |  | ||||||
|             decoder_step_token_num,  # speculate_max_draft_token_num |  | ||||||
|             True,  # causal |  | ||||||
|             decoder_step_token_num > 1,  # speculate_decoder |  | ||||||
|         ) |         ) | ||||||
|         paddle.device.synchronize() |         s_time = 0 | ||||||
|     e_time = time.time() |         for i in range(self.run_time + self.warm_up): | ||||||
|     print(f"mean infer time: {np.mean((e_time - s_time) * 1000 / run_time):.2f}") |             if i == self.warm_up: | ||||||
|     return out[0].reshape([token_num, num_q_head, head_dim]) |                 s_time = time.time() | ||||||
|  |             out = append_attention( | ||||||
|  |                 qkv, | ||||||
|  |                 self.cache_k, | ||||||
|  |                 self.cache_v, | ||||||
|  |                 seq_lens_encoder, | ||||||
|  |                 seq_lens_decoder, | ||||||
|  |                 seq_lens_this_time, | ||||||
|  |                 batch_id_per_token, | ||||||
|  |                 cu_seqlens_q, | ||||||
|  |                 self.block_tables, | ||||||
|  |                 encoder_batch_ids, | ||||||
|  |                 encoder_tile_ids_per_batch, | ||||||
|  |                 encoder_num_blocks, | ||||||
|  |                 kv_batch_ids, | ||||||
|  |                 kv_tile_ids_per_batch, | ||||||
|  |                 kv_num_blocks, | ||||||
|  |                 decoder_batch_ids, | ||||||
|  |                 decoder_tile_ids_per_batch, | ||||||
|  |                 decoder_num_blocks, | ||||||
|  |                 max_len_tensor_cpu, | ||||||
|  |                 max_len_kv, | ||||||
|  |                 rotary_embs, | ||||||
|  |                 attn_mask, | ||||||
|  |                 None, | ||||||
|  |                 None, | ||||||
|  |                 cache_k_scale, | ||||||
|  |                 cache_v_scale, | ||||||
|  |                 cache_k_out_scale, | ||||||
|  |                 cache_v_out_scale, | ||||||
|  |                 None, | ||||||
|  |                 None, | ||||||
|  |                 None, | ||||||
|  |                 None, | ||||||
|  |                 None, | ||||||
|  |                 None, | ||||||
|  |                 None, | ||||||
|  |                 None, | ||||||
|  |                 1e-6, | ||||||
|  |                 "bf16", | ||||||
|  |                 "none", | ||||||
|  |                 self.use_neox_rotary_style, | ||||||
|  |                 self.rope_3d, | ||||||
|  |                 self.max_seq_len, | ||||||
|  |                 0.0, | ||||||
|  |                 0.0, | ||||||
|  |                 -1.0, | ||||||
|  |                 encoder_block_shape_q, | ||||||
|  |                 decoder_block_shape_q, | ||||||
|  |                 self.max_partition_size, | ||||||
|  |                 self.encoder_max_partition_size, | ||||||
|  |                 decoder_step_token_num, | ||||||
|  |                 True, | ||||||
|  |                 decoder_step_token_num > 1, | ||||||
|  |             ) | ||||||
|  |             paddle.device.synchronize() | ||||||
|  |         e_time = time.time() | ||||||
|  |         print(f"mean infer time: {np.mean((e_time - s_time) * 1000 / self.run_time):.2f}") | ||||||
|  |         return out[0].reshape([token_num, self.num_q_head, self.head_dim]) | ||||||
|  |  | ||||||
|  |     def test_naive_speculative_decoding(self): | ||||||
|  |         prefill_len = 8192 | ||||||
|  |         dec_len_q = 5 | ||||||
|  |         total_len = prefill_len + dec_len_q | ||||||
|  |         mask = paddle.tril(paddle.ones((self.bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len) | ||||||
|  |         mask = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf"))) | ||||||
|  |         self.run_append_c16_attention(prefill_len, 0, True) | ||||||
|  |         dec_out = self.run_append_c16_attention(dec_len_q, prefill_len, False) | ||||||
|  |  | ||||||
| def test_naive_speculative_decoding(num_q_head, num_kv_head, head_dim): |         ref_out = self.ref_attention(self.CURRENT_Q[0], self.TOTAL_K, self.TOTAL_V, mask) | ||||||
|     prefill_len = 8192 |         np.testing.assert_allclose( | ||||||
|     dec_len_q = 5 |             ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03 | ||||||
|     total_len = prefill_len + dec_len_q |         ) | ||||||
|     mask = paddle.tril(paddle.ones((bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len) |  | ||||||
|     mask = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf"))) |  | ||||||
|     test_append_c16_attention(prefill_len, 0, True) |  | ||||||
|     dec_out = test_append_c16_attention(dec_len_q, prefill_len, False) |  | ||||||
|  |  | ||||||
|     ref_out = ref_attention(CURRENT_Q[0], TOTAL_K, TOTAL_V, num_q_head, num_kv_head, head_dim, mask) |     def test_mask(self): | ||||||
|     np.testing.assert_allclose( |         prefill_len = 8192 | ||||||
|         ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03 |         dec_len_q = 5 | ||||||
|     ) |         total_len = prefill_len + dec_len_q | ||||||
|  |         mask = paddle.tril(paddle.ones((self.bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len) | ||||||
|  |         mask_ref = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf"))) | ||||||
|  |  | ||||||
|  |         mask_append_attn = mask[:, :, prefill_len:] | ||||||
|  |         mask_append_attn = paddle.where( | ||||||
|  |             mask_append_attn == 1, | ||||||
|  |             paddle.full_like(mask_append_attn, fill_value=False, dtype=bool), | ||||||
|  |             paddle.full_like(mask_append_attn, fill_value=True, dtype=bool), | ||||||
|  |         ) | ||||||
|  |  | ||||||
| def test_mask(num_q_head, num_kv_head, head_dim): |         self.run_append_c16_attention(prefill_len, 0, True) | ||||||
|     prefill_len = 8192 |         dec_out = self.run_append_c16_attention(dec_len_q, prefill_len, False, mask_append_attn) | ||||||
|     dec_len_q = 5 |  | ||||||
|     total_len = prefill_len + dec_len_q |  | ||||||
|     mask = paddle.tril(paddle.ones((bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len) |  | ||||||
|     mask_ref = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf"))) |  | ||||||
|  |  | ||||||
|     mask_append_attn = mask[:, :, prefill_len:] |         ref_out = self.ref_attention(self.CURRENT_Q[0], self.TOTAL_K, self.TOTAL_V, mask_ref) | ||||||
|     mask_append_attn = paddle.where( |  | ||||||
|         mask_append_attn == 1, |  | ||||||
|         paddle.full_like(mask_append_attn, fill_value=False, dtype=bool), |  | ||||||
|         paddle.full_like(mask_append_attn, fill_value=True, dtype=bool), |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     test_append_c16_attention(prefill_len, 0, True) |         np.testing.assert_allclose( | ||||||
|     dec_out = test_append_c16_attention(dec_len_q, prefill_len, False, mask_append_attn) |             ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03 | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     ref_out = ref_attention(CURRENT_Q[0], TOTAL_K, TOTAL_V, num_q_head, num_kv_head, head_dim, mask_ref) |     def test_tree_mask(self): | ||||||
|  |         prefill_len = 8192 | ||||||
|  |         dec_len_q = 5 | ||||||
|  |         total_len = prefill_len + dec_len_q | ||||||
|  |         mask = paddle.tril(paddle.ones((self.bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len) | ||||||
|  |         mask[:, 2, prefill_len + 1] = 0 | ||||||
|  |         mask[:, 3, prefill_len + 2] = 0 | ||||||
|  |         mask[:, 4, prefill_len + 1] = 0 | ||||||
|  |         mask[:, 4, prefill_len + 3] = 0 | ||||||
|  |  | ||||||
|     np.testing.assert_allclose( |         mask_ref = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf"))) | ||||||
|         ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03 |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|  |         mask_append_attn = mask[:, :, prefill_len:] | ||||||
|  |         mask_append_attn = paddle.where( | ||||||
|  |             mask_append_attn == 1, | ||||||
|  |             paddle.full_like(mask_append_attn, fill_value=False, dtype=bool), | ||||||
|  |             paddle.full_like(mask_append_attn, fill_value=True, dtype=bool), | ||||||
|  |         ) | ||||||
|  |  | ||||||
| def test_tree_mask(num_q_head, num_kv_head, head_dim): |         self.run_append_c16_attention(prefill_len, 0, True) | ||||||
|     # tree |         dec_out = self.run_append_c16_attention(dec_len_q, prefill_len, False, mask_append_attn) | ||||||
|     #       [N,   N+1,    N+1,    N+2,    N+2] |         ref_out = self.ref_attention(self.CURRENT_Q[0], self.TOTAL_K, self.TOTAL_V, mask_ref) | ||||||
|     # N     [0,   -inf,   -inf,   -inf,   -inf] |         np.testing.assert_allclose( | ||||||
|     # N+1   [0,   0,      -inf,   -inf,   -inf] |             ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03 | ||||||
|     # N+1   [0,   -inf,   0,      -inf,   -inf] |         ) | ||||||
|     # N+2   [0,   0,      -inf,   0,      -inf] |  | ||||||
|     # N+2   [0,   -inf,   0,      -inf,   0] |  | ||||||
|     prefill_len = 8192 |  | ||||||
|     dec_len_q = 5 |  | ||||||
|     total_len = prefill_len + dec_len_q |  | ||||||
|     mask = paddle.tril(paddle.ones((bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len) |  | ||||||
|     mask[:, 2, prefill_len + 1] = 0 |  | ||||||
|     mask[:, 3, prefill_len + 2] = 0 |  | ||||||
|     mask[:, 4, prefill_len + 1] = 0 |  | ||||||
|     mask[:, 4, prefill_len + 3] = 0 |  | ||||||
|  |  | ||||||
|     mask_ref = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf"))) |  | ||||||
|  |  | ||||||
|     mask_append_attn = mask[:, :, prefill_len:] |  | ||||||
|     mask_append_attn = paddle.where( |  | ||||||
|         mask_append_attn == 1, |  | ||||||
|         paddle.full_like(mask_append_attn, fill_value=False, dtype=bool), |  | ||||||
|         paddle.full_like(mask_append_attn, fill_value=True, dtype=bool), |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     test_append_c16_attention(prefill_len, 0, True) |  | ||||||
|     dec_out = test_append_c16_attention(dec_len_q, prefill_len, False, mask_append_attn) |  | ||||||
|     ref_out = ref_attention(CURRENT_Q[0], TOTAL_K, TOTAL_V, num_q_head, num_kv_head, head_dim, mask_ref) |  | ||||||
|     np.testing.assert_allclose( |  | ||||||
|         ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03 |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|  |     unittest.main() | ||||||
|     test_naive_speculative_decoding(num_q_head, num_kv_head, head_dim) |  | ||||||
|     clear_param() |  | ||||||
|  |  | ||||||
|     test_mask(num_q_head, num_kv_head, head_dim) |  | ||||||
|     clear_param() |  | ||||||
|  |  | ||||||
|     test_tree_mask(num_q_head, num_kv_head, head_dim) |  | ||||||
|   | |||||||
| @@ -12,92 +12,97 @@ | |||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
|  |  | ||||||
|  | import unittest | ||||||
|  |  | ||||||
| import numpy as np | import numpy as np | ||||||
| import paddle | import paddle | ||||||
|  |  | ||||||
| from fastdeploy.model_executor.ops.gpu import w4afp8_gemm, w4afp8_gemm_weight_convert | from fastdeploy.model_executor.ops.gpu import w4afp8_gemm, w4afp8_gemm_weight_convert | ||||||
|  |  | ||||||
|  |  | ||||||
| def w4afp8_gemm_naive(input_bf16, weight_quant, tokens, weight_dequant_scale, BATCH, N): | class TestW4AFP8GEMM(unittest.TestCase): | ||||||
|     all_tokens = int(tokens.sum()) |     def setUp(self): | ||||||
|     out = paddle.zeros([all_tokens, N], dtype="bfloat16") |         paddle.seed(0) | ||||||
|     pre_fix_token = 0 |         self.tokens_per_group = 256 | ||||||
|     for i in range(BATCH): |         self.N = 256 | ||||||
|         input = input_bf16[pre_fix_token : pre_fix_token + tokens[i], :] |         self.K = 256 | ||||||
|         weight = (weight_quant[i] - 7.0) * weight_dequant_scale[i] |         self.BATCH = 1 | ||||||
|         out_i = paddle.matmul(input, weight.astype("bfloat16"), transpose_y=True) |         self.TokenPadding = 0 | ||||||
|         out[pre_fix_token : pre_fix_token + tokens[i], :] = out_i |  | ||||||
|         pre_fix_token += tokens[i] |         tokens = [self.tokens_per_group] * self.BATCH | ||||||
|     return out |         self.tokens_perfix_sum = np.cumsum(tokens) | ||||||
|  |  | ||||||
|  |         self.tokens = paddle.to_tensor(tokens, dtype="int64") | ||||||
|  |         self.tokens_perfix_sum = paddle.to_tensor(self.tokens_perfix_sum, dtype="int64") | ||||||
|  |         self.all_tokens = int(self.tokens.sum()) | ||||||
|  |  | ||||||
|  |         self.input_fp8 = paddle.randn([self.all_tokens, self.K], dtype="bfloat16").astype(paddle.float8_e4m3fn) | ||||||
|  |         self.input_bf16 = self.input_fp8.astype("bfloat16") | ||||||
|  |         self.weight = paddle.randn([self.BATCH, self.N, self.K], dtype="bfloat16") / 10 | ||||||
|  |  | ||||||
|  |         self.weight_scale = 7 / self.weight.abs().max(axis=-1).reshape([self.BATCH, self.N, 1]) | ||||||
|  |         self.weight_quant = (self.weight * self.weight_scale).astype("int") + 7 | ||||||
|  |         self.weight_quant = paddle.clip(self.weight_quant, 0, 14) | ||||||
|  |         self.weight_quant = self.weight_quant.astype("bfloat16") | ||||||
|  |         self.weight_dequant_scale = 1 / self.weight_scale.astype("float32") | ||||||
|  |         self.input_row_sum = self.input_bf16.sum(axis=1) * -7 / 512 | ||||||
|  |         self.max_tokens = int(self.tokens.max()) | ||||||
|  |  | ||||||
|  |     def w4afp8_gemm_naive(self, input_bf16, weight_quant, tokens, weight_dequant_scale): | ||||||
|  |         all_tokens = int(tokens.sum()) | ||||||
|  |         out = paddle.zeros([all_tokens, self.N], dtype="bfloat16") | ||||||
|  |         pre_fix_token = 0 | ||||||
|  |         for i in range(self.BATCH): | ||||||
|  |             input = input_bf16[pre_fix_token : pre_fix_token + tokens[i], :] | ||||||
|  |             weight = (weight_quant[i] - 7.0) * weight_dequant_scale[i] | ||||||
|  |             out_i = paddle.matmul(input, weight.astype("bfloat16"), transpose_y=True) | ||||||
|  |             out[pre_fix_token : pre_fix_token + tokens[i], :] = out_i | ||||||
|  |             pre_fix_token += tokens[i] | ||||||
|  |         return out | ||||||
|  |  | ||||||
|  |     def permute_scale(self, weight_scale): | ||||||
|  |         weight_scale = weight_scale.reshape([self.BATCH, self.N]) | ||||||
|  |         temp = paddle.zeros([16]) | ||||||
|  |         for b in range(self.BATCH): | ||||||
|  |             for n in range(0, self.N, 16): | ||||||
|  |                 temp[:] = weight_scale[b, n : n + 16] | ||||||
|  |                 for j in range(0, 16, 2): | ||||||
|  |                     weight_scale[b, n + j] = temp[j // 2] | ||||||
|  |                     weight_scale[b, n + j + 1] = temp[j // 2 + 8] | ||||||
|  |         return weight_scale | ||||||
|  |  | ||||||
|  |     def test_w4afp8_gemm(self): | ||||||
|  |         out_naive = self.w4afp8_gemm_naive(self.input_bf16, self.weight_quant, self.tokens, self.weight_dequant_scale) | ||||||
|  |  | ||||||
|  |         weight_dequant_scale = paddle.to_tensor(self.permute_scale(self.weight_dequant_scale) * 512) | ||||||
|  |         weight_int4 = w4afp8_gemm_weight_convert(self.weight_quant.astype("uint8").cpu()) | ||||||
|  |  | ||||||
|  |         if self.TokenPadding == 0: | ||||||
|  |             out_cuda = w4afp8_gemm( | ||||||
|  |                 self.input_fp8, | ||||||
|  |                 weight_int4.cuda(), | ||||||
|  |                 self.tokens_perfix_sum, | ||||||
|  |                 self.input_row_sum.astype("float32"), | ||||||
|  |                 weight_dequant_scale.astype("float32"), | ||||||
|  |                 int(self.TokenPadding), | ||||||
|  |                 self.max_tokens, | ||||||
|  |                 True, | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             out_cuda = w4afp8_gemm( | ||||||
|  |                 self.input_fp8, | ||||||
|  |                 weight_int4.cuda(), | ||||||
|  |                 self.tokens, | ||||||
|  |                 self.input_row_sum.astype("float32"), | ||||||
|  |                 weight_dequant_scale.astype("float32"), | ||||||
|  |                 int(self.TokenPadding), | ||||||
|  |                 self.max_tokens, | ||||||
|  |                 True, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         gap = (out_cuda - out_naive).abs() | ||||||
|  |         self.assertLess(float(gap.mean()), 0.07) | ||||||
|  |  | ||||||
|  |  | ||||||
| def permute_scale(weight_scale): | if __name__ == "__main__": | ||||||
|     weight_scale = weight_scale.reshape([BATCH, N]) |     unittest.main() | ||||||
|     temp = paddle.zeros([16]) |  | ||||||
|     for b in range(BATCH): |  | ||||||
|         for n in range(0, N, 16): |  | ||||||
|             temp[:] = weight_scale[b, n : n + 16] |  | ||||||
|             for j in range(0, 16, 2): |  | ||||||
|                 weight_scale[b, n + j] = temp[j // 2] |  | ||||||
|                 weight_scale[b, n + j + 1] = temp[j // 2 + 8] |  | ||||||
|     return weight_scale |  | ||||||
|  |  | ||||||
|  |  | ||||||
| paddle.seed(0) |  | ||||||
| tokens_per_group = 256 |  | ||||||
| N = 256 |  | ||||||
| K = 256 |  | ||||||
| BATCH = 1 |  | ||||||
| TokenPadding = 0 |  | ||||||
|  |  | ||||||
| tokens = [tokens_per_group] * BATCH |  | ||||||
| tokens_perfix_sum = np.cumsum(tokens) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| tokens = paddle.to_tensor(tokens, dtype="int64") |  | ||||||
| tokens_perfix_sum = paddle.to_tensor(tokens_perfix_sum, dtype="int64") |  | ||||||
|  |  | ||||||
| all_tokens = int(tokens.sum()) |  | ||||||
|  |  | ||||||
| input_fp8 = paddle.randn([all_tokens, K], dtype="bfloat16").astype(paddle.float8_e4m3fn) |  | ||||||
| input_bf16 = input_fp8.astype("bfloat16") |  | ||||||
| weight = paddle.randn([BATCH, N, K], dtype="bfloat16") / 10 |  | ||||||
|  |  | ||||||
| weight_scale = 7 / weight.abs().max(axis=-1).reshape([BATCH, N, 1]) |  | ||||||
| weight_quant = (weight * weight_scale).astype("int") + 7 |  | ||||||
| weight_quant = paddle.clip(weight_quant, 0, 14) |  | ||||||
| weight_quant = weight_quant.astype("bfloat16") |  | ||||||
| weight_dequant_scale = 1 / weight_scale.astype("float32") |  | ||||||
| input_row_sum = input_bf16.sum(axis=1) * -7 / 512 |  | ||||||
| max_tokens = int(tokens.max()) |  | ||||||
|  |  | ||||||
| out_naive = w4afp8_gemm_naive(input_bf16, weight_quant, tokens, weight_dequant_scale, BATCH, N) |  | ||||||
| weight_dequant_scale = paddle.to_tensor(permute_scale(weight_dequant_scale) * 512) |  | ||||||
|  |  | ||||||
| weight_int4 = w4afp8_gemm_weight_convert(weight_quant.astype("uint8").cpu()) |  | ||||||
|  |  | ||||||
| if TokenPadding == 0: |  | ||||||
|     out_cuda = w4afp8_gemm( |  | ||||||
|         input_fp8, |  | ||||||
|         weight_int4.cuda(), |  | ||||||
|         tokens_perfix_sum, |  | ||||||
|         input_row_sum.astype("float32"), |  | ||||||
|         weight_dequant_scale.astype("float32"), |  | ||||||
|         int(TokenPadding), |  | ||||||
|         max_tokens, |  | ||||||
|         True, |  | ||||||
|     ) |  | ||||||
| else: |  | ||||||
|     out_cuda = w4afp8_gemm( |  | ||||||
|         input_fp8, |  | ||||||
|         weight_int4.cuda(), |  | ||||||
|         tokens, |  | ||||||
|         input_row_sum.astype("float32"), |  | ||||||
|         weight_dequant_scale.astype("float32"), |  | ||||||
|         int(TokenPadding), |  | ||||||
|         max_tokens, |  | ||||||
|         True, |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
| gap = (out_cuda - out_naive).abs() |  | ||||||
| assert float(gap.mean()) < 0.07 |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 YuanRisheng
					YuanRisheng