mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-27 04:46:16 +08:00
Fix int32 to uint32 casting issues causing -1 to become large positive number
Co-authored-by: yuanlehome <23653004+yuanlehome@users.noreply.github.com>
This commit is contained in:
@@ -435,7 +435,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
|||||||
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
|
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
|
||||||
OutT *__restrict__ out,
|
OutT *__restrict__ out,
|
||||||
const int speculate_max_draft_token_num = 5,
|
const int speculate_max_draft_token_num = 5,
|
||||||
const uint32_t attn_mask_len = -1) {
|
const int32_t attn_mask_len = -1) {
|
||||||
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
|
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
|
||||||
static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1");
|
static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1");
|
||||||
static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4");
|
static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4");
|
||||||
@@ -1089,7 +1089,7 @@ void MultiQueryAppendAttention(
|
|||||||
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t attn_mask_len;
|
int32_t attn_mask_len;
|
||||||
if (attn_mask) {
|
if (attn_mask) {
|
||||||
attn_mask_len = attn_mask.get().shape()[1];
|
attn_mask_len = attn_mask.get().shape()[1];
|
||||||
} else {
|
} else {
|
||||||
|
@@ -533,7 +533,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
|||||||
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
|
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
|
||||||
OutT *__restrict__ out,
|
OutT *__restrict__ out,
|
||||||
const int speculate_max_draft_token_num = 5,
|
const int speculate_max_draft_token_num = 5,
|
||||||
const uint32_t attn_mask_len = -1) {
|
const int32_t attn_mask_len = -1) {
|
||||||
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
|
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
|
||||||
constexpr uint32_t num_vecs_per_head_k =
|
constexpr uint32_t num_vecs_per_head_k =
|
||||||
HEAD_DIM / 2 / num_elems_per_128b<CacheT>();
|
HEAD_DIM / 2 / num_elems_per_128b<CacheT>();
|
||||||
@@ -1313,7 +1313,7 @@ void MultiQueryAppendC4Attention(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const int num_chunks = div_up(max_seq_len, chunk_size);
|
const int num_chunks = div_up(max_seq_len, chunk_size);
|
||||||
uint32_t attn_mask_len;
|
int32_t attn_mask_len;
|
||||||
if (attn_mask) {
|
if (attn_mask) {
|
||||||
attn_mask_len = attn_mask.get().shape()[1];
|
attn_mask_len = attn_mask.get().shape()[1];
|
||||||
} else {
|
} else {
|
||||||
|
@@ -540,7 +540,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
|||||||
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
|
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
|
||||||
OutT *__restrict__ out,
|
OutT *__restrict__ out,
|
||||||
const int speculate_max_draft_token_num = 5,
|
const int speculate_max_draft_token_num = 5,
|
||||||
const uint32_t attn_mask_len = -1) {
|
const int32_t attn_mask_len = -1) {
|
||||||
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
|
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
|
||||||
constexpr uint32_t num_vecs_per_head_k =
|
constexpr uint32_t num_vecs_per_head_k =
|
||||||
HEAD_DIM / num_elems_per_128b<CacheT>();
|
HEAD_DIM / num_elems_per_128b<CacheT>();
|
||||||
@@ -1372,7 +1372,7 @@ void MultiQueryAppendC8Attention(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const int num_chunks = div_up(max_seq_len, chunk_size);
|
const int num_chunks = div_up(max_seq_len, chunk_size);
|
||||||
uint32_t attn_mask_len;
|
int32_t attn_mask_len;
|
||||||
if (attn_mask) {
|
if (attn_mask) {
|
||||||
attn_mask_len = attn_mask.get().shape()[1];
|
attn_mask_len = attn_mask.get().shape()[1];
|
||||||
} else {
|
} else {
|
||||||
|
@@ -1026,7 +1026,7 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
|
|||||||
const uint32_t qo_len,
|
const uint32_t qo_len,
|
||||||
const uint32_t kv_len,
|
const uint32_t kv_len,
|
||||||
const uint32_t chunk_end,
|
const uint32_t chunk_end,
|
||||||
const uint32_t attn_mask_len,
|
const int32_t attn_mask_len,
|
||||||
float (*s_frag)[num_frags_z][8],
|
float (*s_frag)[num_frags_z][8],
|
||||||
const int *mask_offset = nullptr) {
|
const int *mask_offset = nullptr) {
|
||||||
const uint32_t tx = threadIdx.x;
|
const uint32_t tx = threadIdx.x;
|
||||||
@@ -1050,7 +1050,7 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
|
|||||||
(causal
|
(causal
|
||||||
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
|
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
|
||||||
: kv_idx >= chunk_end);
|
: kv_idx >= chunk_end);
|
||||||
if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && q_idx < attn_mask_len) {
|
if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && attn_mask_len > 0 && q_idx < static_cast<uint32_t>(attn_mask_len)) {
|
||||||
const int32_t mask_idx = q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
|
const int32_t mask_idx = q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
|
||||||
bool mask = attn_mask[mask_idx];
|
bool mask = attn_mask[mask_idx];
|
||||||
out_of_boundary |= mask;
|
out_of_boundary |= mask;
|
||||||
|
@@ -23,7 +23,7 @@
|
|||||||
template <typename OutputType>
|
template <typename OutputType>
|
||||||
void DisPatchWFp8AFp8Gemm(
|
void DisPatchWFp8AFp8Gemm(
|
||||||
const cutlass::float_e4m3_t* input,
|
const cutlass::float_e4m3_t* input,
|
||||||
const uint32_t* sparse_idx,
|
const int32_t* sparse_idx,
|
||||||
const cutlass::float_e4m3_t* weight,
|
const cutlass::float_e4m3_t* weight,
|
||||||
const int * tokens,
|
const int * tokens,
|
||||||
const float * weight_scale,
|
const float * weight_scale,
|
||||||
@@ -80,7 +80,7 @@ void WFp8AFp8Gemm(
|
|||||||
if (is_bfloat16) {
|
if (is_bfloat16) {
|
||||||
DisPatchWFp8AFp8Gemm(
|
DisPatchWFp8AFp8Gemm(
|
||||||
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
|
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
|
||||||
reinterpret_cast<const uint32_t*>(sparse_idx.data<int32_t>()),
|
sparse_idx.data<int32_t>(),
|
||||||
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<phi::dtype::float8_e4m3fn>()),
|
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<phi::dtype::float8_e4m3fn>()),
|
||||||
tokens.data<int>(),
|
tokens.data<int>(),
|
||||||
weight_scale.data<float>(),
|
weight_scale.data<float>(),
|
||||||
|
Reference in New Issue
Block a user