fix mask_offset in append_attn (#3745)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* fix mask_offset in append_attn

* fix test
This commit is contained in:
lizhenyun01
2025-08-31 15:03:16 +08:00
committed by GitHub
parent 753772ace8
commit bed09ae8f8
8 changed files with 46 additions and 44 deletions

View File

@@ -42,19 +42,22 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
const uint32_t elem_cnt, const uint32_t elem_cnt,
const int kv_num_heads, const int kv_num_heads,
const bool rope_3d, const bool rope_3d,
const T* q_norm_weight, const float* q_norm_weight,
const T* k_norm_weight, const float* k_norm_weight,
const float rms_norm_eps) { const float rms_norm_eps) {
using LoadT = AlignedVector<T, VecSize>; using LoadT = AlignedVector<T, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>; using LoadBiasT = AlignedVector<T, VecSize>;
using LoadKVT = AlignedVector<T, VecSize>; using LoadKVT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2; constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>; using LoadEmbT = AlignedVector<float, HalfVecSize>;
using LoadFloat = AlignedVector<float, VecSize>;
LoadT src_vec; LoadT src_vec;
LoadBiasT out_vec; LoadBiasT out_vec;
LoadKVT cache_vec; LoadKVT cache_vec;
LoadEmbT cos_emb_vec; LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec; LoadEmbT sin_emb_vec;
LoadFloat tmp_vec;
LoadFloat q_norm_vec, k_norm_vec;
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y; int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
int64_t all_warp_num = gridDim.x * blockDim.y; int64_t all_warp_num = gridDim.x * blockDim.y;
@@ -105,10 +108,8 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
float tmp1 = input_left * cos_tmp - input_right * sin_tmp; float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp; float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
out_vec[2 * i] = tmp_vec[2 * i] = tmp1;
static_cast<T>(tmp1); tmp_vec[2 * i + 1] = tmp2;
out_vec[2 * i + 1] =
static_cast<T>(tmp2);
} else { } else {
out_vec[2 * i] = src_vec[2 * i]; out_vec[2 * i] = src_vec[2 * i];
out_vec[2 * i + 1] = src_vec[2 * i + 1]; out_vec[2 * i + 1] = src_vec[2 * i + 1];
@@ -119,17 +120,17 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
float row_variance = float row_variance =
max(warp_m2 / head_size, 0.0f); max(warp_m2 / head_size, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps); float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
LoadT q_norm_vec, k_norm_vec;
if (hi < num_heads) { // q if (hi < num_heads) { // q
Load<T, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec); Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
#pragma unroll #pragma unroll
for (int i = 0; i < VecSize; i++) { for (int i = 0; i < VecSize; i++) {
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i])); out_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
} }
} else { // k } else { // k
Load<T, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec); Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
for (int i = 0; i < VecSize; i++) { for (int i = 0; i < VecSize; i++) {
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i])); out_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
} }
} }
} }

View File

@@ -39,8 +39,8 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
const cudaStream_t& stream, const cudaStream_t& stream,
const bool use_neox_style, const bool use_neox_style,
const bool rope_3d, const bool rope_3d,
const T* q_norm_weight, const float* q_norm_weight,
const T* k_norm_weight, const float* k_norm_weight,
const float rms_norm_eps) { const float rms_norm_eps) {
const uint32_t elem_nums = const uint32_t elem_nums =
use_neox_style ? bsz * (num_heads + 2 * kv_num_heads) * dim_head / 2 use_neox_style ? bsz * (num_heads + 2 * kv_num_heads) * dim_head / 2
@@ -569,8 +569,8 @@ void DecoderWriteCacheWithRoPEKernel(
stream, stream,
use_neox_rotary_style, use_neox_rotary_style,
rope_3d, rope_3d,
reinterpret_cast<const DataType_*>(q_norm_weight.get().data<T>()), q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
reinterpret_cast<const DataType_*>(k_norm_weight.get().data<T>()), k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
rms_norm_eps); rms_norm_eps);
} else { } else {
PD_THROW( PD_THROW(

View File

@@ -431,16 +431,19 @@ __global__ void GQAVariableLengthRotaryQKNormKernel(
const int seq_len, const int seq_len,
const int last_dim, const int last_dim,
const bool rope_3d, const bool rope_3d,
const T* q_norm_weight, const float* q_norm_weight,
const T* k_norm_weight, const float* k_norm_weight,
const float rms_norm_eps const float rms_norm_eps
) { ) {
using LoadT = AlignedVector<T, VecSize>; using LoadT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2; constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>; using LoadEmbT = AlignedVector<float, HalfVecSize>;
using LoadFloat = AlignedVector<float, VecSize>;
LoadT src_vec; LoadT src_vec;
LoadEmbT cos_emb_vec; LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec; LoadEmbT sin_emb_vec;
LoadFloat tmp_vec;
LoadFloat q_norm_vec, k_norm_vec;
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y; int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
int64_t all_warp_num = gridDim.x * blockDim.y; int64_t all_warp_num = gridDim.x * blockDim.y;
const int half_lastdim = last_dim / 2; const int half_lastdim = last_dim / 2;
@@ -477,25 +480,25 @@ __global__ void GQAVariableLengthRotaryQKNormKernel(
const float sin_tmp = sin_emb_vec[i]; const float sin_tmp = sin_emb_vec[i];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp; float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp; float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
src_vec[2 * i] = static_cast<T>(tmp1); tmp_vec[2 * i] = tmp1;
src_vec[2 * i + 1] = static_cast<T>(tmp2); tmp_vec[2 * i + 1] = tmp2;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
} }
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2); WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
float row_variance = float row_variance =
max(warp_m2 / last_dim, 0.0f); max(warp_m2 / last_dim, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps); float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
LoadT q_norm_vec, k_norm_vec;
if (hi < q_num_head) { if (hi < q_num_head) {
Load<T, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec); Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
#pragma unroll #pragma unroll
for (int i = 0; i < VecSize; i++) { for (int i = 0; i < VecSize; i++) {
src_vec[i] = static_cast<T>(static_cast<float>(src_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i])); src_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
} }
} else { } else {
Load<T, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec); Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
for (int i = 0; i < VecSize; i++) { for (int i = 0; i < VecSize; i++) {
src_vec[i] = static_cast<T>(static_cast<float>(src_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i])); src_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
} }
} }
Store<T, VecSize>(src_vec, &qkv_out[base_idx]); Store<T, VecSize>(src_vec, &qkv_out[base_idx]);
@@ -1695,8 +1698,8 @@ void gqa_rotary_qk_norm_variable(
const cudaStream_t &stream, const cudaStream_t &stream,
bool use_neox_style = false, bool use_neox_style = false,
bool rope_3d = false, bool rope_3d = false,
const T *q_norm_weight = nullptr, const float *q_norm_weight = nullptr,
const T *k_norm_weight = nullptr, const float *k_norm_weight = nullptr,
const float rms_norm_eps = 1e-6) { const float rms_norm_eps = 1e-6) {
int64_t elem_nums = int64_t elem_nums =
qkv_out_scales qkv_out_scales

View File

@@ -80,8 +80,8 @@ void EncoderWriteCacheWithRopeKernel(
stream, stream,
use_neox_style, use_neox_style,
rope_3d, rope_3d,
q_norm_weight ? q_norm_weight.get().data<T>() : nullptr, q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
k_norm_weight ? k_norm_weight.get().data<T>() : nullptr, k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
rms_norm_eps); rms_norm_eps);
} else { } else {
PD_THROW( PD_THROW(

View File

@@ -63,7 +63,6 @@ class AppendAttentionMetadata(AttentionMetadata):
block_tables: Optional[paddle.Tensor] = None block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None attn_mask: Optional[paddle.Tensor] = None
mask_offset: Optional[paddle.Tensor] = None
_fuse_kernel_compute_dtype: str = "bf16" _fuse_kernel_compute_dtype: str = "bf16"
# pd_disaggregation # pd_disaggregation
@@ -142,7 +141,6 @@ class AppendAttentionBackend(AttentionBackend):
metadata.block_tables = forward_meta.block_tables metadata.block_tables = forward_meta.block_tables
metadata.rotary_embs = forward_meta.rotary_embs metadata.rotary_embs = forward_meta.rotary_embs
metadata.attn_mask = forward_meta.attn_mask metadata.attn_mask = forward_meta.attn_mask
metadata.mask_offset = forward_meta.attn_mask_offsets
metadata.pre_caches_length = forward_meta.pre_caches_length metadata.pre_caches_length = forward_meta.pre_caches_length
( (
metadata.encoder_batch_ids, metadata.encoder_batch_ids,
@@ -303,7 +301,7 @@ class AppendAttentionBackend(AttentionBackend):
getattr(layer, "cache_v_zp", None), getattr(layer, "cache_v_zp", None),
layer.linear_shift, layer.linear_shift,
layer.linear_smooth, layer.linear_smooth,
metadata.mask_offset, forward_meta.attn_mask_offsets,
metadata.kv_signal_data_list[layer.layer_id], metadata.kv_signal_data_list[layer.layer_id],
getattr(layer, "q_norm_weight", None), getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None), getattr(layer, "k_norm_weight", None),
@@ -358,7 +356,7 @@ class AppendAttentionBackend(AttentionBackend):
getattr(layer, "cache_v_zp", None), getattr(layer, "cache_v_zp", None),
layer.linear_shift, layer.linear_shift,
layer.linear_smooth, layer.linear_smooth,
metadata.mask_offset, forward_meta.attn_mask_offsets,
metadata.kv_signal_data_list[layer.layer_id], metadata.kv_signal_data_list[layer.layer_id],
getattr(layer, "q_norm_weight", None), getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None), getattr(layer, "k_norm_weight", None),

View File

@@ -163,14 +163,14 @@ class Attention(nn.Layer):
def init_weight(self): def init_weight(self):
self.q_norm_weight = self.create_parameter( self.q_norm_weight = self.create_parameter(
shape=[self.qk_head_dim], shape=[self.qk_head_dim],
dtype=self._dtype, dtype="float32",
is_bias=False, is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0), default_initializer=paddle.nn.initializer.Constant(0),
) )
self.k_norm_weight = self.create_parameter( self.k_norm_weight = self.create_parameter(
shape=[self.qk_head_dim], shape=[self.qk_head_dim],
dtype=self._dtype, dtype="float32",
is_bias=False, is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0), default_initializer=paddle.nn.initializer.Constant(0),
) )
@@ -184,8 +184,8 @@ class Attention(nn.Layer):
if self.use_qk_norm: if self.use_qk_norm:
q_norm_weight_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.q_norm_key + ".weight"))) q_norm_weight_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.q_norm_key + ".weight")))
k_norm_weight_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.k_norm_key + ".weight"))) k_norm_weight_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.k_norm_key + ".weight")))
self.q_norm_weight.set_value(q_norm_weight_tensor) self.q_norm_weight.set_value(q_norm_weight_tensor.astype("float32"))
self.k_norm_weight.set_value(k_norm_weight_tensor) self.k_norm_weight.set_value(k_norm_weight_tensor.astype("float32"))
def forward( def forward(
self, self,

View File

@@ -250,8 +250,8 @@ def get_qkv_and_qkv_concat_tensor(bs, q_num_head, kv_num_head, seq_len, dim_head
def apply_qk_norm(head_dim, dtype, q, k): def apply_qk_norm(head_dim, dtype, q, k):
q_norm_weight = np.random.random([head_dim]) / 10 q_norm_weight = np.random.random([head_dim]) / 10
k_norm_weight = np.random.random([head_dim]) / 10 k_norm_weight = np.random.random([head_dim]) / 10
q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype=dtype) q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype="float32")
k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype=dtype) k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype="float32")
print("q:", q.shape) print("q:", q.shape)
print("k:", k.shape) print("k:", k.shape)
bs, q_num_head, seq_len, dim_head = q.shape bs, q_num_head, seq_len, dim_head = q.shape
@@ -260,9 +260,9 @@ def apply_qk_norm(head_dim, dtype, q, k):
q = q.reshape([-1, head_dim]) q = q.reshape([-1, head_dim])
k = k.reshape([-1, head_dim]) k = k.reshape([-1, head_dim])
print("q:", q) print("q:", q)
q = fused_rms_norm(q, q_norm_weight_tensor, None, 1e-5)[0] q = fused_rms_norm(q.astype("float32"), q_norm_weight_tensor, None, 1e-5)[0].astype(dtype)
print("q after norm:", q) print("q after norm:", q)
k = fused_rms_norm(k, k_norm_weight_tensor, None, 1e-5)[0] k = fused_rms_norm(k.astype("float32"), k_norm_weight_tensor, None, 1e-5)[0].astype(dtype)
q = q.reshape([-1, q_num_head, seq_len, dim_head]) q = q.reshape([-1, q_num_head, seq_len, dim_head])
k = k.reshape([-1, kv_num_head, seq_len, dim_head]) k = k.reshape([-1, kv_num_head, seq_len, dim_head])
return q, k, q_norm_weight_tensor, k_norm_weight_tensor return q, k, q_norm_weight_tensor, k_norm_weight_tensor

View File

@@ -250,8 +250,8 @@ def get_qkv_and_qkv_concat_tensor(bs, q_num_head, kv_num_head, seq_len, dim_head
def apply_qk_norm(head_dim, dtype, q, k): def apply_qk_norm(head_dim, dtype, q, k):
q_norm_weight = np.random.random([head_dim]) / 10 q_norm_weight = np.random.random([head_dim]) / 10
k_norm_weight = np.random.random([head_dim]) / 10 k_norm_weight = np.random.random([head_dim]) / 10
q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype=dtype) q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype="float32")
k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype=dtype) k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype="float32")
print("q:", q.shape) print("q:", q.shape)
print("k:", k.shape) print("k:", k.shape)
bs, q_num_head, seq_len, dim_head = q.shape bs, q_num_head, seq_len, dim_head = q.shape
@@ -260,9 +260,9 @@ def apply_qk_norm(head_dim, dtype, q, k):
q = q.reshape([-1, head_dim]) q = q.reshape([-1, head_dim])
k = k.reshape([-1, head_dim]) k = k.reshape([-1, head_dim])
print("q:", q) print("q:", q)
q = fused_rms_norm(q, q_norm_weight_tensor, None, 1e-5)[0] q = fused_rms_norm(q.astype("float32"), q_norm_weight_tensor, None, 1e-5)[0].astype(dtype)
print("q after norm:", q) print("q after norm:", q)
k = fused_rms_norm(k, k_norm_weight_tensor, None, 1e-5)[0] k = fused_rms_norm(k.astype("float32"), k_norm_weight_tensor, None, 1e-5)[0].astype(dtype)
q = q.reshape([-1, q_num_head, seq_len, dim_head]) q = q.reshape([-1, q_num_head, seq_len, dim_head])
k = k.reshape([-1, kv_num_head, seq_len, dim_head]) k = k.reshape([-1, kv_num_head, seq_len, dim_head])
return q, k, q_norm_weight_tensor, k_norm_weight_tensor return q, k, q_norm_weight_tensor, k_norm_weight_tensor