mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
fix mask_offset in append_attn (#3745)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
* fix mask_offset in append_attn * fix test
This commit is contained in:
@@ -250,8 +250,8 @@ def get_qkv_and_qkv_concat_tensor(bs, q_num_head, kv_num_head, seq_len, dim_head
|
||||
def apply_qk_norm(head_dim, dtype, q, k):
|
||||
q_norm_weight = np.random.random([head_dim]) / 10
|
||||
k_norm_weight = np.random.random([head_dim]) / 10
|
||||
q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype=dtype)
|
||||
k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype=dtype)
|
||||
q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype="float32")
|
||||
k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype="float32")
|
||||
print("q:", q.shape)
|
||||
print("k:", k.shape)
|
||||
bs, q_num_head, seq_len, dim_head = q.shape
|
||||
@@ -260,9 +260,9 @@ def apply_qk_norm(head_dim, dtype, q, k):
|
||||
q = q.reshape([-1, head_dim])
|
||||
k = k.reshape([-1, head_dim])
|
||||
print("q:", q)
|
||||
q = fused_rms_norm(q, q_norm_weight_tensor, None, 1e-5)[0]
|
||||
q = fused_rms_norm(q.astype("float32"), q_norm_weight_tensor, None, 1e-5)[0].astype(dtype)
|
||||
print("q after norm:", q)
|
||||
k = fused_rms_norm(k, k_norm_weight_tensor, None, 1e-5)[0]
|
||||
k = fused_rms_norm(k.astype("float32"), k_norm_weight_tensor, None, 1e-5)[0].astype(dtype)
|
||||
q = q.reshape([-1, q_num_head, seq_len, dim_head])
|
||||
k = k.reshape([-1, kv_num_head, seq_len, dim_head])
|
||||
return q, k, q_norm_weight_tensor, k_norm_weight_tensor
|
||||
|
Reference in New Issue
Block a user