supports dynamic Cfp8 (#3767)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* supports dynamic Cfp8

* add unittest
This commit is contained in:
lzy
2025-09-08 11:41:29 +08:00
committed by GitHub
parent b5e20e3015
commit af49b81ffd
20 changed files with 1417 additions and 225 deletions

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import time
import unittest
@@ -20,6 +21,7 @@ import paddle
from paddle.incubate.nn.functional import fused_rms_norm
paddle.seed(10)
np.random.seed(10)
class RopeEmbedding:
@@ -334,7 +336,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.name = "TestAppendGroupQueryAttnWithRope"
self.place = paddle.CUDAPlace(0)
self.batch_size = 1
self.q_num_head = 12
self.q_num_head = 16
self.kv_num_head = 2
self.seq_len = 64
self.max_dec_len = 64
@@ -347,9 +349,10 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.max_seq_len = self.seq_len + self.max_dec_len
self.softmax_scale = self.dim_head**-0.5
self.rope_theta = 10000
self.dtype = "float16"
self.dtype = "bfloat16"
self.use_qk_norm = True
self.use_mask_offset = False
self.use_dynamic_quant = False
self.init_tensor()
def init_tensor(self):
@@ -391,8 +394,23 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
)
self.scale = 1.0 / np.sqrt(self.dim_head)
self.cache_k = paddle.zeros(shape=self.cache_shape, dtype=self.dtype)
self.cache_v = paddle.zeros(shape=self.cache_shape, dtype=self.dtype)
if self.use_dynamic_quant:
self.cache_scale_shape = (
self.max_block_num,
self.kv_num_head,
self.blocksize,
)
self.cache_k = paddle.zeros(shape=self.cache_shape, dtype="uint8")
self.cache_v = paddle.zeros(shape=self.cache_shape, dtype="uint8")
self.cache_k_T = paddle.zeros(shape=self.cache_shape, dtype=self.dtype)
self.cache_v_T = paddle.zeros(shape=self.cache_shape, dtype=self.dtype)
self.key_cache_scale = paddle.zeros(shape=self.cache_scale_shape, dtype=self.dtype)
self.value_cache_scale = paddle.zeros(shape=self.cache_scale_shape, dtype=self.dtype)
else:
self.cache_k = paddle.zeros(shape=self.cache_shape, dtype=self.dtype)
self.cache_v = paddle.zeros(shape=self.cache_shape, dtype=self.dtype)
self.key_cache_scale = None
self.value_cache_scale = None
self.block_tables = paddle.zeros(shape=(self.batch_size, self.block_num_per_seq), dtype="int32")
for i in range(self.batch_size):
need_block_num = (self.seq_len + self.max_dec_len + self.blocksize - 1) // self.blocksize
@@ -415,6 +433,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask=None):
paddle.disable_static()
print("use_dynamic_quant: ", self.use_dynamic_quant)
self.token_num = self.seq_len * self.batch_size
q, k, v, qkv = get_qkv_and_qkv_concat_tensor(
self.batch_size,
@@ -472,18 +491,17 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.blocksize,
speculate_max_draft_token_num + 1,
)
if self.use_dynamic_quant:
cache_quant_type = "block_wise_fp8"
else:
cache_quant_type = "none"
# Warm up
WARM_UP = 1
RUN_TIME = 2
for i in range(WARM_UP + RUN_TIME):
if i == WARM_UP:
paddle.device.synchronize()
start_time = time.time()
out = append_attention(
qkv,
self.cache_k,
self.cache_v,
if self.use_dynamic_quant:
qkv_copy = copy.deepcopy(qkv)
append_attention(
qkv_copy,
self.cache_k_T,
self.cache_v_T,
self.seq_lens_encoder,
self.seq_lens_decoder,
self.seq_lens_this_time,
@@ -519,7 +537,69 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
k_norm_weight, # k_norm_weight
1e-6,
"fp16",
"none", # cache_quant_type
"none",
self.use_neox_rotary_style,
False,
self.max_seq_len,
0.0, # quant_min_bound
0.0, # quant_max_bound
-1, # out_linear_in_scale
64, # encoder_block_shape_q
16, # decoder_block_shape_q
32768, # max_partition_size
32768, # encoder_max_partition_size
speculate_max_draft_token_num + 1, # speculate_max_draft_token_num
True, # causal
False, # speculate_decoder
)
# Warm up
WARM_UP = 1
RUN_TIME = 2
for i in range(WARM_UP + RUN_TIME):
if i == WARM_UP:
paddle.device.synchronize()
start_time = time.time()
out = append_attention(
qkv,
self.cache_k,
self.cache_v,
self.seq_lens_encoder,
self.seq_lens_decoder,
self.seq_lens_this_time,
self.padding_offset,
self.cum_offset,
self.block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
self.decoder_batch_ids,
self.decoder_tile_ids_per_batch,
self.decoder_num_blocks_cpu,
self.max_len_tensor_cpu,
max_len_kv,
self.rope_emb, # rope_emb
None, # attn_mask
None, # qkv_bias
None, # qkv_out_scales
self.key_cache_scale, # cache_k_quant_scales
self.value_cache_scale, # cache_v_quant_scales
None, # cache_k_dequant_scales
None, # cache_v_dequant_scales
None, # cache_k_zp
None, # cache_v_zp
None, # linear_shift
None, # linear_smooth
self.mask_offset, # mask_offset
None, # kv_signal_data
q_norm_weight, # q_norm_weight
k_norm_weight, # k_norm_weight
1e-6,
"fp16",
cache_quant_type,
self.use_neox_rotary_style,
False,
self.max_seq_len,
@@ -537,13 +617,6 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
paddle.device.synchronize()
end_time = time.time()
print(f"[append-attn ut] cost_time:{(end_time - start_time) / RUN_TIME * 1000}ms")
naive_cache_k, naive_cache_v = block_cache_to_naive_cache(
self.cache_k,
self.cache_v,
self.batch_size,
self.block_tables,
self.seq_len,
)
np.testing.assert_allclose(
out.numpy(),
out_.numpy(),
@@ -572,13 +645,22 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
if self.use_mask_offset:
print("encoder mask_offset: ", self.mask_offset)
self.cmp_append_attention(attn_mask=self.attention_mask)
naive_cache_k, naive_cache_v = block_cache_to_naive_cache(
self.cache_k,
self.cache_v,
self.batch_size,
self.block_tables,
self.seq_len,
)
if self.use_dynamic_quant:
naive_cache_k, naive_cache_v = block_cache_to_naive_cache(
self.cache_k_T,
self.cache_v_T,
self.batch_size,
self.block_tables,
self.seq_len,
)
else:
naive_cache_k, naive_cache_v = block_cache_to_naive_cache(
self.cache_k,
self.cache_v,
self.batch_size,
self.block_tables,
self.seq_len,
)
# decoder
self.seq_lens_decoder[:] = self.seq_lens_encoder
self.seq_lens_encoder[:] = 0
@@ -613,10 +695,10 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
class TestAppendGroupQueryAttnWithNeoXRope(TestAppendGroupQueryAttnWithRope):
def setUp(self):
paddle.disable_static()
self.name = "TestAppendGroupQueryAttnWithRope"
self.name = "TestAppendGroupQueryAttnWithNeoXRope"
self.place = paddle.CUDAPlace(0)
self.batch_size = 1
self.q_num_head = 12
self.q_num_head = 16
self.kv_num_head = 2
self.seq_len = 64
self.max_dec_len = 64
@@ -632,6 +714,33 @@ class TestAppendGroupQueryAttnWithNeoXRope(TestAppendGroupQueryAttnWithRope):
self.dtype = "float16"
self.use_qk_norm = False
self.use_mask_offset = True
self.use_dynamic_quant = False
self.init_tensor()
class TestAppendGroupQueryAttnWithRopeDyCfp8(TestAppendGroupQueryAttnWithRope):
def setUp(self):
paddle.disable_static()
self.name = "TestAppendGroupQueryAttnWithRopeDyCfp8"
self.place = paddle.CUDAPlace(0)
self.batch_size = 1
self.q_num_head = 16
self.kv_num_head = 2
self.seq_len = 64
self.max_dec_len = 64
self.dim_head = 128
self.q_hid_dim = self.q_num_head * self.dim_head
self.kv_hid_dim = self.kv_num_head * self.dim_head
self.blocksize = 64
self.use_neox_rotary_style = False
# max_seq_len = self.seq_len + self.max_dec_len
self.max_seq_len = self.seq_len + self.max_dec_len
self.softmax_scale = self.dim_head**-0.5
self.rope_theta = 10000
self.dtype = "bfloat16"
self.use_qk_norm = True
self.use_mask_offset = False
self.use_dynamic_quant = True
self.init_tensor()