diff --git a/tests/layers/test_attention_layer.py b/tests/layers/test_attention_layer.py index 91bd43eb6..106cb93cd 100644 --- a/tests/layers/test_attention_layer.py +++ b/tests/layers/test_attention_layer.py @@ -90,7 +90,6 @@ class TestAttentionPerformance(unittest.TestCase): self.attention_layer[i] = Ernie4_5_Attention(self.fd_config, layer_id=i, prefix="test_layer") state_dict = self.create_random_attention_state_dict(self.fd_config, prefix="test_layer") self.attention_layer[i].load_state_dict(state_dict) - self.attention_layer[i].attn.cache_quant_type_str = "block_wise_fp8" def attn_forward(forward_meta, hidden_states): for i in range(num_layers): @@ -100,6 +99,8 @@ class TestAttentionPerformance(unittest.TestCase): self.attn_forward = attn_forward + self.cache_quant_type_str = getattr(self.attention_layer[0].attn, "cache_quant_type_str", "none") + print("===== Initialization Complete =====") def tearDown(self): @@ -119,9 +120,10 @@ class TestAttentionPerformance(unittest.TestCase): config_dict = { "architectures": ["Ernie4_5_MoeForCausalLM"], "dtype": "bfloat16", - "hidden_size": 4096, "max_position_embeddings": 131072, - "max_model_len": 36 * 1024 + 1024, + "max_model_len": 131072, + "head_dim": 128, + "hidden_size": 4096, "num_attention_heads": 32, "num_key_value_heads": 4, "num_hidden_layers": 57, @@ -153,7 +155,9 @@ class TestAttentionPerformance(unittest.TestCase): scheduler_config=SchedulerConfig({}), load_config=LoadConfig({}), quant_config=MixQuantConfig( - dense_quant_type="block_wise_fp8", moe_quant_type="block_wise_fp8", kv_cache_quant_type="float8_e4m3fn" + dense_quant_type="block_wise_fp8", + moe_quant_type="block_wise_fp8", + kv_cache_quant_type="float8_e4m3fn", ), graph_opt_config=GraphOptimizationConfig({}), commit_config=CommitConfig(), @@ -202,7 +206,7 @@ class TestAttentionPerformance(unittest.TestCase): mode: ForwardMode, fd_config: FDConfig, attn_backend: AttentionBackend, - use_dynamic_quant: bool = False, + cache_quant_type_str: str = "none", ) -> ForwardMeta: """ Creates a high-fidelity ForwardMeta object. @@ -231,30 +235,31 @@ class TestAttentionPerformance(unittest.TestCase): block_size = fd_config.cache_config.block_size max_model_len = fd_config.model_config.max_model_len - num_blocks_per_seq = (max_model_len + block_size - 1) // block_size - num_blocks = num_blocks_per_seq * batch_size + max_blocks_per_seq = (max_model_len + block_size - 1) // block_size + allocated_blocks_per_seq = seq_len // block_size + 1 + allocated_num_blocks = allocated_blocks_per_seq * batch_size head_dim = fd_config.model_config.head_dim kv_num_heads_tp = fd_config.model_config.num_key_value_heads // fd_config.parallel_config.tensor_parallel_size num_layers = fd_config.model_config.num_hidden_layers cache_type = fd_config.model_config.dtype - if use_dynamic_quant: + if cache_quant_type_str != "none": cache_type = "uint8" - cache_shape = (num_blocks, kv_num_heads_tp, block_size, head_dim) - scale_shape = (num_blocks, kv_num_heads_tp, block_size) + cache_shape = (allocated_num_blocks, kv_num_heads_tp, block_size, head_dim) + scale_shape = (allocated_num_blocks, kv_num_heads_tp, block_size) caches = [] for _ in range(num_layers): key_cache = paddle.randint(0, 255, shape=cache_shape, dtype="int32").cast(cache_type) value_cache = paddle.randint(0, 255, shape=cache_shape, dtype="int32").cast(cache_type) caches.extend([key_cache, value_cache]) - if use_dynamic_quant: + if cache_quant_type_str == "block_wise_fp8": key_cache_scale = paddle.rand(shape=scale_shape, dtype=fd_config.model_config.dtype) value_cache_scale = paddle.rand(shape=scale_shape, dtype=fd_config.model_config.dtype) caches.extend([key_cache_scale, value_cache_scale]) - block_tables = paddle.zeros(shape=(batch_size, num_blocks_per_seq), dtype="int32") + block_tables = paddle.zeros(shape=(batch_size, max_blocks_per_seq), dtype="int32") for i in range(batch_size): - for j in range(num_blocks_per_seq): - block_tables[i, j] = i * num_blocks_per_seq + j + for j in range(allocated_blocks_per_seq): + block_tables[i, j] = i * allocated_blocks_per_seq + j tmp_position_ids = paddle.arange(fd_config.model_config.max_model_len).reshape((1, -1)) rope_emb = get_rope( @@ -294,7 +299,6 @@ class TestAttentionPerformance(unittest.TestCase): def test_decode_performance_with_prefill(self): # Test parameters test_steps = 100 - use_dynamic_quant = True act_tensor_dtype = paddle.bfloat16 # prefill_batch_size = 1 @@ -311,7 +315,7 @@ class TestAttentionPerformance(unittest.TestCase): # mode=ForwardMode.EXTEND, # fd_config=self.fd_config, # attn_backend=self.attn_backend, - # use_dynamic_quant=use_dynamic_quant, + # cache_quant_type_str=self.cache_quant_type_str, # ) # self.attn_backend.init_attention_metadata(forward_meta) @@ -339,6 +343,7 @@ class TestAttentionPerformance(unittest.TestCase): # times = np.array([round(s.elapsed_time(e), 1) for s, e in zip(start_events, end_events)])[1:] # print(times[-5:]) + # return # p.stop() @@ -361,7 +366,7 @@ class TestAttentionPerformance(unittest.TestCase): mode=ForwardMode.DECODE, fd_config=self.fd_config, attn_backend=self.attn_backend, - use_dynamic_quant=use_dynamic_quant, + cache_quant_type_str=self.cache_quant_type_str, ) self.attn_backend.init_attention_metadata(forward_meta)