mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
Support GPT-OSS-BF16 (#4240)
* [Feature] AppendAtten support sinks & HEAD_DIM=64 * fix bug * fix bug * fix bug * fix bug * [Feature] support gpt-oss * fix bug * add mask * support-gpt-oss * support-gpt-oss * fix long seq * support wint8 * support wint8 * support wint8 * update test * change sliding windows init pos --------- Co-authored-by: ming1753 <ideaminghp@163.com> Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Co-authored-by: ming1753 <61511741+ming1753@users.noreply.github.com>
This commit is contained in:
@@ -77,6 +77,14 @@ struct prefill_softmax_state_t {
|
||||
|
||||
__device__ __forceinline__ void normalize() {
|
||||
const T d_t = static_cast<T>(d);
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
o[i] /= d_t;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void normalize(float current_sink) {
|
||||
const T d_t = static_cast<T>(d + __expf(current_sink - m));
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
o[i] /= d_t;
|
||||
@@ -1028,7 +1036,8 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
|
||||
const uint32_t chunk_end,
|
||||
const uint32_t attn_mask_len,
|
||||
float (*s_frag)[num_frags_z][8],
|
||||
const int *mask_offset = nullptr) {
|
||||
const int *mask_offset = nullptr,
|
||||
const int sliding_window = 0) {
|
||||
const uint32_t tx = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
|
||||
@@ -1045,11 +1054,21 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
|
||||
bool out_of_boundary;
|
||||
if (mask_offset) {
|
||||
out_of_boundary = q_idx < qo_len ? (kv_idx >= mask_offset[q_idx * 2 + 1] || kv_idx < mask_offset[q_idx * 2]) : true;
|
||||
} else {
|
||||
}
|
||||
else if (sliding_window > 0)
|
||||
{
|
||||
bool out_of_window = int(kv_idx) <= (int)kv_len + (int)q_idx - (int)qo_len - sliding_window;
|
||||
out_of_boundary =
|
||||
(causal
|
||||
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
|
||||
: kv_idx >= chunk_end);
|
||||
(causal
|
||||
? (kv_idx > kv_len + q_idx - qo_len || out_of_window || (kv_idx >= chunk_end))
|
||||
: kv_idx >= chunk_end);
|
||||
}
|
||||
else
|
||||
{
|
||||
out_of_boundary =
|
||||
(causal
|
||||
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
|
||||
: kv_idx >= chunk_end);
|
||||
if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && q_idx < attn_mask_len) {
|
||||
const int32_t mask_idx = q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
|
||||
bool mask = attn_mask[mask_idx];
|
||||
@@ -1064,7 +1083,7 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
|
||||
s_frag[fx][fz][reg_id] =
|
||||
out_of_boundary ? -3.0e+30f : s_frag[fx][fz][reg_id];
|
||||
}
|
||||
// printf("tid: %d. qk[%u,%u] = %f, mask: %d \n ", threadIdx.x, kv_idx, q_idx, static_cast<float>(s_frag[fx][fz][reg_id]), int(out_of_boundary));
|
||||
|
||||
} else {
|
||||
const uint32_t q_idx = qo_idx_base,
|
||||
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +
|
||||
@@ -1458,6 +1477,33 @@ __device__ __forceinline__ void normalize_d(float (*o_frag)[num_frags_y][8],
|
||||
}
|
||||
}
|
||||
|
||||
template <uint32_t num_frags_x, uint32_t num_frags_y>
|
||||
__device__ __forceinline__ void normalize_d(float (*o_frag)[num_frags_y][8],
|
||||
float (*d)[2],
|
||||
float (*m)[2],
|
||||
float (*current_sinks)[2]) {
|
||||
float d_rcp[num_frags_x][2];
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; ++j) {
|
||||
d_rcp[fx][j] = 1.f / (d[fx][j] + __expf(current_sinks[fx][j] - m[fx][j]));
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
||||
#pragma unroll
|
||||
for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
|
||||
o_frag[fx][fy][reg_id] =
|
||||
o_frag[fx][fy][reg_id] * d_rcp[fx][(reg_id % 4) / 2];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <uint32_t num_frags_x,
|
||||
uint32_t num_frags_y,
|
||||
uint32_t NUM_WARPS,
|
||||
@@ -2271,6 +2317,7 @@ __global__ void merge_multi_chunks_decoder_kernel(
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ sinks, // [q_num_heads]
|
||||
OutT *__restrict__ out,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
@@ -2354,7 +2401,12 @@ __global__ void merge_multi_chunks_decoder_kernel(
|
||||
const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
|
||||
st.merge(load_vec, m_tmp, d_tmp);
|
||||
}
|
||||
st.normalize();
|
||||
if (sinks) {
|
||||
float current_sink = static_cast<float>(sinks[hid]);
|
||||
st.normalize(current_sink);
|
||||
} else {
|
||||
st.normalize();
|
||||
}
|
||||
|
||||
const uint32_t shift_smooth_offset = hid * head_dim + vid * vec_size;
|
||||
AlignedVector<T, vec_size> shift_bias_vec;
|
||||
@@ -2394,6 +2446,7 @@ __global__ void merge_multi_chunks_v2_kernel(
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ sinks, // [q_num_heads]
|
||||
OutT *__restrict__ out,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
@@ -2511,7 +2564,13 @@ __global__ void merge_multi_chunks_v2_kernel(
|
||||
const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
|
||||
st.merge(load_vec, m_tmp, d_tmp);
|
||||
}
|
||||
st.normalize();
|
||||
|
||||
if (sinks) {
|
||||
float current_sink = static_cast<float>(sinks[hid]);
|
||||
st.normalize(current_sink);
|
||||
} else {
|
||||
st.normalize();
|
||||
}
|
||||
|
||||
const uint32_t shift_smooth_offset = hid * head_dim + vid * vec_size;
|
||||
AlignedVector<T, vec_size> shift_bias_vec;
|
||||
|
||||
Reference in New Issue
Block a user