diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh index cc537e46c..618f1c32a 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh @@ -435,7 +435,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel( float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] OutT *__restrict__ out, const int speculate_max_draft_token_num = 5, - const uint32_t attn_mask_len = -1) { + const int32_t attn_mask_len = -1) { constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1"); static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4"); @@ -1089,7 +1089,7 @@ void MultiQueryAppendAttention( chunk_size = static_cast(encoder_max_partition_size); } - uint32_t attn_mask_len; + int32_t attn_mask_len; if (attn_mask) { attn_mask_len = attn_mask.get().shape()[1]; } else { diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh index 49317bfdf..eb7afe975 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh @@ -533,7 +533,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel( float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] OutT *__restrict__ out, const int speculate_max_draft_token_num = 5, - const uint32_t attn_mask_len = -1) { + const int32_t attn_mask_len = -1) { constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); constexpr uint32_t num_vecs_per_head_k = HEAD_DIM / 2 / num_elems_per_128b(); @@ -1313,7 +1313,7 @@ void MultiQueryAppendC4Attention( } const int num_chunks = div_up(max_seq_len, chunk_size); - uint32_t attn_mask_len; + int32_t attn_mask_len; if (attn_mask) { attn_mask_len = attn_mask.get().shape()[1]; } else { diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh index b2fe4c6f6..57e854451 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh @@ -540,7 +540,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] OutT *__restrict__ out, const int speculate_max_draft_token_num = 5, - const uint32_t attn_mask_len = -1) { + const int32_t attn_mask_len = -1) { constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); constexpr uint32_t num_vecs_per_head_k = HEAD_DIM / num_elems_per_128b(); @@ -1372,7 +1372,7 @@ void MultiQueryAppendC8Attention( } const int num_chunks = div_up(max_seq_len, chunk_size); - uint32_t attn_mask_len; + int32_t attn_mask_len; if (attn_mask) { attn_mask_len = attn_mask.get().shape()[1]; } else { diff --git a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh index 24787e8b7..6c932b054 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh @@ -1026,7 +1026,7 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask, const uint32_t qo_len, const uint32_t kv_len, const uint32_t chunk_end, - const uint32_t attn_mask_len, + const int32_t attn_mask_len, float (*s_frag)[num_frags_z][8], const int *mask_offset = nullptr) { const uint32_t tx = threadIdx.x; @@ -1050,7 +1050,7 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask, (causal ? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end)) : kv_idx >= chunk_end); - if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && q_idx < attn_mask_len) { + if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && attn_mask_len > 0 && q_idx < static_cast(attn_mask_len)) { const int32_t mask_idx = q_idx * attn_mask_len + kv_idx - kv_len + qo_len; bool mask = attn_mask[mask_idx]; out_of_boundary |= mask; diff --git a/custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8afp8_sparse_gemm.cu b/custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8afp8_sparse_gemm.cu index 03d3c16a1..f9f07c9d9 100644 --- a/custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8afp8_sparse_gemm.cu +++ b/custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8afp8_sparse_gemm.cu @@ -23,7 +23,7 @@ template void DisPatchWFp8AFp8Gemm( const cutlass::float_e4m3_t* input, - const uint32_t* sparse_idx, + const int32_t* sparse_idx, const cutlass::float_e4m3_t* weight, const int * tokens, const float * weight_scale, @@ -80,7 +80,7 @@ void WFp8AFp8Gemm( if (is_bfloat16) { DisPatchWFp8AFp8Gemm( reinterpret_cast(input.data()), - reinterpret_cast(sparse_idx.data()), + sparse_idx.data(), reinterpret_cast(weight.data()), tokens.data(), weight_scale.data(),