mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
make append_attn supports mask_offset (#3138)
* make append_attn supports mask_offset * add unittest
This commit is contained in:
@@ -910,7 +910,8 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base,
|
||||
const uint32_t qo_len,
|
||||
const uint32_t kv_len,
|
||||
const uint32_t chunk_end,
|
||||
float (*s_frag)[num_frags_z][8]) {
|
||||
float (*s_frag)[num_frags_z][8],
|
||||
const int *mask_offset = nullptr) {
|
||||
const uint32_t tx = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
|
||||
@@ -924,10 +925,15 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base,
|
||||
group_size,
|
||||
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +
|
||||
8 * (reg_id / 4) + reg_id % 2;
|
||||
const bool out_of_boundary =
|
||||
(causal
|
||||
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
|
||||
: kv_idx >= chunk_end);
|
||||
bool out_of_boundary;
|
||||
if (mask_offset) {
|
||||
out_of_boundary = q_idx < qo_len ? (kv_idx > mask_offset[q_idx]) : true;
|
||||
} else {
|
||||
out_of_boundary =
|
||||
(causal
|
||||
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
|
||||
: kv_idx >= chunk_end);
|
||||
}
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
s_frag[fx][fz][reg_id] =
|
||||
out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id];
|
||||
|
||||
Reference in New Issue
Block a user