diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh index 52c8ad5af..c8273cd3c 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh @@ -42,19 +42,22 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel( const uint32_t elem_cnt, const int kv_num_heads, const bool rope_3d, - const T* q_norm_weight, - const T* k_norm_weight, + const float* q_norm_weight, + const float* k_norm_weight, const float rms_norm_eps) { using LoadT = AlignedVector; using LoadBiasT = AlignedVector; using LoadKVT = AlignedVector; constexpr int HalfVecSize = VecSize / 2; using LoadEmbT = AlignedVector; + using LoadFloat = AlignedVector; LoadT src_vec; LoadBiasT out_vec; LoadKVT cache_vec; LoadEmbT cos_emb_vec; LoadEmbT sin_emb_vec; + LoadFloat tmp_vec; + LoadFloat q_norm_vec, k_norm_vec; int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y; int64_t all_warp_num = gridDim.x * blockDim.y; @@ -105,10 +108,8 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel( float tmp1 = input_left * cos_tmp - input_right * sin_tmp; float tmp2 = input_right * cos_tmp + input_left * sin_tmp; thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; - out_vec[2 * i] = - static_cast(tmp1); - out_vec[2 * i + 1] = - static_cast(tmp2); + tmp_vec[2 * i] = tmp1; + tmp_vec[2 * i + 1] = tmp2; } else { out_vec[2 * i] = src_vec[2 * i]; out_vec[2 * i + 1] = src_vec[2 * i + 1]; @@ -119,17 +120,17 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel( float row_variance = max(warp_m2 / head_size, 0.0f); float row_inv_var = Rsqrt(row_variance + rms_norm_eps); - LoadT q_norm_vec, k_norm_vec; + if (hi < num_heads) { // q - Load(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec); + Load(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec); #pragma unroll for (int i = 0; i < VecSize; i++) { - out_vec[i] = static_cast(static_cast(out_vec[i]) * row_inv_var * static_cast(q_norm_vec[i])); + out_vec[i] = static_cast(tmp_vec[i] * row_inv_var * q_norm_vec[i]); } } else { // k - Load(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec); + Load(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec); for (int i = 0; i < VecSize; i++) { - out_vec[i] = static_cast(static_cast(out_vec[i]) * row_inv_var * static_cast(k_norm_vec[i])); + out_vec[i] = static_cast(tmp_vec[i] * row_inv_var * k_norm_vec[i]); } } } diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu index ffee65ee0..68b22968b 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu @@ -39,8 +39,8 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv, const cudaStream_t& stream, const bool use_neox_style, const bool rope_3d, - const T* q_norm_weight, - const T* k_norm_weight, + const float* q_norm_weight, + const float* k_norm_weight, const float rms_norm_eps) { const uint32_t elem_nums = use_neox_style ? bsz * (num_heads + 2 * kv_num_heads) * dim_head / 2 @@ -569,8 +569,8 @@ void DecoderWriteCacheWithRoPEKernel( stream, use_neox_rotary_style, rope_3d, - reinterpret_cast(q_norm_weight.get().data()), - reinterpret_cast(k_norm_weight.get().data()), + q_norm_weight ? q_norm_weight.get().data() : nullptr, + k_norm_weight ? k_norm_weight.get().data() : nullptr, rms_norm_eps); } else { PD_THROW( diff --git a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh index 1b14a577d..44489bae0 100644 --- a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh @@ -431,16 +431,19 @@ __global__ void GQAVariableLengthRotaryQKNormKernel( const int seq_len, const int last_dim, const bool rope_3d, - const T* q_norm_weight, - const T* k_norm_weight, + const float* q_norm_weight, + const float* k_norm_weight, const float rms_norm_eps ) { using LoadT = AlignedVector; constexpr int HalfVecSize = VecSize / 2; using LoadEmbT = AlignedVector; + using LoadFloat = AlignedVector; LoadT src_vec; LoadEmbT cos_emb_vec; LoadEmbT sin_emb_vec; + LoadFloat tmp_vec; + LoadFloat q_norm_vec, k_norm_vec; int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y; int64_t all_warp_num = gridDim.x * blockDim.y; const int half_lastdim = last_dim / 2; @@ -477,25 +480,25 @@ __global__ void GQAVariableLengthRotaryQKNormKernel( const float sin_tmp = sin_emb_vec[i]; float tmp1 = input_left * cos_tmp - input_right * sin_tmp; float tmp2 = input_right * cos_tmp + input_left * sin_tmp; - src_vec[2 * i] = static_cast(tmp1); - src_vec[2 * i + 1] = static_cast(tmp2); + tmp_vec[2 * i] = tmp1; + tmp_vec[2 * i + 1] = tmp2; thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; } WelfordWarpAllReduce(thread_m2, &warp_m2); float row_variance = max(warp_m2 / last_dim, 0.0f); float row_inv_var = Rsqrt(row_variance + rms_norm_eps); - LoadT q_norm_vec, k_norm_vec; + if (hi < q_num_head) { - Load(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec); + Load(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec); #pragma unroll for (int i = 0; i < VecSize; i++) { - src_vec[i] = static_cast(static_cast(src_vec[i]) * row_inv_var * static_cast(q_norm_vec[i])); + src_vec[i] = static_cast(tmp_vec[i] * row_inv_var * q_norm_vec[i]); } } else { - Load(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec); + Load(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec); for (int i = 0; i < VecSize; i++) { - src_vec[i] = static_cast(static_cast(src_vec[i]) * row_inv_var * static_cast(k_norm_vec[i])); + src_vec[i] = static_cast(tmp_vec[i] * row_inv_var * k_norm_vec[i]); } } Store(src_vec, &qkv_out[base_idx]); @@ -1695,8 +1698,8 @@ void gqa_rotary_qk_norm_variable( const cudaStream_t &stream, bool use_neox_style = false, bool rope_3d = false, - const T *q_norm_weight = nullptr, - const T *k_norm_weight = nullptr, + const float *q_norm_weight = nullptr, + const float *k_norm_weight = nullptr, const float rms_norm_eps = 1e-6) { int64_t elem_nums = qkv_out_scales diff --git a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h index 1e5d79878..5af84e73f 100644 --- a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h +++ b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h @@ -80,8 +80,8 @@ void EncoderWriteCacheWithRopeKernel( stream, use_neox_style, rope_3d, - q_norm_weight ? q_norm_weight.get().data() : nullptr, - k_norm_weight ? k_norm_weight.get().data() : nullptr, + q_norm_weight ? q_norm_weight.get().data() : nullptr, + k_norm_weight ? k_norm_weight.get().data() : nullptr, rms_norm_eps); } else { PD_THROW( diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index 029764c63..29d570e23 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -63,7 +63,6 @@ class AppendAttentionMetadata(AttentionMetadata): block_tables: Optional[paddle.Tensor] = None rotary_embs: Optional[paddle.Tensor] = None attn_mask: Optional[paddle.Tensor] = None - mask_offset: Optional[paddle.Tensor] = None _fuse_kernel_compute_dtype: str = "bf16" # pd_disaggregation @@ -142,7 +141,6 @@ class AppendAttentionBackend(AttentionBackend): metadata.block_tables = forward_meta.block_tables metadata.rotary_embs = forward_meta.rotary_embs metadata.attn_mask = forward_meta.attn_mask - metadata.mask_offset = forward_meta.attn_mask_offsets metadata.pre_caches_length = forward_meta.pre_caches_length ( metadata.encoder_batch_ids, @@ -303,7 +301,7 @@ class AppendAttentionBackend(AttentionBackend): getattr(layer, "cache_v_zp", None), layer.linear_shift, layer.linear_smooth, - metadata.mask_offset, + forward_meta.attn_mask_offsets, metadata.kv_signal_data_list[layer.layer_id], getattr(layer, "q_norm_weight", None), getattr(layer, "k_norm_weight", None), @@ -358,7 +356,7 @@ class AppendAttentionBackend(AttentionBackend): getattr(layer, "cache_v_zp", None), layer.linear_shift, layer.linear_smooth, - metadata.mask_offset, + forward_meta.attn_mask_offsets, metadata.kv_signal_data_list[layer.layer_id], getattr(layer, "q_norm_weight", None), getattr(layer, "k_norm_weight", None), diff --git a/fastdeploy/model_executor/layers/attention/attention.py b/fastdeploy/model_executor/layers/attention/attention.py index 20e0e5efe..7b3581de2 100644 --- a/fastdeploy/model_executor/layers/attention/attention.py +++ b/fastdeploy/model_executor/layers/attention/attention.py @@ -163,14 +163,14 @@ class Attention(nn.Layer): def init_weight(self): self.q_norm_weight = self.create_parameter( shape=[self.qk_head_dim], - dtype=self._dtype, + dtype="float32", is_bias=False, default_initializer=paddle.nn.initializer.Constant(0), ) self.k_norm_weight = self.create_parameter( shape=[self.qk_head_dim], - dtype=self._dtype, + dtype="float32", is_bias=False, default_initializer=paddle.nn.initializer.Constant(0), ) @@ -184,8 +184,8 @@ class Attention(nn.Layer): if self.use_qk_norm: q_norm_weight_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.q_norm_key + ".weight"))) k_norm_weight_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.k_norm_key + ".weight"))) - self.q_norm_weight.set_value(q_norm_weight_tensor) - self.k_norm_weight.set_value(k_norm_weight_tensor) + self.q_norm_weight.set_value(q_norm_weight_tensor.astype("float32")) + self.k_norm_weight.set_value(k_norm_weight_tensor.astype("float32")) def forward( self, diff --git a/tests/layers/test_append_attention.py b/tests/layers/test_append_attention.py index b9fbbf4d6..af345a324 100644 --- a/tests/layers/test_append_attention.py +++ b/tests/layers/test_append_attention.py @@ -250,8 +250,8 @@ def get_qkv_and_qkv_concat_tensor(bs, q_num_head, kv_num_head, seq_len, dim_head def apply_qk_norm(head_dim, dtype, q, k): q_norm_weight = np.random.random([head_dim]) / 10 k_norm_weight = np.random.random([head_dim]) / 10 - q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype=dtype) - k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype=dtype) + q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype="float32") + k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype="float32") print("q:", q.shape) print("k:", k.shape) bs, q_num_head, seq_len, dim_head = q.shape @@ -260,9 +260,9 @@ def apply_qk_norm(head_dim, dtype, q, k): q = q.reshape([-1, head_dim]) k = k.reshape([-1, head_dim]) print("q:", q) - q = fused_rms_norm(q, q_norm_weight_tensor, None, 1e-5)[0] + q = fused_rms_norm(q.astype("float32"), q_norm_weight_tensor, None, 1e-5)[0].astype(dtype) print("q after norm:", q) - k = fused_rms_norm(k, k_norm_weight_tensor, None, 1e-5)[0] + k = fused_rms_norm(k.astype("float32"), k_norm_weight_tensor, None, 1e-5)[0].astype(dtype) q = q.reshape([-1, q_num_head, seq_len, dim_head]) k = k.reshape([-1, kv_num_head, seq_len, dim_head]) return q, k, q_norm_weight_tensor, k_norm_weight_tensor diff --git a/tests/layers/test_append_attention_with_output.py b/tests/layers/test_append_attention_with_output.py index 3c6f427cd..b7cd2e25b 100644 --- a/tests/layers/test_append_attention_with_output.py +++ b/tests/layers/test_append_attention_with_output.py @@ -250,8 +250,8 @@ def get_qkv_and_qkv_concat_tensor(bs, q_num_head, kv_num_head, seq_len, dim_head def apply_qk_norm(head_dim, dtype, q, k): q_norm_weight = np.random.random([head_dim]) / 10 k_norm_weight = np.random.random([head_dim]) / 10 - q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype=dtype) - k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype=dtype) + q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype="float32") + k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype="float32") print("q:", q.shape) print("k:", k.shape) bs, q_num_head, seq_len, dim_head = q.shape @@ -260,9 +260,9 @@ def apply_qk_norm(head_dim, dtype, q, k): q = q.reshape([-1, head_dim]) k = k.reshape([-1, head_dim]) print("q:", q) - q = fused_rms_norm(q, q_norm_weight_tensor, None, 1e-5)[0] + q = fused_rms_norm(q.astype("float32"), q_norm_weight_tensor, None, 1e-5)[0].astype(dtype) print("q after norm:", q) - k = fused_rms_norm(k, k_norm_weight_tensor, None, 1e-5)[0] + k = fused_rms_norm(k.astype("float32"), k_norm_weight_tensor, None, 1e-5)[0].astype(dtype) q = q.reshape([-1, q_num_head, seq_len, dim_head]) k = k.reshape([-1, kv_num_head, seq_len, dim_head]) return q, k, q_norm_weight_tensor, k_norm_weight_tensor