mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-27 04:46:16 +08:00
Fix int32 to uint32 casting issues causing -1 to become large positive number
Co-authored-by: yuanlehome <23653004+yuanlehome@users.noreply.github.com>
This commit is contained in:
@@ -435,7 +435,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
|
||||
OutT *__restrict__ out,
|
||||
const int speculate_max_draft_token_num = 5,
|
||||
const uint32_t attn_mask_len = -1) {
|
||||
const int32_t attn_mask_len = -1) {
|
||||
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
|
||||
static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1");
|
||||
static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4");
|
||||
@@ -1089,7 +1089,7 @@ void MultiQueryAppendAttention(
|
||||
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
||||
}
|
||||
|
||||
uint32_t attn_mask_len;
|
||||
int32_t attn_mask_len;
|
||||
if (attn_mask) {
|
||||
attn_mask_len = attn_mask.get().shape()[1];
|
||||
} else {
|
||||
|
@@ -533,7 +533,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
|
||||
OutT *__restrict__ out,
|
||||
const int speculate_max_draft_token_num = 5,
|
||||
const uint32_t attn_mask_len = -1) {
|
||||
const int32_t attn_mask_len = -1) {
|
||||
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
|
||||
constexpr uint32_t num_vecs_per_head_k =
|
||||
HEAD_DIM / 2 / num_elems_per_128b<CacheT>();
|
||||
@@ -1313,7 +1313,7 @@ void MultiQueryAppendC4Attention(
|
||||
}
|
||||
|
||||
const int num_chunks = div_up(max_seq_len, chunk_size);
|
||||
uint32_t attn_mask_len;
|
||||
int32_t attn_mask_len;
|
||||
if (attn_mask) {
|
||||
attn_mask_len = attn_mask.get().shape()[1];
|
||||
} else {
|
||||
|
@@ -540,7 +540,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
|
||||
OutT *__restrict__ out,
|
||||
const int speculate_max_draft_token_num = 5,
|
||||
const uint32_t attn_mask_len = -1) {
|
||||
const int32_t attn_mask_len = -1) {
|
||||
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
|
||||
constexpr uint32_t num_vecs_per_head_k =
|
||||
HEAD_DIM / num_elems_per_128b<CacheT>();
|
||||
@@ -1372,7 +1372,7 @@ void MultiQueryAppendC8Attention(
|
||||
}
|
||||
|
||||
const int num_chunks = div_up(max_seq_len, chunk_size);
|
||||
uint32_t attn_mask_len;
|
||||
int32_t attn_mask_len;
|
||||
if (attn_mask) {
|
||||
attn_mask_len = attn_mask.get().shape()[1];
|
||||
} else {
|
||||
|
@@ -1026,7 +1026,7 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
|
||||
const uint32_t qo_len,
|
||||
const uint32_t kv_len,
|
||||
const uint32_t chunk_end,
|
||||
const uint32_t attn_mask_len,
|
||||
const int32_t attn_mask_len,
|
||||
float (*s_frag)[num_frags_z][8],
|
||||
const int *mask_offset = nullptr) {
|
||||
const uint32_t tx = threadIdx.x;
|
||||
@@ -1050,7 +1050,7 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
|
||||
(causal
|
||||
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
|
||||
: kv_idx >= chunk_end);
|
||||
if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && q_idx < attn_mask_len) {
|
||||
if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && attn_mask_len > 0 && q_idx < static_cast<uint32_t>(attn_mask_len)) {
|
||||
const int32_t mask_idx = q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
|
||||
bool mask = attn_mask[mask_idx];
|
||||
out_of_boundary |= mask;
|
||||
|
@@ -23,7 +23,7 @@
|
||||
template <typename OutputType>
|
||||
void DisPatchWFp8AFp8Gemm(
|
||||
const cutlass::float_e4m3_t* input,
|
||||
const uint32_t* sparse_idx,
|
||||
const int32_t* sparse_idx,
|
||||
const cutlass::float_e4m3_t* weight,
|
||||
const int * tokens,
|
||||
const float * weight_scale,
|
||||
@@ -80,7 +80,7 @@ void WFp8AFp8Gemm(
|
||||
if (is_bfloat16) {
|
||||
DisPatchWFp8AFp8Gemm(
|
||||
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
|
||||
reinterpret_cast<const uint32_t*>(sparse_idx.data<int32_t>()),
|
||||
sparse_idx.data<int32_t>(),
|
||||
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<phi::dtype::float8_e4m3fn>()),
|
||||
tokens.data<int>(),
|
||||
weight_scale.data<float>(),
|
||||
|
Reference in New Issue
Block a user