Support GPT-OSS-BF16 (#4240)

* [Feature] AppendAtten support sinks & HEAD_DIM=64

* fix bug

* fix bug

* fix bug

* fix bug

* [Feature] support gpt-oss

* fix bug

* add mask

* support-gpt-oss

* support-gpt-oss

* fix long seq

* support wint8

* support wint8

* support wint8

* update test

* change sliding windows init pos

---------

Co-authored-by: ming1753 <ideaminghp@163.com>
Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
Co-authored-by: ming1753 <61511741+ming1753@users.noreply.github.com>
This commit is contained in:
Haonan Luo
2025-10-20 14:44:58 +08:00
committed by GitHub
parent 80a16c4c87
commit 1b9f351d21
32 changed files with 1502 additions and 172 deletions

View File

@@ -77,6 +77,14 @@ struct prefill_softmax_state_t {
__device__ __forceinline__ void normalize() {
const T d_t = static_cast<T>(d);
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
o[i] /= d_t;
}
}
__device__ __forceinline__ void normalize(float current_sink) {
const T d_t = static_cast<T>(d + __expf(current_sink - m));
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
o[i] /= d_t;
@@ -1028,7 +1036,8 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
const uint32_t chunk_end,
const uint32_t attn_mask_len,
float (*s_frag)[num_frags_z][8],
const int *mask_offset = nullptr) {
const int *mask_offset = nullptr,
const int sliding_window = 0) {
const uint32_t tx = threadIdx.x;
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
@@ -1045,11 +1054,21 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
bool out_of_boundary;
if (mask_offset) {
out_of_boundary = q_idx < qo_len ? (kv_idx >= mask_offset[q_idx * 2 + 1] || kv_idx < mask_offset[q_idx * 2]) : true;
} else {
}
else if (sliding_window > 0)
{
bool out_of_window = int(kv_idx) <= (int)kv_len + (int)q_idx - (int)qo_len - sliding_window;
out_of_boundary =
(causal
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
: kv_idx >= chunk_end);
(causal
? (kv_idx > kv_len + q_idx - qo_len || out_of_window || (kv_idx >= chunk_end))
: kv_idx >= chunk_end);
}
else
{
out_of_boundary =
(causal
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
: kv_idx >= chunk_end);
if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && q_idx < attn_mask_len) {
const int32_t mask_idx = q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
bool mask = attn_mask[mask_idx];
@@ -1064,7 +1083,7 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
s_frag[fx][fz][reg_id] =
out_of_boundary ? -3.0e+30f : s_frag[fx][fz][reg_id];
}
// printf("tid: %d. qk[%u,%u] = %f, mask: %d \n ", threadIdx.x, kv_idx, q_idx, static_cast<float>(s_frag[fx][fz][reg_id]), int(out_of_boundary));
} else {
const uint32_t q_idx = qo_idx_base,
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +
@@ -1458,6 +1477,33 @@ __device__ __forceinline__ void normalize_d(float (*o_frag)[num_frags_y][8],
}
}
template <uint32_t num_frags_x, uint32_t num_frags_y>
__device__ __forceinline__ void normalize_d(float (*o_frag)[num_frags_y][8],
float (*d)[2],
float (*m)[2],
float (*current_sinks)[2]) {
float d_rcp[num_frags_x][2];
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
d_rcp[fx][j] = 1.f / (d[fx][j] + __expf(current_sinks[fx][j] - m[fx][j]));
}
}
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
#pragma unroll
for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
o_frag[fx][fy][reg_id] =
o_frag[fx][fy][reg_id] * d_rcp[fx][(reg_id % 4) / 2];
}
}
}
}
template <uint32_t num_frags_x,
uint32_t num_frags_y,
uint32_t NUM_WARPS,
@@ -2271,6 +2317,7 @@ __global__ void merge_multi_chunks_decoder_kernel(
const int *__restrict__ cu_seqlens_q,
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const T *__restrict__ sinks, // [q_num_heads]
OutT *__restrict__ out,
const float quant_max_bound,
const float quant_min_bound,
@@ -2354,7 +2401,12 @@ __global__ void merge_multi_chunks_decoder_kernel(
const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
st.merge(load_vec, m_tmp, d_tmp);
}
st.normalize();
if (sinks) {
float current_sink = static_cast<float>(sinks[hid]);
st.normalize(current_sink);
} else {
st.normalize();
}
const uint32_t shift_smooth_offset = hid * head_dim + vid * vec_size;
AlignedVector<T, vec_size> shift_bias_vec;
@@ -2394,6 +2446,7 @@ __global__ void merge_multi_chunks_v2_kernel(
const int *__restrict__ cu_seqlens_q,
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const T *__restrict__ sinks, // [q_num_heads]
OutT *__restrict__ out,
const float quant_max_bound,
const float quant_min_bound,
@@ -2511,7 +2564,13 @@ __global__ void merge_multi_chunks_v2_kernel(
const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
st.merge(load_vec, m_tmp, d_tmp);
}
st.normalize();
if (sinks) {
float current_sink = static_cast<float>(sinks[hid]);
st.normalize(current_sink);
} else {
st.normalize();
}
const uint32_t shift_smooth_offset = hid * head_dim + vid * vec_size;
AlignedVector<T, vec_size> shift_bias_vec;