mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
fix mask_offset in append_attn (#3745)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
* fix mask_offset in append_attn * fix test
This commit is contained in:
@@ -42,19 +42,22 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
|||||||
const uint32_t elem_cnt,
|
const uint32_t elem_cnt,
|
||||||
const int kv_num_heads,
|
const int kv_num_heads,
|
||||||
const bool rope_3d,
|
const bool rope_3d,
|
||||||
const T* q_norm_weight,
|
const float* q_norm_weight,
|
||||||
const T* k_norm_weight,
|
const float* k_norm_weight,
|
||||||
const float rms_norm_eps) {
|
const float rms_norm_eps) {
|
||||||
using LoadT = AlignedVector<T, VecSize>;
|
using LoadT = AlignedVector<T, VecSize>;
|
||||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||||
using LoadKVT = AlignedVector<T, VecSize>;
|
using LoadKVT = AlignedVector<T, VecSize>;
|
||||||
constexpr int HalfVecSize = VecSize / 2;
|
constexpr int HalfVecSize = VecSize / 2;
|
||||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||||
|
using LoadFloat = AlignedVector<float, VecSize>;
|
||||||
LoadT src_vec;
|
LoadT src_vec;
|
||||||
LoadBiasT out_vec;
|
LoadBiasT out_vec;
|
||||||
LoadKVT cache_vec;
|
LoadKVT cache_vec;
|
||||||
LoadEmbT cos_emb_vec;
|
LoadEmbT cos_emb_vec;
|
||||||
LoadEmbT sin_emb_vec;
|
LoadEmbT sin_emb_vec;
|
||||||
|
LoadFloat tmp_vec;
|
||||||
|
LoadFloat q_norm_vec, k_norm_vec;
|
||||||
|
|
||||||
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
|
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
|
||||||
int64_t all_warp_num = gridDim.x * blockDim.y;
|
int64_t all_warp_num = gridDim.x * blockDim.y;
|
||||||
@@ -105,10 +108,8 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
|||||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||||
out_vec[2 * i] =
|
tmp_vec[2 * i] = tmp1;
|
||||||
static_cast<T>(tmp1);
|
tmp_vec[2 * i + 1] = tmp2;
|
||||||
out_vec[2 * i + 1] =
|
|
||||||
static_cast<T>(tmp2);
|
|
||||||
} else {
|
} else {
|
||||||
out_vec[2 * i] = src_vec[2 * i];
|
out_vec[2 * i] = src_vec[2 * i];
|
||||||
out_vec[2 * i + 1] = src_vec[2 * i + 1];
|
out_vec[2 * i + 1] = src_vec[2 * i + 1];
|
||||||
@@ -119,17 +120,17 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
|||||||
float row_variance =
|
float row_variance =
|
||||||
max(warp_m2 / head_size, 0.0f);
|
max(warp_m2 / head_size, 0.0f);
|
||||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||||
LoadT q_norm_vec, k_norm_vec;
|
|
||||||
if (hi < num_heads) { // q
|
if (hi < num_heads) { // q
|
||||||
Load<T, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
|
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < VecSize; i++) {
|
for (int i = 0; i < VecSize; i++) {
|
||||||
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i]));
|
out_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
|
||||||
}
|
}
|
||||||
} else { // k
|
} else { // k
|
||||||
Load<T, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
|
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
|
||||||
for (int i = 0; i < VecSize; i++) {
|
for (int i = 0; i < VecSize; i++) {
|
||||||
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i]));
|
out_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -39,8 +39,8 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
|||||||
const cudaStream_t& stream,
|
const cudaStream_t& stream,
|
||||||
const bool use_neox_style,
|
const bool use_neox_style,
|
||||||
const bool rope_3d,
|
const bool rope_3d,
|
||||||
const T* q_norm_weight,
|
const float* q_norm_weight,
|
||||||
const T* k_norm_weight,
|
const float* k_norm_weight,
|
||||||
const float rms_norm_eps) {
|
const float rms_norm_eps) {
|
||||||
const uint32_t elem_nums =
|
const uint32_t elem_nums =
|
||||||
use_neox_style ? bsz * (num_heads + 2 * kv_num_heads) * dim_head / 2
|
use_neox_style ? bsz * (num_heads + 2 * kv_num_heads) * dim_head / 2
|
||||||
@@ -569,8 +569,8 @@ void DecoderWriteCacheWithRoPEKernel(
|
|||||||
stream,
|
stream,
|
||||||
use_neox_rotary_style,
|
use_neox_rotary_style,
|
||||||
rope_3d,
|
rope_3d,
|
||||||
reinterpret_cast<const DataType_*>(q_norm_weight.get().data<T>()),
|
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
|
||||||
reinterpret_cast<const DataType_*>(k_norm_weight.get().data<T>()),
|
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
|
||||||
rms_norm_eps);
|
rms_norm_eps);
|
||||||
} else {
|
} else {
|
||||||
PD_THROW(
|
PD_THROW(
|
||||||
|
@@ -431,16 +431,19 @@ __global__ void GQAVariableLengthRotaryQKNormKernel(
|
|||||||
const int seq_len,
|
const int seq_len,
|
||||||
const int last_dim,
|
const int last_dim,
|
||||||
const bool rope_3d,
|
const bool rope_3d,
|
||||||
const T* q_norm_weight,
|
const float* q_norm_weight,
|
||||||
const T* k_norm_weight,
|
const float* k_norm_weight,
|
||||||
const float rms_norm_eps
|
const float rms_norm_eps
|
||||||
) {
|
) {
|
||||||
using LoadT = AlignedVector<T, VecSize>;
|
using LoadT = AlignedVector<T, VecSize>;
|
||||||
constexpr int HalfVecSize = VecSize / 2;
|
constexpr int HalfVecSize = VecSize / 2;
|
||||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||||
|
using LoadFloat = AlignedVector<float, VecSize>;
|
||||||
LoadT src_vec;
|
LoadT src_vec;
|
||||||
LoadEmbT cos_emb_vec;
|
LoadEmbT cos_emb_vec;
|
||||||
LoadEmbT sin_emb_vec;
|
LoadEmbT sin_emb_vec;
|
||||||
|
LoadFloat tmp_vec;
|
||||||
|
LoadFloat q_norm_vec, k_norm_vec;
|
||||||
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
|
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
|
||||||
int64_t all_warp_num = gridDim.x * blockDim.y;
|
int64_t all_warp_num = gridDim.x * blockDim.y;
|
||||||
const int half_lastdim = last_dim / 2;
|
const int half_lastdim = last_dim / 2;
|
||||||
@@ -477,25 +480,25 @@ __global__ void GQAVariableLengthRotaryQKNormKernel(
|
|||||||
const float sin_tmp = sin_emb_vec[i];
|
const float sin_tmp = sin_emb_vec[i];
|
||||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||||
src_vec[2 * i] = static_cast<T>(tmp1);
|
tmp_vec[2 * i] = tmp1;
|
||||||
src_vec[2 * i + 1] = static_cast<T>(tmp2);
|
tmp_vec[2 * i + 1] = tmp2;
|
||||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||||
}
|
}
|
||||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||||
float row_variance =
|
float row_variance =
|
||||||
max(warp_m2 / last_dim, 0.0f);
|
max(warp_m2 / last_dim, 0.0f);
|
||||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||||
LoadT q_norm_vec, k_norm_vec;
|
|
||||||
if (hi < q_num_head) {
|
if (hi < q_num_head) {
|
||||||
Load<T, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
|
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < VecSize; i++) {
|
for (int i = 0; i < VecSize; i++) {
|
||||||
src_vec[i] = static_cast<T>(static_cast<float>(src_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i]));
|
src_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Load<T, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
|
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
|
||||||
for (int i = 0; i < VecSize; i++) {
|
for (int i = 0; i < VecSize; i++) {
|
||||||
src_vec[i] = static_cast<T>(static_cast<float>(src_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i]));
|
src_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Store<T, VecSize>(src_vec, &qkv_out[base_idx]);
|
Store<T, VecSize>(src_vec, &qkv_out[base_idx]);
|
||||||
@@ -1695,8 +1698,8 @@ void gqa_rotary_qk_norm_variable(
|
|||||||
const cudaStream_t &stream,
|
const cudaStream_t &stream,
|
||||||
bool use_neox_style = false,
|
bool use_neox_style = false,
|
||||||
bool rope_3d = false,
|
bool rope_3d = false,
|
||||||
const T *q_norm_weight = nullptr,
|
const float *q_norm_weight = nullptr,
|
||||||
const T *k_norm_weight = nullptr,
|
const float *k_norm_weight = nullptr,
|
||||||
const float rms_norm_eps = 1e-6) {
|
const float rms_norm_eps = 1e-6) {
|
||||||
int64_t elem_nums =
|
int64_t elem_nums =
|
||||||
qkv_out_scales
|
qkv_out_scales
|
||||||
|
@@ -80,8 +80,8 @@ void EncoderWriteCacheWithRopeKernel(
|
|||||||
stream,
|
stream,
|
||||||
use_neox_style,
|
use_neox_style,
|
||||||
rope_3d,
|
rope_3d,
|
||||||
q_norm_weight ? q_norm_weight.get().data<T>() : nullptr,
|
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
|
||||||
k_norm_weight ? k_norm_weight.get().data<T>() : nullptr,
|
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
|
||||||
rms_norm_eps);
|
rms_norm_eps);
|
||||||
} else {
|
} else {
|
||||||
PD_THROW(
|
PD_THROW(
|
||||||
|
@@ -63,7 +63,6 @@ class AppendAttentionMetadata(AttentionMetadata):
|
|||||||
block_tables: Optional[paddle.Tensor] = None
|
block_tables: Optional[paddle.Tensor] = None
|
||||||
rotary_embs: Optional[paddle.Tensor] = None
|
rotary_embs: Optional[paddle.Tensor] = None
|
||||||
attn_mask: Optional[paddle.Tensor] = None
|
attn_mask: Optional[paddle.Tensor] = None
|
||||||
mask_offset: Optional[paddle.Tensor] = None
|
|
||||||
_fuse_kernel_compute_dtype: str = "bf16"
|
_fuse_kernel_compute_dtype: str = "bf16"
|
||||||
|
|
||||||
# pd_disaggregation
|
# pd_disaggregation
|
||||||
@@ -142,7 +141,6 @@ class AppendAttentionBackend(AttentionBackend):
|
|||||||
metadata.block_tables = forward_meta.block_tables
|
metadata.block_tables = forward_meta.block_tables
|
||||||
metadata.rotary_embs = forward_meta.rotary_embs
|
metadata.rotary_embs = forward_meta.rotary_embs
|
||||||
metadata.attn_mask = forward_meta.attn_mask
|
metadata.attn_mask = forward_meta.attn_mask
|
||||||
metadata.mask_offset = forward_meta.attn_mask_offsets
|
|
||||||
metadata.pre_caches_length = forward_meta.pre_caches_length
|
metadata.pre_caches_length = forward_meta.pre_caches_length
|
||||||
(
|
(
|
||||||
metadata.encoder_batch_ids,
|
metadata.encoder_batch_ids,
|
||||||
@@ -303,7 +301,7 @@ class AppendAttentionBackend(AttentionBackend):
|
|||||||
getattr(layer, "cache_v_zp", None),
|
getattr(layer, "cache_v_zp", None),
|
||||||
layer.linear_shift,
|
layer.linear_shift,
|
||||||
layer.linear_smooth,
|
layer.linear_smooth,
|
||||||
metadata.mask_offset,
|
forward_meta.attn_mask_offsets,
|
||||||
metadata.kv_signal_data_list[layer.layer_id],
|
metadata.kv_signal_data_list[layer.layer_id],
|
||||||
getattr(layer, "q_norm_weight", None),
|
getattr(layer, "q_norm_weight", None),
|
||||||
getattr(layer, "k_norm_weight", None),
|
getattr(layer, "k_norm_weight", None),
|
||||||
@@ -358,7 +356,7 @@ class AppendAttentionBackend(AttentionBackend):
|
|||||||
getattr(layer, "cache_v_zp", None),
|
getattr(layer, "cache_v_zp", None),
|
||||||
layer.linear_shift,
|
layer.linear_shift,
|
||||||
layer.linear_smooth,
|
layer.linear_smooth,
|
||||||
metadata.mask_offset,
|
forward_meta.attn_mask_offsets,
|
||||||
metadata.kv_signal_data_list[layer.layer_id],
|
metadata.kv_signal_data_list[layer.layer_id],
|
||||||
getattr(layer, "q_norm_weight", None),
|
getattr(layer, "q_norm_weight", None),
|
||||||
getattr(layer, "k_norm_weight", None),
|
getattr(layer, "k_norm_weight", None),
|
||||||
|
@@ -163,14 +163,14 @@ class Attention(nn.Layer):
|
|||||||
def init_weight(self):
|
def init_weight(self):
|
||||||
self.q_norm_weight = self.create_parameter(
|
self.q_norm_weight = self.create_parameter(
|
||||||
shape=[self.qk_head_dim],
|
shape=[self.qk_head_dim],
|
||||||
dtype=self._dtype,
|
dtype="float32",
|
||||||
is_bias=False,
|
is_bias=False,
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.k_norm_weight = self.create_parameter(
|
self.k_norm_weight = self.create_parameter(
|
||||||
shape=[self.qk_head_dim],
|
shape=[self.qk_head_dim],
|
||||||
dtype=self._dtype,
|
dtype="float32",
|
||||||
is_bias=False,
|
is_bias=False,
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
)
|
)
|
||||||
@@ -184,8 +184,8 @@ class Attention(nn.Layer):
|
|||||||
if self.use_qk_norm:
|
if self.use_qk_norm:
|
||||||
q_norm_weight_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.q_norm_key + ".weight")))
|
q_norm_weight_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.q_norm_key + ".weight")))
|
||||||
k_norm_weight_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.k_norm_key + ".weight")))
|
k_norm_weight_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.k_norm_key + ".weight")))
|
||||||
self.q_norm_weight.set_value(q_norm_weight_tensor)
|
self.q_norm_weight.set_value(q_norm_weight_tensor.astype("float32"))
|
||||||
self.k_norm_weight.set_value(k_norm_weight_tensor)
|
self.k_norm_weight.set_value(k_norm_weight_tensor.astype("float32"))
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@@ -250,8 +250,8 @@ def get_qkv_and_qkv_concat_tensor(bs, q_num_head, kv_num_head, seq_len, dim_head
|
|||||||
def apply_qk_norm(head_dim, dtype, q, k):
|
def apply_qk_norm(head_dim, dtype, q, k):
|
||||||
q_norm_weight = np.random.random([head_dim]) / 10
|
q_norm_weight = np.random.random([head_dim]) / 10
|
||||||
k_norm_weight = np.random.random([head_dim]) / 10
|
k_norm_weight = np.random.random([head_dim]) / 10
|
||||||
q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype=dtype)
|
q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype="float32")
|
||||||
k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype=dtype)
|
k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype="float32")
|
||||||
print("q:", q.shape)
|
print("q:", q.shape)
|
||||||
print("k:", k.shape)
|
print("k:", k.shape)
|
||||||
bs, q_num_head, seq_len, dim_head = q.shape
|
bs, q_num_head, seq_len, dim_head = q.shape
|
||||||
@@ -260,9 +260,9 @@ def apply_qk_norm(head_dim, dtype, q, k):
|
|||||||
q = q.reshape([-1, head_dim])
|
q = q.reshape([-1, head_dim])
|
||||||
k = k.reshape([-1, head_dim])
|
k = k.reshape([-1, head_dim])
|
||||||
print("q:", q)
|
print("q:", q)
|
||||||
q = fused_rms_norm(q, q_norm_weight_tensor, None, 1e-5)[0]
|
q = fused_rms_norm(q.astype("float32"), q_norm_weight_tensor, None, 1e-5)[0].astype(dtype)
|
||||||
print("q after norm:", q)
|
print("q after norm:", q)
|
||||||
k = fused_rms_norm(k, k_norm_weight_tensor, None, 1e-5)[0]
|
k = fused_rms_norm(k.astype("float32"), k_norm_weight_tensor, None, 1e-5)[0].astype(dtype)
|
||||||
q = q.reshape([-1, q_num_head, seq_len, dim_head])
|
q = q.reshape([-1, q_num_head, seq_len, dim_head])
|
||||||
k = k.reshape([-1, kv_num_head, seq_len, dim_head])
|
k = k.reshape([-1, kv_num_head, seq_len, dim_head])
|
||||||
return q, k, q_norm_weight_tensor, k_norm_weight_tensor
|
return q, k, q_norm_weight_tensor, k_norm_weight_tensor
|
||||||
|
@@ -250,8 +250,8 @@ def get_qkv_and_qkv_concat_tensor(bs, q_num_head, kv_num_head, seq_len, dim_head
|
|||||||
def apply_qk_norm(head_dim, dtype, q, k):
|
def apply_qk_norm(head_dim, dtype, q, k):
|
||||||
q_norm_weight = np.random.random([head_dim]) / 10
|
q_norm_weight = np.random.random([head_dim]) / 10
|
||||||
k_norm_weight = np.random.random([head_dim]) / 10
|
k_norm_weight = np.random.random([head_dim]) / 10
|
||||||
q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype=dtype)
|
q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype="float32")
|
||||||
k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype=dtype)
|
k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype="float32")
|
||||||
print("q:", q.shape)
|
print("q:", q.shape)
|
||||||
print("k:", k.shape)
|
print("k:", k.shape)
|
||||||
bs, q_num_head, seq_len, dim_head = q.shape
|
bs, q_num_head, seq_len, dim_head = q.shape
|
||||||
@@ -260,9 +260,9 @@ def apply_qk_norm(head_dim, dtype, q, k):
|
|||||||
q = q.reshape([-1, head_dim])
|
q = q.reshape([-1, head_dim])
|
||||||
k = k.reshape([-1, head_dim])
|
k = k.reshape([-1, head_dim])
|
||||||
print("q:", q)
|
print("q:", q)
|
||||||
q = fused_rms_norm(q, q_norm_weight_tensor, None, 1e-5)[0]
|
q = fused_rms_norm(q.astype("float32"), q_norm_weight_tensor, None, 1e-5)[0].astype(dtype)
|
||||||
print("q after norm:", q)
|
print("q after norm:", q)
|
||||||
k = fused_rms_norm(k, k_norm_weight_tensor, None, 1e-5)[0]
|
k = fused_rms_norm(k.astype("float32"), k_norm_weight_tensor, None, 1e-5)[0].astype(dtype)
|
||||||
q = q.reshape([-1, q_num_head, seq_len, dim_head])
|
q = q.reshape([-1, q_num_head, seq_len, dim_head])
|
||||||
k = k.reshape([-1, kv_num_head, seq_len, dim_head])
|
k = k.reshape([-1, kv_num_head, seq_len, dim_head])
|
||||||
return q, k, q_norm_weight_tensor, k_norm_weight_tensor
|
return q, k, q_norm_weight_tensor, k_norm_weight_tensor
|
||||||
|
Reference in New Issue
Block a user