Fix int32 to uint32 casting issues causing -1 to become large positive number

Co-authored-by: yuanlehome <23653004+yuanlehome@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2025-09-17 10:55:36 +00:00
parent 3e319c0f90
commit 0f2b609496
5 changed files with 10 additions and 10 deletions

View File

@@ -435,7 +435,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5,
const uint32_t attn_mask_len = -1) {
const int32_t attn_mask_len = -1) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1");
static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4");
@@ -1089,7 +1089,7 @@ void MultiQueryAppendAttention(
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
}
uint32_t attn_mask_len;
int32_t attn_mask_len;
if (attn_mask) {
attn_mask_len = attn_mask.get().shape()[1];
} else {

View File

@@ -533,7 +533,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5,
const uint32_t attn_mask_len = -1) {
const int32_t attn_mask_len = -1) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
constexpr uint32_t num_vecs_per_head_k =
HEAD_DIM / 2 / num_elems_per_128b<CacheT>();
@@ -1313,7 +1313,7 @@ void MultiQueryAppendC4Attention(
}
const int num_chunks = div_up(max_seq_len, chunk_size);
uint32_t attn_mask_len;
int32_t attn_mask_len;
if (attn_mask) {
attn_mask_len = attn_mask.get().shape()[1];
} else {

View File

@@ -540,7 +540,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5,
const uint32_t attn_mask_len = -1) {
const int32_t attn_mask_len = -1) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
constexpr uint32_t num_vecs_per_head_k =
HEAD_DIM / num_elems_per_128b<CacheT>();
@@ -1372,7 +1372,7 @@ void MultiQueryAppendC8Attention(
}
const int num_chunks = div_up(max_seq_len, chunk_size);
uint32_t attn_mask_len;
int32_t attn_mask_len;
if (attn_mask) {
attn_mask_len = attn_mask.get().shape()[1];
} else {

View File

@@ -1026,7 +1026,7 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
const uint32_t qo_len,
const uint32_t kv_len,
const uint32_t chunk_end,
const uint32_t attn_mask_len,
const int32_t attn_mask_len,
float (*s_frag)[num_frags_z][8],
const int *mask_offset = nullptr) {
const uint32_t tx = threadIdx.x;
@@ -1050,7 +1050,7 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
(causal
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
: kv_idx >= chunk_end);
if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && q_idx < attn_mask_len) {
if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && attn_mask_len > 0 && q_idx < static_cast<uint32_t>(attn_mask_len)) {
const int32_t mask_idx = q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
bool mask = attn_mask[mask_idx];
out_of_boundary |= mask;

View File

@@ -23,7 +23,7 @@
template <typename OutputType>
void DisPatchWFp8AFp8Gemm(
const cutlass::float_e4m3_t* input,
const uint32_t* sparse_idx,
const int32_t* sparse_idx,
const cutlass::float_e4m3_t* weight,
const int * tokens,
const float * weight_scale,
@@ -80,7 +80,7 @@ void WFp8AFp8Gemm(
if (is_bfloat16) {
DisPatchWFp8AFp8Gemm(
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
reinterpret_cast<const uint32_t*>(sparse_idx.data<int32_t>()),
sparse_idx.data<int32_t>(),
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<phi::dtype::float8_e4m3fn>()),
tokens.data<int>(),
weight_scale.data<float>(),